Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Dec 7, 2023
1 parent c6914e1 commit 71ecfa5
Show file tree
Hide file tree
Showing 16 changed files with 84 additions and 78 deletions.
3 changes: 2 additions & 1 deletion bson/bson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ func TestMapCodec(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mapCodec := bsoncodec.NewMapCodec(tc.opts)
mapRegistry := NewRegistryBuilder().RegisterDefaultEncoder(reflect.Map, mapCodec).Build()
mapRegistry := NewRegistry()
mapRegistry.RegisterKindEncoder(reflect.Map, mapCodec)
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err)
Expand Down
62 changes: 26 additions & 36 deletions bson/bsoncodec/default_value_decoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ func TestDefaultValueDecoders(t *testing.T) {
{
"wrong kind (non-string key)",
map[bool]interface{}{},
&DecodeContext{Registry: buildDefaultRegistry()},
&DecodeContext{Registry: NewRegistry()},
&bsonrwtest.ValueReaderWriter{},
bsonrwtest.ReadElement,
fmt.Errorf("unsupported key type: %T", false),
Expand All @@ -825,15 +825,15 @@ func TestDefaultValueDecoders(t *testing.T) {
{
"Lookup Error",
map[string]string{},
&DecodeContext{Registry: NewRegistryBuilder().Build()},
&DecodeContext{Registry: NewRegistry()},
&bsonrwtest.ValueReaderWriter{},
bsonrwtest.ReadDocument,
ErrNoDecoder{Type: reflect.TypeOf("")},
},
{
"ReadElement Error",
make(map[string]interface{}),
&DecodeContext{Registry: buildDefaultRegistry()},
&DecodeContext{Registry: NewRegistry()},
&bsonrwtest.ValueReaderWriter{Err: errors.New("re error"), ErrAfter: bsonrwtest.ReadElement},
bsonrwtest.ReadElement,
errors.New("re error"),
Expand Down Expand Up @@ -911,7 +911,7 @@ func TestDefaultValueDecoders(t *testing.T) {
{
"Lookup Error",
[1]string{},
&DecodeContext{Registry: NewRegistryBuilder().Build()},
&DecodeContext{Registry: NewRegistry()},
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Array},
bsonrwtest.ReadArray,
ErrNoDecoder{Type: reflect.TypeOf("")},
Expand Down Expand Up @@ -1005,7 +1005,7 @@ func TestDefaultValueDecoders(t *testing.T) {
{
"Lookup Error",
[]string{},
&DecodeContext{Registry: NewRegistryBuilder().Build()},
&DecodeContext{Registry: NewRegistry()},
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Array},
bsonrwtest.ReadArray,
ErrNoDecoder{Type: reflect.TypeOf("")},
Expand Down Expand Up @@ -3182,7 +3182,7 @@ func TestDefaultValueDecoders(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
vr := bsonrw.NewBSONDocumentReader(tc.b)
reg := buildDefaultRegistry()
reg := NewRegistry()
vtype := reflect.TypeOf(tc.value)
dec, err := reg.LookupDecoder(vtype)
noerr(t, err)
Expand Down Expand Up @@ -3311,7 +3311,7 @@ func TestDefaultValueDecoders(t *testing.T) {
t.Skip()
}
val := reflect.New(tEmpty).Elem()
dc := DecodeContext{Registry: NewRegistryBuilder().Build()}
dc := DecodeContext{Registry: NewRegistry()}
want := ErrNoTypeMapEntry{Type: tc.bsontype}
got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val)
if !compareErrors(got, want) {
Expand All @@ -3324,11 +3324,9 @@ func TestDefaultValueDecoders(t *testing.T) {
t.Skip()
}
val := reflect.New(tEmpty).Elem()
dc := DecodeContext{
Registry: NewRegistryBuilder().
RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)).
Build(),
}
reg := NewRegistry()
reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val))
dc := DecodeContext{Registry: reg}
want := ErrNoDecoder{Type: reflect.TypeOf(tc.val)}
got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val)
if !compareErrors(got, want) {
Expand All @@ -3342,12 +3340,10 @@ func TestDefaultValueDecoders(t *testing.T) {
}
want := errors.New("DecodeValue failure error")
llc := &llCodec{t: t, err: want}
dc := DecodeContext{
Registry: NewRegistryBuilder().
RegisterTypeDecoder(reflect.TypeOf(tc.val), llc).
RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)).
Build(),
}
reg := NewRegistry()
reg.RegisterTypeDecoder(reflect.TypeOf(tc.val), llc)
reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val))
dc := DecodeContext{Registry: reg}
got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, reflect.New(tEmpty).Elem())
if !compareErrors(got, want) {
t.Errorf("Errors are not equal. got %v; want %v", got, want)
Expand All @@ -3357,12 +3353,10 @@ func TestDefaultValueDecoders(t *testing.T) {
t.Run("Success", func(t *testing.T) {
want := tc.val
llc := &llCodec{t: t, decodeval: tc.val}
dc := DecodeContext{
Registry: NewRegistryBuilder().
RegisterTypeDecoder(reflect.TypeOf(tc.val), llc).
RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)).
Build(),
}
reg := NewRegistry()
reg.RegisterTypeDecoder(reflect.TypeOf(tc.val), llc)
reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val))
dc := DecodeContext{Registry: reg}
got := reflect.New(tEmpty).Elem()
err := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, got)
noerr(t, err)
Expand Down Expand Up @@ -3396,7 +3390,7 @@ func TestDefaultValueDecoders(t *testing.T) {
llvr := &bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double}
want := ErrNoTypeMapEntry{Type: bsontype.Double}
val := reflect.New(tEmpty).Elem()
got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: NewRegistryBuilder().Build()}, llvr, val)
got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: NewRegistry()}, llvr, val)
if !compareErrors(got, want) {
t.Errorf("Errors are not equal. got %v; want %v", got, want)
}
Expand Down Expand Up @@ -3502,8 +3496,8 @@ func TestDefaultValueDecoders(t *testing.T) {
emptyInterfaceErrorDecode := func(DecodeContext, bsonrw.ValueReader, reflect.Value) error {
return decodeValueError
}
emptyInterfaceErrorRegistry := NewRegistryBuilder().
RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)).Build()
emptyInterfaceErrorRegistry := NewRegistry()
emptyInterfaceErrorRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode))

// Set up a document {foo: 10} and an error that would happen if the value were decoded into interface{}
// using the registry defined above.
Expand Down Expand Up @@ -3555,11 +3549,8 @@ func TestDefaultValueDecoders(t *testing.T) {
outerDoc := buildDocument(bsoncore.AppendDocumentElement(nil, "first", inner1Doc))

// Use a registry that has all default decoders with the custom interface{} decoder that always errors.
nestedRegistryBuilder := NewRegistryBuilder()
defaultValueDecoders.RegisterDefaultDecoders(nestedRegistryBuilder)
nestedRegistry := nestedRegistryBuilder.
RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)).
Build()
nestedRegistry := NewRegistry()
nestedRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode))
nestedErr := &DecodeError{
keys: []string{"fourth", "1", "third", "randomKey", "second", "first"},
wrapped: decodeValueError,
Expand Down Expand Up @@ -3644,7 +3635,7 @@ func TestDefaultValueDecoders(t *testing.T) {
"struct - no decoder found",
stringStruct{},
bsonrw.NewBSONDocumentReader(docBytes),
NewRegistryBuilder().Build(),
NewRegistry(),
defaultTestStructCodec,
stringStructErr,
},
Expand Down Expand Up @@ -3709,9 +3700,8 @@ func TestDefaultValueDecoders(t *testing.T) {
bsoncore.BuildArrayElement(nil, "boolArray", trueValue),
)

rb := NewRegistryBuilder()
defaultValueDecoders.RegisterDefaultDecoders(rb)
reg := rb.RegisterTypeMapEntry(bsontype.Boolean, reflect.TypeOf(mybool(true))).Build()
reg := NewRegistry()
reg.RegisterTypeMapEntry(bsontype.Boolean, reflect.TypeOf(mybool(true)))

dc := DecodeContext{Registry: reg}
vr := bsonrw.NewBSONDocumentReader(docBytes)
Expand Down
14 changes: 7 additions & 7 deletions bson/bsoncodec/default_value_encoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func TestDefaultValueEncoders(t *testing.T) {
{
"Lookup Error",
map[string]int{"foo": 1},
&EncodeContext{Registry: NewRegistryBuilder().Build()},
&EncodeContext{Registry: NewRegistry()},
&bsonrwtest.ValueReaderWriter{},
bsonrwtest.WriteDocument,
fmt.Errorf("no encoder found for int"),
Expand All @@ -262,7 +262,7 @@ func TestDefaultValueEncoders(t *testing.T) {
{
"empty map/success",
map[string]interface{}{},
&EncodeContext{Registry: NewRegistryBuilder().Build()},
&EncodeContext{Registry: NewRegistry()},
&bsonrwtest.ValueReaderWriter{},
bsonrwtest.WriteDocumentEnd,
nil,
Expand Down Expand Up @@ -318,7 +318,7 @@ func TestDefaultValueEncoders(t *testing.T) {
{
"Lookup Error",
[1]int{1},
&EncodeContext{Registry: NewRegistryBuilder().Build()},
&EncodeContext{Registry: NewRegistry()},
&bsonrwtest.ValueReaderWriter{},
bsonrwtest.WriteArray,
fmt.Errorf("no encoder found for int"),
Expand Down Expand Up @@ -396,7 +396,7 @@ func TestDefaultValueEncoders(t *testing.T) {
{
"Lookup Error",
[]int{1},
&EncodeContext{Registry: NewRegistryBuilder().Build()},
&EncodeContext{Registry: NewRegistry()},
&bsonrwtest.ValueReaderWriter{},
bsonrwtest.WriteArray,
fmt.Errorf("no encoder found for int"),
Expand Down Expand Up @@ -436,7 +436,7 @@ func TestDefaultValueEncoders(t *testing.T) {
{
"empty slice/success",
[]interface{}{},
&EncodeContext{Registry: NewRegistryBuilder().Build()},
&EncodeContext{Registry: NewRegistry()},
&bsonrwtest.ValueReaderWriter{},
bsonrwtest.WriteArrayEnd,
nil,
Expand Down Expand Up @@ -1837,7 +1837,7 @@ func TestDefaultValueEncoders(t *testing.T) {
t.Run("EmptyInterfaceEncodeValue/nil", func(t *testing.T) {
val := reflect.New(tEmpty).Elem()
llvrw := new(bsonrwtest.ValueReaderWriter)
err := dve.EmptyInterfaceEncodeValue(EncodeContext{Registry: NewRegistryBuilder().Build()}, llvrw, val)
err := dve.EmptyInterfaceEncodeValue(EncodeContext{Registry: NewRegistry()}, llvrw, val)
noerr(t, err)
if llvrw.Invoked != bsonrwtest.WriteNull {
t.Errorf("Incorrect method called. got %v; want %v", llvrw.Invoked, bsonrwtest.WriteNull)
Expand All @@ -1848,7 +1848,7 @@ func TestDefaultValueEncoders(t *testing.T) {
val := reflect.New(tEmpty).Elem()
val.Set(reflect.ValueOf(int64(1234567890)))
llvrw := new(bsonrwtest.ValueReaderWriter)
got := dve.EmptyInterfaceEncodeValue(EncodeContext{Registry: NewRegistryBuilder().Build()}, llvrw, val)
got := dve.EmptyInterfaceEncodeValue(EncodeContext{Registry: NewRegistry()}, llvrw, val)
want := ErrNoEncoder{Type: tInt64}
if !compareErrors(got, want) {
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
Expand Down
2 changes: 1 addition & 1 deletion bson/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func TestDecoderv2(t *testing.T) {
t.Run("SetRegistry", func(t *testing.T) {
t.Parallel()

r1, r2 := DefaultRegistry, NewRegistryBuilder().Build()
r1, r2 := DefaultRegistry, NewRegistry()
dc1 := bsoncodec.DecodeContext{Registry: r1}
dc2 := bsoncodec.DecodeContext{Registry: r2}
dec, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{}))
Expand Down
5 changes: 2 additions & 3 deletions bson/marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,8 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) {

return vw.WriteInt32(int32(val.Int()) * -1)
}
customReg := NewRegistryBuilder().
RegisterTypeEncoder(tInt32, encodeInt32).
Build()
customReg := NewRegistry()
customReg.RegisterTypeEncoder(tInt32, encodeInt32)

// Helper function to run the test and make assertions. The provided original value should result in the document
// {"x": {$numberInt: 1}} when marshalled with the default registry.
Expand Down
14 changes: 7 additions & 7 deletions bson/raw_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestRawValue(t *testing.T) {
t.Run("Uses registry attached to value", func(t *testing.T) {
t.Parallel()

reg := bsoncodec.NewRegistryBuilder().Build()
reg := bsoncodec.NewRegistry()
val := RawValue{Type: bsontype.String, Value: bsoncore.AppendString(nil, "foobar"), r: reg}
var s string
want := bsoncodec.ErrNoDecoder{Type: reflect.TypeOf(s)}
Expand Down Expand Up @@ -64,7 +64,7 @@ func TestRawValue(t *testing.T) {
t.Run("Returns lookup error", func(t *testing.T) {
t.Parallel()

reg := bsoncodec.NewRegistryBuilder().Build()
reg := bsoncodec.NewRegistry()
var val RawValue
var s string
want := bsoncodec.ErrNoDecoder{Type: reflect.TypeOf(s)}
Expand All @@ -76,7 +76,7 @@ func TestRawValue(t *testing.T) {
t.Run("Returns DecodeValue error", func(t *testing.T) {
t.Parallel()

reg := NewRegistryBuilder().Build()
reg := NewRegistry()
val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, 3.14159)}
var s string
want := fmt.Errorf("cannot decode %v into a string type", bsontype.Double)
Expand All @@ -88,7 +88,7 @@ func TestRawValue(t *testing.T) {
t.Run("Success", func(t *testing.T) {
t.Parallel()

reg := NewRegistryBuilder().Build()
reg := NewRegistry()
want := float64(3.14159)
val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, want)}
var got float64
Expand All @@ -115,7 +115,7 @@ func TestRawValue(t *testing.T) {
t.Run("Returns lookup error", func(t *testing.T) {
t.Parallel()

dc := bsoncodec.DecodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()}
dc := bsoncodec.DecodeContext{Registry: bsoncodec.NewRegistry()}
var val RawValue
var s string
want := bsoncodec.ErrNoDecoder{Type: reflect.TypeOf(s)}
Expand All @@ -127,7 +127,7 @@ func TestRawValue(t *testing.T) {
t.Run("Returns DecodeValue error", func(t *testing.T) {
t.Parallel()

dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()}
dc := bsoncodec.DecodeContext{Registry: NewRegistry()}
val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, 3.14159)}
var s string
want := fmt.Errorf("cannot decode %v into a string type", bsontype.Double)
Expand All @@ -139,7 +139,7 @@ func TestRawValue(t *testing.T) {
t.Run("Success", func(t *testing.T) {
t.Parallel()

dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()}
dc := bsoncodec.DecodeContext{Registry: NewRegistry()}
want := float64(3.14159)
val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, want)}
var got float64
Expand Down
5 changes: 2 additions & 3 deletions bson/unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,8 @@ func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) {
val.SetInt(int64(-1 * i32))
return nil
}
customReg := NewRegistryBuilder().
RegisterTypeDecoder(tInt32, decodeInt32).
Build()
customReg := NewRegistry()
customReg.RegisterTypeDecoder(tInt32, decodeInt32)

docBytes := bsoncore.BuildDocumentFromElements(
nil,
Expand Down
10 changes: 6 additions & 4 deletions bson/unmarshal_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,16 @@ func TestUnmarshalValue(t *testing.T) {
bytes: bsoncore.AppendString(nil, "hello world"),
},
}
rb := NewRegistryBuilder().RegisterTypeDecoder(reflect.TypeOf([]byte{}), bsoncodec.NewSliceCodec()).Build()
reg := NewRegistry()
reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), bsoncodec.NewSliceCodec())
for _, tc := range testCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

gotValue := reflect.New(reflect.TypeOf(tc.val))
err := UnmarshalValueWithRegistry(rb, tc.bsontype, tc.bytes, gotValue.Interface())
err := UnmarshalValueWithRegistry(reg, tc.bsontype, tc.bytes, gotValue.Interface())
assert.Nil(t, err, "UnmarshalValueWithRegistry error: %v", err)
assert.Equal(t, tc.val, gotValue.Elem().Interface(), "value mismatch; expected %s, got %s", tc.val, gotValue.Elem())
})
Expand All @@ -111,12 +112,13 @@ func BenchmarkSliceCodecUnmarshal(b *testing.B) {
bytes: bsoncore.AppendString(nil, strings.Repeat("t", 4096)),
},
}
rb := NewRegistryBuilder().RegisterTypeDecoder(reflect.TypeOf([]byte{}), bsoncodec.NewSliceCodec()).Build()
reg := NewRegistry()
reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), bsoncodec.NewSliceCodec())
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
err := UnmarshalValueWithRegistry(rb, bm.bsontype, bm.bytes, &[]byte{})
err := UnmarshalValueWithRegistry(reg, bm.bsontype, bm.bytes, &[]byte{})
if err != nil {
b.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions mongo/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestDatabase(t *testing.T) {
wc2 := &writeconcern.WriteConcern{W: 10}
rcLocal := readconcern.Local()
rcMajority := readconcern.Majority()
reg := bsoncodec.NewRegistryBuilder().Build()
reg := bsoncodec.NewRegistry()

opts := options.Database().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetWriteConcern(wc1).
SetReadPreference(rpSecondary).SetReadConcern(rcMajority).SetWriteConcern(wc2).SetRegistry(reg)
Expand All @@ -71,7 +71,7 @@ func TestDatabase(t *testing.T) {
rpPrimary := readpref.Primary()
rcLocal := readconcern.Local()
wc1 := &writeconcern.WriteConcern{W: 10}
reg := bsoncodec.NewRegistryBuilder().Build()
reg := bsoncodec.NewRegistry()

client := setupClient(options.Client().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetRegistry(reg))
got := client.Database("foo", options.Database().SetWriteConcern(wc1))
Expand Down
Loading

0 comments on commit 71ecfa5

Please sign in to comment.