From 6ad2b4a17816b1e991f73e598885c07704aea7ef Mon Sep 17 00:00:00 2001 From: jiuker <2818723467@qq.com> Date: Thu, 18 Jan 2024 15:55:15 +0800 Subject: [PATCH] fix: support all types in StringSet JSON unmarshal (#1925) --- pkg/policy/bucket-policy_test.go | 45 +++++++++++++ pkg/set/stringset.go | 11 ++-- pkg/set/stringset_test.go | 105 +++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+), 7 deletions(-) diff --git a/pkg/policy/bucket-policy_test.go b/pkg/policy/bucket-policy_test.go index 94cfe9050..3d10fcee0 100644 --- a/pkg/policy/bucket-policy_test.go +++ b/pkg/policy/bucket-policy_test.go @@ -210,6 +210,51 @@ func TestUnmarshalBucketPolicy(t *testing.T) { } } ] +}`, shouldSucceed: true}, + // Test 10 + {policyData: `{ + "Version": "2012-10-17", + "Statement": [{ + "Effect": "Deny", + "Principal": { + "AWS": [ + "*" + ] + }, + "Action": [ + "s3:PutObject" + ], + "Resource": [ + "arn:aws:s3:::mytest/*" + ], + "Condition": { + "Null": { + "s3:x-amz-server-side-encryption": [ + true + ] + } + } + }] +}`, shouldSucceed: true}, + // Test 11 + {policyData: `{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Deny", + "Principal": "*", + "Action": "s3:PutObject", + "Resource": [ + "arn:aws:s3:::DOC-EXAMPLE-BUCKET1", + "arn:aws:s3:::DOC-EXAMPLE-BUCKET1/*" + ], + "Condition": { + "NumericLessThan": { + "s3:TlsVersion": 1.2 + } + } + } + ] }`, shouldSucceed: true}, } diff --git a/pkg/set/stringset.go b/pkg/set/stringset.go index c35e58e1a..2566a3df7 100644 --- a/pkg/set/stringset.go +++ b/pkg/set/stringset.go @@ -149,22 +149,19 @@ func (set StringSet) MarshalJSON() ([]byte, error) { } // UnmarshalJSON - parses JSON data and creates new set with it. -// If 'data' contains JSON string array, the set contains each string. -// If 'data' contains JSON string, the set contains the string as one element. -// If 'data' contains Other JSON types, JSON parse error is returned. func (set *StringSet) UnmarshalJSON(data []byte) error { - sl := []string{} + sl := []interface{}{} var err error if err = json.Unmarshal(data, &sl); err == nil { *set = make(StringSet) for _, s := range sl { - set.Add(s) + set.Add(fmt.Sprintf("%v", s)) } } else { - var s string + var s interface{} if err = json.Unmarshal(data, &s); err == nil { *set = make(StringSet) - set.Add(s) + set.Add(fmt.Sprintf("%v", s)) } } diff --git a/pkg/set/stringset_test.go b/pkg/set/stringset_test.go index ec6dfbe63..278313135 100644 --- a/pkg/set/stringset_test.go +++ b/pkg/set/stringset_test.go @@ -19,6 +19,8 @@ package set import ( "fmt" + "reflect" + "sort" "strings" "testing" ) @@ -346,3 +348,106 @@ func TestStringSetToSlice(t *testing.T) { } } } + +func TestStringSet_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + expectResult []string + } + tests := []struct { + name string + set StringSet + args args + wantErr bool + }{ + { + name: "test strings", + set: NewStringSet(), + args: args{ + data: []byte(`["foo","bar"]`), + expectResult: []string{"foo", "bar"}, + }, + wantErr: false, + }, + { + name: "test string", + set: NewStringSet(), + args: args{ + data: []byte(`"foo"`), + expectResult: []string{"foo"}, + }, + wantErr: false, + }, + { + name: "test bools", + set: NewStringSet(), + args: args{ + data: []byte(`[false,true]`), + expectResult: []string{"false", "true"}, + }, + wantErr: false, + }, + { + name: "test bool", + set: NewStringSet(), + args: args{ + data: []byte(`false`), + expectResult: []string{"false"}, + }, + wantErr: false, + }, + { + name: "test ints", + set: NewStringSet(), + args: args{ + data: []byte(`[1,2]`), + expectResult: []string{"1", "2"}, + }, + wantErr: false, + }, + { + name: "test int", + set: NewStringSet(), + args: args{ + data: []byte(`1`), + expectResult: []string{"1"}, + }, + wantErr: false, + }, + { + name: "test floats", + set: NewStringSet(), + args: args{ + data: []byte(`[1.1,2.2]`), + expectResult: []string{"1.1", "2.2"}, + }, + wantErr: false, + }, + { + name: "test float", + set: NewStringSet(), + args: args{ + data: []byte(`1.1`), + expectResult: []string{"1.1"}, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.set.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + slice := tt.set.ToSlice() + sort.Slice(slice, func(i, j int) bool { + return slice[i] < slice[j] + }) + sort.Slice(tt.args.expectResult, func(i, j int) bool { + return tt.args.expectResult[i] < tt.args.expectResult[j] + }) + if !reflect.DeepEqual(slice, tt.args.expectResult) { + t.Errorf("StringSet() get %v, want %v", tt.set.ToSlice(), tt.args.expectResult) + } + }) + } +}