Skip to content

Commit

Permalink
Merge pull request #614 from iotaledger/fix/must-occur
Browse files Browse the repository at this point in the history
Fix Must Occur serix validation
  • Loading branch information
PhilippGackstatter authored Nov 22, 2023
2 parents 4ca2b6c + 65f84e0 commit bdf1cc3
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 1 deletion.
4 changes: 3 additions & 1 deletion serializer/serializable.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ type ArrayRules struct {
Min uint
// The max array bound.
Max uint
// A map of types which must occur within the array.
// A map of object types which must occur within the array.
// This is only checked on slices of types with an object type set.
// In particular, this means this is not checked for byte slices.
MustOccur TypePrefixes
// The guards applied while de/serializing Serializables.
Guards SerializableGuard
Expand Down
6 changes: 6 additions & 0 deletions serializer/serix/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,12 @@ func (api *API) decodeSlice(ctx context.Context, b []byte, value reflect.Value,
value.Set(reflect.MakeSlice(valueType, 0, 0))
}

if opts.validation {
if err := api.checkArrayMustOccur(value, ts); err != nil {
return bytesRead, ierrors.Wrapf(err, "can't deserialize '%s' type", value.Kind())
}
}

return bytesRead, nil
}

Expand Down
7 changes: 7 additions & 0 deletions serializer/serix/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,13 @@ func (api *API) encodeSlice(ctx context.Context, value reflect.Value, valueType

return seri.Serialize()
}

if opts.validation {
if err := api.checkArrayMustOccur(value, ts); err != nil {
return nil, ierrors.Wrapf(err, "can't serialize '%s' type", value.Kind())
}
}

sliceLen := value.Len()
data := make([][]byte, sliceLen)
for i := 0; i < sliceLen; i++ {
Expand Down
6 changes: 6 additions & 0 deletions serializer/serix/map_decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ func (api *API) mapDecodeSlice(ctx context.Context, mapVal any, value reflect.Va
value.Set(reflect.MakeSlice(valueType, 0, 0))
}

if opts.validation {
if err := api.checkArrayMustOccur(value, ts); err != nil {
return ierrors.Wrapf(err, "can't deserialize '%s' type", value.Kind())
}
}

return nil
}

Expand Down
6 changes: 6 additions & 0 deletions serializer/serix/map_encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ func (api *API) mapEncodeSlice(ctx context.Context, value reflect.Value, valueTy
return nil, ierrors.Wrapf(err, "can't serialize '%s' type", value.Kind())
}

if opts.validation {
if err := api.checkArrayMustOccur(value, ts); err != nil {
return nil, ierrors.Wrapf(err, "can't serialize '%s' type", value.Kind())
}
}

data := make([]any, sliceLen)
for i := 0; i < sliceLen; i++ {
elemValue := value.Index(i)
Expand Down
42 changes: 42 additions & 0 deletions serializer/serix/serix.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,48 @@ func (api *API) checkSerializedSize(ctx context.Context, value reflect.Value, ts
return api.checkMaxByteSize(len(bytes), ts)
}

// Checks the "Must Occur" array rules in the given slice.
func (api *API) checkArrayMustOccur(slice reflect.Value, ts TypeSettings) error {
if slice.Kind() != reflect.Slice {
return ierrors.Errorf("must occur can only be checked for a slice, got value of kind %v", slice.Kind())
}

if ts.arrayRules == nil || len(ts.arrayRules.MustOccur) == 0 {
return nil
}

mustOccurPrefixes := make(serializer.TypePrefixes, len(ts.arrayRules.MustOccur))
for key, value := range ts.arrayRules.MustOccur {
mustOccurPrefixes[key] = value
}

sliceLen := slice.Len()
for i := 0; i < sliceLen; i++ {
elemValue := slice.Index(i)

// Get the type prefix of the element by retrieving the type settings.
if elemValue.Kind() == reflect.Ptr || elemValue.Kind() == reflect.Interface {
elemValue = reflect.Indirect(elemValue.Elem())
}

elemTypeSettings, exists := api.getTypeSettings(elemValue.Type())
if !exists {
return ierrors.Errorf("missing type settings for %s; needed to check Must Occur rules", elemValue)
}
_, typePrefix, err := getTypeDenotationAndCode(elemTypeSettings.objectType)
if err != nil {
return ierrors.WithStack(err)
}
delete(mustOccurPrefixes, typePrefix)
}

if len(mustOccurPrefixes) != 0 {
return ierrors.Wrapf(serializer.ErrArrayValidationTypesNotOccurred, "expected type prefixes that did not occur: %v", mustOccurPrefixes)
}

return nil
}

// Encode serializes the provided object obj into bytes.
// serix traverses the object recursively and serializes everything based on the type.
// If a type implements the custom Serializable interface serix delegates the serialization to that type.
Expand Down
194 changes: 194 additions & 0 deletions serializer/serix/serix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,197 @@ func TestSerixFieldKeyString(t *testing.T) {
})
}
}

func TestSerixMustOccur(t *testing.T) {
const (
ShapeSquare byte = 100
ShapeRectangle byte = 101
ShapeTriangle byte = 102
)

type (
Shape interface {
}
Square struct {
Size uint8 `serix:""`
}
Rectangle struct {
Size uint8 `serix:""`
}
Triangle struct {
Size uint16 `serix:""`
}
Container struct {
Shapes []Shape `serix:""`
}
)

var shapesArrRules = &serix.ArrayRules{
Min: 0,
Max: 10,
MustOccur: serializer.TypePrefixes{
uint32(ShapeSquare): struct{}{},
uint32(ShapeRectangle): struct{}{},
},
ValidationMode: serializer.ArrayValidationModeNoDuplicates |
serializer.ArrayValidationModeLexicalOrdering |
serializer.ArrayValidationModeAtMostOneOfEachTypeByte,
}

must(testAPI.RegisterTypeSettings(Triangle{}, serix.TypeSettings{}.WithObjectType(uint8(ShapeTriangle))))
must(testAPI.RegisterTypeSettings(Square{}, serix.TypeSettings{}.WithObjectType(uint8(ShapeSquare))))
must(testAPI.RegisterTypeSettings(Rectangle{}, serix.TypeSettings{}.WithObjectType(uint8(ShapeRectangle))))
must(testAPI.RegisterTypeSettings(Container{}, serix.TypeSettings{}.WithObjectType(uint8(5))))

must(testAPI.RegisterTypeSettings([]Shape{},
serix.TypeSettings{}.WithLengthPrefixType(serix.LengthPrefixTypeAsByte).WithArrayRules(shapesArrRules),
))

must(testAPI.RegisterInterfaceObjects((*Shape)(nil), (*Triangle)(nil)))
must(testAPI.RegisterInterfaceObjects((*Shape)(nil), (*Square)(nil)))
must(testAPI.RegisterInterfaceObjects((*Shape)(nil), (*Rectangle)(nil)))

tests := []encodingTest{
{
name: "ok encoding",
source: &Container{
Shapes: []Shape{
&Square{Size: 10},
&Rectangle{Size: 5},
&Triangle{Size: 3},
},
},
target: &Container{},
seriErr: nil,
},
{
name: "fail encoding - square must occur",
source: &Container{
Shapes: []Shape{
&Rectangle{Size: 5},
&Triangle{Size: 3},
},
},
target: &Container{},
seriErr: serializer.ErrArrayValidationTypesNotOccurred,
},
{
name: "fail encoding - square & rectangle must occur - empty slice",
source: &Container{
Shapes: []Shape{},
},
target: &Container{},
seriErr: serializer.ErrArrayValidationTypesNotOccurred,
},
}

for _, tt := range tests {
t.Run(tt.name, tt.run)
}

deSeriTests := []decodingTest{
{
name: "ok decoding",
source: &Container{
Shapes: []Shape{
&Square{Size: 10},
&Rectangle{Size: 5},
&Triangle{Size: 3},
},
},
target: &Container{},
deSeriErr: nil,
},
{
name: "fail decoding - square must occur",
source: &Container{
Shapes: []Shape{
&Rectangle{Size: 5},
&Triangle{Size: 3},
},
},
target: &Container{},
deSeriErr: serializer.ErrArrayValidationTypesNotOccurred,
},
{
name: "fail decoding - square & rectangle must occur - empty slice",
source: &Container{
Shapes: []Shape{},
},
target: &Container{},
deSeriErr: serializer.ErrArrayValidationTypesNotOccurred,
},
}

for _, tt := range deSeriTests {
t.Run(tt.name, tt.run)
}
}

type encodingTest struct {
name string
source any
target any
seriErr error
}

func (test *encodingTest) run(t *testing.T) {
serixData, err := testAPI.Encode(context.Background(), test.source, serix.WithValidation())
jsonData, jsonErr := testAPI.JSONEncode(context.Background(), test.source, serix.WithValidation())

if test.seriErr != nil {
require.ErrorIs(t, err, test.seriErr)
require.ErrorIs(t, jsonErr, test.seriErr)

return
}
require.NoError(t, err)
require.NoError(t, jsonErr)

serixTarget := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
bytesRead, err := testAPI.Decode(context.Background(), serixData, serixTarget)

require.NoError(t, err)
require.Len(t, serixData, bytesRead)
require.EqualValues(t, test.source, serixTarget)

jsonDest := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
require.NoError(t, testAPI.JSONDecode(context.Background(), jsonData, jsonDest))

require.EqualValues(t, test.source, jsonDest)
}

type decodingTest struct {
name string
source any
target any
deSeriErr error
}

func (test *decodingTest) run(t *testing.T) {
serixData, err := testAPI.Encode(context.Background(), test.source)
require.NoError(t, err)

sourceJSON, err := testAPI.JSONEncode(context.Background(), test.source)
require.NoError(t, err)

serixTarget := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
bytesRead, err := testAPI.Decode(context.Background(), serixData, serixTarget, serix.WithValidation())

jsonDest := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
jsonErr := testAPI.JSONDecode(context.Background(), sourceJSON, jsonDest, serix.WithValidation())

if test.deSeriErr != nil {
require.ErrorIs(t, err, test.deSeriErr)
require.ErrorIs(t, jsonErr, test.deSeriErr)

return
}
require.NoError(t, err)
require.Len(t, serixData, bytesRead)
require.EqualValues(t, test.source, serixTarget)

require.NoError(t, jsonErr)

require.EqualValues(t, test.source, jsonDest)
}

0 comments on commit bdf1cc3

Please sign in to comment.