Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for UTF8 validity in serix binary and map serialization #610

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion serializer/serix/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"math/big"
"reflect"
"time"
"unicode/utf8"

"github.com/iotaledger/hive.go/ierrors"
"github.com/iotaledger/hive.go/serializer/v2"
Expand Down Expand Up @@ -166,7 +167,11 @@ func (api *API) decodeBasedOnType(ctx context.Context, b []byte, value reflect.V
deseri := serializer.NewDeserializer(b)
addrValue := value.Addr()
addrValue = addrValue.Convert(reflect.TypeOf((*string)(nil)))
minLen, maxLen := ts.MinMaxLen()

var minLen, maxLen int
if opts.validation {
minLen, maxLen = ts.MinMaxLen()
}

//nolint:forcetypeassert // false positive, we already checked the type via reflect
deseri.ReadString(
Expand All @@ -185,6 +190,13 @@ func (api *API) decodeBasedOnType(ctx context.Context, b []byte, value reflect.V
}
}, minLen, maxLen)

if opts.validation {
// check the string for UTF-8 validity
if !utf8.ValidString(value.String()) {
return 0, ierrors.Errorf("can't deserialize 'string' type: %w", ErrNonUTF8String)
}
}

return deseri.Done()

case reflect.Bool:
Expand Down
10 changes: 9 additions & 1 deletion serializer/serix/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math/big"
"reflect"
"time"
"unicode/utf8"

"github.com/iotaledger/hive.go/ierrors"
"github.com/iotaledger/hive.go/serializer/v2"
Expand Down Expand Up @@ -110,6 +111,8 @@ func (api *API) encodeBasedOnType(
case reflect.Interface:
return api.encodeInterface(ctx, value, valueType, ts, opts)
case reflect.String:
str := value.String()

lengthPrefixType, set := ts.LengthPrefixType()
if !set {
return nil, ierrors.New("can't serialize 'string' type: no LengthPrefixType was provided")
Expand All @@ -118,11 +121,16 @@ func (api *API) encodeBasedOnType(
var minLen, maxLen int
if opts.validation {
minLen, maxLen = ts.MinMaxLen()

// check the string for UTF-8 validity
if !utf8.ValidString(str) {
return nil, ierrors.Errorf("can't serialize 'string' type: %w", ErrNonUTF8String)
}
}
seri := serializer.NewSerializer()

return seri.WriteString(
value.String(),
str,
serializer.SeriLengthPrefixType(lengthPrefixType),
func(err error) error {
return ierrors.Wrap(err, "failed to write string value to serializer")
Expand Down
12 changes: 7 additions & 5 deletions serializer/serix/map_decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,20 @@ func (api *API) mapDecodeBasedOnType(ctx context.Context, mapVal any, value refl
if !ok {
return ierrors.New("non string value for string field")
}
if !utf8.ValidString(str) {
return ErrNonUTF8String
}
addrValue := value.Addr().Convert(reflect.TypeOf((*string)(nil)))
addrValue.Elem().Set(reflect.ValueOf(mapVal))

if opts.validation {
if err := api.checkMinMaxBoundsLength(len(str), ts); err != nil {
return ierrors.Wrapf(err, "can't deserialize '%s' type", value.Kind())
}
// check the string for UTF-8 validity
if !utf8.ValidString(str) {
return ErrNonUTF8String
}
}

addrValue := value.Addr().Convert(reflect.TypeOf((*string)(nil)))
addrValue.Elem().Set(reflect.ValueOf(mapVal))

return nil
case reflect.Bool:
addrValue := value.Addr().Convert(reflect.TypeOf((*bool)(nil)))
Expand Down
12 changes: 4 additions & 8 deletions serializer/serix/map_encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ const (
keyDefaultSliceArray = "data"
)

var (
// ErrNonUTF8String gets returned when a non UTF-8 string is being encoded/decoded.
ErrNonUTF8String = ierrors.New("non UTF-8 string value")
)

func (api *API) mapEncode(ctx context.Context, value reflect.Value, ts TypeSettings, opts *options) (ele any, err error) {
valueI := value.Interface()
valueType := value.Type()
Expand Down Expand Up @@ -99,14 +94,15 @@ func (api *API) mapEncodeBasedOnType(
return api.mapEncodeInterface(ctx, value, valueType, opts)
case reflect.String:
str := value.String()
if !utf8.ValidString(str) {
return nil, ErrNonUTF8String
}

if opts.validation {
if err := api.checkMinMaxBoundsLength(len(str), ts); err != nil {
return nil, ierrors.Wrapf(err, "can't serialize '%s' type", value.Kind())
}
// check the string for UTF-8 validity
if !utf8.ValidString(str) {
return nil, ErrNonUTF8String
}
}

return value.String(), nil
Expand Down
2 changes: 2 additions & 0 deletions serializer/serix/serix.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ var (
ErrValidationMaxBytesExceeded = ierrors.New("max bytes size exceeded")
// ErrMapValidationViolatesUniqueness gets returned if the map elements are not unique.
ErrMapValidationViolatesUniqueness = ierrors.New("map elements must be unique")
// ErrNonUTF8String gets returned when a non UTF-8 string is being encoded/decoded.
ErrNonUTF8String = ierrors.New("non UTF-8 string value")
)

var (
Expand Down
136 changes: 110 additions & 26 deletions serializer/serix/serix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"reflect"
"testing"

"github.com/iancoleman/orderedmap"
"github.com/stretchr/testify/require"

"github.com/iotaledger/hive.go/serializer/v2"
Expand Down Expand Up @@ -145,38 +146,38 @@ func (test *serializeTest) run(t *testing.T) {
// binary serialize
serixData, err := testAPI.Encode(context.Background(), test.source, serix.WithValidation())
if test.seriErr != nil {
require.ErrorIs(t, err, test.seriErr)
require.ErrorIs(t, err, test.seriErr, "binary serialization failed")

// we also need to check the json serialization
_, err := testAPI.JSONEncode(context.Background(), test.source, serix.WithValidation())
require.ErrorIs(t, err, test.seriErr)
require.ErrorIs(t, err, test.seriErr, "json serialization failed")

return
}
require.NoError(t, err)
require.NoError(t, err, "binary serialization failed")

require.Equal(t, test.size, len(serixData))

// binary deserialize
serixTarget := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
bytesRead, err := testAPI.Decode(context.Background(), serixData, serixTarget)
require.NoError(t, err)
require.NoError(t, err, "binary deserialization failed")

require.Len(t, serixData, bytesRead)
require.EqualValues(t, test.source, serixTarget)
require.EqualValues(t, test.source, serixTarget, "binary")

// json serialize
sourceJSON, err := testAPI.JSONEncode(context.Background(), test.source, serix.WithValidation())
require.NoError(t, err)
require.NoError(t, err, "json serialization failed")

// json deserialize
jsonDest := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
require.NoError(t, testAPI.JSONDecode(context.Background(), sourceJSON, jsonDest, serix.WithValidation()))
require.NoError(t, testAPI.JSONDecode(context.Background(), sourceJSON, jsonDest, serix.WithValidation()), "json deserialization failed")

require.EqualValues(t, test.source, jsonDest)
require.EqualValues(t, test.source, jsonDest, "json")
}

func TestSerixMapSerialize(t *testing.T) {
func TestSerixSerializeMap(t *testing.T) {

type MyMapTypeKey string
type MyMapTypeValue string
Expand Down Expand Up @@ -297,6 +298,40 @@ func TestSerixMapSerialize(t *testing.T) {
}
}

func TestSerixSerializeString(t *testing.T) {

type TestStruct struct {
TestString string `serix:",lenPrefix=uint8"`
}

testAPI.RegisterTypeSettings(TestStruct{}, serix.TypeSettings{})

tests := []serializeTest{
{
name: "ok",
source: &TestStruct{
TestString: "hello world!",
},
target: &TestStruct{},
size: 13,
seriErr: nil,
},
{
name: "fail - invalid utf8 string",
source: &TestStruct{
TestString: string([]byte{0xff, 0xfe, 0xfd}),
},
target: &TestStruct{},
size: 0,
seriErr: serix.ErrNonUTF8String,
},
}

for _, tt := range tests {
t.Run(tt.name, tt.run)
}
}

type deSerializeTest struct {
name string
source any
Expand All @@ -305,41 +340,56 @@ type deSerializeTest struct {
deSeriErr error
}

// convert all *orderedmap.OrderedMap to map[string]interface{}
func convertOrderedMapToMap(m *orderedmap.OrderedMap) map[string]interface{} {
for k, v := range m.Values() {
if v, ok := v.(*orderedmap.OrderedMap); ok {
m.Set(k, convertOrderedMapToMap(v))
}
}

return m.Values()
}

func (test *deSerializeTest) run(t *testing.T) {
// binary serialize test data
// binary serialize test data (without validation)
serixData, err := testAPI.Encode(context.Background(), test.source)
require.NoError(t, err)
require.NoError(t, err, "binary serialization failed")

// json serialize test data
sourceJSON, err := testAPI.JSONEncode(context.Background(), test.source)
require.NoError(t, err)
// "map" serialize test data (without validation)
// we don't use the json serialization here, because we want to test serix, and be able to inject malicous data
serixMapData, err := testAPI.MapEncode(context.Background(), test.source)
require.NoError(t, err, "map serialization failed")

// convert all *orderedmap.OrderedMap in serixMapData to map[string]interface{}
serixMapDataUnordered := convertOrderedMapToMap(serixMapData)

// binary deserialize
serixTarget := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
bytesRead, err := testAPI.Decode(context.Background(), serixData, serixTarget, serix.WithValidation())
if test.deSeriErr != nil {
require.ErrorIs(t, err, test.deSeriErr)
require.ErrorIs(t, err, test.deSeriErr, "binary deserialization failed")

// we also need to check the json deserialization
jsonDest := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
err := testAPI.JSONDecode(context.Background(), sourceJSON, jsonDest, serix.WithValidation())
require.ErrorIs(t, err, test.deSeriErr)
// we also need to check the "map" deserialization
mapDest := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
err := testAPI.MapDecode(context.Background(), serixMapDataUnordered, mapDest, serix.WithValidation())
require.ErrorIs(t, err, test.deSeriErr, "map deserialization failed")

return
}
require.NoError(t, err)
require.NoError(t, err, "binary deserialization failed")

require.Equal(t, test.size, bytesRead)
require.EqualValues(t, test.source, serixTarget)
require.EqualValues(t, test.source, serixTarget, "binary")

// json deserialize
jsonDest := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
require.NoError(t, testAPI.JSONDecode(context.Background(), sourceJSON, jsonDest, serix.WithValidation()))
// "map" deserialize
mapDest := reflect.New(reflect.TypeOf(test.target).Elem()).Interface()
require.NoError(t, testAPI.MapDecode(context.Background(), serixMapDataUnordered, mapDest, serix.WithValidation()), "map deserialization failed")

require.EqualValues(t, test.source, jsonDest)
require.EqualValues(t, test.source, mapDest, "map")
}

func TestSerixMapDeserialize(t *testing.T) {
func TestSerixDeserializeMap(t *testing.T) {

type MyMapTypeKey string
type MyMapTypeValue string
Expand Down Expand Up @@ -459,6 +509,40 @@ func TestSerixMapDeserialize(t *testing.T) {
}
}

func TestSerixDeserializeString(t *testing.T) {

type TestStruct struct {
TestString string `serix:",lenPrefix=uint8"`
}

testAPI.RegisterTypeSettings(TestStruct{}, serix.TypeSettings{})

tests := []deSerializeTest{
{
name: "ok",
source: &TestStruct{
TestString: "hello world!",
},
target: &TestStruct{},
size: 13,
deSeriErr: nil,
},
{
name: "fail - invalid utf8 string",
source: &TestStruct{
TestString: string([]byte{0xff, 0xfe, 0xfd}),
},
target: &TestStruct{},
size: 0,
deSeriErr: serix.ErrNonUTF8String,
},
}

for _, tt := range tests {
t.Run(tt.name, tt.run)
}
}

func TestSerixFieldKeyString(t *testing.T) {
type test struct {
name string
Expand Down
Loading