diff --git a/bson/array_codec.go b/bson/array_codec.go index 5b07f4acd4..76b9a059f7 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -12,24 +12,11 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -// ArrayCodec is the Codec used for bsoncore.Array values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ArrayCodec registered. -type ArrayCodec struct{} - -var defaultArrayCodec = NewArrayCodec() - -// NewArrayCodec returns an ArrayCodec. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ArrayCodec registered. -func NewArrayCodec() *ArrayCodec { - return &ArrayCodec{} -} +// arrayCodec is the Codec used for bsoncore.Array values. +type arrayCodec struct{} // EncodeValue is the ValueEncoder for bsoncore.Array values. -func (ac *ArrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (ac *arrayCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreArray { return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } @@ -39,7 +26,7 @@ func (ac *ArrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.V } // DecodeValue is the ValueDecoder for bsoncore.Array values. -func (ac *ArrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +func (ac *arrayCodec) DecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreArray { return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } diff --git a/bson/bson_test.go b/bson/bson_test.go index dcfc1037d9..e54b4dc865 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -17,7 +17,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "go.mongodb.org/mongo-driver/bson/bsonoptions" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" @@ -349,23 +348,21 @@ func TestMapCodec(t *testing.T) { strstr := stringerString("foo") mapObj := map[stringerString]int{strstr: 1} testCases := []struct { - name string - opts *bsonoptions.MapCodecOptions - key string + name string + codec *mapCodec + key string }{ - {"default", bsonoptions.MapCodec(), "foo"}, - {"true", bsonoptions.MapCodec().SetEncodeKeysWithStringer(true), "bar"}, - {"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"}, + {"default", &mapCodec{}, "foo"}, + {"true", &mapCodec{encodeKeysWithStringer: true}, "bar"}, + {"false", &mapCodec{encodeKeysWithStringer: false}, "foo"}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - mapCodec := NewMapCodec(tc.opts) - mapRegistry := NewRegistry() - mapRegistry.RegisterKindEncoder(reflect.Map, mapCodec) + mapRegistry := NewRegistryBuilder() + mapRegistry.RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return tc.codec }) buf := new(bytes.Buffer) vw := NewValueWriter(buf) - enc := NewEncoder(vw) - enc.SetRegistry(mapRegistry) + enc := NewEncoderWithRegistry(mapRegistry.Build(), vw) err := enc.Encode(mapObj) assert.Nil(t, err, "Encode error: %v", err) str := buf.String() diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 860a6b82af..5e910fca88 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -72,202 +72,33 @@ func (vde ValueDecoderError) Error() string { return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received) } -// EncodeContext is the contextual information required for a Codec to encode a -// value. -type EncodeContext struct { - *Registry - - // MinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, - // uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) - // that can represent the integer value. - // - // Deprecated: Use bson.Encoder.IntMinSize instead. - MinSize bool - - errorOnInlineDuplicates bool - stringifyMapKeysWithFmt bool - nilMapAsEmpty bool - nilSliceAsEmpty bool - nilByteSliceAsEmpty bool - omitZeroStruct bool - useJSONStructTags bool -} - -// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in -// the marshaled BSON when the "inline" struct tag option is set. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.ErrorOnInlineDuplicates] instead. -func (ec *EncodeContext) ErrorOnInlineDuplicates() { - ec.errorOnInlineDuplicates = true -} - -// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name -// strings using fmt.Sprintf() instead of the default string conversion logic. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.StringifyMapKeysWithFmt] instead. -func (ec *EncodeContext) StringifyMapKeysWithFmt() { - ec.stringifyMapKeysWithFmt = true -} - -// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON -// null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilMapAsEmpty] instead. -func (ec *EncodeContext) NilMapAsEmpty() { - ec.nilMapAsEmpty = true -} - -// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON -// null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilSliceAsEmpty] instead. -func (ec *EncodeContext) NilSliceAsEmpty() { - ec.nilSliceAsEmpty = true -} - -// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values -// instead of BSON null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilByteSliceAsEmpty] instead. -func (ec *EncodeContext) NilByteSliceAsEmpty() { - ec.nilByteSliceAsEmpty = true -} - -// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{}) -// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set. -// -// Note that the Encoder only examines exported struct fields when determining if a struct is the -// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.OmitZeroStruct] instead. -func (ec *EncodeContext) OmitZeroStruct() { - ec.omitZeroStruct = true -} - -// UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson" -// struct tag is not specified. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.UseJSONStructTags] instead. -func (ec *EncodeContext) UseJSONStructTags() { - ec.useJSONStructTags = true -} - -// DecodeContext is the contextual information required for a Codec to decode a -// value. -type DecodeContext struct { - *Registry - - // Truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" - // values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64, - // uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to - // BSON "decimal128" values. - // - // Deprecated: Use bson.Decoder.AllowTruncatingDoubles instead. - Truncate bool - - // Ancestor is the type of a containing document. This is mainly used to determine what type - // should be used when decoding an embedded document into an empty interface. For example, if - // Ancestor is a bson.M, BSON embedded document values being decoded into an empty interface - // will be decoded into a bson.M. - // - // Deprecated: Use bson.Decoder.DefaultDocumentM or bson.Decoder.DefaultDocumentD instead. - Ancestor reflect.Type - - // defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the - // usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is - // set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an - // error. DocumentType overrides the Ancestor field. - defaultDocumentType reflect.Type - - binaryAsSlice bool - useJSONStructTags bool - useLocalTimeZone bool - zeroMaps bool - zeroStructs bool -} - -// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or -// "Old" BSON binary subtype as a Go byte slice instead of a Binary. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.BinaryAsSlice] instead. -func (dc *DecodeContext) BinaryAsSlice() { - dc.binaryAsSlice = true -} - -// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson" -// struct tag is not specified. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.UseJSONStructTags] instead. -func (dc *DecodeContext) UseJSONStructTags() { - dc.useJSONStructTags = true -} - -// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead -// of the UTC timezone. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.UseLocalTimeZone] instead. -func (dc *DecodeContext) UseLocalTimeZone() { - dc.useLocalTimeZone = true -} - -// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value -// passed to Decode before unmarshaling BSON documents into them. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroMaps] instead. -func (dc *DecodeContext) ZeroMaps() { - dc.zeroMaps = true -} - -// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination -// value passed to Decode before unmarshaling BSON documents into them. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroStructs] instead. -func (dc *DecodeContext) ZeroStructs() { - dc.zeroStructs = true -} - -// DefaultDocumentM causes the Decoder to always unmarshal documents into the M type. This -// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DefaultDocumentM] instead. -func (dc *DecodeContext) DefaultDocumentM() { - dc.defaultDocumentType = reflect.TypeOf(M{}) -} - -// DefaultDocumentD causes the Decoder to always unmarshal documents into the D type. This -// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DefaultDocumentD] instead. -func (dc *DecodeContext) DefaultDocumentD() { - dc.defaultDocumentType = reflect.TypeOf(D{}) -} - -// ValueCodec is an interface for encoding and decoding a reflect.Value. -// values. -// -// Deprecated: Use [ValueEncoder] and [ValueDecoder] instead. -type ValueCodec interface { - ValueEncoder - ValueDecoder +// EncoderRegistry is an interface provides a ValueEncoder based on the given reflect.Type. +type EncoderRegistry interface { + LookupEncoder(reflect.Type) (ValueEncoder, error) } // ValueEncoder is the interface implemented by types that can encode a provided Go type to BSON. // The value to encode is provided as a reflect.Value and a bson.ValueWriter is used within the // EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc -// is provided to allow use of a function with the correct signature as a ValueEncoder. An -// EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and -// to provide configuration information. +// is provided to allow use of a function with the correct signature as a ValueEncoder. A pointer +// to a Registry instance is provided to allow implementations to lookup further ValueEncoders. type ValueEncoder interface { - EncodeValue(EncodeContext, ValueWriter, reflect.Value) error + EncodeValue(EncoderRegistry, ValueWriter, reflect.Value) error } // ValueEncoderFunc is an adapter function that allows a function with the correct signature to be // used as a ValueEncoder. -type ValueEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error +type ValueEncoderFunc func(EncoderRegistry, ValueWriter, reflect.Value) error // EncodeValue implements the ValueEncoder interface. -func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return fn(ec, vw, val) +func (fn ValueEncoderFunc) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { + return fn(reg, vw, val) +} + +// DecoderRegistry is an interface provides a ValueDecoder based on the given reflect.Type. +type DecoderRegistry interface { + LookupDecoder(reflect.Type) (ValueDecoder, error) + LookupTypeMapEntry(Type) (reflect.Type, error) } // ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type. @@ -276,28 +107,28 @@ func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val ref // ValueDecoder. A DecodeContext instance is provided and serves similar functionality to the // EncodeContext. type ValueDecoder interface { - DecodeValue(DecodeContext, ValueReader, reflect.Value) error + DecodeValue(DecoderRegistry, ValueReader, reflect.Value) error } // ValueDecoderFunc is an adapter function that allows a function with the correct signature to be // used as a ValueDecoder. -type ValueDecoderFunc func(DecodeContext, ValueReader, reflect.Value) error +type ValueDecoderFunc func(DecoderRegistry, ValueReader, reflect.Value) error // DecodeValue implements the ValueDecoder interface. -func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - return fn(dc, vr, val) +func (fn ValueDecoderFunc) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { + return fn(reg, vr, val) } // typeDecoder is the interface implemented by types that can handle the decoding of a value given its type. type typeDecoder interface { - decodeType(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error) + decodeType(DecoderRegistry, ValueReader, reflect.Type) (reflect.Value, error) } // typeDecoderFunc is an adapter function that allows a function with the correct signature to be used as a typeDecoder. -type typeDecoderFunc func(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error) +type typeDecoderFunc func(DecoderRegistry, ValueReader, reflect.Type) (reflect.Value, error) -func (fn typeDecoderFunc) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - return fn(dc, vr, t) +func (fn typeDecoderFunc) decodeType(reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { + return fn(reg, vr, t) } // decodeAdapter allows two functions with the correct signatures to be used as both a ValueDecoder and typeDecoder. @@ -309,31 +140,13 @@ type decodeAdapter struct { var _ ValueDecoder = decodeAdapter{} var _ typeDecoder = decodeAdapter{} -// decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type -// t and calls decoder.DecodeValue on it. -func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - td, _ := decoder.(typeDecoder) - return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true) -} - -func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { - if td != nil { - val, err := td.decodeType(dc, vr, t) - if err == nil && convert && val.Type() != t { - // This conversion step is necessary for slices and maps. If a user declares variables like: - // - // type myBool bool - // var m map[string]myBool - // - // and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present - // because we'll try to assign a value of type bool to one of type myBool. - val = val.Convert(t) - } - return val, err +func decodeTypeOrValueWithInfo(vd ValueDecoder, reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { + if td, _ := vd.(typeDecoder); td != nil { + return td.decodeType(reg, vr, t) } val := reflect.New(t).Elem() - err := vd.DecodeValue(dc, vr, val) + err := vd.DecodeValue(reg, vr, val) return val, err } diff --git a/bson/bsoncodec_test.go b/bson/bsoncodec_test.go index d1dc21a953..02db6ca003 100644 --- a/bson/bsoncodec_test.go +++ b/bson/bsoncodec_test.go @@ -7,40 +7,10 @@ package bson import ( - "fmt" "reflect" "testing" ) -func ExampleValueEncoder() { - var _ ValueEncoderFunc = func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if val.Kind() != reflect.String { - return ValueEncoderError{Name: "StringEncodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} - } - - return vw.WriteString(val.String()) - } -} - -func ExampleValueDecoder() { - var _ ValueDecoderFunc = func(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Kind() != reflect.String { - return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} - } - - if vr.Type() != TypeString { - return fmt.Errorf("cannot decode %v into a string type", vr.Type()) - } - - str, err := vr.ReadString() - if err != nil { - return err - } - val.SetString(str) - return nil - } -} - type llCodec struct { t *testing.T decodeval interface{} @@ -48,7 +18,7 @@ type llCodec struct { err error } -func (llc *llCodec) EncodeValue(_ EncodeContext, _ ValueWriter, i interface{}) error { +func (llc *llCodec) EncodeValue(_ EncoderRegistry, _ ValueWriter, i interface{}) error { if llc.err != nil { return llc.err } @@ -57,7 +27,7 @@ func (llc *llCodec) EncodeValue(_ EncodeContext, _ ValueWriter, i interface{}) e return nil } -func (llc *llCodec) DecodeValue(_ DecodeContext, _ ValueReader, val reflect.Value) error { +func (llc *llCodec) DecodeValue(_ DecoderRegistry, _ ValueReader, val reflect.Value) error { if llc.err != nil { return llc.err } diff --git a/bson/bsonoptions/byte_slice_codec_options.go b/bson/bsonoptions/byte_slice_codec_options.go deleted file mode 100644 index 996bd17127..0000000000 --- a/bson/bsonoptions/byte_slice_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// ByteSliceCodecOptions represents all possible options for byte slice encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type ByteSliceCodecOptions struct { - EncodeNilAsEmpty *bool // Specifies if a nil byte slice should encode as an empty binary instead of null. Defaults to false. -} - -// ByteSliceCodec creates a new *ByteSliceCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func ByteSliceCodec() *ByteSliceCodecOptions { - return &ByteSliceCodecOptions{} -} - -// SetEncodeNilAsEmpty specifies if a nil byte slice should encode as an empty binary instead of null. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilByteSliceAsEmpty] instead. -func (bs *ByteSliceCodecOptions) SetEncodeNilAsEmpty(b bool) *ByteSliceCodecOptions { - bs.EncodeNilAsEmpty = &b - return bs -} - -// MergeByteSliceCodecOptions combines the given *ByteSliceCodecOptions into a single *ByteSliceCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeByteSliceCodecOptions(opts ...*ByteSliceCodecOptions) *ByteSliceCodecOptions { - bs := ByteSliceCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.EncodeNilAsEmpty != nil { - bs.EncodeNilAsEmpty = opt.EncodeNilAsEmpty - } - } - - return bs -} diff --git a/bson/bsonoptions/doc.go b/bson/bsonoptions/doc.go deleted file mode 100644 index c40973c8d4..0000000000 --- a/bson/bsonoptions/doc.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2022-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -// Package bsonoptions defines the optional configurations for the BSON codecs. -package bsonoptions diff --git a/bson/bsonoptions/empty_interface_codec_options.go b/bson/bsonoptions/empty_interface_codec_options.go deleted file mode 100644 index f522c7e03f..0000000000 --- a/bson/bsonoptions/empty_interface_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// EmptyInterfaceCodecOptions represents all possible options for interface{} encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type EmptyInterfaceCodecOptions struct { - DecodeBinaryAsSlice *bool // Specifies if Old and Generic type binarys should default to []slice instead of primitive.Binary. Defaults to false. -} - -// EmptyInterfaceCodec creates a new *EmptyInterfaceCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func EmptyInterfaceCodec() *EmptyInterfaceCodecOptions { - return &EmptyInterfaceCodecOptions{} -} - -// SetDecodeBinaryAsSlice specifies if Old and Generic type binarys should default to []slice instead of primitive.Binary. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.BinaryAsSlice] instead. -func (e *EmptyInterfaceCodecOptions) SetDecodeBinaryAsSlice(b bool) *EmptyInterfaceCodecOptions { - e.DecodeBinaryAsSlice = &b - return e -} - -// MergeEmptyInterfaceCodecOptions combines the given *EmptyInterfaceCodecOptions into a single *EmptyInterfaceCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeEmptyInterfaceCodecOptions(opts ...*EmptyInterfaceCodecOptions) *EmptyInterfaceCodecOptions { - e := EmptyInterfaceCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.DecodeBinaryAsSlice != nil { - e.DecodeBinaryAsSlice = opt.DecodeBinaryAsSlice - } - } - - return e -} diff --git a/bson/bsonoptions/map_codec_options.go b/bson/bsonoptions/map_codec_options.go deleted file mode 100644 index a7a7c1d980..0000000000 --- a/bson/bsonoptions/map_codec_options.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// MapCodecOptions represents all possible options for map encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type MapCodecOptions struct { - DecodeZerosMap *bool // Specifies if the map should be zeroed before decoding into it. Defaults to false. - EncodeNilAsEmpty *bool // Specifies if a nil map should encode as an empty document instead of null. Defaults to false. - // Specifies how keys should be handled. If false, the behavior matches encoding/json, where the encoding key type must - // either be a string, an integer type, or implement bsoncodec.KeyMarshaler and the decoding key type must either be a - // string, an integer type, or implement bsoncodec.KeyUnmarshaler. If true, keys are encoded with fmt.Sprint() and the - // encoding key type must be a string, an integer type, or a float. If true, the use of Stringer will override - // TextMarshaler/TextUnmarshaler. Defaults to false. - EncodeKeysWithStringer *bool -} - -// MapCodec creates a new *MapCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func MapCodec() *MapCodecOptions { - return &MapCodecOptions{} -} - -// SetDecodeZerosMap specifies if the map should be zeroed before decoding into it. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroMaps] instead. -func (t *MapCodecOptions) SetDecodeZerosMap(b bool) *MapCodecOptions { - t.DecodeZerosMap = &b - return t -} - -// SetEncodeNilAsEmpty specifies if a nil map should encode as an empty document instead of null. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilMapAsEmpty] instead. -func (t *MapCodecOptions) SetEncodeNilAsEmpty(b bool) *MapCodecOptions { - t.EncodeNilAsEmpty = &b - return t -} - -// SetEncodeKeysWithStringer specifies how keys should be handled. If false, the behavior matches encoding/json, where the -// encoding key type must either be a string, an integer type, or implement bsoncodec.KeyMarshaler and the decoding key -// type must either be a string, an integer type, or implement bsoncodec.KeyUnmarshaler. If true, keys are encoded with -// fmt.Sprint() and the encoding key type must be a string, an integer type, or a float. If true, the use of Stringer -// will override TextMarshaler/TextUnmarshaler. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.StringifyMapKeysWithFmt] instead. -func (t *MapCodecOptions) SetEncodeKeysWithStringer(b bool) *MapCodecOptions { - t.EncodeKeysWithStringer = &b - return t -} - -// MergeMapCodecOptions combines the given *MapCodecOptions into a single *MapCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeMapCodecOptions(opts ...*MapCodecOptions) *MapCodecOptions { - s := MapCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.DecodeZerosMap != nil { - s.DecodeZerosMap = opt.DecodeZerosMap - } - if opt.EncodeNilAsEmpty != nil { - s.EncodeNilAsEmpty = opt.EncodeNilAsEmpty - } - if opt.EncodeKeysWithStringer != nil { - s.EncodeKeysWithStringer = opt.EncodeKeysWithStringer - } - } - - return s -} diff --git a/bson/bsonoptions/slice_codec_options.go b/bson/bsonoptions/slice_codec_options.go deleted file mode 100644 index 3c1e4f35ba..0000000000 --- a/bson/bsonoptions/slice_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// SliceCodecOptions represents all possible options for slice encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type SliceCodecOptions struct { - EncodeNilAsEmpty *bool // Specifies if a nil slice should encode as an empty array instead of null. Defaults to false. -} - -// SliceCodec creates a new *SliceCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func SliceCodec() *SliceCodecOptions { - return &SliceCodecOptions{} -} - -// SetEncodeNilAsEmpty specifies if a nil slice should encode as an empty array instead of null. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilSliceAsEmpty] instead. -func (s *SliceCodecOptions) SetEncodeNilAsEmpty(b bool) *SliceCodecOptions { - s.EncodeNilAsEmpty = &b - return s -} - -// MergeSliceCodecOptions combines the given *SliceCodecOptions into a single *SliceCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeSliceCodecOptions(opts ...*SliceCodecOptions) *SliceCodecOptions { - s := SliceCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.EncodeNilAsEmpty != nil { - s.EncodeNilAsEmpty = opt.EncodeNilAsEmpty - } - } - - return s -} diff --git a/bson/bsonoptions/string_codec_options.go b/bson/bsonoptions/string_codec_options.go deleted file mode 100644 index f8b76f996e..0000000000 --- a/bson/bsonoptions/string_codec_options.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -var defaultDecodeOIDAsHex = true - -// StringCodecOptions represents all possible options for string encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type StringCodecOptions struct { - DecodeObjectIDAsHex *bool // Specifies if we should decode ObjectID as the hex value. Defaults to true. -} - -// StringCodec creates a new *StringCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func StringCodec() *StringCodecOptions { - return &StringCodecOptions{} -} - -// SetDecodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. If false, a string made -// from the raw object ID bytes will be used. Defaults to true. -// -// Deprecated: Decoding object IDs as raw bytes will not be supported in Go Driver 2.0. -func (t *StringCodecOptions) SetDecodeObjectIDAsHex(b bool) *StringCodecOptions { - t.DecodeObjectIDAsHex = &b - return t -} - -// MergeStringCodecOptions combines the given *StringCodecOptions into a single *StringCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeStringCodecOptions(opts ...*StringCodecOptions) *StringCodecOptions { - s := &StringCodecOptions{&defaultDecodeOIDAsHex} - for _, opt := range opts { - if opt == nil { - continue - } - if opt.DecodeObjectIDAsHex != nil { - s.DecodeObjectIDAsHex = opt.DecodeObjectIDAsHex - } - } - - return s -} diff --git a/bson/bsonoptions/struct_codec_options.go b/bson/bsonoptions/struct_codec_options.go deleted file mode 100644 index 1cbfa32e8b..0000000000 --- a/bson/bsonoptions/struct_codec_options.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -var defaultOverwriteDuplicatedInlinedFields = true - -// StructCodecOptions represents all possible options for struct encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type StructCodecOptions struct { - DecodeZeroStruct *bool // Specifies if structs should be zeroed before decoding into them. Defaults to false. - DecodeDeepZeroInline *bool // Specifies if structs should be recursively zeroed when a inline value is decoded. Defaults to false. - EncodeOmitDefaultStruct *bool // Specifies if default structs should be considered empty by omitempty. Defaults to false. - AllowUnexportedFields *bool // Specifies if unexported fields should be marshaled/unmarshaled. Defaults to false. - OverwriteDuplicatedInlinedFields *bool // Specifies if fields in inlined structs can be overwritten by higher level struct fields with the same key. Defaults to true. -} - -// StructCodec creates a new *StructCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func StructCodec() *StructCodecOptions { - return &StructCodecOptions{} -} - -// SetDecodeZeroStruct specifies if structs should be zeroed before decoding into them. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroStructs] instead. -func (t *StructCodecOptions) SetDecodeZeroStruct(b bool) *StructCodecOptions { - t.DecodeZeroStruct = &b - return t -} - -// SetDecodeDeepZeroInline specifies if structs should be zeroed before decoding into them. Defaults to false. -// -// Deprecated: DecodeDeepZeroInline will not be supported in Go Driver 2.0. -func (t *StructCodecOptions) SetDecodeDeepZeroInline(b bool) *StructCodecOptions { - t.DecodeDeepZeroInline = &b - return t -} - -// SetEncodeOmitDefaultStruct specifies if default structs should be considered empty by omitempty. A default struct has all -// its values set to their default value. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.OmitZeroStruct] instead. -func (t *StructCodecOptions) SetEncodeOmitDefaultStruct(b bool) *StructCodecOptions { - t.EncodeOmitDefaultStruct = &b - return t -} - -// SetOverwriteDuplicatedInlinedFields specifies if inlined struct fields can be overwritten by higher level struct fields with the -// same bson key. When true and decoding, values will be written to the outermost struct with a matching key, and when -// encoding, keys will have the value of the top-most matching field. When false, decoding and encoding will error if -// there are duplicate keys after the struct is inlined. Defaults to true. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.ErrorOnInlineDuplicates] instead. -func (t *StructCodecOptions) SetOverwriteDuplicatedInlinedFields(b bool) *StructCodecOptions { - t.OverwriteDuplicatedInlinedFields = &b - return t -} - -// SetAllowUnexportedFields specifies if unexported fields should be marshaled/unmarshaled. Defaults to false. -// -// Deprecated: AllowUnexportedFields does not work on recent versions of Go and will not be -// supported in Go Driver 2.0. -func (t *StructCodecOptions) SetAllowUnexportedFields(b bool) *StructCodecOptions { - t.AllowUnexportedFields = &b - return t -} - -// MergeStructCodecOptions combines the given *StructCodecOptions into a single *StructCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeStructCodecOptions(opts ...*StructCodecOptions) *StructCodecOptions { - s := &StructCodecOptions{ - OverwriteDuplicatedInlinedFields: &defaultOverwriteDuplicatedInlinedFields, - } - for _, opt := range opts { - if opt == nil { - continue - } - - if opt.DecodeZeroStruct != nil { - s.DecodeZeroStruct = opt.DecodeZeroStruct - } - if opt.DecodeDeepZeroInline != nil { - s.DecodeDeepZeroInline = opt.DecodeDeepZeroInline - } - if opt.EncodeOmitDefaultStruct != nil { - s.EncodeOmitDefaultStruct = opt.EncodeOmitDefaultStruct - } - if opt.OverwriteDuplicatedInlinedFields != nil { - s.OverwriteDuplicatedInlinedFields = opt.OverwriteDuplicatedInlinedFields - } - if opt.AllowUnexportedFields != nil { - s.AllowUnexportedFields = opt.AllowUnexportedFields - } - } - - return s -} diff --git a/bson/bsonoptions/time_codec_options.go b/bson/bsonoptions/time_codec_options.go deleted file mode 100644 index 3f38433d22..0000000000 --- a/bson/bsonoptions/time_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// TimeCodecOptions represents all possible options for time.Time encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type TimeCodecOptions struct { - UseLocalTimeZone *bool // Specifies if we should decode into the local time zone. Defaults to false. -} - -// TimeCodec creates a new *TimeCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func TimeCodec() *TimeCodecOptions { - return &TimeCodecOptions{} -} - -// SetUseLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.UseLocalTimeZone] instead. -func (t *TimeCodecOptions) SetUseLocalTimeZone(b bool) *TimeCodecOptions { - t.UseLocalTimeZone = &b - return t -} - -// MergeTimeCodecOptions combines the given *TimeCodecOptions into a single *TimeCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeTimeCodecOptions(opts ...*TimeCodecOptions) *TimeCodecOptions { - t := TimeCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.UseLocalTimeZone != nil { - t.UseLocalTimeZone = opt.UseLocalTimeZone - } - } - - return t -} diff --git a/bson/bsonoptions/uint_codec_options.go b/bson/bsonoptions/uint_codec_options.go deleted file mode 100644 index 5091e4d963..0000000000 --- a/bson/bsonoptions/uint_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// UIntCodecOptions represents all possible options for uint encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type UIntCodecOptions struct { - EncodeToMinSize *bool // Specifies if all uints except uint64 should be decoded to minimum size bsontype. Defaults to false. -} - -// UIntCodec creates a new *UIntCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func UIntCodec() *UIntCodecOptions { - return &UIntCodecOptions{} -} - -// SetEncodeToMinSize specifies if all uints except uint64 should be decoded to minimum size bsontype. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.IntMinSize] instead. -func (u *UIntCodecOptions) SetEncodeToMinSize(b bool) *UIntCodecOptions { - u.EncodeToMinSize = &b - return u -} - -// MergeUIntCodecOptions combines the given *UIntCodecOptions into a single *UIntCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeUIntCodecOptions(opts ...*UIntCodecOptions) *UIntCodecOptions { - u := UIntCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.EncodeToMinSize != nil { - u.EncodeToMinSize = opt.EncodeToMinSize - } - } - - return u -} diff --git a/bson/byte_slice_codec.go b/bson/byte_slice_codec.go index 586c006467..779ae9ed71 100644 --- a/bson/byte_slice_codec.go +++ b/bson/byte_slice_codec.go @@ -9,56 +9,27 @@ package bson import ( "fmt" "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -// ByteSliceCodec is the Codec used for []byte values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ByteSliceCodec registered. -type ByteSliceCodec struct { - // EncodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values +// byteSliceCodec is the Codec used for []byte values. +type byteSliceCodec struct { + // encodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values // instead of BSON null. - // - // Deprecated: Use bson.Encoder.NilByteSliceAsEmpty instead. - EncodeNilAsEmpty bool -} - -var ( - defaultByteSliceCodec = NewByteSliceCodec() - - // Assert that defaultByteSliceCodec satisfies the typeDecoder interface, which allows it to be - // used by collection type decoders (e.g. map, slice, etc) to set individual values in a - // collection. - _ typeDecoder = defaultByteSliceCodec -) - -// NewByteSliceCodec returns a ByteSliceCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ByteSliceCodec registered. -func NewByteSliceCodec(opts ...*bsonoptions.ByteSliceCodecOptions) *ByteSliceCodec { - byteSliceOpt := bsonoptions.MergeByteSliceCodecOptions(opts...) - codec := ByteSliceCodec{} - if byteSliceOpt.EncodeNilAsEmpty != nil { - codec.EncodeNilAsEmpty = *byteSliceOpt.EncodeNilAsEmpty - } - return &codec + encodeNilAsEmpty bool } // EncodeValue is the ValueEncoder for []byte. -func (bsc *ByteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (bsc *byteSliceCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tByteSlice { return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } - if val.IsNil() && !bsc.EncodeNilAsEmpty && !ec.nilByteSliceAsEmpty { + if val.IsNil() && !bsc.encodeNilAsEmpty { return vw.WriteNull() } return vw.WriteBinary(val.Interface().([]byte)) } -func (bsc *ByteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (bsc *byteSliceCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tByteSlice { return emptyValue, ValueDecoderError{ Name: "ByteSliceDecodeValue", @@ -106,12 +77,12 @@ func (bsc *ByteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect } // DecodeValue is the ValueDecoder for []byte. -func (bsc *ByteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (bsc *byteSliceCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tByteSlice { return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } - elem, err := bsc.decodeType(dc, vr, tByteSlice) + elem, err := bsc.decodeType(reg, vr, tByteSlice) if err != nil { return err } diff --git a/bson/cond_addr_codec.go b/bson/cond_addr_codec.go index fba139ff07..cd2727e2cc 100644 --- a/bson/cond_addr_codec.go +++ b/bson/cond_addr_codec.go @@ -18,19 +18,13 @@ type condAddrEncoder struct { var _ ValueEncoder = (*condAddrEncoder)(nil) -// newCondAddrEncoder returns an condAddrEncoder. -func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder { - encoder := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc} - return &encoder -} - // EncodeValue is the ValueEncoderFunc for a value that may be addressable. -func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (cae *condAddrEncoder) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.CanAddr() { - return cae.canAddrEnc.EncodeValue(ec, vw, val) + return cae.canAddrEnc.EncodeValue(reg, vw, val) } if cae.elseEnc != nil { - return cae.elseEnc.EncodeValue(ec, vw, val) + return cae.elseEnc.EncodeValue(reg, vw, val) } return ErrNoEncoder{Type: val.Type()} } @@ -43,19 +37,13 @@ type condAddrDecoder struct { var _ ValueDecoder = (*condAddrDecoder)(nil) -// newCondAddrDecoder returns an CondAddrDecoder. -func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder { - decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec} - return &decoder -} - // DecodeValue is the ValueDecoderFunc for a value that may be addressable. -func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (cad *condAddrDecoder) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if val.CanAddr() { - return cad.canAddrDec.DecodeValue(dc, vr, val) + return cad.canAddrDec.DecodeValue(reg, vr, val) } if cad.elseDec != nil { - return cad.elseDec.DecodeValue(dc, vr, val) + return cad.elseDec.DecodeValue(reg, vr, val) } return ErrNoDecoder{Type: val.Type()} } diff --git a/bson/cond_addr_codec_test.go b/bson/cond_addr_codec_test.go index c22c29fe72..6fd777ae77 100644 --- a/bson/cond_addr_codec_test.go +++ b/bson/cond_addr_codec_test.go @@ -22,15 +22,15 @@ func TestCondAddrCodec(t *testing.T) { t.Run("addressEncode", func(t *testing.T) { invoked := 0 - encode1 := ValueEncoderFunc(func(EncodeContext, ValueWriter, reflect.Value) error { + encode1 := ValueEncoderFunc(func(EncoderRegistry, ValueWriter, reflect.Value) error { invoked = 1 return nil }) - encode2 := ValueEncoderFunc(func(EncodeContext, ValueWriter, reflect.Value) error { + encode2 := ValueEncoderFunc(func(EncoderRegistry, ValueWriter, reflect.Value) error { invoked = 2 return nil }) - condEncoder := newCondAddrEncoder(encode1, encode2) + condEncoder := &condAddrEncoder{canAddrEnc: encode1, elseEnc: encode2} testCases := []struct { name string @@ -42,7 +42,7 @@ func TestCondAddrCodec(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := condEncoder.EncodeValue(EncodeContext{}, rw, tc.val) + err := condEncoder.EncodeValue(nil, rw, tc.val) assert.Nil(t, err, "CondAddrEncoder error: %v", err) assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked) @@ -50,23 +50,23 @@ func TestCondAddrCodec(t *testing.T) { } t.Run("error", func(t *testing.T) { - errEncoder := newCondAddrEncoder(encode1, nil) - err := errEncoder.EncodeValue(EncodeContext{}, rw, unaddressable) + errEncoder := &condAddrEncoder{canAddrEnc: encode1, elseEnc: nil} + err := errEncoder.EncodeValue(nil, rw, unaddressable) want := ErrNoEncoder{Type: unaddressable.Type()} assert.Equal(t, err, want, "expected error %v, got %v", want, err) }) }) t.Run("addressDecode", func(t *testing.T) { invoked := 0 - decode1 := ValueDecoderFunc(func(DecodeContext, ValueReader, reflect.Value) error { + decode1 := ValueDecoderFunc(func(DecoderRegistry, ValueReader, reflect.Value) error { invoked = 1 return nil }) - decode2 := ValueDecoderFunc(func(DecodeContext, ValueReader, reflect.Value) error { + decode2 := ValueDecoderFunc(func(DecoderRegistry, ValueReader, reflect.Value) error { invoked = 2 return nil }) - condDecoder := newCondAddrDecoder(decode1, decode2) + condDecoder := &condAddrDecoder{canAddrDec: decode1, elseDec: decode2} testCases := []struct { name string @@ -78,7 +78,7 @@ func TestCondAddrCodec(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := condDecoder.DecodeValue(DecodeContext{}, rw, tc.val) + err := condDecoder.DecodeValue(nil, rw, tc.val) assert.Nil(t, err, "CondAddrDecoder error: %v", err) assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked) @@ -86,8 +86,8 @@ func TestCondAddrCodec(t *testing.T) { } t.Run("error", func(t *testing.T) { - errDecoder := newCondAddrDecoder(decode1, nil) - err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable) + errDecoder := &condAddrDecoder{canAddrDec: decode1, elseDec: nil} + err := errDecoder.DecodeValue(nil, rw, unaddressable) want := ErrNoDecoder{Type: unaddressable.Type()} assert.Equal(t, err, want, "expected error %v, got %v", want, err) }) diff --git a/bson/decoder.go b/bson/decoder.go index 6ea5ad97c1..7b0d0d68bd 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -10,44 +10,38 @@ import ( "errors" "fmt" "reflect" - "sync" ) // ErrDecodeToNil is the error returned when trying to decode to a nil value var ErrDecodeToNil = errors.New("cannot Decode to nil value") -// This pool is used to keep the allocations of Decoders down. This is only used for the Marshal* -// methods and is not consumable from outside of this package. The Decoders retrieved from this pool -// must have both Reset and SetRegistry called on them. -var decPool = sync.Pool{ - New: func() interface{} { - return new(Decoder) - }, +// ConfigurableDecoderRegistry refers a DecoderRegistry that is configurable with *RegistryOpt. +type ConfigurableDecoderRegistry interface { + DecoderRegistry + SetCodecOption(opt *RegistryOpt) error } // A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as // the source of BSON data. type Decoder struct { - dc DecodeContext - vr ValueReader - - // We persist defaultDocumentM and defaultDocumentD on the Decoder to prevent overwriting from - // (*Decoder).SetContext. - defaultDocumentM bool - defaultDocumentD bool - - binaryAsSlice bool - useJSONStructTags bool - useLocalTimeZone bool - zeroMaps bool - zeroStructs bool + reg ConfigurableDecoderRegistry + vr ValueReader } -// NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr. +// NewDecoder returns a new decoder that uses the default registry to read from vr. func NewDecoder(vr ValueReader) *Decoder { + r := NewRegistryBuilder().Build() + return &Decoder{ + reg: r, + vr: vr, + } +} + +// NewDecoderWithRegistry returns a new decoder that uses the given registry to read from vr. +func NewDecoderWithRegistry(r *Registry, vr ValueReader) *Decoder { return &Decoder{ - dc: DecodeContext{Registry: DefaultRegistry}, - vr: vr, + reg: r, + vr: vr, } } @@ -79,92 +73,15 @@ func (d *Decoder) Decode(val interface{}) error { default: return fmt.Errorf("argument to Decode must be a pointer or a map, but got %v", rval) } - decoder, err := d.dc.LookupDecoder(rval.Type()) + decoder, err := d.reg.LookupDecoder(rval.Type()) if err != nil { return err } - if d.defaultDocumentM { - d.dc.DefaultDocumentM() - } - if d.defaultDocumentD { - d.dc.DefaultDocumentD() - } - if d.binaryAsSlice { - d.dc.BinaryAsSlice() - } - if d.useJSONStructTags { - d.dc.UseJSONStructTags() - } - if d.useLocalTimeZone { - d.dc.UseLocalTimeZone() - } - if d.zeroMaps { - d.dc.ZeroMaps() - } - if d.zeroStructs { - d.dc.ZeroStructs() - } - - return decoder.DecodeValue(d.dc, d.vr, rval) -} - -// Reset will reset the state of the decoder, using the same *DecodeContext used in -// the original construction but using vr for reading. -func (d *Decoder) Reset(vr ValueReader) { - d.vr = vr -} - -// SetRegistry replaces the current registry of the decoder with r. -func (d *Decoder) SetRegistry(r *Registry) { - d.dc.Registry = r -} - -// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This -// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -func (d *Decoder) DefaultDocumentM() { - d.defaultDocumentM = true -} - -// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This -// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -func (d *Decoder) DefaultDocumentD() { - d.defaultDocumentD = true -} - -// AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values -// when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct -// field. The truncation logic does not apply to BSON "decimal128" values. -func (d *Decoder) AllowTruncatingDoubles() { - d.dc.Truncate = true -} - -// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or -// "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary. -func (d *Decoder) BinaryAsSlice() { - d.binaryAsSlice = true -} - -// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson" -// struct tag is not specified. -func (d *Decoder) UseJSONStructTags() { - d.useJSONStructTags = true -} - -// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead -// of the UTC timezone. -func (d *Decoder) UseLocalTimeZone() { - d.useLocalTimeZone = true -} - -// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value -// passed to Decode before unmarshaling BSON documents into them. -func (d *Decoder) ZeroMaps() { - d.zeroMaps = true + return decoder.DecodeValue(d.reg, d.vr, rval) } -// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination -// value passed to Decode before unmarshaling BSON documents into them. -func (d *Decoder) ZeroStructs() { - d.zeroStructs = true +// SetBehavior set the decoder behavior with *RegistryOpt. +func (d *Decoder) SetBehavior(opt *RegistryOpt) error { + return d.reg.SetCodecOption(opt) } diff --git a/bson/decoder_example_test.go b/bson/decoder_example_test.go index 3e17e98927..590756090d 100644 --- a/bson/decoder_example_test.go +++ b/bson/decoder_example_test.go @@ -48,7 +48,7 @@ func ExampleDecoder() { // Output: {Name:Cereal Rounds SKU:AB12345 Price:399} } -func ExampleDecoder_DefaultDocumentM() { +func ExampleDecoder_SetBehavior_defaultDocumentM() { // Marshal a BSON document that contains a city name and a nested document // with various city properties. doc := bson.D{ @@ -77,7 +77,10 @@ func ExampleDecoder_DefaultDocumentM() { // type if the decode destination has no type information. The Properties // field in the City struct will be decoded as a "M" (i.e. map) instead // of the default "D". - decoder.DefaultDocumentM() + err = decoder.SetBehavior(bson.DefaultDocumentM) + if err != nil { + panic(err) + } var res City err = decoder.Decode(&res) @@ -89,7 +92,7 @@ func ExampleDecoder_DefaultDocumentM() { // Output: {Name:New York Properties:map[elevation:10 population:8804190 state:NY]} } -func ExampleDecoder_UseJSONStructTags() { +func ExampleDecoder_SetBehavior_useJSONStructTags() { // Marshal a BSON document that contains the name, SKU, and price (in cents) // of a product. doc := bson.D{ @@ -114,7 +117,10 @@ func ExampleDecoder_UseJSONStructTags() { // Configure the Decoder to use "json" struct tags when decoding if "bson" // struct tags are not present. - decoder.UseJSONStructTags() + err = decoder.SetBehavior(bson.UseJSONStructTags) + if err != nil { + panic(err) + } var res Product err = decoder.Decode(&res) diff --git a/bson/decoder_test.go b/bson/decoder_test.go index 8fe8d07480..6ff2ad7545 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -29,10 +29,10 @@ func TestBasicDecode(t *testing.T) { got := reflect.New(tc.sType).Elem() vr := NewValueReader(tc.data) - reg := DefaultRegistry + reg := NewRegistryBuilder().Build() decoder, err := reg.LookupDecoder(reflect.TypeOf(got)) noerr(t, err) - err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got) + err = decoder.DecodeValue(reg, vr, got) noerr(t, err) assert.Equal(t, tc.want, got.Addr().Interface(), "Results do not match.") }) @@ -183,34 +183,6 @@ func TestDecoderv2(t *testing.T) { want := foo{Item: "canvas", Qty: 4, Bonus: 2} assert.Equal(t, want, got, "Results do not match.") }) - t.Run("Reset", func(t *testing.T) { - t.Parallel() - - vr1, vr2 := NewValueReader([]byte{}), NewValueReader([]byte{}) - dec := NewDecoder(vr1) - if dec.vr != vr1 { - t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr1) - } - dec.Reset(vr2) - if dec.vr != vr2 { - t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr2) - } - }) - t.Run("SetRegistry", func(t *testing.T) { - t.Parallel() - - r1, r2 := DefaultRegistry, NewRegistry() - dc1 := DecodeContext{Registry: r1} - dc2 := DecodeContext{Registry: r2} - dec := NewDecoder(NewValueReader([]byte{})) - if !reflect.DeepEqual(dec.dc, dc1) { - t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1) - } - dec.SetRegistry(r2) - if !reflect.DeepEqual(dec.dc, dc2) { - t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2) - } - }) t.Run("DecodeToNil", func(t *testing.T) { t.Parallel() @@ -281,7 +253,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "AllowTruncatingDoubles", configure: func(dec *Decoder) { - dec.AllowTruncatingDoubles() + _ = dec.SetBehavior(AllowTruncatingDoubles) }, input: bsoncore.NewDocumentBuilder(). AppendDouble("myInt", 1.999). @@ -314,7 +286,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "BinaryAsSlice", configure: func(dec *Decoder) { - dec.BinaryAsSlice() + _ = dec.SetBehavior(BinaryAsSlice) }, input: bsoncore.NewDocumentBuilder(). AppendBinary("myBinary", TypeBinaryGeneric, []byte{}). @@ -327,7 +299,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "DefaultDocumentD nested", configure: func(dec *Decoder) { - dec.DefaultDocumentD() + _ = dec.SetBehavior(DefaultDocumentD) }, input: bsoncore.NewDocumentBuilder(). AppendDocument("myDocument", bsoncore.NewDocumentBuilder(). @@ -344,7 +316,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "DefaultDocumentM nested", configure: func(dec *Decoder) { - dec.DefaultDocumentM() + _ = dec.SetBehavior(DefaultDocumentM) }, input: bsoncore.NewDocumentBuilder(). AppendDocument("myDocument", bsoncore.NewDocumentBuilder(). @@ -361,7 +333,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "UseJSONStructTags", configure: func(dec *Decoder) { - dec.UseJSONStructTags() + _ = dec.SetBehavior(UseJSONStructTags) }, input: bsoncore.NewDocumentBuilder(). AppendString("jsonFieldName", "test value"). @@ -374,7 +346,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "UseLocalTimeZone", configure: func(dec *Decoder) { - dec.UseLocalTimeZone() + _ = dec.SetBehavior(UseLocalTimeZone) }, input: bsoncore.NewDocumentBuilder(). AppendDateTime("myTime", 1684349179939). @@ -387,7 +359,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "ZeroMaps", configure: func(dec *Decoder) { - dec.ZeroMaps() + _ = dec.SetBehavior(ZeroMaps) }, input: bsoncore.NewDocumentBuilder(). AppendDocument("myMap", bsoncore.NewDocumentBuilder(). @@ -404,7 +376,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "ZeroStructs", configure: func(dec *Decoder) { - dec.ZeroStructs() + _ = dec.SetBehavior(ZeroStructs) }, input: bsoncore.NewDocumentBuilder(). AppendString("myString", "test value"). @@ -445,10 +417,11 @@ func TestDecoderConfiguration(t *testing.T) { dec := NewDecoder(NewValueReader(input)) - dec.DefaultDocumentM() + err := dec.SetBehavior(DefaultDocumentM) + require.NoError(t, err, "SetBehavior error") var got interface{} - err := dec.Decode(&got) + err = dec.Decode(&got) require.NoError(t, err, "Decode error") want := M{ @@ -469,10 +442,11 @@ func TestDecoderConfiguration(t *testing.T) { dec := NewDecoder(NewValueReader(input)) - dec.DefaultDocumentD() + err := dec.SetBehavior(DefaultDocumentD) + require.NoError(t, err, "SetBehavior error") var got interface{} - err := dec.Decode(&got) + err = dec.Decode(&got) require.NoError(t, err, "Decode error") want := D{ diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index bc8c7b9344..bd49385c7b 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -10,18 +10,15 @@ import ( "encoding/json" "errors" "fmt" - "math" "net/url" "reflect" "strconv" - "time" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) var ( - defaultValueDecoders DefaultValueDecoders - errCannotTruncate = errors.New("float64 can only be truncated to a lower precision type when truncation is enabled") + errCannotTruncate = errors.New("float64 can only be truncated to a lower precision type when truncation is enabled") ) type decodeBinaryError struct { @@ -33,82 +30,59 @@ func (d decodeBinaryError) Error() string { return fmt.Sprintf("only binary values with subtype 0x00 or 0x02 can be decoded into %s, but got subtype %v", d.typeName, d.subtype) } -func newDefaultStructCodec() *StructCodec { - codec, err := NewStructCodec(DefaultStructTagParser) - if err != nil { - // This function is called from the codec registration path, so errors can't be propagated. If there's an error - // constructing the StructCodec, we panic to avoid losing it. - panic(fmt.Errorf("error creating default StructCodec: %w", err)) - } - return codec -} - -// DefaultValueDecoders is a namespace type for the default ValueDecoders used -// when creating a registry. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -type DefaultValueDecoders struct{} - -// RegisterDefaultDecoders will register the decoder methods attached to DefaultValueDecoders with -// the provided RegistryBuilder. +// registerDefaultDecoders will register the default decoder methods with the provided Registry. // // There is no support for decoding map[string]interface{} because there is no decoder for // interface{}, so users must either register this decoder themselves or use the // EmptyInterfaceDecoder available in the bson package. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) { +func registerDefaultDecoders(rb *RegistryBuilder) { if rb == nil { panic(errors.New("argument to RegisterDefaultDecoders must not be nil")) } - intDecoder := decodeAdapter{dvd.IntDecodeValue, dvd.intDecodeType} - floatDecoder := decodeAdapter{dvd.FloatDecodeValue, dvd.floatDecodeType} - + numDecoder := func(*Registry) ValueDecoder { return &numCodec{} } rb. - RegisterTypeDecoder(tD, ValueDecoderFunc(dvd.DDecodeValue)). - RegisterTypeDecoder(tBinary, decodeAdapter{dvd.BinaryDecodeValue, dvd.binaryDecodeType}). - RegisterTypeDecoder(tUndefined, decodeAdapter{dvd.UndefinedDecodeValue, dvd.undefinedDecodeType}). - RegisterTypeDecoder(tDateTime, decodeAdapter{dvd.DateTimeDecodeValue, dvd.dateTimeDecodeType}). - RegisterTypeDecoder(tNull, decodeAdapter{dvd.NullDecodeValue, dvd.nullDecodeType}). - RegisterTypeDecoder(tRegex, decodeAdapter{dvd.RegexDecodeValue, dvd.regexDecodeType}). - RegisterTypeDecoder(tDBPointer, decodeAdapter{dvd.DBPointerDecodeValue, dvd.dBPointerDecodeType}). - RegisterTypeDecoder(tTimestamp, decodeAdapter{dvd.TimestampDecodeValue, dvd.timestampDecodeType}). - RegisterTypeDecoder(tMinKey, decodeAdapter{dvd.MinKeyDecodeValue, dvd.minKeyDecodeType}). - RegisterTypeDecoder(tMaxKey, decodeAdapter{dvd.MaxKeyDecodeValue, dvd.maxKeyDecodeType}). - RegisterTypeDecoder(tJavaScript, decodeAdapter{dvd.JavaScriptDecodeValue, dvd.javaScriptDecodeType}). - RegisterTypeDecoder(tSymbol, decodeAdapter{dvd.SymbolDecodeValue, dvd.symbolDecodeType}). - RegisterTypeDecoder(tByteSlice, defaultByteSliceCodec). - RegisterTypeDecoder(tTime, defaultTimeCodec). - RegisterTypeDecoder(tEmpty, defaultEmptyInterfaceCodec). - RegisterTypeDecoder(tCoreArray, defaultArrayCodec). - RegisterTypeDecoder(tOID, decodeAdapter{dvd.ObjectIDDecodeValue, dvd.objectIDDecodeType}). - RegisterTypeDecoder(tDecimal, decodeAdapter{dvd.Decimal128DecodeValue, dvd.decimal128DecodeType}). - RegisterTypeDecoder(tJSONNumber, decodeAdapter{dvd.JSONNumberDecodeValue, dvd.jsonNumberDecodeType}). - RegisterTypeDecoder(tURL, decodeAdapter{dvd.URLDecodeValue, dvd.urlDecodeType}). - RegisterTypeDecoder(tCoreDocument, ValueDecoderFunc(dvd.CoreDocumentDecodeValue)). - RegisterTypeDecoder(tCodeWithScope, decodeAdapter{dvd.CodeWithScopeDecodeValue, dvd.codeWithScopeDecodeType}). - RegisterDefaultDecoder(reflect.Bool, decodeAdapter{dvd.BooleanDecodeValue, dvd.booleanDecodeType}). - RegisterDefaultDecoder(reflect.Int, intDecoder). - RegisterDefaultDecoder(reflect.Int8, intDecoder). - RegisterDefaultDecoder(reflect.Int16, intDecoder). - RegisterDefaultDecoder(reflect.Int32, intDecoder). - RegisterDefaultDecoder(reflect.Int64, intDecoder). - RegisterDefaultDecoder(reflect.Uint, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint8, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint16, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint32, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint64, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Float32, floatDecoder). - RegisterDefaultDecoder(reflect.Float64, floatDecoder). - RegisterDefaultDecoder(reflect.Array, ValueDecoderFunc(dvd.ArrayDecodeValue)). - RegisterDefaultDecoder(reflect.Map, defaultMapCodec). - RegisterDefaultDecoder(reflect.Slice, defaultSliceCodec). - RegisterDefaultDecoder(reflect.String, defaultStringCodec). - RegisterDefaultDecoder(reflect.Struct, newDefaultStructCodec()). - RegisterDefaultDecoder(reflect.Ptr, NewPointerCodec()). + RegisterTypeDecoder(tD, func(*Registry) ValueDecoder { return ValueDecoderFunc(dDecodeValue) }). + RegisterTypeDecoder(tBinary, func(*Registry) ValueDecoder { return &decodeAdapter{binaryDecodeValue, binaryDecodeType} }). + RegisterTypeDecoder(tUndefined, func(*Registry) ValueDecoder { return &decodeAdapter{undefinedDecodeValue, undefinedDecodeType} }). + RegisterTypeDecoder(tDateTime, func(*Registry) ValueDecoder { return &decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType} }). + RegisterTypeDecoder(tNull, func(*Registry) ValueDecoder { return &decodeAdapter{nullDecodeValue, nullDecodeType} }). + RegisterTypeDecoder(tRegex, func(*Registry) ValueDecoder { return &decodeAdapter{regexDecodeValue, regexDecodeType} }). + RegisterTypeDecoder(tDBPointer, func(*Registry) ValueDecoder { return &decodeAdapter{dbPointerDecodeValue, dbPointerDecodeType} }). + RegisterTypeDecoder(tTimestamp, func(*Registry) ValueDecoder { return &decodeAdapter{timestampDecodeValue, timestampDecodeType} }). + RegisterTypeDecoder(tMinKey, func(*Registry) ValueDecoder { return &decodeAdapter{minKeyDecodeValue, minKeyDecodeType} }). + RegisterTypeDecoder(tMaxKey, func(*Registry) ValueDecoder { return &decodeAdapter{maxKeyDecodeValue, maxKeyDecodeType} }). + RegisterTypeDecoder(tJavaScript, func(*Registry) ValueDecoder { return &decodeAdapter{javaScriptDecodeValue, javaScriptDecodeType} }). + RegisterTypeDecoder(tSymbol, func(*Registry) ValueDecoder { return &decodeAdapter{symbolDecodeValue, symbolDecodeType} }). + RegisterTypeDecoder(tByteSlice, func(*Registry) ValueDecoder { return &byteSliceCodec{} }). + RegisterTypeDecoder(tTime, func(*Registry) ValueDecoder { return &timeCodec{} }). + RegisterTypeDecoder(tEmpty, func(*Registry) ValueDecoder { return &emptyInterfaceCodec{} }). + RegisterTypeDecoder(tCoreArray, func(*Registry) ValueDecoder { return &arrayCodec{} }). + RegisterTypeDecoder(tOID, func(*Registry) ValueDecoder { return &decodeAdapter{objectIDDecodeValue, objectIDDecodeType} }). + RegisterTypeDecoder(tDecimal, func(*Registry) ValueDecoder { return &decodeAdapter{decimal128DecodeValue, decimal128DecodeType} }). + RegisterTypeDecoder(tJSONNumber, func(*Registry) ValueDecoder { return &decodeAdapter{jsonNumberDecodeValue, jsonNumberDecodeType} }). + RegisterTypeDecoder(tURL, func(*Registry) ValueDecoder { return &decodeAdapter{urlDecodeValue, urlDecodeType} }). + RegisterTypeDecoder(tCoreDocument, func(*Registry) ValueDecoder { return ValueDecoderFunc(coreDocumentDecodeValue) }). + RegisterTypeDecoder(tCodeWithScope, func(*Registry) ValueDecoder { return &decodeAdapter{codeWithScopeDecodeValue, codeWithScopeDecodeType} }). + RegisterKindDecoder(reflect.Bool, func(*Registry) ValueDecoder { return &decodeAdapter{booleanDecodeValue, booleanDecodeType} }). + RegisterKindDecoder(reflect.Int, numDecoder). + RegisterKindDecoder(reflect.Int8, numDecoder). + RegisterKindDecoder(reflect.Int16, numDecoder). + RegisterKindDecoder(reflect.Int32, numDecoder). + RegisterKindDecoder(reflect.Int64, numDecoder). + RegisterKindDecoder(reflect.Uint, numDecoder). + RegisterKindDecoder(reflect.Uint8, numDecoder). + RegisterKindDecoder(reflect.Uint16, numDecoder). + RegisterKindDecoder(reflect.Uint32, numDecoder). + RegisterKindDecoder(reflect.Uint64, numDecoder). + RegisterKindDecoder(reflect.Float32, numDecoder). + RegisterKindDecoder(reflect.Float64, numDecoder). + RegisterKindDecoder(reflect.Array, func(*Registry) ValueDecoder { return ValueDecoderFunc(arrayDecodeValue) }). + RegisterKindDecoder(reflect.Map, func(*Registry) ValueDecoder { return &mapCodec{} }). + RegisterKindDecoder(reflect.Slice, func(*Registry) ValueDecoder { return &sliceCodec{} }). + RegisterKindDecoder(reflect.String, func(*Registry) ValueDecoder { return &stringCodec{} }). + RegisterKindDecoder(reflect.Struct, func(*Registry) ValueDecoder { return newStructCodec(nil) }). + RegisterKindDecoder(reflect.Ptr, func(*Registry) ValueDecoder { return &pointerCodec{} }). RegisterTypeMapEntry(TypeDouble, tFloat64). RegisterTypeMapEntry(TypeString, tString). RegisterTypeMapEntry(TypeArray, tA). @@ -130,22 +104,19 @@ func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) { RegisterTypeMapEntry(TypeMaxKey, tMaxKey). RegisterTypeMapEntry(Type(0), tD). RegisterTypeMapEntry(TypeEmbeddedDocument, tD). - RegisterHookDecoder(tValueUnmarshaler, ValueDecoderFunc(dvd.ValueUnmarshalerDecodeValue)). - RegisterHookDecoder(tUnmarshaler, ValueDecoderFunc(dvd.UnmarshalerDecodeValue)) + RegisterInterfaceDecoder(tValueUnmarshaler, func(*Registry) ValueDecoder { return ValueDecoderFunc(valueUnmarshalerDecodeValue) }). + RegisterInterfaceDecoder(tUnmarshaler, func(*Registry) ValueDecoder { return ValueDecoderFunc(unmarshalerDecodeValue) }) } -// DDecodeValue is the ValueDecoderFunc for D instances. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// dDecodeValue is the ValueDecoderFunc for D instances. +func dDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || !val.CanSet() || val.Type() != tD { return ValueDecoderError{Name: "DDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } switch vrType := vr.Type(); vrType { case Type(0), TypeEmbeddedDocument: - dc.Ancestor = tD + break case TypeNull: val.Set(reflect.Zero(val.Type())) return vr.ReadNull() @@ -158,11 +129,10 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr ValueReader, v return err } - decoder, err := dc.LookupDecoder(tEmpty) + decoder, err := reg.LookupDecoder(tEmpty) if err != nil { return err } - tEmptyTypeDecoder, _ := decoder.(typeDecoder) // Use the elements in the provided value if it's non nil. Otherwise, allocate a new D instance. var elems D @@ -181,8 +151,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr ValueReader, v return err } - // Pass false for convert because we don't need to call reflect.Value.Convert for tEmpty. - elem, err := decodeTypeOrValueWithInfo(decoder, tEmptyTypeDecoder, dc, elemVr, tEmpty, false) + elem, err := decodeTypeOrValueWithInfo(decoder, reg, elemVr, tD) if err != nil { return err } @@ -194,7 +163,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr ValueReader, v return nil } -func (dvd DefaultValueDecoders) booleanDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func booleanDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t.Kind() != reflect.Bool { return emptyValue, ValueDecoderError{ Name: "BooleanDecodeValue", @@ -240,16 +209,13 @@ func (dvd DefaultValueDecoders) booleanDecodeType(_ DecodeContext, vr ValueReade return reflect.ValueOf(b), nil } -// BooleanDecodeValue is the ValueDecoderFunc for bool types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// booleanDecodeValue is the ValueDecoderFunc for bool types. +func booleanDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool { return ValueDecoderError{Name: "BooleanDecodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } - elem, err := dvd.booleanDecodeType(dctx, vr, val.Type()) + elem, err := booleanDecodeType(reg, vr, val.Type()) if err != nil { return err } @@ -258,300 +224,7 @@ func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr ValueR return nil } -func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - var i64 int64 - var err error - switch vrType := vr.Type(); vrType { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return emptyValue, err - } - i64 = int64(i32) - case TypeInt64: - i64, err = vr.ReadInt64() - if err != nil { - return emptyValue, err - } - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return emptyValue, err - } - if !dc.Truncate && math.Floor(f64) != f64 { - return emptyValue, errCannotTruncate - } - if f64 > float64(math.MaxInt64) { - return emptyValue, fmt.Errorf("%g overflows int64", f64) - } - i64 = int64(f64) - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return emptyValue, err - } - if b { - i64 = 1 - } - case TypeNull: - if err = vr.ReadNull(); err != nil { - return emptyValue, err - } - case TypeUndefined: - if err = vr.ReadUndefined(); err != nil { - return emptyValue, err - } - default: - return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType) - } - - switch t.Kind() { - case reflect.Int8: - if i64 < math.MinInt8 || i64 > math.MaxInt8 { - return emptyValue, fmt.Errorf("%d overflows int8", i64) - } - - return reflect.ValueOf(int8(i64)), nil - case reflect.Int16: - if i64 < math.MinInt16 || i64 > math.MaxInt16 { - return emptyValue, fmt.Errorf("%d overflows int16", i64) - } - - return reflect.ValueOf(int16(i64)), nil - case reflect.Int32: - if i64 < math.MinInt32 || i64 > math.MaxInt32 { - return emptyValue, fmt.Errorf("%d overflows int32", i64) - } - - return reflect.ValueOf(int32(i64)), nil - case reflect.Int64: - return reflect.ValueOf(i64), nil - case reflect.Int: - if int64(int(i64)) != i64 { // Can we fit this inside of an int - return emptyValue, fmt.Errorf("%d overflows int", i64) - } - - return reflect.ValueOf(int(i64)), nil - default: - return emptyValue, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, - Received: reflect.Zero(t), - } - } -} - -// IntDecodeValue is the ValueDecoderFunc for int types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() { - return ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, - Received: val, - } - } - - elem, err := dvd.intDecodeType(dc, vr, val.Type()) - if err != nil { - return err - } - - val.SetInt(elem.Int()) - return nil -} - -// UintDecodeValue is the ValueDecoderFunc for uint types. -// -// Deprecated: UintDecodeValue is not registered by default. Use UintCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) UintDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - var i64 int64 - var err error - switch vr.Type() { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return err - } - i64 = int64(i32) - case TypeInt64: - i64, err = vr.ReadInt64() - if err != nil { - return err - } - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return err - } - if !dc.Truncate && math.Floor(f64) != f64 { - return errors.New("UintDecodeValue can only truncate float64 to an integer type when truncation is enabled") - } - if f64 > float64(math.MaxInt64) { - return fmt.Errorf("%g overflows int64", f64) - } - i64 = int64(f64) - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return err - } - if b { - i64 = 1 - } - default: - return fmt.Errorf("cannot decode %v into an integer type", vr.Type()) - } - - if !val.CanSet() { - return ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } - } - - switch val.Kind() { - case reflect.Uint8: - if i64 < 0 || i64 > math.MaxUint8 { - return fmt.Errorf("%d overflows uint8", i64) - } - case reflect.Uint16: - if i64 < 0 || i64 > math.MaxUint16 { - return fmt.Errorf("%d overflows uint16", i64) - } - case reflect.Uint32: - if i64 < 0 || i64 > math.MaxUint32 { - return fmt.Errorf("%d overflows uint32", i64) - } - case reflect.Uint64: - if i64 < 0 { - return fmt.Errorf("%d overflows uint64", i64) - } - case reflect.Uint: - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint - return fmt.Errorf("%d overflows uint", i64) - } - default: - return ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } - } - - val.SetUint(uint64(i64)) - return nil -} - -func (dvd DefaultValueDecoders) floatDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - var f float64 - var err error - switch vrType := vr.Type(); vrType { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return emptyValue, err - } - f = float64(i32) - case TypeInt64: - i64, err := vr.ReadInt64() - if err != nil { - return emptyValue, err - } - f = float64(i64) - case TypeDouble: - f, err = vr.ReadDouble() - if err != nil { - return emptyValue, err - } - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return emptyValue, err - } - if b { - f = 1 - } - case TypeNull: - if err = vr.ReadNull(); err != nil { - return emptyValue, err - } - case TypeUndefined: - if err = vr.ReadUndefined(); err != nil { - return emptyValue, err - } - default: - return emptyValue, fmt.Errorf("cannot decode %v into a float32 or float64 type", vrType) - } - - switch t.Kind() { - case reflect.Float32: - if !dc.Truncate && float64(float32(f)) != f { - return emptyValue, errCannotTruncate - } - - return reflect.ValueOf(float32(f)), nil - case reflect.Float64: - return reflect.ValueOf(f), nil - default: - return emptyValue, ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, - Received: reflect.Zero(t), - } - } -} - -// FloatDecodeValue is the ValueDecoderFunc for float types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() { - return ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, - Received: val, - } - } - - elem, err := dvd.floatDecodeType(ec, vr, val.Type()) - if err != nil { - return err - } - - val.SetFloat(elem.Float()) - return nil -} - -// StringDecodeValue is the ValueDecoderFunc for string types. -// -// Deprecated: StringDecodeValue is not registered by default. Use StringCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) StringDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { - var str string - var err error - switch vr.Type() { - // TODO(GODRIVER-577): Handle JavaScript and Symbol BSON types when allowed. - case TypeString: - str, err = vr.ReadString() - if err != nil { - return err - } - default: - return fmt.Errorf("cannot decode %v into a string type", vr.Type()) - } - if !val.CanSet() || val.Kind() != reflect.String { - return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} - } - - val.SetString(str) - return nil -} - -func (DefaultValueDecoders) javaScriptDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func javaScriptDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tJavaScript { return emptyValue, ValueDecoderError{ Name: "JavaScriptDecodeValue", @@ -579,16 +252,13 @@ func (DefaultValueDecoders) javaScriptDecodeType(_ DecodeContext, vr ValueReader return reflect.ValueOf(JavaScript(js)), nil } -// JavaScriptDecodeValue is the ValueDecoderFunc for the JavaScript type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// javaScriptDecodeValue is the ValueDecoderFunc for the JavaScript type. +func javaScriptDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tJavaScript { return ValueDecoderError{Name: "JavaScriptDecodeValue", Types: []reflect.Type{tJavaScript}, Received: val} } - elem, err := dvd.javaScriptDecodeType(dctx, vr, tJavaScript) + elem, err := javaScriptDecodeType(reg, vr, tJavaScript) if err != nil { return err } @@ -597,7 +267,7 @@ func (dvd DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr Val return nil } -func (DefaultValueDecoders) symbolDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func symbolDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tSymbol { return emptyValue, ValueDecoderError{ Name: "SymbolDecodeValue", @@ -637,16 +307,13 @@ func (DefaultValueDecoders) symbolDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(Symbol(symbol)), nil } -// SymbolDecodeValue is the ValueDecoderFunc for the Symbol type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// symbolDecodeValue is the ValueDecoderFunc for the Symbol type. +func symbolDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tSymbol { return ValueDecoderError{Name: "SymbolDecodeValue", Types: []reflect.Type{tSymbol}, Received: val} } - elem, err := dvd.symbolDecodeType(dctx, vr, tSymbol) + elem, err := symbolDecodeType(reg, vr, tSymbol) if err != nil { return err } @@ -655,7 +322,7 @@ func (dvd DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr ValueRe return nil } -func (DefaultValueDecoders) binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func binaryDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tBinary { return emptyValue, ValueDecoderError{ Name: "BinaryDecodeValue", @@ -684,16 +351,13 @@ func (DefaultValueDecoders) binaryDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(Binary{Subtype: subtype, Data: data}), nil } -// BinaryDecodeValue is the ValueDecoderFunc for Binary. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// binaryDecodeValue is the ValueDecoderFunc for Binary. +func binaryDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tBinary { return ValueDecoderError{Name: "BinaryDecodeValue", Types: []reflect.Type{tBinary}, Received: val} } - elem, err := dvd.binaryDecodeType(dc, vr, tBinary) + elem, err := binaryDecodeType(reg, vr, tBinary) if err != nil { return err } @@ -702,7 +366,7 @@ func (dvd DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr ValueRead return nil } -func (DefaultValueDecoders) undefinedDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func undefinedDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tUndefined { return emptyValue, ValueDecoderError{ Name: "UndefinedDecodeValue", @@ -727,16 +391,13 @@ func (DefaultValueDecoders) undefinedDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(Undefined{}), nil } -// UndefinedDecodeValue is the ValueDecoderFunc for Undefined. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// undefinedDecodeValue is the ValueDecoderFunc for Undefined. +func undefinedDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tUndefined { return ValueDecoderError{Name: "UndefinedDecodeValue", Types: []reflect.Type{tUndefined}, Received: val} } - elem, err := dvd.undefinedDecodeType(dc, vr, tUndefined) + elem, err := undefinedDecodeType(reg, vr, tUndefined) if err != nil { return err } @@ -746,7 +407,7 @@ func (dvd DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr ValueR } // Accept both 12-byte string and pretty-printed 24-byte hex string formats. -func (dvd DefaultValueDecoders) objectIDDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func objectIDDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tOID { return emptyValue, ValueDecoderError{ Name: "ObjectIDDecodeValue", @@ -791,16 +452,13 @@ func (dvd DefaultValueDecoders) objectIDDecodeType(_ DecodeContext, vr ValueRead return reflect.ValueOf(oid), nil } -// ObjectIDDecodeValue is the ValueDecoderFunc for ObjectID. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// objectIDDecodeValue is the ValueDecoderFunc for ObjectID. +func objectIDDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tOID { return ValueDecoderError{Name: "ObjectIDDecodeValue", Types: []reflect.Type{tOID}, Received: val} } - elem, err := dvd.objectIDDecodeType(dc, vr, tOID) + elem, err := objectIDDecodeType(reg, vr, tOID) if err != nil { return err } @@ -809,7 +467,7 @@ func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr ValueRe return nil } -func (DefaultValueDecoders) dateTimeDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func dateTimeDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDateTime { return emptyValue, ValueDecoderError{ Name: "DateTimeDecodeValue", @@ -837,16 +495,13 @@ func (DefaultValueDecoders) dateTimeDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(DateTime(dt)), nil } -// DateTimeDecodeValue is the ValueDecoderFunc for DateTime. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// dateTimeDecodeValue is the ValueDecoderFunc for DateTime. +func dateTimeDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDateTime { return ValueDecoderError{Name: "DateTimeDecodeValue", Types: []reflect.Type{tDateTime}, Received: val} } - elem, err := dvd.dateTimeDecodeType(dc, vr, tDateTime) + elem, err := dateTimeDecodeType(reg, vr, tDateTime) if err != nil { return err } @@ -855,7 +510,7 @@ func (dvd DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr ValueRe return nil } -func (DefaultValueDecoders) nullDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func nullDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tNull { return emptyValue, ValueDecoderError{ Name: "NullDecodeValue", @@ -880,16 +535,13 @@ func (DefaultValueDecoders) nullDecodeType(_ DecodeContext, vr ValueReader, t re return reflect.ValueOf(Null{}), nil } -// NullDecodeValue is the ValueDecoderFunc for Null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// nullDecodeValue is the ValueDecoderFunc for Null. +func nullDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tNull { return ValueDecoderError{Name: "NullDecodeValue", Types: []reflect.Type{tNull}, Received: val} } - elem, err := dvd.nullDecodeType(dc, vr, tNull) + elem, err := nullDecodeType(reg, vr, tNull) if err != nil { return err } @@ -898,7 +550,7 @@ func (dvd DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr ValueReader return nil } -func (DefaultValueDecoders) regexDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func regexDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tRegex { return emptyValue, ValueDecoderError{ Name: "RegexDecodeValue", @@ -926,16 +578,13 @@ func (DefaultValueDecoders) regexDecodeType(_ DecodeContext, vr ValueReader, t r return reflect.ValueOf(Regex{Pattern: pattern, Options: options}), nil } -// RegexDecodeValue is the ValueDecoderFunc for Regex. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// regexDecodeValue is the ValueDecoderFunc for Regex. +func regexDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRegex { return ValueDecoderError{Name: "RegexDecodeValue", Types: []reflect.Type{tRegex}, Received: val} } - elem, err := dvd.regexDecodeType(dc, vr, tRegex) + elem, err := regexDecodeType(reg, vr, tRegex) if err != nil { return err } @@ -944,7 +593,7 @@ func (dvd DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr ValueReade return nil } -func (DefaultValueDecoders) dBPointerDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func dbPointerDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDBPointer { return emptyValue, ValueDecoderError{ Name: "DBPointerDecodeValue", @@ -973,16 +622,13 @@ func (DefaultValueDecoders) dBPointerDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(DBPointer{DB: ns, Pointer: pointer}), nil } -// DBPointerDecodeValue is the ValueDecoderFunc for DBPointer. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// dbPointerDecodeValue is the ValueDecoderFunc for DBPointer. +func dbPointerDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDBPointer { return ValueDecoderError{Name: "DBPointerDecodeValue", Types: []reflect.Type{tDBPointer}, Received: val} } - elem, err := dvd.dBPointerDecodeType(dc, vr, tDBPointer) + elem, err := dbPointerDecodeType(reg, vr, tDBPointer) if err != nil { return err } @@ -991,7 +637,7 @@ func (dvd DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr ValueR return nil } -func (DefaultValueDecoders) timestampDecodeType(_ DecodeContext, vr ValueReader, reflectType reflect.Type) (reflect.Value, error) { +func timestampDecodeType(_ DecoderRegistry, vr ValueReader, reflectType reflect.Type) (reflect.Value, error) { if reflectType != tTimestamp { return emptyValue, ValueDecoderError{ Name: "TimestampDecodeValue", @@ -1019,16 +665,13 @@ func (DefaultValueDecoders) timestampDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(Timestamp{T: t, I: incr}), nil } -// TimestampDecodeValue is the ValueDecoderFunc for Timestamp. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// timestampDecodeValue is the ValueDecoderFunc for Timestamp. +func timestampDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tTimestamp { return ValueDecoderError{Name: "TimestampDecodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } - elem, err := dvd.timestampDecodeType(dc, vr, tTimestamp) + elem, err := timestampDecodeType(reg, vr, tTimestamp) if err != nil { return err } @@ -1037,7 +680,7 @@ func (dvd DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr ValueR return nil } -func (DefaultValueDecoders) minKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func minKeyDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tMinKey { return emptyValue, ValueDecoderError{ Name: "MinKeyDecodeValue", @@ -1064,16 +707,13 @@ func (DefaultValueDecoders) minKeyDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(MinKey{}), nil } -// MinKeyDecodeValue is the ValueDecoderFunc for MinKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// minKeyDecodeValue is the ValueDecoderFunc for MinKey. +func minKeyDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tMinKey { return ValueDecoderError{Name: "MinKeyDecodeValue", Types: []reflect.Type{tMinKey}, Received: val} } - elem, err := dvd.minKeyDecodeType(dc, vr, tMinKey) + elem, err := minKeyDecodeType(reg, vr, tMinKey) if err != nil { return err } @@ -1082,7 +722,7 @@ func (dvd DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr ValueRead return nil } -func (DefaultValueDecoders) maxKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func maxKeyDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tMaxKey { return emptyValue, ValueDecoderError{ Name: "MaxKeyDecodeValue", @@ -1109,16 +749,13 @@ func (DefaultValueDecoders) maxKeyDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(MaxKey{}), nil } -// MaxKeyDecodeValue is the ValueDecoderFunc for MaxKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// maxKeyDecodeValue is the ValueDecoderFunc for MaxKey. +func maxKeyDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tMaxKey { return ValueDecoderError{Name: "MaxKeyDecodeValue", Types: []reflect.Type{tMaxKey}, Received: val} } - elem, err := dvd.maxKeyDecodeType(dc, vr, tMaxKey) + elem, err := maxKeyDecodeType(reg, vr, tMaxKey) if err != nil { return err } @@ -1127,7 +764,7 @@ func (dvd DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr ValueRead return nil } -func (dvd DefaultValueDecoders) decimal128DecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func decimal128DecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDecimal { return emptyValue, ValueDecoderError{ Name: "Decimal128DecodeValue", @@ -1155,16 +792,13 @@ func (dvd DefaultValueDecoders) decimal128DecodeType(_ DecodeContext, vr ValueRe return reflect.ValueOf(d128), nil } -// Decimal128DecodeValue is the ValueDecoderFunc for Decimal128. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// decimal128DecodeValue is the ValueDecoderFunc for Decimal128. +func decimal128DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDecimal { return ValueDecoderError{Name: "Decimal128DecodeValue", Types: []reflect.Type{tDecimal}, Received: val} } - elem, err := dvd.decimal128DecodeType(dctx, vr, tDecimal) + elem, err := decimal128DecodeType(reg, vr, tDecimal) if err != nil { return err } @@ -1173,7 +807,7 @@ func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr Val return nil } -func (dvd DefaultValueDecoders) jsonNumberDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func jsonNumberDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tJSONNumber { return emptyValue, ValueDecoderError{ Name: "JSONNumberDecodeValue", @@ -1217,16 +851,13 @@ func (dvd DefaultValueDecoders) jsonNumberDecodeType(_ DecodeContext, vr ValueRe return reflect.ValueOf(jsonNum), nil } -// JSONNumberDecodeValue is the ValueDecoderFunc for json.Number. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// jsonNumberDecodeValue is the ValueDecoderFunc for json.Number. +func jsonNumberDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tJSONNumber { return ValueDecoderError{Name: "JSONNumberDecodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} } - elem, err := dvd.jsonNumberDecodeType(dc, vr, tJSONNumber) + elem, err := jsonNumberDecodeType(reg, vr, tJSONNumber) if err != nil { return err } @@ -1235,7 +866,7 @@ func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr Value return nil } -func (dvd DefaultValueDecoders) urlDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func urlDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tURL { return emptyValue, ValueDecoderError{ Name: "URLDecodeValue", @@ -1269,16 +900,13 @@ func (dvd DefaultValueDecoders) urlDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(urlPtr).Elem(), nil } -// URLDecodeValue is the ValueDecoderFunc for url.URL. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// urlDecodeValue is the ValueDecoderFunc for url.URL. +func urlDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tURL { return ValueDecoderError{Name: "URLDecodeValue", Types: []reflect.Type{tURL}, Received: val} } - elem, err := dvd.urlDecodeType(dc, vr, tURL) + elem, err := urlDecodeType(reg, vr, tURL) if err != nil { return err } @@ -1287,119 +915,8 @@ func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr ValueReader, return nil } -// TimeDecodeValue is the ValueDecoderFunc for time.Time. -// -// Deprecated: TimeDecodeValue is not registered by default. Use TimeCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) TimeDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { - if vr.Type() != TypeDateTime { - return fmt.Errorf("cannot decode %v into a time.Time", vr.Type()) - } - - dt, err := vr.ReadDateTime() - if err != nil { - return err - } - - if !val.CanSet() || val.Type() != tTime { - return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val} - } - - val.Set(reflect.ValueOf(time.Unix(dt/1000, dt%1000*1000000).UTC())) - return nil -} - -// ByteSliceDecodeValue is the ValueDecoderFunc for []byte. -// -// Deprecated: ByteSliceDecodeValue is not registered by default. Use ByteSliceCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) ByteSliceDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { - if vr.Type() != TypeBinary && vr.Type() != TypeNull { - return fmt.Errorf("cannot decode %v into a []byte", vr.Type()) - } - - if !val.CanSet() || val.Type() != tByteSlice { - return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val} - } - - if vr.Type() == TypeNull { - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - } - - data, subtype, err := vr.ReadBinary() - if err != nil { - return err - } - if subtype != 0x00 { - return fmt.Errorf("ByteSliceDecodeValue can only be used to decode subtype 0x00 for %s, got %v", TypeBinary, subtype) - } - - val.Set(reflect.ValueOf(data)) - return nil -} - -// MapDecodeValue is the ValueDecoderFunc for map[string]* types. -// -// Deprecated: MapDecodeValue is not registered by default. Use MapCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) MapDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String { - return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} - } - - switch vr.Type() { - case Type(0), TypeEmbeddedDocument: - case TypeNull: - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - default: - return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type()) - } - - dr, err := vr.ReadDocument() - if err != nil { - return err - } - - if val.IsNil() { - val.Set(reflect.MakeMap(val.Type())) - } - - eType := val.Type().Elem() - decoder, err := dc.LookupDecoder(eType) - if err != nil { - return err - } - - if eType == tEmpty { - dc.Ancestor = val.Type() - } - - keyType := val.Type().Key() - for { - key, vr, err := dr.ReadElement() - if errors.Is(err, ErrEOD) { - break - } - if err != nil { - return err - } - - elem := reflect.New(eType).Elem() - - err = decoder.DecodeValue(dc, vr, elem) - if err != nil { - return err - } - - val.SetMapIndex(reflect.ValueOf(key).Convert(keyType), elem) - } - return nil -} - -// ArrayDecodeValue is the ValueDecoderFunc for array types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// arrayDecodeValue is the ValueDecoderFunc for array types. +func arrayDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueDecoderError{Name: "ArrayDecodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} } @@ -1440,15 +957,15 @@ func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr ValueReade return fmt.Errorf("cannot decode %v into an array", vrType) } - var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error) + var elemsFunc func(DecoderRegistry, ValueReader, reflect.Value) ([]reflect.Value, error) switch val.Type().Elem() { case tE: - elemsFunc = dvd.decodeD + elemsFunc = decodeD default: - elemsFunc = dvd.decodeDefault + elemsFunc = decodeDefault } - elems, err := elemsFunc(dc, vr, val) + elems, err := elemsFunc(reg, vr, val) if err != nil { return err } @@ -1464,56 +981,8 @@ func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr ValueReade return nil } -// SliceDecodeValue is the ValueDecoderFunc for slice types. -// -// Deprecated: SliceDecodeValue is not registered by default. Use SliceCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) SliceDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Kind() != reflect.Slice { - return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} - } - - switch vr.Type() { - case TypeArray: - case TypeNull: - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - case Type(0), TypeEmbeddedDocument: - if val.Type().Elem() != tE { - return fmt.Errorf("cannot decode document into %s", val.Type()) - } - default: - return fmt.Errorf("cannot decode %v into a slice", vr.Type()) - } - - var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error) - switch val.Type().Elem() { - case tE: - dc.Ancestor = val.Type() - elemsFunc = dvd.decodeD - default: - elemsFunc = dvd.decodeDefault - } - - elems, err := elemsFunc(dc, vr, val) - if err != nil { - return err - } - - if val.IsNil() { - val.Set(reflect.MakeSlice(val.Type(), 0, len(elems))) - } - - val.SetLen(0) - val.Set(reflect.Append(val, elems...)) - - return nil -} - -// ValueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// valueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations. +func valueUnmarshalerDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tValueUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tValueUnmarshaler)) { return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } @@ -1545,11 +1014,8 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr return m.UnmarshalBSONValue(t, src) } -// UnmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// unmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations. +func unmarshalerDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tUnmarshaler)) { return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} } @@ -1593,51 +1059,8 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr Value return m.UnmarshalBSON(src) } -// EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}. -// -// Deprecated: EmptyInterfaceDecodeValue is not registered by default. Use EmptyInterfaceCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) EmptyInterfaceDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tEmpty { - return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} - } - - rtype, err := dc.LookupTypeMapEntry(vr.Type()) - if err != nil { - switch vr.Type() { - case TypeEmbeddedDocument: - if dc.Ancestor != nil { - rtype = dc.Ancestor - break - } - rtype = tD - case TypeNull: - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - default: - return err - } - } - - decoder, err := dc.LookupDecoder(rtype) - if err != nil { - return err - } - - elem := reflect.New(rtype).Elem() - err = decoder.DecodeValue(dc, vr, elem) - if err != nil { - return err - } - - val.Set(elem) - return nil -} - -// CoreDocumentDecodeValue is the ValueDecoderFunc for bsoncore.Document. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (DefaultValueDecoders) CoreDocumentDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// coreDocumentDecodeValue is the ValueDecoderFunc for bsoncore.Document. +func coreDocumentDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreDocument { return ValueDecoderError{Name: "CoreDocumentDecodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -1653,7 +1076,7 @@ func (DefaultValueDecoders) CoreDocumentDecodeValue(_ DecodeContext, vr ValueRea return err } -func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]reflect.Value, error) { +func decodeDefault(reg DecoderRegistry, vr ValueReader, val reflect.Value) ([]reflect.Value, error) { elems := make([]reflect.Value, 0) ar, err := vr.ReadArray() @@ -1663,11 +1086,10 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr ValueReader, eType := val.Type().Elem() - decoder, err := dc.LookupDecoder(eType) + decoder, err := reg.LookupDecoder(eType) if err != nil { return nil, err } - eTypeDecoder, _ := decoder.(typeDecoder) idx := 0 for { @@ -1679,10 +1101,13 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr ValueReader, return nil, err } - elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true) + elem, err := decodeTypeOrValueWithInfo(decoder, reg, vr, eType) if err != nil { return nil, newDecodeError(strconv.Itoa(idx), err) } + if elem.Type() != eType { + elem = elem.Convert(eType) + } elems = append(elems, elem) idx++ } @@ -1690,31 +1115,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr ValueReader, return elems, nil } -func (dvd DefaultValueDecoders) readCodeWithScope(dc DecodeContext, vr ValueReader) (CodeWithScope, error) { - var cws CodeWithScope - - code, dr, err := vr.ReadCodeWithScope() - if err != nil { - return cws, err - } - - scope := reflect.New(tD).Elem() - elems, err := dvd.decodeElemsFromDocumentReader(dc, dr) - if err != nil { - return cws, err - } - - scope.Set(reflect.MakeSlice(tD, 0, len(elems))) - scope.Set(reflect.Append(scope, elems...)) - - cws = CodeWithScope{ - Code: JavaScript(code), - Scope: scope.Interface().(D), - } - return cws, nil -} - -func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func codeWithScopeDecodeType(reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tCodeWithScope { return emptyValue, ValueDecoderError{ Name: "CodeWithScopeDecodeValue", @@ -1727,7 +1128,24 @@ func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr Val var err error switch vrType := vr.Type(); vrType { case TypeCodeWithScope: - cws, err = dvd.readCodeWithScope(dc, vr) + code, dr, err := vr.ReadCodeWithScope() + if err != nil { + return emptyValue, err + } + + scope := reflect.New(tD).Elem() + elems, err := decodeElemsFromDocumentReader(reg, dr, tEmpty) + if err != nil { + return emptyValue, err + } + + scope.Set(reflect.MakeSlice(tD, 0, len(elems))) + scope.Set(reflect.Append(scope, elems...)) + + cws = CodeWithScope{ + Code: JavaScript(code), + Scope: scope.Interface().(D), + } case TypeNull: err = vr.ReadNull() case TypeUndefined: @@ -1742,16 +1160,13 @@ func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr Val return reflect.ValueOf(cws), nil } -// CodeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) CodeWithScopeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// codeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope. +func codeWithScopeDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCodeWithScope { return ValueDecoderError{Name: "CodeWithScopeDecodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} } - elem, err := dvd.codeWithScopeDecodeType(dc, vr, tCodeWithScope) + elem, err := codeWithScopeDecodeType(reg, vr, tCodeWithScope) if err != nil { return err } @@ -1760,9 +1175,10 @@ func (dvd DefaultValueDecoders) CodeWithScopeDecodeValue(dc DecodeContext, vr Va return nil } -func (dvd DefaultValueDecoders) decodeD(dc DecodeContext, vr ValueReader, _ reflect.Value) ([]reflect.Value, error) { +func decodeD(reg DecoderRegistry, vr ValueReader, val reflect.Value) ([]reflect.Value, error) { switch vr.Type() { case Type(0), TypeEmbeddedDocument: + break default: return nil, fmt.Errorf("cannot decode %v into a D", vr.Type()) } @@ -1772,11 +1188,11 @@ func (dvd DefaultValueDecoders) decodeD(dc DecodeContext, vr ValueReader, _ refl return nil, err } - return dvd.decodeElemsFromDocumentReader(dc, dr) + return decodeElemsFromDocumentReader(reg, dr, val.Type()) } -func (DefaultValueDecoders) decodeElemsFromDocumentReader(dc DecodeContext, dr DocumentReader) ([]reflect.Value, error) { - decoder, err := dc.LookupDecoder(tEmpty) +func decodeElemsFromDocumentReader(reg DecoderRegistry, dr DocumentReader, t reflect.Type) ([]reflect.Value, error) { + decoder, err := reg.LookupDecoder(tEmpty) if err != nil { return nil, err } @@ -1791,8 +1207,8 @@ func (DefaultValueDecoders) decodeElemsFromDocumentReader(dc DecodeContext, dr D return nil, err } - val := reflect.New(tEmpty).Elem() - err = decoder.DecodeValue(dc, vr, val) + var val reflect.Value + val, err = decodeTypeOrValueWithInfo(decoder, reg, vr, t) if err != nil { return nil, newDecodeError(key, err) } diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 699a958605..50b1a668c2 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -23,11 +23,10 @@ import ( ) var ( - defaultTestStructCodec = newDefaultStructCodec() + defaultTestStructCodec = newStructCodec(nil) ) func TestDefaultValueDecoders(t *testing.T) { - var dvd DefaultValueDecoders var wrong = func(string, string) string { return "wrong" } type mybool bool @@ -58,7 +57,7 @@ func TestDefaultValueDecoders(t *testing.T) { type subtest struct { name string val interface{} - dctx *DecodeContext + reg *Registry llvrw *valueReaderWriter invoke invoked err error @@ -71,7 +70,7 @@ func TestDefaultValueDecoders(t *testing.T) { }{ { "BooleanDecodeValue", - ValueDecoderFunc(dvd.BooleanDecodeValue), + ValueDecoderFunc(booleanDecodeValue), []subtest{ { "wrong type", @@ -140,17 +139,21 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "IntDecodeValue", - ValueDecoderFunc(dvd.IntDecodeValue), + &numCodec{}, []subtest{ { "wrong type", wrong, nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, - readInt32, + nothing, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, @@ -187,15 +190,10 @@ func TestDefaultValueDecoders(t *testing.T) { errors.New("ReadDouble error"), }, { - "ReadDouble", int64(3), &DecodeContext{}, + "ReadDouble", int64(3), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.00)}, readDouble, nil, }, - { - "ReadDouble (truncate)", int64(3), &DecodeContext{Truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, { "ReadDouble (no truncate)", int64(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -213,46 +211,66 @@ func TestDefaultValueDecoders(t *testing.T) { {"int/fast path", int(1234), nil, &valueReaderWriter{BSONType: TypeInt64, Return: int64(1234)}, readInt64, nil}, { "int8/fast path - nil", (*int8)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*int8)(nil)), }, }, { "int16/fast path - nil", (*int16)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*int16)(nil)), }, }, { "int32/fast path - nil", (*int32)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*int32)(nil)), }, }, { "int64/fast path - nil", (*int64)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*int64)(nil)), }, }, { "int/fast path - nil", (*int)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*int)(nil)), }, }, @@ -348,8 +366,12 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, }, }, { @@ -371,18 +393,22 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "defaultUIntCodec.DecodeValue", - defaultUIntCodec, + "UintDecodeValue", + &numCodec{}, []subtest{ { "wrong type", wrong, nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, - readInt32, + nothing, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, @@ -419,15 +445,10 @@ func TestDefaultValueDecoders(t *testing.T) { errors.New("ReadDouble error"), }, { - "ReadDouble", uint64(3), &DecodeContext{}, + "ReadDouble", uint64(3), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.00)}, readDouble, nil, }, - { - "ReadDouble (truncate)", uint64(3), &DecodeContext{Truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, { "ReadDouble (no truncate)", uint64(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -445,46 +466,66 @@ func TestDefaultValueDecoders(t *testing.T) { {"uint/fast path", uint(1234), nil, &valueReaderWriter{BSONType: TypeInt64, Return: int64(1234)}, readInt64, nil}, { "uint8/fast path - nil", (*uint8)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*uint8)(nil)), }, }, { "uint16/fast path - nil", (*uint16)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*uint16)(nil)), }, }, { "uint32/fast path - nil", (*uint32)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*uint32)(nil)), }, }, { "uint64/fast path - nil", (*uint64)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*uint64)(nil)), }, }, { "uint/fast path - nil", (*uint)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*uint)(nil)), }, }, @@ -600,31 +641,39 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, }, }, }, }, { "FloatDecodeValue", - ValueDecoderFunc(dvd.FloatDecodeValue), + &numCodec{}, []subtest{ { "wrong type", wrong, nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, - readDouble, + nothing, ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, { "type not double", - 0, + float64(0), nil, &valueReaderWriter{BSONType: TypeString}, nothing, @@ -674,11 +723,6 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14159)}, readDouble, nil, }, - { - "float32/fast path (truncate)", float32(3.14), &DecodeContext{Truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, { "float32/fast path (no truncate)", float32(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -686,19 +730,27 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "float32/fast path - nil", (*float32)(nil), nil, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, readDouble, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, nothing, ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*float32)(nil)), }, }, { "float64/fast path - nil", (*float64)(nil), nil, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, readDouble, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, nothing, ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*float64)(nil)), }, }, @@ -712,11 +764,6 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14159)}, readDouble, nil, }, - { - "float32/reflection path (truncate)", myfloat32(3.14), &DecodeContext{Truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, { "float32/reflection path (no truncate)", myfloat32(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -729,15 +776,45 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, nothing, ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, }, }, }, }, { - "defaultTimeCodec.DecodeValue", - defaultTimeCodec, + "NumDecodeValue (truncate)", + &numCodec{truncate: true}, + []subtest{ + { + "int ReadDouble (truncate)", int64(3), nil, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + nil, + }, + { + "uint ReadDouble (truncate)", uint64(3), nil, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + nil, + }, + { + "float32/fast path (truncate)", float32(3.14), nil, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + nil, + }, + { + "float32/reflection path (truncate)", myfloat32(3.14), nil, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + nil, + }, + }, + }, + { + "TimeDecodeValue", + &timeCodec{}, []subtest{ { "wrong type", @@ -790,8 +867,8 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "defaultMapCodec.DecodeValue", - defaultMapCodec, + "MapDecodeValue", + &mapCodec{}, []subtest{ { "wrong kind", @@ -804,7 +881,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "wrong kind (non-string key)", map[bool]interface{}{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{}, readElement, fmt.Errorf("unsupported key type: %T", false), @@ -820,7 +897,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", map[string]string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{}, readDocument, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -828,7 +905,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "ReadElement Error", make(map[string]interface{}), - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("re error"), ErrAfter: readElement}, readElement, errors.New("re error"), @@ -869,7 +946,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "ArrayDecodeValue", - ValueDecoderFunc(dvd.ArrayDecodeValue), + ValueDecoderFunc(arrayDecodeValue), []subtest{ { "wrong kind", @@ -906,7 +983,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", [1]string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -914,7 +991,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "ReadValue Error", [1]string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("rv error"), ErrAfter: readValue, BSONType: TypeArray}, readValue, errors.New("rv error"), @@ -922,7 +999,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "DecodeValue Error", [1]string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{BSONType: TypeArray}, readValue, &DecodeError{keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type")}, @@ -962,8 +1039,8 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "defaultSliceCodec.DecodeValue", - defaultSliceCodec, + "SliceDecodeValue", + &sliceCodec{}, []subtest{ { "wrong kind", @@ -1000,7 +1077,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", []string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -1008,7 +1085,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "ReadValue Error", []string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("rv error"), ErrAfter: readValue, BSONType: TypeArray}, readValue, errors.New("rv error"), @@ -1016,7 +1093,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "DecodeValue Error", []string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{BSONType: TypeArray}, readValue, &DecodeError{keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type")}, @@ -1057,7 +1134,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "ObjectIDDecodeValue", - ValueDecoderFunc(dvd.ObjectIDDecodeValue), + ValueDecoderFunc(objectIDDecodeValue), []subtest{ { "wrong type", @@ -1144,7 +1221,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "Decimal128DecodeValue", - ValueDecoderFunc(dvd.Decimal128DecodeValue), + ValueDecoderFunc(decimal128DecodeValue), []subtest{ { "wrong type", @@ -1206,7 +1283,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "JSONNumberDecodeValue", - ValueDecoderFunc(dvd.JSONNumberDecodeValue), + ValueDecoderFunc(jsonNumberDecodeValue), []subtest{ { "wrong type", @@ -1300,7 +1377,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "URLDecodeValue", - ValueDecoderFunc(dvd.URLDecodeValue), + ValueDecoderFunc(urlDecodeValue), []subtest{ { "wrong type", @@ -1373,8 +1450,8 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "defaultByteSliceCodec.DecodeValue", - defaultByteSliceCodec, + "ByteSliceDecodeValue", + &byteSliceCodec{}, []subtest{ { "wrong type", @@ -1441,8 +1518,8 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "defaultStringCodec.DecodeValue", - defaultStringCodec, + "StringDecodeValue", + &stringCodec{}, []subtest{ { "symbol", @@ -1472,7 +1549,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "ValueUnmarshalerDecodeValue", - ValueDecoderFunc(dvd.ValueUnmarshalerDecodeValue), + ValueDecoderFunc(valueUnmarshalerDecodeValue), []subtest{ { "wrong type", @@ -1506,7 +1583,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "UnmarshalerDecodeValue", - ValueDecoderFunc(dvd.UnmarshalerDecodeValue), + ValueDecoderFunc(unmarshalerDecodeValue), []subtest{ { "wrong type", @@ -1550,19 +1627,19 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "PointerCodec.DecodeValue", - NewPointerCodec(), + "PointerDecodeValue", + &pointerCodec{}, []subtest{ { "not valid", nil, nil, nil, nothing, - ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: reflect.Value{}}, + ValueDecoderError{Name: "pointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: reflect.Value{}}, }, { "can set", cansettest, nil, nil, nothing, - ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}}, + ValueDecoderError{Name: "pointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}}, }, { - "No Decoder", &wrong, &DecodeContext{Registry: buildDefaultRegistry()}, nil, nothing, + "No Decoder", &wrong, buildDefaultRegistry(), nil, nothing, ErrNoDecoder{Type: reflect.TypeOf(wrong)}, }, { @@ -1585,7 +1662,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "BinaryDecodeValue", - ValueDecoderFunc(dvd.BinaryDecodeValue), + ValueDecoderFunc(binaryDecodeValue), []subtest{ { "wrong type", @@ -1645,7 +1722,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "UndefinedDecodeValue", - ValueDecoderFunc(dvd.UndefinedDecodeValue), + ValueDecoderFunc(undefinedDecodeValue), []subtest{ { "wrong type", @@ -1691,7 +1768,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "DateTimeDecodeValue", - ValueDecoderFunc(dvd.DateTimeDecodeValue), + ValueDecoderFunc(dateTimeDecodeValue), []subtest{ { "wrong type", @@ -1745,7 +1822,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "NullDecodeValue", - ValueDecoderFunc(dvd.NullDecodeValue), + ValueDecoderFunc(nullDecodeValue), []subtest{ { "wrong type", @@ -1783,7 +1860,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "RegexDecodeValue", - ValueDecoderFunc(dvd.RegexDecodeValue), + ValueDecoderFunc(regexDecodeValue), []subtest{ { "wrong type", @@ -1843,7 +1920,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "DBPointerDecodeValue", - ValueDecoderFunc(dvd.DBPointerDecodeValue), + ValueDecoderFunc(dbPointerDecodeValue), []subtest{ { "wrong type", @@ -1908,7 +1985,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "TimestampDecodeValue", - ValueDecoderFunc(dvd.TimestampDecodeValue), + ValueDecoderFunc(timestampDecodeValue), []subtest{ { "wrong type", @@ -1968,7 +2045,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "MinKeyDecodeValue", - ValueDecoderFunc(dvd.MinKeyDecodeValue), + ValueDecoderFunc(minKeyDecodeValue), []subtest{ { "wrong type", @@ -2022,7 +2099,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "MaxKeyDecodeValue", - ValueDecoderFunc(dvd.MaxKeyDecodeValue), + ValueDecoderFunc(maxKeyDecodeValue), []subtest{ { "wrong type", @@ -2076,7 +2153,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "JavaScriptDecodeValue", - ValueDecoderFunc(dvd.JavaScriptDecodeValue), + ValueDecoderFunc(javaScriptDecodeValue), []subtest{ { "wrong type", @@ -2130,7 +2207,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "SymbolDecodeValue", - ValueDecoderFunc(dvd.SymbolDecodeValue), + ValueDecoderFunc(symbolDecodeValue), []subtest{ { "wrong type", @@ -2184,7 +2261,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "CoreDocumentDecodeValue", - ValueDecoderFunc(dvd.CoreDocumentDecodeValue), + ValueDecoderFunc(coreDocumentDecodeValue), []subtest{ { "wrong type", @@ -2221,7 +2298,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "StructCodec.DecodeValue", + "StructDecodeValue", defaultTestStructCodec, []subtest{ { @@ -2252,7 +2329,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "CodeWithScopeDecodeValue", - ValueDecoderFunc(dvd.CodeWithScopeDecodeValue), + ValueDecoderFunc(codeWithScopeDecodeValue), []subtest{ { "wrong type", @@ -2288,7 +2365,7 @@ func TestDefaultValueDecoders(t *testing.T) { Code: "var hello = 'world';", Scope: D{{"foo", nil}}, }, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{BSONType: TypeCodeWithScope, Err: errors.New("dd error"), ErrAfter: readElement}, readElement, errors.New("dd error"), @@ -2313,7 +2390,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "CoreArrayDecodeValue", - defaultArrayCodec, + &arrayCodec{}, []subtest{ { "wrong type", @@ -2347,10 +2424,6 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, rc := range tc.subtests { t.Run(rc.name, func(t *testing.T) { - var dc DecodeContext - if rc.dctx != nil { - dc = *rc.dctx - } llvrw := new(valueReaderWriter) if rc.llvrw != nil { llvrw = rc.llvrw @@ -2358,13 +2431,13 @@ func TestDefaultValueDecoders(t *testing.T) { llvrw.T = t // var got interface{} if rc.val == cansetreflectiontest { // We're doing a CanSet reflection test - err := tc.vd.DecodeValue(dc, llvrw, reflect.Value{}) + err := tc.vd.DecodeValue(rc.reg, llvrw, reflect.Value{}) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } val := reflect.New(reflect.TypeOf(rc.val)).Elem() - err = tc.vd.DecodeValue(dc, llvrw, val) + err = tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } @@ -2376,13 +2449,13 @@ func TestDefaultValueDecoders(t *testing.T) { t.Fatalf("Error must be a DecodeValueError, but got a %T", rc.err) } - err := tc.vd.DecodeValue(dc, llvrw, reflect.Value{}) + err := tc.vd.DecodeValue(rc.reg, llvrw, reflect.Value{}) wanterr.Received = reflect.ValueOf(nil) if !assert.CompareErrors(err, wanterr) { t.Errorf("Errors do not match. got %v; want %v", err, wanterr) } - err = tc.vd.DecodeValue(dc, llvrw, reflect.ValueOf(int(12345))) + err = tc.vd.DecodeValue(rc.reg, llvrw, reflect.ValueOf(int(12345))) wanterr.Received = reflect.ValueOf(int(12345)) if !assert.CompareErrors(err, wanterr) { t.Errorf("Errors do not match. got %v; want %v", err, wanterr) @@ -2396,11 +2469,10 @@ func TestDefaultValueDecoders(t *testing.T) { want := rc.val defer func() { if err := recover(); err != nil { - fmt.Println(t.Name()) panic(err) } }() - err := tc.vd.DecodeValue(dc, llvrw, val) + err := tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } @@ -2421,7 +2493,7 @@ func TestDefaultValueDecoders(t *testing.T) { } t.Run("CodeWithScopeCodec/DecodeValue/success", func(t *testing.T) { - dc := DecodeContext{Registry: buildDefaultRegistry()} + reg := buildDefaultRegistry() b := bsoncore.BuildDocument(nil, bsoncore.AppendCodeWithScopeElement( nil, "foo", "var hello = 'world';", @@ -2439,7 +2511,7 @@ func TestDefaultValueDecoders(t *testing.T) { Scope: D{{"bar", nil}}, } val := reflect.New(tCodeWithScope).Elem() - err = dvd.CodeWithScopeDecodeValue(dc, vr, val) + err = codeWithScopeDecodeValue(reg, vr, val) noerr(t, err) got := val.Interface().(CodeWithScope) @@ -2448,34 +2520,31 @@ func TestDefaultValueDecoders(t *testing.T) { } }) t.Run("ValueUnmarshalerDecodeValue/UnmarshalBSONValue error", func(t *testing.T) { - var dc DecodeContext llvrw := &valueReaderWriter{BSONType: TypeString, Return: string("hello, world!")} llvrw.T = t want := errors.New("ubsonv error") valUnmarshaler := &testValueUnmarshaler{err: want} - got := dvd.ValueUnmarshalerDecodeValue(dc, llvrw, reflect.ValueOf(valUnmarshaler)) + got := valueUnmarshalerDecodeValue(nil, llvrw, reflect.ValueOf(valUnmarshaler)) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } }) t.Run("ValueUnmarshalerDecodeValue/Unaddressable value", func(t *testing.T) { - var dc DecodeContext llvrw := &valueReaderWriter{BSONType: TypeString, Return: string("hello, world!")} llvrw.T = t val := reflect.ValueOf(testValueUnmarshaler{}) want := ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} - got := dvd.ValueUnmarshalerDecodeValue(dc, llvrw, val) + got := valueUnmarshalerDecodeValue(nil, llvrw, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } }) - t.Run("SliceCodec/DecodeValue/can't set slice", func(t *testing.T) { var val []string want := ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: reflect.ValueOf(val)} - got := dvd.SliceDecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) + got := (&sliceCodec{}).DecodeValue(nil, nil, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -2498,8 +2567,8 @@ func TestDefaultValueDecoders(t *testing.T) { var val [1]string want := fmt.Errorf("more elements returned in array than can fit inside %T, got 2 elements", val) - dc := DecodeContext{Registry: buildDefaultRegistry()} - got := dvd.ArrayDecodeValue(dc, vr, reflect.ValueOf(val)) + reg := buildDefaultRegistry() + got := arrayDecodeValue(reg, vr, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -3147,7 +3216,7 @@ func TestDefaultValueDecoders(t *testing.T) { noerr(t, err) gotVal := reflect.New(reflect.TypeOf(tc.value)).Elem() - err = dec.DecodeValue(DecodeContext{Registry: reg}, vr, gotVal) + err = dec.DecodeValue(reg, vr, gotVal) noerr(t, err) got := gotVal.Interface() @@ -3196,7 +3265,7 @@ func TestDefaultValueDecoders(t *testing.T) { noerr(t, err) gotVal := reflect.New(reflect.TypeOf(tc.value)).Elem() - err = dec.DecodeValue(DecodeContext{Registry: reg}, vr, gotVal) + err = dec.DecodeValue(reg, vr, gotVal) if err == nil || !strings.Contains(err.Error(), tc.err.Error()) { t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) } @@ -3205,6 +3274,7 @@ func TestDefaultValueDecoders(t *testing.T) { }) t.Run("defaultEmptyInterfaceCodec.DecodeValue", func(t *testing.T) { + defaultEmptyInterfaceCodec := &emptyInterfaceCodec{} t.Run("DecodeValue", func(t *testing.T) { testCases := []struct { name string @@ -3319,9 +3389,9 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() - dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} + reg := newTestRegistryBuilder().Build() want := ErrNoTypeMapEntry{Type: tc.bsontype} - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3332,13 +3402,11 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() - dc := DecodeContext{ - Registry: newTestRegistryBuilder(). - RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). - Build(), - } + reg := newTestRegistryBuilder(). + RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). + Build() want := ErrNoDecoder{Type: reflect.TypeOf(tc.val)} - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3349,14 +3417,12 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } want := errors.New("DecodeValue failure error") - llc := &llCodec{t: t, err: want} - dc := DecodeContext{ - Registry: newTestRegistryBuilder(). - RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). - RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). - Build(), - } - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, reflect.New(tEmpty).Elem()) + llc := func(*Registry) ValueDecoder { return &llCodec{t: t, err: want} } + reg := newTestRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). + RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). + Build() + got := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, reflect.New(tEmpty).Elem()) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3364,15 +3430,13 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("Success", func(t *testing.T) { want := tc.val - llc := &llCodec{t: t, decodeval: tc.val} - dc := DecodeContext{ - Registry: newTestRegistryBuilder(). - RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). - RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). - Build(), - } + llc := func(*Registry) ValueDecoder { return &llCodec{t: t, decodeval: tc.val} } + reg := newTestRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). + RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). + Build() got := reflect.New(tEmpty).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, got) + err := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, got) noerr(t, err) if !cmp.Equal(got.Interface(), want, cmp.Comparer(compareDecimal128)) { t.Errorf("Did not receive expected value. got %v; want %v", got.Interface(), want) @@ -3385,7 +3449,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("non-interface{}", func(t *testing.T) { val := uint64(1234567890) want := ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(val)} - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) + got := defaultEmptyInterfaceCodec.DecodeValue(nil, nil, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3394,7 +3458,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("nil *interface{}", func(t *testing.T) { var val interface{} want := ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(val)} - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) + got := defaultEmptyInterfaceCodec.DecodeValue(nil, nil, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3404,7 +3468,7 @@ func TestDefaultValueDecoders(t *testing.T) { llvr := &valueReaderWriter{BSONType: TypeDouble} want := ErrNoTypeMapEntry{Type: TypeDouble} val := reflect.New(tEmpty).Elem() - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: newTestRegistryBuilder().Build()}, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(newTestRegistryBuilder().Build(), llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3415,7 +3479,7 @@ func TestDefaultValueDecoders(t *testing.T) { want := D{{"pi", 3.14159}} var got interface{} val := reflect.ValueOf(&got).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: buildDefaultRegistry()}, vr, val) + err := defaultEmptyInterfaceCodec.DecodeValue(buildDefaultRegistry(), vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Errorf("Did not get correct result. got %v; want %v", got, want) @@ -3426,13 +3490,13 @@ func TestDefaultValueDecoders(t *testing.T) { // both top-level and embedded documents to decode to registered type when unmarshalling to interface{} topLevelRb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(topLevelRb) - defaultValueDecoders.RegisterDefaultDecoders(topLevelRb) + registerDefaultEncoders(topLevelRb) + registerDefaultDecoders(topLevelRb) topLevelRb.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) embeddedRb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(embeddedRb) - defaultValueDecoders.RegisterDefaultDecoders(embeddedRb) + registerDefaultEncoders(embeddedRb) + registerDefaultDecoders(embeddedRb) embeddedRb.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) // create doc {"nested": {"foo": 1}} @@ -3462,7 +3526,7 @@ func TestDefaultValueDecoders(t *testing.T) { vr := NewValueReader(doc) val := reflect.ValueOf(&got).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: tc.registry}, vr, val) + err := defaultEmptyInterfaceCodec.DecodeValue(tc.registry, vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Fatalf("got %v, want %v", got, want) @@ -3474,8 +3538,8 @@ func TestDefaultValueDecoders(t *testing.T) { // information if available instead of the registered entry. rb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(rb) - defaultValueDecoders.RegisterDefaultDecoders(rb) + registerDefaultEncoders(rb) + registerDefaultDecoders(rb) rb.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(M{})) reg := rb.Build() @@ -3497,7 +3561,7 @@ func TestDefaultValueDecoders(t *testing.T) { var got D vr := NewValueReader(doc) val := reflect.ValueOf(&got).Elem() - err := defaultSliceCodec.DecodeValue(DecodeContext{Registry: reg}, vr, val) + err := (&sliceCodec{}).DecodeValue(reg, vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Fatalf("got %v, want %v", got, want) @@ -3507,11 +3571,12 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("decode errors contain key information", func(t *testing.T) { decodeValueError := errors.New("decode value error") - emptyInterfaceErrorDecode := func(DecodeContext, ValueReader, reflect.Value) error { + emptyInterfaceErrorDecode := func(DecoderRegistry, ValueReader, reflect.Value) error { return decodeValueError } emptyInterfaceErrorRegistry := newTestRegistryBuilder(). - RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)).Build() + RegisterTypeDecoder(tEmpty, func(*Registry) ValueDecoder { return ValueDecoderFunc(emptyInterfaceErrorDecode) }). + Build() // Set up a document {foo: 10} and an error that would happen if the value were decoded into interface{} // using the registry defined above. @@ -3564,10 +3629,8 @@ func TestDefaultValueDecoders(t *testing.T) { // Use a registry that has all default decoders with the custom interface{} decoder that always errors. nestedRegistryBuilder := newTestRegistryBuilder() - defaultValueDecoders.RegisterDefaultDecoders(nestedRegistryBuilder) - nestedRegistry := nestedRegistryBuilder. - RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)). - Build() + registerDefaultDecoders(nestedRegistryBuilder) + nestedRegistryBuilder.RegisterTypeDecoder(tEmpty, func(*Registry) ValueDecoder { return ValueDecoderFunc(emptyInterfaceErrorDecode) }) nestedErr := &DecodeError{ keys: []string{"fourth", "1", "third", "randomKey", "second", "first"}, wrapped: decodeValueError, @@ -3587,7 +3650,7 @@ func TestDefaultValueDecoders(t *testing.T) { D{}, NewValueReader(docBytes), emptyInterfaceErrorRegistry, - defaultSliceCodec, + &sliceCodec{}, docEmptyInterfaceErr, }, { @@ -3596,7 +3659,7 @@ func TestDefaultValueDecoders(t *testing.T) { []string{}, &valueReaderWriter{BSONType: TypeArray}, nil, - defaultSliceCodec, + &sliceCodec{}, &DecodeError{ keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type"), @@ -3610,7 +3673,7 @@ func TestDefaultValueDecoders(t *testing.T) { [1]E{}, NewValueReader(docBytes), emptyInterfaceErrorRegistry, - ValueDecoderFunc(dvd.ArrayDecodeValue), + ValueDecoderFunc(arrayDecodeValue), docEmptyInterfaceErr, }, { @@ -3621,7 +3684,7 @@ func TestDefaultValueDecoders(t *testing.T) { [1]string{}, &valueReaderWriter{BSONType: TypeArray}, nil, - ValueDecoderFunc(dvd.ArrayDecodeValue), + ValueDecoderFunc(arrayDecodeValue), &DecodeError{ keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type"), @@ -3633,7 +3696,7 @@ func TestDefaultValueDecoders(t *testing.T) { map[string]interface{}{}, NewValueReader(docBytes), emptyInterfaceErrorRegistry, - defaultMapCodec, + &mapCodec{}, docEmptyInterfaceErr, }, { @@ -3660,23 +3723,23 @@ func TestDefaultValueDecoders(t *testing.T) { "deeply nested struct", outer{}, NewValueReader(outerDoc), - nestedRegistry, + nestedRegistryBuilder.Build(), defaultTestStructCodec, nestedErr, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - dc := DecodeContext{Registry: tc.registry} - if dc.Registry == nil { - dc.Registry = buildDefaultRegistry() + reg := tc.registry + if reg == nil { + reg = buildDefaultRegistry() } var val reflect.Value if rtype := reflect.TypeOf(tc.val); rtype != nil { val = reflect.New(rtype).Elem() } - err := tc.decoder.DecodeValue(dc, tc.vr, val) + err := tc.decoder.DecodeValue(reg, tc.vr, val) assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err) }) } @@ -3688,10 +3751,10 @@ func TestDefaultValueDecoders(t *testing.T) { type inner struct{ Bar string } type outer struct{ Foo inner } - dc := DecodeContext{Registry: buildDefaultRegistry()} + reg := buildDefaultRegistry() vr := NewValueReader(outerBytes) val := reflect.New(reflect.TypeOf(outer{})).Elem() - err := defaultTestStructCodec.DecodeValue(dc, vr, val) + err := defaultTestStructCodec.DecodeValue(reg, vr, val) var decodeErr *DecodeError assert.True(t, errors.As(err, &decodeErr), "expected DecodeError, got %v of type %T", err, err) @@ -3718,13 +3781,13 @@ func TestDefaultValueDecoders(t *testing.T) { ) rb := newTestRegistryBuilder() - defaultValueDecoders.RegisterDefaultDecoders(rb) - reg := rb.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))).Build() + registerDefaultDecoders(rb) + rb.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) - dc := DecodeContext{Registry: reg} + reg := rb.Build() vr := NewValueReader(docBytes) val := reflect.New(tD).Elem() - err := defaultValueDecoders.DDecodeValue(dc, vr, val) + err := dDecodeValue(reg, vr, val) assert.Nil(t, err, "DDecodeValue error: %v", err) want := D{ @@ -3740,10 +3803,10 @@ func TestDefaultValueDecoders(t *testing.T) { ) type myMap map[string]mybool - dc := DecodeContext{Registry: buildDefaultRegistry()} + reg := buildDefaultRegistry() vr := NewValueReader(docBytes) val := reflect.New(reflect.TypeOf(myMap{})).Elem() - err := defaultMapCodec.DecodeValue(dc, vr, val) + err := (&mapCodec{}).DecodeValue(reg, vr, val) assert.Nil(t, err, "DecodeValue error: %v", err) want := myMap{ @@ -3787,7 +3850,7 @@ func buildDocument(elems []byte) []byte { func buildDefaultRegistry() *Registry { rb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(rb) - defaultValueDecoders.RegisterDefaultDecoders(rb) + registerDefaultEncoders(rb) + registerDefaultDecoders(rb) return rb.Build() } diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index f2773c36e5..c9eb3fbe08 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -9,18 +9,14 @@ package bson import ( "encoding/json" "errors" - "fmt" "math" "net/url" "reflect" "sync" - "time" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -var defaultValueEncoders DefaultValueEncoders - var bvwPool = NewValueWriterPool() var errInvalidValue = errors.New("cannot encode invalid element") @@ -32,7 +28,7 @@ var sliceWriterPool = sync.Pool{ }, } -func encodeElement(ec EncodeContext, dw DocumentWriter, e E) error { +func encodeElement(reg EncoderRegistry, dw DocumentWriter, e E) error { vw, err := dw.WriteDocumentElement(e.Key) if err != nil { return err @@ -41,85 +37,78 @@ func encodeElement(ec EncodeContext, dw DocumentWriter, e E) error { if e.Value == nil { return vw.WriteNull() } - encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value)) + encoder, err := reg.LookupEncoder(reflect.TypeOf(e.Value)) if err != nil { return err } - err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value)) + err = encoder.EncodeValue(reg, vw, reflect.ValueOf(e.Value)) if err != nil { return err } return nil } -// DefaultValueEncoders is a namespace type for the default ValueEncoders used -// when creating a registry. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -type DefaultValueEncoders struct{} - -// RegisterDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with -// the provided RegistryBuilder. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) { +// registerDefaultEncoders will register the default encoder methods with the provided Registry. +func registerDefaultEncoders(rb *RegistryBuilder) { if rb == nil { panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) } - rb. - RegisterTypeEncoder(tByteSlice, defaultByteSliceCodec). - RegisterTypeEncoder(tTime, defaultTimeCodec). - RegisterTypeEncoder(tEmpty, defaultEmptyInterfaceCodec). - RegisterTypeEncoder(tCoreArray, defaultArrayCodec). - RegisterTypeEncoder(tOID, ValueEncoderFunc(dve.ObjectIDEncodeValue)). - RegisterTypeEncoder(tDecimal, ValueEncoderFunc(dve.Decimal128EncodeValue)). - RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(dve.JSONNumberEncodeValue)). - RegisterTypeEncoder(tURL, ValueEncoderFunc(dve.URLEncodeValue)). - RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(dve.JavaScriptEncodeValue)). - RegisterTypeEncoder(tSymbol, ValueEncoderFunc(dve.SymbolEncodeValue)). - RegisterTypeEncoder(tBinary, ValueEncoderFunc(dve.BinaryEncodeValue)). - RegisterTypeEncoder(tUndefined, ValueEncoderFunc(dve.UndefinedEncodeValue)). - RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dve.DateTimeEncodeValue)). - RegisterTypeEncoder(tNull, ValueEncoderFunc(dve.NullEncodeValue)). - RegisterTypeEncoder(tRegex, ValueEncoderFunc(dve.RegexEncodeValue)). - RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dve.DBPointerEncodeValue)). - RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(dve.TimestampEncodeValue)). - RegisterTypeEncoder(tMinKey, ValueEncoderFunc(dve.MinKeyEncodeValue)). - RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(dve.MaxKeyEncodeValue)). - RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(dve.CoreDocumentEncodeValue)). - RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(dve.CodeWithScopeEncodeValue)). - RegisterDefaultEncoder(reflect.Bool, ValueEncoderFunc(dve.BooleanEncodeValue)). - RegisterDefaultEncoder(reflect.Int, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int8, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int16, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int32, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int64, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Uint, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint8, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint16, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint32, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint64, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Float32, ValueEncoderFunc(dve.FloatEncodeValue)). - RegisterDefaultEncoder(reflect.Float64, ValueEncoderFunc(dve.FloatEncodeValue)). - RegisterDefaultEncoder(reflect.Array, ValueEncoderFunc(dve.ArrayEncodeValue)). - RegisterDefaultEncoder(reflect.Map, defaultMapCodec). - RegisterDefaultEncoder(reflect.Slice, defaultSliceCodec). - RegisterDefaultEncoder(reflect.String, defaultStringCodec). - RegisterDefaultEncoder(reflect.Struct, newDefaultStructCodec()). - RegisterDefaultEncoder(reflect.Ptr, NewPointerCodec()). - RegisterHookEncoder(tValueMarshaler, ValueEncoderFunc(dve.ValueMarshalerEncodeValue)). - RegisterHookEncoder(tMarshaler, ValueEncoderFunc(dve.MarshalerEncodeValue)). - RegisterHookEncoder(tProxy, ValueEncoderFunc(dve.ProxyEncodeValue)) -} -// BooleanEncodeValue is the ValueEncoderFunc for bool types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) BooleanEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + numEncoder := func(*Registry) ValueEncoder { return &numCodec{} } + rb. + RegisterTypeEncoder(tByteSlice, func(*Registry) ValueEncoder { return &byteSliceCodec{} }). + RegisterTypeEncoder(tTime, func(*Registry) ValueEncoder { return &timeCodec{} }). + RegisterTypeEncoder(tEmpty, func(*Registry) ValueEncoder { return &emptyInterfaceCodec{} }). + RegisterTypeEncoder(tCoreArray, func(*Registry) ValueEncoder { return &arrayCodec{} }). + RegisterTypeEncoder(tOID, func(*Registry) ValueEncoder { return ValueEncoderFunc(objectIDEncodeValue) }). + RegisterTypeEncoder(tDecimal, func(*Registry) ValueEncoder { return ValueEncoderFunc(decimal128EncodeValue) }). + RegisterTypeEncoder(tJSONNumber, func(*Registry) ValueEncoder { return ValueEncoderFunc(jsonNumberEncodeValue) }). + RegisterTypeEncoder(tURL, func(*Registry) ValueEncoder { return ValueEncoderFunc(urlEncodeValue) }). + RegisterTypeEncoder(tJavaScript, func(*Registry) ValueEncoder { return ValueEncoderFunc(javaScriptEncodeValue) }). + RegisterTypeEncoder(tSymbol, func(*Registry) ValueEncoder { return ValueEncoderFunc(symbolEncodeValue) }). + RegisterTypeEncoder(tBinary, func(*Registry) ValueEncoder { return ValueEncoderFunc(binaryEncodeValue) }). + RegisterTypeEncoder(tUndefined, func(*Registry) ValueEncoder { return ValueEncoderFunc(undefinedEncodeValue) }). + RegisterTypeEncoder(tDateTime, func(*Registry) ValueEncoder { return ValueEncoderFunc(dateTimeEncodeValue) }). + RegisterTypeEncoder(tNull, func(*Registry) ValueEncoder { return ValueEncoderFunc(nullEncodeValue) }). + RegisterTypeEncoder(tRegex, func(*Registry) ValueEncoder { return ValueEncoderFunc(regexEncodeValue) }). + RegisterTypeEncoder(tDBPointer, func(*Registry) ValueEncoder { return ValueEncoderFunc(dbPointerEncodeValue) }). + RegisterTypeEncoder(tTimestamp, func(*Registry) ValueEncoder { return ValueEncoderFunc(timestampEncodeValue) }). + RegisterTypeEncoder(tMinKey, func(*Registry) ValueEncoder { return ValueEncoderFunc(minKeyEncodeValue) }). + RegisterTypeEncoder(tMaxKey, func(*Registry) ValueEncoder { return ValueEncoderFunc(maxKeyEncodeValue) }). + RegisterTypeEncoder(tCoreDocument, func(*Registry) ValueEncoder { return ValueEncoderFunc(coreDocumentEncodeValue) }). + RegisterTypeEncoder(tCodeWithScope, func(*Registry) ValueEncoder { return ValueEncoderFunc(codeWithScopeEncodeValue) }). + RegisterKindEncoder(reflect.Bool, func(*Registry) ValueEncoder { return ValueEncoderFunc(booleanEncodeValue) }). + RegisterKindEncoder(reflect.Int, numEncoder). + RegisterKindEncoder(reflect.Int8, numEncoder). + RegisterKindEncoder(reflect.Int16, numEncoder). + RegisterKindEncoder(reflect.Int32, numEncoder). + RegisterKindEncoder(reflect.Int64, numEncoder). + RegisterKindEncoder(reflect.Uint, numEncoder). + RegisterKindEncoder(reflect.Uint8, numEncoder). + RegisterKindEncoder(reflect.Uint16, numEncoder). + RegisterKindEncoder(reflect.Uint32, numEncoder). + RegisterKindEncoder(reflect.Uint64, numEncoder). + RegisterKindEncoder(reflect.Float32, numEncoder). + RegisterKindEncoder(reflect.Float64, numEncoder). + RegisterKindEncoder(reflect.Array, func(*Registry) ValueEncoder { return ValueEncoderFunc(arrayEncodeValue) }). + RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return &mapCodec{} }). + RegisterKindEncoder(reflect.Slice, func(*Registry) ValueEncoder { return &sliceCodec{} }). + RegisterKindEncoder(reflect.String, func(*Registry) ValueEncoder { return &stringCodec{} }). + RegisterKindEncoder(reflect.Struct, func(reg *Registry) ValueEncoder { + // reflect.Struct is 25 that is bigger than reflect.Map, 21, in the kind array, + // so Map will be registered earlier than Struct. + enc, _ := reg.lookupKindEncoder(reflect.Map) + return newStructCodec(enc.(mapElementsEncoder)) + }). + RegisterKindEncoder(reflect.Ptr, func(*Registry) ValueEncoder { return &pointerCodec{} }). + RegisterInterfaceEncoder(tValueMarshaler, func(*Registry) ValueEncoder { return ValueEncoderFunc(valueMarshalerEncodeValue) }). + RegisterInterfaceEncoder(tMarshaler, func(*Registry) ValueEncoder { return ValueEncoderFunc(marshalerEncodeValue) }). + RegisterInterfaceEncoder(tProxy, func(*Registry) ValueEncoder { return ValueEncoderFunc(proxyEncodeValue) }) +} + +// booleanEncodeValue is the ValueEncoderFunc for bool types. +func booleanEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Bool { return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } @@ -130,115 +119,24 @@ func fitsIn32Bits(i int64) bool { return math.MinInt32 <= i && i <= math.MaxInt32 } -// IntEncodeValue is the ValueEncoderFunc for int types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - switch val.Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32: - return vw.WriteInt32(int32(val.Int())) - case reflect.Int: - i64 := val.Int() - if fitsIn32Bits(i64) { - return vw.WriteInt32(int32(i64)) - } - return vw.WriteInt64(i64) - case reflect.Int64: - i64 := val.Int() - if ec.MinSize && fitsIn32Bits(i64) { - return vw.WriteInt32(int32(i64)) - } - return vw.WriteInt64(i64) - } - - return ValueEncoderError{ - Name: "IntEncodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, - Received: val, - } -} - -// UintEncodeValue is the ValueEncoderFunc for uint types. -// -// Deprecated: UintEncodeValue is not registered by default. Use UintCodec.EncodeValue instead. -func (dve DefaultValueEncoders) UintEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - switch val.Kind() { - case reflect.Uint8, reflect.Uint16: - return vw.WriteInt32(int32(val.Uint())) - case reflect.Uint, reflect.Uint32, reflect.Uint64: - u64 := val.Uint() - if ec.MinSize && u64 <= math.MaxInt32 { - return vw.WriteInt32(int32(u64)) - } - if u64 > math.MaxInt64 { - return fmt.Errorf("%d overflows int64", u64) - } - return vw.WriteInt64(int64(u64)) - } - - return ValueEncoderError{ - Name: "UintEncodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } -} - -// FloatEncodeValue is the ValueEncoderFunc for float types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) FloatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - switch val.Kind() { - case reflect.Float32, reflect.Float64: - return vw.WriteDouble(val.Float()) - } - - return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} -} - -// StringEncodeValue is the ValueEncoderFunc for string types. -// -// Deprecated: StringEncodeValue is not registered by default. Use StringCodec.EncodeValue instead. -func (dve DefaultValueEncoders) StringEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if val.Kind() != reflect.String { - return ValueEncoderError{ - Name: "StringEncodeValue", - Kinds: []reflect.Kind{reflect.String}, - Received: val, - } - } - - return vw.WriteString(val.String()) -} - -// ObjectIDEncodeValue is the ValueEncoderFunc for ObjectID. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ObjectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// objectIDEncodeValue is the ValueEncoderFunc for ObjectID. +func objectIDEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tOID { return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} } return vw.WriteObjectID(val.Interface().(ObjectID)) } -// Decimal128EncodeValue is the ValueEncoderFunc for Decimal128. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) Decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// decimal128EncodeValue is the ValueEncoderFunc for Decimal128. +func decimal128EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDecimal { return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} } return vw.WriteDecimal128(val.Interface().(Decimal128)) } -// JSONNumberEncodeValue is the ValueEncoderFunc for json.Number. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. +func jsonNumberEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJSONNumber { return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} } @@ -246,7 +144,11 @@ func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw Value // Attempt int first, then float64 if i64, err := jsnum.Int64(); err == nil { - return dve.IntEncodeValue(ec, vw, reflect.ValueOf(i64)) + encoder, err := reg.LookupEncoder(tInt64) + if err != nil { + return err + } + return encoder.EncodeValue(reg, vw, reflect.ValueOf(i64)) } f64, err := jsnum.Float64() @@ -254,123 +156,26 @@ func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw Value return err } - return dve.FloatEncodeValue(ec, vw, reflect.ValueOf(f64)) -} - -// URLEncodeValue is the ValueEncoderFunc for url.URL. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) URLEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tURL { - return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} - } - u := val.Interface().(url.URL) - return vw.WriteString(u.String()) -} - -// TimeEncodeValue is the ValueEncoderFunc for time.TIme. -// -// Deprecated: TimeEncodeValue is not registered by default. Use TimeCodec.EncodeValue instead. -func (dve DefaultValueEncoders) TimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tTime { - return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} - } - tt := val.Interface().(time.Time) - dt := NewDateTimeFromTime(tt) - return vw.WriteDateTime(int64(dt)) -} - -// ByteSliceEncodeValue is the ValueEncoderFunc for []byte. -// -// Deprecated: ByteSliceEncodeValue is not registered by default. Use ByteSliceCodec.EncodeValue instead. -func (dve DefaultValueEncoders) ByteSliceEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tByteSlice { - return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} - } - if val.IsNil() { - return vw.WriteNull() - } - return vw.WriteBinary(val.Interface().([]byte)) -} - -// MapEncodeValue is the ValueEncoderFunc for map[string]* types. -// -// Deprecated: MapEncodeValue is not registered by default. Use MapCodec.EncodeValue instead. -func (dve DefaultValueEncoders) MapEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String { - return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} - } - - if val.IsNil() { - // If we have a nill map but we can't WriteNull, that means we're probably trying to encode - // to a TopLevel document. We can't currently tell if this is what actually happened, but if - // there's a deeper underlying problem, the error will also be returned from WriteDocument, - // so just continue. The operations on a map reflection value are valid, so we can call - // MapKeys within mapEncodeValue without a problem. - err := vw.WriteNull() - if err == nil { - return nil - } - } - - dw, err := vw.WriteDocument() + var encoder ValueEncoder + encoder, err = reg.LookupEncoder(reflect.TypeOf(f64)) if err != nil { return err } - return dve.mapEncodeValue(ec, dw, val, nil) + return encoder.EncodeValue(reg, vw, reflect.ValueOf(f64)) } -// mapEncodeValue handles encoding of the values of a map. The collisionFn returns -// true if the provided key exists, this is mainly used for inline maps in the -// struct codec. -func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { - - elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) - if err != nil && elemType.Kind() != reflect.Interface { - return err - } - - keys := val.MapKeys() - for _, key := range keys { - if collisionFn != nil && collisionFn(key.String()) { - return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) - } - - currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.MapIndex(key)) - if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { - return lookupErr - } - - vw, err := dw.WriteDocumentElement(key.String()) - if err != nil { - return err - } - - if errors.Is(lookupErr, errInvalidValue) { - err = vw.WriteNull() - if err != nil { - return err - } - continue - } - - err = currEncoder.EncodeValue(ec, vw, currVal) - if err != nil { - return err - } +// urlEncodeValue is the ValueEncoderFunc for url.URL. +func urlEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tURL { + return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} } - - return dw.WriteDocumentEnd() + u := val.Interface().(url.URL) + return vw.WriteString(u.String()) } -// ArrayEncodeValue is the ValueEncoderFunc for array types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// arrayEncodeValue is the ValueEncoderFunc for array types. +func arrayEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} } @@ -384,7 +189,7 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw ValueWrite for idx := 0; idx < val.Len(); idx++ { e := val.Index(idx).Interface().(E) - err = encodeElement(ec, dw, e) + err = encodeElement(reg, dw, e) if err != nil { return err } @@ -408,82 +213,13 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw ValueWrite } elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) - if err != nil && elemType.Kind() != reflect.Interface { - return err - } - - for idx := 0; idx < val.Len(); idx++ { - currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) - if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { - return lookupErr - } - - vw, err := aw.WriteArrayElement() - if err != nil { - return err - } - - if errors.Is(lookupErr, errInvalidValue) { - err = vw.WriteNull() - if err != nil { - return err - } - continue - } - - err = currEncoder.EncodeValue(ec, vw, currVal) - if err != nil { - return err - } - } - return aw.WriteArrayEnd() -} - -// SliceEncodeValue is the ValueEncoderFunc for slice types. -// -// Deprecated: SliceEncodeValue is not registered by default. Use SliceCodec.EncodeValue instead. -func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Kind() != reflect.Slice { - return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} - } - - if val.IsNil() { - return vw.WriteNull() - } - - // If we have a []E we want to treat it as a document instead of as an array. - if val.Type().ConvertibleTo(tD) { - d := val.Convert(tD).Interface().(D) - - dw, err := vw.WriteDocument() - if err != nil { - return err - } - - for _, e := range d { - err = encodeElement(ec, dw, e) - if err != nil { - return err - } - } - - return dw.WriteDocumentEnd() - } - - aw, err := vw.WriteArray() - if err != nil { - return err - } - - elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) + encoder, err := reg.LookupEncoder(elemType) if err != nil && elemType.Kind() != reflect.Interface { return err } for idx := 0; idx < val.Len(); idx++ { - currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) + currEncoder, currVal, lookupErr := lookupElementEncoder(reg, encoder, val.Index(idx)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -501,7 +237,7 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw ValueWrite continue } - err = currEncoder.EncodeValue(ec, vw, currVal) + err = currEncoder.EncodeValue(reg, vw, currVal) if err != nil { return err } @@ -509,7 +245,7 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw ValueWrite return aw.WriteArrayEnd() } -func (dve DefaultValueEncoders) lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { +func lookupElementEncoder(reg EncoderRegistry, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { if origEncoder != nil || (currVal.Kind() != reflect.Interface) { return origEncoder, currVal, nil } @@ -517,35 +253,13 @@ func (dve DefaultValueEncoders) lookupElementEncoder(ec EncodeContext, origEncod if !currVal.IsValid() { return nil, currVal, errInvalidValue } - currEncoder, err := ec.LookupEncoder(currVal.Type()) + currEncoder, err := reg.LookupEncoder(currVal.Type()) return currEncoder, currVal, err } -// EmptyInterfaceEncodeValue is the ValueEncoderFunc for interface{}. -// -// Deprecated: EmptyInterfaceEncodeValue is not registered by default. Use EmptyInterfaceCodec.EncodeValue instead. -func (dve DefaultValueEncoders) EmptyInterfaceEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tEmpty { - return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} - } - - if val.IsNil() { - return vw.WriteNull() - } - encoder, err := ec.LookupEncoder(val.Elem().Type()) - if err != nil { - return err - } - - return encoder.EncodeValue(ec, vw, val.Elem()) -} - -// ValueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// valueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. +func valueMarshalerEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement ValueMarshaler switch { case !val.IsValid(): @@ -572,11 +286,8 @@ func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw Va return copyValueFromBytes(vw, t, data) } -// MarshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// marshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. +func marshalerEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Marshaler switch { case !val.IsValid(): @@ -603,11 +314,8 @@ func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw ValueWr return copyValueFromBytes(vw, TypeEmbeddedDocument, data) } -// ProxyEncodeValue is the ValueEncoderFunc for Proxy implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// proxyEncodeValue is the ValueEncoderFunc for Proxy implementations. +func proxyEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Proxy switch { case !val.IsValid(): @@ -632,29 +340,26 @@ func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw ValueWrite return err } if v == nil { - encoder, err := ec.LookupEncoder(nil) + encoder, err := reg.LookupEncoder(nil) if err != nil { return err } - return encoder.EncodeValue(ec, vw, reflect.ValueOf(nil)) + return encoder.EncodeValue(reg, vw, reflect.ValueOf(nil)) } vv := reflect.ValueOf(v) switch vv.Kind() { case reflect.Ptr, reflect.Interface: vv = vv.Elem() } - encoder, err := ec.LookupEncoder(vv.Type()) + encoder, err := reg.LookupEncoder(vv.Type()) if err != nil { return err } - return encoder.EncodeValue(ec, vw, vv) + return encoder.EncodeValue(reg, vw, vv) } -// JavaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) JavaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. +func javaScriptEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJavaScript { return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} } @@ -662,11 +367,8 @@ func (DefaultValueEncoders) JavaScriptEncodeValue(_ EncodeContext, vw ValueWrite return vw.WriteJavascript(val.String()) } -// SymbolEncodeValue is the ValueEncoderFunc for the Symbol type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) SymbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// symbolEncodeValue is the ValueEncoderFunc for the Symbol type. +func symbolEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tSymbol { return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} } @@ -674,11 +376,8 @@ func (DefaultValueEncoders) SymbolEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteSymbol(val.String()) } -// BinaryEncodeValue is the ValueEncoderFunc for Binary. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) BinaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// binaryEncodeValue is the ValueEncoderFunc for Binary. +func binaryEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tBinary { return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} } @@ -687,11 +386,8 @@ func (DefaultValueEncoders) BinaryEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } -// UndefinedEncodeValue is the ValueEncoderFunc for Undefined. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) UndefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// undefinedEncodeValue is the ValueEncoderFunc for Undefined. +func undefinedEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tUndefined { return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} } @@ -699,11 +395,8 @@ func (DefaultValueEncoders) UndefinedEncodeValue(_ EncodeContext, vw ValueWriter return vw.WriteUndefined() } -// DateTimeEncodeValue is the ValueEncoderFunc for DateTime. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) DateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// dateTimeEncodeValue is the ValueEncoderFunc for DateTime. +func dateTimeEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDateTime { return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} } @@ -711,11 +404,8 @@ func (DefaultValueEncoders) DateTimeEncodeValue(_ EncodeContext, vw ValueWriter, return vw.WriteDateTime(val.Int()) } -// NullEncodeValue is the ValueEncoderFunc for Null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) NullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// nullEncodeValue is the ValueEncoderFunc for Null. +func nullEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tNull { return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} } @@ -723,11 +413,8 @@ func (DefaultValueEncoders) NullEncodeValue(_ EncodeContext, vw ValueWriter, val return vw.WriteNull() } -// RegexEncodeValue is the ValueEncoderFunc for Regex. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) RegexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// regexEncodeValue is the ValueEncoderFunc for Regex. +func regexEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRegex { return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} } @@ -737,11 +424,8 @@ func (DefaultValueEncoders) RegexEncodeValue(_ EncodeContext, vw ValueWriter, va return vw.WriteRegex(regex.Pattern, regex.Options) } -// DBPointerEncodeValue is the ValueEncoderFunc for DBPointer. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) DBPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. +func dbPointerEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDBPointer { return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} } @@ -751,11 +435,8 @@ func (DefaultValueEncoders) DBPointerEncodeValue(_ EncodeContext, vw ValueWriter return vw.WriteDBPointer(dbp.DB, dbp.Pointer) } -// TimestampEncodeValue is the ValueEncoderFunc for Timestamp. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) TimestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// timestampEncodeValue is the ValueEncoderFunc for Timestamp. +func timestampEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTimestamp { return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } @@ -765,11 +446,8 @@ func (DefaultValueEncoders) TimestampEncodeValue(_ EncodeContext, vw ValueWriter return vw.WriteTimestamp(ts.T, ts.I) } -// MinKeyEncodeValue is the ValueEncoderFunc for MinKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) MinKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// minKeyEncodeValue is the ValueEncoderFunc for MinKey. +func minKeyEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMinKey { return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} } @@ -777,11 +455,8 @@ func (DefaultValueEncoders) MinKeyEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteMinKey() } -// MaxKeyEncodeValue is the ValueEncoderFunc for MaxKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) MaxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. +func maxKeyEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMaxKey { return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} } @@ -789,11 +464,8 @@ func (DefaultValueEncoders) MaxKeyEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteMaxKey() } -// CoreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) CoreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. +func coreDocumentEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreDocument { return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -803,11 +475,8 @@ func (DefaultValueEncoders) CoreDocumentEncodeValue(_ EncodeContext, vw ValueWri return copyDocumentFromBytes(vw, cdoc) } -// CodeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) CodeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. +func codeWithScopeEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCodeWithScope { return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} } @@ -826,12 +495,12 @@ func (dve DefaultValueEncoders) CodeWithScopeEncodeValue(ec EncodeContext, vw Va scopeVW := bvwPool.Get(sw) defer bvwPool.Put(scopeVW) - encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope)) + encoder, err := reg.LookupEncoder(reflect.TypeOf(cws.Scope)) if err != nil { return err } - err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope)) + err = encoder.EncodeValue(reg, scopeVW, reflect.ValueOf(cws.Scope)) if err != nil { return err } diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 481c6cb1a1..5faf76b25f 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -35,7 +35,6 @@ func (ms myStruct) Foo() int { } func TestDefaultValueEncoders(t *testing.T) { - var dve DefaultValueEncoders var wrong = func(string, string) string { return "wrong" } type mybool bool @@ -67,7 +66,7 @@ func TestDefaultValueEncoders(t *testing.T) { type subtest struct { name string val interface{} - ectx *EncodeContext + reg *Registry llvrw *valueReaderWriter invoke invoked err error @@ -80,7 +79,7 @@ func TestDefaultValueEncoders(t *testing.T) { }{ { "BooleanEncodeValue", - ValueEncoderFunc(dve.BooleanEncodeValue), + ValueEncoderFunc(booleanEncodeValue), []subtest{ { "wrong type", @@ -96,7 +95,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "IntEncodeValue", - ValueEncoderFunc(dve.IntEncodeValue), + &numCodec{}, []subtest{ { "wrong type", @@ -105,8 +104,12 @@ func TestDefaultValueEncoders(t *testing.T) { nil, nothing, ValueEncoderError{ - Name: "IntEncodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "NumEncodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, @@ -114,9 +117,6 @@ func TestDefaultValueEncoders(t *testing.T) { {"int16/fast path", int16(32767), nil, nil, writeInt32, nil}, {"int32/fast path", int32(2147483647), nil, nil, writeInt32, nil}, {"int64/fast path", int64(1234567890987), nil, nil, writeInt64, nil}, - {"int64/fast path - minsize", int64(math.MaxInt32), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, {"int/fast path - positive int32", int(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, {"int/fast path - negative int32", int(math.MinInt32 + 1), nil, nil, writeInt32, nil}, {"int/fast path - MaxInt32", int(math.MaxInt32), nil, nil, writeInt32, nil}, @@ -125,9 +125,6 @@ func TestDefaultValueEncoders(t *testing.T) { {"int16/reflection path", myint16(32767), nil, nil, writeInt32, nil}, {"int32/reflection path", myint32(2147483647), nil, nil, writeInt32, nil}, {"int64/reflection path", myint64(1234567890987), nil, nil, writeInt64, nil}, - {"int64/reflection path - minsize", myint64(math.MaxInt32), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, {"int/reflection path - positive int32", myint(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, {"int/reflection path - negative int32", myint(math.MinInt32 + 1), nil, nil, writeInt32, nil}, {"int/reflection path - MaxInt32", myint(math.MaxInt32), nil, nil, writeInt32, nil}, @@ -136,7 +133,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "UintEncodeValue", - defaultUIntCodec, + &numCodec{}, []subtest{ { "wrong type", @@ -145,8 +142,12 @@ func TestDefaultValueEncoders(t *testing.T) { nil, nothing, ValueEncoderError{ - Name: "UintEncodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "NumEncodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, @@ -155,29 +156,41 @@ func TestDefaultValueEncoders(t *testing.T) { {"uint32/fast path", uint32(2147483647), nil, nil, writeInt64, nil}, {"uint64/fast path", uint64(1234567890987), nil, nil, writeInt64, nil}, {"uint/fast path", uint(1234567), nil, nil, writeInt64, nil}, - {"uint32/fast path - minsize", uint32(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint64/fast path - minsize", uint64(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint/fast path - minsize", uint(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint32/fast path - minsize too large", uint32(2147483648), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"uint64/fast path - minsize too large", uint64(2147483648), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"uint/fast path - minsize too large", uint(2147483648), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, {"uint64/fast path - overflow", uint64(1 << 63), nil, nil, nothing, fmt.Errorf("%d overflows int64", uint64(1<<63))}, {"uint8/reflection path", myuint8(127), nil, nil, writeInt32, nil}, {"uint16/reflection path", myuint16(32767), nil, nil, writeInt32, nil}, {"uint32/reflection path", myuint32(2147483647), nil, nil, writeInt64, nil}, {"uint64/reflection path", myuint64(1234567890987), nil, nil, writeInt64, nil}, - {"uint32/reflection path - minsize", myuint32(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint64/reflection path - minsize", myuint64(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint/reflection path - minsize", myuint(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint32/reflection path - minsize too large", myuint(1 << 31), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"uint64/reflection path - minsize too large", myuint64(1 << 31), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"uint/reflection path - minsize too large", myuint(2147483648), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, {"uint64/reflection path - overflow", myuint64(1 << 63), nil, nil, nothing, fmt.Errorf("%d overflows int64", uint64(1<<63))}, }, }, + { + "NumEncodeValue (minSize)", + &numCodec{minSize: true}, + []subtest{ + {"int64/fast path - minsize", int64(math.MaxInt32), nil, nil, writeInt32, nil}, + {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), nil, nil, writeInt64, nil}, + {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), nil, nil, writeInt64, nil}, + {"int64/reflection path - minsize", myint64(math.MaxInt32), nil, nil, writeInt32, nil}, + {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), nil, nil, writeInt64, nil}, + {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), nil, nil, writeInt64, nil}, + {"uint32/fast path - minsize", uint32(2147483647), nil, nil, writeInt32, nil}, + {"uint64/fast path - minsize", uint64(2147483647), nil, nil, writeInt32, nil}, + {"uint/fast path - minsize", uint(2147483647), nil, nil, writeInt32, nil}, + {"uint32/fast path - minsize too large", uint32(2147483648), nil, nil, writeInt64, nil}, + {"uint64/fast path - minsize too large", uint64(2147483648), nil, nil, writeInt64, nil}, + {"uint/fast path - minsize too large", uint(2147483648), nil, nil, writeInt64, nil}, + {"uint32/reflection path - minsize", myuint32(2147483647), nil, nil, writeInt32, nil}, + {"uint64/reflection path - minsize", myuint64(2147483647), nil, nil, writeInt32, nil}, + {"uint/reflection path - minsize", myuint(2147483647), nil, nil, writeInt32, nil}, + {"uint32/reflection path - minsize too large", myuint(1 << 31), nil, nil, writeInt64, nil}, + {"uint64/reflection path - minsize too large", myuint64(1 << 31), nil, nil, writeInt64, nil}, + {"uint/reflection path - minsize too large", myuint(2147483648), nil, nil, writeInt64, nil}, + }, + }, { "FloatEncodeValue", - ValueEncoderFunc(dve.FloatEncodeValue), + &numCodec{}, []subtest{ { "wrong type", @@ -186,8 +199,12 @@ func TestDefaultValueEncoders(t *testing.T) { nil, nothing, ValueEncoderError{ - Name: "FloatEncodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Name: "NumEncodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, @@ -199,7 +216,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "TimeEncodeValue", - defaultTimeCodec, + &timeCodec{}, []subtest{ { "wrong type", @@ -214,7 +231,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MapEncodeValue", - defaultMapCodec, + &mapCodec{}, []subtest{ { "wrong kind", @@ -235,7 +252,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", map[string]int{"foo": 1}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeDocument, fmt.Errorf("no encoder found for int"), @@ -243,7 +260,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "WriteDocumentElement Error", map[string]interface{}{"foo": "bar"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("wde error"), ErrAfter: writeDocumentElement}, writeDocumentElement, errors.New("wde error"), @@ -251,7 +268,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "EncodeValue Error", map[string]interface{}{"foo": "bar"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeString}, writeString, errors.New("ev error"), @@ -259,7 +276,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty map/success", map[string]interface{}{}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeDocumentEnd, nil, @@ -267,7 +284,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "with interface/success", map[string]myInterface{"foo": myStruct{1}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -275,7 +292,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "with interface/nil/success", map[string]myInterface{"foo": nil}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -285,7 +302,7 @@ func TestDefaultValueEncoders(t *testing.T) { map[int]interface{}{ 1: "foobar", }, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{}, writeDocumentEnd, nil, @@ -294,7 +311,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ArrayEncodeValue", - ValueEncoderFunc(dve.ArrayEncodeValue), + ValueEncoderFunc(arrayEncodeValue), []subtest{ { "wrong kind", @@ -315,7 +332,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", [1]int{1}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -323,7 +340,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "WriteArrayElement Error", [1]string{"foo"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("wae error"), ErrAfter: writeArrayElement}, writeArrayElement, errors.New("wae error"), @@ -331,7 +348,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "EncodeValue Error", [1]string{"foo"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeString}, writeString, errors.New("ev error"), @@ -339,7 +356,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "[1]E/success", [1]E{{"hello", "world"}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -347,7 +364,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "[1]E/success", [1]E{{"hello", nil}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -355,7 +372,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "[1]interface/success", [1]myInterface{myStruct{1}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeArrayEnd, nil, @@ -363,7 +380,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "[1]interface/nil/success", [1]myInterface{nil}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeArrayEnd, nil, @@ -372,7 +389,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "SliceEncodeValue", - defaultSliceCodec, + &sliceCodec{}, []subtest{ { "wrong kind", @@ -393,7 +410,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", []int{1}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -401,7 +418,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "WriteArrayElement Error", []string{"foo"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("wae error"), ErrAfter: writeArrayElement}, writeArrayElement, errors.New("wae error"), @@ -409,7 +426,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "EncodeValue Error", []string{"foo"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeString}, writeString, errors.New("ev error"), @@ -417,7 +434,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "D/success", D{{"hello", "world"}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -425,7 +442,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "D/success", D{{"hello", nil}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -433,7 +450,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty slice/success", []interface{}{}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeArrayEnd, nil, @@ -441,7 +458,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "interface/success", []myInterface{myStruct{1}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeArrayEnd, nil, @@ -449,7 +466,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "interface/success", []myInterface{nil}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeArrayEnd, nil, @@ -458,7 +475,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ObjectIDEncodeValue", - ValueEncoderFunc(dve.ObjectIDEncodeValue), + ValueEncoderFunc(objectIDEncodeValue), []subtest{ { "wrong type", @@ -477,7 +494,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "Decimal128EncodeValue", - ValueEncoderFunc(dve.Decimal128EncodeValue), + ValueEncoderFunc(decimal128EncodeValue), []subtest{ { "wrong type", @@ -492,7 +509,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "JSONNumberEncodeValue", - ValueEncoderFunc(dve.JSONNumberEncodeValue), + ValueEncoderFunc(jsonNumberEncodeValue), []subtest{ { "wrong type", @@ -510,18 +527,18 @@ func TestDefaultValueEncoders(t *testing.T) { { "json.Number/int64/success", json.Number("1234567890"), - nil, nil, writeInt64, nil, + buildDefaultRegistry(), nil, writeInt64, nil, }, { "json.Number/float64/success", json.Number("3.14159"), - nil, nil, writeDouble, nil, + buildDefaultRegistry(), nil, writeDouble, nil, }, }, }, { "URLEncodeValue", - ValueEncoderFunc(dve.URLEncodeValue), + ValueEncoderFunc(urlEncodeValue), []subtest{ { "wrong type", @@ -536,7 +553,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ByteSliceEncodeValue", - defaultByteSliceCodec, + &byteSliceCodec{}, []subtest{ { "wrong type", @@ -552,7 +569,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "EmptyInterfaceEncodeValue", - defaultEmptyInterfaceCodec, + &emptyInterfaceCodec{}, []subtest{ { "wrong type", @@ -566,7 +583,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ValueMarshalerEncodeValue", - ValueEncoderFunc(dve.ValueMarshalerEncodeValue), + ValueEncoderFunc(valueMarshalerEncodeValue), []subtest{ { "wrong type", @@ -644,7 +661,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MarshalerEncodeValue", - ValueEncoderFunc(dve.MarshalerEncodeValue), + ValueEncoderFunc(marshalerEncodeValue), []subtest{ { "wrong type", @@ -706,7 +723,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ProxyEncodeValue", - ValueEncoderFunc(dve.ProxyEncodeValue), + ValueEncoderFunc(proxyEncodeValue), []subtest{ { "wrong type", @@ -727,7 +744,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup error", testProxy{ret: nil}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, nothing, ErrNoEncoder{Type: nil}, @@ -735,7 +752,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "success struct implementation", testProxy{ret: int64(1234567890)}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeInt64, nil, @@ -743,7 +760,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "success ptr to struct implementation", &testProxy{ret: int64(1234567890)}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeInt64, nil, @@ -759,7 +776,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "success ptr to ptr implementation", &testProxyPtr{ret: int64(1234567890)}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeInt64, nil, @@ -776,7 +793,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "PointerCodec.EncodeValue", - NewPointerCodec(), + &pointerCodec{}, []subtest{ { "nil", @@ -792,7 +809,7 @@ func TestDefaultValueEncoders(t *testing.T) { nil, nil, nothing, - ValueEncoderError{Name: "PointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: reflect.ValueOf(int32(123456))}, + ValueEncoderError{Name: "pointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: reflect.ValueOf(int32(123456))}, }, { "typed nil", @@ -805,7 +822,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "no encoder", &wrong, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, nothing, ErrNoEncoder{Type: reflect.TypeOf(wrong)}, @@ -814,12 +831,12 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "pointer implementation addressable interface", - NewPointerCodec(), + &pointerCodec{}, []subtest{ { "ValueMarshaler", &vmStruct, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -827,7 +844,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Marshaler", &mStruct, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -835,7 +852,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Proxy", &pStruct, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -844,7 +861,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "JavaScriptEncodeValue", - ValueEncoderFunc(dve.JavaScriptEncodeValue), + ValueEncoderFunc(javaScriptEncodeValue), []subtest{ { "wrong type", @@ -859,7 +876,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "SymbolEncodeValue", - ValueEncoderFunc(dve.SymbolEncodeValue), + ValueEncoderFunc(symbolEncodeValue), []subtest{ { "wrong type", @@ -874,7 +891,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "BinaryEncodeValue", - ValueEncoderFunc(dve.BinaryEncodeValue), + ValueEncoderFunc(binaryEncodeValue), []subtest{ { "wrong type", @@ -889,7 +906,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "UndefinedEncodeValue", - ValueEncoderFunc(dve.UndefinedEncodeValue), + ValueEncoderFunc(undefinedEncodeValue), []subtest{ { "wrong type", @@ -904,7 +921,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "DateTimeEncodeValue", - ValueEncoderFunc(dve.DateTimeEncodeValue), + ValueEncoderFunc(dateTimeEncodeValue), []subtest{ { "wrong type", @@ -919,7 +936,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "NullEncodeValue", - ValueEncoderFunc(dve.NullEncodeValue), + ValueEncoderFunc(nullEncodeValue), []subtest{ { "wrong type", @@ -934,7 +951,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "RegexEncodeValue", - ValueEncoderFunc(dve.RegexEncodeValue), + ValueEncoderFunc(regexEncodeValue), []subtest{ { "wrong type", @@ -949,7 +966,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "DBPointerEncodeValue", - ValueEncoderFunc(dve.DBPointerEncodeValue), + ValueEncoderFunc(dbPointerEncodeValue), []subtest{ { "wrong type", @@ -971,7 +988,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "TimestampEncodeValue", - ValueEncoderFunc(dve.TimestampEncodeValue), + ValueEncoderFunc(timestampEncodeValue), []subtest{ { "wrong type", @@ -986,7 +1003,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MinKeyEncodeValue", - ValueEncoderFunc(dve.MinKeyEncodeValue), + ValueEncoderFunc(minKeyEncodeValue), []subtest{ { "wrong type", @@ -1001,7 +1018,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MaxKeyEncodeValue", - ValueEncoderFunc(dve.MaxKeyEncodeValue), + ValueEncoderFunc(maxKeyEncodeValue), []subtest{ { "wrong type", @@ -1016,7 +1033,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CoreDocumentEncodeValue", - ValueEncoderFunc(dve.CoreDocumentEncodeValue), + ValueEncoderFunc(coreDocumentEncodeValue), []subtest{ { "wrong type", @@ -1079,7 +1096,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "interface value", struct{ Foo myInterface }{Foo: myStruct{1}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -1087,7 +1104,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "nil interface value", struct{ Foo myInterface }{Foo: nil}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -1096,7 +1113,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CodeWithScopeEncodeValue", - ValueEncoderFunc(dve.CodeWithScopeEncodeValue), + ValueEncoderFunc(codeWithScopeEncodeValue), []subtest{ { "wrong type", @@ -1124,14 +1141,14 @@ func TestDefaultValueEncoders(t *testing.T) { Code: "var hello = 'world';", Scope: D{}, }, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, }, }, }, { "CoreArrayEncodeValue", - defaultArrayCodec, + &arrayCodec{}, []subtest{ { "wrong type", @@ -1182,16 +1199,12 @@ func TestDefaultValueEncoders(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, subtest := range tc.subtests { t.Run(subtest.name, func(t *testing.T) { - var ec EncodeContext - if subtest.ectx != nil { - ec = *subtest.ectx - } llvrw := new(valueReaderWriter) if subtest.llvrw != nil { llvrw = subtest.llvrw } llvrw.T = t - err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val)) + err := tc.ve.EncodeValue(subtest.reg, llvrw, reflect.ValueOf(subtest.val)) if !assert.CompareErrors(err, subtest.err) { t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) } @@ -1771,7 +1784,7 @@ func TestDefaultValueEncoders(t *testing.T) { reg := buildDefaultRegistry() enc, err := reg.LookupEncoder(reflect.TypeOf(tc.value)) noerr(t, err) - err = enc.EncodeValue(EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.value)) + err = enc.EncodeValue(reg, vw, reflect.ValueOf(tc.value)) if !errors.Is(err, tc.err) { t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) } @@ -1821,7 +1834,7 @@ func TestDefaultValueEncoders(t *testing.T) { reg := buildDefaultRegistry() enc, err := reg.LookupEncoder(reflect.TypeOf(tc.value)) noerr(t, err) - err = enc.EncodeValue(EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.value)) + err = enc.EncodeValue(reg, vw, reflect.ValueOf(tc.value)) if err == nil || !strings.Contains(err.Error(), tc.err.Error()) { t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) } @@ -1832,7 +1845,7 @@ func TestDefaultValueEncoders(t *testing.T) { t.Run("EmptyInterfaceEncodeValue/nil", func(t *testing.T) { val := reflect.New(tEmpty).Elem() llvrw := new(valueReaderWriter) - err := dve.EmptyInterfaceEncodeValue(EncodeContext{Registry: newTestRegistryBuilder().Build()}, llvrw, val) + err := (&emptyInterfaceCodec{}).EncodeValue(newTestRegistryBuilder().Build(), llvrw, val) noerr(t, err) if llvrw.invoked != writeNull { t.Errorf("Incorrect method called. got %v; want %v", llvrw.invoked, writeNull) @@ -1843,7 +1856,7 @@ func TestDefaultValueEncoders(t *testing.T) { val := reflect.New(tEmpty).Elem() val.Set(reflect.ValueOf(int64(1234567890))) llvrw := new(valueReaderWriter) - got := dve.EmptyInterfaceEncodeValue(EncodeContext{Registry: newTestRegistryBuilder().Build()}, llvrw, val) + got := (&emptyInterfaceCodec{}).EncodeValue(newTestRegistryBuilder().Build(), llvrw, val) want := ErrNoEncoder{Type: tInt64} if !assert.CompareErrors(got, want) { t.Errorf("Did not receive expected error. got %v; want %v", got, want) diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index 56468e3068..cf30014859 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -8,47 +8,23 @@ package bson import ( "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -// EmptyInterfaceCodec is the Codec used for interface{} values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// EmptyInterfaceCodec registered. -type EmptyInterfaceCodec struct { - // DecodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the - // "Generic" or "Old" BSON binary subtype as a Go byte slice instead of a Binary. - // - // Deprecated: Use bson.Decoder.BinaryAsSlice instead. - DecodeBinaryAsSlice bool -} - -var ( - defaultEmptyInterfaceCodec = NewEmptyInterfaceCodec() - - // Assert that defaultEmptyInterfaceCodec satisfies the typeDecoder interface, which allows it - // to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a - // collection. - _ typeDecoder = defaultEmptyInterfaceCodec -) +// emptyInterfaceCodec is the Codec used for interface{} values. +type emptyInterfaceCodec struct { + // defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the + // usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is + // set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an + // error. DocumentType overrides the Ancestor field. + defaultDocumentType reflect.Type -// NewEmptyInterfaceCodec returns a EmptyInterfaceCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// EmptyInterfaceCodec registered. -func NewEmptyInterfaceCodec(opts ...*bsonoptions.EmptyInterfaceCodecOptions) *EmptyInterfaceCodec { - interfaceOpt := bsonoptions.MergeEmptyInterfaceCodecOptions(opts...) - - codec := EmptyInterfaceCodec{} - if interfaceOpt.DecodeBinaryAsSlice != nil { - codec.DecodeBinaryAsSlice = *interfaceOpt.DecodeBinaryAsSlice - } - return &codec + // decodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the + // "Generic" or "Old" BSON binary subtype as a Go byte slice instead of a Binary. + decodeBinaryAsSlice bool } // EncodeValue is the ValueEncoderFunc for interface{}. -func (eic EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (eic *emptyInterfaceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tEmpty { return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} } @@ -56,31 +32,31 @@ func (eic EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val if val.IsNil() { return vw.WriteNull() } - encoder, err := ec.LookupEncoder(val.Elem().Type()) + encoder, err := reg.LookupEncoder(val.Elem().Type()) if err != nil { return err } - return encoder.EncodeValue(ec, vw, val.Elem()) + return encoder.EncodeValue(reg, vw, val.Elem()) } -func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) { +func (eic *emptyInterfaceCodec) getEmptyInterfaceDecodeType(reg DecoderRegistry, valueType Type, ancestorType reflect.Type) (reflect.Type, error) { isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument if isDocument { - if dc.defaultDocumentType != nil { + if eic.defaultDocumentType != nil { // If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return // that type. - return dc.defaultDocumentType, nil + return eic.defaultDocumentType, nil } - if dc.Ancestor != nil { + if ancestorType != nil && ancestorType != tEmpty { // Using ancestor information rather than looking up the type map entry forces consistent decoding. // If we're decoding into a bson.D, subdocuments should also be decoded as bson.D, even if a type map entry // has been registered. - return dc.Ancestor, nil + return ancestorType, nil } } - rtype, err := dc.LookupTypeMapEntry(valueType) + rtype, err := reg.LookupTypeMapEntry(valueType) if err == nil { return rtype, nil } @@ -96,7 +72,7 @@ func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, val lookupType = Type(0) } - rtype, err = dc.LookupTypeMapEntry(lookupType) + rtype, err = reg.LookupTypeMapEntry(lookupType) if err == nil { return rtype, nil } @@ -105,32 +81,31 @@ func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, val return nil, err } -func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - if t != tEmpty { - return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)} - } - - rtype, err := eic.getEmptyInterfaceDecodeType(dc, vr.Type()) +func (eic *emptyInterfaceCodec) decodeType(reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { + rtype, err := eic.getEmptyInterfaceDecodeType(reg, vr.Type(), t) if err != nil { switch vr.Type() { case TypeNull: - return reflect.Zero(t), vr.ReadNull() + return reflect.Zero(tEmpty), vr.ReadNull() default: return emptyValue, err } } - decoder, err := dc.LookupDecoder(rtype) + decoder, err := reg.LookupDecoder(rtype) if err != nil { return emptyValue, err } - elem, err := decodeTypeOrValue(decoder, dc, vr, rtype) + elem, err := decodeTypeOrValueWithInfo(decoder, reg, vr, rtype) if err != nil { return emptyValue, err } + if elem.Type() != rtype { + elem = elem.Convert(rtype) + } - if (eic.DecodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary { + if eic.decodeBinaryAsSlice && rtype == tBinary { binElem := elem.Interface().(Binary) if binElem.Subtype == TypeBinaryGeneric || binElem.Subtype == TypeBinaryBinaryOld { elem = reflect.ValueOf(binElem.Data) @@ -141,12 +116,12 @@ func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t re } // DecodeValue is the ValueDecoderFunc for interface{}. -func (eic EmptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (eic *emptyInterfaceCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tEmpty { return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} } - elem, err := eic.decodeType(dc, vr, val.Type()) + elem, err := eic.decodeType(reg, vr, val.Type()) if err != nil { return err } diff --git a/bson/encoder.go b/bson/encoder.go index fb865cd285..1317bee79e 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -8,39 +8,34 @@ package bson import ( "reflect" - "sync" ) -// This pool is used to keep the allocations of Encoders down. This is only used for the Marshal* -// methods and is not consumable from outside of this package. The Encoders retrieved from this pool -// must have both Reset and SetRegistry called on them. -var encPool = sync.Pool{ - New: func() interface{} { - return new(Encoder) - }, +// ConfigurableEncoderRegistry refers a EncoderRegistry that is configurable with *RegistryOpt. +type ConfigurableEncoderRegistry interface { + EncoderRegistry + SetCodecOption(opt *RegistryOpt) error } // An Encoder writes a serialization format to an output stream. It writes to a ValueWriter // as the destination of BSON data. type Encoder struct { - ec EncodeContext - vw ValueWriter - - errorOnInlineDuplicates bool - intMinSize bool - stringifyMapKeysWithFmt bool - nilMapAsEmpty bool - nilSliceAsEmpty bool - nilByteSliceAsEmpty bool - omitZeroStruct bool - useJSONStructTags bool + reg ConfigurableEncoderRegistry + vw ValueWriter } -// NewEncoder returns a new encoder that uses the DefaultRegistry to write to vw. +// NewEncoder returns a new encoder that uses the default registry to write to vw. func NewEncoder(vw ValueWriter) *Encoder { return &Encoder{ - ec: EncodeContext{Registry: DefaultRegistry}, - vw: vw, + reg: NewRegistryBuilder().Build(), + vw: vw, + } +} + +// NewEncoderWithRegistry returns a new encoder that uses the given registry to write to vw. +func NewEncoderWithRegistry(r ConfigurableEncoderRegistry, vw ValueWriter) *Encoder { + return &Encoder{ + reg: r, + vw: vw, } } @@ -57,103 +52,15 @@ func (e *Encoder) Encode(val interface{}) error { return copyDocumentFromBytes(e.vw, buf) } - encoder, err := e.ec.LookupEncoder(reflect.TypeOf(val)) + encoder, err := e.reg.LookupEncoder(reflect.TypeOf(val)) if err != nil { return err } - // Copy the configurations applied to the Encoder over to the EncodeContext, which actually - // communicates those configurations to the default ValueEncoders. - if e.errorOnInlineDuplicates { - e.ec.ErrorOnInlineDuplicates() - } - if e.intMinSize { - e.ec.MinSize = true - } - if e.stringifyMapKeysWithFmt { - e.ec.StringifyMapKeysWithFmt() - } - if e.nilMapAsEmpty { - e.ec.NilMapAsEmpty() - } - if e.nilSliceAsEmpty { - e.ec.NilSliceAsEmpty() - } - if e.nilByteSliceAsEmpty { - e.ec.NilByteSliceAsEmpty() - } - if e.omitZeroStruct { - e.ec.OmitZeroStruct() - } - if e.useJSONStructTags { - e.ec.UseJSONStructTags() - } - - return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val)) -} - -// Reset will reset the state of the Encoder, using the same *EncodeContext used in -// the original construction but using vw. -func (e *Encoder) Reset(vw ValueWriter) { - e.vw = vw -} - -// SetRegistry replaces the current registry of the Encoder with r. -func (e *Encoder) SetRegistry(r *Registry) { - e.ec.Registry = r -} - -// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in -// the marshaled BSON when the "inline" struct tag option is set. -func (e *Encoder) ErrorOnInlineDuplicates() { - e.errorOnInlineDuplicates = true -} - -// IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint, -// uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can -// represent the integer value. -func (e *Encoder) IntMinSize() { - e.intMinSize = true -} - -// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name -// strings using fmt.Sprint instead of the default string conversion logic. -func (e *Encoder) StringifyMapKeysWithFmt() { - e.stringifyMapKeysWithFmt = true -} - -// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON -// null. -func (e *Encoder) NilMapAsEmpty() { - e.nilMapAsEmpty = true -} - -// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON -// null. -func (e *Encoder) NilSliceAsEmpty() { - e.nilSliceAsEmpty = true -} - -// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values -// instead of BSON null. -func (e *Encoder) NilByteSliceAsEmpty() { - e.nilByteSliceAsEmpty = true -} - -// TODO(GODRIVER-2820): Update the description to remove the note about only examining exported -// TODO struct fields once the logic is updated to also inspect private struct fields. - -// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{}) -// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set. -// -// Note that the Encoder only examines exported struct fields when determining if a struct is the -// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. -func (e *Encoder) OmitZeroStruct() { - e.omitZeroStruct = true + return encoder.EncodeValue(e.reg, e.vw, reflect.ValueOf(val)) } -// UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson" -// struct tag is not specified. -func (e *Encoder) UseJSONStructTags() { - e.useJSONStructTags = true +// SetBehavior set the encoder behavior with *RegistryOpt. +func (e *Encoder) SetBehavior(opt *RegistryOpt) error { + return e.reg.SetCodecOption(opt) } diff --git a/bson/encoder_example_test.go b/bson/encoder_example_test.go index 5c34192db4..60e85fabd3 100644 --- a/bson/encoder_example_test.go +++ b/bson/encoder_example_test.go @@ -53,7 +53,33 @@ func (k CityState) String() string { return fmt.Sprintf("%s, %s", k.City, k.State) } -func ExampleEncoder_StringifyMapKeysWithFmt() { +func ExampleEncoder_SetBehavior_intMinSize() { + // Create an encoder that will marshal integers as the minimum BSON int size + // (either 32 or 64 bits) that can represent the integer value. + type foo struct { + Bar uint32 + } + + buf := new(bytes.Buffer) + vw := bson.NewValueWriter(buf) + + enc := bson.NewEncoder(vw) + err := enc.SetBehavior(bson.IntMinSize) + if err != nil { + panic(err) + } + + err = enc.Encode(foo{2}) + if err != nil { + panic(err) + } + + fmt.Println(bson.Raw(buf.Bytes()).String()) + // Output: + // {"bar": {"$numberInt":"2"}} +} + +func ExampleEncoder_SetBehavior_stringifyMapKeysWithFmt() { // Create an Encoder that writes BSON values to a bytes.Buffer. buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) @@ -61,14 +87,17 @@ func ExampleEncoder_StringifyMapKeysWithFmt() { // Configure the Encoder to convert Go map keys to BSON document field names // using fmt.Sprintf instead of the default string conversion logic. - encoder.StringifyMapKeysWithFmt() + err := encoder.SetBehavior(bson.StringifyMapKeysWithFmt) + if err != nil { + panic(err) + } // Use the Encoder to marshal a BSON document that contains is a map of // city and state to a list of zip codes in that city. zipCodes := map[CityState][]int{ {City: "New York", State: "NY"}: {10001, 10301, 10451}, } - err := encoder.Encode(zipCodes) + err = encoder.Encode(zipCodes) if err != nil { panic(err) } @@ -78,7 +107,7 @@ func ExampleEncoder_StringifyMapKeysWithFmt() { // Output: {"New York, NY": [{"$numberInt":"10001"},{"$numberInt":"10301"},{"$numberInt":"10451"}]} } -func ExampleEncoder_UseJSONStructTags() { +func ExampleEncoder_SetBehavior_useJSONStructTags() { // Create an Encoder that writes BSON values to a bytes.Buffer. buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) @@ -92,7 +121,10 @@ func ExampleEncoder_UseJSONStructTags() { // Configure the Encoder to use "json" struct tags when decoding if "bson" // struct tags are not present. - encoder.UseJSONStructTags() + err := encoder.SetBehavior(bson.UseJSONStructTags) + if err != nil { + panic(err) + } // Use the Encoder to marshal a BSON document that contains the name, SKU, // and price (in cents) of a product. @@ -101,7 +133,7 @@ func ExampleEncoder_UseJSONStructTags() { SKU: "AB12345", Price: 399, } - err := encoder.Encode(product) + err = encoder.Encode(product) if err != nil { panic(err) } @@ -215,26 +247,3 @@ func ExampleEncoder_multipleExtendedJSONDocuments() { // {"x":{"$numberInt":"3"},"y":{"$numberInt":"4"}} // {"x":{"$numberInt":"4"},"y":{"$numberInt":"5"}} } - -func ExampleEncoder_IntMinSize() { - // Create an encoder that will marshal integers as the minimum BSON int size - // (either 32 or 64 bits) that can represent the integer value. - type foo struct { - Bar uint32 - } - - buf := new(bytes.Buffer) - vw := bson.NewValueWriter(buf) - - enc := bson.NewEncoder(vw) - enc.IntMinSize() - - err := enc.Encode(foo{2}) - if err != nil { - panic(err) - } - - fmt.Println(bson.Raw(buf.Bytes()).String()) - // Output: - // {"bar": {"$numberInt":"2"}} -} diff --git a/bson/encoder_test.go b/bson/encoder_test.go index 999b9962ef..a9f7376b5a 100644 --- a/bson/encoder_test.go +++ b/bson/encoder_test.go @@ -22,10 +22,10 @@ func TestBasicEncode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { got := make(SliceWriter, 0, 1024) vw := NewValueWriter(&got) - reg := DefaultRegistry + reg := NewRegistryBuilder().Build() encoder, err := reg.LookupEncoder(reflect.TypeOf(tc.val)) noerr(t, err) - err = encoder.EncodeValue(EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.val)) + err = encoder.EncodeValue(reg, vw, reflect.ValueOf(tc.val)) noerr(t, err) if !bytes.Equal(got, tc.want) { @@ -160,7 +160,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "ErrorOnInlineDuplicates", configure: func(enc *Encoder) { - enc.ErrorOnInlineDuplicates() + _ = enc.SetBehavior(ErrorOnInlineDuplicates) }, input: inlineDuplicateOuter{ Inline: inlineDuplicateInner{Duplicate: "inner"}, @@ -173,7 +173,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "IntMinSize", configure: func(enc *Encoder) { - enc.IntMinSize() + _ = enc.SetBehavior(IntMinSize) }, input: D{ {Key: "myInt", Value: int(1)}, @@ -194,7 +194,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "StringifyMapKeysWithFmt", configure: func(enc *Encoder) { - enc.StringifyMapKeysWithFmt() + _ = enc.SetBehavior(StringifyMapKeysWithFmt) }, input: map[stringerTest]string{ {}: "test value", @@ -207,7 +207,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "NilMapAsEmpty", configure: func(enc *Encoder) { - enc.NilMapAsEmpty() + _ = enc.SetBehavior(NilMapAsEmpty) }, input: D{{Key: "myMap", Value: map[string]string(nil)}}, want: bsoncore.NewDocumentBuilder(). @@ -218,7 +218,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "NilSliceAsEmpty", configure: func(enc *Encoder) { - enc.NilSliceAsEmpty() + _ = enc.SetBehavior(NilSliceAsEmpty) }, input: D{{Key: "mySlice", Value: []string(nil)}}, want: bsoncore.NewDocumentBuilder(). @@ -229,7 +229,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "NilByteSliceAsEmpty", configure: func(enc *Encoder) { - enc.NilByteSliceAsEmpty() + _ = enc.SetBehavior(NilByteSliceAsEmpty) }, input: D{{Key: "myBytes", Value: []byte(nil)}}, want: bsoncore.NewDocumentBuilder(). @@ -241,7 +241,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "OmitZeroStruct", configure: func(enc *Encoder) { - enc.OmitZeroStruct() + _ = enc.SetBehavior(OmitZeroStruct) }, input: struct { Zero zeroStruct `bson:",omitempty"` @@ -253,7 +253,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "UseJSONStructTags", configure: func(enc *Encoder) { - enc.UseJSONStructTags() + _ = enc.SetBehavior(UseJSONStructTags) }, input: struct { StructFieldName string `json:"jsonFieldName"` diff --git a/bson/map_codec.go b/bson/map_codec.go index 9592957db4..0089c75717 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -12,34 +12,21 @@ import ( "fmt" "reflect" "strconv" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -var defaultMapCodec = NewMapCodec() - -// MapCodec is the Codec used for map values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// MapCodec registered. -type MapCodec struct { - // DecodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination +// mapCodec is the Codec used for map values. +type mapCodec struct { + // decodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination // value passed to Decode before unmarshaling BSON documents into them. - // - // Deprecated: Use bson.Decoder.ZeroMaps instead. - DecodeZerosMap bool + decodeZerosMap bool - // EncodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of + // encodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of // BSON null. - // - // Deprecated: Use bson.Encoder.NilMapAsEmpty instead. - EncodeNilAsEmpty bool + encodeNilAsEmpty bool - // EncodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name + // encodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name // strings using fmt.Sprintf() instead of the default string conversion logic. - // - // Deprecated: Use bson.Encoder.StringifyMapKeysWithFmt instead. - EncodeKeysWithStringer bool + encodeKeysWithStringer bool } // KeyMarshaler is the interface implemented by an object that can marshal itself into a string key. @@ -58,33 +45,13 @@ type KeyUnmarshaler interface { UnmarshalKey(key string) error } -// NewMapCodec returns a MapCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// MapCodec registered. -func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec { - mapOpt := bsonoptions.MergeMapCodecOptions(opts...) - - codec := MapCodec{} - if mapOpt.DecodeZerosMap != nil { - codec.DecodeZerosMap = *mapOpt.DecodeZerosMap - } - if mapOpt.EncodeNilAsEmpty != nil { - codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty - } - if mapOpt.EncodeKeysWithStringer != nil { - codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer - } - return &codec -} - // EncodeValue is the ValueEncoder for map[*]* types. -func (mc *MapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (mc *mapCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Map { return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} } - if val.IsNil() && !mc.EncodeNilAsEmpty && !ec.nilMapAsEmpty { + if val.IsNil() && !mc.encodeNilAsEmpty { // If we have a nil map but we can't WriteNull, that means we're probably trying to encode // to a TopLevel document. We can't currently tell if this is what actually happened, but if // there's a deeper underlying problem, the error will also be returned from WriteDocument, @@ -101,23 +68,27 @@ func (mc *MapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Va return err } - return mc.mapEncodeValue(ec, dw, val, nil) + err = mc.encodeMapElements(reg, dw, val, nil) + if err != nil { + return err + } + return dw.WriteDocumentEnd() } -// mapEncodeValue handles encoding of the values of a map. The collisionFn returns +// encodeMapElements handles encoding of the values of a map. The collisionFn returns // true if the provided key exists, this is mainly used for inline maps in the // struct codec. -func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { +func (mc *mapCodec) encodeMapElements(reg EncoderRegistry, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) + encoder, err := reg.LookupEncoder(elemType) if err != nil && elemType.Kind() != reflect.Interface { return err } keys := val.MapKeys() for _, key := range keys { - keyStr, err := mc.encodeKey(key, ec.stringifyMapKeysWithFmt) + keyStr, err := mc.encodeKey(key, mc.encodeKeysWithStringer) if err != nil { return err } @@ -126,7 +97,7 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val refl return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) } - currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key)) + currEncoder, currVal, lookupErr := lookupElementEncoder(reg, encoder, val.MapIndex(key)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -144,17 +115,17 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val refl continue } - err = currEncoder.EncodeValue(ec, vw, currVal) + err = currEncoder.EncodeValue(reg, vw, currVal) if err != nil { return err } } - return dw.WriteDocumentEnd() + return nil } // DecodeValue is the ValueDecoder for map[string/decimal]* types. -func (mc *MapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (mc *mapCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) { return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} } @@ -180,19 +151,18 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va val.Set(reflect.MakeMap(val.Type())) } - if val.Len() > 0 && (mc.DecodeZerosMap || dc.zeroMaps) { + if val.Len() > 0 && mc.decodeZerosMap { clearMap(val) } eType := val.Type().Elem() - decoder, err := dc.LookupDecoder(eType) + decoder, err := reg.LookupDecoder(eType) if err != nil { return err } - eTypeDecoder, _ := decoder.(typeDecoder) if eType == tEmpty { - dc.Ancestor = val.Type() + eType = val.Type() } keyType := val.Type().Key() @@ -211,10 +181,13 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va return err } - elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true) + elem, err := decodeTypeOrValueWithInfo(decoder, reg, vr, eType) if err != nil { return newDecodeError(key, err) } + if t := val.Type().Elem(); elem.Type() != t { + elem = elem.Convert(t) + } val.SetMapIndex(k, elem) } @@ -228,8 +201,8 @@ func clearMap(m reflect.Value) { } } -func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) { - if mc.EncodeKeysWithStringer || encodeKeysWithStringer { +func (mc *mapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) { + if mc.encodeKeysWithStringer || encodeKeysWithStringer { return fmt.Sprint(val), nil } @@ -274,12 +247,12 @@ func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (s var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem() var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() -func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) { +func (mc *mapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) { keyVal := reflect.ValueOf(key) var err error switch { // First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler - case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType): + case !mc.encodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType): keyVal = reflect.New(keyType) v := keyVal.Interface().(KeyUnmarshaler) err = v.UnmarshalKey(key) @@ -309,7 +282,7 @@ func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, } keyVal = reflect.ValueOf(n).Convert(keyType) case reflect.Float32, reflect.Float64: - if mc.EncodeKeysWithStringer { + if mc.encodeKeysWithStringer { parsed, err := strconv.ParseFloat(key, 64) if err != nil { return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %w", keyType.Kind(), err) diff --git a/bson/marshal.go b/bson/marshal.go index 573de16398..db0bf47fae 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -70,11 +70,7 @@ func Marshal(val interface{}) ([]byte, error) { } }() sw.Reset() - vw := NewValueWriter(sw) - enc := encPool.Get().(*Encoder) - defer encPool.Put(enc) - enc.Reset(vw) - enc.SetRegistry(DefaultRegistry) + enc := NewEncoderWithRegistry(defaultRegistry, NewValueWriter(sw)) err := enc.Encode(val) if err != nil { return nil, err @@ -85,10 +81,10 @@ func Marshal(val interface{}) ([]byte, error) { // MarshalValue returns the BSON encoding of val. // -// MarshalValue will use bson.DefaultRegistry to transform val into a BSON value. If val is a struct, this function will +// MarshalValue will use default registry to transform val into a BSON value. If val is a struct, this function will // inspect struct tags and alter the marshalling process accordingly. func MarshalValue(val interface{}) (Type, []byte, error) { - return MarshalValueWithRegistry(DefaultRegistry, val) + return MarshalValueWithRegistry(defaultRegistry, val) } // MarshalValueWithRegistry returns the BSON encoding of val using Registry r. @@ -100,10 +96,7 @@ func MarshalValueWithRegistry(r *Registry, val interface{}) (Type, []byte, error vwFlusher := bvwPool.GetAtModeElement(&sw) // get an Encoder and encode the value - enc := encPool.Get().(*Encoder) - defer encPool.Put(enc) - enc.Reset(vwFlusher) - enc.ec = EncodeContext{Registry: r} + enc := NewEncoderWithRegistry(r, vwFlusher) if err := enc.Encode(val); err != nil { return 0, nil, err } @@ -123,12 +116,7 @@ func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error) ejvw := extjPool.Get(&sw, canonical, escapeHTML) defer extjPool.Put(ejvw) - enc := encPool.Get().(*Encoder) - defer encPool.Put(enc) - - enc.Reset(ejvw) - enc.ec = EncodeContext{Registry: DefaultRegistry} - + enc := NewEncoderWithRegistry(defaultRegistry, ejvw) err := enc.Encode(val) if err != nil { return nil, err diff --git a/bson/marshal_test.go b/bson/marshal_test.go index ecf67d8493..93787338a2 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -28,37 +28,11 @@ func TestMarshalWithRegistry(t *testing.T) { if tc.reg != nil { reg = tc.reg } else { - reg = DefaultRegistry + reg = defaultRegistry } buf := new(bytes.Buffer) vw := NewValueWriter(buf) - enc := NewEncoder(vw) - enc.SetRegistry(reg) - err := enc.Encode(tc.val) - noerr(t, err) - - if got := buf.Bytes(); !bytes.Equal(got, tc.want) { - t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) - t.Errorf("Bytes:\n%v\n%v", got, tc.want) - } - }) - } -} - -func TestMarshalWithContext(t *testing.T) { - for _, tc := range marshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - var reg *Registry - if tc.reg != nil { - reg = tc.reg - } else { - reg = DefaultRegistry - } - buf := new(bytes.Buffer) - vw := NewValueWriter(buf) - enc := NewEncoder(vw) - enc.IntMinSize() - enc.SetRegistry(reg) + enc := NewEncoderWithRegistry(reg, vw) err := enc.Encode(tc.val) noerr(t, err) @@ -149,15 +123,16 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { // different Registry is used. // Create a custom Registry that negates int32 values when encoding. - var encodeInt32 ValueEncoderFunc = func(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + var encodeInt32 ValueEncoderFunc = func(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.Int32 { return fmt.Errorf("expected kind to be int32, got %v", val.Kind()) } return vw.WriteInt32(int32(val.Int()) * -1) } - customReg := NewRegistry() - customReg.RegisterTypeEncoder(tInt32, encodeInt32) + customReg := NewRegistryBuilder(). + RegisterTypeEncoder(tInt32, func(*Registry) ValueEncoder { return encodeInt32 }). + Build() // Helper function to run the test and make assertions. The provided original value should result in the document // {"x": {$numberInt: 1}} when marshalled with the default registry. @@ -174,8 +149,7 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { buf := new(bytes.Buffer) vw := NewValueWriter(buf) - enc := NewEncoder(vw) - enc.SetRegistry(customReg) + enc := NewEncoderWithRegistry(customReg, vw) err = enc.Encode(original) assert.Nil(t, err, "Encode error: %v", err) second := buf.Bytes() diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 6651509983..c31605d105 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -79,15 +79,13 @@ var sampleItems = []testItemType{ "\x13\x00\x00\x00\x05slice\x00\x02\x00\x00\x00\x00\x01\x02\x00"}, } -func TestMarshalSampleItems(t *testing.T) { +func TestEncodeSampleItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range sampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) str := buf.String() @@ -96,11 +94,13 @@ func TestMarshalSampleItems(t *testing.T) { } } -func TestUnmarshalSampleItems(t *testing.T) { +func TestDecodeSampleItems(t *testing.T) { for i, item := range sampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { value := bson.M{} - err := bson.UnmarshalWithRegistry(Registry, []byte(item.data), &value) + vr := bson.NewValueReader([]byte(item.data)) + dec := bson.NewDecoderWithRegistry(Registry, vr) + err := dec.Decode(&value) assert.Nil(t, err, "expected nil error, got: %v", err) assert.True(t, reflect.DeepEqual(value, item.obj), "expected: %v, got: %v", item.obj, value) }) @@ -164,15 +164,13 @@ var allItems = []testItemType{ "\xFF_\x00"}, } -func TestMarshalAllItems(t *testing.T) { +func TestEncodeAllItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range allItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) str := buf.String() @@ -181,11 +179,13 @@ func TestMarshalAllItems(t *testing.T) { } } -func TestUnmarshalAllItems(t *testing.T) { +func TestDecodeAllItems(t *testing.T) { for i, item := range allItems { t.Run(strconv.Itoa(i), func(t *testing.T) { value := bson.M{} - err := bson.UnmarshalWithRegistry(Registry, []byte(wrapInDoc(item.data)), &value) + vr := bson.NewValueReader([]byte(wrapInDoc(item.data))) + dec := bson.NewDecoderWithRegistry(Registry, vr) + err := dec.Decode(&value) assert.Nil(t, err, "expected nil error, got: %v", err) assert.True(t, reflect.DeepEqual(value, item.obj), "expected: %v, got: %v", item.obj, value) }) @@ -220,8 +220,7 @@ func TestUnmarshalRawIncompatible(t *testing.T) { func TestUnmarshalZeroesStruct(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(bson.M{"b": 2}) assert.Nil(t, err, "expected nil error, got: %v", err) type T struct{ A, B int } @@ -235,8 +234,7 @@ func TestUnmarshalZeroesStruct(t *testing.T) { func TestUnmarshalZeroesMap(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(bson.M{"b": 2}) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{"a": 1} @@ -250,8 +248,7 @@ func TestUnmarshalZeroesMap(t *testing.T) { func TestUnmarshalNonNilInterface(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(bson.M{"b": 2}) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{"a": 1} @@ -288,13 +285,11 @@ func TestPtrInline(t *testing.T) { } buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, cs := range cases { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(cs.In) assert.Nil(t, err, "expected nil error, got: %v", err) var dataBSON bson.M @@ -377,13 +372,11 @@ var oneWayMarshalItems = []testItemType{ func TestOneWayMarshalItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range oneWayMarshalItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -414,13 +407,11 @@ var structSampleItems = []testItemType{ func TestMarshalStructSampleItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range structSampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, item.data, buf.String(), "expected: %v, got: %v", item.data, buf.String()) @@ -441,8 +432,7 @@ func Test64bitInt(t *testing.T) { if int(i) > 0 { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(bson.M{"i": int(i)}) assert.Nil(t, err, "expected nil error, got: %v", err) want := wrapInDoc("\x12i\x00\x00\x00\x00\x80\x00\x00\x00\x00") @@ -471,7 +461,7 @@ func (t *prefixPtr) GetBSON() (interface{}, error) { func (t *prefixPtr) SetBSON(raw bson.RawValue) error { var s string if raw.Type == 0x0A { - return ErrSetZero + return bson.ErrSetZero } rval := reflect.ValueOf(&s).Elem() decoder, err := Registry.LookupDecoder(rval.Type()) @@ -479,7 +469,7 @@ func (t *prefixPtr) SetBSON(raw bson.RawValue) error { return err } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -498,7 +488,7 @@ func (t prefixVal) GetBSON() (interface{}, error) { func (t *prefixVal) SetBSON(raw bson.RawValue) error { var s string if raw.Type == 0x0A { - return ErrSetZero + return bson.ErrSetZero } rval := reflect.ValueOf(&s).Elem() decoder, err := Registry.LookupDecoder(rval.Type()) @@ -506,7 +496,7 @@ func (t *prefixVal) SetBSON(raw bson.RawValue) error { return err } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -580,13 +570,11 @@ var structItems = []testItemType{ func TestMarshalStructItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range structItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, wrapInDoc(item.data), buf.String(), "expected: %v, got: %v", wrapInDoc(item.data), buf.String()) @@ -656,13 +644,11 @@ var marshalItems = []testItemType{ func TestMarshalOneWayItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range marshalItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, wrapInDoc(item.data), buf.String(), "expected: %v, got: %v", wrapInDoc(item.data), buf.String()) @@ -765,13 +751,11 @@ var marshalErrorItems = []testItemType{ func TestMarshalErrorItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range marshalErrorItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.NotNil(t, err, "expected error") @@ -930,7 +914,7 @@ func (o *setterType) SetBSON(raw bson.RawValue) error { raw.Type = bson.TypeEmbeddedDocument } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -1026,13 +1010,12 @@ func TestDMap(t *testing.T) { } func TestUnmarshalSetterErrSetZero(t *testing.T) { - setterResult["foo"] = ErrSetZero + setterResult["foo"] = bson.ErrSetZero defer delete(setterResult, "field") buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(bson.M{"field": "foo"}) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -1066,7 +1049,6 @@ type docWithGetterField struct { func TestMarshalAllItemsWithGetter(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range allItems { if item.data == "" { continue @@ -1076,8 +1058,7 @@ func TestMarshalAllItemsWithGetter(t *testing.T) { obj := &docWithGetterField{} obj.Field = &typeWithGetter{result: item.obj.(bson.M)["_"]} vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, wrapInDoc(item.data), buf.String(), @@ -1090,8 +1071,7 @@ func TestMarshalWholeDocumentWithGetter(t *testing.T) { obj := &typeWithGetter{result: sampleItems[0].obj} buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, sampleItems[0].data, buf.String(), @@ -1105,8 +1085,7 @@ func TestGetterErrors(t *testing.T) { obj1.Field = &typeWithGetter{sampleItems[0].obj, e} buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(obj1) assert.Equal(t, e, err, "expected error: %v, got: %v", e, err) assert.Nil(t, buf.Bytes(), "expected nil data, got: %v", buf.Bytes()) @@ -1114,8 +1093,7 @@ func TestGetterErrors(t *testing.T) { obj2 := &typeWithGetter{sampleItems[0].obj, e} buf.Reset() vw = bson.NewValueWriter(buf) - enc = bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc = bson.NewEncoderWithRegistry(Registry, vw) err = enc.Encode(obj2) assert.Equal(t, e, err, "expected error: %v, got: %v", e, err) assert.Nil(t, buf.Bytes(), "expected nil data, got: %v", buf.Bytes()) @@ -1135,8 +1113,7 @@ func TestMarshalShortWithGetter(t *testing.T) { obj := typeWithIntGetter{42} buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{} @@ -1149,8 +1126,7 @@ func TestMarshalWithGetterNil(t *testing.T) { obj := docWithGetterField{} buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{} @@ -1289,7 +1265,7 @@ func (s *getterSetterD) SetBSON(raw bson.RawValue) error { raw.Type = bson.TypeEmbeddedDocument } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -1315,7 +1291,7 @@ func (i *getterSetterInt) SetBSON(raw bson.RawValue) error { raw.Type = bson.TypeEmbeddedDocument } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -1337,7 +1313,7 @@ func (s *ifaceSlice) SetBSON(raw bson.RawValue) error { return err } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -1442,9 +1418,6 @@ var twoWayCrossItems = []crossTypeItem{ {&struct{ S []byte }{[]byte("def")}, &struct{ S bson.Symbol }{"def"}}, {&struct{ S string }{"ghi"}, &struct{ S bson.Symbol }{"ghi"}}, - {&struct{ S string }{"0123456789ab"}, - &struct{ S bson.ObjectID }{bson.ObjectID{0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x61, 0x62}}}, - // map <=> struct {&struct { A struct { @@ -1597,8 +1570,7 @@ func testCrossPair(t *testing.T, dump interface{}, load interface{}) { zero := makeZeroDoc(load) buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(dump) assert.Nil(t, err, "expected nil error, got: %v", err) err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), zero) @@ -1711,8 +1683,7 @@ func TestMarshalNotRespectNil(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(testStruct1) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -1744,8 +1715,7 @@ func TestMarshalRespectNil(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(testStruct1) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -1773,8 +1743,7 @@ func TestMarshalRespectNil(t *testing.T) { buf.Reset() vw = bson.NewValueWriter(buf) - enc = bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc = bson.NewEncoderWithRegistry(Registry, vw) err = enc.Encode(testStruct1) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -1809,8 +1778,7 @@ func TestInlineWithPointerToSelf(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(x1) assert.Nil(t, err, "expected nil error, got: %v", err) diff --git a/bson/mgocompat/doc.go b/bson/mgocompat/doc.go index 8a9434b1d1..a1c91aff4c 100644 --- a/bson/mgocompat/doc.go +++ b/bson/mgocompat/doc.go @@ -9,11 +9,6 @@ // with mgo's BSON with RespectNilValues set to true. A registry can be configured on a // mongo.Client with the SetRegistry option. See the bson docs for more details on registries. // -// Registry supports Getter and Setter equivalents by registering hooks. Note that if a value -// matches the hook for bson.Marshaler, bson.ValueMarshaler, or bson.Proxy, that -// hook will take priority over the Getter hook. The same is true for the hooks for -// bson.Unmarshaler and bson.ValueUnmarshaler and the Setter hook. -// // The functional differences between Registry and globalsign/mgo's BSON library are: // // 1) Registry errors instead of silently skipping mismatched types when decoding. diff --git a/bson/mgocompat/registry.go b/bson/mgocompat/registry.go index 7024ab9fdc..7ffb90b22e 100644 --- a/bson/mgocompat/registry.go +++ b/bson/mgocompat/registry.go @@ -7,106 +7,12 @@ package mgocompat import ( - "errors" - "reflect" - "time" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/bsonoptions" -) - -var ( - // ErrSetZero may be returned from a SetBSON method to have the value set to its respective zero value. - ErrSetZero = errors.New("set to zero") - - tInt = reflect.TypeOf(int(0)) - tTime = reflect.TypeOf(time.Time{}) - tM = reflect.TypeOf(bson.M{}) - tInterfaceSlice = reflect.TypeOf([]interface{}{}) - tByteSlice = reflect.TypeOf([]byte{}) - tEmpty = reflect.TypeOf((*interface{})(nil)).Elem() - tGetter = reflect.TypeOf((*Getter)(nil)).Elem() - tSetter = reflect.TypeOf((*Setter)(nil)).Elem() ) // Registry is the mgo compatible bson.Registry. It contains the default and // primitive codecs with mgo compatible options. -var Registry = NewRegistryBuilder().Build() +var Registry = bson.NewMgoRegistry() // RespectNilValuesRegistry is the bson.Registry compatible with mgo withSetRespectNilValues set to true. -var RespectNilValuesRegistry = NewRespectNilValuesRegistryBuilder().Build() - -// NewRegistryBuilder creates a new bson.RegistryBuilder configured with the default encoders and -// decoders from the bson.DefaultValueEncoders and bson.DefaultValueDecoders types and the -// PrimitiveCodecs type in this package. -func NewRegistryBuilder() *bson.RegistryBuilder { - rb := bson.NewRegistryBuilder() - bson.DefaultValueEncoders{}.RegisterDefaultEncoders(rb) - bson.DefaultValueDecoders{}.RegisterDefaultDecoders(rb) - bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb) - - structcodec, _ := bson.NewStructCodec(bson.DefaultStructTagParser, - bsonoptions.StructCodec(). - SetDecodeZeroStruct(true). - SetEncodeOmitDefaultStruct(true). - SetOverwriteDuplicatedInlinedFields(false). - SetAllowUnexportedFields(true)) - emptyInterCodec := bson.NewEmptyInterfaceCodec( - bsonoptions.EmptyInterfaceCodec(). - SetDecodeBinaryAsSlice(true)) - mapCodec := bson.NewMapCodec( - bsonoptions.MapCodec(). - SetDecodeZerosMap(true). - SetEncodeNilAsEmpty(true). - SetEncodeKeysWithStringer(true)) - uintcodec := bson.NewUIntCodec(bsonoptions.UIntCodec().SetEncodeToMinSize(true)) - - rb.RegisterTypeDecoder(tEmpty, emptyInterCodec). - RegisterDefaultDecoder(reflect.String, bson.NewStringCodec(bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false))). - RegisterDefaultDecoder(reflect.Struct, structcodec). - RegisterDefaultDecoder(reflect.Map, mapCodec). - RegisterTypeEncoder(tByteSlice, bson.NewByteSliceCodec(bsonoptions.ByteSliceCodec().SetEncodeNilAsEmpty(true))). - RegisterDefaultEncoder(reflect.Struct, structcodec). - RegisterDefaultEncoder(reflect.Slice, bson.NewSliceCodec(bsonoptions.SliceCodec().SetEncodeNilAsEmpty(true))). - RegisterDefaultEncoder(reflect.Map, mapCodec). - RegisterDefaultEncoder(reflect.Uint, uintcodec). - RegisterDefaultEncoder(reflect.Uint8, uintcodec). - RegisterDefaultEncoder(reflect.Uint16, uintcodec). - RegisterDefaultEncoder(reflect.Uint32, uintcodec). - RegisterDefaultEncoder(reflect.Uint64, uintcodec). - RegisterTypeMapEntry(bson.TypeInt32, tInt). - RegisterTypeMapEntry(bson.TypeDateTime, tTime). - RegisterTypeMapEntry(bson.TypeArray, tInterfaceSlice). - RegisterTypeMapEntry(bson.Type(0), tM). - RegisterTypeMapEntry(bson.TypeEmbeddedDocument, tM). - RegisterHookEncoder(tGetter, bson.ValueEncoderFunc(GetterEncodeValue)). - RegisterHookDecoder(tSetter, bson.ValueDecoderFunc(SetterDecodeValue)) - - return rb -} - -// NewRespectNilValuesRegistryBuilder creates a new bson.RegistryBuilder configured to behave like mgo/bson -// with RespectNilValues set to true. -func NewRespectNilValuesRegistryBuilder() *bson.RegistryBuilder { - rb := NewRegistryBuilder() - - structcodec, _ := bson.NewStructCodec(bson.DefaultStructTagParser, - bsonoptions.StructCodec(). - SetDecodeZeroStruct(true). - SetEncodeOmitDefaultStruct(true). - SetOverwriteDuplicatedInlinedFields(false). - SetAllowUnexportedFields(true)) - mapCodec := bson.NewMapCodec( - bsonoptions.MapCodec(). - SetDecodeZerosMap(true). - SetEncodeNilAsEmpty(false)) - - rb.RegisterDefaultDecoder(reflect.Struct, structcodec). - RegisterDefaultDecoder(reflect.Map, mapCodec). - RegisterTypeEncoder(tByteSlice, bson.NewByteSliceCodec(bsonoptions.ByteSliceCodec().SetEncodeNilAsEmpty(false))). - RegisterDefaultEncoder(reflect.Struct, structcodec). - RegisterDefaultEncoder(reflect.Slice, bson.NewSliceCodec(bsonoptions.SliceCodec().SetEncodeNilAsEmpty(false))). - RegisterDefaultEncoder(reflect.Map, mapCodec) - - return rb -} +var RespectNilValuesRegistry = bson.NewRespectNilValuesMgoRegistry() diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go new file mode 100644 index 0000000000..1aa96380cd --- /dev/null +++ b/bson/mgoregistry.go @@ -0,0 +1,84 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "errors" + "reflect" +) + +var ( + // ErrSetZero may be returned from a SetBSON method to have the value set to its respective zero value. + ErrSetZero = errors.New("set to zero") + + tInt = reflect.TypeOf(int(0)) + tM = reflect.TypeOf(M{}) + tInterfaceSlice = reflect.TypeOf([]interface{}{}) + tGetter = reflect.TypeOf((*Getter)(nil)).Elem() + tSetter = reflect.TypeOf((*Setter)(nil)).Elem() +) + +func newMgoRegistryBuilder() *RegistryBuilder { + mapCodec := &mapCodec{ + decodeZerosMap: true, + encodeNilAsEmpty: true, + encodeKeysWithStringer: true, + } + newStructCodec := func(elemEncoder mapElementsEncoder) *structCodec { + return &structCodec{ + elemEncoder: elemEncoder, + decodeZeroStruct: true, + encodeOmitDefaultStruct: true, + allowUnexportedFields: true, + } + } + numcodecFac := func(*Registry) ValueEncoder { return &numCodec{encodeUintToMinSize: true} } + + return NewRegistryBuilder(). + RegisterTypeDecoder(tEmpty, func(*Registry) ValueDecoder { return &emptyInterfaceCodec{decodeBinaryAsSlice: true} }). + RegisterKindDecoder(reflect.Struct, func(*Registry) ValueDecoder { return newStructCodec(nil) }). + RegisterKindDecoder(reflect.Map, func(*Registry) ValueDecoder { return mapCodec }). + RegisterTypeEncoder(tByteSlice, func(*Registry) ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: true} }). + RegisterKindEncoder(reflect.Struct, func(reg *Registry) ValueEncoder { + enc, _ := reg.lookupKindEncoder(reflect.Map) + return newStructCodec(enc.(mapElementsEncoder)) + }). + RegisterKindEncoder(reflect.Slice, func(*Registry) ValueEncoder { return &sliceCodec{encodeNilAsEmpty: true} }). + RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return mapCodec }). + RegisterKindEncoder(reflect.Uint, numcodecFac). + RegisterKindEncoder(reflect.Uint8, numcodecFac). + RegisterKindEncoder(reflect.Uint16, numcodecFac). + RegisterKindEncoder(reflect.Uint32, numcodecFac). + RegisterKindEncoder(reflect.Uint64, numcodecFac). + RegisterTypeMapEntry(TypeInt32, tInt). + RegisterTypeMapEntry(TypeDateTime, tTime). + RegisterTypeMapEntry(TypeArray, tInterfaceSlice). + RegisterTypeMapEntry(Type(0), tM). + RegisterTypeMapEntry(TypeEmbeddedDocument, tM). + RegisterInterfaceEncoder(tGetter, func(*Registry) ValueEncoder { return ValueEncoderFunc(GetterEncodeValue) }). + RegisterInterfaceDecoder(tSetter, func(*Registry) ValueDecoder { return ValueDecoderFunc(SetterDecodeValue) }) +} + +// NewMgoRegistry creates a new bson.Registry configured with the default encoders and decoders. +func NewMgoRegistry() *Registry { + return newMgoRegistryBuilder().Build() +} + +// NewRespectNilValuesMgoRegistry creates a new bson.Registry configured to behave like mgo/bson +// with RespectNilValues set to true. +func NewRespectNilValuesMgoRegistry() *Registry { + mapCodec := &mapCodec{ + decodeZerosMap: true, + } + + return newMgoRegistryBuilder(). + RegisterKindDecoder(reflect.Map, func(*Registry) ValueDecoder { return mapCodec }). + RegisterTypeEncoder(tByteSlice, func(*Registry) ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: false} }). + RegisterKindEncoder(reflect.Slice, func(*Registry) ValueEncoder { return &sliceCodec{} }). + RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return mapCodec }). + Build() +} diff --git a/bson/num_codec.go b/bson/num_codec.go new file mode 100644 index 0000000000..33c16fd15e --- /dev/null +++ b/bson/num_codec.go @@ -0,0 +1,301 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "fmt" + "math" + "reflect" +) + +// numCodec is the Codec used for numeric values. +type numCodec struct { + // minSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, + // uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) + // that can represent the integer value. + minSize bool + + // encodeUintToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the + // minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value. + encodeUintToMinSize bool + + // truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" + // values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64, + // uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to + // BSON "decimal128" values. + truncate bool +} + +// EncodeValue is the ValueEncoder for numeric types. +func (nc *numCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { + switch val.Kind() { + case reflect.Float32, reflect.Float64: + return vw.WriteDouble(val.Float()) + + case reflect.Int8, reflect.Int16, reflect.Int32: + return vw.WriteInt32(int32(val.Int())) + case reflect.Int: + i64 := val.Int() + if fitsIn32Bits(i64) { + return vw.WriteInt32(int32(i64)) + } + return vw.WriteInt64(i64) + case reflect.Int64: + i64 := val.Int() + if nc.minSize && fitsIn32Bits(i64) { + return vw.WriteInt32(int32(i64)) + } + return vw.WriteInt64(i64) + + case reflect.Uint8, reflect.Uint16: + return vw.WriteInt32(int32(val.Uint())) + case reflect.Uint, reflect.Uint32, reflect.Uint64: + u64 := val.Uint() + + // If minSize or encodeToMinSize is true for a non-uint64 value we should write val as an int32 + useMinSize := nc.minSize || (nc.encodeUintToMinSize && val.Kind() != reflect.Uint64) + + if u64 <= math.MaxInt32 && useMinSize { + return vw.WriteInt32(int32(u64)) + } + if u64 > math.MaxInt64 { + return fmt.Errorf("%d overflows int64", u64) + } + return vw.WriteInt64(int64(u64)) + } + + return ValueEncoderError{ + Name: "NumEncodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, + Received: val, + } +} + +func (nc *numCodec) decodeTypeInt(vr ValueReader, t reflect.Type) (reflect.Value, error) { + var i64 int64 + switch vrType := vr.Type(); vrType { + case TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return emptyValue, err + } + i64 = int64(i32) + case TypeInt64: + var err error + i64, err = vr.ReadInt64() + if err != nil { + return emptyValue, err + } + case TypeDouble: + f64, err := vr.ReadDouble() + if err != nil { + return emptyValue, err + } + if !nc.truncate && math.Floor(f64) != f64 { + return emptyValue, errCannotTruncate + } + if f64 > float64(math.MaxInt64) { + return emptyValue, fmt.Errorf("%g overflows int64", f64) + } + i64 = int64(f64) + case TypeBoolean: + b, err := vr.ReadBoolean() + if err != nil { + return emptyValue, err + } + if b { + i64 = 1 + } + case TypeNull: + if err := vr.ReadNull(); err != nil { + return emptyValue, err + } + case TypeUndefined: + if err := vr.ReadUndefined(); err != nil { + return emptyValue, err + } + default: + return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType) + } + + switch t.Kind() { + case reflect.Int8: + if i64 < math.MinInt8 || i64 > math.MaxInt8 { + return emptyValue, fmt.Errorf("%d overflows int8", i64) + } + return reflect.ValueOf(int8(i64)), nil + case reflect.Int16: + if i64 < math.MinInt16 || i64 > math.MaxInt16 { + return emptyValue, fmt.Errorf("%d overflows int16", i64) + } + return reflect.ValueOf(int16(i64)), nil + case reflect.Int32: + if i64 < math.MinInt32 || i64 > math.MaxInt32 { + return emptyValue, fmt.Errorf("%d overflows int32", i64) + } + return reflect.ValueOf(int32(i64)), nil + case reflect.Int64: + return reflect.ValueOf(i64), nil + case reflect.Int: + if int64(int(i64)) != i64 { // Can we fit this inside of an int + return emptyValue, fmt.Errorf("%d overflows int", i64) + } + return reflect.ValueOf(int(i64)), nil + + case reflect.Uint8: + if i64 < 0 || i64 > math.MaxUint8 { + return emptyValue, fmt.Errorf("%d overflows uint8", i64) + } + return reflect.ValueOf(uint8(i64)), nil + case reflect.Uint16: + if i64 < 0 || i64 > math.MaxUint16 { + return emptyValue, fmt.Errorf("%d overflows uint16", i64) + } + return reflect.ValueOf(uint16(i64)), nil + case reflect.Uint32: + if i64 < 0 || i64 > math.MaxUint32 { + return emptyValue, fmt.Errorf("%d overflows uint32", i64) + } + return reflect.ValueOf(uint32(i64)), nil + case reflect.Uint64: + if i64 < 0 { + return emptyValue, fmt.Errorf("%d overflows uint64", i64) + } + return reflect.ValueOf(uint64(i64)), nil + case reflect.Uint: + if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint + return emptyValue, fmt.Errorf("%d overflows uint", i64) + } + return reflect.ValueOf(uint(i64)), nil + + default: + return emptyValue, ValueDecoderError{ + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, + Received: reflect.Zero(t), + } + } +} + +func (nc *numCodec) decodeTypeFloat(vr ValueReader, t reflect.Type) (reflect.Value, error) { + var f float64 + var err error + switch vrType := vr.Type(); vrType { + case TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return emptyValue, err + } + f = float64(i32) + case TypeInt64: + i64, err := vr.ReadInt64() + if err != nil { + return emptyValue, err + } + f = float64(i64) + case TypeDouble: + f, err = vr.ReadDouble() + if err != nil { + return emptyValue, err + } + case TypeBoolean: + b, err := vr.ReadBoolean() + if err != nil { + return emptyValue, err + } + if b { + f = 1 + } + case TypeNull: + if err = vr.ReadNull(); err != nil { + return emptyValue, err + } + case TypeUndefined: + if err = vr.ReadUndefined(); err != nil { + return emptyValue, err + } + default: + return emptyValue, fmt.Errorf("cannot decode %v into a float32 or float64 type", vrType) + } + + switch t.Kind() { + case reflect.Float32: + if !nc.truncate && float64(float32(f)) != f { + return emptyValue, errCannotTruncate + } + + return reflect.ValueOf(float32(f)), nil + case reflect.Float64: + return reflect.ValueOf(f), nil + + default: + return emptyValue, ValueDecoderError{ + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, + Received: reflect.Zero(t), + } + } +} + +func (nc *numCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { + switch t.Kind() { + case reflect.Float32, reflect.Float64: + return nc.decodeTypeFloat(vr, t) + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + return nc.decodeTypeInt(vr, t) + default: + return emptyValue, ValueDecoderError{ + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, + Received: reflect.Zero(t), + } + } +} + +// DecodeValue is the ValueDecoder for numeric types. +func (nc *numCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { + if !val.CanSet() { + return ValueDecoderError{ + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, + Received: val, + } + } + + elem, err := nc.decodeType(reg, vr, val.Type()) + if err != nil { + return err + } + + if t := val.Type(); elem.Type() != t { + elem = elem.Convert(t) + } + + val.Set(elem) + return nil +} diff --git a/bson/pointer_codec.go b/bson/pointer_codec.go index 5946b9cc9f..bca19742bc 100644 --- a/bson/pointer_codec.go +++ b/bson/pointer_codec.go @@ -8,36 +8,23 @@ package bson import ( "reflect" + "sync" ) -var _ ValueEncoder = &PointerCodec{} -var _ ValueDecoder = &PointerCodec{} - -// PointerCodec is the Codec used for pointers. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// PointerCodec registered. -type PointerCodec struct { - ecache typeEncoderCache - dcache typeDecoderCache -} - -// NewPointerCodec returns a PointerCodec that has been initialized. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// PointerCodec registered. -func NewPointerCodec() *PointerCodec { - return &PointerCodec{} +// pointerCodec is the Codec used for pointers. +type pointerCodec struct { + ecache sync.Map // map[reflect.Type]ValueEncoder + dcache sync.Map // map[reflect.Type]ValueDecoder } // EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil // or looking up an encoder for the type of value the pointer points to. -func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (pc *pointerCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.Ptr { if !val.IsValid() { return vw.WriteNull() } - return ValueEncoderError{Name: "PointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} + return ValueEncoderError{Name: "pointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} } if val.IsNil() { @@ -49,22 +36,24 @@ func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflec if v == nil { return ErrNoEncoder{Type: typ} } - return v.EncodeValue(ec, vw, val.Elem()) + return v.(ValueEncoder).EncodeValue(reg, vw, val.Elem()) } // TODO(charlie): handle concurrent requests for the same type - enc, err := ec.LookupEncoder(typ.Elem()) - enc = pc.ecache.LoadOrStore(typ, enc) + enc, err := reg.LookupEncoder(typ.Elem()) if err != nil { return err } - return enc.EncodeValue(ec, vw, val.Elem()) + if v, ok := pc.ecache.LoadOrStore(typ, enc); ok { + enc = v.(ValueEncoder) + } + return enc.EncodeValue(reg, vw, val.Elem()) } // DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and // using that to decode. If the BSON value is Null, this method will set the pointer to nil. -func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (pc *pointerCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Ptr { - return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} + return ValueDecoderError{Name: "pointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} } typ := val.Type() @@ -85,13 +74,15 @@ func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflec if v == nil { return ErrNoDecoder{Type: typ} } - return v.DecodeValue(dc, vr, val.Elem()) + return v.(ValueDecoder).DecodeValue(reg, vr, val.Elem()) } // TODO(charlie): handle concurrent requests for the same type - dec, err := dc.LookupDecoder(typ.Elem()) - dec = pc.dcache.LoadOrStore(typ, dec) + dec, err := reg.LookupDecoder(typ.Elem()) if err != nil { return err } - return dec.DecodeValue(dc, vr, val.Elem()) + if v, ok := pc.dcache.LoadOrStore(typ, dec); ok { + dec = v.(ValueDecoder) + } + return dec.DecodeValue(reg, vr, val.Elem()) } diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index 262645ce4c..2cf68ddf85 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -15,38 +15,24 @@ import ( var tRawValue = reflect.TypeOf(RawValue{}) var tRaw = reflect.TypeOf(Raw(nil)) -// PrimitiveCodecs is a namespace for all of the default Codecs for the primitive types -// defined in this package. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -type PrimitiveCodecs struct{} - -// RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs -// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *RegistryBuilder) { +// registerPrimitiveCodecs will register the encode and decode methods with the provided Registry. +func registerPrimitiveCodecs(rb *RegistryBuilder) { if rb == nil { panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil")) } rb. - RegisterTypeEncoder(tRawValue, ValueEncoderFunc(pc.RawValueEncodeValue)). - RegisterTypeEncoder(tRaw, ValueEncoderFunc(pc.RawEncodeValue)). - RegisterTypeDecoder(tRawValue, ValueDecoderFunc(pc.RawValueDecodeValue)). - RegisterTypeDecoder(tRaw, ValueDecoderFunc(pc.RawDecodeValue)) + RegisterTypeEncoder(tRawValue, func(*Registry) ValueEncoder { return ValueEncoderFunc(rawValueEncodeValue) }). + RegisterTypeEncoder(tRaw, func(*Registry) ValueEncoder { return ValueEncoderFunc(rawEncodeValue) }). + RegisterTypeDecoder(tRawValue, func(*Registry) ValueDecoder { return ValueDecoderFunc(rawValueDecodeValue) }). + RegisterTypeDecoder(tRaw, func(*Registry) ValueDecoder { return ValueDecoderFunc(rawDecodeValue) }) } -// RawValueEncodeValue is the ValueEncoderFunc for RawValue. +// rawValueEncodeValue is the ValueEncoderFunc for RawValue. // // If the RawValue's Type is "invalid" and the RawValue's Value is not empty or // nil, then this method will return an error. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive -// encoders and decoders registered. -func (PrimitiveCodecs) RawValueEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func rawValueEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRawValue { return ValueEncoderError{ Name: "RawValueEncodeValue", @@ -64,11 +50,8 @@ func (PrimitiveCodecs) RawValueEncodeValue(_ EncodeContext, vw ValueWriter, val return copyValueFromBytes(vw, rawvalue.Type, rawvalue.Value) } -// RawValueDecodeValue is the ValueDecoderFunc for RawValue. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (PrimitiveCodecs) RawValueDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// rawValueDecodeValue is the ValueDecoderFunc for RawValue. +func rawValueDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRawValue { return ValueDecoderError{Name: "RawValueDecodeValue", Types: []reflect.Type{tRawValue}, Received: val} } @@ -82,11 +65,8 @@ func (PrimitiveCodecs) RawValueDecodeValue(_ DecodeContext, vr ValueReader, val return nil } -// RawEncodeValue is the ValueEncoderFunc for Reader. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (PrimitiveCodecs) RawEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// rawEncodeValue is the ValueEncoderFunc for Reader. +func rawEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRaw { return ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: val} } @@ -96,11 +76,8 @@ func (PrimitiveCodecs) RawEncodeValue(_ EncodeContext, vw ValueWriter, val refle return copyDocumentFromBytes(vw, rdr) } -// RawDecodeValue is the ValueDecoderFunc for Reader. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (PrimitiveCodecs) RawDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// rawDecodeValue is the ValueDecoderFunc for Reader. +func rawDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRaw { return ValueDecoderError{Name: "RawDecodeValue", Types: []reflect.Type{tRaw}, Received: val} } diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index be3aeab978..18dcfb71b3 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -32,14 +32,11 @@ func bytesFromDoc(doc interface{}) []byte { func TestPrimitiveValueEncoders(t *testing.T) { t.Parallel() - var pc PrimitiveCodecs - var wrong = func(string, string) string { return "wrong" } type subtest struct { name string val interface{} - ectx *EncodeContext llvrw *valueReaderWriter invoke invoked err error @@ -52,13 +49,12 @@ func TestPrimitiveValueEncoders(t *testing.T) { }{ { "RawValueEncodeValue", - ValueEncoderFunc(pc.RawValueEncodeValue), + ValueEncoderFunc(rawValueEncodeValue), []subtest{ { "wrong type", wrong, nil, - nil, nothing, ValueEncoderError{ Name: "RawValueEncodeValue", @@ -70,7 +66,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { "RawValue/success", RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, 3.14159)}, nil, - nil, writeDouble, nil, }, @@ -81,7 +76,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { Value: bsoncore.AppendDouble(nil, 3.14159), }, nil, - nil, nothing, fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x0"), }, @@ -92,7 +86,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { Value: bsoncore.AppendDouble(nil, 3.14159), }, nil, - nil, nothing, fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x8f"), }, @@ -100,20 +93,18 @@ func TestPrimitiveValueEncoders(t *testing.T) { }, { "RawEncodeValue", - ValueEncoderFunc(pc.RawEncodeValue), + ValueEncoderFunc(rawEncodeValue), []subtest{ { "wrong type", wrong, nil, - nil, nothing, ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: reflect.ValueOf(wrong)}, }, { "WriteDocument Error", Raw{}, - nil, &valueReaderWriter{Err: errors.New("wd error"), ErrAfter: writeDocument}, writeDocument, errors.New("wd error"), @@ -121,7 +112,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { { "Raw.Elements Error", Raw{0xFF, 0x00, 0x00, 0x00, 0x00}, - nil, &valueReaderWriter{}, writeDocument, errors.New("length read exceeds number of bytes available. length=5 bytes=255"), @@ -129,7 +119,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { { "WriteDocumentElement Error", Raw(bytesFromDoc(D{{"foo", nil}})), - nil, &valueReaderWriter{Err: errors.New("wde error"), ErrAfter: writeDocumentElement}, writeDocumentElement, errors.New("wde error"), @@ -137,7 +126,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { { "encodeValue error", Raw(bytesFromDoc(D{{"foo", nil}})), - nil, &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeNull}, writeNull, errors.New("ev error"), @@ -145,7 +133,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { { "iterator error", Raw{0x0C, 0x00, 0x00, 0x00, 0x01, 'f', 'o', 'o', 0x00, 0x01, 0x02, 0x03}, - nil, &valueReaderWriter{}, writeDocumentElement, errors.New("not enough bytes available to read type. bytes=3 type=double"), @@ -166,16 +153,12 @@ func TestPrimitiveValueEncoders(t *testing.T) { t.Run(subtest.name, func(t *testing.T) { t.Parallel() - var ec EncodeContext - if subtest.ectx != nil { - ec = *subtest.ectx - } llvrw := new(valueReaderWriter) if subtest.llvrw != nil { llvrw = subtest.llvrw } llvrw.T = t - err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val)) + err := tc.ve.EncodeValue(nil, llvrw, reflect.ValueOf(subtest.val)) if !assert.CompareErrors(err, subtest.err) { t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) } @@ -478,8 +461,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { } func TestPrimitiveValueDecoders(t *testing.T) { - var pc PrimitiveCodecs - var wrong = func(string, string) string { return "wrong" } const cansetreflectiontest = "cansetreflectiontest" @@ -487,7 +468,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { type subtest struct { name string val interface{} - dctx *DecodeContext + reg *Registry llvrw *valueReaderWriter invoke invoked err error @@ -500,7 +481,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { }{ { "RawValueDecodeValue", - ValueDecoderFunc(pc.RawValueDecodeValue), + ValueDecoderFunc(rawValueDecodeValue), []subtest{ { "wrong type", @@ -544,7 +525,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { }, { "RawDecodeValue", - ValueDecoderFunc(pc.RawDecodeValue), + ValueDecoderFunc(rawDecodeValue), []subtest{ { "wrong type", @@ -582,23 +563,19 @@ func TestPrimitiveValueDecoders(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, rc := range tc.subtests { t.Run(rc.name, func(t *testing.T) { - var dc DecodeContext - if rc.dctx != nil { - dc = *rc.dctx - } llvrw := new(valueReaderWriter) if rc.llvrw != nil { llvrw = rc.llvrw } llvrw.T = t if rc.val == cansetreflectiontest { // We're doing a CanSet reflection test - err := tc.vd.DecodeValue(dc, llvrw, reflect.Value{}) + err := tc.vd.DecodeValue(rc.reg, llvrw, reflect.Value{}) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } val := reflect.New(reflect.TypeOf(rc.val)).Elem() - err = tc.vd.DecodeValue(dc, llvrw, val) + err = tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } @@ -615,7 +592,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { panic(err) } }() - err := tc.vd.DecodeValue(dc, llvrw, val) + err := tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } diff --git a/bson/raw_value.go b/bson/raw_value.go index a32b82e41d..0e32361ee4 100644 --- a/bson/raw_value.go +++ b/bson/raw_value.go @@ -46,7 +46,7 @@ func (rv RawValue) IsZero() bool { func (rv RawValue) Unmarshal(val interface{}) error { reg := rv.r if reg == nil { - reg = DefaultRegistry + reg = defaultRegistry } return rv.UnmarshalWithRegistry(reg, val) } @@ -81,13 +81,13 @@ func (rv RawValue) UnmarshalWithRegistry(r *Registry, val interface{}) error { if err != nil { return err } - return dec.DecodeValue(DecodeContext{Registry: r}, vr, rval) + return dec.DecodeValue(r, vr, rval) } // UnmarshalWithContext performs the same unmarshalling as Unmarshal but uses the provided DecodeContext // instead of the one attached or the default registry. -func (rv RawValue) UnmarshalWithContext(dc *DecodeContext, val interface{}) error { - if dc == nil { +func (rv RawValue) UnmarshalWithContext(reg *Registry, val interface{}) error { + if reg == nil { return ErrNilContext } @@ -97,11 +97,11 @@ func (rv RawValue) UnmarshalWithContext(dc *DecodeContext, val interface{}) erro return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval) } rval = rval.Elem() - dec, err := dc.LookupDecoder(rval.Type()) + dec, err := reg.LookupDecoder(rval.Type()) if err != nil { return err } - return dec.DecodeValue(*dc, vr, rval) + return dec.DecodeValue(reg, vr, rval) } func convertFromCoreValue(v bsoncore.Value) RawValue { diff --git a/bson/raw_value_test.go b/bson/raw_value_test.go index f02fe8f326..aa2b9a0eb6 100644 --- a/bson/raw_value_test.go +++ b/bson/raw_value_test.go @@ -75,7 +75,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns DecodeValue error", func(t *testing.T) { t.Parallel() - reg := NewRegistry() + reg := NewRegistryBuilder().Build() val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string want := fmt.Errorf("cannot decode %v into a string type", TypeDouble) @@ -87,7 +87,7 @@ func TestRawValue(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - reg := NewRegistry() + reg := NewRegistryBuilder().Build() want := float64(3.14159) val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, want)} var got float64 @@ -114,11 +114,11 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} + reg := newTestRegistryBuilder().Build() var val RawValue var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} - got := val.UnmarshalWithContext(&dc, &s) + got := val.UnmarshalWithContext(reg, &s) if !assert.CompareErrors(got, want) { t.Errorf("Expected errors to match. got %v; want %v", got, want) } @@ -126,11 +126,11 @@ func TestRawValue(t *testing.T) { t.Run("Returns DecodeValue error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: NewRegistry()} + reg := NewRegistryBuilder().Build() val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string want := fmt.Errorf("cannot decode %v into a string type", TypeDouble) - got := val.UnmarshalWithContext(&dc, &s) + got := val.UnmarshalWithContext(reg, &s) if !assert.CompareErrors(got, want) { t.Errorf("Expected errors to match. got %v; want %v", got, want) } @@ -138,11 +138,11 @@ func TestRawValue(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: NewRegistry()} + reg := NewRegistryBuilder().Build() want := float64(3.14159) val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, want)} var got float64 - err := val.UnmarshalWithContext(&dc, &got) + err := val.UnmarshalWithContext(reg, &got) noerr(t, err) if got != want { t.Errorf("Expected results to match. got %g; want %g", got, want) diff --git a/bson/registry.go b/bson/registry.go index 74b99e93ab..740a2bb665 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -7,25 +7,12 @@ package bson import ( - "errors" "fmt" "reflect" "sync" ) -// DefaultRegistry is the default Registry. It contains the default codecs and the -// primitive codecs. -var DefaultRegistry = NewRegistry() - -// ErrNilType is returned when nil is passed to either LookupEncoder or LookupDecoder. -// -// Deprecated: ErrNilType will not be supported in Go Driver 2.0. -var ErrNilType = errors.New("cannot perform a decoder lookup on ") - -// ErrNotPointer is returned when a non-pointer type is provided to LookupDecoder. -// -// Deprecated: ErrNotPointer will not be supported in Go Driver 2.0. -var ErrNotPointer = errors.New("non-pointer provided to LookupDecoder") +var defaultRegistry = NewRegistryBuilder().Build() // ErrNoEncoder is returned when there wasn't an encoder available for a type. // @@ -49,7 +36,13 @@ type ErrNoDecoder struct { } func (end ErrNoDecoder) Error() string { - return "no decoder found for " + end.Type.String() + var typeStr string + if end.Type != nil { + typeStr = end.Type.String() + } else { + typeStr = "nil type" + } + return "no decoder found for " + typeStr } // ErrNoTypeMapEntry is returned when there wasn't a type available for the provided BSON type. @@ -63,147 +56,148 @@ func (entme ErrNoTypeMapEntry) Error() string { return "no type map entry found for " + entme.Type.String() } -// ErrNotInterface is returned when the provided type is not an interface. -// -// Deprecated: ErrNotInterface will not be supported in Go Driver 2.0. -var ErrNotInterface = errors.New("The provided type is not an interface") +// EncoderFactory is an idempotent factory function that generates a new ValueEncoder. +type EncoderFactory func(*Registry) ValueEncoder + +// DecoderFactory is an idempotent factory function that generates a new ValueDecoder. +type DecoderFactory func(*Registry) ValueDecoder // A RegistryBuilder is used to build a Registry. This type is not goroutine // safe. -// -// Deprecated: Use Registry instead. type RegistryBuilder struct { - registry *Registry + typeEncoders map[reflect.Type]EncoderFactory + typeDecoders map[reflect.Type]DecoderFactory + interfaceEncoders map[reflect.Type]EncoderFactory + interfaceDecoders map[reflect.Type]DecoderFactory + kindEncoders [reflect.UnsafePointer + 1]EncoderFactory + kindDecoders [reflect.UnsafePointer + 1]DecoderFactory + typeMap map[Type]reflect.Type } // NewRegistryBuilder creates a new empty RegistryBuilder. -// -// Deprecated: Use NewRegistry instead. func NewRegistryBuilder() *RegistryBuilder { rb := &RegistryBuilder{ - registry: &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - }, - } - DefaultValueEncoders{}.RegisterDefaultEncoders(rb) - DefaultValueDecoders{}.RegisterDefaultDecoders(rb) - PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb) + typeEncoders: make(map[reflect.Type]EncoderFactory), + typeDecoders: make(map[reflect.Type]DecoderFactory), + interfaceEncoders: make(map[reflect.Type]EncoderFactory), + interfaceDecoders: make(map[reflect.Type]DecoderFactory), + typeMap: make(map[Type]reflect.Type), + } + registerDefaultEncoders(rb) + registerDefaultDecoders(rb) + registerPrimitiveCodecs(rb) return rb } -// RegisterCodec will register the provided ValueCodec for the provided type. +// RegisterTypeEncoder registers a ValueEncoder factory for the provided type. +// +// The type will be used as provided, so an encoder factory can be registered for a type and a +// different one can be registered for a pointer to that type. +// +// If the given type is an interface, the encoder will be called when marshaling a type that is +// that interface. It will not be called when marshaling a non-interface type that implements the +// interface. To get the latter behavior, call RegisterInterfaceEncoder instead. // -// Deprecated: Use Registry.RegisterTypeEncoder and Registry.RegisterTypeDecoder instead. -func (rb *RegistryBuilder) RegisterCodec(t reflect.Type, codec ValueCodec) *RegistryBuilder { - rb.RegisterTypeEncoder(t, codec) - rb.RegisterTypeDecoder(t, codec) +// RegisterTypeEncoder should not be called concurrently with any other Registry method. +func (rb *RegistryBuilder) RegisterTypeEncoder(valueType reflect.Type, encFac EncoderFactory) *RegistryBuilder { + if encFac != nil { + rb.typeEncoders[valueType] = encFac + } return rb } -// RegisterTypeEncoder will register the provided ValueEncoder for the provided type. +// RegisterTypeDecoder registers a ValueDecoder factory for the provided type. // -// The type will be used directly, so an encoder can be registered for a type and a different encoder can be registered -// for a pointer to that type. +// The type will be used as provided, so a decoder can be registered for a type and a different +// decoder can be registered for a pointer to that type. // -// If the given type is an interface, the encoder will be called when marshaling a type that is that interface. It -// will not be called when marshaling a non-interface type that implements the interface. +// If the given type is an interface, the decoder will be called when unmarshaling into a type that +// is that interface. It will not be called when unmarshaling into a non-interface type that +// implements the interface. To get the latter behavior, call RegisterHookDecoder instead. // -// Deprecated: Use Registry.RegisterTypeEncoder instead. -func (rb *RegistryBuilder) RegisterTypeEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { - rb.registry.RegisterTypeEncoder(t, enc) +// RegisterTypeDecoder should not be called concurrently with any other Registry method. +func (rb *RegistryBuilder) RegisterTypeDecoder(valueType reflect.Type, decFac DecoderFactory) *RegistryBuilder { + if decFac != nil { + rb.typeDecoders[valueType] = decFac + } return rb } -// RegisterHookEncoder will register an encoder for the provided interface type t. This encoder will be called when -// marshaling a type if the type implements t or a pointer to the type implements t. If the provided type is not -// an interface (i.e. t.Kind() != reflect.Interface), this method will panic. +// RegisterKindEncoder registers a ValueEncoder factory for the provided kind. // -// Deprecated: Use Registry.RegisterInterfaceEncoder instead. -func (rb *RegistryBuilder) RegisterHookEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { - rb.registry.RegisterInterfaceEncoder(t, enc) - return rb -} - -// RegisterTypeDecoder will register the provided ValueDecoder for the provided type. +// Use RegisterKindEncoder to register an encoder factory for any type with the same underlying kind. +// For example, consider the type MyInt defined as // -// The type will be used directly, so a decoder can be registered for a type and a different decoder can be registered -// for a pointer to that type. +// type MyInt int32 // -// If the given type is an interface, the decoder will be called when unmarshaling into a type that is that interface. -// It will not be called when unmarshaling into a non-interface type that implements the interface. +// To define an encoder factory for MyInt and int32, use RegisterKindEncoder like // -// Deprecated: Use Registry.RegisterTypeDecoder instead. -func (rb *RegistryBuilder) RegisterTypeDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { - rb.registry.RegisterTypeDecoder(t, dec) - return rb -} - -// RegisterHookDecoder will register an decoder for the provided interface type t. This decoder will be called when -// unmarshaling into a type if the type implements t or a pointer to the type implements t. If the provided type is not -// an interface (i.e. t.Kind() != reflect.Interface), this method will panic. +// reg.RegisterKindEncoder(reflect.Int32, myEncoder) // -// Deprecated: Use Registry.RegisterInterfaceDecoder instead. -func (rb *RegistryBuilder) RegisterHookDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { - rb.registry.RegisterInterfaceDecoder(t, dec) +// RegisterKindEncoder should not be called concurrently with any other Registry method. +func (rb *RegistryBuilder) RegisterKindEncoder(kind reflect.Kind, encFac EncoderFactory) *RegistryBuilder { + if encFac != nil && kind < reflect.Kind(len(rb.kindEncoders)) { + rb.kindEncoders[kind] = encFac + } return rb } -// RegisterEncoder registers the provided type and encoder pair. +// RegisterKindDecoder registers a ValueDecoder factory for the provided kind. // -// Deprecated: Use Registry.RegisterTypeEncoder or Registry.RegisterInterfaceEncoder instead. -func (rb *RegistryBuilder) RegisterEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { - if t == tEmpty { - rb.registry.RegisterTypeEncoder(t, enc) - return rb - } - switch t.Kind() { - case reflect.Interface: - rb.registry.RegisterInterfaceEncoder(t, enc) - default: - rb.registry.RegisterTypeEncoder(t, enc) +// Use RegisterKindDecoder to register a decoder for any type with the same underlying kind. For +// example, consider the type MyInt defined as +// +// type MyInt int32 +// +// To define an decoder for MyInt and int32, use RegisterKindDecoder like +// +// reg.RegisterKindDecoder(reflect.Int32, myDecoder) +// +// RegisterKindDecoder should not be called concurrently with any other Registry method. +func (rb *RegistryBuilder) RegisterKindDecoder(kind reflect.Kind, decFac DecoderFactory) *RegistryBuilder { + if decFac != nil && kind < reflect.Kind(len(rb.kindDecoders)) { + rb.kindDecoders[kind] = decFac } return rb } -// RegisterDecoder registers the provided type and decoder pair. +// RegisterInterfaceEncoder registers an encoder factory for the provided interface type iface. This +// encoder will be called when marshaling a type if the type implements iface or a pointer to the type +// implements iface. If the provided type is not an interface +// (i.e. iface.Kind() != reflect.Interface), this method will panic. // -// Deprecated: Use Registry.RegisterTypeDecoder or Registry.RegisterInterfaceDecoder instead. -func (rb *RegistryBuilder) RegisterDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { - if t == nil { - rb.registry.RegisterTypeDecoder(t, dec) - return rb - } - if t == tEmpty { - rb.registry.RegisterTypeDecoder(t, dec) - return rb +// RegisterInterfaceEncoder should not be called concurrently with any other Registry method. +func (rb *RegistryBuilder) RegisterInterfaceEncoder(iface reflect.Type, encFac EncoderFactory) *RegistryBuilder { + if iface.Kind() != reflect.Interface { + panicStr := fmt.Errorf("RegisterInterfaceEncoder expects a type with kind reflect.Interface, "+ + "got type %s with kind %s", iface, iface.Kind()) + panic(panicStr) } - switch t.Kind() { - case reflect.Interface: - rb.registry.RegisterInterfaceDecoder(t, dec) - default: - rb.registry.RegisterTypeDecoder(t, dec) + + if encFac != nil { + rb.interfaceEncoders[iface] = encFac } - return rb -} -// RegisterDefaultEncoder will register the provided ValueEncoder to the provided -// kind. -// -// Deprecated: Use Registry.RegisterKindEncoder instead. -func (rb *RegistryBuilder) RegisterDefaultEncoder(kind reflect.Kind, enc ValueEncoder) *RegistryBuilder { - rb.registry.RegisterKindEncoder(kind, enc) return rb } -// RegisterDefaultDecoder will register the provided ValueDecoder to the -// provided kind. +// RegisterInterfaceDecoder registers a decoder factory for the provided interface type iface. This decoder +// will be called when unmarshaling into a type if the type implements iface or a pointer to the type +// implements iface. If the provided type is not an interface (i.e. iface.Kind() != reflect.Interface), +// this method will panic. // -// Deprecated: Use Registry.RegisterKindDecoder instead. -func (rb *RegistryBuilder) RegisterDefaultDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder { - rb.registry.RegisterKindDecoder(kind, dec) +// RegisterInterfaceDecoder should not be called concurrently with any other Registry method. +func (rb *RegistryBuilder) RegisterInterfaceDecoder(iface reflect.Type, decFac DecoderFactory) *RegistryBuilder { + if iface.Kind() != reflect.Interface { + panicStr := fmt.Errorf("RegisterInterfaceDecoder expects a type with kind reflect.Interface, "+ + "got type %s with kind %s", iface, iface.Kind()) + panic(panicStr) + } + + if decFac != nil { + rb.interfaceDecoders[iface] = decFac + } + return rb } @@ -217,30 +211,82 @@ func (rb *RegistryBuilder) RegisterDefaultDecoder(kind reflect.Kind, dec ValueDe // // rb.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) // -// Deprecated: Use Registry.RegisterTypeMapEntry instead. +// RegisterTypeMapEntry should not be called concurrently with any other Registry method. func (rb *RegistryBuilder) RegisterTypeMapEntry(bt Type, rt reflect.Type) *RegistryBuilder { - rb.registry.RegisterTypeMapEntry(bt, rt) + rb.typeMap[bt] = rt return rb } // Build creates a Registry from the current state of this RegistryBuilder. -// -// Deprecated: Use NewRegistry instead. func (rb *RegistryBuilder) Build() *Registry { r := &Registry{ - interfaceEncoders: append([]interfaceValueEncoder(nil), rb.registry.interfaceEncoders...), - interfaceDecoders: append([]interfaceValueDecoder(nil), rb.registry.interfaceDecoders...), - typeEncoders: rb.registry.typeEncoders.Clone(), - typeDecoders: rb.registry.typeDecoders.Clone(), - kindEncoders: rb.registry.kindEncoders.Clone(), - kindDecoders: rb.registry.kindDecoders.Clone(), - } - rb.registry.typeMap.Range(func(k, v interface{}) bool { - if k != nil && v != nil { - r.typeMap.Store(k, v) + typeEncoders: new(sync.Map), + typeDecoders: new(sync.Map), + interfaceEncoders: make([]interfaceValueEncoder, 0, len(rb.interfaceEncoders)), + interfaceDecoders: make([]interfaceValueDecoder, 0, len(rb.interfaceDecoders)), + typeMap: make(map[Type]reflect.Type), + + codecTypeMap: make(map[reflect.Type][]interface{}), + } + + codecCache := make(map[reflect.Value]interface{}) + + getEncoder := func(encFac EncoderFactory) ValueEncoder { + if enc, ok := codecCache[reflect.ValueOf(encFac)]; ok { + return enc.(ValueEncoder) } - return true - }) + encoder := encFac(r) + codecCache[reflect.ValueOf(encFac)] = encoder + t := reflect.ValueOf(encoder).Type() + r.codecTypeMap[t] = append(r.codecTypeMap[t], encoder) + return encoder + } + for k, v := range rb.typeEncoders { + encoder := getEncoder(v) + r.typeEncoders.Store(k, encoder) + } + for k, v := range rb.interfaceEncoders { + encoder := getEncoder(v) + r.interfaceEncoders = append(r.interfaceEncoders, interfaceValueEncoder{k, encoder}) + } + for i, v := range rb.kindEncoders { + if v == nil { + continue + } + encoder := getEncoder(v) + r.kindEncoders[i] = encoder + } + + getDecoder := func(decFac DecoderFactory) ValueDecoder { + if dec, ok := codecCache[reflect.ValueOf(decFac)]; ok { + return dec.(ValueDecoder) + } + decoder := decFac(r) + codecCache[reflect.ValueOf(decFac)] = decoder + t := reflect.ValueOf(decoder).Type() + r.codecTypeMap[t] = append(r.codecTypeMap[t], decoder) + return decoder + } + for k, v := range rb.typeDecoders { + decoder := getDecoder(v) + r.typeDecoders.Store(k, decoder) + } + for k, v := range rb.interfaceDecoders { + decoder := getDecoder(v) + r.interfaceDecoders = append(r.interfaceDecoders, interfaceValueDecoder{k, decoder}) + } + for i, v := range rb.kindDecoders { + if v == nil { + continue + } + decoder := getDecoder(v) + r.kindDecoders[i] = decoder + } + + for k, v := range rb.typeMap { + r.typeMap[k] = v + } + return r } @@ -278,137 +324,32 @@ func (rb *RegistryBuilder) Build() *Registry { // // Read [Registry.LookupDecoder] and [Registry.LookupEncoder] for Registry lookup procedure. type Registry struct { + typeEncoders *sync.Map // map[reflect.Type]ValueEncoder + typeDecoders *sync.Map // map[reflect.Type]ValueDecoder interfaceEncoders []interfaceValueEncoder interfaceDecoders []interfaceValueDecoder - typeEncoders *typeEncoderCache - typeDecoders *typeDecoderCache - kindEncoders *kindEncoderCache - kindDecoders *kindDecoderCache - typeMap sync.Map // map[Type]reflect.Type -} + kindEncoders [reflect.UnsafePointer + 1]ValueEncoder + kindDecoders [reflect.UnsafePointer + 1]ValueDecoder + typeMap map[Type]reflect.Type -// NewRegistry creates a new empty Registry. -func NewRegistry() *Registry { - return NewRegistryBuilder().Build() + codecTypeMap map[reflect.Type][]interface{} } -// RegisterTypeEncoder registers the provided ValueEncoder for the provided type. -// -// The type will be used as provided, so an encoder can be registered for a type and a different -// encoder can be registered for a pointer to that type. -// -// If the given type is an interface, the encoder will be called when marshaling a type that is -// that interface. It will not be called when marshaling a non-interface type that implements the -// interface. To get the latter behavior, call RegisterHookEncoder instead. -// -// RegisterTypeEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) { - r.typeEncoders.Store(valueType, enc) -} - -// RegisterTypeDecoder registers the provided ValueDecoder for the provided type. -// -// The type will be used as provided, so a decoder can be registered for a type and a different -// decoder can be registered for a pointer to that type. -// -// If the given type is an interface, the decoder will be called when unmarshaling into a type that -// is that interface. It will not be called when unmarshaling into a non-interface type that -// implements the interface. To get the latter behavior, call RegisterHookDecoder instead. -// -// RegisterTypeDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) { - r.typeDecoders.Store(valueType, dec) -} - -// RegisterKindEncoder registers the provided ValueEncoder for the provided kind. -// -// Use RegisterKindEncoder to register an encoder for any type with the same underlying kind. For -// example, consider the type MyInt defined as -// -// type MyInt int32 -// -// To define an encoder for MyInt and int32, use RegisterKindEncoder like -// -// reg.RegisterKindEncoder(reflect.Int32, myEncoder) -// -// RegisterKindEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { - r.kindEncoders.Store(kind, enc) -} - -// RegisterKindDecoder registers the provided ValueDecoder for the provided kind. -// -// Use RegisterKindDecoder to register a decoder for any type with the same underlying kind. For -// example, consider the type MyInt defined as -// -// type MyInt int32 -// -// To define an decoder for MyInt and int32, use RegisterKindDecoder like -// -// reg.RegisterKindDecoder(reflect.Int32, myDecoder) -// -// RegisterKindDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) { - r.kindDecoders.Store(kind, dec) -} - -// RegisterInterfaceEncoder registers an encoder for the provided interface type iface. This encoder will -// be called when marshaling a type if the type implements iface or a pointer to the type -// implements iface. If the provided type is not an interface -// (i.e. iface.Kind() != reflect.Interface), this method will panic. -// -// RegisterInterfaceEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc ValueEncoder) { - if iface.Kind() != reflect.Interface { - panicStr := fmt.Errorf("RegisterInterfaceEncoder expects a type with kind reflect.Interface, "+ - "got type %s with kind %s", iface, iface.Kind()) - panic(panicStr) - } - - for idx, encoder := range r.interfaceEncoders { - if encoder.i == iface { - r.interfaceEncoders[idx].ve = enc - return - } +// SetCodecOption configures Registry using a *RegistryOpt. +func (r *Registry) SetCodecOption(opt *RegistryOpt) error { + v, ok := r.codecTypeMap[opt.typ] + if !ok || len(v) == 0 { + return fmt.Errorf("could not find codec %s", opt.typ.String()) } - - r.interfaceEncoders = append(r.interfaceEncoders, interfaceValueEncoder{i: iface, ve: enc}) -} - -// RegisterInterfaceDecoder registers an decoder for the provided interface type iface. This decoder will -// be called when unmarshaling into a type if the type implements iface or a pointer to the type -// implements iface. If the provided type is not an interface (i.e. iface.Kind() != reflect.Interface), -// this method will panic. -// -// RegisterInterfaceDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder) { - if iface.Kind() != reflect.Interface { - panicStr := fmt.Errorf("RegisterInterfaceDecoder expects a type with kind reflect.Interface, "+ - "got type %s with kind %s", iface, iface.Kind()) - panic(panicStr) - } - - for idx, decoder := range r.interfaceDecoders { - if decoder.i == iface { - r.interfaceDecoders[idx].vd = dec - return + for i := range v { + rtns := opt.fn.Call([]reflect.Value{reflect.ValueOf(v[i])}) + for _, r := range rtns { + if !r.IsNil() { + return r.Interface().(error) + } } } - - r.interfaceDecoders = append(r.interfaceDecoders, interfaceValueDecoder{i: iface, vd: dec}) -} - -// RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this -// mapping is decoding situations where an empty interface is used and a default type needs to be -// created and decoded into. -// -// By default, BSON documents will decode into interface{} values as bson.D. To change the default type for BSON -// documents, a type map entry for TypeEmbeddedDocument should be registered. For example, to force BSON documents -// to decode to bson.Raw, use the following code: -// -// reg.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) -func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) { - r.typeMap.Store(bt, rt) + return nil } // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup @@ -427,36 +368,38 @@ func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) { // 3. An encoder registered using RegisterKindEncoder for the kind of value. // // If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for -// concurrent use by multiple goroutines after all codecs and encoders are registered. +// concurrent use by multiple goroutines. func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { if valueType == nil { return nil, ErrNoEncoder{Type: valueType} } - enc, found := r.lookupTypeEncoder(valueType) - if found { + + if enc, found := r.typeEncoders.Load(valueType); found { if enc == nil { return nil, ErrNoEncoder{Type: valueType} } - return enc, nil + return enc.(ValueEncoder), nil } - enc, found = r.lookupInterfaceEncoder(valueType, true) - if found { - return r.typeEncoders.LoadOrStore(valueType, enc), nil + if enc, found := r.lookupInterfaceEncoder(valueType, true); found { + r.typeEncoders.Store(valueType, enc) + return enc, nil } - if v, ok := r.kindEncoders.Load(valueType.Kind()); ok { - return r.storeTypeEncoder(valueType, v), nil + if enc, found := r.lookupKindEncoder(valueType.Kind()); found { + r.typeEncoders.Store(valueType, enc) + return enc, nil } return nil, ErrNoEncoder{Type: valueType} } -func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEncoder { - return r.typeEncoders.LoadOrStore(rt, enc) -} - -func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) { - return r.typeEncoders.Load(rt) +func (r *Registry) lookupKindEncoder(valueKind reflect.Kind) (ValueEncoder, bool) { + if valueKind < reflect.Kind(len(r.kindEncoders)) { + if enc := r.kindEncoders[valueKind]; enc != nil { + return enc, true + } + } + return nil, false } func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (ValueEncoder, bool) { @@ -472,9 +415,9 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool // ahead in interfaceEncoders defaultEnc, found := r.lookupInterfaceEncoder(valueType, false) if !found { - defaultEnc, _ = r.kindEncoders.Load(valueType.Kind()) + defaultEnc, _ = r.lookupKindEncoder(valueType.Kind()) } - return newCondAddrEncoder(ienc.ve, defaultEnc), true + return &condAddrEncoder{canAddrEnc: ienc.ve, elseEnc: defaultEnc}, true } } return nil, false @@ -496,39 +439,44 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool // 3. A decoder registered using RegisterKindDecoder for the kind of value. // // If no decoder is found, an error of type ErrNoDecoder is returned. LookupDecoder is safe for -// concurrent use by multiple goroutines after all codecs and decoders are registered. +// concurrent use by multiple goroutines. func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { if valueType == nil { - return nil, ErrNilType + return nil, ErrNoDecoder{Type: valueType} } - dec, found := r.lookupTypeDecoder(valueType) - if found { + + if dec, found := r.typeDecoders.Load(valueType); found { if dec == nil { return nil, ErrNoDecoder{Type: valueType} } - return dec, nil + return dec.(ValueDecoder), nil } - dec, found = r.lookupInterfaceDecoder(valueType, true) - if found { - return r.storeTypeDecoder(valueType, dec), nil + if dec, found := r.lookupInterfaceDecoder(valueType, true); found { + r.typeDecoders.Store(valueType, dec) + return dec, nil } - if v, ok := r.kindDecoders.Load(valueType.Kind()); ok { - return r.storeTypeDecoder(valueType, v), nil + if dec, found := r.lookupKindDecoder(valueType.Kind()); found { + r.typeDecoders.Store(valueType, dec) + return dec, nil } return nil, ErrNoDecoder{Type: valueType} } -func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (ValueDecoder, bool) { - return r.typeDecoders.Load(valueType) -} - -func (r *Registry) storeTypeDecoder(typ reflect.Type, dec ValueDecoder) ValueDecoder { - return r.typeDecoders.LoadOrStore(typ, dec) +func (r *Registry) lookupKindDecoder(valueKind reflect.Kind) (ValueDecoder, bool) { + if valueKind < reflect.Kind(len(r.kindDecoders)) { + if dec := r.kindDecoders[valueKind]; dec != nil { + return dec, true + } + } + return nil, false } func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) { + if valueType == nil { + return nil, false + } for _, idec := range r.interfaceDecoders { if valueType.Implements(idec.i) { return idec.vd, true @@ -538,9 +486,9 @@ func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool // ahead in interfaceDecoders defaultDec, found := r.lookupInterfaceDecoder(valueType, false) if !found { - defaultDec, _ = r.kindDecoders.Load(valueType.Kind()) + defaultDec, _ = r.lookupKindDecoder(valueType.Kind()) } - return newCondAddrDecoder(idec.vd, defaultDec), true + return &condAddrDecoder{canAddrDec: idec.vd, elseDec: defaultDec}, true } } return nil, false @@ -548,14 +496,12 @@ func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool // LookupTypeMapEntry inspects the registry's type map for a Go type for the corresponding BSON // type. If no type is found, ErrNoTypeMapEntry is returned. -// -// LookupTypeMapEntry should not be called concurrently with any other Registry method. func (r *Registry) LookupTypeMapEntry(bt Type) (reflect.Type, error) { - v, ok := r.typeMap.Load(bt) + v, ok := r.typeMap[bt] if v == nil || !ok { return nil, ErrNoTypeMapEntry{Type: bt} } - return v.(reflect.Type), nil + return v, nil } type interfaceValueEncoder struct { diff --git a/bson/registry_examples_test.go b/bson/registry_examples_test.go index 39214f1b65..92d82ba58e 100644 --- a/bson/registry_examples_test.go +++ b/bson/registry_examples_test.go @@ -23,7 +23,7 @@ func ExampleRegistry_customEncoder() { negatedIntType := reflect.TypeOf(negatedInt(0)) negatedIntEncoder := func( - ec bson.EncodeContext, + _ bson.EncoderRegistry, vw bson.ValueWriter, val reflect.Value, ) error { @@ -46,10 +46,14 @@ func ExampleRegistry_customEncoder() { return vw.WriteInt64(negatedVal) } - reg := bson.NewRegistry() - reg.RegisterTypeEncoder( - negatedIntType, - bson.ValueEncoderFunc(negatedIntEncoder)) + reg := bson.NewRegistryBuilder(). + RegisterTypeEncoder( + negatedIntType, + func(*bson.Registry) bson.ValueEncoder { + return bson.ValueEncoderFunc(negatedIntEncoder) + }, + ). + Build() // Define a document that includes both int and negatedInt fields with the // same value. @@ -66,8 +70,7 @@ func ExampleRegistry_customEncoder() { // same value and that the negatedInt field is encoded as the negated value. buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(reg) + enc := bson.NewEncoderWithRegistry(reg, vw) err := enc.Encode(doc) if err != nil { panic(err) @@ -85,7 +88,7 @@ func ExampleRegistry_customDecoder() { lenientBoolType := reflect.TypeOf(lenientBool(true)) lenientBoolDecoder := func( - dc bson.DecodeContext, + _ bson.DecoderRegistry, vr bson.ValueReader, val reflect.Value, ) error { @@ -129,10 +132,13 @@ func ExampleRegistry_customDecoder() { return nil } - reg := bson.NewRegistry() - reg.RegisterTypeDecoder( + rb := bson.NewRegistryBuilder() + rb.RegisterTypeDecoder( lenientBoolType, - bson.ValueDecoderFunc(lenientBoolDecoder)) + func(*bson.Registry) bson.ValueDecoder { + return bson.ValueDecoderFunc(lenientBoolDecoder) + }, + ) // Marshal a BSON document with a single field "isOK" that is a non-zero // integer value. @@ -148,7 +154,7 @@ func ExampleRegistry_customDecoder() { IsOK lenientBool `bson:"isOK"` } var doc MyDocument - err = bson.UnmarshalWithRegistry(reg, b, &doc) + err = bson.UnmarshalWithRegistry(rb.Build(), b, &doc) if err != nil { panic(err) } @@ -156,13 +162,13 @@ func ExampleRegistry_customDecoder() { // Output: {IsOK:true} } -func ExampleRegistry_RegisterKindEncoder() { +func ExampleRegistryBuilder_RegisterKindEncoder() { // Create a custom encoder that writes any Go type that has underlying type // int32 as an a BSON int64. To do that, we register the encoder as a "kind" // encoder for kind reflect.Int32. That way, even user-defined types with // underlying type int32 will be encoded as a BSON int64. int32To64Encoder := func( - ec bson.EncodeContext, + _ bson.EncoderRegistry, vw bson.ValueWriter, val reflect.Value, ) error { @@ -179,12 +185,16 @@ func ExampleRegistry_RegisterKindEncoder() { return vw.WriteInt64(val.Int()) } - // Create a default registry and register our int32-to-int64 encoder for + // Create a registry with our int32-to-int64 register encoder for // kind reflect.Int32. - reg := bson.NewRegistry() - reg.RegisterKindEncoder( - reflect.Int32, - bson.ValueEncoderFunc(int32To64Encoder)) + reg := bson.NewRegistryBuilder(). + RegisterKindEncoder( + reflect.Int32, + func(*bson.Registry) bson.ValueEncoder { + return bson.ValueEncoderFunc(int32To64Encoder) + }, + ). + Build() // Define a document that includes an int32, an int64, and a user-defined // type "myInt" that has underlying type int32. @@ -204,8 +214,7 @@ func ExampleRegistry_RegisterKindEncoder() { // int64 (represented as "$numberLong" when encoded as Extended JSON). buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(reg) + enc := bson.NewEncoderWithRegistry(reg, vw) err := enc.Encode(doc) if err != nil { panic(err) @@ -214,14 +223,14 @@ func ExampleRegistry_RegisterKindEncoder() { // Output: {"myint": {"$numberLong":"1"},"int32": {"$numberLong":"1"},"int64": {"$numberLong":"1"}} } -func ExampleRegistry_RegisterKindDecoder() { +func ExampleRegistryBuilder_RegisterKindDecoder() { // Create a custom decoder that can decode any integer value, including // integer values encoded as floating point numbers, to any Go type // with underlying type int64. To do that, we register the decoder as a // "kind" decoder for kind reflect.Int64. That way, we can even decode to // user-defined types with underlying type int64. flexibleInt64KindDecoder := func( - dc bson.DecodeContext, + _ bson.DecoderRegistry, vr bson.ValueReader, val reflect.Value, ) error { @@ -270,10 +279,13 @@ func ExampleRegistry_RegisterKindDecoder() { return nil } - reg := bson.NewRegistry() - reg.RegisterKindDecoder( + rb := bson.NewRegistryBuilder() + rb.RegisterKindDecoder( reflect.Int64, - bson.ValueDecoderFunc(flexibleInt64KindDecoder)) + func(*bson.Registry) bson.ValueDecoder { + return bson.ValueDecoderFunc(flexibleInt64KindDecoder) + }, + ) // Marshal a BSON document with fields that are mixed numeric types but all // hold integer values (i.e. values with no fractional part). @@ -290,7 +302,7 @@ func ExampleRegistry_RegisterKindDecoder() { Int64 int64 } var doc myDocument - err = bson.UnmarshalWithRegistry(reg, b, &doc) + err = bson.UnmarshalWithRegistry(rb.Build(), b, &doc) if err != nil { panic(err) } diff --git a/bson/registry_option.go b/bson/registry_option.go new file mode 100644 index 0000000000..cee6c9f7a0 --- /dev/null +++ b/bson/registry_option.go @@ -0,0 +1,155 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "reflect" +) + +// RegistryOpt is used to configure a Registry. +type RegistryOpt struct { + typ reflect.Type + fn reflect.Value +} + +// NewRegistryOpt creates a *RegistryOpt from a setter function. +// For example: +// +// opt := NewRegistryOpt(func(c *Codec) { +// c.attr = value +// }) +// +// reg := NewRegistryBuilder().Build() +// reg.SetCodecOptions(opt) +// +// The "attr" field in the registered Codec can be set to "value". +func NewRegistryOpt[T any](fn func(T) error) *RegistryOpt { + var zero [0]T + return &RegistryOpt{ + typ: reflect.TypeOf(zero).Elem(), + fn: reflect.ValueOf(fn), + } +} + +// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values +// instead of BSON null. +var NilByteSliceAsEmpty = NewRegistryOpt(func(c *byteSliceCodec) error { + c.encodeNilAsEmpty = true + return nil +}) + +// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or +// "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary. +var BinaryAsSlice = NewRegistryOpt(func(c *emptyInterfaceCodec) error { + c.decodeBinaryAsSlice = true + return nil +}) + +// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". +var DefaultDocumentM = NewRegistryOpt(func(c *emptyInterfaceCodec) error { + c.defaultDocumentType = reflect.TypeOf(M{}) + return nil +}) + +// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". +var DefaultDocumentD = NewRegistryOpt(func(c *emptyInterfaceCodec) error { + c.defaultDocumentType = reflect.TypeOf(D{}) + return nil +}) + +// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON +// null. +var NilMapAsEmpty = NewRegistryOpt(func(c *mapCodec) error { + c.encodeNilAsEmpty = true + return nil +}) + +// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name +// strings using fmt.Sprint instead of the default string conversion logic. +var StringifyMapKeysWithFmt = NewRegistryOpt(func(c *mapCodec) error { + c.encodeKeysWithStringer = true + return nil +}) + +// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value +// passed to Decode before unmarshaling BSON documents into them. +var ZeroMaps = NewRegistryOpt(func(c *mapCodec) error { + c.decodeZerosMap = true + return nil +}) + +// AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values +// when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct +// field. The truncation logic does not apply to BSON "decimal128" values. +var AllowTruncatingDoubles = NewRegistryOpt(func(c *numCodec) error { + c.truncate = true + return nil +}) + +// IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint, +// uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can +// represent the integer value. +var IntMinSize = NewRegistryOpt(func(c *numCodec) error { + c.minSize = true + return nil +}) + +// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON +// null. +var NilSliceAsEmpty = NewRegistryOpt(func(c *sliceCodec) error { + c.encodeNilAsEmpty = true + return nil +}) + +// ObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. +var ObjectIDAsHex = NewRegistryOpt(func(c *stringCodec) error { + c.decodeObjectIDAsHex = true + return nil +}) + +// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in +// the marshaled BSON when the "inline" struct tag option is set. +var ErrorOnInlineDuplicates = NewRegistryOpt(func(c *structCodec) error { + c.overwriteDuplicatedInlinedFields = false + return nil +}) + +// TODO(GODRIVER-2820): Update the description to remove the note about only examining exported +// TODO struct fields once the logic is updated to also inspect private struct fields. + +// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{}) +// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set. +// +// Note that the Encoder only examines exported struct fields when determining if a struct is the +// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. +var OmitZeroStruct = NewRegistryOpt(func(c *structCodec) error { + c.encodeOmitDefaultStruct = true + return nil +}) + +// UseJSONStructTags causes the Encoder and Decoder to fall back to using the "json" struct tag if +// a "bson" struct tag is not specified. +var UseJSONStructTags = NewRegistryOpt(func(c *structCodec) error { + c.useJSONStructTags = true + return nil +}) + +// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination +// value passed to Decode before unmarshaling BSON documents into them. +var ZeroStructs = NewRegistryOpt(func(c *structCodec) error { + c.decodeZeroStruct = true + return nil +}) + +// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead +// of the UTC timezone. +var UseLocalTimeZone = NewRegistryOpt(func(c *timeCodec) error { + c.useLocalTimeZone = true + return nil +}) diff --git a/bson/registry_test.go b/bson/registry_test.go index 2bc87364d3..a87fd3c6cb 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -15,523 +15,247 @@ import ( "go.mongodb.org/mongo-driver/internal/assert" ) -// newTestRegistryBuilder creates a new empty Registry. +// newTestRegistryBuilder creates a new empty RegistryBuilder. func newTestRegistryBuilder() *RegistryBuilder { return &RegistryBuilder{ - registry: &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - }, + typeEncoders: make(map[reflect.Type]EncoderFactory), + typeDecoders: make(map[reflect.Type]DecoderFactory), + interfaceEncoders: make(map[reflect.Type]EncoderFactory), + interfaceDecoders: make(map[reflect.Type]DecoderFactory), + typeMap: make(map[Type]reflect.Type), } } func TestRegistryBuilder(t *testing.T) { + t.Parallel() + t.Run("Register", func(t *testing.T) { + t.Parallel() + fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec) t.Run("interface", func(t *testing.T) { - var t1f *testInterface1 - var t2f *testInterface2 - var t4f *testInterface4 - ips := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc1}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + t.Parallel() + + t1f, t2f, t3f, t4f := + reflect.TypeOf((*testInterface1)(nil)).Elem(), + reflect.TypeOf((*testInterface2)(nil)).Elem(), + reflect.TypeOf((*testInterface3)(nil)).Elem(), + reflect.TypeOf((*testInterface4)(nil)).Elem() + + var c1, c2, c3, c4 int + ef1 := func(*Registry) ValueEncoder { + c1++ + return fc1 + } + ef2 := func(*Registry) ValueEncoder { + c2++ + return fc2 + } + ef3 := func(*Registry) ValueEncoder { + c3++ + return fc3 + } + ef4 := func(*Registry) ValueEncoder { + c4++ + return fc4 + } + + ips := []struct { + i reflect.Type + ef EncoderFactory + }{ + {i: t1f, ef: ef1}, + {i: t2f, ef: ef2}, + {i: t1f, ef: ef3}, + {i: t3f, ef: ef2}, + {i: t4f, ef: ef4}, } want := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + {i: t1f, ve: fc3}, {i: t2f, ve: fc2}, + {i: t3f, ve: fc2}, {i: t4f, ve: fc4}, } + rb := newTestRegistryBuilder() for _, ip := range ips { - rb.RegisterHookEncoder(ip.i, ip.ve) + rb.RegisterInterfaceEncoder(ip.i, ip.ef) } - reg := rb.Build() - got := reg.interfaceEncoders - if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) { - t.Errorf("the registered interfaces are not correct: got %#v, want %#v", got, want) + + if !cmp.Equal(c1, 0) { + t.Errorf("ef1 is called %d time(s); expected 0", c1) } - }) - t.Run("type", func(t *testing.T) { - ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} - rb := newTestRegistryBuilder(). - RegisterTypeEncoder(reflect.TypeOf(ft1), fc1). - RegisterTypeEncoder(reflect.TypeOf(ft2), fc2). - RegisterTypeEncoder(reflect.TypeOf(ft1), fc3). - RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) - want := []struct { - t reflect.Type - c ValueEncoder - }{ - {reflect.TypeOf(ft1), fc3}, - {reflect.TypeOf(ft2), fc2}, - {reflect.TypeOf(ft4), fc4}, + if !cmp.Equal(c2, 1) { + t.Errorf("ef2 is called %d time(s); expected 1", c2) } - - reg := rb.Build() - got := reg.typeEncoders - for _, s := range want { - wantT, wantC := s.t, s.c - gotC, exists := got.Load(wantT) - if !exists { - t.Errorf("Did not find type in the type registry: %v", wantT) - } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC) - } + if !cmp.Equal(c3, 1) { + t.Errorf("ef3 is called %d time(s); expected 1", c3) } - }) - t.Run("kind", func(t *testing.T) { - k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map - rb := newTestRegistryBuilder(). - RegisterDefaultEncoder(k1, fc1). - RegisterDefaultEncoder(k2, fc2). - RegisterDefaultEncoder(k1, fc3). - RegisterDefaultEncoder(k4, fc4) - want := []struct { - k reflect.Kind - c ValueEncoder - }{ - {k1, fc3}, - {k2, fc2}, - {k4, fc4}, + if !cmp.Equal(c4, 1) { + t.Errorf("ef4 is called %d time(s); expected 1", c4) + } + codecs, ok := reg.codecTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.codecTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + t.Errorf("codecs were not cached correctly") + } + got := make(map[reflect.Type]ValueEncoder) + for _, e := range reg.interfaceEncoders { + got[e.i] = e.ve } - - reg := rb.Build() - got := reg.kindEncoders for _, s := range want { - wantK, wantC := s.k, s.c - gotC, exists := got.Load(wantK) + wantI, wantVe := s.i, s.ve + gotVe, exists := got[wantI] if !exists { - t.Errorf("Did not find kind in the kind registry: %v", wantK) + t.Errorf("Did not find type in the type registry: %v", wantI) } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC) + if !cmp.Equal(gotVe, wantVe, cmp.AllowUnexported(fakeCodec{})) { + t.Errorf("codecs did not match: got %#v; want %#v", gotVe, wantVe) } } }) - t.Run("RegisterDefault", func(t *testing.T) { - t.Run("MapCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() - - rb.RegisterDefaultEncoder(reflect.Map, codec) - reg := rb.Build() - if reg.kindEncoders.get(reflect.Map) != codec { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) - } - - rb.RegisterDefaultEncoder(reflect.Map, codec2) - reg = rb.Build() - if reg.kindEncoders.get(reflect.Map) != codec2 { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) - } - }) - t.Run("StructCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() - - rb.RegisterDefaultEncoder(reflect.Struct, codec) - reg := rb.Build() - if reg.kindEncoders.get(reflect.Struct) != codec { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) - } - - rb.RegisterDefaultEncoder(reflect.Struct, codec2) - reg = rb.Build() - if reg.kindEncoders.get(reflect.Struct) != codec2 { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) - } - }) - t.Run("SliceCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() - - rb.RegisterDefaultEncoder(reflect.Slice, codec) - reg := rb.Build() - if reg.kindEncoders.get(reflect.Slice) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) - } - - rb.RegisterDefaultEncoder(reflect.Slice, codec2) - reg = rb.Build() - if reg.kindEncoders.get(reflect.Slice) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) - } - }) - t.Run("ArrayCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() + t.Run("type", func(t *testing.T) { + t.Parallel() - rb.RegisterDefaultEncoder(reflect.Array, codec) - reg := rb.Build() - if reg.kindEncoders.get(reflect.Array) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) - } + ft1, ft2, ft3, ft4 := + reflect.TypeOf(fakeType1{}), + reflect.TypeOf(fakeType2{}), + reflect.TypeOf(fakeType3{}), + reflect.TypeOf(fakeType4{}) - rb.RegisterDefaultEncoder(reflect.Array, codec2) - reg = rb.Build() - if reg.kindEncoders.get(reflect.Array) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) - } - }) - }) - t.Run("Lookup", func(t *testing.T) { - type Codec interface { - ValueEncoder - ValueDecoder + var c1, c2, c3, c4 int + ef1 := func(*Registry) ValueEncoder { + c1++ + return fc1 + } + ef2 := func(*Registry) ValueEncoder { + c2++ + return fc2 + } + ef3 := func(*Registry) ValueEncoder { + c3++ + return fc3 + } + ef4 := func(*Registry) ValueEncoder { + c4++ + return fc4 } - var ( - arrinstance [12]int - arr = reflect.TypeOf(arrinstance) - slc = reflect.TypeOf(make([]int, 12)) - m = reflect.TypeOf(make(map[string]int)) - strct = reflect.TypeOf(struct{ Foo string }{}) - ft1 = reflect.PtrTo(reflect.TypeOf(fakeType1{})) - ft2 = reflect.TypeOf(fakeType2{}) - ft3 = reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" })) - ti1 = reflect.TypeOf((*testInterface1)(nil)).Elem() - ti2 = reflect.TypeOf((*testInterface2)(nil)).Elem() - ti1Impl = reflect.TypeOf(testInterface1Impl{}) - ti2Impl = reflect.TypeOf(testInterface2Impl{}) - ti3 = reflect.TypeOf((*testInterface3)(nil)).Elem() - ti3Impl = reflect.TypeOf(testInterface3Impl{}) - ti3ImplPtr = reflect.TypeOf((*testInterface3Impl)(nil)) - fc1, fc2 = &fakeCodec{num: 1}, &fakeCodec{num: 2} - fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec) - pc = NewPointerCodec() - ) - - reg := newTestRegistryBuilder(). - RegisterTypeEncoder(ft1, fc1). - RegisterTypeEncoder(ft2, fc2). - RegisterTypeEncoder(ti1, fc1). - RegisterDefaultEncoder(reflect.Struct, fsc). - RegisterDefaultEncoder(reflect.Slice, fslcc). - RegisterDefaultEncoder(reflect.Array, fslcc). - RegisterDefaultEncoder(reflect.Map, fmc). - RegisterDefaultEncoder(reflect.Ptr, pc). - RegisterTypeDecoder(ft1, fc1). - RegisterTypeDecoder(ft2, fc2). - RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder - RegisterDefaultDecoder(reflect.Struct, fsc). - RegisterDefaultDecoder(reflect.Slice, fslcc). - RegisterDefaultDecoder(reflect.Array, fslcc). - RegisterDefaultDecoder(reflect.Map, fmc). - RegisterDefaultDecoder(reflect.Ptr, pc). - RegisterHookEncoder(ti2, fc2). - RegisterHookDecoder(ti2, fc2). - RegisterHookEncoder(ti3, fc3). - RegisterHookDecoder(ti3, fc3). - Build() - - testCases := []struct { - name string - t reflect.Type - wantcodec Codec - wanterr error - testcache bool + ips := []struct { + i reflect.Type + ef EncoderFactory }{ - { - "type registry (pointer)", - ft1, - fc1, - nil, - false, - }, - { - "type registry (non-pointer)", - ft2, - fc2, - nil, - false, - }, - { - // lookup an interface type and expect that the registered encoder is returned - "interface with type encoder", - ti1, - fc1, - nil, - true, - }, - { - // lookup a type that implements an interface and expect that the default struct codec is returned - "interface implementation with type encoder", - ti1Impl, - fsc, - nil, - false, - }, - { - // lookup an interface type and expect that the registered hook is returned - "interface with hook", - ti2, - fc2, - nil, - false, - }, - { - // lookup a type that implements an interface and expect that the registered hook is returned - "interface implementation with hook", - ti2Impl, - fc2, - nil, - false, - }, - { - // lookup a pointer to a type where the pointer implements an interface and expect that the - // registered hook is returned - "interface pointer to implementation with hook (pointer)", - ti3ImplPtr, - fc3, - nil, - false, - }, - { - "default struct codec (pointer)", - reflect.PtrTo(strct), - pc, - nil, - false, - }, - { - "default struct codec (non-pointer)", - strct, - fsc, - nil, - false, - }, - { - "default array codec", - arr, - fslcc, - nil, - false, - }, - { - "default slice codec", - slc, - fslcc, - nil, - false, - }, - { - "default map", - m, - fmc, - nil, - false, - }, - { - "map non-string key", - reflect.TypeOf(map[int]int{}), - fmc, - nil, - false, - }, - { - "No Codec Registered", - ft3, - nil, - ErrNoEncoder{Type: ft3}, - false, - }, + {i: ft1, ef: ef1}, + {i: ft2, ef: ef2}, + {i: ft1, ef: ef3}, + {i: ft3, ef: ef2}, + {i: ft4, ef: ef4}, } - - allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{}) - comparepc := func(pc1, pc2 *PointerCodec) bool { return true } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Run("Encoder", func(t *testing.T) { - gotcodec, goterr := reg.LookupEncoder(tc.t) - if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(assert.CompareErrors)) { - t.Errorf("errors did not match: got %#v, want %#v", goterr, tc.wanterr) - } - if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec) - } - }) - t.Run("Decoder", func(t *testing.T) { - wanterr := tc.wanterr - if ene, ok := tc.wanterr.(ErrNoEncoder); ok { - wanterr = ErrNoDecoder(ene) - } - - gotcodec, goterr := reg.LookupDecoder(tc.t) - if !cmp.Equal(goterr, wanterr, cmp.Comparer(assert.CompareErrors)) { - t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr) - } - if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec) - } - }) - }) + want := []interfaceValueEncoder{ + {i: ft1, ve: fc3}, {i: ft2, ve: fc2}, + {i: ft3, ve: fc2}, {i: ft4, ve: fc4}, } - // lookup a type whose pointer implements an interface and expect that the registered hook is - // returned - t.Run("interface implementation with hook (pointer)", func(t *testing.T) { - t.Run("Encoder", func(t *testing.T) { - gotEnc, err := reg.LookupEncoder(ti3Impl) - assert.Nil(t, err, "LookupEncoder error: %v", err) - - cae, ok := gotEnc.(*condAddrEncoder) - assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc) - if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected canAddrEnc %#v, got %#v", cae.canAddrEnc, fc3) - } - if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected elseEnc %#v, got %#v", cae.elseEnc, fsc) - } - }) - t.Run("Decoder", func(t *testing.T) { - gotDec, err := reg.LookupDecoder(ti3Impl) - assert.Nil(t, err, "LookupDecoder error: %v", err) - - cad, ok := gotDec.(*condAddrDecoder) - assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec) - if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected canAddrDec %#v, got %#v", cad.canAddrDec, fc3) - } - if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected elseDec %#v, got %#v", cad.elseDec, fsc) - } - }) - }) - }) - }) - t.Run("Type Map", func(t *testing.T) { - reg := newTestRegistryBuilder(). - RegisterTypeMapEntry(TypeString, reflect.TypeOf("")). - RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))). - Build() - - var got, want reflect.Type - want = reflect.TypeOf("") - got, err := reg.LookupTypeMapEntry(TypeString) - noerr(t, err) - if got != want { - t.Errorf("unexpected type: got %#v, want %#v", got, want) - } - - want = reflect.TypeOf(int(0)) - got, err = reg.LookupTypeMapEntry(TypeInt32) - noerr(t, err) - if got != want { - t.Errorf("unexpected type: got %#v, want %#v", got, want) - } - - want = nil - wanterr := ErrNoTypeMapEntry{Type: TypeObjectID} - got, err = reg.LookupTypeMapEntry(TypeObjectID) - if !errors.Is(err, wanterr) { - t.Errorf("did not get expected error: got %#v, want %#v", err, wanterr) - } - if got != want { - t.Errorf("unexpected type: got %#v, want %#v", got, want) - } - }) -} - -func TestRegistry(t *testing.T) { - t.Parallel() - - t.Run("Register", func(t *testing.T) { - t.Parallel() - - fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec) - t.Run("interface", func(t *testing.T) { - t.Parallel() + rb := newTestRegistryBuilder() + for _, ip := range ips { + rb.RegisterTypeEncoder(ip.i, ip.ef) + } + reg := rb.Build() - var t1f *testInterface1 - var t2f *testInterface2 - var t4f *testInterface4 - ips := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc1}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + if !cmp.Equal(c1, 0) { + t.Errorf("ef1 is called %d time(s); expected 0", c1) } - want := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + if !cmp.Equal(c2, 1) { + t.Errorf("ef2 is called %d time(s); expected 1", c2) } - reg := newTestRegistryBuilder().Build() - for _, ip := range ips { - reg.RegisterInterfaceEncoder(ip.i, ip.ve) + if !cmp.Equal(c3, 1) { + t.Errorf("ef3 is called %d time(s); expected 1", c3) } - got := reg.interfaceEncoders - if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) { - t.Errorf("registered interfaces are not correct: got %#v, want %#v", got, want) + if !cmp.Equal(c4, 1) { + t.Errorf("ef4 is called %d time(s); expected 1", c4) } - }) - t.Run("type", func(t *testing.T) { - t.Parallel() - - ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} - reg := newTestRegistryBuilder().Build() - reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc1) - reg.RegisterTypeEncoder(reflect.TypeOf(ft2), fc2) - reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc3) - reg.RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) - - want := []struct { - t reflect.Type - c ValueEncoder - }{ - {reflect.TypeOf(ft1), fc3}, - {reflect.TypeOf(ft2), fc2}, - {reflect.TypeOf(ft4), fc4}, + codecs, ok := reg.codecTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.codecTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + t.Errorf("codecs were not cached correctly") } got := reg.typeEncoders for _, s := range want { - wantT, wantC := s.t, s.c - gotC, exists := got.Load(wantT) + wantI, wantVe := s.i, s.ve + gotVe, exists := got.Load(wantI) if !exists { - t.Errorf("type missing in registry: %v", wantT) + t.Errorf("type missing in registry: %v", wantI) } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC) + if !cmp.Equal(gotVe, wantVe, cmp.AllowUnexported(fakeCodec{})) { + t.Errorf("codecs did not match: got %#v; want %#v", gotVe, wantVe) } } }) t.Run("kind", func(t *testing.T) { t.Parallel() - k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map - reg := newTestRegistryBuilder().Build() - reg.RegisterKindEncoder(k1, fc1) - reg.RegisterKindEncoder(k2, fc2) - reg.RegisterKindEncoder(k1, fc3) - reg.RegisterKindEncoder(k4, fc4) + k1, k2, k3, k4 := reflect.Struct, reflect.Slice, reflect.Int, reflect.Map + var c1, c2, c3, c4 int + ef1 := func(*Registry) ValueEncoder { + c1++ + return fc1 + } + ef2 := func(*Registry) ValueEncoder { + c2++ + return fc2 + } + ef3 := func(*Registry) ValueEncoder { + c3++ + return fc3 + } + ef4 := func(*Registry) ValueEncoder { + c4++ + return fc4 + } + + ips := []struct { + k reflect.Kind + ef EncoderFactory + }{ + {k: k1, ef: ef1}, + {k: k2, ef: ef2}, + {k: k1, ef: ef3}, + {k: k3, ef: ef2}, + {k: k4, ef: ef4}, + } want := []struct { k reflect.Kind c ValueEncoder }{ - {k1, fc3}, - {k2, fc2}, - {k4, fc4}, + {k1, fc3}, {k2, fc2}, {k4, fc4}, + } + + rb := newTestRegistryBuilder() + for _, ip := range ips { + rb.RegisterKindEncoder(ip.k, ip.ef) + } + reg := rb.Build() + + if !cmp.Equal(c1, 0) { + t.Errorf("ef1 is called %d time(s); expected 0", c1) + } + if !cmp.Equal(c2, 1) { + t.Errorf("ef2 is called %d time(s); expected 1", c2) + } + if !cmp.Equal(c3, 1) { + t.Errorf("ef3 is called %d time(s); expected 1", c3) + } + if !cmp.Equal(c4, 1) { + t.Errorf("ef4 is called %d time(s); expected 1", c4) + } + codecs, ok := reg.codecTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.codecTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + t.Errorf("codecs were not cached correctly") } got := reg.kindEncoders for _, s := range want { - wantK, wantC := s.k, s.c - gotC, exists := got.Load(wantK) - if !exists { - t.Errorf("type missing in registry: %v", wantK) - } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v, want %#v", gotC, wantC) + wantI, wantVe := s.k, s.c + gotC := got[wantI] + if !cmp.Equal(gotC, wantVe, cmp.AllowUnexported(fakeCodec{})) { + t.Errorf("codecs did not match: got %#v, want %#v", gotC, wantVe) } } }) @@ -543,14 +267,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() - reg.RegisterKindEncoder(reflect.Map, codec) - if reg.kindEncoders.get(reflect.Map) != codec { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Map]; got != codec { + t.Errorf("map codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Map, codec2) - if reg.kindEncoders.get(reflect.Map) != codec2 { - t.Errorf("map codec properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) + + rb.RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Map]; got != codec2 { + t.Errorf("map codec not properly set: got %#v, want %#v", got, codec2) } }) t.Run("StructCodec", func(t *testing.T) { @@ -558,14 +286,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() - reg.RegisterKindEncoder(reflect.Struct, codec) - if reg.kindEncoders.get(reflect.Struct) != codec { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Struct, func(*Registry) ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Struct]; got != codec { + t.Errorf("struct codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Struct, codec2) - if reg.kindEncoders.get(reflect.Struct) != codec2 { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) + + rb.RegisterKindEncoder(reflect.Struct, func(*Registry) ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Struct]; got != codec2 { + t.Errorf("struct codec not properly set: got %#v, want %#v", got, codec2) } }) t.Run("SliceCodec", func(t *testing.T) { @@ -573,14 +305,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() - reg.RegisterKindEncoder(reflect.Slice, codec) - if reg.kindEncoders.get(reflect.Slice) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Slice, func(*Registry) ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Slice]; got != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Slice, codec2) - if reg.kindEncoders.get(reflect.Slice) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) + + rb.RegisterKindEncoder(reflect.Slice, func(*Registry) ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Slice]; got != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec2) } }) t.Run("ArrayCodec", func(t *testing.T) { @@ -588,14 +324,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() - reg.RegisterKindEncoder(reflect.Array, codec) - if reg.kindEncoders.get(reflect.Array) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Array, func(*Registry) ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Array]; got != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Array, codec2) - if reg.kindEncoders.get(reflect.Array) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) + + rb.RegisterKindEncoder(reflect.Array, func(*Registry) ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Array]; got != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec2) } }) }) @@ -625,30 +365,47 @@ func TestRegistry(t *testing.T) { ti3ImplPtr = reflect.TypeOf((*testInterface3Impl)(nil)) fc1, fc2 = &fakeCodec{num: 1}, &fakeCodec{num: 2} fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec) - pc = NewPointerCodec() + pc = &pointerCodec{} ) - reg := newTestRegistryBuilder().Build() - reg.RegisterTypeEncoder(ft1, fc1) - reg.RegisterTypeEncoder(ft2, fc2) - reg.RegisterTypeEncoder(ti1, fc1) - reg.RegisterKindEncoder(reflect.Struct, fsc) - reg.RegisterKindEncoder(reflect.Slice, fslcc) - reg.RegisterKindEncoder(reflect.Array, fslcc) - reg.RegisterKindEncoder(reflect.Map, fmc) - reg.RegisterKindEncoder(reflect.Ptr, pc) - reg.RegisterTypeDecoder(ft1, fc1) - reg.RegisterTypeDecoder(ft2, fc2) - reg.RegisterTypeDecoder(ti1, fc1) // values whose exact type is testInterface1 will use fc1 encoder - reg.RegisterKindDecoder(reflect.Struct, fsc) - reg.RegisterKindDecoder(reflect.Slice, fslcc) - reg.RegisterKindDecoder(reflect.Array, fslcc) - reg.RegisterKindDecoder(reflect.Map, fmc) - reg.RegisterKindDecoder(reflect.Ptr, pc) - reg.RegisterInterfaceEncoder(ti2, fc2) - reg.RegisterInterfaceDecoder(ti2, fc2) - reg.RegisterInterfaceEncoder(ti3, fc3) - reg.RegisterInterfaceDecoder(ti3, fc3) + fc1EncFac := func(*Registry) ValueEncoder { return fc1 } + fc2EncFac := func(*Registry) ValueEncoder { return fc2 } + fc3EncFac := func(*Registry) ValueEncoder { return fc3 } + fscEncFac := func(*Registry) ValueEncoder { return fsc } + fslccEncFac := func(*Registry) ValueEncoder { return fslcc } + fmcEncFac := func(*Registry) ValueEncoder { return fmc } + pcEncFac := func(*Registry) ValueEncoder { return pc } + + fc1DecFac := func(*Registry) ValueDecoder { return fc1 } + fc2DecFac := func(*Registry) ValueDecoder { return fc2 } + fc3DecFac := func(*Registry) ValueDecoder { return fc3 } + fscDecFac := func(*Registry) ValueDecoder { return fsc } + fslccDecFac := func(*Registry) ValueDecoder { return fslcc } + fmcDecFac := func(*Registry) ValueDecoder { return fmc } + pcDecFac := func(*Registry) ValueDecoder { return pc } + + reg := newTestRegistryBuilder(). + RegisterTypeEncoder(ft1, fc1EncFac). + RegisterTypeEncoder(ft2, fc2EncFac). + RegisterTypeEncoder(ti1, fc1EncFac). + RegisterKindEncoder(reflect.Struct, fscEncFac). + RegisterKindEncoder(reflect.Slice, fslccEncFac). + RegisterKindEncoder(reflect.Array, fslccEncFac). + RegisterKindEncoder(reflect.Map, fmcEncFac). + RegisterKindEncoder(reflect.Ptr, pcEncFac). + RegisterTypeDecoder(ft1, fc1DecFac). + RegisterTypeDecoder(ft2, fc2DecFac). + RegisterTypeDecoder(ti1, fc1DecFac). // values whose exact type is testInterface1 will use fc1 encoder + RegisterKindDecoder(reflect.Struct, fscDecFac). + RegisterKindDecoder(reflect.Slice, fslccDecFac). + RegisterKindDecoder(reflect.Array, fslccDecFac). + RegisterKindDecoder(reflect.Map, fmcDecFac). + RegisterKindDecoder(reflect.Ptr, pcDecFac). + RegisterInterfaceEncoder(ti2, fc2EncFac). + RegisterInterfaceEncoder(ti3, fc3EncFac). + RegisterInterfaceDecoder(ti2, fc2DecFac). + RegisterInterfaceDecoder(ti3, fc3DecFac). + Build() testCases := []struct { name string @@ -764,7 +521,7 @@ func TestRegistry(t *testing.T) { } allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{}) - comparepc := func(pc1, pc2 *PointerCodec) bool { return true } + comparepc := func(pc1, pc2 *pointerCodec) bool { return true } for _, tc := range testCases { tc := tc @@ -819,7 +576,7 @@ func TestRegistry(t *testing.T) { t.Run("Decoder", func(t *testing.T) { t.Parallel() - wanterr := ErrNilType + wanterr := ErrNoDecoder{Type: nil} gotcodec, goterr := reg.LookupDecoder(nil) if !cmp.Equal(goterr, wanterr, cmp.Comparer(assert.CompareErrors)) { @@ -869,9 +626,10 @@ func TestRegistry(t *testing.T) { }) t.Run("Type Map", func(t *testing.T) { t.Parallel() - reg := newTestRegistryBuilder().Build() - reg.RegisterTypeMapEntry(TypeString, reflect.TypeOf("")) - reg.RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))) + reg := newTestRegistryBuilder(). + RegisterTypeMapEntry(TypeString, reflect.TypeOf("")). + RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))). + Build() var got, want reflect.Type @@ -901,12 +659,6 @@ func TestRegistry(t *testing.T) { }) } -// get is only for testing as it does return if the value was found -func (c *kindEncoderCache) get(rt reflect.Kind) ValueEncoder { - e, _ := c.Load(rt) - return e -} - func BenchmarkLookupEncoder(b *testing.B) { type childStruct struct { V1, V2, V3, V4 int @@ -923,10 +675,11 @@ func BenchmarkLookupEncoder(b *testing.B) { reflect.TypeOf(&testInterface1Impl{}), reflect.TypeOf(&nestedStruct{}), } - r := NewRegistry() + rb := NewRegistryBuilder() for _, typ := range types { - r.RegisterTypeEncoder(typ, &fakeCodec{}) + rb.RegisterTypeEncoder(typ, func(*Registry) ValueEncoder { return &fakeCodec{} }) } + r := rb.Build() b.Run("Serial", func(b *testing.B) { for i := 0; i < b.N; i++ { _, err := r.LookupEncoder(types[i%len(types)]) @@ -949,6 +702,7 @@ func BenchmarkLookupEncoder(b *testing.B) { type fakeType1 struct{} type fakeType2 struct{} +type fakeType3 struct{} type fakeType4 struct{} type fakeType5 func(string, string) string type fakeStructCodec struct{ *fakeCodec } @@ -963,10 +717,10 @@ type fakeCodec struct { num int } -func (*fakeCodec) EncodeValue(EncodeContext, ValueWriter, reflect.Value) error { +func (*fakeCodec) EncodeValue(EncoderRegistry, ValueWriter, reflect.Value) error { return nil } -func (*fakeCodec) DecodeValue(DecodeContext, ValueReader, reflect.Value) error { +func (*fakeCodec) DecodeValue(DecoderRegistry, ValueReader, reflect.Value) error { return nil } @@ -992,5 +746,3 @@ type testInterface3Impl struct{} var _ testInterface3 = (*testInterface3Impl)(nil) func (*testInterface3Impl) test3() {} - -func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 } diff --git a/bson/mgocompat/setter_getter.go b/bson/setter_getter.go similarity index 66% rename from bson/mgocompat/setter_getter.go rename to bson/setter_getter.go index fc620fbba8..5d08b40c42 100644 --- a/bson/mgocompat/setter_getter.go +++ b/bson/setter_getter.go @@ -4,13 +4,11 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package mgocompat +package bson import ( "errors" "reflect" - - "go.mongodb.org/mongo-driver/bson" ) // Setter interface: a value implementing the bson.Setter interface will receive the BSON @@ -34,7 +32,7 @@ import ( // return raw.Unmarshal(s) // } type Setter interface { - SetBSON(raw bson.RawValue) error + SetBSON(raw RawValue) error } // Getter interface: a value implementing the bson.Getter interface will have its GetBSON @@ -48,35 +46,35 @@ type Getter interface { } // SetterDecodeValue is the ValueDecoderFunc for Setter types. -func SetterDecodeValue(_ bson.DecodeContext, vr bson.ValueReader, val reflect.Value) error { +func SetterDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tSetter) && !reflect.PtrTo(val.Type()).Implements(tSetter)) { - return bson.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} } if val.Kind() == reflect.Ptr && val.IsNil() { if !val.CanSet() { - return bson.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} } val.Set(reflect.New(val.Type().Elem())) } if !val.Type().Implements(tSetter) { if !val.CanAddr() { - return bson.ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tSetter}, Received: val} } val = val.Addr() // If the type doesn't implement the interface, a pointer to it must. } - t, src, err := bson.CopyValueToBytes(vr) + t, src, err := CopyValueToBytes(vr) if err != nil { return err } m, ok := val.Interface().(Setter) if !ok { - return bson.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} } - if err := m.SetBSON(bson.RawValue{Type: t, Value: src}); err != nil { + if err := m.SetBSON(RawValue{Type: t, Value: src}); err != nil { if !errors.Is(err, ErrSetZero) { return err } @@ -86,11 +84,11 @@ func SetterDecodeValue(_ bson.DecodeContext, vr bson.ValueReader, val reflect.Va } // GetterEncodeValue is the ValueEncoderFunc for Getter types. -func GetterEncodeValue(ec bson.EncodeContext, vw bson.ValueWriter, val reflect.Value) error { +func GetterEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Getter switch { case !val.IsValid(): - return bson.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} + return ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} case val.Type().Implements(tGetter): // If Getter is implemented on a concrete type, make sure that val isn't a nil pointer if isImplementationNil(val, tGetter) { @@ -99,7 +97,7 @@ func GetterEncodeValue(ec bson.EncodeContext, vw bson.ValueWriter, val reflect.V case reflect.PtrTo(val.Type()).Implements(tGetter) && val.CanAddr(): val = val.Addr() default: - return bson.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} + return ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} } m, ok := val.Interface().(Getter) @@ -114,18 +112,9 @@ func GetterEncodeValue(ec bson.EncodeContext, vw bson.ValueWriter, val reflect.V return vw.WriteNull() } vv := reflect.ValueOf(x) - encoder, err := ec.Registry.LookupEncoder(vv.Type()) + encoder, err := reg.LookupEncoder(vv.Type()) if err != nil { return err } - return encoder.EncodeValue(ec, vw, vv) -} - -// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type -func isImplementationNil(val reflect.Value, inter reflect.Type) bool { - vt := val.Type() - for vt.Kind() == reflect.Ptr { - vt = vt.Elem() - } - return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil() + return encoder.EncodeValue(reg, vw, vv) } diff --git a/bson/slice_codec.go b/bson/slice_codec.go index 52449239b9..f08f8100d6 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -10,45 +10,22 @@ import ( "errors" "fmt" "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -var defaultSliceCodec = NewSliceCodec() - -// SliceCodec is the Codec used for slice values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// SliceCodec registered. -type SliceCodec struct { - // EncodeNilAsEmpty causes EncodeValue to marshal nil Go slices as empty BSON arrays instead of +// sliceCodec is the Codec used for slice values. +type sliceCodec struct { + // encodeNilAsEmpty causes EncodeValue to marshal nil Go slices as empty BSON arrays instead of // BSON null. - // - // Deprecated: Use bson.Encoder.NilSliceAsEmpty instead. - EncodeNilAsEmpty bool -} - -// NewSliceCodec returns a MapCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// SliceCodec registered. -func NewSliceCodec(opts ...*bsonoptions.SliceCodecOptions) *SliceCodec { - sliceOpt := bsonoptions.MergeSliceCodecOptions(opts...) - - codec := SliceCodec{} - if sliceOpt.EncodeNilAsEmpty != nil { - codec.EncodeNilAsEmpty = *sliceOpt.EncodeNilAsEmpty - } - return &codec + encodeNilAsEmpty bool } // EncodeValue is the ValueEncoder for slice types. -func (sc SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc *sliceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Slice { return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } - if val.IsNil() && !sc.EncodeNilAsEmpty && !ec.nilSliceAsEmpty { + if val.IsNil() && !sc.encodeNilAsEmpty { return vw.WriteNull() } @@ -69,7 +46,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.V } for _, e := range d { - err = encodeElement(ec, dw, e) + err = encodeElement(reg, dw, e) if err != nil { return err } @@ -84,13 +61,13 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.V } elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) + encoder, err := reg.LookupEncoder(elemType) if err != nil && elemType.Kind() != reflect.Interface { return err } for idx := 0; idx < val.Len(); idx++ { - currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.Index(idx)) + currEncoder, currVal, lookupErr := lookupElementEncoder(reg, encoder, val.Index(idx)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -108,7 +85,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.V continue } - err = currEncoder.EncodeValue(ec, vw, currVal) + err = currEncoder.EncodeValue(reg, vw, currVal) if err != nil { return err } @@ -117,7 +94,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.V } // DecodeValue is the ValueDecoder for slice types. -func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *sliceCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Slice { return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } @@ -172,16 +149,15 @@ func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect. return fmt.Errorf("cannot decode %v into a slice", vrType) } - var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error) + var elemsFunc func(DecoderRegistry, ValueReader, reflect.Value) ([]reflect.Value, error) switch val.Type().Elem() { case tE: - dc.Ancestor = val.Type() - elemsFunc = defaultValueDecoders.decodeD + elemsFunc = decodeD default: - elemsFunc = defaultValueDecoders.decodeDefault + elemsFunc = decodeDefault } - elems, err := elemsFunc(dc, vr, val) + elems, err := elemsFunc(reg, vr, val) if err != nil { return err } diff --git a/bson/string_codec.go b/bson/string_codec.go index 50fb9229fe..55b0fd9c62 100644 --- a/bson/string_codec.go +++ b/bson/string_codec.go @@ -7,44 +7,20 @@ package bson import ( + "errors" "fmt" "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -// StringCodec is the Codec used for string values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StringCodec registered. -type StringCodec struct { - // DecodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. +// stringCodec is the Codec used for string values. +type stringCodec struct { + // decodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. // If false, a string made from the raw object ID bytes will be used. Defaults to true. - // - // Deprecated: Decoding object IDs as raw bytes will not be supported in Go Driver 2.0. - DecodeObjectIDAsHex bool -} - -var ( - defaultStringCodec = NewStringCodec() - - // Assert that defaultStringCodec satisfies the typeDecoder interface, which allows it to be - // used by collection type decoders (e.g. map, slice, etc) to set individual values in a - // collection. - _ typeDecoder = defaultStringCodec -) - -// NewStringCodec returns a StringCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StringCodec registered. -func NewStringCodec(opts ...*bsonoptions.StringCodecOptions) *StringCodec { - stringOpt := bsonoptions.MergeStringCodecOptions(opts...) - return &StringCodec{*stringOpt.DecodeObjectIDAsHex} + decodeObjectIDAsHex bool } // EncodeValue is the ValueEncoder for string types. -func (sc *StringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc *stringCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.String { return ValueEncoderError{ Name: "StringEncodeValue", @@ -56,7 +32,7 @@ func (sc *StringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect. return vw.WriteString(val.String()) } -func (sc *StringCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (sc *stringCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t.Kind() != reflect.String { return emptyValue, ValueDecoderError{ Name: "StringDecodeValue", @@ -78,13 +54,10 @@ func (sc *StringCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Typ if err != nil { return emptyValue, err } - if sc.DecodeObjectIDAsHex { - str = oid.Hex() - } else { - // TODO(GODRIVER-2796): Return an error here instead of decoding to a garbled string. - byteArray := [12]byte(oid) - str = string(byteArray[:]) + if !sc.decodeObjectIDAsHex { + return emptyValue, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set") } + str = oid.Hex() case TypeSymbol: str, err = vr.ReadSymbol() if err != nil { @@ -115,12 +88,12 @@ func (sc *StringCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Typ } // DecodeValue is the ValueDecoder for string types. -func (sc *StringCodec) DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *stringCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.String { return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} } - elem, err := sc.decodeType(dctx, vr, val.Type()) + elem, err := sc.decodeType(reg, vr, val.Type()) if err != nil { return err } diff --git a/bson/string_codec_test.go b/bson/string_codec_test.go index 75ace60c5d..c764af97dc 100644 --- a/bson/string_codec_test.go +++ b/bson/string_codec_test.go @@ -7,35 +7,36 @@ package bson import ( + "errors" "reflect" "testing" - "go.mongodb.org/mongo-driver/bson/bsonoptions" "go.mongodb.org/mongo-driver/internal/assert" ) func TestStringCodec(t *testing.T) { t.Run("ObjectIDAsHex", func(t *testing.T) { oid := NewObjectID() - byteArray := [12]byte(oid) reader := &valueReaderWriter{BSONType: TypeObjectID, Return: oid} testCases := []struct { name string - opts *bsonoptions.StringCodecOptions - hex bool + codec *stringCodec + err error result string }{ - {"default", bsonoptions.StringCodec(), true, oid.Hex()}, - {"true", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(true), true, oid.Hex()}, - {"false", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false), false, string(byteArray[:])}, + {"default", &stringCodec{}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, + {"true", &stringCodec{decodeObjectIDAsHex: true}, nil, oid.Hex()}, + {"false", &stringCodec{decodeObjectIDAsHex: false}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - stringCodec := NewStringCodec(tc.opts) - actual := reflect.New(reflect.TypeOf("")).Elem() - err := stringCodec.DecodeValue(DecodeContext{}, reader, actual) - assert.Nil(t, err, "StringCodec.DecodeValue error: %v", err) + err := tc.codec.DecodeValue(nil, reader, actual) + if tc.err == nil { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tc.err.Error()) + } actualString := actual.Interface().(string) assert.Equal(t, tc.result, actualString, "Expected string %v, got %v", tc.result, actualString) diff --git a/bson/struct_codec.go b/bson/struct_codec.go index 917ac17bfd..c9b5306c33 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -14,8 +14,6 @@ import ( "strings" "sync" "time" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) // DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type. @@ -49,91 +47,97 @@ func (de *DecodeError) Keys() []string { return reversedKeys } -// StructCodec is the Codec used for struct values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StructCodec registered. -type StructCodec struct { - cache sync.Map // map[reflect.Type]*structDescription - parser StructTagParser +// mapElementsEncoder handles encoding of the values of an inline map. +type mapElementsEncoder interface { + encodeMapElements(EncoderRegistry, DocumentWriter, reflect.Value, func(string) bool) error +} - // DecodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the - // destination value passed to Decode before unmarshaling BSON documents into them. - // - // Deprecated: Use bson.Decoder.ZeroStructs instead. - DecodeZeroStruct bool +// structCodec is the Codec used for struct values. +type structCodec struct { + cache sync.Map // map[reflect.Type]*structDescription + elemEncoder mapElementsEncoder + + // decodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the + decodeZeroStruct bool - // DecodeDeepZeroInline causes DecodeValue to delete any existing values from Go structs in the + // decodeDeepZeroInline causes DecodeValue to delete any existing values from Go structs in the // destination value passed to Decode before unmarshaling BSON documents into them. - // - // Deprecated: DecodeDeepZeroInline will not be supported in Go Driver 2.0. - DecodeDeepZeroInline bool + decodeDeepZeroInline bool - // EncodeOmitDefaultStruct causes the Encoder to consider the zero value for a struct (e.g. + // encodeOmitDefaultStruct causes the Encoder to consider the zero value for a struct (e.g. // MyStruct{}) as empty and omit it from the marshaled BSON when the "omitempty" struct tag // option is set. - // - // Deprecated: Use bson.Encoder.OmitZeroStruct instead. - EncodeOmitDefaultStruct bool + encodeOmitDefaultStruct bool - // AllowUnexportedFields allows encoding and decoding values from un-exported struct fields. - // - // Deprecated: AllowUnexportedFields does not work on recent versions of Go and will not be - // supported in Go Driver 2.0. - AllowUnexportedFields bool + // allowUnexportedFields allows encoding and decoding values from un-exported struct fields. + allowUnexportedFields bool - // OverwriteDuplicatedInlinedFields, if false, causes EncodeValue to return an error if there is + // overwriteDuplicatedInlinedFields, if false, causes EncodeValue to return an error if there is // a duplicate field in the marshaled BSON when the "inline" struct tag option is set. The // default value is true. - // - // Deprecated: Use bson.Encoder.ErrorOnInlineDuplicates instead. - OverwriteDuplicatedInlinedFields bool -} + overwriteDuplicatedInlinedFields bool -var _ ValueEncoder = &StructCodec{} -var _ ValueDecoder = &StructCodec{} + useJSONStructTags bool +} -// NewStructCodec returns a StructCodec that uses p for struct tag parsing. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StructCodec registered. -func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) { - if p == nil { - return nil, errors.New("a StructTagParser must be provided to NewStructCodec") +// newStructCodec returns a StructCodec that uses p for struct tag parsing. +func newStructCodec(elemEncoder mapElementsEncoder) *structCodec { + return &structCodec{ + elemEncoder: elemEncoder, + overwriteDuplicatedInlinedFields: true, } +} - structOpt := bsonoptions.MergeStructCodecOptions(opts...) +type localEncoderRegistry struct { + registry EncoderRegistry + encoderLookup func(EncoderRegistry, reflect.Type) (ValueEncoder, error) +} - codec := &StructCodec{ - parser: p, - } +func (r *localEncoderRegistry) LookupEncoder(t reflect.Type) (ValueEncoder, error) { + return r.encoderLookup(r.registry, t) +} - if structOpt.DecodeZeroStruct != nil { - codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct - } - if structOpt.DecodeDeepZeroInline != nil { - codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline +func onMinSize(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { + enc, err := reg.LookupEncoder(t) + if err != nil { + return enc, err } - if structOpt.EncodeOmitDefaultStruct != nil { - codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct + switch t.Kind() { + case reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + if codec, ok := enc.(*numCodec); ok { + c := *codec + c.minSize = true + return &c, nil + } } - if structOpt.OverwriteDuplicatedInlinedFields != nil { - codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields + return enc, nil +} + +func onTruncate(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { + enc, err := reg.LookupEncoder(t) + if err != nil { + return enc, err } - if structOpt.AllowUnexportedFields != nil { - codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields + switch t.Kind() { + case reflect.Float32, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + if codec, ok := enc.(*numCodec); ok { + c := *codec + c.truncate = true + return &c, nil + } } - - return codec, nil + return enc, nil } // EncodeValue handles encoding generic struct types. -func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc *structCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Struct { - return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} + return ValueEncoderError{Name: "structCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } - sd, err := sc.describeStruct(ec.Registry, val.Type(), ec.useJSONStructTags, ec.errorOnInlineDuplicates) + sd, err := sc.describeStruct(val.Type(), sc.useJSONStructTags, !sc.overwriteDuplicatedInlinedFields) if err != nil { return err } @@ -153,7 +157,17 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect } } - desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(ec, desc.encoder, rv) + reg = &localEncoderRegistry{ + registry: reg, + encoderLookup: desc.encoderLookup, + } + + var encoder ValueEncoder + if encoder, err = reg.LookupEncoder(desc.fieldType); err != nil { + encoder = nil + } + + encoder, rv, err = lookupElementEncoder(reg, encoder, rv) if err != nil && !errors.Is(err, errInvalidValue) { return err @@ -174,12 +188,10 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect continue } - if desc.encoder == nil { + if encoder == nil { return ErrNoEncoder{Type: rv.Type()} } - encoder := desc.encoder - var empty bool if cz, ok := encoder.(CodecZeroer); ok { empty = cz.IsTypeZero(rv.Interface()) @@ -188,7 +200,7 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect // nil interface separately. empty = rv.IsNil() } else { - empty = isEmpty(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct) + empty = isEmpty(rv, sc.encodeOmitDefaultStruct) } if desc.omitEmpty && empty { continue @@ -199,18 +211,7 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect return err } - ectx := EncodeContext{ - Registry: ec.Registry, - MinSize: desc.minSize || ec.MinSize, - errorOnInlineDuplicates: ec.errorOnInlineDuplicates, - stringifyMapKeysWithFmt: ec.stringifyMapKeysWithFmt, - nilMapAsEmpty: ec.nilMapAsEmpty, - nilSliceAsEmpty: ec.nilSliceAsEmpty, - nilByteSliceAsEmpty: ec.nilByteSliceAsEmpty, - omitZeroStruct: ec.omitZeroStruct, - useJSONStructTags: ec.useJSONStructTags, - } - err = encoder.EncodeValue(ectx, vw2, rv) + err = encoder.EncodeValue(reg, vw2, rv) if err != nil { return err } @@ -223,7 +224,10 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect return exists } - return defaultMapCodec.mapEncodeValue(ec, dw, rv, collisionFn) + err = sc.elemEncoder.encodeMapElements(reg, dw, rv, collisionFn) + if err != nil { + return err + } } return dw.WriteDocumentEnd() @@ -245,9 +249,9 @@ func newDecodeError(key string, original error) error { // DecodeValue implements the Codec interface. // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr. // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared. -func (sc *StructCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *structCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Struct { - return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} + return ValueDecoderError{Name: "structCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } switch vrType := vr.Type(); vrType { @@ -270,15 +274,15 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) } - sd, err := sc.describeStruct(dc.Registry, val.Type(), dc.useJSONStructTags, false) + sd, err := sc.describeStruct(val.Type(), sc.useJSONStructTags, false) if err != nil { return err } - if sc.DecodeZeroStruct || dc.zeroStructs { + if sc.decodeZeroStruct { val.Set(reflect.Zero(val.Type())) } - if sc.DecodeDeepZeroInline && sd.inline { + if sc.decodeDeepZeroInline && sd.inline { val.Set(deepZero(val.Type())) } @@ -286,7 +290,7 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect var inlineMap reflect.Value if sd.inlineMap >= 0 { inlineMap = val.Field(sd.inlineMap) - decoder, err = dc.LookupDecoder(inlineMap.Type().Elem()) + decoder, err = reg.LookupDecoder(inlineMap.Type().Elem()) if err != nil { return err } @@ -325,13 +329,22 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect continue } + inlineT := inlineMap.Type() + if inlineMap.IsNil() { - inlineMap.Set(reflect.MakeMap(inlineMap.Type())) + inlineMap.Set(reflect.MakeMap(inlineT)) } - elem := reflect.New(inlineMap.Type().Elem()).Elem() - dc.Ancestor = inlineMap.Type() - err = decoder.DecodeValue(dc, vr, elem) + var elem reflect.Value + if elemT := inlineT.Elem(); elemT == tEmpty { + elem, err = decodeTypeOrValueWithInfo(decoder, reg, vr, inlineT) + if elem.Type() != elemT { + elem = elem.Convert(elemT) + } + } else { + elem = reflect.New(elemT).Elem() + err = decoder.DecodeValue(reg, vr, elem) + } if err != nil { return err } @@ -358,22 +371,12 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect } field = field.Addr() - dctx := DecodeContext{ - Registry: dc.Registry, - Truncate: fd.truncate || dc.Truncate, - defaultDocumentType: dc.defaultDocumentType, - binaryAsSlice: dc.binaryAsSlice, - useJSONStructTags: dc.useJSONStructTags, - useLocalTimeZone: dc.useLocalTimeZone, - zeroMaps: dc.zeroMaps, - zeroStructs: dc.zeroStructs, - } - - if fd.decoder == nil { + decoder, err := reg.LookupDecoder(fd.fieldType) + if err != nil { return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()}) } - err = fd.decoder.DecodeValue(dctx, vr, field.Elem()) + err = decoder.DecodeValue(reg, vr, field.Elem()) if err != nil { return newDecodeError(fd.name, err) } @@ -421,15 +424,13 @@ type structDescription struct { } type fieldDescription struct { - name string // BSON key name - fieldName string // struct field name - idx int - omitEmpty bool - minSize bool - truncate bool - inline []int - encoder ValueEncoder - decoder ValueDecoder + name string // BSON key name + fieldName string // struct field name + idx int + inline []int + omitEmpty bool + fieldType reflect.Type + encoderLookup func(EncoderRegistry, reflect.Type) (ValueEncoder, error) } type byIndex []fieldDescription @@ -461,8 +462,7 @@ func (bi byIndex) Less(i, j int) bool { return len(bi[i].inline) < len(bi[j].inline) } -func (sc *StructCodec) describeStruct( - r *Registry, +func (sc *structCodec) describeStruct( t reflect.Type, useJSONStructTags bool, errorOnDuplicates bool, @@ -474,7 +474,7 @@ func (sc *StructCodec) describeStruct( } // TODO(charlie): Only describe the struct once when called // concurrently with the same type. - ds, err := sc.describeStructSlow(r, t, useJSONStructTags, errorOnDuplicates) + ds, err := sc.describeStructSlow(t, useJSONStructTags, errorOnDuplicates) if err != nil { return nil, err } @@ -484,8 +484,7 @@ func (sc *StructCodec) describeStruct( return ds, nil } -func (sc *StructCodec) describeStructSlow( - r *Registry, +func (sc *structCodec) describeStructSlow( t reflect.Type, useJSONStructTags bool, errorOnDuplicates bool, @@ -500,35 +499,27 @@ func (sc *StructCodec) describeStructSlow( var fields []fieldDescription for i := 0; i < numFields; i++ { sf := t.Field(i) - if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) { + if sf.PkgPath != "" && (!sc.allowUnexportedFields || !sf.Anonymous) { // field is private or unexported fields aren't allowed, ignore continue } sfType := sf.Type - encoder, err := r.LookupEncoder(sfType) - if err != nil { - encoder = nil - } - decoder, err := r.LookupDecoder(sfType) - if err != nil { - decoder = nil - } description := fieldDescription{ fieldName: sf.Name, idx: i, - encoder: encoder, - decoder: decoder, + fieldType: sfType, } - var stags StructTags + var stags *structTags + var err error // If the caller requested that we use JSON struct tags, use the JSONFallbackStructTagParser // instead of the parser defined on the codec. if useJSONStructTags { - stags, err = JSONFallbackStructTagParser.ParseStructTags(sf) + stags, err = parseJSONStructTags(sf) } else { - stags, err = sc.parser.ParseStructTags(sf) + stags, err = parseStructTags(sf) } if err != nil { return nil, err @@ -538,8 +529,21 @@ func (sc *StructCodec) describeStructSlow( } description.name = stags.Name description.omitEmpty = stags.OmitEmpty - description.minSize = stags.MinSize - description.truncate = stags.Truncate + description.encoderLookup = func(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { + if stags.MinSize { + reg = &localEncoderRegistry{ + registry: reg, + encoderLookup: onMinSize, + } + } + if stags.Truncate { + reg = &localEncoderRegistry{ + registry: reg, + encoderLookup: onTruncate, + } + } + return reg.LookupEncoder(t) + } if stags.Inline { sd.inline = true @@ -559,7 +563,7 @@ func (sc *StructCodec) describeStructSlow( } fallthrough case reflect.Struct: - inlinesf, err := sc.describeStruct(r, sfType, useJSONStructTags, errorOnDuplicates) + inlinesf, err := sc.describeStruct(sfType, useJSONStructTags, errorOnDuplicates) if err != nil { return nil, err } @@ -611,7 +615,7 @@ func (sc *StructCodec) describeStructSlow( continue } dominant, ok := dominantField(fields[i : i+advance]) - if !ok || !sc.OverwriteDuplicatedInlinedFields || errorOnDuplicates { + if !ok || !sc.overwriteDuplicatedInlinedFields || errorOnDuplicates { return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name) } sd.fl = append(sd.fl, dominant) diff --git a/bson/struct_tag_parser.go b/bson/struct_tag_parser.go index d116c14040..7cf8aecffe 100644 --- a/bson/struct_tag_parser.go +++ b/bson/struct_tag_parser.go @@ -11,25 +11,7 @@ import ( "strings" ) -// StructTagParser returns the struct tags for a given struct field. -// -// Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. -type StructTagParser interface { - ParseStructTags(reflect.StructField) (StructTags, error) -} - -// StructTagParserFunc is an adapter that allows a generic function to be used -// as a StructTagParser. -// -// Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. -type StructTagParserFunc func(reflect.StructField) (StructTags, error) - -// ParseStructTags implements the StructTagParser interface. -func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructTags, error) { - return stpf(sf) -} - -// StructTags represents the struct tag fields that the StructCodec uses during +// structTags represents the struct tag fields that the StructCodec uses during // the encoding and decoding process. // // In the case of a struct, the lowercased field name is used as the key for each exported @@ -38,34 +20,25 @@ func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructT // // The properties are defined below: // -// OmitEmpty Only include the field if it's not set to the zero value for the type or to -// empty slices or maps. -// -// MinSize Marshal an integer of a type larger than 32 bits value as an int32, if that's -// feasible while preserving the numeric value. -// -// Truncate When unmarshaling a BSON double, it is permitted to lose precision to fit within -// a float32. -// -// Inline Inline the field, which must be a struct or a map, causing all of its fields +// inline Inline the field, which must be a struct or a map, causing all of its fields // or keys to be processed as if they were part of the outer struct. For maps, // keys must not conflict with the bson keys of other struct fields. // -// Skip This struct field should be skipped. This is usually denoted by parsing a "-" -// for the name. +// omitEmpty Only include the field if it's not set to the zero value for the type or to +// empty slices or maps. // -// Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. -type StructTags struct { +// skip This struct field should be skipped. This is usually denoted by parsing a "-" +// for the name. +type structTags struct { Name string - OmitEmpty bool - MinSize bool - Truncate bool Inline bool + MinSize bool + OmitEmpty bool Skip bool + Truncate bool } -// DefaultStructTagParser is the StructTagParser used by the StructCodec by default. -// It will handle the bson struct tag. See the documentation for StructTags to see +// parseStructTags handles the bson struct tag. See the documentation for StructTags to see // what each of the returned fields means. // // If there is no name in the struct tag fields, the struct field name is lowercased. @@ -89,22 +62,35 @@ type StructTags struct { // A struct tag either consisting entirely of '-' or with a bson key with a // value consisting entirely of '-' will return a StructTags with Skip true and // the remaining fields will be their default values. -// -// Deprecated: DefaultStructTagParser will be removed in Go Driver 2.0. -var DefaultStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) { +func parseStructTags(sf reflect.StructField) (*structTags, error) { + key := strings.ToLower(sf.Name) + tag, ok := sf.Tag.Lookup("bson") + if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { + tag = string(sf.Tag) + } + return parseTags(key, tag) +} + +// parseJSONStructTags parses the json tag instead on a field where the +// bson tag isn't available. +func parseJSONStructTags(sf reflect.StructField) (*structTags, error) { key := strings.ToLower(sf.Name) tag, ok := sf.Tag.Lookup("bson") + if !ok { + tag, ok = sf.Tag.Lookup("json") + } if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { tag = string(sf.Tag) } + return parseTags(key, tag) } -func parseTags(key string, tag string) (StructTags, error) { - var st StructTags +func parseTags(key string, tag string) (*structTags, error) { + var st structTags if tag == "-" { st.Skip = true - return st, nil + return &st, nil } for idx, str := range strings.Split(tag, ",") { @@ -112,37 +98,18 @@ func parseTags(key string, tag string) (StructTags, error) { key = str } switch str { - case "omitempty": - st.OmitEmpty = true + case "inline": + st.Inline = true case "minsize": st.MinSize = true + case "omitempty": + st.OmitEmpty = true case "truncate": st.Truncate = true - case "inline": - st.Inline = true } } st.Name = key - return st, nil -} - -// JSONFallbackStructTagParser has the same behavior as DefaultStructTagParser -// but will also fallback to parsing the json tag instead on a field where the -// bson tag isn't available. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.UseJSONStructTags] and -// [go.mongodb.org/mongo-driver/bson.Decoder.UseJSONStructTags] instead. -var JSONFallbackStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) { - key := strings.ToLower(sf.Name) - tag, ok := sf.Tag.Lookup("bson") - if !ok { - tag, ok = sf.Tag.Lookup("json") - } - if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { - tag = string(sf.Tag) - } - - return parseTags(key, tag) + return &st, nil } diff --git a/bson/struct_tag_parser_test.go b/bson/struct_tag_parser_test.go index b03815488a..c34faec0b5 100644 --- a/bson/struct_tag_parser_test.go +++ b/bson/struct_tag_parser_test.go @@ -17,134 +17,134 @@ func TestStructTagParsers(t *testing.T) { testCases := []struct { name string sf reflect.StructField - want StructTags - parser StructTagParserFunc + want *structTags + parser func(reflect.StructField) (*structTags, error) }{ { "default no bson tag", reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")}, - StructTags{Name: "bar"}, - DefaultStructTagParser, + &structTags{Name: "bar"}, + parseStructTags, }, { "default empty", reflect.StructField{Name: "foo", Tag: reflect.StructTag("")}, - StructTags{Name: "foo"}, - DefaultStructTagParser, + &structTags{Name: "foo"}, + parseStructTags, }, { "default tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")}, - StructTags{Skip: true}, - DefaultStructTagParser, + &structTags{Skip: true}, + parseStructTags, }, { "default bson tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)}, - StructTags{Skip: true}, - DefaultStructTagParser, + &structTags{Skip: true}, + parseStructTags, }, { "default all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{Name: "bar", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseStructTags, }, { "default all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{Name: "foo", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseStructTags, }, { "default bson tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{Name: "bar", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseStructTags, }, { "default bson tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{Name: "foo", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseStructTags, }, { "default ignore xml", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)}, - StructTags{Name: "foo"}, - DefaultStructTagParser, + &structTags{Name: "foo"}, + parseStructTags, }, { "JSONFallback no bson tag", reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")}, - StructTags{Name: "bar"}, - JSONFallbackStructTagParser, + &structTags{Name: "bar"}, + parseJSONStructTags, }, { "JSONFallback empty", reflect.StructField{Name: "foo", Tag: reflect.StructTag("")}, - StructTags{Name: "foo"}, - JSONFallbackStructTagParser, + &structTags{Name: "foo"}, + parseJSONStructTags, }, { "JSONFallback tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")}, - StructTags{Skip: true}, - JSONFallbackStructTagParser, + &structTags{Skip: true}, + parseJSONStructTags, }, { "JSONFallback bson tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)}, - StructTags{Skip: true}, - JSONFallbackStructTagParser, + &structTags{Skip: true}, + parseJSONStructTags, }, { "JSONFallback all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "bar", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "foo", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback bson tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "bar", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback bson tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "foo", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback json tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:"bar,omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "bar", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback json tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:",omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{Name: "foo", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback bson tag overrides other tags", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar" json:"qux,truncate"`)}, - StructTags{Name: "bar"}, - JSONFallbackStructTagParser, + &structTags{Name: "bar"}, + parseJSONStructTags, }, { "JSONFallback ignore xml", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)}, - StructTags{Name: "foo"}, - JSONFallbackStructTagParser, + &structTags{Name: "foo"}, + parseJSONStructTags, }, } diff --git a/bson/time_codec.go b/bson/time_codec.go index a168d1e769..1e62117d47 100644 --- a/bson/time_codec.go +++ b/bson/time_codec.go @@ -10,48 +10,19 @@ import ( "fmt" "reflect" "time" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) const ( timeFormatString = "2006-01-02T15:04:05.999Z07:00" ) -// TimeCodec is the Codec used for time.Time values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// TimeCodec registered. -type TimeCodec struct { - // UseLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. - // - // Deprecated: Use bson.Decoder.UseLocalTimeZone instead. - UseLocalTimeZone bool -} - -var ( - defaultTimeCodec = NewTimeCodec() - - // Assert that defaultTimeCodec satisfies the typeDecoder interface, which allows it to be used - // by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. - _ typeDecoder = defaultTimeCodec -) - -// NewTimeCodec returns a TimeCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// TimeCodec registered. -func NewTimeCodec(opts ...*bsonoptions.TimeCodecOptions) *TimeCodec { - timeOpt := bsonoptions.MergeTimeCodecOptions(opts...) - - codec := TimeCodec{} - if timeOpt.UseLocalTimeZone != nil { - codec.UseLocalTimeZone = *timeOpt.UseLocalTimeZone - } - return &codec +// timeCodec is the Codec used for time.Time values. +type timeCodec struct { + // useLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. + useLocalTimeZone bool } -func (tc *TimeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (tc *timeCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tTime { return emptyValue, ValueDecoderError{ Name: "TimeDecodeValue", @@ -102,19 +73,19 @@ func (tc *TimeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type return emptyValue, fmt.Errorf("cannot decode %v into a time.Time", vrType) } - if !tc.UseLocalTimeZone && !dc.useLocalTimeZone { + if !tc.useLocalTimeZone { timeVal = timeVal.UTC() } return reflect.ValueOf(timeVal), nil } // DecodeValue is the ValueDecoderFunc for time.Time. -func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (tc *timeCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tTime { return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val} } - elem, err := tc.decodeType(dc, vr, tTime) + elem, err := tc.decodeType(reg, vr, tTime) if err != nil { return err } @@ -124,7 +95,7 @@ func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.V } // EncodeValue is the ValueEncoderFunc for time.TIme. -func (tc *TimeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (tc *timeCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTime { return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} } diff --git a/bson/time_codec_test.go b/bson/time_codec_test.go index 1f185692da..fc32339602 100644 --- a/bson/time_codec_test.go +++ b/bson/time_codec_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "go.mongodb.org/mongo-driver/bson/bsonoptions" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) @@ -22,20 +21,17 @@ func TestTimeCodec(t *testing.T) { t.Run("UseLocalTimeZone", func(t *testing.T) { reader := &valueReaderWriter{BSONType: TypeDateTime, Return: now.UnixNano() / int64(time.Millisecond)} testCases := []struct { - name string - opts *bsonoptions.TimeCodecOptions - utc bool + name string + codec *timeCodec + utc bool }{ - {"default", bsonoptions.TimeCodec(), true}, - {"false", bsonoptions.TimeCodec().SetUseLocalTimeZone(false), true}, - {"true", bsonoptions.TimeCodec().SetUseLocalTimeZone(true), false}, + {"default", &timeCodec{}, true}, + {"true", &timeCodec{useLocalTimeZone: true}, false}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - timeCodec := NewTimeCodec(tc.opts) - actual := reflect.New(reflect.TypeOf(now)).Elem() - err := timeCodec.DecodeValue(DecodeContext{}, reader, actual) + err := tc.codec.DecodeValue(nil, reader, actual) assert.Nil(t, err, "TimeCodec.DecodeValue error: %v", err) actualTime := actual.Interface().(time.Time) @@ -69,7 +65,7 @@ func TestTimeCodec(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { actual := reflect.New(reflect.TypeOf(now)).Elem() - err := defaultTimeCodec.DecodeValue(DecodeContext{}, tc.reader, actual) + err := (&timeCodec{}).DecodeValue(nil, tc.reader, actual) assert.Nil(t, err, "DecodeValue error: %v", err) actualTime := actual.Interface().(time.Time) diff --git a/bson/truncation_test.go b/bson/truncation_test.go index 865917cfe4..04deb0efde 100644 --- a/bson/truncation_test.go +++ b/bson/truncation_test.go @@ -32,19 +32,22 @@ func TestTruncation(t *testing.T) { buf := new(bytes.Buffer) vw := NewValueWriter(buf) - enc := NewEncoder(vw) - enc.IntMinSize() - enc.SetRegistry(DefaultRegistry) - err := enc.Encode(&input) + enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), vw) + err := enc.SetBehavior(IntMinSize) + assert.Nil(t, err) + err = enc.Encode(&input) assert.Nil(t, err) var output outputArgs - dc := DecodeContext{ - Registry: DefaultRegistry, - Truncate: true, - } + opt := NewRegistryOpt(func(c *numCodec) error { + c.truncate = true + return nil + }) + reg := NewRegistryBuilder().Build() + err = reg.SetCodecOption(opt) + assert.Nil(t, err) - err = UnmarshalWithContext(dc, buf.Bytes(), &output) + err = UnmarshalWithRegistry(reg, buf.Bytes(), &output) assert.Nil(t, err) assert.Equal(t, inputName, output.Name) @@ -58,20 +61,23 @@ func TestTruncation(t *testing.T) { buf := new(bytes.Buffer) vw := NewValueWriter(buf) - enc := NewEncoder(vw) - enc.IntMinSize() - enc.SetRegistry(DefaultRegistry) - err := enc.Encode(&input) + enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), vw) + err := enc.SetBehavior(IntMinSize) + assert.Nil(t, err) + err = enc.Encode(&input) assert.Nil(t, err) var output outputArgs - dc := DecodeContext{ - Registry: DefaultRegistry, - Truncate: false, - } + opt := NewRegistryOpt(func(c *numCodec) error { + c.truncate = false + return nil + }) + reg := NewRegistryBuilder().Build() + err = reg.SetCodecOption(opt) + assert.Nil(t, err) // case throws an error when truncation is disabled - err = UnmarshalWithContext(dc, buf.Bytes(), &output) + err = UnmarshalWithRegistry(reg, buf.Bytes(), &output) assert.NotNil(t, err) }) } diff --git a/bson/uint_codec.go b/bson/uint_codec.go deleted file mode 100644 index 73bc01966e..0000000000 --- a/bson/uint_codec.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bson - -import ( - "fmt" - "math" - "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" -) - -// UIntCodec is the Codec used for uint values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// UIntCodec registered. -type UIntCodec struct { - // EncodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the - // minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value. - // - // Deprecated: Use bson.Encoder.IntMinSize instead. - EncodeToMinSize bool -} - -var ( - defaultUIntCodec = NewUIntCodec() - - // Assert that defaultUIntCodec satisfies the typeDecoder interface, which allows it to be used - // by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. - _ typeDecoder = defaultUIntCodec -) - -// NewUIntCodec returns a UIntCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// UIntCodec registered. -func NewUIntCodec(opts ...*bsonoptions.UIntCodecOptions) *UIntCodec { - uintOpt := bsonoptions.MergeUIntCodecOptions(opts...) - - codec := UIntCodec{} - if uintOpt.EncodeToMinSize != nil { - codec.EncodeToMinSize = *uintOpt.EncodeToMinSize - } - return &codec -} - -// EncodeValue is the ValueEncoder for uint types. -func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - switch val.Kind() { - case reflect.Uint8, reflect.Uint16: - return vw.WriteInt32(int32(val.Uint())) - case reflect.Uint, reflect.Uint32, reflect.Uint64: - u64 := val.Uint() - - // If ec.MinSize or if encodeToMinSize is true for a non-uint64 value we should write val as an int32 - useMinSize := ec.MinSize || (uic.EncodeToMinSize && val.Kind() != reflect.Uint64) - - if u64 <= math.MaxInt32 && useMinSize { - return vw.WriteInt32(int32(u64)) - } - if u64 > math.MaxInt64 { - return fmt.Errorf("%d overflows int64", u64) - } - return vw.WriteInt64(int64(u64)) - } - - return ValueEncoderError{ - Name: "UintEncodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } -} - -func (uic *UIntCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - var i64 int64 - var err error - switch vrType := vr.Type(); vrType { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return emptyValue, err - } - i64 = int64(i32) - case TypeInt64: - i64, err = vr.ReadInt64() - if err != nil { - return emptyValue, err - } - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return emptyValue, err - } - if !dc.Truncate && math.Floor(f64) != f64 { - return emptyValue, errCannotTruncate - } - if f64 > float64(math.MaxInt64) { - return emptyValue, fmt.Errorf("%g overflows int64", f64) - } - i64 = int64(f64) - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return emptyValue, err - } - if b { - i64 = 1 - } - case TypeNull: - if err = vr.ReadNull(); err != nil { - return emptyValue, err - } - case TypeUndefined: - if err = vr.ReadUndefined(); err != nil { - return emptyValue, err - } - default: - return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType) - } - - switch t.Kind() { - case reflect.Uint8: - if i64 < 0 || i64 > math.MaxUint8 { - return emptyValue, fmt.Errorf("%d overflows uint8", i64) - } - - return reflect.ValueOf(uint8(i64)), nil - case reflect.Uint16: - if i64 < 0 || i64 > math.MaxUint16 { - return emptyValue, fmt.Errorf("%d overflows uint16", i64) - } - - return reflect.ValueOf(uint16(i64)), nil - case reflect.Uint32: - if i64 < 0 || i64 > math.MaxUint32 { - return emptyValue, fmt.Errorf("%d overflows uint32", i64) - } - - return reflect.ValueOf(uint32(i64)), nil - case reflect.Uint64: - if i64 < 0 { - return emptyValue, fmt.Errorf("%d overflows uint64", i64) - } - - return reflect.ValueOf(uint64(i64)), nil - case reflect.Uint: - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint - return emptyValue, fmt.Errorf("%d overflows uint", i64) - } - - return reflect.ValueOf(uint(i64)), nil - default: - return emptyValue, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: reflect.Zero(t), - } - } -} - -// DecodeValue is the ValueDecoder for uint types. -func (uic *UIntCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() { - return ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } - } - - elem, err := uic.decodeType(dc, vr, val.Type()) - if err != nil { - return err - } - - val.SetUint(elem.Uint()) - return nil -} diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 7caadc5dbc..6cead8048b 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -38,51 +38,31 @@ type ValueUnmarshaler interface { // pointed to by val. If val is nil or not a pointer, Unmarshal returns // InvalidUnmarshalError. func Unmarshal(data []byte, val interface{}) error { - return UnmarshalWithRegistry(DefaultRegistry, data, val) + return UnmarshalWithRegistry(defaultRegistry, data, val) } // UnmarshalWithRegistry parses the BSON-encoded data using Registry r and // stores the result in the value pointed to by val. If val is nil or not // a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError. // -// Deprecated: Use [NewDecoder] and specify the Registry by calling [Decoder.SetRegistry] instead: +// Deprecated: Use [NewDecoderWithRegistry] instead: // -// dec, err := bson.NewDecoder(NewBSONDocumentReader(data)) +// dec, err := bson.NewDecoderWithRegistry(reg, NewBSONDocumentReader(data)) // if err != nil { // panic(err) // } -// dec.SetRegistry(reg) // // See [Decoder] for more examples. -func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { +func UnmarshalWithRegistry(reg *Registry, data []byte, val interface{}) error { vr := NewValueReader(data) - return unmarshalFromReader(DecodeContext{Registry: r}, vr, val) + return NewDecoderWithRegistry(reg, vr).Decode(val) } -// UnmarshalWithContext parses the BSON-encoded data using DecodeContext dc and -// stores the result in the value pointed to by val. If val is nil or not -// a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError. -// -// Deprecated: Use [NewDecoder] and use the Decoder configuration methods to set the desired unmarshal -// behavior instead: -// -// dec, err := bson.NewDecoder(NewBSONDocumentReader(data)) -// if err != nil { -// panic(err) -// } -// dec.DefaultDocumentM() -// -// See [Decoder] for more examples. -func UnmarshalWithContext(dc DecodeContext, data []byte, val interface{}) error { - vr := NewValueReader(data) - return unmarshalFromReader(dc, vr, val) -} - -// UnmarshalValue parses the BSON value of type t with bson.DefaultRegistry and +// UnmarshalValue parses the BSON value of type t with default registry and // stores the result in the value pointed to by val. If val is nil or not a pointer, // UnmarshalValue returns an error. func UnmarshalValue(t Type, data []byte, val interface{}) error { - return UnmarshalValueWithRegistry(DefaultRegistry, t, data, val) + return UnmarshalValueWithRegistry(defaultRegistry, t, data, val) } // UnmarshalValueWithRegistry parses the BSON value of type t with registry r and @@ -91,16 +71,16 @@ func UnmarshalValue(t Type, data []byte, val interface{}) error { // // Deprecated: Using a custom registry to unmarshal individual BSON values will not be supported in // Go Driver 2.0. -func UnmarshalValueWithRegistry(r *Registry, t Type, data []byte, val interface{}) error { +func UnmarshalValueWithRegistry(reg *Registry, t Type, data []byte, val interface{}) error { vr := NewBSONValueReader(t, data) - return unmarshalFromReader(DecodeContext{Registry: r}, vr, val) + return NewDecoderWithRegistry(reg, vr).Decode(val) } // UnmarshalExtJSON parses the extended JSON-encoded data and stores the result // in the value pointed to by val. If val is nil or not a pointer, Unmarshal // returns InvalidUnmarshalError. func UnmarshalExtJSON(data []byte, canonical bool, val interface{}) error { - return UnmarshalExtJSONWithRegistry(DefaultRegistry, data, canonical, val) + return UnmarshalExtJSONWithRegistry(defaultRegistry, data, canonical, val) } // UnmarshalExtJSONWithRegistry parses the extended JSON-encoded data using @@ -120,13 +100,13 @@ func UnmarshalExtJSON(data []byte, canonical bool, val interface{}) error { // dec.SetRegistry(reg) // // See [Decoder] for more examples. -func UnmarshalExtJSONWithRegistry(r *Registry, data []byte, canonical bool, val interface{}) error { - ejvr, err := NewExtJSONValueReader(bytes.NewReader(data), canonical) +func UnmarshalExtJSONWithRegistry(reg *Registry, data []byte, canonical bool, val interface{}) error { + vr, err := NewExtJSONValueReader(bytes.NewReader(data), canonical) if err != nil { return err } - return unmarshalFromReader(DecodeContext{Registry: r}, ejvr, val) + return NewDecoderWithRegistry(reg, vr).Decode(val) } // UnmarshalExtJSONWithContext parses the extended JSON-encoded data using @@ -147,21 +127,11 @@ func UnmarshalExtJSONWithRegistry(r *Registry, data []byte, canonical bool, val // dec.DefaultDocumentM() // // See [Decoder] for more examples. -func UnmarshalExtJSONWithContext(dc DecodeContext, data []byte, canonical bool, val interface{}) error { - ejvr, err := NewExtJSONValueReader(bytes.NewReader(data), canonical) +func UnmarshalExtJSONWithContext(reg *Registry, data []byte, canonical bool, val interface{}) error { + vr, err := NewExtJSONValueReader(bytes.NewReader(data), canonical) if err != nil { return err } - return unmarshalFromReader(dc, ejvr, val) -} - -func unmarshalFromReader(dc DecodeContext, vr ValueReader, val interface{}) error { - dec := decPool.Get().(*Decoder) - defer decPool.Put(dec) - - dec.Reset(vr) - dec.dc = dc - - return dec.Decode(val) + return NewDecoderWithRegistry(reg, vr).Decode(val) } diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index 0871237386..d643a7db57 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -48,30 +48,7 @@ func TestUnmarshalWithRegistry(t *testing.T) { // Assert that unmarshaling the input data results in the expected value. got := reflect.New(tc.sType).Interface() - err := UnmarshalWithRegistry(DefaultRegistry, data, got) - noerr(t, err) - assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") - - // Fill the input data slice with random bytes and then assert that the result still - // matches the expected value. - _, err = rand.Read(data) - noerr(t, err) - assert.Equal(t, tc.want, got, "unmarshaled value does not match expected after modifying the input bytes") - }) - } -} - -func TestUnmarshalWithContext(t *testing.T) { - for _, tc := range unmarshalingTestCases() { - t.Run(tc.name, func(t *testing.T) { - // Make a copy of the test data so we can modify it later. - data := make([]byte, len(tc.data)) - copy(data, tc.data) - - // Assert that unmarshaling the input data results in the expected value. - dc := DecodeContext{Registry: DefaultRegistry} - got := reflect.New(tc.sType).Interface() - err := UnmarshalWithContext(dc, data, got) + err := UnmarshalWithRegistry(defaultRegistry, data, got) noerr(t, err) assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") @@ -89,7 +66,7 @@ func TestUnmarshalExtJSONWithRegistry(t *testing.T) { type teststruct struct{ Foo int } var got teststruct data := []byte("{\"foo\":1}") - err := UnmarshalExtJSONWithRegistry(DefaultRegistry, data, true, &got) + err := UnmarshalExtJSONWithRegistry(defaultRegistry, data, true, &got) noerr(t, err) want := teststruct{1} assert.Equal(t, want, got, "Did not unmarshal as expected.") @@ -97,7 +74,7 @@ func TestUnmarshalExtJSONWithRegistry(t *testing.T) { t.Run("UnmarshalExtJSONInvalidInput", func(t *testing.T) { data := []byte("invalid") - err := UnmarshalExtJSONWithRegistry(DefaultRegistry, data, true, &M{}) + err := UnmarshalExtJSONWithRegistry(defaultRegistry, data, true, &M{}) if !errors.Is(err, ErrInvalidJSON) { t.Fatalf("wanted ErrInvalidJSON, got %v", err) } @@ -199,8 +176,7 @@ func TestUnmarshalExtJSONWithContext(t *testing.T) { // Assert that unmarshaling the input data results in the expected value. got := reflect.New(tc.sType).Interface() - dc := DecodeContext{Registry: DefaultRegistry} - err := UnmarshalExtJSONWithContext(dc, data, true, got) + err := UnmarshalExtJSONWithContext(defaultRegistry, data, true, got) noerr(t, err) assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") @@ -219,7 +195,7 @@ func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) { // different Registry is used. // Create a custom Registry that negates BSON int32 values when decoding. - var decodeInt32 ValueDecoderFunc = func(_ DecodeContext, vr ValueReader, val reflect.Value) error { + var decodeInt32 ValueDecoderFunc = func(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { i32, err := vr.ReadInt32() if err != nil { return err @@ -228,8 +204,9 @@ func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) { val.SetInt(int64(-1 * i32)) return nil } - customReg := NewRegistry() - customReg.RegisterTypeDecoder(tInt32, decodeInt32) + customReg := NewRegistryBuilder(). + RegisterTypeDecoder(tInt32, func(*Registry) ValueDecoder { return decodeInt32 }). + Build() docBytes := bsoncore.BuildDocumentFromElements( nil, diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index 8d9dfb5351..f9acc852ad 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -36,7 +36,7 @@ func TestUnmarshalValue(t *testing.T) { }) } }) - t.Run("UnmarshalValueWithRegistry with DefaultRegistry", func(t *testing.T) { + t.Run("UnmarshalValueWithRegistry with default registry", func(t *testing.T) { t.Parallel() for _, tc := range unmarshalValueTestCases { @@ -46,7 +46,7 @@ func TestUnmarshalValue(t *testing.T) { t.Parallel() gotValue := reflect.New(reflect.TypeOf(tc.val)) - err := UnmarshalValueWithRegistry(DefaultRegistry, tc.bsontype, tc.bytes, gotValue.Interface()) + err := UnmarshalValueWithRegistry(defaultRegistry, tc.bsontype, tc.bytes, gotValue.Interface()) assert.Nil(t, err, "UnmarshalValueWithRegistry error: %v", err) assert.Equal(t, tc.val, gotValue.Elem().Interface(), "value mismatch; expected %s, got %s", tc.val, gotValue.Elem()) }) @@ -75,8 +75,9 @@ func TestUnmarshalValue(t *testing.T) { bytes: bsoncore.AppendString(nil, "hello world"), }, } - reg := NewRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), NewSliceCodec()) + reg := NewRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf([]byte{}), func(*Registry) ValueDecoder { return &sliceCodec{} }). + Build() for _, tc := range testCases { tc := tc @@ -110,8 +111,9 @@ func BenchmarkSliceCodecUnmarshal(b *testing.B) { bytes: bsoncore.AppendString(nil, strings.Repeat("t", 4096)), }, } - reg := NewRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), NewSliceCodec()) + reg := NewRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf([]byte{}), func(*Registry) ValueDecoder { return &sliceCodec{} }). + Build() for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { b.RunParallel(func(pb *testing.PB) { diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index ceae58ac81..e2f9fb36d1 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -39,12 +39,12 @@ type negateCodec struct { ID int64 `bson:"_id"` } -func (e *negateCodec) EncodeValue(_ bson.EncodeContext, vw bson.ValueWriter, val reflect.Value) error { +func (e *negateCodec) EncodeValue(_ bson.EncoderRegistry, vw bson.ValueWriter, val reflect.Value) error { return vw.WriteInt64(val.Int()) } // DecodeValue negates the value of ID when reading -func (e *negateCodec) DecodeValue(_ bson.DecodeContext, vr bson.ValueReader, val reflect.Value) error { +func (e *negateCodec) DecodeValue(_ bson.DecoderRegistry, vr bson.ValueReader, val reflect.Value) error { i, err := vr.ReadInt64() if err != nil { return err @@ -100,9 +100,10 @@ func (sc *slowConn) Read(b []byte) (n int, err error) { func TestClient(t *testing.T) { mt := mtest.New(t, noClientOpts) - reg := bson.NewRegistry() - reg.RegisterTypeEncoder(reflect.TypeOf(int64(0)), &negateCodec{}) - reg.RegisterTypeDecoder(reflect.TypeOf(int64(0)), &negateCodec{}) + reg := bson.NewRegistryBuilder(). + RegisterTypeEncoder(reflect.TypeOf(int64(0)), func(*bson.Registry) bson.ValueEncoder { return &negateCodec{} }). + RegisterTypeDecoder(reflect.TypeOf(int64(0)), func(*bson.Registry) bson.ValueDecoder { return &negateCodec{} }). + Build() registryOpts := options.Client(). SetRegistry(reg) mt.RunOpts("registry passed to cursors", mtest.NewOptions().ClientOptions(registryOpts), func(mt *mtest.T) { diff --git a/internal/integration/crud_spec_test.go b/internal/integration/crud_spec_test.go index e6583f8ade..996cdd27f4 100644 --- a/internal/integration/crud_spec_test.go +++ b/internal/integration/crud_spec_test.go @@ -55,11 +55,9 @@ type crudOutcome struct { Collection *outcomeCollection `bson:"collection"` } -var crudRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) - return reg -}() +var crudRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). + Build() func TestCrudSpec(t *testing.T) { for _, dir := range []string{crudReadDir, crudWriteDir} { diff --git a/internal/integration/database_test.go b/internal/integration/database_test.go index 12c2e0cd53..da043a6636 100644 --- a/internal/integration/database_test.go +++ b/internal/integration/database_test.go @@ -29,11 +29,9 @@ const ( ) var ( - interfaceAsMapRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.M{})) - return reg - }() + interfaceAsMapRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.M{})). + Build() ) func TestDatabase(t *testing.T) { diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index cba3244db3..4ea80d0a54 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -77,24 +77,24 @@ type testData struct { } // custom decoder for testData type -func decodeTestData(dc bson.DecodeContext, vr bson.ValueReader, val reflect.Value) error { +func decodeTestData(reg bson.DecoderRegistry, vr bson.ValueReader, val reflect.Value) error { switch vr.Type() { case bson.TypeArray: docsVal := val.FieldByName("Documents") - decoder, err := dc.Registry.LookupDecoder(docsVal.Type()) + decoder, err := reg.LookupDecoder(docsVal.Type()) if err != nil { return err } - return decoder.DecodeValue(dc, vr, docsVal) + return decoder.DecodeValue(reg, vr, docsVal) case bson.TypeEmbeddedDocument: gridfsDataVal := val.FieldByName("GridFSData") - decoder, err := dc.Registry.LookupDecoder(gridfsDataVal.Type()) + decoder, err := reg.LookupDecoder(gridfsDataVal.Type()) if err != nil { return err } - return decoder.DecodeValue(dc, vr, gridfsDataVal) + return decoder.DecodeValue(reg, vr, gridfsDataVal) } return nil } @@ -181,12 +181,10 @@ var directories = []string{ } var checkOutcomeOpts = options.Collection().SetReadPreference(readpref.Primary()).SetReadConcern(readconcern.Local()) -var specTestRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) - reg.RegisterTypeDecoder(reflect.TypeOf(testData{}), bson.ValueDecoderFunc(decodeTestData)) - return reg -}() +var specTestRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). + RegisterTypeDecoder(reflect.TypeOf(testData{}), func(*bson.Registry) bson.ValueDecoder { return bson.ValueDecoderFunc(decodeTestData) }). + Build() func TestUnifiedSpecs(t *testing.T) { for _, specDir := range directories { diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 0df6ae03c7..b8a06a9d24 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -602,7 +602,10 @@ func (cs *ChangeStream) Decode(val interface{}) error { return ErrNilCursor } - dec := getDecoder(cs.Current, cs.bsonOpts, cs.registry) + dec, err := getDecoder(cs.Current, cs.bsonOpts, cs.registry) + if err != nil { + return err + } return dec.Decode(val) } diff --git a/mongo/client.go b/mongo/client.go index 36f6fbc35f..eecbdbaa03 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -166,7 +166,7 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { client.bsonOpts = clientOpt.BSONOptions } // Registry - client.registry = bson.DefaultRegistry + client.registry = bson.NewRegistryBuilder().Build() if clientOpt.Registry != nil { client.registry = clientOpt.Registry } diff --git a/mongo/cursor.go b/mongo/cursor.go index 8f07b1ee9b..251d190e89 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -55,7 +55,7 @@ func newCursorWithSession( clientSession *session.Client, ) (*Cursor, error) { if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } if bc == nil { return nil, errors.New("batch cursor must not be nil") @@ -82,16 +82,15 @@ func newEmptyCursor() *Cursor { } // NewCursorFromDocuments creates a new Cursor pre-loaded with the provided documents, error and registry. If no registry is provided, -// bson.DefaultRegistry will be used. +// a default registry will be used. // // The documents parameter must be a slice of documents. The slice may be nil or empty, but all elements must be non-nil. func NewCursorFromDocuments(documents []interface{}, preloadedErr error, registry *bson.Registry) (*Cursor, error) { if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } buf := new(bytes.Buffer) - enc := new(bson.Encoder) values := make([]bsoncore.Value, len(documents)) for i, doc := range documents { @@ -104,9 +103,7 @@ func NewCursorFromDocuments(documents []interface{}, preloadedErr error, registr } vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(registry) - + enc := bson.NewEncoderWithRegistry(registry, vw) if err := enc.Encode(doc); err != nil { return nil, err } @@ -237,47 +234,59 @@ func getDecoder( data []byte, opts *options.BSONOptions, reg *bson.Registry, -) *bson.Decoder { - dec := bson.NewDecoder(bson.NewValueReader(data)) +) (*bson.Decoder, error) { + vr := bson.NewValueReader(data) + var dec *bson.Decoder + if reg != nil { + dec = bson.NewDecoderWithRegistry(reg, vr) + } else { + dec = bson.NewDecoder(vr) + } if opts != nil { + regOpts := []*bson.RegistryOpt{} if opts.AllowTruncatingDoubles { - dec.AllowTruncatingDoubles() + regOpts = append(regOpts, bson.AllowTruncatingDoubles) } if opts.BinaryAsSlice { - dec.BinaryAsSlice() + regOpts = append(regOpts, bson.BinaryAsSlice) } if opts.DefaultDocumentD { - dec.DefaultDocumentD() + regOpts = append(regOpts, bson.DefaultDocumentD) } if opts.DefaultDocumentM { - dec.DefaultDocumentM() + regOpts = append(regOpts, bson.DefaultDocumentM) } if opts.UseJSONStructTags { - dec.UseJSONStructTags() + regOpts = append(regOpts, bson.UseJSONStructTags) } if opts.UseLocalTimeZone { - dec.UseLocalTimeZone() + regOpts = append(regOpts, bson.UseLocalTimeZone) } if opts.ZeroMaps { - dec.ZeroMaps() + regOpts = append(regOpts, bson.ZeroMaps) } if opts.ZeroStructs { - dec.ZeroStructs() + regOpts = append(regOpts, bson.ZeroStructs) + } + for _, opt := range regOpts { + err := dec.SetBehavior(opt) + if err != nil { + return nil, err + } } } - if reg != nil { - dec.SetRegistry(reg) - } - - return dec + return dec, nil } // Decode will unmarshal the current document into val and return any errors from the unmarshalling process without any // modification. If val is nil or is a typed nil, an error will be returned. func (c *Cursor) Decode(val interface{}) error { - dec := getDecoder(c.Current, c.bsonOpts, c.registry) + dec, err := getDecoder(c.Current, c.bsonOpts, c.registry) + if err != nil { + return err + } return dec.Decode(val) } @@ -368,7 +377,11 @@ func (c *Cursor) addFromBatch(sliceVal reflect.Value, elemType reflect.Type, bat } currElem := sliceVal.Index(index).Addr().Interface() - dec := getDecoder(doc, c.bsonOpts, c.registry) + var dec *bson.Decoder + dec, err = getDecoder(doc, c.bsonOpts, c.registry) + if err != nil { + return sliceVal, index, err + } err = dec.Decode(currElem) if err != nil { return sliceVal, index, err diff --git a/mongo/database_test.go b/mongo/database_test.go index 31bd900439..1142b6df9c 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -53,7 +53,7 @@ func TestDatabase(t *testing.T) { wc2 := &writeconcern.WriteConcern{W: 10} rcLocal := readconcern.Local() rcMajority := readconcern.Majority() - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder().Build() opts := options.Database().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetWriteConcern(wc1). SetReadPreference(rpSecondary).SetReadConcern(rcMajority).SetWriteConcern(wc2).SetRegistry(reg) @@ -70,7 +70,7 @@ func TestDatabase(t *testing.T) { rpPrimary := readpref.Primary() rcLocal := readconcern.Local() wc1 := &writeconcern.WriteConcern{W: 10} - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder().Build() client := setupClient(options.Client().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetRegistry(reg)) got := client.Database("foo", options.Database().SetWriteConcern(wc1)) diff --git a/mongo/gridfs_bucket.go b/mongo/gridfs_bucket.go index 7c2bbac64e..55212eb334 100644 --- a/mongo/gridfs_bucket.go +++ b/mongo/gridfs_bucket.go @@ -613,15 +613,14 @@ func (b *GridFSBucket) parseUploadOptions(opts ...*options.UploadOptions) (*uplo upload.chunkSize = *uo.ChunkSizeBytes } if uo.Registry == nil { - uo.Registry = bson.DefaultRegistry + uo.Registry = bson.NewRegistryBuilder().Build() } if uo.Metadata != nil { // TODO(GODRIVER-2726): Replace with marshal() and unmarshal() once the // TODO gridfs package is merged into the mongo package. buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(uo.Registry) + enc := bson.NewEncoderWithRegistry(uo.Registry, vw) err := enc.Encode(uo.Metadata) if err != nil { return nil, err diff --git a/mongo/mongo.go b/mongo/mongo.go index 318c765000..8cd6258e38 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -63,37 +63,45 @@ func getEncoder( reg *bson.Registry, ) (*bson.Encoder, error) { vw := bvwPool.Get(w) - enc := bson.NewEncoder(vw) + var enc *bson.Encoder + if reg != nil { + enc = bson.NewEncoderWithRegistry(reg, vw) + } else { + enc = bson.NewEncoder(vw) + } if opts != nil { + regOpts := []*bson.RegistryOpt{} if opts.ErrorOnInlineDuplicates { - enc.ErrorOnInlineDuplicates() + regOpts = append(regOpts, bson.ErrorOnInlineDuplicates) } if opts.IntMinSize { - enc.IntMinSize() + regOpts = append(regOpts, bson.IntMinSize) } if opts.NilByteSliceAsEmpty { - enc.NilByteSliceAsEmpty() + regOpts = append(regOpts, bson.NilByteSliceAsEmpty) } if opts.NilMapAsEmpty { - enc.NilMapAsEmpty() + regOpts = append(regOpts, bson.NilMapAsEmpty) } if opts.NilSliceAsEmpty { - enc.NilSliceAsEmpty() + regOpts = append(regOpts, bson.NilSliceAsEmpty) } if opts.OmitZeroStruct { - enc.OmitZeroStruct() + regOpts = append(regOpts, bson.OmitZeroStruct) } if opts.StringifyMapKeysWithFmt { - enc.StringifyMapKeysWithFmt() + regOpts = append(regOpts, bson.StringifyMapKeysWithFmt) } if opts.UseJSONStructTags { - enc.UseJSONStructTags() + regOpts = append(regOpts, bson.UseJSONStructTags) + } + for _, opt := range regOpts { + err := enc.SetBehavior(opt) + if err != nil { + return nil, err + } } - } - - if reg != nil { - enc.SetRegistry(reg) } return enc, nil @@ -118,7 +126,7 @@ func marshal( registry *bson.Registry, ) (bsoncore.Document, error) { if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } if val == nil { return nil, ErrNilDocument @@ -153,10 +161,10 @@ func ensureID( doc bsoncore.Document, oid bson.ObjectID, bsonOpts *options.BSONOptions, - reg *bson.Registry, + registry *bson.Registry, ) (bsoncore.Document, interface{}, error) { - if reg == nil { - reg = bson.DefaultRegistry + if registry == nil { + registry = bson.NewRegistryBuilder().Build() } // Try to find the "_id" element. If it exists, try to unmarshal just the @@ -166,7 +174,11 @@ func ensureID( var id struct { ID interface{} `bson:"_id"` } - dec := getDecoder(doc, bsonOpts, reg) + var dec *bson.Decoder + dec, err = getDecoder(doc, bsonOpts, registry) + if err != nil { + return nil, nil, fmt.Errorf("error unmarshaling BSON document: %w", err) + } err = dec.Decode(&id) if err != nil { return nil, nil, fmt.Errorf("error unmarshaling BSON document: %w", err) diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index beba45514f..078c029308 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -80,7 +80,7 @@ func TestClientOptions(t *testing.T) { {"Monitor", (*ClientOptions).SetMonitor, &event.CommandMonitor{}, "Monitor", false}, {"ReadConcern", (*ClientOptions).SetReadConcern, readconcern.Majority(), "ReadConcern", false}, {"ReadPreference", (*ClientOptions).SetReadPreference, readpref.SecondaryPreferred(), "ReadPreference", false}, - {"Registry", (*ClientOptions).SetRegistry, bson.NewRegistry(), "Registry", false}, + {"Registry", (*ClientOptions).SetRegistry, bson.NewRegistryBuilder().Build(), "Registry", false}, {"ReplicaSet", (*ClientOptions).SetReplicaSet, "example-replicaset", "ReplicaSet", true}, {"RetryWrites", (*ClientOptions).SetRetryWrites, true, "RetryWrites", true}, {"ServerSelectionTimeout", (*ClientOptions).SetServerSelectionTimeout, 5 * time.Second, "ServerSelectionTimeout", true}, diff --git a/mongo/options/gridfsoptions.go b/mongo/options/gridfsoptions.go index 10d454c89d..47f97a5a51 100644 --- a/mongo/options/gridfsoptions.go +++ b/mongo/options/gridfsoptions.go @@ -99,7 +99,7 @@ type UploadOptions struct { // GridFSUpload creates a new UploadOptions instance. func GridFSUpload() *UploadOptions { - return &UploadOptions{Registry: bson.DefaultRegistry} + return &UploadOptions{Registry: bson.NewRegistryBuilder().Build()} } // SetChunkSizeBytes sets the value for the ChunkSize field. diff --git a/mongo/options/mongooptions.go b/mongo/options/mongooptions.go index 756684af7b..e6b1e5c5f9 100644 --- a/mongo/options/mongooptions.go +++ b/mongo/options/mongooptions.go @@ -128,16 +128,14 @@ type ArrayFilters struct { func (af *ArrayFilters) ToArray() ([]bson.Raw, error) { registry := af.Registry if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } filters := make([]bson.Raw, 0, len(af.Filters)) buf := new(bytes.Buffer) - enc := new(bson.Encoder) for _, f := range af.Filters { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(registry) + enc := bson.NewEncoderWithRegistry(registry, vw) err := enc.Encode(f) if err != nil { return nil, err @@ -154,17 +152,15 @@ func (af *ArrayFilters) ToArray() ([]bson.Raw, error) { func (af *ArrayFilters) ToArrayDocument() (bson.Raw, error) { registry := af.Registry if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } idx, arr := bsoncore.AppendArrayStart(nil) buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, f := range af.Filters { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(registry) + enc := bson.NewEncoderWithRegistry(registry, vw) err := enc.Encode(f) if err != nil { return nil, err diff --git a/mongo/read_write_concern_spec_test.go b/mongo/read_write_concern_spec_test.go index ec49bb91db..c737f76a9b 100644 --- a/mongo/read_write_concern_spec_test.go +++ b/mongo/read_write_concern_spec_test.go @@ -31,11 +31,9 @@ const ( var ( serverDefaultConcern = []byte{5, 0, 0, 0, 0} // server default read concern and write concern is empty document - specTestRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) - return reg - }() + specTestRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). + Build() ) type connectionStringTestFile struct { diff --git a/mongo/single_result.go b/mongo/single_result.go index 6a0a695685..793d8aa6eb 100644 --- a/mongo/single_result.go +++ b/mongo/single_result.go @@ -44,7 +44,7 @@ func NewSingleResultFromDocument( return &SingleResult{err: ErrNilDocument} } if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } cur, createErr := NewCursorFromDocuments([]interface{}{document}, err, registry) @@ -77,7 +77,10 @@ func (sr *SingleResult) Decode(v interface{}) error { return sr.err } - dec := getDecoder(sr.rdr, sr.bsonOpts, sr.reg) + dec, err := getDecoder(sr.rdr, sr.bsonOpts, sr.reg) + if err != nil { + return err + } return dec.Decode(v) } diff --git a/mongo/single_result_test.go b/mongo/single_result_test.go index a9f409eeb0..34068dd2fc 100644 --- a/mongo/single_result_test.go +++ b/mongo/single_result_test.go @@ -96,10 +96,10 @@ func TestSingleResult_Decode(t *testing.T) { t.Run("decode twice", func(t *testing.T) { t.Run("bson.Raw", func(t *testing.T) { // Test that Decode and Raw can be called more than once - c, err := newCursor(newTestBatchCursor(1, 1), nil, bson.DefaultRegistry) + c, err := newCursor(newTestBatchCursor(1, 1), nil, bson.NewRegistryBuilder().Build()) assert.Nil(t, err, "newCursor error: %v", err) - sr := &SingleResult{cur: c, reg: bson.DefaultRegistry} + sr := &SingleResult{cur: c, reg: c.registry} var firstDecode, secondDecode bson.Raw err = sr.Decode(&firstDecode) assert.Nil(t, err, "Decode error: %v", err) diff --git a/x/mongo/driver/topology/server_options.go b/x/mongo/driver/topology/server_options.go index c02600e232..dca9c0581b 100644 --- a/x/mongo/driver/topology/server_options.go +++ b/x/mongo/driver/topology/server_options.go @@ -17,7 +17,7 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) -var defaultRegistry = bson.NewRegistry() +var defaultRegistry = bson.NewRegistryBuilder().Build() type serverConfig struct { clock *session.ClusterClock