From 7a0b8bf799d3f02827f316d4209206ca90da0388 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Mon, 20 May 2024 14:13:43 -0400 Subject: [PATCH] WIP --- bson/bsoncodec.go | 28 +----- bson/cond_addr_codec_test.go | 4 +- bson/decoder.go | 78 +++++++++++--- bson/decoder_test.go | 12 +-- bson/default_value_decoders_test.go | 151 +++++++++++++--------------- bson/encoder.go | 16 +-- bson/mgocompat/bson_test.go | 12 +-- bson/primitive_codecs_test.go | 12 +-- bson/raw_value.go | 10 +- bson/raw_value_test.go | 12 +-- bson/registry.go | 10 +- bson/registry_test.go | 12 +-- bson/string_codec_test.go | 12 +-- bson/time_codec_test.go | 4 +- bson/truncation_test.go | 20 ++-- bson/unmarshal.go | 18 ++-- bson/unmarshal_test.go | 6 +- 17 files changed, 205 insertions(+), 212 deletions(-) diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 6d70a99ca1..5e910fca88 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -72,31 +72,6 @@ func (vde ValueDecoderError) Error() string { return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received) } -// DecodeContext is the contextual information required for a Codec to decode a -// value. -type DecodeContext struct { - *Registry - - // defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the - // usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is - // set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an - // error. DocumentType overrides the Ancestor field. - defaultDocumentType reflect.Type - - // truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" - // values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64, - // uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to - // BSON "decimal128" values. - truncate bool - - binaryAsSlice bool - decodeObjectIDAsHex bool - useJSONStructTags bool - useLocalTimeZone bool - zeroMaps bool - zeroStructs bool -} - // EncoderRegistry is an interface provides a ValueEncoder based on the given reflect.Type. type EncoderRegistry interface { LookupEncoder(reflect.Type) (ValueEncoder, error) @@ -166,8 +141,7 @@ var _ ValueDecoder = decodeAdapter{} var _ typeDecoder = decodeAdapter{} func decodeTypeOrValueWithInfo(vd ValueDecoder, reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { - td, ok := vd.(typeDecoder) - if ok && td != nil { + if td, _ := vd.(typeDecoder); td != nil { return td.decodeType(reg, vr, t) } diff --git a/bson/cond_addr_codec_test.go b/bson/cond_addr_codec_test.go index 15b4a8a333..6fd777ae77 100644 --- a/bson/cond_addr_codec_test.go +++ b/bson/cond_addr_codec_test.go @@ -78,7 +78,7 @@ func TestCondAddrCodec(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := condDecoder.DecodeValue(DecodeContext{}, rw, tc.val) + err := condDecoder.DecodeValue(nil, rw, tc.val) assert.Nil(t, err, "CondAddrDecoder error: %v", err) assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked) @@ -87,7 +87,7 @@ func TestCondAddrCodec(t *testing.T) { t.Run("error", func(t *testing.T) { errDecoder := &condAddrDecoder{canAddrDec: decode1, elseDec: nil} - err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable) + err := errDecoder.DecodeValue(nil, rw, unaddressable) want := ErrNoDecoder{Type: unaddressable.Type()} assert.Equal(t, err, want, "expected error %v, got %v", want, err) }) diff --git a/bson/decoder.go b/bson/decoder.go index cc53f422f9..14335c5bb5 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -28,15 +28,15 @@ var decPool = sync.Pool{ // A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as // the source of BSON data. type Decoder struct { - dc DecodeContext - vr ValueReader + reg *Registry + vr ValueReader } // NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr. func NewDecoder(vr ValueReader) *Decoder { return &Decoder{ - dc: DecodeContext{Registry: DefaultRegistry}, - vr: vr, + reg: DefaultRegistry, + vr: vr, } } @@ -68,12 +68,12 @@ func (d *Decoder) Decode(val interface{}) error { default: return fmt.Errorf("argument to Decode must be a pointer or a map, but got %v", rval) } - decoder, err := d.dc.LookupDecoder(rval.Type()) + decoder, err := d.reg.LookupDecoder(rval.Type()) if err != nil { return err } - return decoder.DecodeValue(d.dc, d.vr, rval) + return decoder.DecodeValue(d.reg, d.vr, rval) } // Reset will reset the state of the decoder, using the same *DecodeContext used in @@ -84,59 +84,105 @@ func (d *Decoder) Reset(vr ValueReader) { // SetRegistry replaces the current registry of the decoder with r. func (d *Decoder) SetRegistry(r *Registry) { - d.dc.Registry = r + d.reg = r } // DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This // behavior is restricted to data typed as "interface{}" or "map[string]interface{}". func (d *Decoder) DefaultDocumentM() { - d.dc.defaultDocumentType = reflect.TypeOf(M{}) + t := reflect.TypeOf((*emptyInterfaceCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*emptyInterfaceCodec).defaultDocumentType = reflect.TypeOf(M{}) + } + } } // DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This // behavior is restricted to data typed as "interface{}" or "map[string]interface{}". func (d *Decoder) DefaultDocumentD() { - d.dc.defaultDocumentType = reflect.TypeOf(D{}) + t := reflect.TypeOf((*emptyInterfaceCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*emptyInterfaceCodec).defaultDocumentType = reflect.TypeOf(D{}) + } + } } // AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values // when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct // field. The truncation logic does not apply to BSON "decimal128" values. func (d *Decoder) AllowTruncatingDoubles() { - d.dc.truncate = true + t := reflect.TypeOf((*intCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*intCodec).truncate = true + } + } + // TODO floatCodec } // BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or // "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary. func (d *Decoder) BinaryAsSlice() { - d.dc.binaryAsSlice = true + t := reflect.TypeOf((*emptyInterfaceCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*emptyInterfaceCodec).decodeBinaryAsSlice = true + } + } } // DecodeObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. func (d *Decoder) DecodeObjectIDAsHex() { - d.dc.decodeObjectIDAsHex = true + t := reflect.TypeOf((*stringCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*stringCodec).decodeObjectIDAsHex = true + } + } } // UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson" // struct tag is not specified. func (d *Decoder) UseJSONStructTags() { - d.dc.useJSONStructTags = true + t := reflect.TypeOf((*structCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*structCodec).useJSONStructTags = true + } + } } // UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead // of the UTC timezone. func (d *Decoder) UseLocalTimeZone() { - d.dc.useLocalTimeZone = true + t := reflect.TypeOf((*timeCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*timeCodec).useLocalTimeZone = true + } + } } // ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value // passed to Decode before unmarshaling BSON documents into them. func (d *Decoder) ZeroMaps() { - d.dc.zeroMaps = true + t := reflect.TypeOf((*mapCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*mapCodec).decodeZerosMap = true + } + } } // ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination // value passed to Decode before unmarshaling BSON documents into them. func (d *Decoder) ZeroStructs() { - d.dc.zeroStructs = true + t := reflect.TypeOf((*structCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*structCodec).decodeZeroStruct = true + } + } } diff --git a/bson/decoder_test.go b/bson/decoder_test.go index 3b96f63559..b101b38d65 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -32,7 +32,7 @@ func TestBasicDecode(t *testing.T) { reg := NewRegistryBuilder().Build() decoder, err := reg.LookupDecoder(reflect.TypeOf(got)) noerr(t, err) - err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got) + err = decoder.DecodeValue(reg, vr, got) noerr(t, err) assert.Equal(t, tc.want, got.Addr().Interface(), "Results do not match.") }) @@ -200,15 +200,13 @@ func TestDecoderv2(t *testing.T) { t.Parallel() r1, r2 := DefaultRegistry, NewRegistryBuilder().Build() - dc1 := DecodeContext{Registry: r1} - dc2 := DecodeContext{Registry: r2} dec := NewDecoder(NewValueReader([]byte{})) - if !reflect.DeepEqual(dec.dc, dc1) { - t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1) + if !reflect.DeepEqual(dec.reg, r1) { + t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r1) } dec.SetRegistry(r2) - if !reflect.DeepEqual(dec.dc, dc2) { - t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2) + if !reflect.DeepEqual(dec.reg, r2) { + t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r2) } }) t.Run("DecodeToNil", func(t *testing.T) { diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 24892adeae..42602e3cbe 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -57,7 +57,7 @@ func TestDefaultValueDecoders(t *testing.T) { type subtest struct { name string val interface{} - dctx *DecodeContext + reg *Registry llvrw *valueReaderWriter invoke invoked err error @@ -186,15 +186,15 @@ func TestDefaultValueDecoders(t *testing.T) { errors.New("ReadDouble error"), }, { - "ReadDouble", int64(3), &DecodeContext{}, + "ReadDouble", int64(3), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.00)}, readDouble, nil, }, - { - "ReadDouble (truncate)", int64(3), &DecodeContext{truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, + // { + // "ReadDouble (truncate)", int64(3), &DecodeContext{truncate: true}, + // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + // nil, + // }, { "ReadDouble (no truncate)", int64(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -418,15 +418,15 @@ func TestDefaultValueDecoders(t *testing.T) { errors.New("ReadDouble error"), }, { - "ReadDouble", uint64(3), &DecodeContext{}, + "ReadDouble", uint64(3), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.00)}, readDouble, nil, }, - { - "ReadDouble (truncate)", uint64(3), &DecodeContext{truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, + // { + // "ReadDouble (truncate)", uint64(3), &DecodeContext{truncate: true}, + // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + // nil, + // }, { "ReadDouble (no truncate)", uint64(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -673,11 +673,11 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14159)}, readDouble, nil, }, - { - "float32/fast path (truncate)", float32(3.14), &DecodeContext{truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, + // { + // "float32/fast path (truncate)", float32(3.14), &DecodeContext{truncate: true}, + // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + // nil, + // }, { "float32/fast path (no truncate)", float32(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -711,11 +711,11 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14159)}, readDouble, nil, }, - { - "float32/reflection path (truncate)", myfloat32(3.14), &DecodeContext{truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, + // { + // "float32/reflection path (truncate)", myfloat32(3.14), &DecodeContext{truncate: true}, + // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + // nil, + // }, { "float32/reflection path (no truncate)", myfloat32(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -803,7 +803,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "wrong kind (non-string key)", map[bool]interface{}{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{}, readElement, fmt.Errorf("unsupported key type: %T", false), @@ -819,7 +819,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", map[string]string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{}, readDocument, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -827,7 +827,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "ReadElement Error", make(map[string]interface{}), - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("re error"), ErrAfter: readElement}, readElement, errors.New("re error"), @@ -905,7 +905,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", [1]string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -913,7 +913,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "ReadValue Error", [1]string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("rv error"), ErrAfter: readValue, BSONType: TypeArray}, readValue, errors.New("rv error"), @@ -921,7 +921,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "DecodeValue Error", [1]string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{BSONType: TypeArray}, readValue, &DecodeError{keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type")}, @@ -999,7 +999,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", []string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -1007,7 +1007,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "ReadValue Error", []string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("rv error"), ErrAfter: readValue, BSONType: TypeArray}, readValue, errors.New("rv error"), @@ -1015,7 +1015,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "DecodeValue Error", []string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{BSONType: TypeArray}, readValue, &DecodeError{keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type")}, @@ -1561,7 +1561,7 @@ func TestDefaultValueDecoders(t *testing.T) { ValueDecoderError{Name: "pointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}}, }, { - "No Decoder", &wrong, &DecodeContext{Registry: buildDefaultRegistry()}, nil, nothing, + "No Decoder", &wrong, buildDefaultRegistry(), nil, nothing, ErrNoDecoder{Type: reflect.TypeOf(wrong)}, }, { @@ -2287,7 +2287,7 @@ func TestDefaultValueDecoders(t *testing.T) { Code: "var hello = 'world';", Scope: D{{"foo", nil}}, }, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{BSONType: TypeCodeWithScope, Err: errors.New("dd error"), ErrAfter: readElement}, readElement, errors.New("dd error"), @@ -2346,10 +2346,6 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, rc := range tc.subtests { t.Run(rc.name, func(t *testing.T) { - var dc DecodeContext - if rc.dctx != nil { - dc = *rc.dctx - } llvrw := new(valueReaderWriter) if rc.llvrw != nil { llvrw = rc.llvrw @@ -2357,13 +2353,13 @@ func TestDefaultValueDecoders(t *testing.T) { llvrw.T = t // var got interface{} if rc.val == cansetreflectiontest { // We're doing a CanSet reflection test - err := tc.vd.DecodeValue(dc, llvrw, reflect.Value{}) + err := tc.vd.DecodeValue(rc.reg, llvrw, reflect.Value{}) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } val := reflect.New(reflect.TypeOf(rc.val)).Elem() - err = tc.vd.DecodeValue(dc, llvrw, val) + err = tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } @@ -2375,13 +2371,13 @@ func TestDefaultValueDecoders(t *testing.T) { t.Fatalf("Error must be a DecodeValueError, but got a %T", rc.err) } - err := tc.vd.DecodeValue(dc, llvrw, reflect.Value{}) + err := tc.vd.DecodeValue(rc.reg, llvrw, reflect.Value{}) wanterr.Received = reflect.ValueOf(nil) if !assert.CompareErrors(err, wanterr) { t.Errorf("Errors do not match. got %v; want %v", err, wanterr) } - err = tc.vd.DecodeValue(dc, llvrw, reflect.ValueOf(int(12345))) + err = tc.vd.DecodeValue(rc.reg, llvrw, reflect.ValueOf(int(12345))) wanterr.Received = reflect.ValueOf(int(12345)) if !assert.CompareErrors(err, wanterr) { t.Errorf("Errors do not match. got %v; want %v", err, wanterr) @@ -2399,7 +2395,7 @@ func TestDefaultValueDecoders(t *testing.T) { panic(err) } }() - err := tc.vd.DecodeValue(dc, llvrw, val) + err := tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } @@ -2420,7 +2416,7 @@ func TestDefaultValueDecoders(t *testing.T) { } t.Run("CodeWithScopeCodec/DecodeValue/success", func(t *testing.T) { - dc := DecodeContext{Registry: buildDefaultRegistry()} + reg := buildDefaultRegistry() b := bsoncore.BuildDocument(nil, bsoncore.AppendCodeWithScopeElement( nil, "foo", "var hello = 'world';", @@ -2438,7 +2434,7 @@ func TestDefaultValueDecoders(t *testing.T) { Scope: D{{"bar", nil}}, } val := reflect.New(tCodeWithScope).Elem() - err = codeWithScopeDecodeValue(dc, vr, val) + err = codeWithScopeDecodeValue(reg, vr, val) noerr(t, err) got := val.Interface().(CodeWithScope) @@ -2447,25 +2443,23 @@ func TestDefaultValueDecoders(t *testing.T) { } }) t.Run("ValueUnmarshalerDecodeValue/UnmarshalBSONValue error", func(t *testing.T) { - var dc DecodeContext llvrw := &valueReaderWriter{BSONType: TypeString, Return: string("hello, world!")} llvrw.T = t want := errors.New("ubsonv error") valUnmarshaler := &testValueUnmarshaler{err: want} - got := valueUnmarshalerDecodeValue(dc, llvrw, reflect.ValueOf(valUnmarshaler)) + got := valueUnmarshalerDecodeValue(nil, llvrw, reflect.ValueOf(valUnmarshaler)) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } }) t.Run("ValueUnmarshalerDecodeValue/Unaddressable value", func(t *testing.T) { - var dc DecodeContext llvrw := &valueReaderWriter{BSONType: TypeString, Return: string("hello, world!")} llvrw.T = t val := reflect.ValueOf(testValueUnmarshaler{}) want := ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} - got := valueUnmarshalerDecodeValue(dc, llvrw, val) + got := valueUnmarshalerDecodeValue(nil, llvrw, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -2488,8 +2482,8 @@ func TestDefaultValueDecoders(t *testing.T) { var val [1]string want := fmt.Errorf("more elements returned in array than can fit inside %T, got 2 elements", val) - dc := DecodeContext{Registry: buildDefaultRegistry()} - got := arrayDecodeValue(dc, vr, reflect.ValueOf(val)) + reg := buildDefaultRegistry() + got := arrayDecodeValue(reg, vr, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -3137,7 +3131,7 @@ func TestDefaultValueDecoders(t *testing.T) { noerr(t, err) gotVal := reflect.New(reflect.TypeOf(tc.value)).Elem() - err = dec.DecodeValue(DecodeContext{Registry: reg}, vr, gotVal) + err = dec.DecodeValue(reg, vr, gotVal) noerr(t, err) got := gotVal.Interface() @@ -3186,7 +3180,7 @@ func TestDefaultValueDecoders(t *testing.T) { noerr(t, err) gotVal := reflect.New(reflect.TypeOf(tc.value)).Elem() - err = dec.DecodeValue(DecodeContext{Registry: reg}, vr, gotVal) + err = dec.DecodeValue(reg, vr, gotVal) if err == nil || !strings.Contains(err.Error(), tc.err.Error()) { t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) } @@ -3310,9 +3304,9 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() - dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} + reg := newTestRegistryBuilder().Build() want := ErrNoTypeMapEntry{Type: tc.bsontype} - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3326,11 +3320,8 @@ func TestDefaultValueDecoders(t *testing.T) { reg := newTestRegistryBuilder(). RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). Build() - dc := DecodeContext{ - Registry: reg, - } want := ErrNoDecoder{Type: reflect.TypeOf(tc.val)} - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3346,10 +3337,7 @@ func TestDefaultValueDecoders(t *testing.T) { RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). Build() - dc := DecodeContext{ - Registry: reg, - } - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, reflect.New(tEmpty).Elem()) + got := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, reflect.New(tEmpty).Elem()) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3362,11 +3350,8 @@ func TestDefaultValueDecoders(t *testing.T) { RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). Build() - dc := DecodeContext{ - Registry: reg, - } got := reflect.New(tEmpty).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, got) + err := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, got) noerr(t, err) if !cmp.Equal(got.Interface(), want, cmp.Comparer(compareDecimal128)) { t.Errorf("Did not receive expected value. got %v; want %v", got.Interface(), want) @@ -3379,7 +3364,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("non-interface{}", func(t *testing.T) { val := uint64(1234567890) want := ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(val)} - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) + got := defaultEmptyInterfaceCodec.DecodeValue(nil, nil, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3388,7 +3373,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("nil *interface{}", func(t *testing.T) { var val interface{} want := ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(val)} - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) + got := defaultEmptyInterfaceCodec.DecodeValue(nil, nil, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3398,7 +3383,7 @@ func TestDefaultValueDecoders(t *testing.T) { llvr := &valueReaderWriter{BSONType: TypeDouble} want := ErrNoTypeMapEntry{Type: TypeDouble} val := reflect.New(tEmpty).Elem() - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: newTestRegistryBuilder().Build()}, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(newTestRegistryBuilder().Build(), llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3409,7 +3394,7 @@ func TestDefaultValueDecoders(t *testing.T) { want := D{{"pi", 3.14159}} var got interface{} val := reflect.ValueOf(&got).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: buildDefaultRegistry()}, vr, val) + err := defaultEmptyInterfaceCodec.DecodeValue(buildDefaultRegistry(), vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Errorf("Did not get correct result. got %v; want %v", got, want) @@ -3456,7 +3441,7 @@ func TestDefaultValueDecoders(t *testing.T) { vr := NewValueReader(doc) val := reflect.ValueOf(&got).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: tc.registry}, vr, val) + err := defaultEmptyInterfaceCodec.DecodeValue(tc.registry, vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Fatalf("got %v, want %v", got, want) @@ -3491,7 +3476,7 @@ func TestDefaultValueDecoders(t *testing.T) { var got D vr := NewValueReader(doc) val := reflect.ValueOf(&got).Elem() - err := (&sliceCodec{}).DecodeValue(DecodeContext{Registry: reg}, vr, val) + err := (&sliceCodec{}).DecodeValue(reg, vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Fatalf("got %v, want %v", got, want) @@ -3660,16 +3645,16 @@ func TestDefaultValueDecoders(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - dc := DecodeContext{Registry: tc.registry} - if dc.Registry == nil { - dc.Registry = buildDefaultRegistry() + reg := tc.registry + if reg == nil { + reg = buildDefaultRegistry() } var val reflect.Value if rtype := reflect.TypeOf(tc.val); rtype != nil { val = reflect.New(rtype).Elem() } - err := tc.decoder.DecodeValue(dc, tc.vr, val) + err := tc.decoder.DecodeValue(reg, tc.vr, val) assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err) }) } @@ -3681,10 +3666,10 @@ func TestDefaultValueDecoders(t *testing.T) { type inner struct{ Bar string } type outer struct{ Foo inner } - dc := DecodeContext{Registry: buildDefaultRegistry()} + reg := buildDefaultRegistry() vr := NewValueReader(outerBytes) val := reflect.New(reflect.TypeOf(outer{})).Elem() - err := defaultTestStructCodec.DecodeValue(dc, vr, val) + err := defaultTestStructCodec.DecodeValue(reg, vr, val) var decodeErr *DecodeError assert.True(t, errors.As(err, &decodeErr), "expected DecodeError, got %v of type %T", err, err) @@ -3714,10 +3699,10 @@ func TestDefaultValueDecoders(t *testing.T) { registerDefaultDecoders(rb) rb.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) - dc := DecodeContext{Registry: rb.Build()} + reg := rb.Build() vr := NewValueReader(docBytes) val := reflect.New(tD).Elem() - err := dDecodeValue(dc, vr, val) + err := dDecodeValue(reg, vr, val) assert.Nil(t, err, "DDecodeValue error: %v", err) want := D{ @@ -3733,10 +3718,10 @@ func TestDefaultValueDecoders(t *testing.T) { ) type myMap map[string]mybool - dc := DecodeContext{Registry: buildDefaultRegistry()} + reg := buildDefaultRegistry() vr := NewValueReader(docBytes) val := reflect.New(reflect.TypeOf(myMap{})).Elem() - err := (&mapCodec{}).DecodeValue(dc, vr, val) + err := (&mapCodec{}).DecodeValue(reg, vr, val) assert.Nil(t, err, "DecodeValue error: %v", err) want := myMap{ diff --git a/bson/encoder.go b/bson/encoder.go index 42c1900006..dba91b7424 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -71,7 +71,7 @@ func (e *Encoder) SetRegistry(r *Registry) { // the marshaled BSON when the "inline" struct tag option is set. func (e *Encoder) ErrorOnInlineDuplicates() { t := reflect.TypeOf((*structCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*structCodec).overwriteDuplicatedInlinedFields = false } @@ -93,7 +93,7 @@ func (e *Encoder) IntMinSize() { // } // } t := reflect.TypeOf((*intCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*intCodec).encodeToMinSize = true } @@ -104,7 +104,7 @@ func (e *Encoder) IntMinSize() { // strings using fmt.Sprint instead of the default string conversion logic. func (e *Encoder) StringifyMapKeysWithFmt() { t := reflect.TypeOf((*mapCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*mapCodec).encodeKeysWithStringer = true } @@ -115,7 +115,7 @@ func (e *Encoder) StringifyMapKeysWithFmt() { // null. func (e *Encoder) NilMapAsEmpty() { t := reflect.TypeOf((*mapCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*mapCodec).encodeNilAsEmpty = true } @@ -126,7 +126,7 @@ func (e *Encoder) NilMapAsEmpty() { // null. func (e *Encoder) NilSliceAsEmpty() { t := reflect.TypeOf((*sliceCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*sliceCodec).encodeNilAsEmpty = true } @@ -142,7 +142,7 @@ func (e *Encoder) NilByteSliceAsEmpty() { // } // } t := reflect.TypeOf((*byteSliceCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*byteSliceCodec).encodeNilAsEmpty = true } @@ -159,7 +159,7 @@ func (e *Encoder) NilByteSliceAsEmpty() { // zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. func (e *Encoder) OmitZeroStruct() { t := reflect.TypeOf((*structCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*structCodec).encodeOmitDefaultStruct = true } @@ -170,7 +170,7 @@ func (e *Encoder) OmitZeroStruct() { // struct tag is not specified. func (e *Encoder) UseJSONStructTags() { t := reflect.TypeOf((*structCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*structCodec).useJSONStructTags = true } diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 7571abb19f..8a1ff811fd 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -479,7 +479,7 @@ func (t *prefixPtr) SetBSON(raw bson.RawValue) error { return err } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -506,7 +506,7 @@ func (t *prefixVal) SetBSON(raw bson.RawValue) error { return err } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -930,7 +930,7 @@ func (o *setterType) SetBSON(raw bson.RawValue) error { raw.Type = bson.TypeEmbeddedDocument } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -1289,7 +1289,7 @@ func (s *getterSetterD) SetBSON(raw bson.RawValue) error { raw.Type = bson.TypeEmbeddedDocument } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -1315,7 +1315,7 @@ func (i *getterSetterInt) SetBSON(raw bson.RawValue) error { raw.Type = bson.TypeEmbeddedDocument } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -1337,7 +1337,7 @@ func (s *ifaceSlice) SetBSON(raw bson.RawValue) error { return err } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index a38113b72d..18dcfb71b3 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -468,7 +468,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { type subtest struct { name string val interface{} - dctx *DecodeContext + reg *Registry llvrw *valueReaderWriter invoke invoked err error @@ -563,23 +563,19 @@ func TestPrimitiveValueDecoders(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, rc := range tc.subtests { t.Run(rc.name, func(t *testing.T) { - var dc DecodeContext - if rc.dctx != nil { - dc = *rc.dctx - } llvrw := new(valueReaderWriter) if rc.llvrw != nil { llvrw = rc.llvrw } llvrw.T = t if rc.val == cansetreflectiontest { // We're doing a CanSet reflection test - err := tc.vd.DecodeValue(dc, llvrw, reflect.Value{}) + err := tc.vd.DecodeValue(rc.reg, llvrw, reflect.Value{}) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } val := reflect.New(reflect.TypeOf(rc.val)).Elem() - err = tc.vd.DecodeValue(dc, llvrw, val) + err = tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } @@ -596,7 +592,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { panic(err) } }() - err := tc.vd.DecodeValue(dc, llvrw, val) + err := tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } diff --git a/bson/raw_value.go b/bson/raw_value.go index a32b82e41d..f119cbd9fe 100644 --- a/bson/raw_value.go +++ b/bson/raw_value.go @@ -81,13 +81,13 @@ func (rv RawValue) UnmarshalWithRegistry(r *Registry, val interface{}) error { if err != nil { return err } - return dec.DecodeValue(DecodeContext{Registry: r}, vr, rval) + return dec.DecodeValue(r, vr, rval) } // UnmarshalWithContext performs the same unmarshalling as Unmarshal but uses the provided DecodeContext // instead of the one attached or the default registry. -func (rv RawValue) UnmarshalWithContext(dc *DecodeContext, val interface{}) error { - if dc == nil { +func (rv RawValue) UnmarshalWithContext(reg *Registry, val interface{}) error { + if reg == nil { return ErrNilContext } @@ -97,11 +97,11 @@ func (rv RawValue) UnmarshalWithContext(dc *DecodeContext, val interface{}) erro return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval) } rval = rval.Elem() - dec, err := dc.LookupDecoder(rval.Type()) + dec, err := reg.LookupDecoder(rval.Type()) if err != nil { return err } - return dec.DecodeValue(*dc, vr, rval) + return dec.DecodeValue(reg, vr, rval) } func convertFromCoreValue(v bsoncore.Value) RawValue { diff --git a/bson/raw_value_test.go b/bson/raw_value_test.go index 18598ebe8f..aa2b9a0eb6 100644 --- a/bson/raw_value_test.go +++ b/bson/raw_value_test.go @@ -114,11 +114,11 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} + reg := newTestRegistryBuilder().Build() var val RawValue var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} - got := val.UnmarshalWithContext(&dc, &s) + got := val.UnmarshalWithContext(reg, &s) if !assert.CompareErrors(got, want) { t.Errorf("Expected errors to match. got %v; want %v", got, want) } @@ -126,11 +126,11 @@ func TestRawValue(t *testing.T) { t.Run("Returns DecodeValue error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: NewRegistryBuilder().Build()} + reg := NewRegistryBuilder().Build() val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string want := fmt.Errorf("cannot decode %v into a string type", TypeDouble) - got := val.UnmarshalWithContext(&dc, &s) + got := val.UnmarshalWithContext(reg, &s) if !assert.CompareErrors(got, want) { t.Errorf("Expected errors to match. got %v; want %v", got, want) } @@ -138,11 +138,11 @@ func TestRawValue(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: NewRegistryBuilder().Build()} + reg := NewRegistryBuilder().Build() want := float64(3.14159) val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, want)} var got float64 - err := val.UnmarshalWithContext(&dc, &got) + err := val.UnmarshalWithContext(reg, &got) noerr(t, err) if got != want { t.Errorf("Expected results to match. got %g; want %g", got, want) diff --git a/bson/registry.go b/bson/registry.go index 41a652fa31..fa63a4c7eb 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -222,8 +222,7 @@ func (rb *RegistryBuilder) Build() *Registry { interfaceDecoders: make([]interfaceValueDecoder, 0, len(rb.interfaceDecoders)), typeMap: make(map[Type]reflect.Type), - encoderTypeMap: make(map[reflect.Type][]ValueEncoder), - decoderTypeMap: make(map[reflect.Type][]ValueDecoder), + codecTypeMap: make(map[reflect.Type][]interface{}), } encoderCache := make(map[reflect.Value]ValueEncoder) @@ -234,7 +233,7 @@ func (rb *RegistryBuilder) Build() *Registry { encoder := encFac() encoderCache[reflect.ValueOf(encFac)] = encoder t := reflect.ValueOf(encoder).Type() - r.encoderTypeMap[t] = append(r.encoderTypeMap[t], encoder) + r.codecTypeMap[t] = append(r.codecTypeMap[t], encoder) return encoder } for k, v := range rb.typeEncoders { @@ -261,7 +260,7 @@ func (rb *RegistryBuilder) Build() *Registry { decoder := decFac() decoderCache[reflect.ValueOf(decFac)] = decoder t := reflect.ValueOf(decoder).Type() - r.decoderTypeMap[t] = append(r.decoderTypeMap[t], decoder) + r.codecTypeMap[t] = append(r.codecTypeMap[t], decoder) return decoder } for k, v := range rb.typeDecoders { @@ -329,8 +328,7 @@ type Registry struct { kindDecoders [reflect.UnsafePointer + 1]ValueDecoder typeMap map[Type]reflect.Type - encoderTypeMap map[reflect.Type][]ValueEncoder - decoderTypeMap map[reflect.Type][]ValueDecoder + codecTypeMap map[reflect.Type][]interface{} } // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup diff --git a/bson/registry_test.go b/bson/registry_test.go index c321f48d88..fd66e0cc84 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -93,8 +93,8 @@ func TestRegistryBuilder(t *testing.T) { if !cmp.Equal(c4, 1) { t.Errorf("ef4 is called %d time(s); expected 1", c4) } - codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] - if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + codecs, ok := reg.codecTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.codecTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { t.Errorf("codecs were not cached correctly") } got := make(map[reflect.Type]ValueEncoder) @@ -172,8 +172,8 @@ func TestRegistryBuilder(t *testing.T) { if !cmp.Equal(c4, 1) { t.Errorf("ef4 is called %d time(s); expected 1", c4) } - codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] - if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + codecs, ok := reg.codecTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.codecTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { t.Errorf("codecs were not cached correctly") } got := reg.typeEncoders @@ -246,8 +246,8 @@ func TestRegistryBuilder(t *testing.T) { if !cmp.Equal(c4, 1) { t.Errorf("ef4 is called %d time(s); expected 1", c4) } - codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] - if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + codecs, ok := reg.codecTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.codecTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { t.Errorf("codecs were not cached correctly") } got := reg.kindEncoders diff --git a/bson/string_codec_test.go b/bson/string_codec_test.go index 16d1727d4f..c764af97dc 100644 --- a/bson/string_codec_test.go +++ b/bson/string_codec_test.go @@ -20,20 +20,18 @@ func TestStringCodec(t *testing.T) { reader := &valueReaderWriter{BSONType: TypeObjectID, Return: oid} testCases := []struct { name string - dctx DecodeContext + codec *stringCodec err error result string }{ - {"default", DecodeContext{}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, - {"true", DecodeContext{decodeObjectIDAsHex: true}, nil, oid.Hex()}, - {"false", DecodeContext{decodeObjectIDAsHex: false}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, + {"default", &stringCodec{}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, + {"true", &stringCodec{decodeObjectIDAsHex: true}, nil, oid.Hex()}, + {"false", &stringCodec{decodeObjectIDAsHex: false}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - stringCodec := &stringCodec{} - actual := reflect.New(reflect.TypeOf("")).Elem() - err := stringCodec.DecodeValue(tc.dctx, reader, actual) + err := tc.codec.DecodeValue(nil, reader, actual) if tc.err == nil { assert.NoError(t, err) } else { diff --git a/bson/time_codec_test.go b/bson/time_codec_test.go index 70f52906b2..fc32339602 100644 --- a/bson/time_codec_test.go +++ b/bson/time_codec_test.go @@ -31,7 +31,7 @@ func TestTimeCodec(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { actual := reflect.New(reflect.TypeOf(now)).Elem() - err := tc.codec.DecodeValue(DecodeContext{}, reader, actual) + err := tc.codec.DecodeValue(nil, reader, actual) assert.Nil(t, err, "TimeCodec.DecodeValue error: %v", err) actualTime := actual.Interface().(time.Time) @@ -65,7 +65,7 @@ func TestTimeCodec(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { actual := reflect.New(reflect.TypeOf(now)).Elem() - err := (&timeCodec{}).DecodeValue(DecodeContext{}, tc.reader, actual) + err := (&timeCodec{}).DecodeValue(nil, tc.reader, actual) assert.Nil(t, err, "DecodeValue error: %v", err) actualTime := actual.Interface().(time.Time) diff --git a/bson/truncation_test.go b/bson/truncation_test.go index a9aeea278b..8f3301e85d 100644 --- a/bson/truncation_test.go +++ b/bson/truncation_test.go @@ -39,12 +39,12 @@ func TestTruncation(t *testing.T) { assert.Nil(t, err) var output outputArgs - dc := DecodeContext{ - Registry: DefaultRegistry, - truncate: true, - } + // dc := DecodeContext{ + // Registry: DefaultRegistry, + // truncate: true, + // } - err = UnmarshalWithContext(dc, buf.Bytes(), &output) + err = UnmarshalWithContext(DefaultRegistry, buf.Bytes(), &output) assert.Nil(t, err) assert.Equal(t, inputName, output.Name) @@ -65,13 +65,13 @@ func TestTruncation(t *testing.T) { assert.Nil(t, err) var output outputArgs - dc := DecodeContext{ - Registry: DefaultRegistry, - truncate: false, - } + // dc := DecodeContext{ + // Registry: DefaultRegistry, + // truncate: false, + // } // case throws an error when truncation is disabled - err = UnmarshalWithContext(dc, buf.Bytes(), &output) + err = UnmarshalWithContext(DefaultRegistry, buf.Bytes(), &output) assert.NotNil(t, err) }) } diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 7caadc5dbc..a02582577e 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -56,7 +56,7 @@ func Unmarshal(data []byte, val interface{}) error { // See [Decoder] for more examples. func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { vr := NewValueReader(data) - return unmarshalFromReader(DecodeContext{Registry: r}, vr, val) + return unmarshalFromReader(r, vr, val) } // UnmarshalWithContext parses the BSON-encoded data using DecodeContext dc and @@ -73,9 +73,9 @@ func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { // dec.DefaultDocumentM() // // See [Decoder] for more examples. -func UnmarshalWithContext(dc DecodeContext, data []byte, val interface{}) error { +func UnmarshalWithContext(reg *Registry, data []byte, val interface{}) error { vr := NewValueReader(data) - return unmarshalFromReader(dc, vr, val) + return unmarshalFromReader(reg, vr, val) } // UnmarshalValue parses the BSON value of type t with bson.DefaultRegistry and @@ -93,7 +93,7 @@ func UnmarshalValue(t Type, data []byte, val interface{}) error { // Go Driver 2.0. func UnmarshalValueWithRegistry(r *Registry, t Type, data []byte, val interface{}) error { vr := NewBSONValueReader(t, data) - return unmarshalFromReader(DecodeContext{Registry: r}, vr, val) + return unmarshalFromReader(r, vr, val) } // UnmarshalExtJSON parses the extended JSON-encoded data and stores the result @@ -126,7 +126,7 @@ func UnmarshalExtJSONWithRegistry(r *Registry, data []byte, canonical bool, val return err } - return unmarshalFromReader(DecodeContext{Registry: r}, ejvr, val) + return unmarshalFromReader(r, ejvr, val) } // UnmarshalExtJSONWithContext parses the extended JSON-encoded data using @@ -147,21 +147,21 @@ func UnmarshalExtJSONWithRegistry(r *Registry, data []byte, canonical bool, val // dec.DefaultDocumentM() // // See [Decoder] for more examples. -func UnmarshalExtJSONWithContext(dc DecodeContext, data []byte, canonical bool, val interface{}) error { +func UnmarshalExtJSONWithContext(reg *Registry, data []byte, canonical bool, val interface{}) error { ejvr, err := NewExtJSONValueReader(bytes.NewReader(data), canonical) if err != nil { return err } - return unmarshalFromReader(dc, ejvr, val) + return unmarshalFromReader(reg, ejvr, val) } -func unmarshalFromReader(dc DecodeContext, vr ValueReader, val interface{}) error { +func unmarshalFromReader(reg *Registry, vr ValueReader, val interface{}) error { dec := decPool.Get().(*Decoder) defer decPool.Put(dec) dec.Reset(vr) - dec.dc = dc + dec.reg = reg return dec.Decode(val) } diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index dc8102b02c..9748d8a6db 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -69,9 +69,8 @@ func TestUnmarshalWithContext(t *testing.T) { copy(data, tc.data) // Assert that unmarshaling the input data results in the expected value. - dc := DecodeContext{Registry: DefaultRegistry} got := reflect.New(tc.sType).Interface() - err := UnmarshalWithContext(dc, data, got) + err := UnmarshalWithContext(DefaultRegistry, data, got) noerr(t, err) assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") @@ -199,8 +198,7 @@ func TestUnmarshalExtJSONWithContext(t *testing.T) { // Assert that unmarshaling the input data results in the expected value. got := reflect.New(tc.sType).Interface() - dc := DecodeContext{Registry: DefaultRegistry} - err := UnmarshalExtJSONWithContext(dc, data, true, got) + err := UnmarshalExtJSONWithContext(DefaultRegistry, data, true, got) noerr(t, err) assert.Equal(t, tc.want, got, "Did not unmarshal as expected.")