diff --git a/serializer/serializable.go b/serializer/serializable.go index bad021b6..69ef2518 100644 --- a/serializer/serializable.go +++ b/serializer/serializable.go @@ -122,7 +122,9 @@ type ArrayRules struct { Min uint // The max array bound. Max uint - // A map of types which must occur within the array. + // A map of object types which must occur within the array. + // This is only checked on slices of types with an object type set. + // In particular, this means this is not checked for byte slices. MustOccur TypePrefixes // The guards applied while de/serializing Serializables. Guards SerializableGuard diff --git a/serializer/serix/decode.go b/serializer/serix/decode.go index a4f52b2d..1968d657 100644 --- a/serializer/serix/decode.go +++ b/serializer/serix/decode.go @@ -435,6 +435,12 @@ func (api *API) decodeSlice(ctx context.Context, b []byte, value reflect.Value, value.Set(reflect.MakeSlice(valueType, 0, 0)) } + if opts.validation { + if err := api.checkArrayMustOccur(value, ts); err != nil { + return bytesRead, ierrors.Wrapf(err, "can't deserialize '%s' type", value.Kind()) + } + } + return bytesRead, nil } diff --git a/serializer/serix/encode.go b/serializer/serix/encode.go index 916a9397..41f4f2c2 100644 --- a/serializer/serix/encode.go +++ b/serializer/serix/encode.go @@ -319,6 +319,13 @@ func (api *API) encodeSlice(ctx context.Context, value reflect.Value, valueType return seri.Serialize() } + + if opts.validation { + if err := api.checkArrayMustOccur(value, ts); err != nil { + return nil, ierrors.Wrapf(err, "can't serialize '%s' type", value.Kind()) + } + } + sliceLen := value.Len() data := make([][]byte, sliceLen) for i := 0; i < sliceLen; i++ { diff --git a/serializer/serix/map_decode.go b/serializer/serix/map_decode.go index 4c5f7f9e..329edbad 100644 --- a/serializer/serix/map_decode.go +++ b/serializer/serix/map_decode.go @@ -455,6 +455,12 @@ func (api *API) mapDecodeSlice(ctx context.Context, mapVal any, value reflect.Va value.Set(reflect.MakeSlice(valueType, 0, 0)) } + if opts.validation { + if err := api.checkArrayMustOccur(value, ts); err != nil { + return ierrors.Wrapf(err, "can't deserialize '%s' type", value.Kind()) + } + } + return nil } diff --git a/serializer/serix/map_encode.go b/serializer/serix/map_encode.go index 36cee5ad..2f2b4cea 100644 --- a/serializer/serix/map_encode.go +++ b/serializer/serix/map_encode.go @@ -265,6 +265,12 @@ func (api *API) mapEncodeSlice(ctx context.Context, value reflect.Value, valueTy return nil, ierrors.Wrapf(err, "can't serialize '%s' type", value.Kind()) } + if opts.validation { + if err := api.checkArrayMustOccur(value, ts); err != nil { + return nil, ierrors.Wrapf(err, "can't serialize '%s' type", value.Kind()) + } + } + data := make([]any, sliceLen) for i := 0; i < sliceLen; i++ { elemValue := value.Index(i) diff --git a/serializer/serix/serix.go b/serializer/serix/serix.go index 2130f644..82d0e9db 100644 --- a/serializer/serix/serix.go +++ b/serializer/serix/serix.go @@ -257,6 +257,48 @@ func (api *API) checkSerializedSize(ctx context.Context, value reflect.Value, ts return api.checkMaxByteSize(len(bytes), ts) } +// Checks the "Must Occur" array rules in the given slice. +func (api *API) checkArrayMustOccur(slice reflect.Value, ts TypeSettings) error { + if slice.Kind() != reflect.Slice { + return ierrors.Errorf("must occur can only be checked for a slice, got value of kind %v", slice.Kind()) + } + + if ts.arrayRules == nil || len(ts.arrayRules.MustOccur) == 0 { + return nil + } + + mustOccurPrefixes := make(serializer.TypePrefixes, len(ts.arrayRules.MustOccur)) + for key, value := range ts.arrayRules.MustOccur { + mustOccurPrefixes[key] = value + } + + sliceLen := slice.Len() + for i := 0; i < sliceLen; i++ { + elemValue := slice.Index(i) + + // Get the type prefix of the element by retrieving the type settings. + if elemValue.Kind() == reflect.Ptr || elemValue.Kind() == reflect.Interface { + elemValue = reflect.Indirect(elemValue.Elem()) + } + + elemTypeSettings, exists := api.getTypeSettings(elemValue.Type()) + if !exists { + return ierrors.Errorf("missing type settings for %s; needed to check Must Occur rules", elemValue) + } + _, typePrefix, err := getTypeDenotationAndCode(elemTypeSettings.objectType) + if err != nil { + return ierrors.WithStack(err) + } + delete(mustOccurPrefixes, typePrefix) + } + + if len(mustOccurPrefixes) != 0 { + return ierrors.Wrapf(serializer.ErrArrayValidationTypesNotOccurred, "expected type prefixes that did not occur: %v", mustOccurPrefixes) + } + + return nil +} + // Encode serializes the provided object obj into bytes. // serix traverses the object recursively and serializes everything based on the type. // If a type implements the custom Serializable interface serix delegates the serialization to that type. diff --git a/serializer/serix/serix_test.go b/serializer/serix/serix_test.go index 67045eed..6129e73b 100644 --- a/serializer/serix/serix_test.go +++ b/serializer/serix/serix_test.go @@ -614,3 +614,197 @@ func TestSerixFieldKeyString(t *testing.T) { }) } } + +func TestSerixMustOccur(t *testing.T) { + const ( + ShapeSquare byte = 100 + ShapeRectangle byte = 101 + ShapeTriangle byte = 102 + ) + + type ( + Shape interface { + } + Square struct { + Size uint8 `serix:""` + } + Rectangle struct { + Size uint8 `serix:""` + } + Triangle struct { + Size uint16 `serix:""` + } + Container struct { + Shapes []Shape `serix:""` + } + ) + + var shapesArrRules = &serix.ArrayRules{ + Min: 0, + Max: 10, + MustOccur: serializer.TypePrefixes{ + uint32(ShapeSquare): struct{}{}, + uint32(ShapeRectangle): struct{}{}, + }, + ValidationMode: serializer.ArrayValidationModeNoDuplicates | + serializer.ArrayValidationModeLexicalOrdering | + serializer.ArrayValidationModeAtMostOneOfEachTypeByte, + } + + must(testAPI.RegisterTypeSettings(Triangle{}, serix.TypeSettings{}.WithObjectType(uint8(ShapeTriangle)))) + must(testAPI.RegisterTypeSettings(Square{}, serix.TypeSettings{}.WithObjectType(uint8(ShapeSquare)))) + must(testAPI.RegisterTypeSettings(Rectangle{}, serix.TypeSettings{}.WithObjectType(uint8(ShapeRectangle)))) + must(testAPI.RegisterTypeSettings(Container{}, serix.TypeSettings{}.WithObjectType(uint8(5)))) + + must(testAPI.RegisterTypeSettings([]Shape{}, + serix.TypeSettings{}.WithLengthPrefixType(serix.LengthPrefixTypeAsByte).WithArrayRules(shapesArrRules), + )) + + must(testAPI.RegisterInterfaceObjects((*Shape)(nil), (*Triangle)(nil))) + must(testAPI.RegisterInterfaceObjects((*Shape)(nil), (*Square)(nil))) + must(testAPI.RegisterInterfaceObjects((*Shape)(nil), (*Rectangle)(nil))) + + tests := []encodingTest{ + { + name: "ok encoding", + source: &Container{ + Shapes: []Shape{ + &Square{Size: 10}, + &Rectangle{Size: 5}, + &Triangle{Size: 3}, + }, + }, + target: &Container{}, + seriErr: nil, + }, + { + name: "fail encoding - square must occur", + source: &Container{ + Shapes: []Shape{ + &Rectangle{Size: 5}, + &Triangle{Size: 3}, + }, + }, + target: &Container{}, + seriErr: serializer.ErrArrayValidationTypesNotOccurred, + }, + { + name: "fail encoding - square & rectangle must occur - empty slice", + source: &Container{ + Shapes: []Shape{}, + }, + target: &Container{}, + seriErr: serializer.ErrArrayValidationTypesNotOccurred, + }, + } + + for _, tt := range tests { + t.Run(tt.name, tt.run) + } + + deSeriTests := []decodingTest{ + { + name: "ok decoding", + source: &Container{ + Shapes: []Shape{ + &Square{Size: 10}, + &Rectangle{Size: 5}, + &Triangle{Size: 3}, + }, + }, + target: &Container{}, + deSeriErr: nil, + }, + { + name: "fail decoding - square must occur", + source: &Container{ + Shapes: []Shape{ + &Rectangle{Size: 5}, + &Triangle{Size: 3}, + }, + }, + target: &Container{}, + deSeriErr: serializer.ErrArrayValidationTypesNotOccurred, + }, + { + name: "fail decoding - square & rectangle must occur - empty slice", + source: &Container{ + Shapes: []Shape{}, + }, + target: &Container{}, + deSeriErr: serializer.ErrArrayValidationTypesNotOccurred, + }, + } + + for _, tt := range deSeriTests { + t.Run(tt.name, tt.run) + } +} + +type encodingTest struct { + name string + source any + target any + seriErr error +} + +func (test *encodingTest) run(t *testing.T) { + serixData, err := testAPI.Encode(context.Background(), test.source, serix.WithValidation()) + jsonData, jsonErr := testAPI.JSONEncode(context.Background(), test.source, serix.WithValidation()) + + if test.seriErr != nil { + require.ErrorIs(t, err, test.seriErr) + require.ErrorIs(t, jsonErr, test.seriErr) + + return + } + require.NoError(t, err) + require.NoError(t, jsonErr) + + serixTarget := reflect.New(reflect.TypeOf(test.target).Elem()).Interface() + bytesRead, err := testAPI.Decode(context.Background(), serixData, serixTarget) + + require.NoError(t, err) + require.Len(t, serixData, bytesRead) + require.EqualValues(t, test.source, serixTarget) + + jsonDest := reflect.New(reflect.TypeOf(test.target).Elem()).Interface() + require.NoError(t, testAPI.JSONDecode(context.Background(), jsonData, jsonDest)) + + require.EqualValues(t, test.source, jsonDest) +} + +type decodingTest struct { + name string + source any + target any + deSeriErr error +} + +func (test *decodingTest) run(t *testing.T) { + serixData, err := testAPI.Encode(context.Background(), test.source) + require.NoError(t, err) + + sourceJSON, err := testAPI.JSONEncode(context.Background(), test.source) + require.NoError(t, err) + + serixTarget := reflect.New(reflect.TypeOf(test.target).Elem()).Interface() + bytesRead, err := testAPI.Decode(context.Background(), serixData, serixTarget, serix.WithValidation()) + + jsonDest := reflect.New(reflect.TypeOf(test.target).Elem()).Interface() + jsonErr := testAPI.JSONDecode(context.Background(), sourceJSON, jsonDest, serix.WithValidation()) + + if test.deSeriErr != nil { + require.ErrorIs(t, err, test.deSeriErr) + require.ErrorIs(t, jsonErr, test.deSeriErr) + + return + } + require.NoError(t, err) + require.Len(t, serixData, bytesRead) + require.EqualValues(t, test.source, serixTarget) + + require.NoError(t, jsonErr) + + require.EqualValues(t, test.source, jsonDest) +}