Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove LiteralTypeForLiteral by creating IsInstance function #5909

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions flyteadmin/pkg/manager/impl/validation/execution_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,8 @@ func CheckAndFetchInputsForExecution(
}
executionInputMap[name] = expectedInput.GetDefault()
} else {
inputType := validators.LiteralTypeForLiteral(executionInputMap[name])
err := validators.ValidateLiteralType(inputType)
if err != nil {
return nil, errors.NewInvalidLiteralTypeError(name, err)
}
if !validators.AreTypesCastable(inputType, expectedInput.GetVar().GetType()) {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid %s input wrong type. Expected %s, but got %s", name, expectedInput.GetVar().GetType(), inputType)
if !validators.IsInstance(executionInputMap[name], expectedInput.GetVar().GetType()) {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "invalid %s input wrong type. Expected %s, but got literal %s", name, expectedInput.GetVar().GetType(), executionInputMap[name])
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ import (

var execConfig = testutils.GetApplicationConfigWithDefaultDomains()

const failedToValidateLiteralType = "Failed to validate literal type"

func TestValidateExecEmptyProject(t *testing.T) {
request := testutils.GetExecutionRequest()
request.Project = ""
Expand Down Expand Up @@ -154,7 +152,7 @@ func TestValidateExecInputsWrongType(t *testing.T) {
lpRequest.Spec.FixedInputs,
lpRequest.Spec.DefaultInputs,
)
utils.AssertEqualWithSanitizedRegex(t, "invalid foo input wrong type. Expected simple:STRING, but got simple:INTEGER", err.Error())
utils.AssertEqualWithSanitizedRegex(t, "invalid foo input wrong type. Expected simple:STRING, but got literal scalar: {primitive:{integer:1}}", err.Error())
}

func TestValidateExecInputsExtraInputs(t *testing.T) {
Expand Down Expand Up @@ -244,7 +242,7 @@ func TestValidateExecUnknownIDLInputs(t *testing.T) {
assert.NotNil(t, err)

// Expected error message
assert.Contains(t, err.Error(), failedToValidateLiteralType)
assert.Contains(t, err.Error(), "invalid foo input wrong type. Expected simple:1000, but got literal scalar:{}")
}

func TestValidExecutionId(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,9 @@ func checkAndFetchExpectedInputForLaunchPlan(
if !ok {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "unexpected fixed_input %s", name)
}
inputType := validators.LiteralTypeForLiteral(fixedInput)
err := validators.ValidateLiteralType(inputType)
if err != nil {
return nil, errors.NewInvalidLiteralTypeError(name, err)
}
if !validators.AreTypesCastable(inputType, value.GetType()) {
if !validators.IsInstance(fixedInput, value.GetType()) {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"invalid fixed_input wrong type %s, expected %v, got %v instead", name, value.GetType(), inputType)
"invalid fixed_input wrong type %s, expected %v, got literal %v instead", name, value.GetType(), fixedInput)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestValidateLpDefaultInputsWrongType(t *testing.T) {
request.Spec.DefaultInputs.Parameters["foo"].Var.Type = &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_FLOAT}}
err := ValidateLaunchPlan(context.Background(), request, testutils.GetRepoWithDefaultProject(), lpApplicationConfig, getWorkflowInterface())

expected := "Type mismatch for Parameter foo in default_inputs has type simple:FLOAT , expected simple:STRING "
expected := "Invalid default value for variable foo in default_inputs - expected type simple:FLOAT, but got literal scalar:{primitive:{string_value:\"foo-value\"}}"
utils.AssertEqualWithSanitizedRegex(t, expected, err.Error())
}

Expand Down Expand Up @@ -207,7 +207,7 @@ func TestGetLpExpectedInvalidFixedInputType(t *testing.T) {
request.GetSpec().GetFixedInputs(), request.GetSpec().GetDefaultInputs(),
)

utils.AssertEqualWithSanitizedRegex(t, "invalid fixed_input wrong type bar, expected simple:BINARY , got simple:STRING instead", err.Error())
utils.AssertEqualWithSanitizedRegex(t, "invalid fixed_input wrong type bar, expected simple:BINARY, got literal scalar: {primitive: {string_value: \"bar-value\"}} instead", err.Error())
assert.Nil(t, actualMap)
}

Expand Down Expand Up @@ -272,7 +272,7 @@ func TestGetLpExpectedInvalidFixedInputWithUnknownIDL(t *testing.T) {
assert.NotNil(t, err)

// Expected error message
assert.Contains(t, err.Error(), failedToValidateLiteralType)
assert.Contains(t, err.Error(), "invalid fixed_input wrong type foo, expected simple:1000, got literal scalar:{} instead")
}

func TestGetLpExpectedNoFixedInput(t *testing.T) {
Expand Down
7 changes: 1 addition & 6 deletions flyteadmin/pkg/manager/impl/validation/signal_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,11 @@ func ValidateSignalSetRequest(ctx context.Context, db repositoryInterfaces.Repos
"failed to validate that signal [%v] exists, err: [%+v]",
signalModel.SignalKey, err)
}
valueType := propellervalidators.LiteralTypeForLiteral(request.Value)
lookupSignal, err := transformers.FromSignalModel(lookupSignalModel)
if err != nil {
return err
}
err = propellervalidators.ValidateLiteralType(valueType)
if err != nil {
return errors.NewInvalidLiteralTypeError("", err)
}
if !propellervalidators.AreTypesCastable(lookupSignal.Type, valueType) {
if !propellervalidators.IsInstance(request.Value, lookupSignal.Type) {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"requested signal value [%v] is not castable to existing signal type [%v]",
request.Value, lookupSignalModel.Type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,6 @@ func TestValidateSignalUpdateRequest(t *testing.T) {
assert.NotNil(t, err)

// Expected error message
assert.Contains(t, err.Error(), failedToValidateLiteralType)
assert.Contains(t, err.Error(), "requested signal value [scalar:{}] is not castable to existing signal type")
})
}
12 changes: 3 additions & 9 deletions flyteadmin/pkg/manager/impl/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,16 +281,10 @@ func validateParameterMap(inputMap *core.ParameterMap, fieldName string) error {
}
defaultValue := defaultInput.GetDefault()
if defaultValue != nil {
inputType := validators.LiteralTypeForLiteral(defaultValue)
err := validators.ValidateLiteralType(inputType)
if err != nil {
return errors.NewInvalidLiteralTypeError(name, err)
}

if !validators.AreTypesCastable(inputType, defaultInput.GetVar().GetType()) {
if !validators.IsInstance(defaultValue, defaultInput.GetVar().GetType()) {
return errors.NewFlyteAdminErrorf(codes.InvalidArgument,
"Type mismatch for Parameter %s in %s has type %s, expected %s", name, fieldName,
defaultInput.GetVar().GetType().String(), inputType.String())
"Invalid default value for variable %s in %s - expected type %s, but got literal %s",
name, fieldName, defaultInput.GetVar().GetType(), defaultValue)
}

if defaultInput.GetVar().GetType().GetSimple() == core.SimpleType_DATETIME {
Expand Down
2 changes: 1 addition & 1 deletion flyteadmin/pkg/manager/impl/validation/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func TestValidateParameterMap(t *testing.T) {
err := validateParameterMap(&exampleMap, fieldName)
assert.Error(t, err)
fmt.Println(err.Error())
assert.Contains(t, err.Error(), failedToValidateLiteralType)
assert.Contains(t, err.Error(), "Invalid default value for variable foo in test_field_name - expected type simple:1000, but got literal scalar:{}")
})
}

Expand Down
2 changes: 1 addition & 1 deletion flytepropeller/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ require (
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.47.0
go.opentelemetry.io/otel v1.24.0
go.opentelemetry.io/otel/trace v1.24.0
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8
golang.org/x/sync v0.7.0
golang.org/x/time v0.5.0
google.golang.org/grpc v1.62.1
Expand Down Expand Up @@ -136,6 +135,7 @@ require (
go.opentelemetry.io/otel/sdk v1.24.0 // indirect
go.opentelemetry.io/proto/otlp v1.1.0 // indirect
golang.org/x/crypto v0.25.0 // indirect
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 // indirect
golang.org/x/net v0.27.0 // indirect
golang.org/x/oauth2 v0.16.0 // indirect
golang.org/x/sys v0.22.0 // indirect
Expand Down
8 changes: 8 additions & 0 deletions flytepropeller/pkg/compiler/errors/compiler_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@
)
}

func NewMismatchingInstanceErr(nodeID, toVar, toType, fromVar string) *CompileError {
return newError(
MismatchingTypes,
fmt.Sprintf("Variable [%v] expected to be of type [%v], but got [%v].", toVar, toType, fromVar),
nodeID,
)

Check warning on line 221 in flytepropeller/pkg/compiler/errors/compiler_errors.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/errors/compiler_errors.go#L216-L221

Added lines #L216 - L221 were not covered by tests
}

func NewMismatchingVariablesErr(nodeID, fromVar, fromType, toVar, toType string) *CompileError {
return newError(
MismatchingTypes,
Expand Down
10 changes: 2 additions & 8 deletions flytepropeller/pkg/compiler/transformers/k8s/inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,8 @@ func validateInputs(nodeID common.NodeID, iface *core.TypedInterface, inputs cor
continue
}

inputType := validators.LiteralTypeForLiteral(inputVal)
err := validators.ValidateLiteralType(inputType)
if err != nil {
errs.Collect(errors.NewInvalidLiteralTypeErr(nodeID, inputVar, err))
continue
}
if !validators.AreTypesCastable(inputType, v.Type) {
errs.Collect(errors.NewMismatchingTypesErr(nodeID, inputVar, v.Type.String(), inputType.String()))
if !validators.IsInstance(inputVal, v.Type) {
errs.Collect(errors.NewMismatchingInstanceErr(nodeID, inputVar, v.Type.String(), inputVal.String()))
continue
}

Expand Down
6 changes: 3 additions & 3 deletions flytepropeller/pkg/compiler/transformers/k8s/inputs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestValidateInputs_InvalidLiteralType(t *testing.T) {
"input1": {
Type: &core.LiteralType{
Type: &core.LiteralType_Simple{
Simple: 1000,
Simple: core.SimpleType_INTEGER,
},
},
},
Expand All @@ -42,14 +42,14 @@ func TestValidateInputs_InvalidLiteralType(t *testing.T) {
idlNotFound := false
var errMsg string
for _, err := range errs.Errors().List() {
if err.Code() == "InvalidLiteralType" {
if err.Code() == "MismatchingTypes" {
idlNotFound = true
errMsg = err.Error()
break
}
}
assert.True(t, idlNotFound, "Expected InvalidLiteralType error was not found in errors")

expectedContainedErrorMsg := "Failed to validate literal type"
expectedContainedErrorMsg := "Variable [input1] expected to be of type "
assert.Contains(t, errMsg, expectedContainedErrorMsg)
}
Loading
Loading