diff --git a/cache.go b/cache.go index b6bdd11a..d7c00d99 100644 --- a/cache.go +++ b/cache.go @@ -21,12 +21,14 @@ const ( typeKeys typeEndKeys typeOmitNil + typeSelect ) const ( invalidValidation = "Invalid validation tag on field '%s'" undefinedValidation = "Undefined validation function '%s' on field '%s'" keysTagNotDefined = "'" + endKeysTag + "' tag encountered without a corresponding '" + keysTag + "' tag" + invalidSelectTag = "'select' tags must have exactly one value" ) type structCache struct { @@ -266,6 +268,22 @@ func (v *Validate) parseFieldTagsRecursive(tag string, fieldName string, alias s continue default: + if strings.HasPrefix(t, selectTag) { + vals := strings.SplitN(t, tagKeySeparator, 2) + + // Check again for exact match to prevent future conflicts + if vals[0] == selectTag { + if len(vals) == 1 { + panic(invalidSelectTag) + } + + current.typeof = typeSelect + current.hasParam = true + current.param = vals[1] + continue + } + } + if t == isdefault { current.typeof = typeIsDefault } diff --git a/doc.go b/doc.go index b4740918..98e817d6 100644 --- a/doc.go +++ b/doc.go @@ -249,6 +249,20 @@ Example #2 // eq=1|eq=2 will be applied to each array element in the map keys // required will be applied to map values +# Select + +Selects a struct field or map value for which the following tags will be applied. +It is similar to the dive tags for arrays/slices/maps except that it only applies to a single struct field or map value. + + Usage: select=FieldName + +Example: + + // Validates that the field "V" of "MyStruct.Field" is greater than 10 when "Field" is valid + type MyStruct struct { + Field sql.Null[uint] `validate:"omitempty,select=V,gt=10"` + } + # Required This validates that the value is not the data types default zero value. diff --git a/validator.go b/validator.go index 901e7b50..1981ea84 100644 --- a/validator.go +++ b/validator.go @@ -445,8 +445,60 @@ OUTER: ct = ct.next } - default: + case typeSelect: + var name, altName string + var fieldValue reflect.Value + switch kind { + case reflect.Struct: + + v.misc = append(v.misc[0:0], cf.name...) + v.misc = append(v.misc, '.') + v.misc = append(v.misc, ct.param...) + name = string(v.misc) + + if cf.namesEqual { + altName = name + } else { + v.misc = append(v.misc[0:0], cf.altName...) + v.misc = append(v.misc, '.') + v.misc = append(v.misc, ct.param...) + altName = string(v.misc) + } + fieldValue = current.FieldByName(ct.param) + + case reflect.Map: + + v.misc = append(v.misc[0:0], cf.name...) + v.misc = append(v.misc, '[') + v.misc = append(v.misc, ct.param...) + v.misc = append(v.misc, ']') + name = string(v.misc) + + if cf.namesEqual { + altName = name + } else { + v.misc = append(v.misc[0:0], cf.altName...) + v.misc = append(v.misc, '[') + v.misc = append(v.misc, ct.param...) + v.misc = append(v.misc, ']') + altName = string(v.misc) + } + + fieldValue = current.MapIndex(reflect.ValueOf(ct.param)) + + default: + panic("can't select field on a non struct or map types") + } + + v.traverseField(ctx, parent, fieldValue, ns, structNs, &cField{ + altName: altName, + name: name, + }, ct.next) + + return + + default: // set Field Level fields v.slflParent = parent v.flField = current diff --git a/validator_instance.go b/validator_instance.go index 1a345138..f2b0d082 100644 --- a/validator_instance.go +++ b/validator_instance.go @@ -38,6 +38,7 @@ const ( excludedIfTag = "excluded_if" excludedUnlessTag = "excluded_unless" skipValidationTag = "-" + selectTag = "select" diveTag = "dive" keysTag = "keys" endKeysTag = "endkeys" diff --git a/validator_test.go b/validator_test.go index 3b6d2634..b1ad5de1 100644 --- a/validator_test.go +++ b/validator_test.go @@ -13794,3 +13794,49 @@ func TestPrivateFieldsStruct(t *testing.T) { Equal(t, len(errs), tc.errorNum) } } + +func TestSelectTag(t *testing.T) { + validator := New(WithRequiredStructEnabled()) + + t.Run("on struct", func(t *testing.T) { + type Test struct { + Int sql.NullInt64 `validate:"required,select=Int64,gt=1"` + } + + validCase := Test{sql.NullInt64{Int64: 2}} + zeroCase := Test{} + invalidCase := Test{sql.NullInt64{Int64: 1}} + + Equal(t, validator.Struct(validCase), nil) + AssertError(t, validator.Struct(zeroCase), "Test.Int", "Test.Int", "Int", "Int", "required") + AssertError(t, validator.Struct(invalidCase), "Test.Int.Int64", "Test.Int.Int64", "Int.Int64", "Int.Int64", "gt") + }) + + t.Run("on map", func(t *testing.T) { + type Test struct { + Map map[string]int `validate:"required,select=key,gt=1"` + } + + validCase := Test{map[string]int{"key": 2}} + zeroCase := Test{} + invalidCase := Test{map[string]int{"key": 1}} + + Equal(t, validator.Struct(validCase), nil) + AssertError(t, validator.Struct(zeroCase), "Test.Map", "Test.Map", "Map", "Map", "required") + AssertError(t, validator.Struct(invalidCase), "Test.Map[key]", "Test.Map[key]", "Map[key]", "Map[key]", "gt") + }) + + t.Run("missing select value", func(t *testing.T) { + type Test struct { + Int sql.NullInt64 `validate:"required,select"` + } + + defer func() { + if r := recover(); r != invalidSelectTag { + t.Errorf("Expected panic %q, got %v", invalidSelectTag, r) + } + }() + + _ = validator.Struct(Test{}) + }) +}