Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Must Occur serix validation #614

Merged
merged 7 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion serializer/serializable.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ type ArrayRules struct {
Min uint
// The max array bound.
Max uint
// A map of types which must occur within the array.
// A map of object types which must occur within the array.
// This is only checked on slices of types with an object type set.
// In particular, this means this is not checked for byte slices.
MustOccur TypePrefixes
// The guards applied while de/serializing Serializables.
Guards SerializableGuard
Expand Down
6 changes: 6 additions & 0 deletions serializer/serix/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,12 @@ func (api *API) decodeSlice(ctx context.Context, b []byte, value reflect.Value,
value.Set(reflect.MakeSlice(valueType, 0, 0))
}

if opts.validation {
if err := api.checkArrayMustOccur(value, ts); err != nil {
return bytesRead, ierrors.Wrapf(err, "can't deserialize '%s' type", value.Kind())
}
}

return bytesRead, nil
}

Expand Down
7 changes: 7 additions & 0 deletions serializer/serix/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,13 @@ func (api *API) encodeSlice(ctx context.Context, value reflect.Value, valueType

return seri.Serialize()
}

if opts.validation {
if err := api.checkArrayMustOccur(value, ts); err != nil {
return nil, ierrors.Wrapf(err, "can't serialize '%s' type", value.Kind())
}
}

sliceLen := value.Len()
data := make([][]byte, sliceLen)
for i := 0; i < sliceLen; i++ {
Expand Down
6 changes: 6 additions & 0 deletions serializer/serix/map_decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ func (api *API) mapDecodeSlice(ctx context.Context, mapVal any, value reflect.Va
value.Set(reflect.MakeSlice(valueType, 0, 0))
}

if opts.validation {
if err := api.checkArrayMustOccur(value, ts); err != nil {
return ierrors.Wrapf(err, "can't deserialize '%s' type", value.Kind())
}
}

return nil
}

Expand Down
6 changes: 6 additions & 0 deletions serializer/serix/map_encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ func (api *API) mapEncodeSlice(ctx context.Context, value reflect.Value, valueTy
return nil, ierrors.Wrapf(err, "can't serialize '%s' type", value.Kind())
}

if opts.validation {
if err := api.checkArrayMustOccur(value, ts); err != nil {
return nil, ierrors.Wrapf(err, "can't serialize '%s' type", value.Kind())
}
}

data := make([]any, sliceLen)
for i := 0; i < sliceLen; i++ {
elemValue := value.Index(i)
Expand Down
42 changes: 42 additions & 0 deletions serializer/serix/serix.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,48 @@ func (api *API) checkSerializedSize(ctx context.Context, value reflect.Value, ts
return api.checkMaxByteSize(len(bytes), ts)
}

// Checks the "Must Occur" array rules in the given slice.
func (api *API) checkArrayMustOccur(slice reflect.Value, ts TypeSettings) error {
if slice.Kind() != reflect.Slice {
return ierrors.Errorf("must occur can only be checked for a slice, got value of kind %v", slice.Kind())
}

if ts.arrayRules == nil || len(ts.arrayRules.MustOccur) == 0 {
return nil
}

mustOccurPrefixes := make(serializer.TypePrefixes, len(ts.arrayRules.MustOccur))
for key, value := range ts.arrayRules.MustOccur {
mustOccurPrefixes[key] = value
}

sliceLen := slice.Len()
for i := 0; i < sliceLen; i++ {
elemValue := slice.Index(i)

// Get the type prefix of the element by retrieving the type settings.
if elemValue.Kind() == reflect.Ptr || elemValue.Kind() == reflect.Interface {
elemValue = reflect.Indirect(elemValue.Elem())
}

elemTypeSettings, exists := api.getTypeSettings(elemValue.Type())
if !exists {
return ierrors.Errorf("missing type settings for %s; needed to check Must Occur rules", elemValue)
}
_, typePrefix, err := getTypeDenotationAndCode(elemTypeSettings.objectType)
if err != nil {
return ierrors.WithStack(err)
}
delete(mustOccurPrefixes, typePrefix)
}

if len(mustOccurPrefixes) != 0 {
return ierrors.Wrapf(serializer.ErrArrayValidationTypesNotOccurred, "expected type prefixes that did not occur: %v", mustOccurPrefixes)
}

return nil
}

// Encode serializes the provided object obj into bytes.
// serix traverses the object recursively and serializes everything based on the type.
// If a type implements the custom Serializable interface serix delegates the serialization to that type.
Expand Down
194 changes: 194 additions & 0 deletions serializer/serix/serix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,197 @@ func TestSerixFieldKeyString(t *testing.T) {
})
}
}

func TestSerixMustOccur(t *testing.T) {
const (
ShapeSquare byte = 100
ShapeRectangle byte = 101
ShapeTriangle byte = 102
)

type (
Shape interface {
}
Square struct {
Size uint8 `serix:""`
}
Rectangle struct {
Size uint8 `serix:""`
}
Triangle struct {
Size uint16 `serix:""`
}
Container struct {
Shapes []Shape `serix:""`
}
)

var shapesArrRules = &serix.ArrayRules{
Min: 0,
Max: 10,
MustOccur: serializer.TypePrefixes{
uint32(ShapeSquare): struct{}{},
uint32(ShapeRectangle): struct{}{},
},
ValidationMode: serializer.ArrayValidationModeNoDuplicates |
serializer.ArrayValidationModeLexicalOrdering |
serializer.ArrayValidationModeAtMostOneOfEachTypeByte,
}

must(testAPI.RegisterTypeSettings(Triangle{}, serix.TypeSettings{}.WithObjectType(uint8(ShapeTriangle))))
must(testAPI.RegisterTypeSettings(Square{}, serix.TypeSettings{}.WithObjectType(uint8(ShapeSquare))))
must(testAPI.RegisterTypeSettings(Rectangle{}, serix.TypeSettings{}.WithObjectType(uint8(ShapeRectangle))))
must(testAPI.RegisterTypeSettings(Container{}, serix.TypeSettings{}.WithObjectType(uint8(5))))

must(testAPI.RegisterTypeSettings([]Shape{},
serix.TypeSettings{}.WithLengthPrefixType(serix.LengthPrefixTypeAsByte).WithArrayRules(shapesArrRules),
))

must(testAPI.RegisterInterfaceObjects((*Shape)(nil), (*Triangle)(nil)))
must(testAPI.RegisterInterfaceObjects((*Shape)(nil), (*Square)(nil)))
must(testAPI.RegisterInterfaceObjects((*Shape)(nil), (*Rectangle)(nil)))

tests := []encodingTest{
{
name: "ok encoding",
source: &Container{
Shapes: []Shape{
&Square{Size: 10},
&Rectangle{Size: 5},
&Triangle{Size: 3},
},
},
target: &Container{},
seriErr: nil,
},
{
name: "fail encoding - square must occur",
source: &Container{
Shapes: []Shape{
&Rectangle{Size: 5},
&Triangle{Size: 3},
},
},
target: &Container{},
seriErr: serializer.ErrArrayValidationTypesNotOccurred,
},
{
name: "fail encoding - square & rectangle must occur - empty slice",
source: &Container{
Shapes: []Shape{},
},
target: &Container{},
seriErr: serializer.ErrArrayValidationTypesNotOccurred,
},
}

for _, tt := range tests {
t.Run(tt.name, tt.run)
}

deSeriTests := []decodingTest{
{
name: "ok decoding",
source: &Container{
Shapes: []Shape{
&Square{Size: 10},
&Rectangle{Size: 5},
&Triangle{Size: 3},
},
},
target: &Container{},
deSeriErr: nil,
},
{
name: "fail decoding - square must occur",
source: &Container{
Shapes: []Shape{
&Rectangle{Size: 5},
&Triangle{Size: 3},
},
},
target: &Container{},
deSeriErr: serializer.ErrArrayValidationTypesNotOccurred,
},
{
name: "fail decoding - square & rectangle must occur - empty slice",
source: &Container{
Shapes: []Shape{},
},
target: &Container{},
deSeriErr: serializer.ErrArrayValidationTypesNotOccurred,
},
}

for _, tt := range deSeriTests {
t.Run(tt.name, tt.run)
}
}

type encodingTest struct {
name string
source any
target any
seriErr error
}

func (test *encodingTest) run(t *testing.T) {
serixData, err := testAPI.Encode(context.Background(), test.source, serix.WithValidation())
jsonData, jsonErr := testAPI.JSONEncode(context.Background(), test.source, serix.WithValidation())

if test.seriErr != nil {
require.ErrorIs(t, err, test.seriErr)
require.ErrorIs(t, jsonErr, test.seriErr)

return
}
require.NoError(t, err)
require.NoError(t, jsonErr)

serixTarget := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
bytesRead, err := testAPI.Decode(context.Background(), serixData, serixTarget)

require.NoError(t, err)
require.Len(t, serixData, bytesRead)
require.EqualValues(t, test.source, serixTarget)

jsonDest := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
require.NoError(t, testAPI.JSONDecode(context.Background(), jsonData, jsonDest))

require.EqualValues(t, test.source, jsonDest)
}

type decodingTest struct {
name string
source any
target any
deSeriErr error
}

func (test *decodingTest) run(t *testing.T) {
serixData, err := testAPI.Encode(context.Background(), test.source)
require.NoError(t, err)

sourceJSON, err := testAPI.JSONEncode(context.Background(), test.source)
require.NoError(t, err)

serixTarget := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
bytesRead, err := testAPI.Decode(context.Background(), serixData, serixTarget, serix.WithValidation())

jsonDest := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
jsonErr := testAPI.JSONDecode(context.Background(), sourceJSON, jsonDest, serix.WithValidation())

if test.deSeriErr != nil {
require.ErrorIs(t, err, test.deSeriErr)
require.ErrorIs(t, jsonErr, test.deSeriErr)

return
}
require.NoError(t, err)
require.Len(t, serixData, bytesRead)
require.EqualValues(t, test.source, serixTarget)

require.NoError(t, jsonErr)

require.EqualValues(t, test.source, jsonDest)
}
Loading