From 9f587a2a7224fbbdc71ec86a53b2caa37c4c6ad4 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 24 Apr 2024 14:13:55 -0400 Subject: [PATCH 01/15] WIP --- bson/bsoncodec.go | 18 +- bson/decoder.go | 19 +- bson/default_value_decoders.go | 755 +++++++--------------------- bson/default_value_decoders_test.go | 159 +++--- bson/default_value_encoders.go | 513 ++++--------------- bson/default_value_encoders_test.go | 80 ++- bson/map_codec.go | 2 +- bson/mgocompat/bson_test.go | 3 - bson/mgocompat/registry.go | 77 ++- bson/primitive_codecs.go | 54 +- bson/primitive_codecs_test.go | 12 +- bson/raw_value_test.go | 6 +- bson/registry.go | 197 +------- bson/registry_test.go | 143 +++--- bson/slice_codec.go | 6 +- bson/string_codec.go | 9 +- bson/string_codec_test.go | 22 +- bson/struct_codec.go | 2 +- 18 files changed, 563 insertions(+), 1514 deletions(-) diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 860a6b82af..b7aaadf2c2 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -179,11 +179,12 @@ type DecodeContext struct { // error. DocumentType overrides the Ancestor field. defaultDocumentType reflect.Type - binaryAsSlice bool - useJSONStructTags bool - useLocalTimeZone bool - zeroMaps bool - zeroStructs bool + binaryAsSlice bool + decodeObjectIDAsHex bool + useJSONStructTags bool + useLocalTimeZone bool + zeroMaps bool + zeroStructs bool } // BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or @@ -194,6 +195,13 @@ func (dc *DecodeContext) BinaryAsSlice() { dc.binaryAsSlice = true } +// DecodeObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DecodeObjectIDAsHex] instead. +func (dc *DecodeContext) DecodeObjectIDAsHex() { + dc.decodeObjectIDAsHex = true +} + // UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson" // struct tag is not specified. // diff --git a/bson/decoder.go b/bson/decoder.go index 6ea5ad97c1..898dbc87af 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -36,11 +36,12 @@ type Decoder struct { defaultDocumentM bool defaultDocumentD bool - binaryAsSlice bool - useJSONStructTags bool - useLocalTimeZone bool - zeroMaps bool - zeroStructs bool + binaryAsSlice bool + decodeObjectIDAsHex bool + useJSONStructTags bool + useLocalTimeZone bool + zeroMaps bool + zeroStructs bool } // NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr. @@ -93,6 +94,9 @@ func (d *Decoder) Decode(val interface{}) error { if d.binaryAsSlice { d.dc.BinaryAsSlice() } + if d.decodeObjectIDAsHex { + d.dc.DecodeObjectIDAsHex() + } if d.useJSONStructTags { d.dc.UseJSONStructTags() } @@ -145,6 +149,11 @@ func (d *Decoder) BinaryAsSlice() { d.binaryAsSlice = true } +// DecodeObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. +func (d *Decoder) DecodeObjectIDAsHex() { + d.decodeObjectIDAsHex = true +} + // UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson" // struct tag is not specified. func (d *Decoder) UseJSONStructTags() { diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index bc8c7b9344..e4ea1f394e 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -14,14 +14,12 @@ import ( "net/url" "reflect" "strconv" - "time" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) var ( - defaultValueDecoders DefaultValueDecoders - errCannotTruncate = errors.New("float64 can only be truncated to a lower precision type when truncation is enabled") + errCannotTruncate = errors.New("float64 can only be truncated to a lower precision type when truncation is enabled") ) type decodeBinaryError struct { @@ -43,102 +41,87 @@ func newDefaultStructCodec() *StructCodec { return codec } -// DefaultValueDecoders is a namespace type for the default ValueDecoders used -// when creating a registry. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -type DefaultValueDecoders struct{} - -// RegisterDefaultDecoders will register the decoder methods attached to DefaultValueDecoders with -// the provided RegistryBuilder. +// registerDefaultDecoders will register the default decoder methods with the provided Registry. // // There is no support for decoding map[string]interface{} because there is no decoder for // interface{}, so users must either register this decoder themselves or use the // EmptyInterfaceDecoder available in the bson package. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) { - if rb == nil { +func registerDefaultDecoders(reg *Registry) { + if reg == nil { panic(errors.New("argument to RegisterDefaultDecoders must not be nil")) } - intDecoder := decodeAdapter{dvd.IntDecodeValue, dvd.intDecodeType} - floatDecoder := decodeAdapter{dvd.FloatDecodeValue, dvd.floatDecodeType} - - rb. - RegisterTypeDecoder(tD, ValueDecoderFunc(dvd.DDecodeValue)). - RegisterTypeDecoder(tBinary, decodeAdapter{dvd.BinaryDecodeValue, dvd.binaryDecodeType}). - RegisterTypeDecoder(tUndefined, decodeAdapter{dvd.UndefinedDecodeValue, dvd.undefinedDecodeType}). - RegisterTypeDecoder(tDateTime, decodeAdapter{dvd.DateTimeDecodeValue, dvd.dateTimeDecodeType}). - RegisterTypeDecoder(tNull, decodeAdapter{dvd.NullDecodeValue, dvd.nullDecodeType}). - RegisterTypeDecoder(tRegex, decodeAdapter{dvd.RegexDecodeValue, dvd.regexDecodeType}). - RegisterTypeDecoder(tDBPointer, decodeAdapter{dvd.DBPointerDecodeValue, dvd.dBPointerDecodeType}). - RegisterTypeDecoder(tTimestamp, decodeAdapter{dvd.TimestampDecodeValue, dvd.timestampDecodeType}). - RegisterTypeDecoder(tMinKey, decodeAdapter{dvd.MinKeyDecodeValue, dvd.minKeyDecodeType}). - RegisterTypeDecoder(tMaxKey, decodeAdapter{dvd.MaxKeyDecodeValue, dvd.maxKeyDecodeType}). - RegisterTypeDecoder(tJavaScript, decodeAdapter{dvd.JavaScriptDecodeValue, dvd.javaScriptDecodeType}). - RegisterTypeDecoder(tSymbol, decodeAdapter{dvd.SymbolDecodeValue, dvd.symbolDecodeType}). - RegisterTypeDecoder(tByteSlice, defaultByteSliceCodec). - RegisterTypeDecoder(tTime, defaultTimeCodec). - RegisterTypeDecoder(tEmpty, defaultEmptyInterfaceCodec). - RegisterTypeDecoder(tCoreArray, defaultArrayCodec). - RegisterTypeDecoder(tOID, decodeAdapter{dvd.ObjectIDDecodeValue, dvd.objectIDDecodeType}). - RegisterTypeDecoder(tDecimal, decodeAdapter{dvd.Decimal128DecodeValue, dvd.decimal128DecodeType}). - RegisterTypeDecoder(tJSONNumber, decodeAdapter{dvd.JSONNumberDecodeValue, dvd.jsonNumberDecodeType}). - RegisterTypeDecoder(tURL, decodeAdapter{dvd.URLDecodeValue, dvd.urlDecodeType}). - RegisterTypeDecoder(tCoreDocument, ValueDecoderFunc(dvd.CoreDocumentDecodeValue)). - RegisterTypeDecoder(tCodeWithScope, decodeAdapter{dvd.CodeWithScopeDecodeValue, dvd.codeWithScopeDecodeType}). - RegisterDefaultDecoder(reflect.Bool, decodeAdapter{dvd.BooleanDecodeValue, dvd.booleanDecodeType}). - RegisterDefaultDecoder(reflect.Int, intDecoder). - RegisterDefaultDecoder(reflect.Int8, intDecoder). - RegisterDefaultDecoder(reflect.Int16, intDecoder). - RegisterDefaultDecoder(reflect.Int32, intDecoder). - RegisterDefaultDecoder(reflect.Int64, intDecoder). - RegisterDefaultDecoder(reflect.Uint, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint8, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint16, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint32, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Uint64, defaultUIntCodec). - RegisterDefaultDecoder(reflect.Float32, floatDecoder). - RegisterDefaultDecoder(reflect.Float64, floatDecoder). - RegisterDefaultDecoder(reflect.Array, ValueDecoderFunc(dvd.ArrayDecodeValue)). - RegisterDefaultDecoder(reflect.Map, defaultMapCodec). - RegisterDefaultDecoder(reflect.Slice, defaultSliceCodec). - RegisterDefaultDecoder(reflect.String, defaultStringCodec). - RegisterDefaultDecoder(reflect.Struct, newDefaultStructCodec()). - RegisterDefaultDecoder(reflect.Ptr, NewPointerCodec()). - RegisterTypeMapEntry(TypeDouble, tFloat64). - RegisterTypeMapEntry(TypeString, tString). - RegisterTypeMapEntry(TypeArray, tA). - RegisterTypeMapEntry(TypeBinary, tBinary). - RegisterTypeMapEntry(TypeUndefined, tUndefined). - RegisterTypeMapEntry(TypeObjectID, tOID). - RegisterTypeMapEntry(TypeBoolean, tBool). - RegisterTypeMapEntry(TypeDateTime, tDateTime). - RegisterTypeMapEntry(TypeRegex, tRegex). - RegisterTypeMapEntry(TypeDBPointer, tDBPointer). - RegisterTypeMapEntry(TypeJavaScript, tJavaScript). - RegisterTypeMapEntry(TypeSymbol, tSymbol). - RegisterTypeMapEntry(TypeCodeWithScope, tCodeWithScope). - RegisterTypeMapEntry(TypeInt32, tInt32). - RegisterTypeMapEntry(TypeInt64, tInt64). - RegisterTypeMapEntry(TypeTimestamp, tTimestamp). - RegisterTypeMapEntry(TypeDecimal128, tDecimal). - RegisterTypeMapEntry(TypeMinKey, tMinKey). - RegisterTypeMapEntry(TypeMaxKey, tMaxKey). - RegisterTypeMapEntry(Type(0), tD). - RegisterTypeMapEntry(TypeEmbeddedDocument, tD). - RegisterHookDecoder(tValueUnmarshaler, ValueDecoderFunc(dvd.ValueUnmarshalerDecodeValue)). - RegisterHookDecoder(tUnmarshaler, ValueDecoderFunc(dvd.UnmarshalerDecodeValue)) + intDecoder := decodeAdapter{intDecodeValue, intDecodeType} + floatDecoder := decodeAdapter{floatDecodeValue, floatDecodeType} + + reg.RegisterTypeDecoder(tD, ValueDecoderFunc(dDecodeValue)) + reg.RegisterTypeDecoder(tBinary, decodeAdapter{binaryDecodeValue, binaryDecodeType}) + reg.RegisterTypeDecoder(tUndefined, decodeAdapter{undefinedDecodeValue, undefinedDecodeType}) + reg.RegisterTypeDecoder(tDateTime, decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType}) + reg.RegisterTypeDecoder(tNull, decodeAdapter{nullDecodeValue, nullDecodeType}) + reg.RegisterTypeDecoder(tRegex, decodeAdapter{regexDecodeValue, regexDecodeType}) + reg.RegisterTypeDecoder(tDBPointer, decodeAdapter{dbPointerDecodeValue, dbPointerDecodeType}) + reg.RegisterTypeDecoder(tTimestamp, decodeAdapter{timestampDecodeValue, timestampDecodeType}) + reg.RegisterTypeDecoder(tMinKey, decodeAdapter{minKeyDecodeValue, minKeyDecodeType}) + reg.RegisterTypeDecoder(tMaxKey, decodeAdapter{maxKeyDecodeValue, maxKeyDecodeType}) + reg.RegisterTypeDecoder(tJavaScript, decodeAdapter{javaScriptDecodeValue, javaScriptDecodeType}) + reg.RegisterTypeDecoder(tSymbol, decodeAdapter{symbolDecodeValue, symbolDecodeType}) + reg.RegisterTypeDecoder(tByteSlice, defaultByteSliceCodec) + reg.RegisterTypeDecoder(tTime, defaultTimeCodec) + reg.RegisterTypeDecoder(tEmpty, defaultEmptyInterfaceCodec) + reg.RegisterTypeDecoder(tCoreArray, defaultArrayCodec) + reg.RegisterTypeDecoder(tOID, decodeAdapter{objectIDDecodeValue, objectIDDecodeType}) + reg.RegisterTypeDecoder(tDecimal, decodeAdapter{decimal128DecodeValue, decimal128DecodeType}) + reg.RegisterTypeDecoder(tJSONNumber, decodeAdapter{jsonNumberDecodeValue, jsonNumberDecodeType}) + reg.RegisterTypeDecoder(tURL, decodeAdapter{urlDecodeValue, urlDecodeType}) + reg.RegisterTypeDecoder(tCoreDocument, ValueDecoderFunc(coreDocumentDecodeValue)) + reg.RegisterTypeDecoder(tCodeWithScope, decodeAdapter{codeWithScopeDecodeValue, codeWithScopeDecodeType}) + reg.RegisterKindDecoder(reflect.Bool, decodeAdapter{booleanDecodeValue, booleanDecodeType}) + reg.RegisterKindDecoder(reflect.Int, intDecoder) + reg.RegisterKindDecoder(reflect.Int8, intDecoder) + reg.RegisterKindDecoder(reflect.Int16, intDecoder) + reg.RegisterKindDecoder(reflect.Int32, intDecoder) + reg.RegisterKindDecoder(reflect.Int64, intDecoder) + reg.RegisterKindDecoder(reflect.Uint, defaultUIntCodec) + reg.RegisterKindDecoder(reflect.Uint8, defaultUIntCodec) + reg.RegisterKindDecoder(reflect.Uint16, defaultUIntCodec) + reg.RegisterKindDecoder(reflect.Uint32, defaultUIntCodec) + reg.RegisterKindDecoder(reflect.Uint64, defaultUIntCodec) + reg.RegisterKindDecoder(reflect.Float32, floatDecoder) + reg.RegisterKindDecoder(reflect.Float64, floatDecoder) + reg.RegisterKindDecoder(reflect.Array, ValueDecoderFunc(arrayDecodeValue)) + reg.RegisterKindDecoder(reflect.Map, defaultMapCodec) + reg.RegisterKindDecoder(reflect.Slice, defaultSliceCodec) + reg.RegisterKindDecoder(reflect.String, defaultStringCodec) + reg.RegisterKindDecoder(reflect.Struct, newDefaultStructCodec()) + reg.RegisterKindDecoder(reflect.Ptr, NewPointerCodec()) + reg.RegisterTypeMapEntry(TypeDouble, tFloat64) + reg.RegisterTypeMapEntry(TypeString, tString) + reg.RegisterTypeMapEntry(TypeArray, tA) + reg.RegisterTypeMapEntry(TypeBinary, tBinary) + reg.RegisterTypeMapEntry(TypeUndefined, tUndefined) + reg.RegisterTypeMapEntry(TypeObjectID, tOID) + reg.RegisterTypeMapEntry(TypeBoolean, tBool) + reg.RegisterTypeMapEntry(TypeDateTime, tDateTime) + reg.RegisterTypeMapEntry(TypeRegex, tRegex) + reg.RegisterTypeMapEntry(TypeDBPointer, tDBPointer) + reg.RegisterTypeMapEntry(TypeJavaScript, tJavaScript) + reg.RegisterTypeMapEntry(TypeSymbol, tSymbol) + reg.RegisterTypeMapEntry(TypeCodeWithScope, tCodeWithScope) + reg.RegisterTypeMapEntry(TypeInt32, tInt32) + reg.RegisterTypeMapEntry(TypeInt64, tInt64) + reg.RegisterTypeMapEntry(TypeTimestamp, tTimestamp) + reg.RegisterTypeMapEntry(TypeDecimal128, tDecimal) + reg.RegisterTypeMapEntry(TypeMinKey, tMinKey) + reg.RegisterTypeMapEntry(TypeMaxKey, tMaxKey) + reg.RegisterTypeMapEntry(Type(0), tD) + reg.RegisterTypeMapEntry(TypeEmbeddedDocument, tD) + reg.RegisterInterfaceDecoder(tValueUnmarshaler, ValueDecoderFunc(valueUnmarshalerDecodeValue)) + reg.RegisterInterfaceDecoder(tUnmarshaler, ValueDecoderFunc(unmarshalerDecodeValue)) } -// DDecodeValue is the ValueDecoderFunc for D instances. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// dDecodeValue is the ValueDecoderFunc for D instances. +func dDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.IsValid() || !val.CanSet() || val.Type() != tD { return ValueDecoderError{Name: "DDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } @@ -194,7 +177,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr ValueReader, v return nil } -func (dvd DefaultValueDecoders) booleanDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func booleanDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t.Kind() != reflect.Bool { return emptyValue, ValueDecoderError{ Name: "BooleanDecodeValue", @@ -240,16 +223,13 @@ func (dvd DefaultValueDecoders) booleanDecodeType(_ DecodeContext, vr ValueReade return reflect.ValueOf(b), nil } -// BooleanDecodeValue is the ValueDecoderFunc for bool types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// booleanDecodeValue is the ValueDecoderFunc for bool types. +func booleanDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool { return ValueDecoderError{Name: "BooleanDecodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } - elem, err := dvd.booleanDecodeType(dctx, vr, val.Type()) + elem, err := booleanDecodeType(dctx, vr, val.Type()) if err != nil { return err } @@ -258,7 +238,7 @@ func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr ValueR return nil } -func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func intDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { var i64 int64 var err error switch vrType := vr.Type(); vrType { @@ -341,11 +321,8 @@ func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr ValueReader, t re } } -// IntDecodeValue is the ValueDecoderFunc for int types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// intDecodeValue is the ValueDecoderFunc for int types. +func intDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() { return ValueDecoderError{ Name: "IntDecodeValue", @@ -354,7 +331,7 @@ func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr ValueReader, } } - elem, err := dvd.intDecodeType(dc, vr, val.Type()) + elem, err := intDecodeType(dc, vr, val.Type()) if err != nil { return err } @@ -363,90 +340,7 @@ func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr ValueReader, return nil } -// UintDecodeValue is the ValueDecoderFunc for uint types. -// -// Deprecated: UintDecodeValue is not registered by default. Use UintCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) UintDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - var i64 int64 - var err error - switch vr.Type() { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return err - } - i64 = int64(i32) - case TypeInt64: - i64, err = vr.ReadInt64() - if err != nil { - return err - } - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return err - } - if !dc.Truncate && math.Floor(f64) != f64 { - return errors.New("UintDecodeValue can only truncate float64 to an integer type when truncation is enabled") - } - if f64 > float64(math.MaxInt64) { - return fmt.Errorf("%g overflows int64", f64) - } - i64 = int64(f64) - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return err - } - if b { - i64 = 1 - } - default: - return fmt.Errorf("cannot decode %v into an integer type", vr.Type()) - } - - if !val.CanSet() { - return ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } - } - - switch val.Kind() { - case reflect.Uint8: - if i64 < 0 || i64 > math.MaxUint8 { - return fmt.Errorf("%d overflows uint8", i64) - } - case reflect.Uint16: - if i64 < 0 || i64 > math.MaxUint16 { - return fmt.Errorf("%d overflows uint16", i64) - } - case reflect.Uint32: - if i64 < 0 || i64 > math.MaxUint32 { - return fmt.Errorf("%d overflows uint32", i64) - } - case reflect.Uint64: - if i64 < 0 { - return fmt.Errorf("%d overflows uint64", i64) - } - case reflect.Uint: - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint - return fmt.Errorf("%d overflows uint", i64) - } - default: - return ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } - } - - val.SetUint(uint64(i64)) - return nil -} - -func (dvd DefaultValueDecoders) floatDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func floatDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { var f float64 var err error switch vrType := vr.Type(); vrType { @@ -505,11 +399,8 @@ func (dvd DefaultValueDecoders) floatDecodeType(dc DecodeContext, vr ValueReader } } -// FloatDecodeValue is the ValueDecoderFunc for float types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr ValueReader, val reflect.Value) error { +// floatDecodeValue is the ValueDecoderFunc for float types. +func floatDecodeValue(ec DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() { return ValueDecoderError{ Name: "FloatDecodeValue", @@ -518,7 +409,7 @@ func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr ValueReade } } - elem, err := dvd.floatDecodeType(ec, vr, val.Type()) + elem, err := floatDecodeType(ec, vr, val.Type()) if err != nil { return err } @@ -527,31 +418,7 @@ func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr ValueReade return nil } -// StringDecodeValue is the ValueDecoderFunc for string types. -// -// Deprecated: StringDecodeValue is not registered by default. Use StringCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) StringDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { - var str string - var err error - switch vr.Type() { - // TODO(GODRIVER-577): Handle JavaScript and Symbol BSON types when allowed. - case TypeString: - str, err = vr.ReadString() - if err != nil { - return err - } - default: - return fmt.Errorf("cannot decode %v into a string type", vr.Type()) - } - if !val.CanSet() || val.Kind() != reflect.String { - return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} - } - - val.SetString(str) - return nil -} - -func (DefaultValueDecoders) javaScriptDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func javaScriptDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tJavaScript { return emptyValue, ValueDecoderError{ Name: "JavaScriptDecodeValue", @@ -579,16 +446,13 @@ func (DefaultValueDecoders) javaScriptDecodeType(_ DecodeContext, vr ValueReader return reflect.ValueOf(JavaScript(js)), nil } -// JavaScriptDecodeValue is the ValueDecoderFunc for the JavaScript type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// javaScriptDecodeValue is the ValueDecoderFunc for the JavaScript type. +func javaScriptDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tJavaScript { return ValueDecoderError{Name: "JavaScriptDecodeValue", Types: []reflect.Type{tJavaScript}, Received: val} } - elem, err := dvd.javaScriptDecodeType(dctx, vr, tJavaScript) + elem, err := javaScriptDecodeType(dctx, vr, tJavaScript) if err != nil { return err } @@ -597,7 +461,7 @@ func (dvd DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr Val return nil } -func (DefaultValueDecoders) symbolDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func symbolDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tSymbol { return emptyValue, ValueDecoderError{ Name: "SymbolDecodeValue", @@ -637,16 +501,13 @@ func (DefaultValueDecoders) symbolDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(Symbol(symbol)), nil } -// SymbolDecodeValue is the ValueDecoderFunc for the Symbol type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// symbolDecodeValue is the ValueDecoderFunc for the Symbol type. +func symbolDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tSymbol { return ValueDecoderError{Name: "SymbolDecodeValue", Types: []reflect.Type{tSymbol}, Received: val} } - elem, err := dvd.symbolDecodeType(dctx, vr, tSymbol) + elem, err := symbolDecodeType(dctx, vr, tSymbol) if err != nil { return err } @@ -655,7 +516,7 @@ func (dvd DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr ValueRe return nil } -func (DefaultValueDecoders) binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tBinary { return emptyValue, ValueDecoderError{ Name: "BinaryDecodeValue", @@ -684,16 +545,13 @@ func (DefaultValueDecoders) binaryDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(Binary{Subtype: subtype, Data: data}), nil } -// BinaryDecodeValue is the ValueDecoderFunc for Binary. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// binaryDecodeValue is the ValueDecoderFunc for Binary. +func binaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tBinary { return ValueDecoderError{Name: "BinaryDecodeValue", Types: []reflect.Type{tBinary}, Received: val} } - elem, err := dvd.binaryDecodeType(dc, vr, tBinary) + elem, err := binaryDecodeType(dc, vr, tBinary) if err != nil { return err } @@ -702,7 +560,7 @@ func (dvd DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr ValueRead return nil } -func (DefaultValueDecoders) undefinedDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func undefinedDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tUndefined { return emptyValue, ValueDecoderError{ Name: "UndefinedDecodeValue", @@ -727,16 +585,13 @@ func (DefaultValueDecoders) undefinedDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(Undefined{}), nil } -// UndefinedDecodeValue is the ValueDecoderFunc for Undefined. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// undefinedDecodeValue is the ValueDecoderFunc for Undefined. +func undefinedDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tUndefined { return ValueDecoderError{Name: "UndefinedDecodeValue", Types: []reflect.Type{tUndefined}, Received: val} } - elem, err := dvd.undefinedDecodeType(dc, vr, tUndefined) + elem, err := undefinedDecodeType(dc, vr, tUndefined) if err != nil { return err } @@ -746,7 +601,7 @@ func (dvd DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr ValueR } // Accept both 12-byte string and pretty-printed 24-byte hex string formats. -func (dvd DefaultValueDecoders) objectIDDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func objectIDDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tOID { return emptyValue, ValueDecoderError{ Name: "ObjectIDDecodeValue", @@ -791,16 +646,13 @@ func (dvd DefaultValueDecoders) objectIDDecodeType(_ DecodeContext, vr ValueRead return reflect.ValueOf(oid), nil } -// ObjectIDDecodeValue is the ValueDecoderFunc for ObjectID. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// objectIDDecodeValue is the ValueDecoderFunc for ObjectID. +func objectIDDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tOID { return ValueDecoderError{Name: "ObjectIDDecodeValue", Types: []reflect.Type{tOID}, Received: val} } - elem, err := dvd.objectIDDecodeType(dc, vr, tOID) + elem, err := objectIDDecodeType(dc, vr, tOID) if err != nil { return err } @@ -809,7 +661,7 @@ func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr ValueRe return nil } -func (DefaultValueDecoders) dateTimeDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func dateTimeDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDateTime { return emptyValue, ValueDecoderError{ Name: "DateTimeDecodeValue", @@ -837,16 +689,13 @@ func (DefaultValueDecoders) dateTimeDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(DateTime(dt)), nil } -// DateTimeDecodeValue is the ValueDecoderFunc for DateTime. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// dateTimeDecodeValue is the ValueDecoderFunc for DateTime. +func dateTimeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDateTime { return ValueDecoderError{Name: "DateTimeDecodeValue", Types: []reflect.Type{tDateTime}, Received: val} } - elem, err := dvd.dateTimeDecodeType(dc, vr, tDateTime) + elem, err := dateTimeDecodeType(dc, vr, tDateTime) if err != nil { return err } @@ -855,7 +704,7 @@ func (dvd DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr ValueRe return nil } -func (DefaultValueDecoders) nullDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func nullDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tNull { return emptyValue, ValueDecoderError{ Name: "NullDecodeValue", @@ -880,16 +729,13 @@ func (DefaultValueDecoders) nullDecodeType(_ DecodeContext, vr ValueReader, t re return reflect.ValueOf(Null{}), nil } -// NullDecodeValue is the ValueDecoderFunc for Null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// nullDecodeValue is the ValueDecoderFunc for Null. +func nullDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tNull { return ValueDecoderError{Name: "NullDecodeValue", Types: []reflect.Type{tNull}, Received: val} } - elem, err := dvd.nullDecodeType(dc, vr, tNull) + elem, err := nullDecodeType(dc, vr, tNull) if err != nil { return err } @@ -898,7 +744,7 @@ func (dvd DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr ValueReader return nil } -func (DefaultValueDecoders) regexDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func regexDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tRegex { return emptyValue, ValueDecoderError{ Name: "RegexDecodeValue", @@ -926,16 +772,13 @@ func (DefaultValueDecoders) regexDecodeType(_ DecodeContext, vr ValueReader, t r return reflect.ValueOf(Regex{Pattern: pattern, Options: options}), nil } -// RegexDecodeValue is the ValueDecoderFunc for Regex. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// regexDecodeValue is the ValueDecoderFunc for Regex. +func regexDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRegex { return ValueDecoderError{Name: "RegexDecodeValue", Types: []reflect.Type{tRegex}, Received: val} } - elem, err := dvd.regexDecodeType(dc, vr, tRegex) + elem, err := regexDecodeType(dc, vr, tRegex) if err != nil { return err } @@ -944,7 +787,7 @@ func (dvd DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr ValueReade return nil } -func (DefaultValueDecoders) dBPointerDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func dbPointerDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDBPointer { return emptyValue, ValueDecoderError{ Name: "DBPointerDecodeValue", @@ -973,16 +816,13 @@ func (DefaultValueDecoders) dBPointerDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(DBPointer{DB: ns, Pointer: pointer}), nil } -// DBPointerDecodeValue is the ValueDecoderFunc for DBPointer. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// dbPointerDecodeValue is the ValueDecoderFunc for DBPointer. +func dbPointerDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDBPointer { return ValueDecoderError{Name: "DBPointerDecodeValue", Types: []reflect.Type{tDBPointer}, Received: val} } - elem, err := dvd.dBPointerDecodeType(dc, vr, tDBPointer) + elem, err := dbPointerDecodeType(dc, vr, tDBPointer) if err != nil { return err } @@ -991,7 +831,7 @@ func (dvd DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr ValueR return nil } -func (DefaultValueDecoders) timestampDecodeType(_ DecodeContext, vr ValueReader, reflectType reflect.Type) (reflect.Value, error) { +func timestampDecodeType(_ DecodeContext, vr ValueReader, reflectType reflect.Type) (reflect.Value, error) { if reflectType != tTimestamp { return emptyValue, ValueDecoderError{ Name: "TimestampDecodeValue", @@ -1019,16 +859,13 @@ func (DefaultValueDecoders) timestampDecodeType(_ DecodeContext, vr ValueReader, return reflect.ValueOf(Timestamp{T: t, I: incr}), nil } -// TimestampDecodeValue is the ValueDecoderFunc for Timestamp. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// timestampDecodeValue is the ValueDecoderFunc for Timestamp. +func timestampDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tTimestamp { return ValueDecoderError{Name: "TimestampDecodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } - elem, err := dvd.timestampDecodeType(dc, vr, tTimestamp) + elem, err := timestampDecodeType(dc, vr, tTimestamp) if err != nil { return err } @@ -1037,7 +874,7 @@ func (dvd DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr ValueR return nil } -func (DefaultValueDecoders) minKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func minKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tMinKey { return emptyValue, ValueDecoderError{ Name: "MinKeyDecodeValue", @@ -1064,16 +901,13 @@ func (DefaultValueDecoders) minKeyDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(MinKey{}), nil } -// MinKeyDecodeValue is the ValueDecoderFunc for MinKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// minKeyDecodeValue is the ValueDecoderFunc for MinKey. +func minKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tMinKey { return ValueDecoderError{Name: "MinKeyDecodeValue", Types: []reflect.Type{tMinKey}, Received: val} } - elem, err := dvd.minKeyDecodeType(dc, vr, tMinKey) + elem, err := minKeyDecodeType(dc, vr, tMinKey) if err != nil { return err } @@ -1082,7 +916,7 @@ func (dvd DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr ValueRead return nil } -func (DefaultValueDecoders) maxKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func maxKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tMaxKey { return emptyValue, ValueDecoderError{ Name: "MaxKeyDecodeValue", @@ -1109,16 +943,13 @@ func (DefaultValueDecoders) maxKeyDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(MaxKey{}), nil } -// MaxKeyDecodeValue is the ValueDecoderFunc for MaxKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// maxKeyDecodeValue is the ValueDecoderFunc for MaxKey. +func maxKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tMaxKey { return ValueDecoderError{Name: "MaxKeyDecodeValue", Types: []reflect.Type{tMaxKey}, Received: val} } - elem, err := dvd.maxKeyDecodeType(dc, vr, tMaxKey) + elem, err := maxKeyDecodeType(dc, vr, tMaxKey) if err != nil { return err } @@ -1127,7 +958,7 @@ func (dvd DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr ValueRead return nil } -func (dvd DefaultValueDecoders) decimal128DecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func decimal128DecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDecimal { return emptyValue, ValueDecoderError{ Name: "Decimal128DecodeValue", @@ -1155,16 +986,13 @@ func (dvd DefaultValueDecoders) decimal128DecodeType(_ DecodeContext, vr ValueRe return reflect.ValueOf(d128), nil } -// Decimal128DecodeValue is the ValueDecoderFunc for Decimal128. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +// decimal128DecodeValue is the ValueDecoderFunc for Decimal128. +func decimal128DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDecimal { return ValueDecoderError{Name: "Decimal128DecodeValue", Types: []reflect.Type{tDecimal}, Received: val} } - elem, err := dvd.decimal128DecodeType(dctx, vr, tDecimal) + elem, err := decimal128DecodeType(dctx, vr, tDecimal) if err != nil { return err } @@ -1173,7 +1001,7 @@ func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr Val return nil } -func (dvd DefaultValueDecoders) jsonNumberDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func jsonNumberDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tJSONNumber { return emptyValue, ValueDecoderError{ Name: "JSONNumberDecodeValue", @@ -1217,16 +1045,13 @@ func (dvd DefaultValueDecoders) jsonNumberDecodeType(_ DecodeContext, vr ValueRe return reflect.ValueOf(jsonNum), nil } -// JSONNumberDecodeValue is the ValueDecoderFunc for json.Number. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// jsonNumberDecodeValue is the ValueDecoderFunc for json.Number. +func jsonNumberDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tJSONNumber { return ValueDecoderError{Name: "JSONNumberDecodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} } - elem, err := dvd.jsonNumberDecodeType(dc, vr, tJSONNumber) + elem, err := jsonNumberDecodeType(dc, vr, tJSONNumber) if err != nil { return err } @@ -1235,7 +1060,7 @@ func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr Value return nil } -func (dvd DefaultValueDecoders) urlDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func urlDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tURL { return emptyValue, ValueDecoderError{ Name: "URLDecodeValue", @@ -1269,16 +1094,13 @@ func (dvd DefaultValueDecoders) urlDecodeType(_ DecodeContext, vr ValueReader, t return reflect.ValueOf(urlPtr).Elem(), nil } -// URLDecodeValue is the ValueDecoderFunc for url.URL. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// urlDecodeValue is the ValueDecoderFunc for url.URL. +func urlDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tURL { return ValueDecoderError{Name: "URLDecodeValue", Types: []reflect.Type{tURL}, Received: val} } - elem, err := dvd.urlDecodeType(dc, vr, tURL) + elem, err := urlDecodeType(dc, vr, tURL) if err != nil { return err } @@ -1287,119 +1109,8 @@ func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr ValueReader, return nil } -// TimeDecodeValue is the ValueDecoderFunc for time.Time. -// -// Deprecated: TimeDecodeValue is not registered by default. Use TimeCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) TimeDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { - if vr.Type() != TypeDateTime { - return fmt.Errorf("cannot decode %v into a time.Time", vr.Type()) - } - - dt, err := vr.ReadDateTime() - if err != nil { - return err - } - - if !val.CanSet() || val.Type() != tTime { - return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val} - } - - val.Set(reflect.ValueOf(time.Unix(dt/1000, dt%1000*1000000).UTC())) - return nil -} - -// ByteSliceDecodeValue is the ValueDecoderFunc for []byte. -// -// Deprecated: ByteSliceDecodeValue is not registered by default. Use ByteSliceCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) ByteSliceDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { - if vr.Type() != TypeBinary && vr.Type() != TypeNull { - return fmt.Errorf("cannot decode %v into a []byte", vr.Type()) - } - - if !val.CanSet() || val.Type() != tByteSlice { - return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val} - } - - if vr.Type() == TypeNull { - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - } - - data, subtype, err := vr.ReadBinary() - if err != nil { - return err - } - if subtype != 0x00 { - return fmt.Errorf("ByteSliceDecodeValue can only be used to decode subtype 0x00 for %s, got %v", TypeBinary, subtype) - } - - val.Set(reflect.ValueOf(data)) - return nil -} - -// MapDecodeValue is the ValueDecoderFunc for map[string]* types. -// -// Deprecated: MapDecodeValue is not registered by default. Use MapCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) MapDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String { - return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} - } - - switch vr.Type() { - case Type(0), TypeEmbeddedDocument: - case TypeNull: - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - default: - return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type()) - } - - dr, err := vr.ReadDocument() - if err != nil { - return err - } - - if val.IsNil() { - val.Set(reflect.MakeMap(val.Type())) - } - - eType := val.Type().Elem() - decoder, err := dc.LookupDecoder(eType) - if err != nil { - return err - } - - if eType == tEmpty { - dc.Ancestor = val.Type() - } - - keyType := val.Type().Key() - for { - key, vr, err := dr.ReadElement() - if errors.Is(err, ErrEOD) { - break - } - if err != nil { - return err - } - - elem := reflect.New(eType).Elem() - - err = decoder.DecodeValue(dc, vr, elem) - if err != nil { - return err - } - - val.SetMapIndex(reflect.ValueOf(key).Convert(keyType), elem) - } - return nil -} - -// ArrayDecodeValue is the ValueDecoderFunc for array types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// arrayDecodeValue is the ValueDecoderFunc for array types. +func arrayDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueDecoderError{Name: "ArrayDecodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} } @@ -1443,9 +1154,9 @@ func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr ValueReade var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error) switch val.Type().Elem() { case tE: - elemsFunc = dvd.decodeD + elemsFunc = decodeD default: - elemsFunc = dvd.decodeDefault + elemsFunc = decodeDefault } elems, err := elemsFunc(dc, vr, val) @@ -1464,56 +1175,8 @@ func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr ValueReade return nil } -// SliceDecodeValue is the ValueDecoderFunc for slice types. -// -// Deprecated: SliceDecodeValue is not registered by default. Use SliceCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) SliceDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Kind() != reflect.Slice { - return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} - } - - switch vr.Type() { - case TypeArray: - case TypeNull: - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - case Type(0), TypeEmbeddedDocument: - if val.Type().Elem() != tE { - return fmt.Errorf("cannot decode document into %s", val.Type()) - } - default: - return fmt.Errorf("cannot decode %v into a slice", vr.Type()) - } - - var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error) - switch val.Type().Elem() { - case tE: - dc.Ancestor = val.Type() - elemsFunc = dvd.decodeD - default: - elemsFunc = dvd.decodeDefault - } - - elems, err := elemsFunc(dc, vr, val) - if err != nil { - return err - } - - if val.IsNil() { - val.Set(reflect.MakeSlice(val.Type(), 0, len(elems))) - } - - val.SetLen(0) - val.Set(reflect.Append(val, elems...)) - - return nil -} - -// ValueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// valueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations. +func valueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tValueUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tValueUnmarshaler)) { return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } @@ -1545,11 +1208,8 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr return m.UnmarshalBSONValue(t, src) } -// UnmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// unmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations. +func unmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tUnmarshaler)) { return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} } @@ -1593,51 +1253,8 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr Value return m.UnmarshalBSON(src) } -// EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}. -// -// Deprecated: EmptyInterfaceDecodeValue is not registered by default. Use EmptyInterfaceCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) EmptyInterfaceDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tEmpty { - return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} - } - - rtype, err := dc.LookupTypeMapEntry(vr.Type()) - if err != nil { - switch vr.Type() { - case TypeEmbeddedDocument: - if dc.Ancestor != nil { - rtype = dc.Ancestor - break - } - rtype = tD - case TypeNull: - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - default: - return err - } - } - - decoder, err := dc.LookupDecoder(rtype) - if err != nil { - return err - } - - elem := reflect.New(rtype).Elem() - err = decoder.DecodeValue(dc, vr, elem) - if err != nil { - return err - } - - val.Set(elem) - return nil -} - -// CoreDocumentDecodeValue is the ValueDecoderFunc for bsoncore.Document. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (DefaultValueDecoders) CoreDocumentDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// coreDocumentDecodeValue is the ValueDecoderFunc for bsoncore.Document. +func coreDocumentDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreDocument { return ValueDecoderError{Name: "CoreDocumentDecodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -1653,7 +1270,7 @@ func (DefaultValueDecoders) CoreDocumentDecodeValue(_ DecodeContext, vr ValueRea return err } -func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]reflect.Value, error) { +func decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]reflect.Value, error) { elems := make([]reflect.Value, 0) ar, err := vr.ReadArray() @@ -1690,31 +1307,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr ValueReader, return elems, nil } -func (dvd DefaultValueDecoders) readCodeWithScope(dc DecodeContext, vr ValueReader) (CodeWithScope, error) { - var cws CodeWithScope - - code, dr, err := vr.ReadCodeWithScope() - if err != nil { - return cws, err - } - - scope := reflect.New(tD).Elem() - elems, err := dvd.decodeElemsFromDocumentReader(dc, dr) - if err != nil { - return cws, err - } - - scope.Set(reflect.MakeSlice(tD, 0, len(elems))) - scope.Set(reflect.Append(scope, elems...)) - - cws = CodeWithScope{ - Code: JavaScript(code), - Scope: scope.Interface().(D), - } - return cws, nil -} - -func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func codeWithScopeDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tCodeWithScope { return emptyValue, ValueDecoderError{ Name: "CodeWithScopeDecodeValue", @@ -1727,7 +1320,24 @@ func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr Val var err error switch vrType := vr.Type(); vrType { case TypeCodeWithScope: - cws, err = dvd.readCodeWithScope(dc, vr) + code, dr, err := vr.ReadCodeWithScope() + if err != nil { + return emptyValue, err + } + + scope := reflect.New(tD).Elem() + elems, err := decodeElemsFromDocumentReader(dc, dr) + if err != nil { + return emptyValue, err + } + + scope.Set(reflect.MakeSlice(tD, 0, len(elems))) + scope.Set(reflect.Append(scope, elems...)) + + cws = CodeWithScope{ + Code: JavaScript(code), + Scope: scope.Interface().(D), + } case TypeNull: err = vr.ReadNull() case TypeUndefined: @@ -1742,16 +1352,13 @@ func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr Val return reflect.ValueOf(cws), nil } -// CodeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value decoders registered. -func (dvd DefaultValueDecoders) CodeWithScopeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +// codeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope. +func codeWithScopeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCodeWithScope { return ValueDecoderError{Name: "CodeWithScopeDecodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} } - elem, err := dvd.codeWithScopeDecodeType(dc, vr, tCodeWithScope) + elem, err := codeWithScopeDecodeType(dc, vr, tCodeWithScope) if err != nil { return err } @@ -1760,7 +1367,7 @@ func (dvd DefaultValueDecoders) CodeWithScopeDecodeValue(dc DecodeContext, vr Va return nil } -func (dvd DefaultValueDecoders) decodeD(dc DecodeContext, vr ValueReader, _ reflect.Value) ([]reflect.Value, error) { +func decodeD(dc DecodeContext, vr ValueReader, _ reflect.Value) ([]reflect.Value, error) { switch vr.Type() { case Type(0), TypeEmbeddedDocument: default: @@ -1772,10 +1379,10 @@ func (dvd DefaultValueDecoders) decodeD(dc DecodeContext, vr ValueReader, _ refl return nil, err } - return dvd.decodeElemsFromDocumentReader(dc, dr) + return decodeElemsFromDocumentReader(dc, dr) } -func (DefaultValueDecoders) decodeElemsFromDocumentReader(dc DecodeContext, dr DocumentReader) ([]reflect.Value, error) { +func decodeElemsFromDocumentReader(dc DecodeContext, dr DocumentReader) ([]reflect.Value, error) { decoder, err := dc.LookupDecoder(tEmpty) if err != nil { return nil, err diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 699a958605..56fdc464c2 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -27,7 +27,6 @@ var ( ) func TestDefaultValueDecoders(t *testing.T) { - var dvd DefaultValueDecoders var wrong = func(string, string) string { return "wrong" } type mybool bool @@ -71,7 +70,7 @@ func TestDefaultValueDecoders(t *testing.T) { }{ { "BooleanDecodeValue", - ValueDecoderFunc(dvd.BooleanDecodeValue), + ValueDecoderFunc(booleanDecodeValue), []subtest{ { "wrong type", @@ -140,7 +139,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "IntDecodeValue", - ValueDecoderFunc(dvd.IntDecodeValue), + ValueDecoderFunc(intDecodeValue), []subtest{ { "wrong type", @@ -608,7 +607,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "FloatDecodeValue", - ValueDecoderFunc(dvd.FloatDecodeValue), + ValueDecoderFunc(floatDecodeValue), []subtest{ { "wrong type", @@ -820,7 +819,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", map[string]string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + &DecodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, readDocument, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -869,7 +868,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "ArrayDecodeValue", - ValueDecoderFunc(dvd.ArrayDecodeValue), + ValueDecoderFunc(arrayDecodeValue), []subtest{ { "wrong kind", @@ -906,7 +905,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", [1]string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + &DecodeContext{Registry: newTestRegistry()}, &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -1000,7 +999,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", []string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + &DecodeContext{Registry: newTestRegistry()}, &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -1057,7 +1056,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "ObjectIDDecodeValue", - ValueDecoderFunc(dvd.ObjectIDDecodeValue), + ValueDecoderFunc(objectIDDecodeValue), []subtest{ { "wrong type", @@ -1144,7 +1143,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "Decimal128DecodeValue", - ValueDecoderFunc(dvd.Decimal128DecodeValue), + ValueDecoderFunc(decimal128DecodeValue), []subtest{ { "wrong type", @@ -1206,7 +1205,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "JSONNumberDecodeValue", - ValueDecoderFunc(dvd.JSONNumberDecodeValue), + ValueDecoderFunc(jsonNumberDecodeValue), []subtest{ { "wrong type", @@ -1300,7 +1299,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "URLDecodeValue", - ValueDecoderFunc(dvd.URLDecodeValue), + ValueDecoderFunc(urlDecodeValue), []subtest{ { "wrong type", @@ -1472,7 +1471,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "ValueUnmarshalerDecodeValue", - ValueDecoderFunc(dvd.ValueUnmarshalerDecodeValue), + ValueDecoderFunc(valueUnmarshalerDecodeValue), []subtest{ { "wrong type", @@ -1506,7 +1505,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "UnmarshalerDecodeValue", - ValueDecoderFunc(dvd.UnmarshalerDecodeValue), + ValueDecoderFunc(unmarshalerDecodeValue), []subtest{ { "wrong type", @@ -1585,7 +1584,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "BinaryDecodeValue", - ValueDecoderFunc(dvd.BinaryDecodeValue), + ValueDecoderFunc(binaryDecodeValue), []subtest{ { "wrong type", @@ -1645,7 +1644,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "UndefinedDecodeValue", - ValueDecoderFunc(dvd.UndefinedDecodeValue), + ValueDecoderFunc(undefinedDecodeValue), []subtest{ { "wrong type", @@ -1691,7 +1690,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "DateTimeDecodeValue", - ValueDecoderFunc(dvd.DateTimeDecodeValue), + ValueDecoderFunc(dateTimeDecodeValue), []subtest{ { "wrong type", @@ -1745,7 +1744,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "NullDecodeValue", - ValueDecoderFunc(dvd.NullDecodeValue), + ValueDecoderFunc(nullDecodeValue), []subtest{ { "wrong type", @@ -1783,7 +1782,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "RegexDecodeValue", - ValueDecoderFunc(dvd.RegexDecodeValue), + ValueDecoderFunc(regexDecodeValue), []subtest{ { "wrong type", @@ -1843,7 +1842,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "DBPointerDecodeValue", - ValueDecoderFunc(dvd.DBPointerDecodeValue), + ValueDecoderFunc(dbPointerDecodeValue), []subtest{ { "wrong type", @@ -1908,7 +1907,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "TimestampDecodeValue", - ValueDecoderFunc(dvd.TimestampDecodeValue), + ValueDecoderFunc(timestampDecodeValue), []subtest{ { "wrong type", @@ -1968,7 +1967,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "MinKeyDecodeValue", - ValueDecoderFunc(dvd.MinKeyDecodeValue), + ValueDecoderFunc(minKeyDecodeValue), []subtest{ { "wrong type", @@ -2022,7 +2021,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "MaxKeyDecodeValue", - ValueDecoderFunc(dvd.MaxKeyDecodeValue), + ValueDecoderFunc(maxKeyDecodeValue), []subtest{ { "wrong type", @@ -2076,7 +2075,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "JavaScriptDecodeValue", - ValueDecoderFunc(dvd.JavaScriptDecodeValue), + ValueDecoderFunc(javaScriptDecodeValue), []subtest{ { "wrong type", @@ -2130,7 +2129,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "SymbolDecodeValue", - ValueDecoderFunc(dvd.SymbolDecodeValue), + ValueDecoderFunc(symbolDecodeValue), []subtest{ { "wrong type", @@ -2184,7 +2183,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "CoreDocumentDecodeValue", - ValueDecoderFunc(dvd.CoreDocumentDecodeValue), + ValueDecoderFunc(coreDocumentDecodeValue), []subtest{ { "wrong type", @@ -2252,7 +2251,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "CodeWithScopeDecodeValue", - ValueDecoderFunc(dvd.CodeWithScopeDecodeValue), + ValueDecoderFunc(codeWithScopeDecodeValue), []subtest{ { "wrong type", @@ -2439,7 +2438,7 @@ func TestDefaultValueDecoders(t *testing.T) { Scope: D{{"bar", nil}}, } val := reflect.New(tCodeWithScope).Elem() - err = dvd.CodeWithScopeDecodeValue(dc, vr, val) + err = codeWithScopeDecodeValue(dc, vr, val) noerr(t, err) got := val.Interface().(CodeWithScope) @@ -2454,7 +2453,7 @@ func TestDefaultValueDecoders(t *testing.T) { want := errors.New("ubsonv error") valUnmarshaler := &testValueUnmarshaler{err: want} - got := dvd.ValueUnmarshalerDecodeValue(dc, llvrw, reflect.ValueOf(valUnmarshaler)) + got := valueUnmarshalerDecodeValue(dc, llvrw, reflect.ValueOf(valUnmarshaler)) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -2466,16 +2465,7 @@ func TestDefaultValueDecoders(t *testing.T) { val := reflect.ValueOf(testValueUnmarshaler{}) want := ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} - got := dvd.ValueUnmarshalerDecodeValue(dc, llvrw, val) - if !assert.CompareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) - - t.Run("SliceCodec/DecodeValue/can't set slice", func(t *testing.T) { - var val []string - want := ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: reflect.ValueOf(val)} - got := dvd.SliceDecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) + got := valueUnmarshalerDecodeValue(dc, llvrw, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -2499,7 +2489,7 @@ func TestDefaultValueDecoders(t *testing.T) { want := fmt.Errorf("more elements returned in array than can fit inside %T, got 2 elements", val) dc := DecodeContext{Registry: buildDefaultRegistry()} - got := dvd.ArrayDecodeValue(dc, vr, reflect.ValueOf(val)) + got := arrayDecodeValue(dc, vr, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -3319,7 +3309,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() - dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} + dc := DecodeContext{Registry: newTestRegistry()} want := ErrNoTypeMapEntry{Type: tc.bsontype} got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) if !assert.CompareErrors(got, want) { @@ -3332,10 +3322,10 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() + reg := newTestRegistry() + reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) dc := DecodeContext{ - Registry: newTestRegistryBuilder(). - RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). - Build(), + Registry: reg, } want := ErrNoDecoder{Type: reflect.TypeOf(tc.val)} got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) @@ -3350,11 +3340,11 @@ func TestDefaultValueDecoders(t *testing.T) { } want := errors.New("DecodeValue failure error") llc := &llCodec{t: t, err: want} + reg := newTestRegistry() + reg.RegisterTypeDecoder(reflect.TypeOf(tc.val), llc) + reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) dc := DecodeContext{ - Registry: newTestRegistryBuilder(). - RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). - RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). - Build(), + Registry: reg, } got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, reflect.New(tEmpty).Elem()) if !assert.CompareErrors(got, want) { @@ -3365,11 +3355,11 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("Success", func(t *testing.T) { want := tc.val llc := &llCodec{t: t, decodeval: tc.val} + reg := newTestRegistry() + reg.RegisterTypeDecoder(reflect.TypeOf(tc.val), llc) + reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) dc := DecodeContext{ - Registry: newTestRegistryBuilder(). - RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). - RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). - Build(), + Registry: reg, } got := reflect.New(tEmpty).Elem() err := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, got) @@ -3404,7 +3394,7 @@ func TestDefaultValueDecoders(t *testing.T) { llvr := &valueReaderWriter{BSONType: TypeDouble} want := ErrNoTypeMapEntry{Type: TypeDouble} val := reflect.New(tEmpty).Elem() - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: newTestRegistryBuilder().Build()}, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: newTestRegistry()}, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3425,15 +3415,15 @@ func TestDefaultValueDecoders(t *testing.T) { // registering a custom type map entry for both Type(0) anad TypeEmbeddedDocument should cause // both top-level and embedded documents to decode to registered type when unmarshalling to interface{} - topLevelRb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(topLevelRb) - defaultValueDecoders.RegisterDefaultDecoders(topLevelRb) - topLevelRb.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) + topLevelReg := newTestRegistry() + registerDefaultEncoders(topLevelReg) + registerDefaultDecoders(topLevelReg) + topLevelReg.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) - embeddedRb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(embeddedRb) - defaultValueDecoders.RegisterDefaultDecoders(embeddedRb) - embeddedRb.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) + embeddedReg := newTestRegistry() + registerDefaultEncoders(embeddedReg) + registerDefaultDecoders(embeddedReg) + embeddedReg.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) // create doc {"nested": {"foo": 1}} innerDoc := bsoncore.BuildDocument( @@ -3454,8 +3444,8 @@ func TestDefaultValueDecoders(t *testing.T) { name string registry *Registry }{ - {"top level", topLevelRb.Build()}, - {"embedded", embeddedRb.Build()}, + {"top level", topLevelReg}, + {"embedded", embeddedReg}, } for _, tc := range testCases { var got interface{} @@ -3473,11 +3463,10 @@ func TestDefaultValueDecoders(t *testing.T) { // If a type map entry is registered for TypeEmbeddedDocument, the decoder should use ancestor // information if available instead of the registered entry. - rb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(rb) - defaultValueDecoders.RegisterDefaultDecoders(rb) - rb.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(M{})) - reg := rb.Build() + reg := newTestRegistry() + registerDefaultEncoders(reg) + registerDefaultDecoders(reg) + reg.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(M{})) // build document {"nested": {"foo": 10}} inner := bsoncore.BuildDocument( @@ -3510,8 +3499,8 @@ func TestDefaultValueDecoders(t *testing.T) { emptyInterfaceErrorDecode := func(DecodeContext, ValueReader, reflect.Value) error { return decodeValueError } - emptyInterfaceErrorRegistry := newTestRegistryBuilder(). - RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)).Build() + emptyInterfaceErrorRegistry := newTestRegistry() + emptyInterfaceErrorRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) // Set up a document {foo: 10} and an error that would happen if the value were decoded into interface{} // using the registry defined above. @@ -3563,11 +3552,9 @@ func TestDefaultValueDecoders(t *testing.T) { outerDoc := buildDocument(bsoncore.AppendDocumentElement(nil, "first", inner1Doc)) // Use a registry that has all default decoders with the custom interface{} decoder that always errors. - nestedRegistryBuilder := newTestRegistryBuilder() - defaultValueDecoders.RegisterDefaultDecoders(nestedRegistryBuilder) - nestedRegistry := nestedRegistryBuilder. - RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)). - Build() + nestedRegistry := newTestRegistry() + registerDefaultDecoders(nestedRegistry) + nestedRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) nestedErr := &DecodeError{ keys: []string{"fourth", "1", "third", "randomKey", "second", "first"}, wrapped: decodeValueError, @@ -3610,7 +3597,7 @@ func TestDefaultValueDecoders(t *testing.T) { [1]E{}, NewValueReader(docBytes), emptyInterfaceErrorRegistry, - ValueDecoderFunc(dvd.ArrayDecodeValue), + ValueDecoderFunc(arrayDecodeValue), docEmptyInterfaceErr, }, { @@ -3621,7 +3608,7 @@ func TestDefaultValueDecoders(t *testing.T) { [1]string{}, &valueReaderWriter{BSONType: TypeArray}, nil, - ValueDecoderFunc(dvd.ArrayDecodeValue), + ValueDecoderFunc(arrayDecodeValue), &DecodeError{ keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type"), @@ -3652,7 +3639,7 @@ func TestDefaultValueDecoders(t *testing.T) { "struct - no decoder found", stringStruct{}, NewValueReader(docBytes), - newTestRegistryBuilder().Build(), + newTestRegistry(), defaultTestStructCodec, stringStructErr, }, @@ -3717,14 +3704,14 @@ func TestDefaultValueDecoders(t *testing.T) { bsoncore.BuildArrayElement(nil, "boolArray", trueValue), ) - rb := newTestRegistryBuilder() - defaultValueDecoders.RegisterDefaultDecoders(rb) - reg := rb.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))).Build() + reg := newTestRegistry() + registerDefaultDecoders(reg) + reg.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) dc := DecodeContext{Registry: reg} vr := NewValueReader(docBytes) val := reflect.New(tD).Elem() - err := defaultValueDecoders.DDecodeValue(dc, vr, val) + err := dDecodeValue(dc, vr, val) assert.Nil(t, err, "DDecodeValue error: %v", err) want := D{ @@ -3786,8 +3773,8 @@ func buildDocument(elems []byte) []byte { } func buildDefaultRegistry() *Registry { - rb := newTestRegistryBuilder() - defaultValueEncoders.RegisterDefaultEncoders(rb) - defaultValueDecoders.RegisterDefaultDecoders(rb) - return rb.Build() + reg := newTestRegistry() + registerDefaultEncoders(reg) + registerDefaultDecoders(reg) + return reg } diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index f2773c36e5..6b28f1594b 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -9,18 +9,14 @@ package bson import ( "encoding/json" "errors" - "fmt" "math" "net/url" "reflect" "sync" - "time" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -var defaultValueEncoders DefaultValueEncoders - var bvwPool = NewValueWriterPool() var errInvalidValue = errors.New("cannot encode invalid element") @@ -53,73 +49,58 @@ func encodeElement(ec EncodeContext, dw DocumentWriter, e E) error { return nil } -// DefaultValueEncoders is a namespace type for the default ValueEncoders used -// when creating a registry. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -type DefaultValueEncoders struct{} - -// RegisterDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with -// the provided RegistryBuilder. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) { - if rb == nil { +// registerDefaultEncoders will register the default encoder methods with the provided Registry. +func registerDefaultEncoders(reg *Registry) { + if reg == nil { panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) } - rb. - RegisterTypeEncoder(tByteSlice, defaultByteSliceCodec). - RegisterTypeEncoder(tTime, defaultTimeCodec). - RegisterTypeEncoder(tEmpty, defaultEmptyInterfaceCodec). - RegisterTypeEncoder(tCoreArray, defaultArrayCodec). - RegisterTypeEncoder(tOID, ValueEncoderFunc(dve.ObjectIDEncodeValue)). - RegisterTypeEncoder(tDecimal, ValueEncoderFunc(dve.Decimal128EncodeValue)). - RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(dve.JSONNumberEncodeValue)). - RegisterTypeEncoder(tURL, ValueEncoderFunc(dve.URLEncodeValue)). - RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(dve.JavaScriptEncodeValue)). - RegisterTypeEncoder(tSymbol, ValueEncoderFunc(dve.SymbolEncodeValue)). - RegisterTypeEncoder(tBinary, ValueEncoderFunc(dve.BinaryEncodeValue)). - RegisterTypeEncoder(tUndefined, ValueEncoderFunc(dve.UndefinedEncodeValue)). - RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dve.DateTimeEncodeValue)). - RegisterTypeEncoder(tNull, ValueEncoderFunc(dve.NullEncodeValue)). - RegisterTypeEncoder(tRegex, ValueEncoderFunc(dve.RegexEncodeValue)). - RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dve.DBPointerEncodeValue)). - RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(dve.TimestampEncodeValue)). - RegisterTypeEncoder(tMinKey, ValueEncoderFunc(dve.MinKeyEncodeValue)). - RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(dve.MaxKeyEncodeValue)). - RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(dve.CoreDocumentEncodeValue)). - RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(dve.CodeWithScopeEncodeValue)). - RegisterDefaultEncoder(reflect.Bool, ValueEncoderFunc(dve.BooleanEncodeValue)). - RegisterDefaultEncoder(reflect.Int, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int8, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int16, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int32, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Int64, ValueEncoderFunc(dve.IntEncodeValue)). - RegisterDefaultEncoder(reflect.Uint, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint8, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint16, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint32, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Uint64, defaultUIntCodec). - RegisterDefaultEncoder(reflect.Float32, ValueEncoderFunc(dve.FloatEncodeValue)). - RegisterDefaultEncoder(reflect.Float64, ValueEncoderFunc(dve.FloatEncodeValue)). - RegisterDefaultEncoder(reflect.Array, ValueEncoderFunc(dve.ArrayEncodeValue)). - RegisterDefaultEncoder(reflect.Map, defaultMapCodec). - RegisterDefaultEncoder(reflect.Slice, defaultSliceCodec). - RegisterDefaultEncoder(reflect.String, defaultStringCodec). - RegisterDefaultEncoder(reflect.Struct, newDefaultStructCodec()). - RegisterDefaultEncoder(reflect.Ptr, NewPointerCodec()). - RegisterHookEncoder(tValueMarshaler, ValueEncoderFunc(dve.ValueMarshalerEncodeValue)). - RegisterHookEncoder(tMarshaler, ValueEncoderFunc(dve.MarshalerEncodeValue)). - RegisterHookEncoder(tProxy, ValueEncoderFunc(dve.ProxyEncodeValue)) -} - -// BooleanEncodeValue is the ValueEncoderFunc for bool types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) BooleanEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + reg.RegisterTypeEncoder(tByteSlice, defaultByteSliceCodec) + reg.RegisterTypeEncoder(tTime, defaultTimeCodec) + reg.RegisterTypeEncoder(tEmpty, defaultEmptyInterfaceCodec) + reg.RegisterTypeEncoder(tCoreArray, defaultArrayCodec) + reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue)) + reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue)) + reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue)) + reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue)) + reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue)) + reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue)) + reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue)) + reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue)) + reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue)) + reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue)) + reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue)) + reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue)) + reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue)) + reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue)) + reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue)) + reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue)) + reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue)) + reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue)) + reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue)) + reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue)) + reg.RegisterKindEncoder(reflect.Int16, ValueEncoderFunc(intEncodeValue)) + reg.RegisterKindEncoder(reflect.Int32, ValueEncoderFunc(intEncodeValue)) + reg.RegisterKindEncoder(reflect.Int64, ValueEncoderFunc(intEncodeValue)) + reg.RegisterKindEncoder(reflect.Uint, defaultUIntCodec) + reg.RegisterKindEncoder(reflect.Uint8, defaultUIntCodec) + reg.RegisterKindEncoder(reflect.Uint16, defaultUIntCodec) + reg.RegisterKindEncoder(reflect.Uint32, defaultUIntCodec) + reg.RegisterKindEncoder(reflect.Uint64, defaultUIntCodec) + reg.RegisterKindEncoder(reflect.Float32, ValueEncoderFunc(floatEncodeValue)) + reg.RegisterKindEncoder(reflect.Float64, ValueEncoderFunc(floatEncodeValue)) + reg.RegisterKindEncoder(reflect.Array, ValueEncoderFunc(arrayEncodeValue)) + reg.RegisterKindEncoder(reflect.Map, defaultMapCodec) + reg.RegisterKindEncoder(reflect.Slice, defaultSliceCodec) + reg.RegisterKindEncoder(reflect.String, defaultStringCodec) + reg.RegisterKindEncoder(reflect.Struct, newDefaultStructCodec()) + reg.RegisterKindEncoder(reflect.Ptr, NewPointerCodec()) + reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue)) + reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue)) + reg.RegisterInterfaceEncoder(tProxy, ValueEncoderFunc(proxyEncodeValue)) +} + +// booleanEncodeValue is the ValueEncoderFunc for bool types. +func booleanEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Bool { return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } @@ -130,11 +111,8 @@ func fitsIn32Bits(i int64) bool { return math.MinInt32 <= i && i <= math.MaxInt32 } -// IntEncodeValue is the ValueEncoderFunc for int types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// intEncodeValue is the ValueEncoderFunc for int types. +func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Int8, reflect.Int16, reflect.Int32: return vw.WriteInt32(int32(val.Int())) @@ -159,36 +137,8 @@ func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw ValueWriter, } } -// UintEncodeValue is the ValueEncoderFunc for uint types. -// -// Deprecated: UintEncodeValue is not registered by default. Use UintCodec.EncodeValue instead. -func (dve DefaultValueEncoders) UintEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - switch val.Kind() { - case reflect.Uint8, reflect.Uint16: - return vw.WriteInt32(int32(val.Uint())) - case reflect.Uint, reflect.Uint32, reflect.Uint64: - u64 := val.Uint() - if ec.MinSize && u64 <= math.MaxInt32 { - return vw.WriteInt32(int32(u64)) - } - if u64 > math.MaxInt64 { - return fmt.Errorf("%d overflows int64", u64) - } - return vw.WriteInt64(int64(u64)) - } - - return ValueEncoderError{ - Name: "UintEncodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } -} - -// FloatEncodeValue is the ValueEncoderFunc for float types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) FloatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// floatEncodeValue is the ValueEncoderFunc for float types. +func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Float32, reflect.Float64: return vw.WriteDouble(val.Float()) @@ -197,48 +147,24 @@ func (dve DefaultValueEncoders) FloatEncodeValue(_ EncodeContext, vw ValueWriter return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} } -// StringEncodeValue is the ValueEncoderFunc for string types. -// -// Deprecated: StringEncodeValue is not registered by default. Use StringCodec.EncodeValue instead. -func (dve DefaultValueEncoders) StringEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if val.Kind() != reflect.String { - return ValueEncoderError{ - Name: "StringEncodeValue", - Kinds: []reflect.Kind{reflect.String}, - Received: val, - } - } - - return vw.WriteString(val.String()) -} - -// ObjectIDEncodeValue is the ValueEncoderFunc for ObjectID. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ObjectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// objectIDEncodeValue is the ValueEncoderFunc for ObjectID. +func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tOID { return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} } return vw.WriteObjectID(val.Interface().(ObjectID)) } -// Decimal128EncodeValue is the ValueEncoderFunc for Decimal128. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) Decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// decimal128EncodeValue is the ValueEncoderFunc for Decimal128. +func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDecimal { return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} } return vw.WriteDecimal128(val.Interface().(Decimal128)) } -// JSONNumberEncodeValue is the ValueEncoderFunc for json.Number. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. +func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJSONNumber { return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} } @@ -246,7 +172,7 @@ func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw Value // Attempt int first, then float64 if i64, err := jsnum.Int64(); err == nil { - return dve.IntEncodeValue(ec, vw, reflect.ValueOf(i64)) + return intEncodeValue(ec, vw, reflect.ValueOf(i64)) } f64, err := jsnum.Float64() @@ -254,14 +180,11 @@ func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw Value return err } - return dve.FloatEncodeValue(ec, vw, reflect.ValueOf(f64)) + return floatEncodeValue(ec, vw, reflect.ValueOf(f64)) } -// URLEncodeValue is the ValueEncoderFunc for url.URL. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) URLEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// urlEncodeValue is the ValueEncoderFunc for url.URL. +func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tURL { return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} } @@ -269,108 +192,8 @@ func (dve DefaultValueEncoders) URLEncodeValue(_ EncodeContext, vw ValueWriter, return vw.WriteString(u.String()) } -// TimeEncodeValue is the ValueEncoderFunc for time.TIme. -// -// Deprecated: TimeEncodeValue is not registered by default. Use TimeCodec.EncodeValue instead. -func (dve DefaultValueEncoders) TimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tTime { - return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} - } - tt := val.Interface().(time.Time) - dt := NewDateTimeFromTime(tt) - return vw.WriteDateTime(int64(dt)) -} - -// ByteSliceEncodeValue is the ValueEncoderFunc for []byte. -// -// Deprecated: ByteSliceEncodeValue is not registered by default. Use ByteSliceCodec.EncodeValue instead. -func (dve DefaultValueEncoders) ByteSliceEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tByteSlice { - return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} - } - if val.IsNil() { - return vw.WriteNull() - } - return vw.WriteBinary(val.Interface().([]byte)) -} - -// MapEncodeValue is the ValueEncoderFunc for map[string]* types. -// -// Deprecated: MapEncodeValue is not registered by default. Use MapCodec.EncodeValue instead. -func (dve DefaultValueEncoders) MapEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String { - return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} - } - - if val.IsNil() { - // If we have a nill map but we can't WriteNull, that means we're probably trying to encode - // to a TopLevel document. We can't currently tell if this is what actually happened, but if - // there's a deeper underlying problem, the error will also be returned from WriteDocument, - // so just continue. The operations on a map reflection value are valid, so we can call - // MapKeys within mapEncodeValue without a problem. - err := vw.WriteNull() - if err == nil { - return nil - } - } - - dw, err := vw.WriteDocument() - if err != nil { - return err - } - - return dve.mapEncodeValue(ec, dw, val, nil) -} - -// mapEncodeValue handles encoding of the values of a map. The collisionFn returns -// true if the provided key exists, this is mainly used for inline maps in the -// struct codec. -func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { - - elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) - if err != nil && elemType.Kind() != reflect.Interface { - return err - } - - keys := val.MapKeys() - for _, key := range keys { - if collisionFn != nil && collisionFn(key.String()) { - return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) - } - - currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.MapIndex(key)) - if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { - return lookupErr - } - - vw, err := dw.WriteDocumentElement(key.String()) - if err != nil { - return err - } - - if errors.Is(lookupErr, errInvalidValue) { - err = vw.WriteNull() - if err != nil { - return err - } - continue - } - - err = currEncoder.EncodeValue(ec, vw, currVal) - if err != nil { - return err - } - } - - return dw.WriteDocumentEnd() -} - -// ArrayEncodeValue is the ValueEncoderFunc for array types. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// arrayEncodeValue is the ValueEncoderFunc for array types. +func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} } @@ -414,76 +237,7 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw ValueWrite } for idx := 0; idx < val.Len(); idx++ { - currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) - if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { - return lookupErr - } - - vw, err := aw.WriteArrayElement() - if err != nil { - return err - } - - if errors.Is(lookupErr, errInvalidValue) { - err = vw.WriteNull() - if err != nil { - return err - } - continue - } - - err = currEncoder.EncodeValue(ec, vw, currVal) - if err != nil { - return err - } - } - return aw.WriteArrayEnd() -} - -// SliceEncodeValue is the ValueEncoderFunc for slice types. -// -// Deprecated: SliceEncodeValue is not registered by default. Use SliceCodec.EncodeValue instead. -func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Kind() != reflect.Slice { - return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} - } - - if val.IsNil() { - return vw.WriteNull() - } - - // If we have a []E we want to treat it as a document instead of as an array. - if val.Type().ConvertibleTo(tD) { - d := val.Convert(tD).Interface().(D) - - dw, err := vw.WriteDocument() - if err != nil { - return err - } - - for _, e := range d { - err = encodeElement(ec, dw, e) - if err != nil { - return err - } - } - - return dw.WriteDocumentEnd() - } - - aw, err := vw.WriteArray() - if err != nil { - return err - } - - elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) - if err != nil && elemType.Kind() != reflect.Interface { - return err - } - - for idx := 0; idx < val.Len(); idx++ { - currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) + currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -509,7 +263,7 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw ValueWrite return aw.WriteArrayEnd() } -func (dve DefaultValueEncoders) lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { +func lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { if origEncoder != nil || (currVal.Kind() != reflect.Interface) { return origEncoder, currVal, nil } @@ -522,30 +276,8 @@ func (dve DefaultValueEncoders) lookupElementEncoder(ec EncodeContext, origEncod return currEncoder, currVal, err } -// EmptyInterfaceEncodeValue is the ValueEncoderFunc for interface{}. -// -// Deprecated: EmptyInterfaceEncodeValue is not registered by default. Use EmptyInterfaceCodec.EncodeValue instead. -func (dve DefaultValueEncoders) EmptyInterfaceEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if !val.IsValid() || val.Type() != tEmpty { - return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} - } - - if val.IsNil() { - return vw.WriteNull() - } - encoder, err := ec.LookupEncoder(val.Elem().Type()) - if err != nil { - return err - } - - return encoder.EncodeValue(ec, vw, val.Elem()) -} - -// ValueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// valueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. +func valueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement ValueMarshaler switch { case !val.IsValid(): @@ -572,11 +304,8 @@ func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw Va return copyValueFromBytes(vw, t, data) } -// MarshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// marshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. +func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Marshaler switch { case !val.IsValid(): @@ -603,11 +332,8 @@ func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw ValueWr return copyValueFromBytes(vw, TypeEmbeddedDocument, data) } -// ProxyEncodeValue is the ValueEncoderFunc for Proxy implementations. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// proxyEncodeValue is the ValueEncoderFunc for Proxy implementations. +func proxyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Proxy switch { case !val.IsValid(): @@ -650,11 +376,8 @@ func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw ValueWrite return encoder.EncodeValue(ec, vw, vv) } -// JavaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) JavaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. +func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJavaScript { return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} } @@ -662,11 +385,8 @@ func (DefaultValueEncoders) JavaScriptEncodeValue(_ EncodeContext, vw ValueWrite return vw.WriteJavascript(val.String()) } -// SymbolEncodeValue is the ValueEncoderFunc for the Symbol type. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) SymbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// symbolEncodeValue is the ValueEncoderFunc for the Symbol type. +func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tSymbol { return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} } @@ -674,11 +394,8 @@ func (DefaultValueEncoders) SymbolEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteSymbol(val.String()) } -// BinaryEncodeValue is the ValueEncoderFunc for Binary. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) BinaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// binaryEncodeValue is the ValueEncoderFunc for Binary. +func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tBinary { return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} } @@ -687,11 +404,8 @@ func (DefaultValueEncoders) BinaryEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) } -// UndefinedEncodeValue is the ValueEncoderFunc for Undefined. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) UndefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// undefinedEncodeValue is the ValueEncoderFunc for Undefined. +func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tUndefined { return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} } @@ -699,11 +413,8 @@ func (DefaultValueEncoders) UndefinedEncodeValue(_ EncodeContext, vw ValueWriter return vw.WriteUndefined() } -// DateTimeEncodeValue is the ValueEncoderFunc for DateTime. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) DateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// dateTimeEncodeValue is the ValueEncoderFunc for DateTime. +func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDateTime { return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} } @@ -711,11 +422,8 @@ func (DefaultValueEncoders) DateTimeEncodeValue(_ EncodeContext, vw ValueWriter, return vw.WriteDateTime(val.Int()) } -// NullEncodeValue is the ValueEncoderFunc for Null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) NullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// nullEncodeValue is the ValueEncoderFunc for Null. +func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tNull { return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} } @@ -723,11 +431,8 @@ func (DefaultValueEncoders) NullEncodeValue(_ EncodeContext, vw ValueWriter, val return vw.WriteNull() } -// RegexEncodeValue is the ValueEncoderFunc for Regex. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) RegexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// regexEncodeValue is the ValueEncoderFunc for Regex. +func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRegex { return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} } @@ -737,11 +442,8 @@ func (DefaultValueEncoders) RegexEncodeValue(_ EncodeContext, vw ValueWriter, va return vw.WriteRegex(regex.Pattern, regex.Options) } -// DBPointerEncodeValue is the ValueEncoderFunc for DBPointer. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) DBPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. +func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDBPointer { return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} } @@ -751,11 +453,8 @@ func (DefaultValueEncoders) DBPointerEncodeValue(_ EncodeContext, vw ValueWriter return vw.WriteDBPointer(dbp.DB, dbp.Pointer) } -// TimestampEncodeValue is the ValueEncoderFunc for Timestamp. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) TimestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// timestampEncodeValue is the ValueEncoderFunc for Timestamp. +func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTimestamp { return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } @@ -765,11 +464,8 @@ func (DefaultValueEncoders) TimestampEncodeValue(_ EncodeContext, vw ValueWriter return vw.WriteTimestamp(ts.T, ts.I) } -// MinKeyEncodeValue is the ValueEncoderFunc for MinKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) MinKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// minKeyEncodeValue is the ValueEncoderFunc for MinKey. +func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMinKey { return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} } @@ -777,11 +473,8 @@ func (DefaultValueEncoders) MinKeyEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteMinKey() } -// MaxKeyEncodeValue is the ValueEncoderFunc for MaxKey. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) MaxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. +func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMaxKey { return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} } @@ -789,11 +482,8 @@ func (DefaultValueEncoders) MaxKeyEncodeValue(_ EncodeContext, vw ValueWriter, v return vw.WriteMaxKey() } -// CoreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (DefaultValueEncoders) CoreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. +func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreDocument { return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -803,11 +493,8 @@ func (DefaultValueEncoders) CoreDocumentEncodeValue(_ EncodeContext, vw ValueWri return copyDocumentFromBytes(vw, cdoc) } -// CodeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default -// value encoders registered. -func (dve DefaultValueEncoders) CodeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. +func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCodeWithScope { return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} } diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 481c6cb1a1..1ebd57f891 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -35,7 +35,6 @@ func (ms myStruct) Foo() int { } func TestDefaultValueEncoders(t *testing.T) { - var dve DefaultValueEncoders var wrong = func(string, string) string { return "wrong" } type mybool bool @@ -80,7 +79,7 @@ func TestDefaultValueEncoders(t *testing.T) { }{ { "BooleanEncodeValue", - ValueEncoderFunc(dve.BooleanEncodeValue), + ValueEncoderFunc(booleanEncodeValue), []subtest{ { "wrong type", @@ -96,7 +95,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "IntEncodeValue", - ValueEncoderFunc(dve.IntEncodeValue), + ValueEncoderFunc(intEncodeValue), []subtest{ { "wrong type", @@ -177,7 +176,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "FloatEncodeValue", - ValueEncoderFunc(dve.FloatEncodeValue), + ValueEncoderFunc(floatEncodeValue), []subtest{ { "wrong type", @@ -235,7 +234,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", map[string]int{"foo": 1}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + &EncodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, writeDocument, fmt.Errorf("no encoder found for int"), @@ -259,7 +258,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty map/success", map[string]interface{}{}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + &EncodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, writeDocumentEnd, nil, @@ -294,7 +293,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ArrayEncodeValue", - ValueEncoderFunc(dve.ArrayEncodeValue), + ValueEncoderFunc(arrayEncodeValue), []subtest{ { "wrong kind", @@ -315,7 +314,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", [1]int{1}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + &EncodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -393,7 +392,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", []int{1}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + &EncodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -433,7 +432,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty slice/success", []interface{}{}, - &EncodeContext{Registry: newTestRegistryBuilder().Build()}, + &EncodeContext{Registry: newTestRegistry()}, &valueReaderWriter{}, writeArrayEnd, nil, @@ -458,7 +457,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ObjectIDEncodeValue", - ValueEncoderFunc(dve.ObjectIDEncodeValue), + ValueEncoderFunc(objectIDEncodeValue), []subtest{ { "wrong type", @@ -477,7 +476,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "Decimal128EncodeValue", - ValueEncoderFunc(dve.Decimal128EncodeValue), + ValueEncoderFunc(decimal128EncodeValue), []subtest{ { "wrong type", @@ -492,7 +491,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "JSONNumberEncodeValue", - ValueEncoderFunc(dve.JSONNumberEncodeValue), + ValueEncoderFunc(jsonNumberEncodeValue), []subtest{ { "wrong type", @@ -521,7 +520,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "URLEncodeValue", - ValueEncoderFunc(dve.URLEncodeValue), + ValueEncoderFunc(urlEncodeValue), []subtest{ { "wrong type", @@ -566,7 +565,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ValueMarshalerEncodeValue", - ValueEncoderFunc(dve.ValueMarshalerEncodeValue), + ValueEncoderFunc(valueMarshalerEncodeValue), []subtest{ { "wrong type", @@ -644,7 +643,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MarshalerEncodeValue", - ValueEncoderFunc(dve.MarshalerEncodeValue), + ValueEncoderFunc(marshalerEncodeValue), []subtest{ { "wrong type", @@ -706,7 +705,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ProxyEncodeValue", - ValueEncoderFunc(dve.ProxyEncodeValue), + ValueEncoderFunc(proxyEncodeValue), []subtest{ { "wrong type", @@ -844,7 +843,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "JavaScriptEncodeValue", - ValueEncoderFunc(dve.JavaScriptEncodeValue), + ValueEncoderFunc(javaScriptEncodeValue), []subtest{ { "wrong type", @@ -859,7 +858,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "SymbolEncodeValue", - ValueEncoderFunc(dve.SymbolEncodeValue), + ValueEncoderFunc(symbolEncodeValue), []subtest{ { "wrong type", @@ -874,7 +873,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "BinaryEncodeValue", - ValueEncoderFunc(dve.BinaryEncodeValue), + ValueEncoderFunc(binaryEncodeValue), []subtest{ { "wrong type", @@ -889,7 +888,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "UndefinedEncodeValue", - ValueEncoderFunc(dve.UndefinedEncodeValue), + ValueEncoderFunc(undefinedEncodeValue), []subtest{ { "wrong type", @@ -904,7 +903,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "DateTimeEncodeValue", - ValueEncoderFunc(dve.DateTimeEncodeValue), + ValueEncoderFunc(dateTimeEncodeValue), []subtest{ { "wrong type", @@ -919,7 +918,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "NullEncodeValue", - ValueEncoderFunc(dve.NullEncodeValue), + ValueEncoderFunc(nullEncodeValue), []subtest{ { "wrong type", @@ -934,7 +933,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "RegexEncodeValue", - ValueEncoderFunc(dve.RegexEncodeValue), + ValueEncoderFunc(regexEncodeValue), []subtest{ { "wrong type", @@ -949,7 +948,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "DBPointerEncodeValue", - ValueEncoderFunc(dve.DBPointerEncodeValue), + ValueEncoderFunc(dbPointerEncodeValue), []subtest{ { "wrong type", @@ -971,7 +970,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "TimestampEncodeValue", - ValueEncoderFunc(dve.TimestampEncodeValue), + ValueEncoderFunc(timestampEncodeValue), []subtest{ { "wrong type", @@ -986,7 +985,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MinKeyEncodeValue", - ValueEncoderFunc(dve.MinKeyEncodeValue), + ValueEncoderFunc(minKeyEncodeValue), []subtest{ { "wrong type", @@ -1001,7 +1000,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MaxKeyEncodeValue", - ValueEncoderFunc(dve.MaxKeyEncodeValue), + ValueEncoderFunc(maxKeyEncodeValue), []subtest{ { "wrong type", @@ -1016,7 +1015,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CoreDocumentEncodeValue", - ValueEncoderFunc(dve.CoreDocumentEncodeValue), + ValueEncoderFunc(coreDocumentEncodeValue), []subtest{ { "wrong type", @@ -1096,7 +1095,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CodeWithScopeEncodeValue", - ValueEncoderFunc(dve.CodeWithScopeEncodeValue), + ValueEncoderFunc(codeWithScopeEncodeValue), []subtest{ { "wrong type", @@ -1828,27 +1827,6 @@ func TestDefaultValueEncoders(t *testing.T) { }) } }) - - t.Run("EmptyInterfaceEncodeValue/nil", func(t *testing.T) { - val := reflect.New(tEmpty).Elem() - llvrw := new(valueReaderWriter) - err := dve.EmptyInterfaceEncodeValue(EncodeContext{Registry: newTestRegistryBuilder().Build()}, llvrw, val) - noerr(t, err) - if llvrw.invoked != writeNull { - t.Errorf("Incorrect method called. got %v; want %v", llvrw.invoked, writeNull) - } - }) - - t.Run("EmptyInterfaceEncodeValue/LookupEncoder error", func(t *testing.T) { - val := reflect.New(tEmpty).Elem() - val.Set(reflect.ValueOf(int64(1234567890))) - llvrw := new(valueReaderWriter) - got := dve.EmptyInterfaceEncodeValue(EncodeContext{Registry: newTestRegistryBuilder().Build()}, llvrw, val) - want := ErrNoEncoder{Type: tInt64} - if !assert.CompareErrors(got, want) { - t.Errorf("Did not receive expected error. got %v; want %v", got, want) - } - }) } type testValueMarshalPtr struct { diff --git a/bson/map_codec.go b/bson/map_codec.go index 9592957db4..f1294ae99d 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -126,7 +126,7 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val refl return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) } - currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key)) + currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.MapIndex(key)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 6651509983..a74a5a892d 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -1442,9 +1442,6 @@ var twoWayCrossItems = []crossTypeItem{ {&struct{ S []byte }{[]byte("def")}, &struct{ S bson.Symbol }{"def"}}, {&struct{ S string }{"ghi"}, &struct{ S bson.Symbol }{"ghi"}}, - {&struct{ S string }{"0123456789ab"}, - &struct{ S bson.ObjectID }{bson.ObjectID{0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x61, 0x62}}}, - // map <=> struct {&struct { A struct { diff --git a/bson/mgocompat/registry.go b/bson/mgocompat/registry.go index 7024ab9fdc..0d61a029ec 100644 --- a/bson/mgocompat/registry.go +++ b/bson/mgocompat/registry.go @@ -31,19 +31,14 @@ var ( // Registry is the mgo compatible bson.Registry. It contains the default and // primitive codecs with mgo compatible options. -var Registry = NewRegistryBuilder().Build() +var Registry = newRegistry() // RespectNilValuesRegistry is the bson.Registry compatible with mgo withSetRespectNilValues set to true. -var RespectNilValuesRegistry = NewRespectNilValuesRegistryBuilder().Build() +var RespectNilValuesRegistry = newRespectNilValuesRegistry() -// NewRegistryBuilder creates a new bson.RegistryBuilder configured with the default encoders and -// decoders from the bson.DefaultValueEncoders and bson.DefaultValueDecoders types and the -// PrimitiveCodecs type in this package. -func NewRegistryBuilder() *bson.RegistryBuilder { - rb := bson.NewRegistryBuilder() - bson.DefaultValueEncoders{}.RegisterDefaultEncoders(rb) - bson.DefaultValueDecoders{}.RegisterDefaultDecoders(rb) - bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb) +// newRegistry creates a new bson.Registry configured with the default encoders and decoders. +func newRegistry() *bson.Registry { + reg := bson.NewRegistry() structcodec, _ := bson.NewStructCodec(bson.DefaultStructTagParser, bsonoptions.StructCodec(). @@ -61,34 +56,34 @@ func NewRegistryBuilder() *bson.RegistryBuilder { SetEncodeKeysWithStringer(true)) uintcodec := bson.NewUIntCodec(bsonoptions.UIntCodec().SetEncodeToMinSize(true)) - rb.RegisterTypeDecoder(tEmpty, emptyInterCodec). - RegisterDefaultDecoder(reflect.String, bson.NewStringCodec(bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false))). - RegisterDefaultDecoder(reflect.Struct, structcodec). - RegisterDefaultDecoder(reflect.Map, mapCodec). - RegisterTypeEncoder(tByteSlice, bson.NewByteSliceCodec(bsonoptions.ByteSliceCodec().SetEncodeNilAsEmpty(true))). - RegisterDefaultEncoder(reflect.Struct, structcodec). - RegisterDefaultEncoder(reflect.Slice, bson.NewSliceCodec(bsonoptions.SliceCodec().SetEncodeNilAsEmpty(true))). - RegisterDefaultEncoder(reflect.Map, mapCodec). - RegisterDefaultEncoder(reflect.Uint, uintcodec). - RegisterDefaultEncoder(reflect.Uint8, uintcodec). - RegisterDefaultEncoder(reflect.Uint16, uintcodec). - RegisterDefaultEncoder(reflect.Uint32, uintcodec). - RegisterDefaultEncoder(reflect.Uint64, uintcodec). - RegisterTypeMapEntry(bson.TypeInt32, tInt). - RegisterTypeMapEntry(bson.TypeDateTime, tTime). - RegisterTypeMapEntry(bson.TypeArray, tInterfaceSlice). - RegisterTypeMapEntry(bson.Type(0), tM). - RegisterTypeMapEntry(bson.TypeEmbeddedDocument, tM). - RegisterHookEncoder(tGetter, bson.ValueEncoderFunc(GetterEncodeValue)). - RegisterHookDecoder(tSetter, bson.ValueDecoderFunc(SetterDecodeValue)) + reg.RegisterTypeDecoder(tEmpty, emptyInterCodec) + reg.RegisterKindDecoder(reflect.String, bson.NewStringCodec(bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false))) + reg.RegisterKindDecoder(reflect.Struct, structcodec) + reg.RegisterKindDecoder(reflect.Map, mapCodec) + reg.RegisterTypeEncoder(tByteSlice, bson.NewByteSliceCodec(bsonoptions.ByteSliceCodec().SetEncodeNilAsEmpty(true))) + reg.RegisterKindEncoder(reflect.Struct, structcodec) + reg.RegisterKindEncoder(reflect.Slice, bson.NewSliceCodec(bsonoptions.SliceCodec().SetEncodeNilAsEmpty(true))) + reg.RegisterKindEncoder(reflect.Map, mapCodec) + reg.RegisterKindEncoder(reflect.Uint, uintcodec) + reg.RegisterKindEncoder(reflect.Uint8, uintcodec) + reg.RegisterKindEncoder(reflect.Uint16, uintcodec) + reg.RegisterKindEncoder(reflect.Uint32, uintcodec) + reg.RegisterKindEncoder(reflect.Uint64, uintcodec) + reg.RegisterTypeMapEntry(bson.TypeInt32, tInt) + reg.RegisterTypeMapEntry(bson.TypeDateTime, tTime) + reg.RegisterTypeMapEntry(bson.TypeArray, tInterfaceSlice) + reg.RegisterTypeMapEntry(bson.Type(0), tM) + reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, tM) + reg.RegisterInterfaceEncoder(tGetter, bson.ValueEncoderFunc(GetterEncodeValue)) + reg.RegisterInterfaceDecoder(tSetter, bson.ValueDecoderFunc(SetterDecodeValue)) - return rb + return reg } -// NewRespectNilValuesRegistryBuilder creates a new bson.RegistryBuilder configured to behave like mgo/bson +// newRespectNilValuesRegistry creates a new bson.Registry configured to behave like mgo/bson // with RespectNilValues set to true. -func NewRespectNilValuesRegistryBuilder() *bson.RegistryBuilder { - rb := NewRegistryBuilder() +func newRespectNilValuesRegistry() *bson.Registry { + reg := newRegistry() structcodec, _ := bson.NewStructCodec(bson.DefaultStructTagParser, bsonoptions.StructCodec(). @@ -101,12 +96,12 @@ func NewRespectNilValuesRegistryBuilder() *bson.RegistryBuilder { SetDecodeZerosMap(true). SetEncodeNilAsEmpty(false)) - rb.RegisterDefaultDecoder(reflect.Struct, structcodec). - RegisterDefaultDecoder(reflect.Map, mapCodec). - RegisterTypeEncoder(tByteSlice, bson.NewByteSliceCodec(bsonoptions.ByteSliceCodec().SetEncodeNilAsEmpty(false))). - RegisterDefaultEncoder(reflect.Struct, structcodec). - RegisterDefaultEncoder(reflect.Slice, bson.NewSliceCodec(bsonoptions.SliceCodec().SetEncodeNilAsEmpty(false))). - RegisterDefaultEncoder(reflect.Map, mapCodec) + reg.RegisterKindDecoder(reflect.Struct, structcodec) + reg.RegisterKindDecoder(reflect.Map, mapCodec) + reg.RegisterTypeEncoder(tByteSlice, bson.NewByteSliceCodec(bsonoptions.ByteSliceCodec().SetEncodeNilAsEmpty(false))) + reg.RegisterKindEncoder(reflect.Struct, structcodec) + reg.RegisterKindEncoder(reflect.Slice, bson.NewSliceCodec(bsonoptions.SliceCodec().SetEncodeNilAsEmpty(false))) + reg.RegisterKindEncoder(reflect.Map, mapCodec) - return rb + return reg } diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index 262645ce4c..df6e059c4a 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -15,38 +15,23 @@ import ( var tRawValue = reflect.TypeOf(RawValue{}) var tRaw = reflect.TypeOf(Raw(nil)) -// PrimitiveCodecs is a namespace for all of the default Codecs for the primitive types -// defined in this package. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -type PrimitiveCodecs struct{} - -// RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs -// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *RegistryBuilder) { - if rb == nil { +// registerPrimitiveCodecs will register the encode and decode methods with the provided Registry. +func registerPrimitiveCodecs(reg *Registry) { + if reg == nil { panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil")) } - rb. - RegisterTypeEncoder(tRawValue, ValueEncoderFunc(pc.RawValueEncodeValue)). - RegisterTypeEncoder(tRaw, ValueEncoderFunc(pc.RawEncodeValue)). - RegisterTypeDecoder(tRawValue, ValueDecoderFunc(pc.RawValueDecodeValue)). - RegisterTypeDecoder(tRaw, ValueDecoderFunc(pc.RawDecodeValue)) + reg.RegisterTypeEncoder(tRawValue, ValueEncoderFunc(rawValueEncodeValue)) + reg.RegisterTypeEncoder(tRaw, ValueEncoderFunc(rawEncodeValue)) + reg.RegisterTypeDecoder(tRawValue, ValueDecoderFunc(rawValueDecodeValue)) + reg.RegisterTypeDecoder(tRaw, ValueDecoderFunc(rawDecodeValue)) } -// RawValueEncodeValue is the ValueEncoderFunc for RawValue. +// rawValueEncodeValue is the ValueEncoderFunc for RawValue. // // If the RawValue's Type is "invalid" and the RawValue's Value is not empty or // nil, then this method will return an error. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive -// encoders and decoders registered. -func (PrimitiveCodecs) RawValueEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func rawValueEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRawValue { return ValueEncoderError{ Name: "RawValueEncodeValue", @@ -64,11 +49,8 @@ func (PrimitiveCodecs) RawValueEncodeValue(_ EncodeContext, vw ValueWriter, val return copyValueFromBytes(vw, rawvalue.Type, rawvalue.Value) } -// RawValueDecodeValue is the ValueDecoderFunc for RawValue. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (PrimitiveCodecs) RawValueDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// rawValueDecodeValue is the ValueDecoderFunc for RawValue. +func rawValueDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRawValue { return ValueDecoderError{Name: "RawValueDecodeValue", Types: []reflect.Type{tRawValue}, Received: val} } @@ -82,11 +64,8 @@ func (PrimitiveCodecs) RawValueDecodeValue(_ DecodeContext, vr ValueReader, val return nil } -// RawEncodeValue is the ValueEncoderFunc for Reader. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (PrimitiveCodecs) RawEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +// rawEncodeValue is the ValueEncoderFunc for Reader. +func rawEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRaw { return ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: val} } @@ -96,11 +75,8 @@ func (PrimitiveCodecs) RawEncodeValue(_ EncodeContext, vw ValueWriter, val refle return copyDocumentFromBytes(vw, rdr) } -// RawDecodeValue is the ValueDecoderFunc for Reader. -// -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. -func (PrimitiveCodecs) RawDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +// rawDecodeValue is the ValueDecoderFunc for Reader. +func rawDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRaw { return ValueDecoderError{Name: "RawDecodeValue", Types: []reflect.Type{tRaw}, Received: val} } diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index be3aeab978..6b0c4c1e05 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -32,8 +32,6 @@ func bytesFromDoc(doc interface{}) []byte { func TestPrimitiveValueEncoders(t *testing.T) { t.Parallel() - var pc PrimitiveCodecs - var wrong = func(string, string) string { return "wrong" } type subtest struct { @@ -52,7 +50,7 @@ func TestPrimitiveValueEncoders(t *testing.T) { }{ { "RawValueEncodeValue", - ValueEncoderFunc(pc.RawValueEncodeValue), + ValueEncoderFunc(rawValueEncodeValue), []subtest{ { "wrong type", @@ -100,7 +98,7 @@ func TestPrimitiveValueEncoders(t *testing.T) { }, { "RawEncodeValue", - ValueEncoderFunc(pc.RawEncodeValue), + ValueEncoderFunc(rawEncodeValue), []subtest{ { "wrong type", @@ -478,8 +476,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { } func TestPrimitiveValueDecoders(t *testing.T) { - var pc PrimitiveCodecs - var wrong = func(string, string) string { return "wrong" } const cansetreflectiontest = "cansetreflectiontest" @@ -500,7 +496,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { }{ { "RawValueDecodeValue", - ValueDecoderFunc(pc.RawValueDecodeValue), + ValueDecoderFunc(rawValueDecodeValue), []subtest{ { "wrong type", @@ -544,7 +540,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { }, { "RawDecodeValue", - ValueDecoderFunc(pc.RawDecodeValue), + ValueDecoderFunc(rawDecodeValue), []subtest{ { "wrong type", diff --git a/bson/raw_value_test.go b/bson/raw_value_test.go index f02fe8f326..67444faa61 100644 --- a/bson/raw_value_test.go +++ b/bson/raw_value_test.go @@ -25,7 +25,7 @@ func TestRawValue(t *testing.T) { t.Run("Uses registry attached to value", func(t *testing.T) { t.Parallel() - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() val := RawValue{Type: TypeString, Value: bsoncore.AppendString(nil, "foobar"), r: reg} var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -63,7 +63,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() var val RawValue var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -114,7 +114,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} + dc := DecodeContext{Registry: newTestRegistry()} var val RawValue var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} diff --git a/bson/registry.go b/bson/registry.go index 74b99e93ab..d20424b340 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -22,11 +22,6 @@ var DefaultRegistry = NewRegistry() // Deprecated: ErrNilType will not be supported in Go Driver 2.0. var ErrNilType = errors.New("cannot perform a decoder lookup on ") -// ErrNotPointer is returned when a non-pointer type is provided to LookupDecoder. -// -// Deprecated: ErrNotPointer will not be supported in Go Driver 2.0. -var ErrNotPointer = errors.New("non-pointer provided to LookupDecoder") - // ErrNoEncoder is returned when there wasn't an encoder available for a type. // // Deprecated: ErrNoEncoder will not be supported in Go Driver 2.0. @@ -63,187 +58,6 @@ func (entme ErrNoTypeMapEntry) Error() string { return "no type map entry found for " + entme.Type.String() } -// ErrNotInterface is returned when the provided type is not an interface. -// -// Deprecated: ErrNotInterface will not be supported in Go Driver 2.0. -var ErrNotInterface = errors.New("The provided type is not an interface") - -// A RegistryBuilder is used to build a Registry. This type is not goroutine -// safe. -// -// Deprecated: Use Registry instead. -type RegistryBuilder struct { - registry *Registry -} - -// NewRegistryBuilder creates a new empty RegistryBuilder. -// -// Deprecated: Use NewRegistry instead. -func NewRegistryBuilder() *RegistryBuilder { - rb := &RegistryBuilder{ - registry: &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - }, - } - DefaultValueEncoders{}.RegisterDefaultEncoders(rb) - DefaultValueDecoders{}.RegisterDefaultDecoders(rb) - PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb) - return rb -} - -// RegisterCodec will register the provided ValueCodec for the provided type. -// -// Deprecated: Use Registry.RegisterTypeEncoder and Registry.RegisterTypeDecoder instead. -func (rb *RegistryBuilder) RegisterCodec(t reflect.Type, codec ValueCodec) *RegistryBuilder { - rb.RegisterTypeEncoder(t, codec) - rb.RegisterTypeDecoder(t, codec) - return rb -} - -// RegisterTypeEncoder will register the provided ValueEncoder for the provided type. -// -// The type will be used directly, so an encoder can be registered for a type and a different encoder can be registered -// for a pointer to that type. -// -// If the given type is an interface, the encoder will be called when marshaling a type that is that interface. It -// will not be called when marshaling a non-interface type that implements the interface. -// -// Deprecated: Use Registry.RegisterTypeEncoder instead. -func (rb *RegistryBuilder) RegisterTypeEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { - rb.registry.RegisterTypeEncoder(t, enc) - return rb -} - -// RegisterHookEncoder will register an encoder for the provided interface type t. This encoder will be called when -// marshaling a type if the type implements t or a pointer to the type implements t. If the provided type is not -// an interface (i.e. t.Kind() != reflect.Interface), this method will panic. -// -// Deprecated: Use Registry.RegisterInterfaceEncoder instead. -func (rb *RegistryBuilder) RegisterHookEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { - rb.registry.RegisterInterfaceEncoder(t, enc) - return rb -} - -// RegisterTypeDecoder will register the provided ValueDecoder for the provided type. -// -// The type will be used directly, so a decoder can be registered for a type and a different decoder can be registered -// for a pointer to that type. -// -// If the given type is an interface, the decoder will be called when unmarshaling into a type that is that interface. -// It will not be called when unmarshaling into a non-interface type that implements the interface. -// -// Deprecated: Use Registry.RegisterTypeDecoder instead. -func (rb *RegistryBuilder) RegisterTypeDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { - rb.registry.RegisterTypeDecoder(t, dec) - return rb -} - -// RegisterHookDecoder will register an decoder for the provided interface type t. This decoder will be called when -// unmarshaling into a type if the type implements t or a pointer to the type implements t. If the provided type is not -// an interface (i.e. t.Kind() != reflect.Interface), this method will panic. -// -// Deprecated: Use Registry.RegisterInterfaceDecoder instead. -func (rb *RegistryBuilder) RegisterHookDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { - rb.registry.RegisterInterfaceDecoder(t, dec) - return rb -} - -// RegisterEncoder registers the provided type and encoder pair. -// -// Deprecated: Use Registry.RegisterTypeEncoder or Registry.RegisterInterfaceEncoder instead. -func (rb *RegistryBuilder) RegisterEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { - if t == tEmpty { - rb.registry.RegisterTypeEncoder(t, enc) - return rb - } - switch t.Kind() { - case reflect.Interface: - rb.registry.RegisterInterfaceEncoder(t, enc) - default: - rb.registry.RegisterTypeEncoder(t, enc) - } - return rb -} - -// RegisterDecoder registers the provided type and decoder pair. -// -// Deprecated: Use Registry.RegisterTypeDecoder or Registry.RegisterInterfaceDecoder instead. -func (rb *RegistryBuilder) RegisterDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { - if t == nil { - rb.registry.RegisterTypeDecoder(t, dec) - return rb - } - if t == tEmpty { - rb.registry.RegisterTypeDecoder(t, dec) - return rb - } - switch t.Kind() { - case reflect.Interface: - rb.registry.RegisterInterfaceDecoder(t, dec) - default: - rb.registry.RegisterTypeDecoder(t, dec) - } - return rb -} - -// RegisterDefaultEncoder will register the provided ValueEncoder to the provided -// kind. -// -// Deprecated: Use Registry.RegisterKindEncoder instead. -func (rb *RegistryBuilder) RegisterDefaultEncoder(kind reflect.Kind, enc ValueEncoder) *RegistryBuilder { - rb.registry.RegisterKindEncoder(kind, enc) - return rb -} - -// RegisterDefaultDecoder will register the provided ValueDecoder to the -// provided kind. -// -// Deprecated: Use Registry.RegisterKindDecoder instead. -func (rb *RegistryBuilder) RegisterDefaultDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder { - rb.registry.RegisterKindDecoder(kind, dec) - return rb -} - -// RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this -// mapping is decoding situations where an empty interface is used and a default type needs to be -// created and decoded into. -// -// By default, BSON documents will decode into interface{} values as bson.D. To change the default type for BSON -// documents, a type map entry for TypeEmbeddedDocument should be registered. For example, to force BSON documents -// to decode to bson.Raw, use the following code: -// -// rb.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) -// -// Deprecated: Use Registry.RegisterTypeMapEntry instead. -func (rb *RegistryBuilder) RegisterTypeMapEntry(bt Type, rt reflect.Type) *RegistryBuilder { - rb.registry.RegisterTypeMapEntry(bt, rt) - return rb -} - -// Build creates a Registry from the current state of this RegistryBuilder. -// -// Deprecated: Use NewRegistry instead. -func (rb *RegistryBuilder) Build() *Registry { - r := &Registry{ - interfaceEncoders: append([]interfaceValueEncoder(nil), rb.registry.interfaceEncoders...), - interfaceDecoders: append([]interfaceValueDecoder(nil), rb.registry.interfaceDecoders...), - typeEncoders: rb.registry.typeEncoders.Clone(), - typeDecoders: rb.registry.typeDecoders.Clone(), - kindEncoders: rb.registry.kindEncoders.Clone(), - kindDecoders: rb.registry.kindDecoders.Clone(), - } - rb.registry.typeMap.Range(func(k, v interface{}) bool { - if k != nil && v != nil { - r.typeMap.Store(k, v) - } - return true - }) - return r -} - // A Registry is a store for ValueEncoders, ValueDecoders, and a type map. See the Registry type // documentation for examples of registering various custom encoders and decoders. A Registry can // have four main types of codecs: @@ -289,7 +103,16 @@ type Registry struct { // NewRegistry creates a new empty Registry. func NewRegistry() *Registry { - return NewRegistryBuilder().Build() + reg := &Registry{ + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), + } + registerDefaultEncoders(reg) + registerDefaultDecoders(reg) + registerPrimitiveCodecs(reg) + return reg } // RegisterTypeEncoder registers the provided ValueEncoder for the provided type. diff --git a/bson/registry_test.go b/bson/registry_test.go index 2bc87364d3..c7963d4edd 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -15,15 +15,13 @@ import ( "go.mongodb.org/mongo-driver/internal/assert" ) -// newTestRegistryBuilder creates a new empty Registry. -func newTestRegistryBuilder() *RegistryBuilder { - return &RegistryBuilder{ - registry: &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), - }, +// newTestRegistry creates a new empty Registry. +func newTestRegistry() *Registry { + return &Registry{ + typeEncoders: new(typeEncoderCache), + typeDecoders: new(typeDecoderCache), + kindEncoders: new(kindEncoderCache), + kindDecoders: new(kindDecoderCache), } } @@ -45,12 +43,11 @@ func TestRegistryBuilder(t *testing.T) { {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, } - rb := newTestRegistryBuilder() + reg := newTestRegistry() for _, ip := range ips { - rb.RegisterHookEncoder(ip.i, ip.ve) + reg.RegisterInterfaceEncoder(ip.i, ip.ve) } - reg := rb.Build() got := reg.interfaceEncoders if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) { t.Errorf("the registered interfaces are not correct: got %#v, want %#v", got, want) @@ -58,11 +55,11 @@ func TestRegistryBuilder(t *testing.T) { }) t.Run("type", func(t *testing.T) { ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} - rb := newTestRegistryBuilder(). - RegisterTypeEncoder(reflect.TypeOf(ft1), fc1). - RegisterTypeEncoder(reflect.TypeOf(ft2), fc2). - RegisterTypeEncoder(reflect.TypeOf(ft1), fc3). - RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) + reg := newTestRegistry() + reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc1) + reg.RegisterTypeEncoder(reflect.TypeOf(ft2), fc2) + reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc3) + reg.RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) want := []struct { t reflect.Type c ValueEncoder @@ -72,7 +69,6 @@ func TestRegistryBuilder(t *testing.T) { {reflect.TypeOf(ft4), fc4}, } - reg := rb.Build() got := reg.typeEncoders for _, s := range want { wantT, wantC := s.t, s.c @@ -87,11 +83,11 @@ func TestRegistryBuilder(t *testing.T) { }) t.Run("kind", func(t *testing.T) { k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map - rb := newTestRegistryBuilder(). - RegisterDefaultEncoder(k1, fc1). - RegisterDefaultEncoder(k2, fc2). - RegisterDefaultEncoder(k1, fc3). - RegisterDefaultEncoder(k4, fc4) + reg := newTestRegistry() + reg.RegisterKindEncoder(k1, fc1) + reg.RegisterKindEncoder(k2, fc2) + reg.RegisterKindEncoder(k1, fc3) + reg.RegisterKindEncoder(k4, fc4) want := []struct { k reflect.Kind c ValueEncoder @@ -101,7 +97,6 @@ func TestRegistryBuilder(t *testing.T) { {k4, fc4}, } - reg := rb.Build() got := reg.kindEncoders for _, s := range want { wantK, wantC := s.k, s.c @@ -118,16 +113,14 @@ func TestRegistryBuilder(t *testing.T) { t.Run("MapCodec", func(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() + reg := newTestRegistry() - rb.RegisterDefaultEncoder(reflect.Map, codec) - reg := rb.Build() + reg.RegisterKindEncoder(reflect.Map, codec) if reg.kindEncoders.get(reflect.Map) != codec { t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) } - rb.RegisterDefaultEncoder(reflect.Map, codec2) - reg = rb.Build() + reg.RegisterKindEncoder(reflect.Map, codec2) if reg.kindEncoders.get(reflect.Map) != codec2 { t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) } @@ -135,16 +128,14 @@ func TestRegistryBuilder(t *testing.T) { t.Run("StructCodec", func(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() + reg := newTestRegistry() - rb.RegisterDefaultEncoder(reflect.Struct, codec) - reg := rb.Build() + reg.RegisterKindEncoder(reflect.Struct, codec) if reg.kindEncoders.get(reflect.Struct) != codec { t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) } - rb.RegisterDefaultEncoder(reflect.Struct, codec2) - reg = rb.Build() + reg.RegisterKindEncoder(reflect.Struct, codec2) if reg.kindEncoders.get(reflect.Struct) != codec2 { t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) } @@ -152,16 +143,14 @@ func TestRegistryBuilder(t *testing.T) { t.Run("SliceCodec", func(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() + reg := newTestRegistry() - rb.RegisterDefaultEncoder(reflect.Slice, codec) - reg := rb.Build() + reg.RegisterKindEncoder(reflect.Slice, codec) if reg.kindEncoders.get(reflect.Slice) != codec { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) } - rb.RegisterDefaultEncoder(reflect.Slice, codec2) - reg = rb.Build() + reg.RegisterKindEncoder(reflect.Slice, codec2) if reg.kindEncoders.get(reflect.Slice) != codec2 { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) } @@ -169,16 +158,14 @@ func TestRegistryBuilder(t *testing.T) { t.Run("ArrayCodec", func(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - rb := newTestRegistryBuilder() + reg := newTestRegistry() - rb.RegisterDefaultEncoder(reflect.Array, codec) - reg := rb.Build() + reg.RegisterKindEncoder(reflect.Array, codec) if reg.kindEncoders.get(reflect.Array) != codec { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) } - rb.RegisterDefaultEncoder(reflect.Array, codec2) - reg = rb.Build() + reg.RegisterKindEncoder(reflect.Array, codec2) if reg.kindEncoders.get(reflect.Array) != codec2 { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) } @@ -211,28 +198,27 @@ func TestRegistryBuilder(t *testing.T) { pc = NewPointerCodec() ) - reg := newTestRegistryBuilder(). - RegisterTypeEncoder(ft1, fc1). - RegisterTypeEncoder(ft2, fc2). - RegisterTypeEncoder(ti1, fc1). - RegisterDefaultEncoder(reflect.Struct, fsc). - RegisterDefaultEncoder(reflect.Slice, fslcc). - RegisterDefaultEncoder(reflect.Array, fslcc). - RegisterDefaultEncoder(reflect.Map, fmc). - RegisterDefaultEncoder(reflect.Ptr, pc). - RegisterTypeDecoder(ft1, fc1). - RegisterTypeDecoder(ft2, fc2). - RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder - RegisterDefaultDecoder(reflect.Struct, fsc). - RegisterDefaultDecoder(reflect.Slice, fslcc). - RegisterDefaultDecoder(reflect.Array, fslcc). - RegisterDefaultDecoder(reflect.Map, fmc). - RegisterDefaultDecoder(reflect.Ptr, pc). - RegisterHookEncoder(ti2, fc2). - RegisterHookDecoder(ti2, fc2). - RegisterHookEncoder(ti3, fc3). - RegisterHookDecoder(ti3, fc3). - Build() + reg := newTestRegistry() + reg.RegisterTypeEncoder(ft1, fc1) + reg.RegisterTypeEncoder(ft2, fc2) + reg.RegisterTypeEncoder(ti1, fc1) + reg.RegisterKindEncoder(reflect.Struct, fsc) + reg.RegisterKindEncoder(reflect.Slice, fslcc) + reg.RegisterKindEncoder(reflect.Array, fslcc) + reg.RegisterKindEncoder(reflect.Map, fmc) + reg.RegisterKindEncoder(reflect.Ptr, pc) + reg.RegisterTypeDecoder(ft1, fc1) + reg.RegisterTypeDecoder(ft2, fc2) + reg.RegisterTypeDecoder(ti1, fc1) // values whose exact type is testInterface1 will use fc1 encoder + reg.RegisterKindDecoder(reflect.Struct, fsc) + reg.RegisterKindDecoder(reflect.Slice, fslcc) + reg.RegisterKindDecoder(reflect.Array, fslcc) + reg.RegisterKindDecoder(reflect.Map, fmc) + reg.RegisterKindDecoder(reflect.Ptr, pc) + reg.RegisterInterfaceEncoder(ti2, fc2) + reg.RegisterInterfaceEncoder(ti3, fc3) + reg.RegisterInterfaceDecoder(ti2, fc2) + reg.RegisterInterfaceDecoder(ti3, fc3) testCases := []struct { name string @@ -409,10 +395,9 @@ func TestRegistryBuilder(t *testing.T) { }) }) t.Run("Type Map", func(t *testing.T) { - reg := newTestRegistryBuilder(). - RegisterTypeMapEntry(TypeString, reflect.TypeOf("")). - RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))). - Build() + reg := newTestRegistry() + reg.RegisterTypeMapEntry(TypeString, reflect.TypeOf("")) + reg.RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))) var got, want reflect.Type @@ -466,7 +451,7 @@ func TestRegistry(t *testing.T) { {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, } - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() for _, ip := range ips { reg.RegisterInterfaceEncoder(ip.i, ip.ve) } @@ -479,7 +464,7 @@ func TestRegistry(t *testing.T) { t.Parallel() ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc1) reg.RegisterTypeEncoder(reflect.TypeOf(ft2), fc2) reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc3) @@ -509,7 +494,7 @@ func TestRegistry(t *testing.T) { t.Parallel() k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterKindEncoder(k1, fc1) reg.RegisterKindEncoder(k2, fc2) reg.RegisterKindEncoder(k1, fc3) @@ -543,7 +528,7 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterKindEncoder(reflect.Map, codec) if reg.kindEncoders.get(reflect.Map) != codec { t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) @@ -558,7 +543,7 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterKindEncoder(reflect.Struct, codec) if reg.kindEncoders.get(reflect.Struct) != codec { t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) @@ -573,7 +558,7 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterKindEncoder(reflect.Slice, codec) if reg.kindEncoders.get(reflect.Slice) != codec { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) @@ -588,7 +573,7 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterKindEncoder(reflect.Array, codec) if reg.kindEncoders.get(reflect.Array) != codec { t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) @@ -628,7 +613,7 @@ func TestRegistry(t *testing.T) { pc = NewPointerCodec() ) - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterTypeEncoder(ft1, fc1) reg.RegisterTypeEncoder(ft2, fc2) reg.RegisterTypeEncoder(ti1, fc1) @@ -869,7 +854,7 @@ func TestRegistry(t *testing.T) { }) t.Run("Type Map", func(t *testing.T) { t.Parallel() - reg := newTestRegistryBuilder().Build() + reg := newTestRegistry() reg.RegisterTypeMapEntry(TypeString, reflect.TypeOf("")) reg.RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))) diff --git a/bson/slice_codec.go b/bson/slice_codec.go index 52449239b9..f29c36b26d 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -90,7 +90,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.V } for idx := 0; idx < val.Len(); idx++ { - currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.Index(idx)) + currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -176,9 +176,9 @@ func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect. switch val.Type().Elem() { case tE: dc.Ancestor = val.Type() - elemsFunc = defaultValueDecoders.decodeD + elemsFunc = decodeD default: - elemsFunc = defaultValueDecoders.decodeDefault + elemsFunc = decodeDefault } elems, err := elemsFunc(dc, vr, val) diff --git a/bson/string_codec.go b/bson/string_codec.go index 50fb9229fe..4681f15bd4 100644 --- a/bson/string_codec.go +++ b/bson/string_codec.go @@ -7,6 +7,7 @@ package bson import ( + "errors" "fmt" "reflect" @@ -56,7 +57,7 @@ func (sc *StringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect. return vw.WriteString(val.String()) } -func (sc *StringCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (sc *StringCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t.Kind() != reflect.String { return emptyValue, ValueDecoderError{ Name: "StringDecodeValue", @@ -78,12 +79,10 @@ func (sc *StringCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Typ if err != nil { return emptyValue, err } - if sc.DecodeObjectIDAsHex { + if dc.decodeObjectIDAsHex { str = oid.Hex() } else { - // TODO(GODRIVER-2796): Return an error here instead of decoding to a garbled string. - byteArray := [12]byte(oid) - str = string(byteArray[:]) + return emptyValue, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set") } case TypeSymbol: str, err = vr.ReadSymbol() diff --git a/bson/string_codec_test.go b/bson/string_codec_test.go index 75ace60c5d..56d1215af5 100644 --- a/bson/string_codec_test.go +++ b/bson/string_codec_test.go @@ -7,35 +7,37 @@ package bson import ( + "errors" "reflect" "testing" - "go.mongodb.org/mongo-driver/bson/bsonoptions" "go.mongodb.org/mongo-driver/internal/assert" ) func TestStringCodec(t *testing.T) { t.Run("ObjectIDAsHex", func(t *testing.T) { oid := NewObjectID() - byteArray := [12]byte(oid) reader := &valueReaderWriter{BSONType: TypeObjectID, Return: oid} testCases := []struct { name string - opts *bsonoptions.StringCodecOptions - hex bool + dctx DecodeContext + err error result string }{ - {"default", bsonoptions.StringCodec(), true, oid.Hex()}, - {"true", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(true), true, oid.Hex()}, - {"false", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false), false, string(byteArray[:])}, + {"default", DecodeContext{}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, + {"decode hex", DecodeContext{decodeObjectIDAsHex: true}, nil, oid.Hex()}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - stringCodec := NewStringCodec(tc.opts) + stringCodec := NewStringCodec() actual := reflect.New(reflect.TypeOf("")).Elem() - err := stringCodec.DecodeValue(DecodeContext{}, reader, actual) - assert.Nil(t, err, "StringCodec.DecodeValue error: %v", err) + err := stringCodec.DecodeValue(tc.dctx, reader, actual) + if tc.err == nil { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tc.err.Error()) + } actualString := actual.Interface().(string) assert.Equal(t, tc.result, actualString, "Expected string %v, got %v", tc.result, actualString) diff --git a/bson/struct_codec.go b/bson/struct_codec.go index 917ac17bfd..17e51bce14 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -153,7 +153,7 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect } } - desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(ec, desc.encoder, rv) + desc.encoder, rv, err = lookupElementEncoder(ec, desc.encoder, rv) if err != nil && !errors.Is(err, errInvalidValue) { return err From d8fba65d6d8ba00755aaeff42306ecf93706a5d9 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Mon, 29 Apr 2024 17:43:07 -0400 Subject: [PATCH 02/15] WIP --- bson/array_codec.go | 21 +--- bson/bson_test.go | 16 ++- bson/bsonoptions/byte_slice_codec_options.go | 49 -------- bson/bsonoptions/doc.go | 8 -- .../empty_interface_codec_options.go | 49 -------- bson/bsonoptions/map_codec_options.go | 82 -------------- bson/bsonoptions/slice_codec_options.go | 49 -------- bson/bsonoptions/string_codec_options.go | 52 --------- bson/bsonoptions/struct_codec_options.go | 107 ------------------ bson/bsonoptions/time_codec_options.go | 49 -------- bson/bsonoptions/uint_codec_options.go | 49 -------- bson/byte_slice_codec.go | 48 ++------ bson/cond_addr_codec.go | 6 - bson/cond_addr_codec_test.go | 4 +- bson/default_value_decoders.go | 38 +++---- bson/default_value_decoders_test.go | 33 +++--- bson/default_value_encoders.go | 28 ++--- bson/default_value_encoders_test.go | 22 ++-- bson/empty_interface_codec.go | 51 +++------ bson/map_codec.go | 69 +++-------- bson/mgocompat/bson_test.go | 6 +- bson/mgocompat/doc.go | 5 - bson/mgocompat/registry.go | 93 +-------------- bson/mgoregistry.go | 81 +++++++++++++ bson/pointer_codec.go | 27 ++--- bson/registry.go | 2 +- bson/registry_test.go | 8 +- bson/{mgocompat => }/setter_getter.go | 35 ++---- bson/slice_codec.go | 37 ++---- bson/string_codec.go | 47 +++----- bson/string_codec_test.go | 5 +- bson/struct_codec.go | 105 +++++------------ bson/time_codec.go | 47 ++------ bson/time_codec_test.go | 18 ++- bson/uint_codec.go | 47 ++------ bson/unmarshal_value_test.go | 4 +- 36 files changed, 308 insertions(+), 1089 deletions(-) delete mode 100644 bson/bsonoptions/byte_slice_codec_options.go delete mode 100644 bson/bsonoptions/doc.go delete mode 100644 bson/bsonoptions/empty_interface_codec_options.go delete mode 100644 bson/bsonoptions/map_codec_options.go delete mode 100644 bson/bsonoptions/slice_codec_options.go delete mode 100644 bson/bsonoptions/string_codec_options.go delete mode 100644 bson/bsonoptions/struct_codec_options.go delete mode 100644 bson/bsonoptions/time_codec_options.go delete mode 100644 bson/bsonoptions/uint_codec_options.go create mode 100644 bson/mgoregistry.go rename bson/{mgocompat => }/setter_getter.go (68%) diff --git a/bson/array_codec.go b/bson/array_codec.go index 5b07f4acd4..4a53d376bc 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -12,24 +12,11 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -// ArrayCodec is the Codec used for bsoncore.Array values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ArrayCodec registered. -type ArrayCodec struct{} - -var defaultArrayCodec = NewArrayCodec() - -// NewArrayCodec returns an ArrayCodec. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ArrayCodec registered. -func NewArrayCodec() *ArrayCodec { - return &ArrayCodec{} -} +// arrayCodec is the Codec used for bsoncore.Array values. +type arrayCodec struct{} // EncodeValue is the ValueEncoder for bsoncore.Array values. -func (ac *ArrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreArray { return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } @@ -39,7 +26,7 @@ func (ac *ArrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.V } // DecodeValue is the ValueDecoder for bsoncore.Array values. -func (ac *ArrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +func (ac *arrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreArray { return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } diff --git a/bson/bson_test.go b/bson/bson_test.go index 31b6ffb884..5d99e066a8 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -17,7 +17,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "go.mongodb.org/mongo-driver/bson/bsonoptions" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" @@ -349,19 +348,18 @@ func TestMapCodec(t *testing.T) { strstr := stringerString("foo") mapObj := map[stringerString]int{strstr: 1} testCases := []struct { - name string - opts *bsonoptions.MapCodecOptions - key string + name string + codec *mapCodec + key string }{ - {"default", bsonoptions.MapCodec(), "foo"}, - {"true", bsonoptions.MapCodec().SetEncodeKeysWithStringer(true), "bar"}, - {"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"}, + {"default", &mapCodec{}, "foo"}, + {"true", &mapCodec{encodeKeysWithStringer: true}, "bar"}, + {"false", &mapCodec{encodeKeysWithStringer: false}, "foo"}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - mapCodec := NewMapCodec(tc.opts) mapRegistry := NewRegistry() - mapRegistry.RegisterKindEncoder(reflect.Map, mapCodec) + mapRegistry.RegisterKindEncoder(reflect.Map, tc.codec) buf := new(bytes.Buffer) vw := NewValueWriter(buf) enc := NewEncoder(vw) diff --git a/bson/bsonoptions/byte_slice_codec_options.go b/bson/bsonoptions/byte_slice_codec_options.go deleted file mode 100644 index 996bd17127..0000000000 --- a/bson/bsonoptions/byte_slice_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// ByteSliceCodecOptions represents all possible options for byte slice encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type ByteSliceCodecOptions struct { - EncodeNilAsEmpty *bool // Specifies if a nil byte slice should encode as an empty binary instead of null. Defaults to false. -} - -// ByteSliceCodec creates a new *ByteSliceCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func ByteSliceCodec() *ByteSliceCodecOptions { - return &ByteSliceCodecOptions{} -} - -// SetEncodeNilAsEmpty specifies if a nil byte slice should encode as an empty binary instead of null. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilByteSliceAsEmpty] instead. -func (bs *ByteSliceCodecOptions) SetEncodeNilAsEmpty(b bool) *ByteSliceCodecOptions { - bs.EncodeNilAsEmpty = &b - return bs -} - -// MergeByteSliceCodecOptions combines the given *ByteSliceCodecOptions into a single *ByteSliceCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeByteSliceCodecOptions(opts ...*ByteSliceCodecOptions) *ByteSliceCodecOptions { - bs := ByteSliceCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.EncodeNilAsEmpty != nil { - bs.EncodeNilAsEmpty = opt.EncodeNilAsEmpty - } - } - - return bs -} diff --git a/bson/bsonoptions/doc.go b/bson/bsonoptions/doc.go deleted file mode 100644 index c40973c8d4..0000000000 --- a/bson/bsonoptions/doc.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2022-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -// Package bsonoptions defines the optional configurations for the BSON codecs. -package bsonoptions diff --git a/bson/bsonoptions/empty_interface_codec_options.go b/bson/bsonoptions/empty_interface_codec_options.go deleted file mode 100644 index f522c7e03f..0000000000 --- a/bson/bsonoptions/empty_interface_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// EmptyInterfaceCodecOptions represents all possible options for interface{} encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type EmptyInterfaceCodecOptions struct { - DecodeBinaryAsSlice *bool // Specifies if Old and Generic type binarys should default to []slice instead of primitive.Binary. Defaults to false. -} - -// EmptyInterfaceCodec creates a new *EmptyInterfaceCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func EmptyInterfaceCodec() *EmptyInterfaceCodecOptions { - return &EmptyInterfaceCodecOptions{} -} - -// SetDecodeBinaryAsSlice specifies if Old and Generic type binarys should default to []slice instead of primitive.Binary. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.BinaryAsSlice] instead. -func (e *EmptyInterfaceCodecOptions) SetDecodeBinaryAsSlice(b bool) *EmptyInterfaceCodecOptions { - e.DecodeBinaryAsSlice = &b - return e -} - -// MergeEmptyInterfaceCodecOptions combines the given *EmptyInterfaceCodecOptions into a single *EmptyInterfaceCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeEmptyInterfaceCodecOptions(opts ...*EmptyInterfaceCodecOptions) *EmptyInterfaceCodecOptions { - e := EmptyInterfaceCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.DecodeBinaryAsSlice != nil { - e.DecodeBinaryAsSlice = opt.DecodeBinaryAsSlice - } - } - - return e -} diff --git a/bson/bsonoptions/map_codec_options.go b/bson/bsonoptions/map_codec_options.go deleted file mode 100644 index a7a7c1d980..0000000000 --- a/bson/bsonoptions/map_codec_options.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// MapCodecOptions represents all possible options for map encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type MapCodecOptions struct { - DecodeZerosMap *bool // Specifies if the map should be zeroed before decoding into it. Defaults to false. - EncodeNilAsEmpty *bool // Specifies if a nil map should encode as an empty document instead of null. Defaults to false. - // Specifies how keys should be handled. If false, the behavior matches encoding/json, where the encoding key type must - // either be a string, an integer type, or implement bsoncodec.KeyMarshaler and the decoding key type must either be a - // string, an integer type, or implement bsoncodec.KeyUnmarshaler. If true, keys are encoded with fmt.Sprint() and the - // encoding key type must be a string, an integer type, or a float. If true, the use of Stringer will override - // TextMarshaler/TextUnmarshaler. Defaults to false. - EncodeKeysWithStringer *bool -} - -// MapCodec creates a new *MapCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func MapCodec() *MapCodecOptions { - return &MapCodecOptions{} -} - -// SetDecodeZerosMap specifies if the map should be zeroed before decoding into it. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroMaps] instead. -func (t *MapCodecOptions) SetDecodeZerosMap(b bool) *MapCodecOptions { - t.DecodeZerosMap = &b - return t -} - -// SetEncodeNilAsEmpty specifies if a nil map should encode as an empty document instead of null. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilMapAsEmpty] instead. -func (t *MapCodecOptions) SetEncodeNilAsEmpty(b bool) *MapCodecOptions { - t.EncodeNilAsEmpty = &b - return t -} - -// SetEncodeKeysWithStringer specifies how keys should be handled. If false, the behavior matches encoding/json, where the -// encoding key type must either be a string, an integer type, or implement bsoncodec.KeyMarshaler and the decoding key -// type must either be a string, an integer type, or implement bsoncodec.KeyUnmarshaler. If true, keys are encoded with -// fmt.Sprint() and the encoding key type must be a string, an integer type, or a float. If true, the use of Stringer -// will override TextMarshaler/TextUnmarshaler. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.StringifyMapKeysWithFmt] instead. -func (t *MapCodecOptions) SetEncodeKeysWithStringer(b bool) *MapCodecOptions { - t.EncodeKeysWithStringer = &b - return t -} - -// MergeMapCodecOptions combines the given *MapCodecOptions into a single *MapCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeMapCodecOptions(opts ...*MapCodecOptions) *MapCodecOptions { - s := MapCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.DecodeZerosMap != nil { - s.DecodeZerosMap = opt.DecodeZerosMap - } - if opt.EncodeNilAsEmpty != nil { - s.EncodeNilAsEmpty = opt.EncodeNilAsEmpty - } - if opt.EncodeKeysWithStringer != nil { - s.EncodeKeysWithStringer = opt.EncodeKeysWithStringer - } - } - - return s -} diff --git a/bson/bsonoptions/slice_codec_options.go b/bson/bsonoptions/slice_codec_options.go deleted file mode 100644 index 3c1e4f35ba..0000000000 --- a/bson/bsonoptions/slice_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// SliceCodecOptions represents all possible options for slice encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type SliceCodecOptions struct { - EncodeNilAsEmpty *bool // Specifies if a nil slice should encode as an empty array instead of null. Defaults to false. -} - -// SliceCodec creates a new *SliceCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func SliceCodec() *SliceCodecOptions { - return &SliceCodecOptions{} -} - -// SetEncodeNilAsEmpty specifies if a nil slice should encode as an empty array instead of null. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilSliceAsEmpty] instead. -func (s *SliceCodecOptions) SetEncodeNilAsEmpty(b bool) *SliceCodecOptions { - s.EncodeNilAsEmpty = &b - return s -} - -// MergeSliceCodecOptions combines the given *SliceCodecOptions into a single *SliceCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeSliceCodecOptions(opts ...*SliceCodecOptions) *SliceCodecOptions { - s := SliceCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.EncodeNilAsEmpty != nil { - s.EncodeNilAsEmpty = opt.EncodeNilAsEmpty - } - } - - return s -} diff --git a/bson/bsonoptions/string_codec_options.go b/bson/bsonoptions/string_codec_options.go deleted file mode 100644 index f8b76f996e..0000000000 --- a/bson/bsonoptions/string_codec_options.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -var defaultDecodeOIDAsHex = true - -// StringCodecOptions represents all possible options for string encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type StringCodecOptions struct { - DecodeObjectIDAsHex *bool // Specifies if we should decode ObjectID as the hex value. Defaults to true. -} - -// StringCodec creates a new *StringCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func StringCodec() *StringCodecOptions { - return &StringCodecOptions{} -} - -// SetDecodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. If false, a string made -// from the raw object ID bytes will be used. Defaults to true. -// -// Deprecated: Decoding object IDs as raw bytes will not be supported in Go Driver 2.0. -func (t *StringCodecOptions) SetDecodeObjectIDAsHex(b bool) *StringCodecOptions { - t.DecodeObjectIDAsHex = &b - return t -} - -// MergeStringCodecOptions combines the given *StringCodecOptions into a single *StringCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeStringCodecOptions(opts ...*StringCodecOptions) *StringCodecOptions { - s := &StringCodecOptions{&defaultDecodeOIDAsHex} - for _, opt := range opts { - if opt == nil { - continue - } - if opt.DecodeObjectIDAsHex != nil { - s.DecodeObjectIDAsHex = opt.DecodeObjectIDAsHex - } - } - - return s -} diff --git a/bson/bsonoptions/struct_codec_options.go b/bson/bsonoptions/struct_codec_options.go deleted file mode 100644 index 1cbfa32e8b..0000000000 --- a/bson/bsonoptions/struct_codec_options.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -var defaultOverwriteDuplicatedInlinedFields = true - -// StructCodecOptions represents all possible options for struct encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type StructCodecOptions struct { - DecodeZeroStruct *bool // Specifies if structs should be zeroed before decoding into them. Defaults to false. - DecodeDeepZeroInline *bool // Specifies if structs should be recursively zeroed when a inline value is decoded. Defaults to false. - EncodeOmitDefaultStruct *bool // Specifies if default structs should be considered empty by omitempty. Defaults to false. - AllowUnexportedFields *bool // Specifies if unexported fields should be marshaled/unmarshaled. Defaults to false. - OverwriteDuplicatedInlinedFields *bool // Specifies if fields in inlined structs can be overwritten by higher level struct fields with the same key. Defaults to true. -} - -// StructCodec creates a new *StructCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func StructCodec() *StructCodecOptions { - return &StructCodecOptions{} -} - -// SetDecodeZeroStruct specifies if structs should be zeroed before decoding into them. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroStructs] instead. -func (t *StructCodecOptions) SetDecodeZeroStruct(b bool) *StructCodecOptions { - t.DecodeZeroStruct = &b - return t -} - -// SetDecodeDeepZeroInline specifies if structs should be zeroed before decoding into them. Defaults to false. -// -// Deprecated: DecodeDeepZeroInline will not be supported in Go Driver 2.0. -func (t *StructCodecOptions) SetDecodeDeepZeroInline(b bool) *StructCodecOptions { - t.DecodeDeepZeroInline = &b - return t -} - -// SetEncodeOmitDefaultStruct specifies if default structs should be considered empty by omitempty. A default struct has all -// its values set to their default value. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.OmitZeroStruct] instead. -func (t *StructCodecOptions) SetEncodeOmitDefaultStruct(b bool) *StructCodecOptions { - t.EncodeOmitDefaultStruct = &b - return t -} - -// SetOverwriteDuplicatedInlinedFields specifies if inlined struct fields can be overwritten by higher level struct fields with the -// same bson key. When true and decoding, values will be written to the outermost struct with a matching key, and when -// encoding, keys will have the value of the top-most matching field. When false, decoding and encoding will error if -// there are duplicate keys after the struct is inlined. Defaults to true. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.ErrorOnInlineDuplicates] instead. -func (t *StructCodecOptions) SetOverwriteDuplicatedInlinedFields(b bool) *StructCodecOptions { - t.OverwriteDuplicatedInlinedFields = &b - return t -} - -// SetAllowUnexportedFields specifies if unexported fields should be marshaled/unmarshaled. Defaults to false. -// -// Deprecated: AllowUnexportedFields does not work on recent versions of Go and will not be -// supported in Go Driver 2.0. -func (t *StructCodecOptions) SetAllowUnexportedFields(b bool) *StructCodecOptions { - t.AllowUnexportedFields = &b - return t -} - -// MergeStructCodecOptions combines the given *StructCodecOptions into a single *StructCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeStructCodecOptions(opts ...*StructCodecOptions) *StructCodecOptions { - s := &StructCodecOptions{ - OverwriteDuplicatedInlinedFields: &defaultOverwriteDuplicatedInlinedFields, - } - for _, opt := range opts { - if opt == nil { - continue - } - - if opt.DecodeZeroStruct != nil { - s.DecodeZeroStruct = opt.DecodeZeroStruct - } - if opt.DecodeDeepZeroInline != nil { - s.DecodeDeepZeroInline = opt.DecodeDeepZeroInline - } - if opt.EncodeOmitDefaultStruct != nil { - s.EncodeOmitDefaultStruct = opt.EncodeOmitDefaultStruct - } - if opt.OverwriteDuplicatedInlinedFields != nil { - s.OverwriteDuplicatedInlinedFields = opt.OverwriteDuplicatedInlinedFields - } - if opt.AllowUnexportedFields != nil { - s.AllowUnexportedFields = opt.AllowUnexportedFields - } - } - - return s -} diff --git a/bson/bsonoptions/time_codec_options.go b/bson/bsonoptions/time_codec_options.go deleted file mode 100644 index 3f38433d22..0000000000 --- a/bson/bsonoptions/time_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// TimeCodecOptions represents all possible options for time.Time encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type TimeCodecOptions struct { - UseLocalTimeZone *bool // Specifies if we should decode into the local time zone. Defaults to false. -} - -// TimeCodec creates a new *TimeCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func TimeCodec() *TimeCodecOptions { - return &TimeCodecOptions{} -} - -// SetUseLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.UseLocalTimeZone] instead. -func (t *TimeCodecOptions) SetUseLocalTimeZone(b bool) *TimeCodecOptions { - t.UseLocalTimeZone = &b - return t -} - -// MergeTimeCodecOptions combines the given *TimeCodecOptions into a single *TimeCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeTimeCodecOptions(opts ...*TimeCodecOptions) *TimeCodecOptions { - t := TimeCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.UseLocalTimeZone != nil { - t.UseLocalTimeZone = opt.UseLocalTimeZone - } - } - - return t -} diff --git a/bson/bsonoptions/uint_codec_options.go b/bson/bsonoptions/uint_codec_options.go deleted file mode 100644 index 5091e4d963..0000000000 --- a/bson/bsonoptions/uint_codec_options.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bsonoptions - -// UIntCodecOptions represents all possible options for uint encoding and decoding. -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -type UIntCodecOptions struct { - EncodeToMinSize *bool // Specifies if all uints except uint64 should be decoded to minimum size bsontype. Defaults to false. -} - -// UIntCodec creates a new *UIntCodecOptions -// -// Deprecated: Use the bson.Encoder and bson.Decoder configuration methods to set the desired BSON marshal -// and unmarshal behavior instead. -func UIntCodec() *UIntCodecOptions { - return &UIntCodecOptions{} -} - -// SetEncodeToMinSize specifies if all uints except uint64 should be decoded to minimum size bsontype. Defaults to false. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.IntMinSize] instead. -func (u *UIntCodecOptions) SetEncodeToMinSize(b bool) *UIntCodecOptions { - u.EncodeToMinSize = &b - return u -} - -// MergeUIntCodecOptions combines the given *UIntCodecOptions into a single *UIntCodecOptions in a last one wins fashion. -// -// Deprecated: Merging options structs will not be supported in Go Driver 2.0. Users should create a -// single options struct instead. -func MergeUIntCodecOptions(opts ...*UIntCodecOptions) *UIntCodecOptions { - u := UIntCodec() - for _, opt := range opts { - if opt == nil { - continue - } - if opt.EncodeToMinSize != nil { - u.EncodeToMinSize = opt.EncodeToMinSize - } - } - - return u -} diff --git a/bson/byte_slice_codec.go b/bson/byte_slice_codec.go index 586c006467..bd5c5dae85 100644 --- a/bson/byte_slice_codec.go +++ b/bson/byte_slice_codec.go @@ -9,56 +9,32 @@ package bson import ( "fmt" "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -// ByteSliceCodec is the Codec used for []byte values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ByteSliceCodec registered. -type ByteSliceCodec struct { - // EncodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values +// byteSliceCodec is the Codec used for []byte values. +type byteSliceCodec struct { + // encodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values // instead of BSON null. - // - // Deprecated: Use bson.Encoder.NilByteSliceAsEmpty instead. - EncodeNilAsEmpty bool + encodeNilAsEmpty bool } -var ( - defaultByteSliceCodec = NewByteSliceCodec() - - // Assert that defaultByteSliceCodec satisfies the typeDecoder interface, which allows it to be - // used by collection type decoders (e.g. map, slice, etc) to set individual values in a - // collection. - _ typeDecoder = defaultByteSliceCodec -) - -// NewByteSliceCodec returns a ByteSliceCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// ByteSliceCodec registered. -func NewByteSliceCodec(opts ...*bsonoptions.ByteSliceCodecOptions) *ByteSliceCodec { - byteSliceOpt := bsonoptions.MergeByteSliceCodecOptions(opts...) - codec := ByteSliceCodec{} - if byteSliceOpt.EncodeNilAsEmpty != nil { - codec.EncodeNilAsEmpty = *byteSliceOpt.EncodeNilAsEmpty - } - return &codec -} +// Assert that defaultByteSliceCodec satisfies the typeDecoder interface, which allows it to be +// used by collection type decoders (e.g. map, slice, etc) to set individual values in a +// collection. +var _ typeDecoder = (*byteSliceCodec)(nil) // EncodeValue is the ValueEncoder for []byte. -func (bsc *ByteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (bsc *byteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tByteSlice { return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } - if val.IsNil() && !bsc.EncodeNilAsEmpty && !ec.nilByteSliceAsEmpty { + if val.IsNil() && !bsc.encodeNilAsEmpty && !ec.nilByteSliceAsEmpty { return vw.WriteNull() } return vw.WriteBinary(val.Interface().([]byte)) } -func (bsc *ByteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tByteSlice { return emptyValue, ValueDecoderError{ Name: "ByteSliceDecodeValue", @@ -106,7 +82,7 @@ func (bsc *ByteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect } // DecodeValue is the ValueDecoder for []byte. -func (bsc *ByteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (bsc *byteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tByteSlice { return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } diff --git a/bson/cond_addr_codec.go b/bson/cond_addr_codec.go index fba139ff07..26eed212f1 100644 --- a/bson/cond_addr_codec.go +++ b/bson/cond_addr_codec.go @@ -18,12 +18,6 @@ type condAddrEncoder struct { var _ ValueEncoder = (*condAddrEncoder)(nil) -// newCondAddrEncoder returns an condAddrEncoder. -func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder { - encoder := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc} - return &encoder -} - // EncodeValue is the ValueEncoderFunc for a value that may be addressable. func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if val.CanAddr() { diff --git a/bson/cond_addr_codec_test.go b/bson/cond_addr_codec_test.go index c22c29fe72..55d73e8204 100644 --- a/bson/cond_addr_codec_test.go +++ b/bson/cond_addr_codec_test.go @@ -30,7 +30,7 @@ func TestCondAddrCodec(t *testing.T) { invoked = 2 return nil }) - condEncoder := newCondAddrEncoder(encode1, encode2) + condEncoder := &condAddrEncoder{canAddrEnc: encode1, elseEnc: encode2} testCases := []struct { name string @@ -50,7 +50,7 @@ func TestCondAddrCodec(t *testing.T) { } t.Run("error", func(t *testing.T) { - errEncoder := newCondAddrEncoder(encode1, nil) + errEncoder := &condAddrEncoder{canAddrEnc: encode1, elseEnc: nil} err := errEncoder.EncodeValue(EncodeContext{}, rw, unaddressable) want := ErrNoEncoder{Type: unaddressable.Type()} assert.Equal(t, err, want, "expected error %v, got %v", want, err) diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index e4ea1f394e..5ad52e19e2 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -31,16 +31,6 @@ func (d decodeBinaryError) Error() string { return fmt.Sprintf("only binary values with subtype 0x00 or 0x02 can be decoded into %s, but got subtype %v", d.typeName, d.subtype) } -func newDefaultStructCodec() *StructCodec { - codec, err := NewStructCodec(DefaultStructTagParser) - if err != nil { - // This function is called from the codec registration path, so errors can't be propagated. If there's an error - // constructing the StructCodec, we panic to avoid losing it. - panic(fmt.Errorf("error creating default StructCodec: %w", err)) - } - return codec -} - // registerDefaultDecoders will register the default decoder methods with the provided Registry. // // There is no support for decoding map[string]interface{} because there is no decoder for @@ -66,10 +56,10 @@ func registerDefaultDecoders(reg *Registry) { reg.RegisterTypeDecoder(tMaxKey, decodeAdapter{maxKeyDecodeValue, maxKeyDecodeType}) reg.RegisterTypeDecoder(tJavaScript, decodeAdapter{javaScriptDecodeValue, javaScriptDecodeType}) reg.RegisterTypeDecoder(tSymbol, decodeAdapter{symbolDecodeValue, symbolDecodeType}) - reg.RegisterTypeDecoder(tByteSlice, defaultByteSliceCodec) - reg.RegisterTypeDecoder(tTime, defaultTimeCodec) - reg.RegisterTypeDecoder(tEmpty, defaultEmptyInterfaceCodec) - reg.RegisterTypeDecoder(tCoreArray, defaultArrayCodec) + reg.RegisterTypeDecoder(tByteSlice, &byteSliceCodec{}) + reg.RegisterTypeDecoder(tTime, &timeCodec{}) + reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{}) + reg.RegisterTypeDecoder(tCoreArray, &arrayCodec{}) reg.RegisterTypeDecoder(tOID, decodeAdapter{objectIDDecodeValue, objectIDDecodeType}) reg.RegisterTypeDecoder(tDecimal, decodeAdapter{decimal128DecodeValue, decimal128DecodeType}) reg.RegisterTypeDecoder(tJSONNumber, decodeAdapter{jsonNumberDecodeValue, jsonNumberDecodeType}) @@ -82,19 +72,19 @@ func registerDefaultDecoders(reg *Registry) { reg.RegisterKindDecoder(reflect.Int16, intDecoder) reg.RegisterKindDecoder(reflect.Int32, intDecoder) reg.RegisterKindDecoder(reflect.Int64, intDecoder) - reg.RegisterKindDecoder(reflect.Uint, defaultUIntCodec) - reg.RegisterKindDecoder(reflect.Uint8, defaultUIntCodec) - reg.RegisterKindDecoder(reflect.Uint16, defaultUIntCodec) - reg.RegisterKindDecoder(reflect.Uint32, defaultUIntCodec) - reg.RegisterKindDecoder(reflect.Uint64, defaultUIntCodec) + reg.RegisterKindDecoder(reflect.Uint, &uintCodec{}) + reg.RegisterKindDecoder(reflect.Uint8, &uintCodec{}) + reg.RegisterKindDecoder(reflect.Uint16, &uintCodec{}) + reg.RegisterKindDecoder(reflect.Uint32, &uintCodec{}) + reg.RegisterKindDecoder(reflect.Uint64, &uintCodec{}) reg.RegisterKindDecoder(reflect.Float32, floatDecoder) reg.RegisterKindDecoder(reflect.Float64, floatDecoder) reg.RegisterKindDecoder(reflect.Array, ValueDecoderFunc(arrayDecodeValue)) - reg.RegisterKindDecoder(reflect.Map, defaultMapCodec) - reg.RegisterKindDecoder(reflect.Slice, defaultSliceCodec) - reg.RegisterKindDecoder(reflect.String, defaultStringCodec) - reg.RegisterKindDecoder(reflect.Struct, newDefaultStructCodec()) - reg.RegisterKindDecoder(reflect.Ptr, NewPointerCodec()) + reg.RegisterKindDecoder(reflect.Map, &mapCodec{}) + reg.RegisterKindDecoder(reflect.Slice, &sliceCodec{}) + reg.RegisterKindDecoder(reflect.String, &stringCodec{}) + reg.RegisterKindDecoder(reflect.Struct, newStructCodec(DefaultStructTagParser)) + reg.RegisterKindDecoder(reflect.Ptr, &pointerCodec{}) reg.RegisterTypeMapEntry(TypeDouble, tFloat64) reg.RegisterTypeMapEntry(TypeString, tString) reg.RegisterTypeMapEntry(TypeArray, tA) diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 56fdc464c2..258ef9e758 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -23,7 +23,7 @@ import ( ) var ( - defaultTestStructCodec = newDefaultStructCodec() + defaultTestStructCodec = newStructCodec(DefaultStructTagParser) ) func TestDefaultValueDecoders(t *testing.T) { @@ -371,7 +371,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultUIntCodec.DecodeValue", - defaultUIntCodec, + &uintCodec{}, []subtest{ { "wrong type", @@ -736,7 +736,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultTimeCodec.DecodeValue", - defaultTimeCodec, + &timeCodec{}, []subtest{ { "wrong type", @@ -790,7 +790,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultMapCodec.DecodeValue", - defaultMapCodec, + &mapCodec{}, []subtest{ { "wrong kind", @@ -962,7 +962,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultSliceCodec.DecodeValue", - defaultSliceCodec, + &sliceCodec{}, []subtest{ { "wrong kind", @@ -1373,7 +1373,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultByteSliceCodec.DecodeValue", - defaultByteSliceCodec, + &byteSliceCodec{}, []subtest{ { "wrong type", @@ -1441,7 +1441,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultStringCodec.DecodeValue", - defaultStringCodec, + &stringCodec{}, []subtest{ { "symbol", @@ -1550,15 +1550,15 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "PointerCodec.DecodeValue", - NewPointerCodec(), + &pointerCodec{}, []subtest{ { "not valid", nil, nil, nil, nothing, - ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: reflect.Value{}}, + ValueDecoderError{Name: "pointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: reflect.Value{}}, }, { "can set", cansettest, nil, nil, nothing, - ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}}, + ValueDecoderError{Name: "pointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}}, }, { "No Decoder", &wrong, &DecodeContext{Registry: buildDefaultRegistry()}, nil, nothing, @@ -2312,7 +2312,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "CoreArrayDecodeValue", - defaultArrayCodec, + &arrayCodec{}, []subtest{ { "wrong type", @@ -3195,6 +3195,7 @@ func TestDefaultValueDecoders(t *testing.T) { }) t.Run("defaultEmptyInterfaceCodec.DecodeValue", func(t *testing.T) { + defaultEmptyInterfaceCodec := &emptyInterfaceCodec{} t.Run("DecodeValue", func(t *testing.T) { testCases := []struct { name string @@ -3486,7 +3487,7 @@ func TestDefaultValueDecoders(t *testing.T) { var got D vr := NewValueReader(doc) val := reflect.ValueOf(&got).Elem() - err := defaultSliceCodec.DecodeValue(DecodeContext{Registry: reg}, vr, val) + err := (&sliceCodec{}).DecodeValue(DecodeContext{Registry: reg}, vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Fatalf("got %v, want %v", got, want) @@ -3574,7 +3575,7 @@ func TestDefaultValueDecoders(t *testing.T) { D{}, NewValueReader(docBytes), emptyInterfaceErrorRegistry, - defaultSliceCodec, + &sliceCodec{}, docEmptyInterfaceErr, }, { @@ -3583,7 +3584,7 @@ func TestDefaultValueDecoders(t *testing.T) { []string{}, &valueReaderWriter{BSONType: TypeArray}, nil, - defaultSliceCodec, + &sliceCodec{}, &DecodeError{ keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type"), @@ -3620,7 +3621,7 @@ func TestDefaultValueDecoders(t *testing.T) { map[string]interface{}{}, NewValueReader(docBytes), emptyInterfaceErrorRegistry, - defaultMapCodec, + &mapCodec{}, docEmptyInterfaceErr, }, { @@ -3730,7 +3731,7 @@ func TestDefaultValueDecoders(t *testing.T) { dc := DecodeContext{Registry: buildDefaultRegistry()} vr := NewValueReader(docBytes) val := reflect.New(reflect.TypeOf(myMap{})).Elem() - err := defaultMapCodec.DecodeValue(dc, vr, val) + err := (&mapCodec{}).DecodeValue(dc, vr, val) assert.Nil(t, err, "DecodeValue error: %v", err) want := myMap{ diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 6b28f1594b..6b2ff14f61 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -54,10 +54,10 @@ func registerDefaultEncoders(reg *Registry) { if reg == nil { panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) } - reg.RegisterTypeEncoder(tByteSlice, defaultByteSliceCodec) - reg.RegisterTypeEncoder(tTime, defaultTimeCodec) - reg.RegisterTypeEncoder(tEmpty, defaultEmptyInterfaceCodec) - reg.RegisterTypeEncoder(tCoreArray, defaultArrayCodec) + reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{}) + reg.RegisterTypeEncoder(tTime, &timeCodec{}) + reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{}) + reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{}) reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue)) reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue)) reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue)) @@ -81,19 +81,19 @@ func registerDefaultEncoders(reg *Registry) { reg.RegisterKindEncoder(reflect.Int16, ValueEncoderFunc(intEncodeValue)) reg.RegisterKindEncoder(reflect.Int32, ValueEncoderFunc(intEncodeValue)) reg.RegisterKindEncoder(reflect.Int64, ValueEncoderFunc(intEncodeValue)) - reg.RegisterKindEncoder(reflect.Uint, defaultUIntCodec) - reg.RegisterKindEncoder(reflect.Uint8, defaultUIntCodec) - reg.RegisterKindEncoder(reflect.Uint16, defaultUIntCodec) - reg.RegisterKindEncoder(reflect.Uint32, defaultUIntCodec) - reg.RegisterKindEncoder(reflect.Uint64, defaultUIntCodec) + reg.RegisterKindEncoder(reflect.Uint, &uintCodec{}) + reg.RegisterKindEncoder(reflect.Uint8, &uintCodec{}) + reg.RegisterKindEncoder(reflect.Uint16, &uintCodec{}) + reg.RegisterKindEncoder(reflect.Uint32, &uintCodec{}) + reg.RegisterKindEncoder(reflect.Uint64, &uintCodec{}) reg.RegisterKindEncoder(reflect.Float32, ValueEncoderFunc(floatEncodeValue)) reg.RegisterKindEncoder(reflect.Float64, ValueEncoderFunc(floatEncodeValue)) reg.RegisterKindEncoder(reflect.Array, ValueEncoderFunc(arrayEncodeValue)) - reg.RegisterKindEncoder(reflect.Map, defaultMapCodec) - reg.RegisterKindEncoder(reflect.Slice, defaultSliceCodec) - reg.RegisterKindEncoder(reflect.String, defaultStringCodec) - reg.RegisterKindEncoder(reflect.Struct, newDefaultStructCodec()) - reg.RegisterKindEncoder(reflect.Ptr, NewPointerCodec()) + reg.RegisterKindEncoder(reflect.Map, &mapCodec{}) + reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) + reg.RegisterKindEncoder(reflect.String, &stringCodec{}) + reg.RegisterKindEncoder(reflect.Struct, newStructCodec(DefaultStructTagParser)) + reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{}) reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue)) reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue)) reg.RegisterInterfaceEncoder(tProxy, ValueEncoderFunc(proxyEncodeValue)) diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 1ebd57f891..797a77322a 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -135,7 +135,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "UintEncodeValue", - defaultUIntCodec, + &uintCodec{}, []subtest{ { "wrong type", @@ -198,7 +198,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "TimeEncodeValue", - defaultTimeCodec, + &timeCodec{}, []subtest{ { "wrong type", @@ -213,7 +213,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "MapEncodeValue", - defaultMapCodec, + &mapCodec{}, []subtest{ { "wrong kind", @@ -371,7 +371,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "SliceEncodeValue", - defaultSliceCodec, + &sliceCodec{}, []subtest{ { "wrong kind", @@ -535,7 +535,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "ByteSliceEncodeValue", - defaultByteSliceCodec, + &byteSliceCodec{}, []subtest{ { "wrong type", @@ -551,7 +551,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "EmptyInterfaceEncodeValue", - defaultEmptyInterfaceCodec, + &emptyInterfaceCodec{}, []subtest{ { "wrong type", @@ -775,7 +775,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "PointerCodec.EncodeValue", - NewPointerCodec(), + &pointerCodec{}, []subtest{ { "nil", @@ -791,7 +791,7 @@ func TestDefaultValueEncoders(t *testing.T) { nil, nil, nothing, - ValueEncoderError{Name: "PointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: reflect.ValueOf(int32(123456))}, + ValueEncoderError{Name: "pointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: reflect.ValueOf(int32(123456))}, }, { "typed nil", @@ -813,7 +813,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "pointer implementation addressable interface", - NewPointerCodec(), + &pointerCodec{}, []subtest{ { "ValueMarshaler", @@ -1073,7 +1073,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "StructEncodeValue", - defaultTestStructCodec, + newStructCodec(DefaultStructTagParser), []subtest{ { "interface value", @@ -1130,7 +1130,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "CoreArrayEncodeValue", - defaultArrayCodec, + &arrayCodec{}, []subtest{ { "wrong type", diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index 56468e3068..da9efdded3 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -8,47 +8,22 @@ package bson import ( "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -// EmptyInterfaceCodec is the Codec used for interface{} values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// EmptyInterfaceCodec registered. -type EmptyInterfaceCodec struct { - // DecodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the +// emptyInterfaceCodec is the Codec used for interface{} values. +type emptyInterfaceCodec struct { + // decodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the // "Generic" or "Old" BSON binary subtype as a Go byte slice instead of a Binary. - // - // Deprecated: Use bson.Decoder.BinaryAsSlice instead. - DecodeBinaryAsSlice bool + decodeBinaryAsSlice bool } -var ( - defaultEmptyInterfaceCodec = NewEmptyInterfaceCodec() - - // Assert that defaultEmptyInterfaceCodec satisfies the typeDecoder interface, which allows it - // to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a - // collection. - _ typeDecoder = defaultEmptyInterfaceCodec -) - -// NewEmptyInterfaceCodec returns a EmptyInterfaceCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// EmptyInterfaceCodec registered. -func NewEmptyInterfaceCodec(opts ...*bsonoptions.EmptyInterfaceCodecOptions) *EmptyInterfaceCodec { - interfaceOpt := bsonoptions.MergeEmptyInterfaceCodecOptions(opts...) - - codec := EmptyInterfaceCodec{} - if interfaceOpt.DecodeBinaryAsSlice != nil { - codec.DecodeBinaryAsSlice = *interfaceOpt.DecodeBinaryAsSlice - } - return &codec -} +// Assert that defaultEmptyInterfaceCodec satisfies the typeDecoder interface, which allows it +// to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a +// collection. +var _ typeDecoder = (*emptyInterfaceCodec)(nil) // EncodeValue is the ValueEncoderFunc for interface{}. -func (eic EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (eic emptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tEmpty { return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} } @@ -64,7 +39,7 @@ func (eic EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val return encoder.EncodeValue(ec, vw, val.Elem()) } -func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) { +func (eic emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) { isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument if isDocument { if dc.defaultDocumentType != nil { @@ -105,7 +80,7 @@ func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, val return nil, err } -func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (eic emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tEmpty { return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)} } @@ -130,7 +105,7 @@ func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t re return emptyValue, err } - if (eic.DecodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary { + if (eic.decodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary { binElem := elem.Interface().(Binary) if binElem.Subtype == TypeBinaryGeneric || binElem.Subtype == TypeBinaryBinaryOld { elem = reflect.ValueOf(binElem.Data) @@ -141,7 +116,7 @@ func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t re } // DecodeValue is the ValueDecoderFunc for interface{}. -func (eic EmptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (eic emptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tEmpty { return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} } diff --git a/bson/map_codec.go b/bson/map_codec.go index f1294ae99d..e11ffaf726 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -12,34 +12,21 @@ import ( "fmt" "reflect" "strconv" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -var defaultMapCodec = NewMapCodec() - -// MapCodec is the Codec used for map values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// MapCodec registered. -type MapCodec struct { - // DecodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination +// mapCodec is the Codec used for map values. +type mapCodec struct { + // decodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination // value passed to Decode before unmarshaling BSON documents into them. - // - // Deprecated: Use bson.Decoder.ZeroMaps instead. - DecodeZerosMap bool + decodeZerosMap bool - // EncodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of + // encodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of // BSON null. - // - // Deprecated: Use bson.Encoder.NilMapAsEmpty instead. - EncodeNilAsEmpty bool + encodeNilAsEmpty bool - // EncodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name + // encodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name // strings using fmt.Sprintf() instead of the default string conversion logic. - // - // Deprecated: Use bson.Encoder.StringifyMapKeysWithFmt instead. - EncodeKeysWithStringer bool + encodeKeysWithStringer bool } // KeyMarshaler is the interface implemented by an object that can marshal itself into a string key. @@ -58,33 +45,13 @@ type KeyUnmarshaler interface { UnmarshalKey(key string) error } -// NewMapCodec returns a MapCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// MapCodec registered. -func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec { - mapOpt := bsonoptions.MergeMapCodecOptions(opts...) - - codec := MapCodec{} - if mapOpt.DecodeZerosMap != nil { - codec.DecodeZerosMap = *mapOpt.DecodeZerosMap - } - if mapOpt.EncodeNilAsEmpty != nil { - codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty - } - if mapOpt.EncodeKeysWithStringer != nil { - codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer - } - return &codec -} - // EncodeValue is the ValueEncoder for map[*]* types. -func (mc *MapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (mc *mapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Map { return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} } - if val.IsNil() && !mc.EncodeNilAsEmpty && !ec.nilMapAsEmpty { + if val.IsNil() && !mc.encodeNilAsEmpty && !ec.nilMapAsEmpty { // If we have a nil map but we can't WriteNull, that means we're probably trying to encode // to a TopLevel document. We can't currently tell if this is what actually happened, but if // there's a deeper underlying problem, the error will also be returned from WriteDocument, @@ -107,7 +74,7 @@ func (mc *MapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Va // mapEncodeValue handles encoding of the values of a map. The collisionFn returns // true if the provided key exists, this is mainly used for inline maps in the // struct codec. -func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { +func (mc *mapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { elemType := val.Type().Elem() encoder, err := ec.LookupEncoder(elemType) @@ -154,7 +121,7 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val refl } // DecodeValue is the ValueDecoder for map[string/decimal]* types. -func (mc *MapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (mc *mapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) { return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} } @@ -180,7 +147,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va val.Set(reflect.MakeMap(val.Type())) } - if val.Len() > 0 && (mc.DecodeZerosMap || dc.zeroMaps) { + if val.Len() > 0 && (mc.decodeZerosMap || dc.zeroMaps) { clearMap(val) } @@ -228,8 +195,8 @@ func clearMap(m reflect.Value) { } } -func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) { - if mc.EncodeKeysWithStringer || encodeKeysWithStringer { +func (mc *mapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) { + if mc.encodeKeysWithStringer || encodeKeysWithStringer { return fmt.Sprint(val), nil } @@ -274,12 +241,12 @@ func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (s var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem() var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() -func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) { +func (mc *mapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) { keyVal := reflect.ValueOf(key) var err error switch { // First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler - case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType): + case !mc.encodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType): keyVal = reflect.New(keyType) v := keyVal.Interface().(KeyUnmarshaler) err = v.UnmarshalKey(key) @@ -309,7 +276,7 @@ func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, } keyVal = reflect.ValueOf(n).Convert(keyType) case reflect.Float32, reflect.Float64: - if mc.EncodeKeysWithStringer { + if mc.encodeKeysWithStringer { parsed, err := strconv.ParseFloat(key, 64) if err != nil { return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %w", keyType.Kind(), err) diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index a74a5a892d..7571abb19f 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -471,7 +471,7 @@ func (t *prefixPtr) GetBSON() (interface{}, error) { func (t *prefixPtr) SetBSON(raw bson.RawValue) error { var s string if raw.Type == 0x0A { - return ErrSetZero + return bson.ErrSetZero } rval := reflect.ValueOf(&s).Elem() decoder, err := Registry.LookupDecoder(rval.Type()) @@ -498,7 +498,7 @@ func (t prefixVal) GetBSON() (interface{}, error) { func (t *prefixVal) SetBSON(raw bson.RawValue) error { var s string if raw.Type == 0x0A { - return ErrSetZero + return bson.ErrSetZero } rval := reflect.ValueOf(&s).Elem() decoder, err := Registry.LookupDecoder(rval.Type()) @@ -1026,7 +1026,7 @@ func TestDMap(t *testing.T) { } func TestUnmarshalSetterErrSetZero(t *testing.T) { - setterResult["foo"] = ErrSetZero + setterResult["foo"] = bson.ErrSetZero defer delete(setterResult, "field") buf := new(bytes.Buffer) diff --git a/bson/mgocompat/doc.go b/bson/mgocompat/doc.go index 8a9434b1d1..a1c91aff4c 100644 --- a/bson/mgocompat/doc.go +++ b/bson/mgocompat/doc.go @@ -9,11 +9,6 @@ // with mgo's BSON with RespectNilValues set to true. A registry can be configured on a // mongo.Client with the SetRegistry option. See the bson docs for more details on registries. // -// Registry supports Getter and Setter equivalents by registering hooks. Note that if a value -// matches the hook for bson.Marshaler, bson.ValueMarshaler, or bson.Proxy, that -// hook will take priority over the Getter hook. The same is true for the hooks for -// bson.Unmarshaler and bson.ValueUnmarshaler and the Setter hook. -// // The functional differences between Registry and globalsign/mgo's BSON library are: // // 1) Registry errors instead of silently skipping mismatched types when decoding. diff --git a/bson/mgocompat/registry.go b/bson/mgocompat/registry.go index 0d61a029ec..7ffb90b22e 100644 --- a/bson/mgocompat/registry.go +++ b/bson/mgocompat/registry.go @@ -7,101 +7,12 @@ package mgocompat import ( - "errors" - "reflect" - "time" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/bsonoptions" -) - -var ( - // ErrSetZero may be returned from a SetBSON method to have the value set to its respective zero value. - ErrSetZero = errors.New("set to zero") - - tInt = reflect.TypeOf(int(0)) - tTime = reflect.TypeOf(time.Time{}) - tM = reflect.TypeOf(bson.M{}) - tInterfaceSlice = reflect.TypeOf([]interface{}{}) - tByteSlice = reflect.TypeOf([]byte{}) - tEmpty = reflect.TypeOf((*interface{})(nil)).Elem() - tGetter = reflect.TypeOf((*Getter)(nil)).Elem() - tSetter = reflect.TypeOf((*Setter)(nil)).Elem() ) // Registry is the mgo compatible bson.Registry. It contains the default and // primitive codecs with mgo compatible options. -var Registry = newRegistry() +var Registry = bson.NewMgoRegistry() // RespectNilValuesRegistry is the bson.Registry compatible with mgo withSetRespectNilValues set to true. -var RespectNilValuesRegistry = newRespectNilValuesRegistry() - -// newRegistry creates a new bson.Registry configured with the default encoders and decoders. -func newRegistry() *bson.Registry { - reg := bson.NewRegistry() - - structcodec, _ := bson.NewStructCodec(bson.DefaultStructTagParser, - bsonoptions.StructCodec(). - SetDecodeZeroStruct(true). - SetEncodeOmitDefaultStruct(true). - SetOverwriteDuplicatedInlinedFields(false). - SetAllowUnexportedFields(true)) - emptyInterCodec := bson.NewEmptyInterfaceCodec( - bsonoptions.EmptyInterfaceCodec(). - SetDecodeBinaryAsSlice(true)) - mapCodec := bson.NewMapCodec( - bsonoptions.MapCodec(). - SetDecodeZerosMap(true). - SetEncodeNilAsEmpty(true). - SetEncodeKeysWithStringer(true)) - uintcodec := bson.NewUIntCodec(bsonoptions.UIntCodec().SetEncodeToMinSize(true)) - - reg.RegisterTypeDecoder(tEmpty, emptyInterCodec) - reg.RegisterKindDecoder(reflect.String, bson.NewStringCodec(bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false))) - reg.RegisterKindDecoder(reflect.Struct, structcodec) - reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, bson.NewByteSliceCodec(bsonoptions.ByteSliceCodec().SetEncodeNilAsEmpty(true))) - reg.RegisterKindEncoder(reflect.Struct, structcodec) - reg.RegisterKindEncoder(reflect.Slice, bson.NewSliceCodec(bsonoptions.SliceCodec().SetEncodeNilAsEmpty(true))) - reg.RegisterKindEncoder(reflect.Map, mapCodec) - reg.RegisterKindEncoder(reflect.Uint, uintcodec) - reg.RegisterKindEncoder(reflect.Uint8, uintcodec) - reg.RegisterKindEncoder(reflect.Uint16, uintcodec) - reg.RegisterKindEncoder(reflect.Uint32, uintcodec) - reg.RegisterKindEncoder(reflect.Uint64, uintcodec) - reg.RegisterTypeMapEntry(bson.TypeInt32, tInt) - reg.RegisterTypeMapEntry(bson.TypeDateTime, tTime) - reg.RegisterTypeMapEntry(bson.TypeArray, tInterfaceSlice) - reg.RegisterTypeMapEntry(bson.Type(0), tM) - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, tM) - reg.RegisterInterfaceEncoder(tGetter, bson.ValueEncoderFunc(GetterEncodeValue)) - reg.RegisterInterfaceDecoder(tSetter, bson.ValueDecoderFunc(SetterDecodeValue)) - - return reg -} - -// newRespectNilValuesRegistry creates a new bson.Registry configured to behave like mgo/bson -// with RespectNilValues set to true. -func newRespectNilValuesRegistry() *bson.Registry { - reg := newRegistry() - - structcodec, _ := bson.NewStructCodec(bson.DefaultStructTagParser, - bsonoptions.StructCodec(). - SetDecodeZeroStruct(true). - SetEncodeOmitDefaultStruct(true). - SetOverwriteDuplicatedInlinedFields(false). - SetAllowUnexportedFields(true)) - mapCodec := bson.NewMapCodec( - bsonoptions.MapCodec(). - SetDecodeZerosMap(true). - SetEncodeNilAsEmpty(false)) - - reg.RegisterKindDecoder(reflect.Struct, structcodec) - reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, bson.NewByteSliceCodec(bsonoptions.ByteSliceCodec().SetEncodeNilAsEmpty(false))) - reg.RegisterKindEncoder(reflect.Struct, structcodec) - reg.RegisterKindEncoder(reflect.Slice, bson.NewSliceCodec(bsonoptions.SliceCodec().SetEncodeNilAsEmpty(false))) - reg.RegisterKindEncoder(reflect.Map, mapCodec) - - return reg -} +var RespectNilValuesRegistry = bson.NewRespectNilValuesMgoRegistry() diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go new file mode 100644 index 0000000000..398de2afb9 --- /dev/null +++ b/bson/mgoregistry.go @@ -0,0 +1,81 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "errors" + "reflect" +) + +var ( + // ErrSetZero may be returned from a SetBSON method to have the value set to its respective zero value. + ErrSetZero = errors.New("set to zero") + + tInt = reflect.TypeOf(int(0)) + tM = reflect.TypeOf(M{}) + tInterfaceSlice = reflect.TypeOf([]interface{}{}) + tGetter = reflect.TypeOf((*Getter)(nil)).Elem() + tSetter = reflect.TypeOf((*Setter)(nil)).Elem() +) + +// NewMgoRegistry creates a new bson.Registry configured with the default encoders and decoders. +func NewMgoRegistry() *Registry { + reg := NewRegistry() + + structcodec := &structCodec{ + parser: DefaultStructTagParser, + decodeZeroStruct: true, + encodeOmitDefaultStruct: true, + allowUnexportedFields: true, + } + mapCodec := &mapCodec{ + decodeZerosMap: true, + encodeNilAsEmpty: true, + encodeKeysWithStringer: true, + } + uintcodec := &uintCodec{encodeToMinSize: true} + + reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true}) + reg.RegisterKindDecoder(reflect.String, &stringCodec{}) + reg.RegisterKindDecoder(reflect.Struct, structcodec) + reg.RegisterKindDecoder(reflect.Map, mapCodec) + reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: true}) + reg.RegisterKindEncoder(reflect.Struct, structcodec) + reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{encodeNilAsEmpty: true}) + reg.RegisterKindEncoder(reflect.Map, mapCodec) + reg.RegisterKindEncoder(reflect.Uint, uintcodec) + reg.RegisterKindEncoder(reflect.Uint8, uintcodec) + reg.RegisterKindEncoder(reflect.Uint16, uintcodec) + reg.RegisterKindEncoder(reflect.Uint32, uintcodec) + reg.RegisterKindEncoder(reflect.Uint64, uintcodec) + reg.RegisterTypeMapEntry(TypeInt32, tInt) + reg.RegisterTypeMapEntry(TypeDateTime, tTime) + reg.RegisterTypeMapEntry(TypeArray, tInterfaceSlice) + reg.RegisterTypeMapEntry(Type(0), tM) + reg.RegisterTypeMapEntry(TypeEmbeddedDocument, tM) + reg.RegisterInterfaceEncoder(tGetter, ValueEncoderFunc(GetterEncodeValue)) + reg.RegisterInterfaceDecoder(tSetter, ValueDecoderFunc(SetterDecodeValue)) + + return reg +} + +// NewRespectNilValuesMgoRegistry creates a new bson.Registry configured to behave like mgo/bson +// with RespectNilValues set to true. +func NewRespectNilValuesMgoRegistry() *Registry { + reg := NewMgoRegistry() + + mapCodec := &mapCodec{ + decodeZerosMap: true, + } + + reg.RegisterKindDecoder(reflect.Map, mapCodec) + reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: false}) + reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) + reg.RegisterKindEncoder(reflect.Map, mapCodec) + + return reg +} diff --git a/bson/pointer_codec.go b/bson/pointer_codec.go index 5946b9cc9f..4ed1de3013 100644 --- a/bson/pointer_codec.go +++ b/bson/pointer_codec.go @@ -10,34 +10,23 @@ import ( "reflect" ) -var _ ValueEncoder = &PointerCodec{} -var _ ValueDecoder = &PointerCodec{} +var _ ValueEncoder = &pointerCodec{} +var _ ValueDecoder = &pointerCodec{} -// PointerCodec is the Codec used for pointers. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// PointerCodec registered. -type PointerCodec struct { +// pointerCodec is the Codec used for pointers. +type pointerCodec struct { ecache typeEncoderCache dcache typeDecoderCache } -// NewPointerCodec returns a PointerCodec that has been initialized. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// PointerCodec registered. -func NewPointerCodec() *PointerCodec { - return &PointerCodec{} -} - // EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil // or looking up an encoder for the type of value the pointer points to. -func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (pc *pointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.Ptr { if !val.IsValid() { return vw.WriteNull() } - return ValueEncoderError{Name: "PointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} + return ValueEncoderError{Name: "pointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} } if val.IsNil() { @@ -62,9 +51,9 @@ func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflec // DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and // using that to decode. If the BSON value is Null, this method will set the pointer to nil. -func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (pc *pointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Ptr { - return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} + return ValueDecoderError{Name: "pointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} } typ := val.Type() diff --git a/bson/registry.go b/bson/registry.go index d20424b340..71d65259d6 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -297,7 +297,7 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool if !found { defaultEnc, _ = r.kindEncoders.Load(valueType.Kind()) } - return newCondAddrEncoder(ienc.ve, defaultEnc), true + return &condAddrEncoder{canAddrEnc: ienc.ve, elseEnc: defaultEnc}, true } } return nil, false diff --git a/bson/registry_test.go b/bson/registry_test.go index c7963d4edd..b897f04db6 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -195,7 +195,7 @@ func TestRegistryBuilder(t *testing.T) { ti3ImplPtr = reflect.TypeOf((*testInterface3Impl)(nil)) fc1, fc2 = &fakeCodec{num: 1}, &fakeCodec{num: 2} fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec) - pc = NewPointerCodec() + pc = &pointerCodec{} ) reg := newTestRegistry() @@ -334,7 +334,7 @@ func TestRegistryBuilder(t *testing.T) { } allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{}) - comparepc := func(pc1, pc2 *PointerCodec) bool { return true } + comparepc := func(pc1, pc2 *pointerCodec) bool { return true } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Run("Encoder", func(t *testing.T) { @@ -610,7 +610,7 @@ func TestRegistry(t *testing.T) { ti3ImplPtr = reflect.TypeOf((*testInterface3Impl)(nil)) fc1, fc2 = &fakeCodec{num: 1}, &fakeCodec{num: 2} fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec) - pc = NewPointerCodec() + pc = &pointerCodec{} ) reg := newTestRegistry() @@ -749,7 +749,7 @@ func TestRegistry(t *testing.T) { } allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{}) - comparepc := func(pc1, pc2 *PointerCodec) bool { return true } + comparepc := func(pc1, pc2 *pointerCodec) bool { return true } for _, tc := range testCases { tc := tc diff --git a/bson/mgocompat/setter_getter.go b/bson/setter_getter.go similarity index 68% rename from bson/mgocompat/setter_getter.go rename to bson/setter_getter.go index fc620fbba8..3616d25603 100644 --- a/bson/mgocompat/setter_getter.go +++ b/bson/setter_getter.go @@ -4,13 +4,11 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package mgocompat +package bson import ( "errors" "reflect" - - "go.mongodb.org/mongo-driver/bson" ) // Setter interface: a value implementing the bson.Setter interface will receive the BSON @@ -34,7 +32,7 @@ import ( // return raw.Unmarshal(s) // } type Setter interface { - SetBSON(raw bson.RawValue) error + SetBSON(raw RawValue) error } // Getter interface: a value implementing the bson.Getter interface will have its GetBSON @@ -48,35 +46,35 @@ type Getter interface { } // SetterDecodeValue is the ValueDecoderFunc for Setter types. -func SetterDecodeValue(_ bson.DecodeContext, vr bson.ValueReader, val reflect.Value) error { +func SetterDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tSetter) && !reflect.PtrTo(val.Type()).Implements(tSetter)) { - return bson.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} } if val.Kind() == reflect.Ptr && val.IsNil() { if !val.CanSet() { - return bson.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} } val.Set(reflect.New(val.Type().Elem())) } if !val.Type().Implements(tSetter) { if !val.CanAddr() { - return bson.ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tSetter}, Received: val} } val = val.Addr() // If the type doesn't implement the interface, a pointer to it must. } - t, src, err := bson.CopyValueToBytes(vr) + t, src, err := CopyValueToBytes(vr) if err != nil { return err } m, ok := val.Interface().(Setter) if !ok { - return bson.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} } - if err := m.SetBSON(bson.RawValue{Type: t, Value: src}); err != nil { + if err := m.SetBSON(RawValue{Type: t, Value: src}); err != nil { if !errors.Is(err, ErrSetZero) { return err } @@ -86,11 +84,11 @@ func SetterDecodeValue(_ bson.DecodeContext, vr bson.ValueReader, val reflect.Va } // GetterEncodeValue is the ValueEncoderFunc for Getter types. -func GetterEncodeValue(ec bson.EncodeContext, vw bson.ValueWriter, val reflect.Value) error { +func GetterEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Getter switch { case !val.IsValid(): - return bson.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} + return ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} case val.Type().Implements(tGetter): // If Getter is implemented on a concrete type, make sure that val isn't a nil pointer if isImplementationNil(val, tGetter) { @@ -99,7 +97,7 @@ func GetterEncodeValue(ec bson.EncodeContext, vw bson.ValueWriter, val reflect.V case reflect.PtrTo(val.Type()).Implements(tGetter) && val.CanAddr(): val = val.Addr() default: - return bson.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} + return ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} } m, ok := val.Interface().(Getter) @@ -120,12 +118,3 @@ func GetterEncodeValue(ec bson.EncodeContext, vw bson.ValueWriter, val reflect.V } return encoder.EncodeValue(ec, vw, vv) } - -// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type -func isImplementationNil(val reflect.Value, inter reflect.Type) bool { - vt := val.Type() - for vt.Kind() == reflect.Ptr { - vt = vt.Elem() - } - return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil() -} diff --git a/bson/slice_codec.go b/bson/slice_codec.go index f29c36b26d..6d26f6283c 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -10,45 +10,22 @@ import ( "errors" "fmt" "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -var defaultSliceCodec = NewSliceCodec() - -// SliceCodec is the Codec used for slice values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// SliceCodec registered. -type SliceCodec struct { - // EncodeNilAsEmpty causes EncodeValue to marshal nil Go slices as empty BSON arrays instead of +// sliceCodec is the Codec used for slice values. +type sliceCodec struct { + // encodeNilAsEmpty causes EncodeValue to marshal nil Go slices as empty BSON arrays instead of // BSON null. - // - // Deprecated: Use bson.Encoder.NilSliceAsEmpty instead. - EncodeNilAsEmpty bool -} - -// NewSliceCodec returns a MapCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// SliceCodec registered. -func NewSliceCodec(opts ...*bsonoptions.SliceCodecOptions) *SliceCodec { - sliceOpt := bsonoptions.MergeSliceCodecOptions(opts...) - - codec := SliceCodec{} - if sliceOpt.EncodeNilAsEmpty != nil { - codec.EncodeNilAsEmpty = *sliceOpt.EncodeNilAsEmpty - } - return &codec + encodeNilAsEmpty bool } // EncodeValue is the ValueEncoder for slice types. -func (sc SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Slice { return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } - if val.IsNil() && !sc.EncodeNilAsEmpty && !ec.nilSliceAsEmpty { + if val.IsNil() && !sc.encodeNilAsEmpty && !ec.nilSliceAsEmpty { return vw.WriteNull() } @@ -117,7 +94,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.V } // DecodeValue is the ValueDecoder for slice types. -func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Slice { return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } diff --git a/bson/string_codec.go b/bson/string_codec.go index 4681f15bd4..7d7205f34d 100644 --- a/bson/string_codec.go +++ b/bson/string_codec.go @@ -10,42 +10,22 @@ import ( "errors" "fmt" "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -// StringCodec is the Codec used for string values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StringCodec registered. -type StringCodec struct { - // DecodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. +// stringCodec is the Codec used for string values. +type stringCodec struct { + // decodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. // If false, a string made from the raw object ID bytes will be used. Defaults to true. - // - // Deprecated: Decoding object IDs as raw bytes will not be supported in Go Driver 2.0. - DecodeObjectIDAsHex bool + decodeObjectIDAsHex bool } -var ( - defaultStringCodec = NewStringCodec() - - // Assert that defaultStringCodec satisfies the typeDecoder interface, which allows it to be - // used by collection type decoders (e.g. map, slice, etc) to set individual values in a - // collection. - _ typeDecoder = defaultStringCodec -) - -// NewStringCodec returns a StringCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StringCodec registered. -func NewStringCodec(opts ...*bsonoptions.StringCodecOptions) *StringCodec { - stringOpt := bsonoptions.MergeStringCodecOptions(opts...) - return &StringCodec{*stringOpt.DecodeObjectIDAsHex} -} +// Assert that defaultStringCodec satisfies the typeDecoder interface, which allows it to be +// used by collection type decoders (e.g. map, slice, etc) to set individual values in a +// collection. +var _ typeDecoder = (*stringCodec)(nil) // EncodeValue is the ValueEncoder for string types. -func (sc *StringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc *stringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.String { return ValueEncoderError{ Name: "StringEncodeValue", @@ -57,7 +37,7 @@ func (sc *StringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect. return vw.WriteString(val.String()) } -func (sc *StringCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (sc *stringCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t.Kind() != reflect.String { return emptyValue, ValueDecoderError{ Name: "StringDecodeValue", @@ -79,11 +59,10 @@ func (sc *StringCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Ty if err != nil { return emptyValue, err } - if dc.decodeObjectIDAsHex { - str = oid.Hex() - } else { + if !sc.decodeObjectIDAsHex && !dc.decodeObjectIDAsHex { return emptyValue, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set") } + str = oid.Hex() case TypeSymbol: str, err = vr.ReadSymbol() if err != nil { @@ -114,7 +93,7 @@ func (sc *StringCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Ty } // DecodeValue is the ValueDecoder for string types. -func (sc *StringCodec) DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *stringCodec) DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.String { return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} } diff --git a/bson/string_codec_test.go b/bson/string_codec_test.go index 56d1215af5..16d1727d4f 100644 --- a/bson/string_codec_test.go +++ b/bson/string_codec_test.go @@ -25,11 +25,12 @@ func TestStringCodec(t *testing.T) { result string }{ {"default", DecodeContext{}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, - {"decode hex", DecodeContext{decodeObjectIDAsHex: true}, nil, oid.Hex()}, + {"true", DecodeContext{decodeObjectIDAsHex: true}, nil, oid.Hex()}, + {"false", DecodeContext{decodeObjectIDAsHex: false}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - stringCodec := NewStringCodec() + stringCodec := &stringCodec{} actual := reflect.New(reflect.TypeOf("")).Elem() err := stringCodec.DecodeValue(tc.dctx, reader, actual) diff --git a/bson/struct_codec.go b/bson/struct_codec.go index 17e51bce14..0c3eac5c73 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -14,8 +14,6 @@ import ( "strings" "sync" "time" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) // DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type. @@ -49,88 +47,47 @@ func (de *DecodeError) Keys() []string { return reversedKeys } -// StructCodec is the Codec used for struct values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StructCodec registered. -type StructCodec struct { +// structCodec is the Codec used for struct values. +type structCodec struct { cache sync.Map // map[reflect.Type]*structDescription parser StructTagParser - // DecodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the - // destination value passed to Decode before unmarshaling BSON documents into them. - // - // Deprecated: Use bson.Decoder.ZeroStructs instead. - DecodeZeroStruct bool + // decodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the + decodeZeroStruct bool - // DecodeDeepZeroInline causes DecodeValue to delete any existing values from Go structs in the + // decodeDeepZeroInline causes DecodeValue to delete any existing values from Go structs in the // destination value passed to Decode before unmarshaling BSON documents into them. - // - // Deprecated: DecodeDeepZeroInline will not be supported in Go Driver 2.0. - DecodeDeepZeroInline bool + decodeDeepZeroInline bool - // EncodeOmitDefaultStruct causes the Encoder to consider the zero value for a struct (e.g. + // encodeOmitDefaultStruct causes the Encoder to consider the zero value for a struct (e.g. // MyStruct{}) as empty and omit it from the marshaled BSON when the "omitempty" struct tag // option is set. - // - // Deprecated: Use bson.Encoder.OmitZeroStruct instead. - EncodeOmitDefaultStruct bool + encodeOmitDefaultStruct bool - // AllowUnexportedFields allows encoding and decoding values from un-exported struct fields. - // - // Deprecated: AllowUnexportedFields does not work on recent versions of Go and will not be - // supported in Go Driver 2.0. - AllowUnexportedFields bool + // allowUnexportedFields allows encoding and decoding values from un-exported struct fields. + allowUnexportedFields bool - // OverwriteDuplicatedInlinedFields, if false, causes EncodeValue to return an error if there is + // overwriteDuplicatedInlinedFields, if false, causes EncodeValue to return an error if there is // a duplicate field in the marshaled BSON when the "inline" struct tag option is set. The // default value is true. - // - // Deprecated: Use bson.Encoder.ErrorOnInlineDuplicates instead. - OverwriteDuplicatedInlinedFields bool + overwriteDuplicatedInlinedFields bool } -var _ ValueEncoder = &StructCodec{} -var _ ValueDecoder = &StructCodec{} - -// NewStructCodec returns a StructCodec that uses p for struct tag parsing. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// StructCodec registered. -func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) { - if p == nil { - return nil, errors.New("a StructTagParser must be provided to NewStructCodec") - } - - structOpt := bsonoptions.MergeStructCodecOptions(opts...) +var _ ValueEncoder = &structCodec{} +var _ ValueDecoder = &structCodec{} - codec := &StructCodec{ - parser: p, +// newStructCodec returns a StructCodec that uses p for struct tag parsing. +func newStructCodec(p StructTagParser) *structCodec { + return &structCodec{ + parser: p, + overwriteDuplicatedInlinedFields: true, } - - if structOpt.DecodeZeroStruct != nil { - codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct - } - if structOpt.DecodeDeepZeroInline != nil { - codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline - } - if structOpt.EncodeOmitDefaultStruct != nil { - codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct - } - if structOpt.OverwriteDuplicatedInlinedFields != nil { - codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields - } - if structOpt.AllowUnexportedFields != nil { - codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields - } - - return codec, nil } // EncodeValue handles encoding generic struct types. -func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc *structCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Struct { - return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} + return ValueEncoderError{Name: "structCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } sd, err := sc.describeStruct(ec.Registry, val.Type(), ec.useJSONStructTags, ec.errorOnInlineDuplicates) @@ -188,7 +145,7 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect // nil interface separately. empty = rv.IsNil() } else { - empty = isEmpty(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct) + empty = isEmpty(rv, sc.encodeOmitDefaultStruct || ec.omitZeroStruct) } if desc.omitEmpty && empty { continue @@ -223,7 +180,7 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect return exists } - return defaultMapCodec.mapEncodeValue(ec, dw, rv, collisionFn) + return (&mapCodec{}).mapEncodeValue(ec, dw, rv, collisionFn) } return dw.WriteDocumentEnd() @@ -245,9 +202,9 @@ func newDecodeError(key string, original error) error { // DecodeValue implements the Codec interface. // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr. // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared. -func (sc *StructCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Struct { - return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} + return ValueDecoderError{Name: "structCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } switch vrType := vr.Type(); vrType { @@ -275,10 +232,10 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect return err } - if sc.DecodeZeroStruct || dc.zeroStructs { + if sc.decodeZeroStruct || dc.zeroStructs { val.Set(reflect.Zero(val.Type())) } - if sc.DecodeDeepZeroInline && sd.inline { + if sc.decodeDeepZeroInline && sd.inline { val.Set(deepZero(val.Type())) } @@ -461,7 +418,7 @@ func (bi byIndex) Less(i, j int) bool { return len(bi[i].inline) < len(bi[j].inline) } -func (sc *StructCodec) describeStruct( +func (sc *structCodec) describeStruct( r *Registry, t reflect.Type, useJSONStructTags bool, @@ -484,7 +441,7 @@ func (sc *StructCodec) describeStruct( return ds, nil } -func (sc *StructCodec) describeStructSlow( +func (sc *structCodec) describeStructSlow( r *Registry, t reflect.Type, useJSONStructTags bool, @@ -500,7 +457,7 @@ func (sc *StructCodec) describeStructSlow( var fields []fieldDescription for i := 0; i < numFields; i++ { sf := t.Field(i) - if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) { + if sf.PkgPath != "" && (!sc.allowUnexportedFields || !sf.Anonymous) { // field is private or unexported fields aren't allowed, ignore continue } @@ -611,7 +568,7 @@ func (sc *StructCodec) describeStructSlow( continue } dominant, ok := dominantField(fields[i : i+advance]) - if !ok || !sc.OverwriteDuplicatedInlinedFields || errorOnDuplicates { + if !ok || !sc.overwriteDuplicatedInlinedFields || errorOnDuplicates { return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name) } sd.fl = append(sd.fl, dominant) diff --git a/bson/time_codec.go b/bson/time_codec.go index a168d1e769..6bbe300e4a 100644 --- a/bson/time_codec.go +++ b/bson/time_codec.go @@ -10,48 +10,23 @@ import ( "fmt" "reflect" "time" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) const ( timeFormatString = "2006-01-02T15:04:05.999Z07:00" ) -// TimeCodec is the Codec used for time.Time values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// TimeCodec registered. -type TimeCodec struct { - // UseLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. - // - // Deprecated: Use bson.Decoder.UseLocalTimeZone instead. - UseLocalTimeZone bool +// timeCodec is the Codec used for time.Time values. +type timeCodec struct { + // useLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. + useLocalTimeZone bool } -var ( - defaultTimeCodec = NewTimeCodec() - - // Assert that defaultTimeCodec satisfies the typeDecoder interface, which allows it to be used - // by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. - _ typeDecoder = defaultTimeCodec -) - -// NewTimeCodec returns a TimeCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// TimeCodec registered. -func NewTimeCodec(opts ...*bsonoptions.TimeCodecOptions) *TimeCodec { - timeOpt := bsonoptions.MergeTimeCodecOptions(opts...) - - codec := TimeCodec{} - if timeOpt.UseLocalTimeZone != nil { - codec.UseLocalTimeZone = *timeOpt.UseLocalTimeZone - } - return &codec -} +// Assert that defaultTimeCodec satisfies the typeDecoder interface, which allows it to be used +// by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. +var _ typeDecoder = (*timeCodec)(nil) -func (tc *TimeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (tc *timeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tTime { return emptyValue, ValueDecoderError{ Name: "TimeDecodeValue", @@ -102,14 +77,14 @@ func (tc *TimeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type return emptyValue, fmt.Errorf("cannot decode %v into a time.Time", vrType) } - if !tc.UseLocalTimeZone && !dc.useLocalTimeZone { + if !tc.useLocalTimeZone && !dc.useLocalTimeZone { timeVal = timeVal.UTC() } return reflect.ValueOf(timeVal), nil } // DecodeValue is the ValueDecoderFunc for time.Time. -func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tTime { return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val} } @@ -124,7 +99,7 @@ func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.V } // EncodeValue is the ValueEncoderFunc for time.TIme. -func (tc *TimeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (tc *timeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTime { return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} } diff --git a/bson/time_codec_test.go b/bson/time_codec_test.go index 1f185692da..70f52906b2 100644 --- a/bson/time_codec_test.go +++ b/bson/time_codec_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "go.mongodb.org/mongo-driver/bson/bsonoptions" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) @@ -22,20 +21,17 @@ func TestTimeCodec(t *testing.T) { t.Run("UseLocalTimeZone", func(t *testing.T) { reader := &valueReaderWriter{BSONType: TypeDateTime, Return: now.UnixNano() / int64(time.Millisecond)} testCases := []struct { - name string - opts *bsonoptions.TimeCodecOptions - utc bool + name string + codec *timeCodec + utc bool }{ - {"default", bsonoptions.TimeCodec(), true}, - {"false", bsonoptions.TimeCodec().SetUseLocalTimeZone(false), true}, - {"true", bsonoptions.TimeCodec().SetUseLocalTimeZone(true), false}, + {"default", &timeCodec{}, true}, + {"true", &timeCodec{useLocalTimeZone: true}, false}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - timeCodec := NewTimeCodec(tc.opts) - actual := reflect.New(reflect.TypeOf(now)).Elem() - err := timeCodec.DecodeValue(DecodeContext{}, reader, actual) + err := tc.codec.DecodeValue(DecodeContext{}, reader, actual) assert.Nil(t, err, "TimeCodec.DecodeValue error: %v", err) actualTime := actual.Interface().(time.Time) @@ -69,7 +65,7 @@ func TestTimeCodec(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { actual := reflect.New(reflect.TypeOf(now)).Elem() - err := defaultTimeCodec.DecodeValue(DecodeContext{}, tc.reader, actual) + err := (&timeCodec{}).DecodeValue(DecodeContext{}, tc.reader, actual) assert.Nil(t, err, "DecodeValue error: %v", err) actualTime := actual.Interface().(time.Time) diff --git a/bson/uint_codec.go b/bson/uint_codec.go index 73bc01966e..b8f97ae5ab 100644 --- a/bson/uint_codec.go +++ b/bson/uint_codec.go @@ -10,46 +10,21 @@ import ( "fmt" "math" "reflect" - - "go.mongodb.org/mongo-driver/bson/bsonoptions" ) -// UIntCodec is the Codec used for uint values. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// UIntCodec registered. -type UIntCodec struct { - // EncodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the +// uintCodec is the Codec used for uint values. +type uintCodec struct { + // encodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the // minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value. - // - // Deprecated: Use bson.Encoder.IntMinSize instead. - EncodeToMinSize bool + encodeToMinSize bool } -var ( - defaultUIntCodec = NewUIntCodec() - - // Assert that defaultUIntCodec satisfies the typeDecoder interface, which allows it to be used - // by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. - _ typeDecoder = defaultUIntCodec -) - -// NewUIntCodec returns a UIntCodec with options opts. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the -// UIntCodec registered. -func NewUIntCodec(opts ...*bsonoptions.UIntCodecOptions) *UIntCodec { - uintOpt := bsonoptions.MergeUIntCodecOptions(opts...) - - codec := UIntCodec{} - if uintOpt.EncodeToMinSize != nil { - codec.EncodeToMinSize = *uintOpt.EncodeToMinSize - } - return &codec -} +// Assert that defaultUIntCodec satisfies the typeDecoder interface, which allows it to be used +// by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. +var _ typeDecoder = (*uintCodec)(nil) // EncodeValue is the ValueEncoder for uint types. -func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (uic *uintCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Uint8, reflect.Uint16: return vw.WriteInt32(int32(val.Uint())) @@ -57,7 +32,7 @@ func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect. u64 := val.Uint() // If ec.MinSize or if encodeToMinSize is true for a non-uint64 value we should write val as an int32 - useMinSize := ec.MinSize || (uic.EncodeToMinSize && val.Kind() != reflect.Uint64) + useMinSize := ec.MinSize || (uic.encodeToMinSize && val.Kind() != reflect.Uint64) if u64 <= math.MaxInt32 && useMinSize { return vw.WriteInt32(int32(u64)) @@ -75,7 +50,7 @@ func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect. } } -func (uic *UIntCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (uic *uintCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { var i64 int64 var err error switch vrType := vr.Type(); vrType { @@ -163,7 +138,7 @@ func (uic *UIntCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Typ } // DecodeValue is the ValueDecoder for uint types. -func (uic *UIntCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (uic *uintCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if !val.CanSet() { return ValueDecoderError{ Name: "UintDecodeValue", diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index 8d9dfb5351..fd379b5daa 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -76,7 +76,7 @@ func TestUnmarshalValue(t *testing.T) { }, } reg := NewRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), NewSliceCodec()) + reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}) for _, tc := range testCases { tc := tc @@ -111,7 +111,7 @@ func BenchmarkSliceCodecUnmarshal(b *testing.B) { }, } reg := NewRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), NewSliceCodec()) + reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}) for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { b.RunParallel(func(pb *testing.PB) { From 1834500adef0a95d9cea7ad52ec9a563aee37e36 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 30 Apr 2024 09:26:15 -0400 Subject: [PATCH 03/15] WIP --- bson/bsoncodec.go | 171 +++------------------------- bson/bsoncodec_test.go | 30 ----- bson/codec_cache.go | 32 +++--- bson/codec_cache_test.go | 4 +- bson/cond_addr_codec.go | 14 +-- bson/decoder.go | 55 ++------- bson/default_value_decoders.go | 6 +- bson/default_value_decoders_test.go | 12 +- bson/default_value_encoders.go | 4 +- bson/default_value_encoders_test.go | 38 +++---- bson/empty_interface_codec.go | 4 +- bson/encoder.go | 52 ++------- bson/map_codec.go | 2 +- bson/pointer_codec.go | 4 +- bson/primitive_codecs_test.go | 4 +- bson/registry.go | 32 +++--- bson/registry_test.go | 18 +-- bson/slice_codec.go | 2 +- bson/struct_codec.go | 16 +-- bson/truncation_test.go | 4 +- bson/uint_codec.go | 4 +- 21 files changed, 134 insertions(+), 374 deletions(-) diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index b7aaadf2c2..e3ae365fd4 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -77,12 +77,10 @@ func (vde ValueDecoderError) Error() string { type EncodeContext struct { *Registry - // MinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, + // minSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, // uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) // that can represent the integer value. - // - // Deprecated: Use bson.Encoder.IntMinSize instead. - MinSize bool + minSize bool errorOnInlineDuplicates bool stringifyMapKeysWithFmt bool @@ -93,85 +91,16 @@ type EncodeContext struct { useJSONStructTags bool } -// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in -// the marshaled BSON when the "inline" struct tag option is set. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.ErrorOnInlineDuplicates] instead. -func (ec *EncodeContext) ErrorOnInlineDuplicates() { - ec.errorOnInlineDuplicates = true -} - -// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name -// strings using fmt.Sprintf() instead of the default string conversion logic. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.StringifyMapKeysWithFmt] instead. -func (ec *EncodeContext) StringifyMapKeysWithFmt() { - ec.stringifyMapKeysWithFmt = true -} - -// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON -// null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilMapAsEmpty] instead. -func (ec *EncodeContext) NilMapAsEmpty() { - ec.nilMapAsEmpty = true -} - -// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON -// null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilSliceAsEmpty] instead. -func (ec *EncodeContext) NilSliceAsEmpty() { - ec.nilSliceAsEmpty = true -} - -// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values -// instead of BSON null. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.NilByteSliceAsEmpty] instead. -func (ec *EncodeContext) NilByteSliceAsEmpty() { - ec.nilByteSliceAsEmpty = true -} - -// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{}) -// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set. -// -// Note that the Encoder only examines exported struct fields when determining if a struct is the -// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.OmitZeroStruct] instead. -func (ec *EncodeContext) OmitZeroStruct() { - ec.omitZeroStruct = true -} - -// UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson" -// struct tag is not specified. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.UseJSONStructTags] instead. -func (ec *EncodeContext) UseJSONStructTags() { - ec.useJSONStructTags = true -} - // DecodeContext is the contextual information required for a Codec to decode a // value. type DecodeContext struct { *Registry - // Truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" - // values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64, - // uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to - // BSON "decimal128" values. - // - // Deprecated: Use bson.Decoder.AllowTruncatingDoubles instead. - Truncate bool - - // Ancestor is the type of a containing document. This is mainly used to determine what type + // ancestor is the type of a containing document. This is mainly used to determine what type // should be used when decoding an embedded document into an empty interface. For example, if // Ancestor is a bson.M, BSON embedded document values being decoded into an empty interface // will be decoded into a bson.M. - // - // Deprecated: Use bson.Decoder.DefaultDocumentM or bson.Decoder.DefaultDocumentD instead. - Ancestor reflect.Type + ancestor reflect.Type // defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the // usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is @@ -179,6 +108,12 @@ type DecodeContext struct { // error. DocumentType overrides the Ancestor field. defaultDocumentType reflect.Type + // truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" + // values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64, + // uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to + // BSON "decimal128" values. + truncate bool + binaryAsSlice bool decodeObjectIDAsHex bool useJSONStructTags bool @@ -187,85 +122,13 @@ type DecodeContext struct { zeroStructs bool } -// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or -// "Old" BSON binary subtype as a Go byte slice instead of a Binary. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.BinaryAsSlice] instead. -func (dc *DecodeContext) BinaryAsSlice() { - dc.binaryAsSlice = true -} - -// DecodeObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DecodeObjectIDAsHex] instead. -func (dc *DecodeContext) DecodeObjectIDAsHex() { - dc.decodeObjectIDAsHex = true -} - -// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson" -// struct tag is not specified. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.UseJSONStructTags] instead. -func (dc *DecodeContext) UseJSONStructTags() { - dc.useJSONStructTags = true -} - -// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead -// of the UTC timezone. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.UseLocalTimeZone] instead. -func (dc *DecodeContext) UseLocalTimeZone() { - dc.useLocalTimeZone = true -} - -// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value -// passed to Decode before unmarshaling BSON documents into them. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroMaps] instead. -func (dc *DecodeContext) ZeroMaps() { - dc.zeroMaps = true -} - -// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination -// value passed to Decode before unmarshaling BSON documents into them. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.ZeroStructs] instead. -func (dc *DecodeContext) ZeroStructs() { - dc.zeroStructs = true -} - -// DefaultDocumentM causes the Decoder to always unmarshal documents into the M type. This -// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DefaultDocumentM] instead. -func (dc *DecodeContext) DefaultDocumentM() { - dc.defaultDocumentType = reflect.TypeOf(M{}) -} - -// DefaultDocumentD causes the Decoder to always unmarshal documents into the D type. This -// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DefaultDocumentD] instead. -func (dc *DecodeContext) DefaultDocumentD() { - dc.defaultDocumentType = reflect.TypeOf(D{}) -} - -// ValueCodec is an interface for encoding and decoding a reflect.Value. -// values. -// -// Deprecated: Use [ValueEncoder] and [ValueDecoder] instead. -type ValueCodec interface { - ValueEncoder - ValueDecoder -} - -// ValueEncoder is the interface implemented by types that can encode a provided Go type to BSON. +// valueEncoder is the interface implemented by types that can encode a provided Go type to BSON. // The value to encode is provided as a reflect.Value and a bson.ValueWriter is used within the // EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc // is provided to allow use of a function with the correct signature as a ValueEncoder. An // EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and // to provide configuration information. -type ValueEncoder interface { +type valueEncoder interface { EncodeValue(EncodeContext, ValueWriter, reflect.Value) error } @@ -278,12 +141,12 @@ func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val ref return fn(ec, vw, val) } -// ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type. +// valueDecoder is the interface implemented by types that can decode BSON to a provided Go type. // Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc, // ValueDecoderFunc is provided to allow the use of a function with the correct signature as a // ValueDecoder. A DecodeContext instance is provided and serves similar functionality to the // EncodeContext. -type ValueDecoder interface { +type valueDecoder interface { DecodeValue(DecodeContext, ValueReader, reflect.Value) error } @@ -314,17 +177,17 @@ type decodeAdapter struct { typeDecoderFunc } -var _ ValueDecoder = decodeAdapter{} +var _ valueDecoder = decodeAdapter{} var _ typeDecoder = decodeAdapter{} // decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type // t and calls decoder.DecodeValue on it. -func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func decodeTypeOrValue(decoder valueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { td, _ := decoder.(typeDecoder) return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true) } -func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { +func decodeTypeOrValueWithInfo(vd valueDecoder, td typeDecoder, dc DecodeContext, vr ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { if td != nil { val, err := td.decodeType(dc, vr, t) if err == nil && convert && val.Type() != t { diff --git a/bson/bsoncodec_test.go b/bson/bsoncodec_test.go index d1dc21a953..797bc9b383 100644 --- a/bson/bsoncodec_test.go +++ b/bson/bsoncodec_test.go @@ -7,40 +7,10 @@ package bson import ( - "fmt" "reflect" "testing" ) -func ExampleValueEncoder() { - var _ ValueEncoderFunc = func(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - if val.Kind() != reflect.String { - return ValueEncoderError{Name: "StringEncodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} - } - - return vw.WriteString(val.String()) - } -} - -func ExampleValueDecoder() { - var _ ValueDecoderFunc = func(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Kind() != reflect.String { - return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} - } - - if vr.Type() != TypeString { - return fmt.Errorf("cannot decode %v into a string type", vr.Type()) - } - - str, err := vr.ReadString() - if err != nil { - return err - } - val.SetString(str) - return nil - } -} - type llCodec struct { t *testing.T decodeval interface{} diff --git a/bson/codec_cache.go b/bson/codec_cache.go index b4042822e6..f28ec30911 100644 --- a/bson/codec_cache.go +++ b/bson/codec_cache.go @@ -29,20 +29,20 @@ type typeEncoderCache struct { cache sync.Map // map[reflect.Type]ValueEncoder } -func (c *typeEncoderCache) Store(rt reflect.Type, enc ValueEncoder) { +func (c *typeEncoderCache) Store(rt reflect.Type, enc valueEncoder) { c.cache.Store(rt, enc) } -func (c *typeEncoderCache) Load(rt reflect.Type) (ValueEncoder, bool) { +func (c *typeEncoderCache) Load(rt reflect.Type) (valueEncoder, bool) { if v, _ := c.cache.Load(rt); v != nil { - return v.(ValueEncoder), true + return v.(valueEncoder), true } return nil, false } -func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc ValueEncoder) ValueEncoder { +func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc valueEncoder) valueEncoder { if v, loaded := c.cache.LoadOrStore(rt, enc); loaded { - enc = v.(ValueEncoder) + enc = v.(valueEncoder) } return enc } @@ -62,20 +62,20 @@ type typeDecoderCache struct { cache sync.Map // map[reflect.Type]ValueDecoder } -func (c *typeDecoderCache) Store(rt reflect.Type, dec ValueDecoder) { +func (c *typeDecoderCache) Store(rt reflect.Type, dec valueDecoder) { c.cache.Store(rt, dec) } -func (c *typeDecoderCache) Load(rt reflect.Type) (ValueDecoder, bool) { +func (c *typeDecoderCache) Load(rt reflect.Type) (valueDecoder, bool) { if v, _ := c.cache.Load(rt); v != nil { - return v.(ValueDecoder), true + return v.(valueDecoder), true } return nil, false } -func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec ValueDecoder) ValueDecoder { +func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec valueDecoder) valueDecoder { if v, loaded := c.cache.LoadOrStore(rt, dec); loaded { - dec = v.(ValueDecoder) + dec = v.(valueDecoder) } return dec } @@ -96,20 +96,20 @@ func (c *typeDecoderCache) Clone() *typeDecoderCache { // is always the same (since different concrete types may implement the // ValueEncoder interface). type kindEncoderCacheEntry struct { - enc ValueEncoder + enc valueEncoder } type kindEncoderCache struct { entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry } -func (c *kindEncoderCache) Store(rt reflect.Kind, enc ValueEncoder) { +func (c *kindEncoderCache) Store(rt reflect.Kind, enc valueEncoder) { if enc != nil && rt < reflect.Kind(len(c.entries)) { c.entries[rt].Store(&kindEncoderCacheEntry{enc: enc}) } } -func (c *kindEncoderCache) Load(rt reflect.Kind) (ValueEncoder, bool) { +func (c *kindEncoderCache) Load(rt reflect.Kind) (valueEncoder, bool) { if rt < reflect.Kind(len(c.entries)) { if ent, ok := c.entries[rt].Load().(*kindEncoderCacheEntry); ok { return ent.enc, ent.enc != nil @@ -133,20 +133,20 @@ func (c *kindEncoderCache) Clone() *kindEncoderCache { // is always the same (since different concrete types may implement the // ValueDecoder interface). type kindDecoderCacheEntry struct { - dec ValueDecoder + dec valueDecoder } type kindDecoderCache struct { entries [reflect.UnsafePointer + 1]atomic.Value // *kindDecoderCacheEntry } -func (c *kindDecoderCache) Store(rt reflect.Kind, dec ValueDecoder) { +func (c *kindDecoderCache) Store(rt reflect.Kind, dec valueDecoder) { if rt < reflect.Kind(len(c.entries)) { c.entries[rt].Store(&kindDecoderCacheEntry{dec: dec}) } } -func (c *kindDecoderCache) Load(rt reflect.Kind) (ValueDecoder, bool) { +func (c *kindDecoderCache) Load(rt reflect.Kind) (valueDecoder, bool) { if rt < reflect.Kind(len(c.entries)) { if ent, ok := c.entries[rt].Load().(*kindDecoderCacheEntry); ok { return ent.dec, ent.dec != nil diff --git a/bson/codec_cache_test.go b/bson/codec_cache_test.go index d48e05f5a3..3937db68ac 100644 --- a/bson/codec_cache_test.go +++ b/bson/codec_cache_test.go @@ -134,7 +134,7 @@ func TestKindCacheClone(t *testing.T) { func TestKindCacheEncoderNilEncoder(t *testing.T) { t.Run("Encoder", func(t *testing.T) { c := new(kindEncoderCache) - c.Store(reflect.Invalid, ValueEncoder(nil)) + c.Store(reflect.Invalid, valueEncoder(nil)) v, ok := c.Load(reflect.Invalid) if v != nil || ok { t.Errorf("Load of nil ValueEncoder should return: nil, false; got: %v, %t", v, ok) @@ -142,7 +142,7 @@ func TestKindCacheEncoderNilEncoder(t *testing.T) { }) t.Run("Decoder", func(t *testing.T) { c := new(kindDecoderCache) - c.Store(reflect.Invalid, ValueDecoder(nil)) + c.Store(reflect.Invalid, valueDecoder(nil)) v, ok := c.Load(reflect.Invalid) if v != nil || ok { t.Errorf("Load of nil ValueDecoder should return: nil, false; got: %v, %t", v, ok) diff --git a/bson/cond_addr_codec.go b/bson/cond_addr_codec.go index 26eed212f1..a64a9c434f 100644 --- a/bson/cond_addr_codec.go +++ b/bson/cond_addr_codec.go @@ -12,11 +12,11 @@ import ( // condAddrEncoder is the encoder used when a pointer to the encoding value has an encoder. type condAddrEncoder struct { - canAddrEnc ValueEncoder - elseEnc ValueEncoder + canAddrEnc valueEncoder + elseEnc valueEncoder } -var _ ValueEncoder = (*condAddrEncoder)(nil) +var _ valueEncoder = (*condAddrEncoder)(nil) // EncodeValue is the ValueEncoderFunc for a value that may be addressable. func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { @@ -31,14 +31,14 @@ func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw ValueWriter, val re // condAddrDecoder is the decoder used when a pointer to the value has a decoder. type condAddrDecoder struct { - canAddrDec ValueDecoder - elseDec ValueDecoder + canAddrDec valueDecoder + elseDec valueDecoder } -var _ ValueDecoder = (*condAddrDecoder)(nil) +var _ valueDecoder = (*condAddrDecoder)(nil) // newCondAddrDecoder returns an CondAddrDecoder. -func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder { +func newCondAddrDecoder(canAddrDec, elseDec valueDecoder) *condAddrDecoder { decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec} return &decoder } diff --git a/bson/decoder.go b/bson/decoder.go index 898dbc87af..cc53f422f9 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -30,18 +30,6 @@ var decPool = sync.Pool{ type Decoder struct { dc DecodeContext vr ValueReader - - // We persist defaultDocumentM and defaultDocumentD on the Decoder to prevent overwriting from - // (*Decoder).SetContext. - defaultDocumentM bool - defaultDocumentD bool - - binaryAsSlice bool - decodeObjectIDAsHex bool - useJSONStructTags bool - useLocalTimeZone bool - zeroMaps bool - zeroStructs bool } // NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr. @@ -85,31 +73,6 @@ func (d *Decoder) Decode(val interface{}) error { return err } - if d.defaultDocumentM { - d.dc.DefaultDocumentM() - } - if d.defaultDocumentD { - d.dc.DefaultDocumentD() - } - if d.binaryAsSlice { - d.dc.BinaryAsSlice() - } - if d.decodeObjectIDAsHex { - d.dc.DecodeObjectIDAsHex() - } - if d.useJSONStructTags { - d.dc.UseJSONStructTags() - } - if d.useLocalTimeZone { - d.dc.UseLocalTimeZone() - } - if d.zeroMaps { - d.dc.ZeroMaps() - } - if d.zeroStructs { - d.dc.ZeroStructs() - } - return decoder.DecodeValue(d.dc, d.vr, rval) } @@ -127,53 +90,53 @@ func (d *Decoder) SetRegistry(r *Registry) { // DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This // behavior is restricted to data typed as "interface{}" or "map[string]interface{}". func (d *Decoder) DefaultDocumentM() { - d.defaultDocumentM = true + d.dc.defaultDocumentType = reflect.TypeOf(M{}) } // DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This // behavior is restricted to data typed as "interface{}" or "map[string]interface{}". func (d *Decoder) DefaultDocumentD() { - d.defaultDocumentD = true + d.dc.defaultDocumentType = reflect.TypeOf(D{}) } // AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values // when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct // field. The truncation logic does not apply to BSON "decimal128" values. func (d *Decoder) AllowTruncatingDoubles() { - d.dc.Truncate = true + d.dc.truncate = true } // BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or // "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary. func (d *Decoder) BinaryAsSlice() { - d.binaryAsSlice = true + d.dc.binaryAsSlice = true } // DecodeObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. func (d *Decoder) DecodeObjectIDAsHex() { - d.decodeObjectIDAsHex = true + d.dc.decodeObjectIDAsHex = true } // UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson" // struct tag is not specified. func (d *Decoder) UseJSONStructTags() { - d.useJSONStructTags = true + d.dc.useJSONStructTags = true } // UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead // of the UTC timezone. func (d *Decoder) UseLocalTimeZone() { - d.useLocalTimeZone = true + d.dc.useLocalTimeZone = true } // ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value // passed to Decode before unmarshaling BSON documents into them. func (d *Decoder) ZeroMaps() { - d.zeroMaps = true + d.dc.zeroMaps = true } // ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination // value passed to Decode before unmarshaling BSON documents into them. func (d *Decoder) ZeroStructs() { - d.zeroStructs = true + d.dc.zeroStructs = true } diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 5ad52e19e2..cfa09bdd10 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -118,7 +118,7 @@ func dDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { switch vrType := vr.Type(); vrType { case Type(0), TypeEmbeddedDocument: - dc.Ancestor = tD + dc.ancestor = tD case TypeNull: val.Set(reflect.Zero(val.Type())) return vr.ReadNull() @@ -248,7 +248,7 @@ func intDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Va if err != nil { return emptyValue, err } - if !dc.Truncate && math.Floor(f64) != f64 { + if !dc.truncate && math.Floor(f64) != f64 { return emptyValue, errCannotTruncate } if f64 > float64(math.MaxInt64) { @@ -373,7 +373,7 @@ func floatDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect. switch t.Kind() { case reflect.Float32: - if !dc.Truncate && float64(float32(f)) != f { + if !dc.truncate && float64(float32(f)) != f { return emptyValue, errCannotTruncate } diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 258ef9e758..fc86145bfd 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -65,7 +65,7 @@ func TestDefaultValueDecoders(t *testing.T) { testCases := []struct { name string - vd ValueDecoder + vd valueDecoder subtests []subtest }{ { @@ -191,7 +191,7 @@ func TestDefaultValueDecoders(t *testing.T) { nil, }, { - "ReadDouble (truncate)", int64(3), &DecodeContext{Truncate: true}, + "ReadDouble (truncate)", int64(3), &DecodeContext{truncate: true}, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, nil, }, @@ -423,7 +423,7 @@ func TestDefaultValueDecoders(t *testing.T) { nil, }, { - "ReadDouble (truncate)", uint64(3), &DecodeContext{Truncate: true}, + "ReadDouble (truncate)", uint64(3), &DecodeContext{truncate: true}, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, nil, }, @@ -674,7 +674,7 @@ func TestDefaultValueDecoders(t *testing.T) { nil, }, { - "float32/fast path (truncate)", float32(3.14), &DecodeContext{Truncate: true}, + "float32/fast path (truncate)", float32(3.14), &DecodeContext{truncate: true}, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, nil, }, @@ -712,7 +712,7 @@ func TestDefaultValueDecoders(t *testing.T) { nil, }, { - "float32/reflection path (truncate)", myfloat32(3.14), &DecodeContext{Truncate: true}, + "float32/reflection path (truncate)", myfloat32(3.14), &DecodeContext{truncate: true}, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, nil, }, @@ -3566,7 +3566,7 @@ func TestDefaultValueDecoders(t *testing.T) { val interface{} vr ValueReader registry *Registry // buildDefaultRegistry will be used if this is nil - decoder ValueDecoder + decoder valueDecoder err error }{ { diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 6b2ff14f61..31c4bf314c 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -124,7 +124,7 @@ func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { return vw.WriteInt64(i64) case reflect.Int64: i64 := val.Int() - if ec.MinSize && fitsIn32Bits(i64) { + if ec.minSize && fitsIn32Bits(i64) { return vw.WriteInt32(int32(i64)) } return vw.WriteInt64(i64) @@ -263,7 +263,7 @@ func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error return aw.WriteArrayEnd() } -func lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { +func lookupElementEncoder(ec EncodeContext, origEncoder valueEncoder, currVal reflect.Value) (valueEncoder, reflect.Value, error) { if origEncoder != nil || (currVal.Kind() != reflect.Interface) { return origEncoder, currVal, nil } diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 797a77322a..9c869448f9 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -74,7 +74,7 @@ func TestDefaultValueEncoders(t *testing.T) { testCases := []struct { name string - ve ValueEncoder + ve valueEncoder subtests []subtest }{ { @@ -113,9 +113,9 @@ func TestDefaultValueEncoders(t *testing.T) { {"int16/fast path", int16(32767), nil, nil, writeInt32, nil}, {"int32/fast path", int32(2147483647), nil, nil, writeInt32, nil}, {"int64/fast path", int64(1234567890987), nil, nil, writeInt64, nil}, - {"int64/fast path - minsize", int64(math.MaxInt32), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, + {"int64/fast path - minsize", int64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, {"int/fast path - positive int32", int(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, {"int/fast path - negative int32", int(math.MinInt32 + 1), nil, nil, writeInt32, nil}, {"int/fast path - MaxInt32", int(math.MaxInt32), nil, nil, writeInt32, nil}, @@ -124,9 +124,9 @@ func TestDefaultValueEncoders(t *testing.T) { {"int16/reflection path", myint16(32767), nil, nil, writeInt32, nil}, {"int32/reflection path", myint32(2147483647), nil, nil, writeInt32, nil}, {"int64/reflection path", myint64(1234567890987), nil, nil, writeInt64, nil}, - {"int64/reflection path - minsize", myint64(math.MaxInt32), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, + {"int64/reflection path - minsize", myint64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, {"int/reflection path - positive int32", myint(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, {"int/reflection path - negative int32", myint(math.MinInt32 + 1), nil, nil, writeInt32, nil}, {"int/reflection path - MaxInt32", myint(math.MaxInt32), nil, nil, writeInt32, nil}, @@ -154,23 +154,23 @@ func TestDefaultValueEncoders(t *testing.T) { {"uint32/fast path", uint32(2147483647), nil, nil, writeInt64, nil}, {"uint64/fast path", uint64(1234567890987), nil, nil, writeInt64, nil}, {"uint/fast path", uint(1234567), nil, nil, writeInt64, nil}, - {"uint32/fast path - minsize", uint32(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint64/fast path - minsize", uint64(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint/fast path - minsize", uint(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint32/fast path - minsize too large", uint32(2147483648), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"uint64/fast path - minsize too large", uint64(2147483648), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"uint/fast path - minsize too large", uint(2147483648), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, + {"uint32/fast path - minsize", uint32(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + {"uint64/fast path - minsize", uint64(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + {"uint/fast path - minsize", uint(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + {"uint32/fast path - minsize too large", uint32(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"uint64/fast path - minsize too large", uint64(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"uint/fast path - minsize too large", uint(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, {"uint64/fast path - overflow", uint64(1 << 63), nil, nil, nothing, fmt.Errorf("%d overflows int64", uint64(1<<63))}, {"uint8/reflection path", myuint8(127), nil, nil, writeInt32, nil}, {"uint16/reflection path", myuint16(32767), nil, nil, writeInt32, nil}, {"uint32/reflection path", myuint32(2147483647), nil, nil, writeInt64, nil}, {"uint64/reflection path", myuint64(1234567890987), nil, nil, writeInt64, nil}, - {"uint32/reflection path - minsize", myuint32(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint64/reflection path - minsize", myuint64(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint/reflection path - minsize", myuint(2147483647), &EncodeContext{MinSize: true}, nil, writeInt32, nil}, - {"uint32/reflection path - minsize too large", myuint(1 << 31), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"uint64/reflection path - minsize too large", myuint64(1 << 31), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, - {"uint/reflection path - minsize too large", myuint(2147483648), &EncodeContext{MinSize: true}, nil, writeInt64, nil}, + {"uint32/reflection path - minsize", myuint32(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + {"uint64/reflection path - minsize", myuint64(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + {"uint/reflection path - minsize", myuint(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + {"uint32/reflection path - minsize too large", myuint(1 << 31), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"uint64/reflection path - minsize too large", myuint64(1 << 31), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"uint/reflection path - minsize too large", myuint(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, {"uint64/reflection path - overflow", myuint64(1 << 63), nil, nil, nothing, fmt.Errorf("%d overflows int64", uint64(1<<63))}, }, }, diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index da9efdded3..fce09c0ac1 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -47,11 +47,11 @@ func (eic emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, val // that type. return dc.defaultDocumentType, nil } - if dc.Ancestor != nil { + if dc.ancestor != nil { // Using ancestor information rather than looking up the type map entry forces consistent decoding. // If we're decoding into a bson.D, subdocuments should also be decoded as bson.D, even if a type map entry // has been registered. - return dc.Ancestor, nil + return dc.ancestor, nil } } diff --git a/bson/encoder.go b/bson/encoder.go index fb865cd285..1b348f9488 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -25,15 +25,6 @@ var encPool = sync.Pool{ type Encoder struct { ec EncodeContext vw ValueWriter - - errorOnInlineDuplicates bool - intMinSize bool - stringifyMapKeysWithFmt bool - nilMapAsEmpty bool - nilSliceAsEmpty bool - nilByteSliceAsEmpty bool - omitZeroStruct bool - useJSONStructTags bool } // NewEncoder returns a new encoder that uses the DefaultRegistry to write to vw. @@ -62,33 +53,6 @@ func (e *Encoder) Encode(val interface{}) error { return err } - // Copy the configurations applied to the Encoder over to the EncodeContext, which actually - // communicates those configurations to the default ValueEncoders. - if e.errorOnInlineDuplicates { - e.ec.ErrorOnInlineDuplicates() - } - if e.intMinSize { - e.ec.MinSize = true - } - if e.stringifyMapKeysWithFmt { - e.ec.StringifyMapKeysWithFmt() - } - if e.nilMapAsEmpty { - e.ec.NilMapAsEmpty() - } - if e.nilSliceAsEmpty { - e.ec.NilSliceAsEmpty() - } - if e.nilByteSliceAsEmpty { - e.ec.NilByteSliceAsEmpty() - } - if e.omitZeroStruct { - e.ec.OmitZeroStruct() - } - if e.useJSONStructTags { - e.ec.UseJSONStructTags() - } - return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val)) } @@ -106,38 +70,38 @@ func (e *Encoder) SetRegistry(r *Registry) { // ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in // the marshaled BSON when the "inline" struct tag option is set. func (e *Encoder) ErrorOnInlineDuplicates() { - e.errorOnInlineDuplicates = true + e.ec.errorOnInlineDuplicates = true } // IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint, // uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can // represent the integer value. func (e *Encoder) IntMinSize() { - e.intMinSize = true + e.ec.minSize = true } // StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name // strings using fmt.Sprint instead of the default string conversion logic. func (e *Encoder) StringifyMapKeysWithFmt() { - e.stringifyMapKeysWithFmt = true + e.ec.stringifyMapKeysWithFmt = true } // NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON // null. func (e *Encoder) NilMapAsEmpty() { - e.nilMapAsEmpty = true + e.ec.nilMapAsEmpty = true } // NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON // null. func (e *Encoder) NilSliceAsEmpty() { - e.nilSliceAsEmpty = true + e.ec.nilSliceAsEmpty = true } // NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values // instead of BSON null. func (e *Encoder) NilByteSliceAsEmpty() { - e.nilByteSliceAsEmpty = true + e.ec.nilByteSliceAsEmpty = true } // TODO(GODRIVER-2820): Update the description to remove the note about only examining exported @@ -149,11 +113,11 @@ func (e *Encoder) NilByteSliceAsEmpty() { // Note that the Encoder only examines exported struct fields when determining if a struct is the // zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. func (e *Encoder) OmitZeroStruct() { - e.omitZeroStruct = true + e.ec.omitZeroStruct = true } // UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson" // struct tag is not specified. func (e *Encoder) UseJSONStructTags() { - e.useJSONStructTags = true + e.ec.useJSONStructTags = true } diff --git a/bson/map_codec.go b/bson/map_codec.go index e11ffaf726..4da32ccae4 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -159,7 +159,7 @@ func (mc *mapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va eTypeDecoder, _ := decoder.(typeDecoder) if eType == tEmpty { - dc.Ancestor = val.Type() + dc.ancestor = val.Type() } keyType := val.Type().Key() diff --git a/bson/pointer_codec.go b/bson/pointer_codec.go index 4ed1de3013..0fd7fdb81c 100644 --- a/bson/pointer_codec.go +++ b/bson/pointer_codec.go @@ -10,8 +10,8 @@ import ( "reflect" ) -var _ ValueEncoder = &pointerCodec{} -var _ ValueDecoder = &pointerCodec{} +var _ valueEncoder = &pointerCodec{} +var _ valueDecoder = &pointerCodec{} // pointerCodec is the Codec used for pointers. type pointerCodec struct { diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index 6b0c4c1e05..a2e1e61e89 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -45,7 +45,7 @@ func TestPrimitiveValueEncoders(t *testing.T) { testCases := []struct { name string - ve ValueEncoder + ve valueEncoder subtests []subtest }{ { @@ -491,7 +491,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { testCases := []struct { name string - vd ValueDecoder + vd valueDecoder subtests []subtest }{ { diff --git a/bson/registry.go b/bson/registry.go index 71d65259d6..9c680e4e66 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -125,7 +125,7 @@ func NewRegistry() *Registry { // interface. To get the latter behavior, call RegisterHookEncoder instead. // // RegisterTypeEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) { +func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc valueEncoder) { r.typeEncoders.Store(valueType, enc) } @@ -139,7 +139,7 @@ func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) // implements the interface. To get the latter behavior, call RegisterHookDecoder instead. // // RegisterTypeDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) { +func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec valueDecoder) { r.typeDecoders.Store(valueType, dec) } @@ -155,7 +155,7 @@ func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) // reg.RegisterKindEncoder(reflect.Int32, myEncoder) // // RegisterKindEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { +func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc valueEncoder) { r.kindEncoders.Store(kind, enc) } @@ -171,7 +171,7 @@ func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { // reg.RegisterKindDecoder(reflect.Int32, myDecoder) // // RegisterKindDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) { +func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec valueDecoder) { r.kindDecoders.Store(kind, dec) } @@ -181,7 +181,7 @@ func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) { // (i.e. iface.Kind() != reflect.Interface), this method will panic. // // RegisterInterfaceEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc ValueEncoder) { +func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc valueEncoder) { if iface.Kind() != reflect.Interface { panicStr := fmt.Errorf("RegisterInterfaceEncoder expects a type with kind reflect.Interface, "+ "got type %s with kind %s", iface, iface.Kind()) @@ -204,7 +204,7 @@ func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc ValueEncoder // this method will panic. // // RegisterInterfaceDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder) { +func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec valueDecoder) { if iface.Kind() != reflect.Interface { panicStr := fmt.Errorf("RegisterInterfaceDecoder expects a type with kind reflect.Interface, "+ "got type %s with kind %s", iface, iface.Kind()) @@ -251,7 +251,7 @@ func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) { // // If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for // concurrent use by multiple goroutines after all codecs and encoders are registered. -func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { +func (r *Registry) LookupEncoder(valueType reflect.Type) (valueEncoder, error) { if valueType == nil { return nil, ErrNoEncoder{Type: valueType} } @@ -274,15 +274,15 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { return nil, ErrNoEncoder{Type: valueType} } -func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEncoder { +func (r *Registry) storeTypeEncoder(rt reflect.Type, enc valueEncoder) valueEncoder { return r.typeEncoders.LoadOrStore(rt, enc) } -func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) { +func (r *Registry) lookupTypeEncoder(rt reflect.Type) (valueEncoder, bool) { return r.typeEncoders.Load(rt) } -func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (ValueEncoder, bool) { +func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (valueEncoder, bool) { if valueType == nil { return nil, false } @@ -320,7 +320,7 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool // // If no decoder is found, an error of type ErrNoDecoder is returned. LookupDecoder is safe for // concurrent use by multiple goroutines after all codecs and decoders are registered. -func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { +func (r *Registry) LookupDecoder(valueType reflect.Type) (valueDecoder, error) { if valueType == nil { return nil, ErrNilType } @@ -343,15 +343,15 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { return nil, ErrNoDecoder{Type: valueType} } -func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (ValueDecoder, bool) { +func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (valueDecoder, bool) { return r.typeDecoders.Load(valueType) } -func (r *Registry) storeTypeDecoder(typ reflect.Type, dec ValueDecoder) ValueDecoder { +func (r *Registry) storeTypeDecoder(typ reflect.Type, dec valueDecoder) valueDecoder { return r.typeDecoders.LoadOrStore(typ, dec) } -func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) { +func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (valueDecoder, bool) { for _, idec := range r.interfaceDecoders { if valueType.Implements(idec.i) { return idec.vd, true @@ -383,10 +383,10 @@ func (r *Registry) LookupTypeMapEntry(bt Type) (reflect.Type, error) { type interfaceValueEncoder struct { i reflect.Type - ve ValueEncoder + ve valueEncoder } type interfaceValueDecoder struct { i reflect.Type - vd ValueDecoder + vd valueDecoder } diff --git a/bson/registry_test.go b/bson/registry_test.go index b897f04db6..2a5150ad24 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -62,7 +62,7 @@ func TestRegistryBuilder(t *testing.T) { reg.RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) want := []struct { t reflect.Type - c ValueEncoder + c valueEncoder }{ {reflect.TypeOf(ft1), fc3}, {reflect.TypeOf(ft2), fc2}, @@ -90,7 +90,7 @@ func TestRegistryBuilder(t *testing.T) { reg.RegisterKindEncoder(k4, fc4) want := []struct { k reflect.Kind - c ValueEncoder + c valueEncoder }{ {k1, fc3}, {k2, fc2}, @@ -173,8 +173,8 @@ func TestRegistryBuilder(t *testing.T) { }) t.Run("Lookup", func(t *testing.T) { type Codec interface { - ValueEncoder - ValueDecoder + valueEncoder + valueDecoder } var ( @@ -472,7 +472,7 @@ func TestRegistry(t *testing.T) { want := []struct { t reflect.Type - c ValueEncoder + c valueEncoder }{ {reflect.TypeOf(ft1), fc3}, {reflect.TypeOf(ft2), fc2}, @@ -502,7 +502,7 @@ func TestRegistry(t *testing.T) { want := []struct { k reflect.Kind - c ValueEncoder + c valueEncoder }{ {k1, fc3}, {k2, fc2}, @@ -588,8 +588,8 @@ func TestRegistry(t *testing.T) { t.Parallel() type Codec interface { - ValueEncoder - ValueDecoder + valueEncoder + valueDecoder } var ( @@ -887,7 +887,7 @@ func TestRegistry(t *testing.T) { } // get is only for testing as it does return if the value was found -func (c *kindEncoderCache) get(rt reflect.Kind) ValueEncoder { +func (c *kindEncoderCache) get(rt reflect.Kind) valueEncoder { e, _ := c.Load(rt) return e } diff --git a/bson/slice_codec.go b/bson/slice_codec.go index 6d26f6283c..b25efc6bff 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -152,7 +152,7 @@ func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect. var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error) switch val.Type().Elem() { case tE: - dc.Ancestor = val.Type() + dc.ancestor = val.Type() elemsFunc = decodeD default: elemsFunc = decodeDefault diff --git a/bson/struct_codec.go b/bson/struct_codec.go index 0c3eac5c73..c8f783a00f 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -73,8 +73,8 @@ type structCodec struct { overwriteDuplicatedInlinedFields bool } -var _ ValueEncoder = &structCodec{} -var _ ValueDecoder = &structCodec{} +var _ valueEncoder = &structCodec{} +var _ valueDecoder = &structCodec{} // newStructCodec returns a StructCodec that uses p for struct tag parsing. func newStructCodec(p StructTagParser) *structCodec { @@ -158,7 +158,7 @@ func (sc *structCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect ectx := EncodeContext{ Registry: ec.Registry, - MinSize: desc.minSize || ec.MinSize, + minSize: desc.minSize || ec.minSize, errorOnInlineDuplicates: ec.errorOnInlineDuplicates, stringifyMapKeysWithFmt: ec.stringifyMapKeysWithFmt, nilMapAsEmpty: ec.nilMapAsEmpty, @@ -239,7 +239,7 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect val.Set(deepZero(val.Type())) } - var decoder ValueDecoder + var decoder valueDecoder var inlineMap reflect.Value if sd.inlineMap >= 0 { inlineMap = val.Field(sd.inlineMap) @@ -287,7 +287,7 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect } elem := reflect.New(inlineMap.Type().Elem()).Elem() - dc.Ancestor = inlineMap.Type() + dc.ancestor = inlineMap.Type() err = decoder.DecodeValue(dc, vr, elem) if err != nil { return err @@ -317,7 +317,7 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect dctx := DecodeContext{ Registry: dc.Registry, - Truncate: fd.truncate || dc.Truncate, + truncate: fd.truncate || dc.truncate, defaultDocumentType: dc.defaultDocumentType, binaryAsSlice: dc.binaryAsSlice, useJSONStructTags: dc.useJSONStructTags, @@ -385,8 +385,8 @@ type fieldDescription struct { minSize bool truncate bool inline []int - encoder ValueEncoder - decoder ValueDecoder + encoder valueEncoder + decoder valueDecoder } type byIndex []fieldDescription diff --git a/bson/truncation_test.go b/bson/truncation_test.go index 865917cfe4..a9aeea278b 100644 --- a/bson/truncation_test.go +++ b/bson/truncation_test.go @@ -41,7 +41,7 @@ func TestTruncation(t *testing.T) { var output outputArgs dc := DecodeContext{ Registry: DefaultRegistry, - Truncate: true, + truncate: true, } err = UnmarshalWithContext(dc, buf.Bytes(), &output) @@ -67,7 +67,7 @@ func TestTruncation(t *testing.T) { var output outputArgs dc := DecodeContext{ Registry: DefaultRegistry, - Truncate: false, + truncate: false, } // case throws an error when truncation is disabled diff --git a/bson/uint_codec.go b/bson/uint_codec.go index b8f97ae5ab..27a297d043 100644 --- a/bson/uint_codec.go +++ b/bson/uint_codec.go @@ -32,7 +32,7 @@ func (uic *uintCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect. u64 := val.Uint() // If ec.MinSize or if encodeToMinSize is true for a non-uint64 value we should write val as an int32 - useMinSize := ec.MinSize || (uic.encodeToMinSize && val.Kind() != reflect.Uint64) + useMinSize := ec.minSize || (uic.encodeToMinSize && val.Kind() != reflect.Uint64) if u64 <= math.MaxInt32 && useMinSize { return vw.WriteInt32(int32(u64)) @@ -70,7 +70,7 @@ func (uic *uintCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Typ if err != nil { return emptyValue, err } - if !dc.Truncate && math.Floor(f64) != f64 { + if !dc.truncate && math.Floor(f64) != f64 { return emptyValue, errCannotTruncate } if f64 > float64(math.MaxInt64) { From dfdb9c3c0e55ad469f75e9010dc56a42d21bfa36 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 1 May 2024 14:43:28 -0400 Subject: [PATCH 04/15] WIP --- bson/array_codec.go | 6 +- bson/bsoncodec.go | 46 ++----- bson/bsoncodec_test.go | 2 +- bson/byte_slice_codec.go | 11 +- bson/codec_cache.go | 32 ++--- bson/codec_cache_test.go | 4 +- bson/cond_addr_codec.go | 20 ++- bson/cond_addr_codec_test.go | 8 +- bson/default_value_decoders.go | 26 ++-- bson/default_value_decoders_test.go | 4 +- bson/default_value_encoders.go | 132 ++++++++----------- bson/default_value_encoders_test.go | 195 ++++++++++++++-------------- bson/empty_interface_codec.go | 13 +- bson/encoder.go | 67 +++++++--- bson/encoder_test.go | 2 +- bson/int_codec.go | 44 +++++++ bson/map_codec.go | 20 +-- bson/marshal.go | 4 +- bson/marshal_test.go | 2 +- bson/pointer_codec.go | 11 +- bson/primitive_codecs.go | 4 +- bson/primitive_codecs_test.go | 21 +-- bson/registry.go | 32 ++--- bson/registry_examples_test.go | 4 +- bson/registry_test.go | 20 +-- bson/setter_getter.go | 6 +- bson/slice_codec.go | 16 ++- bson/string_codec.go | 9 +- bson/struct_codec.go | 48 +++---- bson/time_codec.go | 8 +- bson/uint_codec.go | 12 +- internal/integration/client_test.go | 2 +- 32 files changed, 431 insertions(+), 400 deletions(-) create mode 100644 bson/int_codec.go diff --git a/bson/array_codec.go b/bson/array_codec.go index 4a53d376bc..9ea43d4028 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -15,8 +15,12 @@ import ( // arrayCodec is the Codec used for bsoncore.Array values. type arrayCodec struct{} +var ( + defaultArrayCodec = &arrayCodec{} +) + // EncodeValue is the ValueEncoder for bsoncore.Array values. -func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (ac *arrayCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreArray { return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index e3ae365fd4..68c108e104 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -72,25 +72,6 @@ func (vde ValueDecoderError) Error() string { return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received) } -// EncodeContext is the contextual information required for a Codec to encode a -// value. -type EncodeContext struct { - *Registry - - // minSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, - // uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) - // that can represent the integer value. - minSize bool - - errorOnInlineDuplicates bool - stringifyMapKeysWithFmt bool - nilMapAsEmpty bool - nilSliceAsEmpty bool - nilByteSliceAsEmpty bool - omitZeroStruct bool - useJSONStructTags bool -} - // DecodeContext is the contextual information required for a Codec to decode a // value. type DecodeContext struct { @@ -122,31 +103,30 @@ type DecodeContext struct { zeroStructs bool } -// valueEncoder is the interface implemented by types that can encode a provided Go type to BSON. +// ValueEncoder is the interface implemented by types that can encode a provided Go type to BSON. // The value to encode is provided as a reflect.Value and a bson.ValueWriter is used within the // EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc -// is provided to allow use of a function with the correct signature as a ValueEncoder. An -// EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and -// to provide configuration information. -type valueEncoder interface { - EncodeValue(EncodeContext, ValueWriter, reflect.Value) error +// is provided to allow use of a function with the correct signature as a ValueEncoder. A pointer +// to a Registry instance is provided to allow implementations to lookup further ValueEncoders. +type ValueEncoder interface { + EncodeValue(*Registry, ValueWriter, reflect.Value) error } // ValueEncoderFunc is an adapter function that allows a function with the correct signature to be // used as a ValueEncoder. -type ValueEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error +type ValueEncoderFunc func(*Registry, ValueWriter, reflect.Value) error // EncodeValue implements the ValueEncoder interface. -func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - return fn(ec, vw, val) +func (fn ValueEncoderFunc) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { + return fn(reg, vw, val) } -// valueDecoder is the interface implemented by types that can decode BSON to a provided Go type. +// ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type. // Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc, // ValueDecoderFunc is provided to allow the use of a function with the correct signature as a // ValueDecoder. A DecodeContext instance is provided and serves similar functionality to the // EncodeContext. -type valueDecoder interface { +type ValueDecoder interface { DecodeValue(DecodeContext, ValueReader, reflect.Value) error } @@ -177,17 +157,17 @@ type decodeAdapter struct { typeDecoderFunc } -var _ valueDecoder = decodeAdapter{} +var _ ValueDecoder = decodeAdapter{} var _ typeDecoder = decodeAdapter{} // decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type // t and calls decoder.DecodeValue on it. -func decodeTypeOrValue(decoder valueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { td, _ := decoder.(typeDecoder) return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true) } -func decodeTypeOrValueWithInfo(vd valueDecoder, td typeDecoder, dc DecodeContext, vr ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { +func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { if td != nil { val, err := td.decodeType(dc, vr, t) if err == nil && convert && val.Type() != t { diff --git a/bson/bsoncodec_test.go b/bson/bsoncodec_test.go index 797bc9b383..e4ba05d5e1 100644 --- a/bson/bsoncodec_test.go +++ b/bson/bsoncodec_test.go @@ -18,7 +18,7 @@ type llCodec struct { err error } -func (llc *llCodec) EncodeValue(_ EncodeContext, _ ValueWriter, i interface{}) error { +func (llc *llCodec) EncodeValue(_ *Registry, _ ValueWriter, i interface{}) error { if llc.err != nil { return llc.err } diff --git a/bson/byte_slice_codec.go b/bson/byte_slice_codec.go index bd5c5dae85..83dba12ecb 100644 --- a/bson/byte_slice_codec.go +++ b/bson/byte_slice_codec.go @@ -18,17 +18,16 @@ type byteSliceCodec struct { encodeNilAsEmpty bool } -// Assert that defaultByteSliceCodec satisfies the typeDecoder interface, which allows it to be -// used by collection type decoders (e.g. map, slice, etc) to set individual values in a -// collection. -var _ typeDecoder = (*byteSliceCodec)(nil) +var ( + defaultByteSliceCodec = &byteSliceCodec{} +) // EncodeValue is the ValueEncoder for []byte. -func (bsc *byteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (bsc *byteSliceCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tByteSlice { return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } - if val.IsNil() && !bsc.encodeNilAsEmpty && !ec.nilByteSliceAsEmpty { + if val.IsNil() && !bsc.encodeNilAsEmpty { return vw.WriteNull() } return vw.WriteBinary(val.Interface().([]byte)) diff --git a/bson/codec_cache.go b/bson/codec_cache.go index f28ec30911..b4042822e6 100644 --- a/bson/codec_cache.go +++ b/bson/codec_cache.go @@ -29,20 +29,20 @@ type typeEncoderCache struct { cache sync.Map // map[reflect.Type]ValueEncoder } -func (c *typeEncoderCache) Store(rt reflect.Type, enc valueEncoder) { +func (c *typeEncoderCache) Store(rt reflect.Type, enc ValueEncoder) { c.cache.Store(rt, enc) } -func (c *typeEncoderCache) Load(rt reflect.Type) (valueEncoder, bool) { +func (c *typeEncoderCache) Load(rt reflect.Type) (ValueEncoder, bool) { if v, _ := c.cache.Load(rt); v != nil { - return v.(valueEncoder), true + return v.(ValueEncoder), true } return nil, false } -func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc valueEncoder) valueEncoder { +func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc ValueEncoder) ValueEncoder { if v, loaded := c.cache.LoadOrStore(rt, enc); loaded { - enc = v.(valueEncoder) + enc = v.(ValueEncoder) } return enc } @@ -62,20 +62,20 @@ type typeDecoderCache struct { cache sync.Map // map[reflect.Type]ValueDecoder } -func (c *typeDecoderCache) Store(rt reflect.Type, dec valueDecoder) { +func (c *typeDecoderCache) Store(rt reflect.Type, dec ValueDecoder) { c.cache.Store(rt, dec) } -func (c *typeDecoderCache) Load(rt reflect.Type) (valueDecoder, bool) { +func (c *typeDecoderCache) Load(rt reflect.Type) (ValueDecoder, bool) { if v, _ := c.cache.Load(rt); v != nil { - return v.(valueDecoder), true + return v.(ValueDecoder), true } return nil, false } -func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec valueDecoder) valueDecoder { +func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec ValueDecoder) ValueDecoder { if v, loaded := c.cache.LoadOrStore(rt, dec); loaded { - dec = v.(valueDecoder) + dec = v.(ValueDecoder) } return dec } @@ -96,20 +96,20 @@ func (c *typeDecoderCache) Clone() *typeDecoderCache { // is always the same (since different concrete types may implement the // ValueEncoder interface). type kindEncoderCacheEntry struct { - enc valueEncoder + enc ValueEncoder } type kindEncoderCache struct { entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry } -func (c *kindEncoderCache) Store(rt reflect.Kind, enc valueEncoder) { +func (c *kindEncoderCache) Store(rt reflect.Kind, enc ValueEncoder) { if enc != nil && rt < reflect.Kind(len(c.entries)) { c.entries[rt].Store(&kindEncoderCacheEntry{enc: enc}) } } -func (c *kindEncoderCache) Load(rt reflect.Kind) (valueEncoder, bool) { +func (c *kindEncoderCache) Load(rt reflect.Kind) (ValueEncoder, bool) { if rt < reflect.Kind(len(c.entries)) { if ent, ok := c.entries[rt].Load().(*kindEncoderCacheEntry); ok { return ent.enc, ent.enc != nil @@ -133,20 +133,20 @@ func (c *kindEncoderCache) Clone() *kindEncoderCache { // is always the same (since different concrete types may implement the // ValueDecoder interface). type kindDecoderCacheEntry struct { - dec valueDecoder + dec ValueDecoder } type kindDecoderCache struct { entries [reflect.UnsafePointer + 1]atomic.Value // *kindDecoderCacheEntry } -func (c *kindDecoderCache) Store(rt reflect.Kind, dec valueDecoder) { +func (c *kindDecoderCache) Store(rt reflect.Kind, dec ValueDecoder) { if rt < reflect.Kind(len(c.entries)) { c.entries[rt].Store(&kindDecoderCacheEntry{dec: dec}) } } -func (c *kindDecoderCache) Load(rt reflect.Kind) (valueDecoder, bool) { +func (c *kindDecoderCache) Load(rt reflect.Kind) (ValueDecoder, bool) { if rt < reflect.Kind(len(c.entries)) { if ent, ok := c.entries[rt].Load().(*kindDecoderCacheEntry); ok { return ent.dec, ent.dec != nil diff --git a/bson/codec_cache_test.go b/bson/codec_cache_test.go index 3937db68ac..d48e05f5a3 100644 --- a/bson/codec_cache_test.go +++ b/bson/codec_cache_test.go @@ -134,7 +134,7 @@ func TestKindCacheClone(t *testing.T) { func TestKindCacheEncoderNilEncoder(t *testing.T) { t.Run("Encoder", func(t *testing.T) { c := new(kindEncoderCache) - c.Store(reflect.Invalid, valueEncoder(nil)) + c.Store(reflect.Invalid, ValueEncoder(nil)) v, ok := c.Load(reflect.Invalid) if v != nil || ok { t.Errorf("Load of nil ValueEncoder should return: nil, false; got: %v, %t", v, ok) @@ -142,7 +142,7 @@ func TestKindCacheEncoderNilEncoder(t *testing.T) { }) t.Run("Decoder", func(t *testing.T) { c := new(kindDecoderCache) - c.Store(reflect.Invalid, valueDecoder(nil)) + c.Store(reflect.Invalid, ValueDecoder(nil)) v, ok := c.Load(reflect.Invalid) if v != nil || ok { t.Errorf("Load of nil ValueDecoder should return: nil, false; got: %v, %t", v, ok) diff --git a/bson/cond_addr_codec.go b/bson/cond_addr_codec.go index a64a9c434f..d1baf96ef4 100644 --- a/bson/cond_addr_codec.go +++ b/bson/cond_addr_codec.go @@ -12,33 +12,31 @@ import ( // condAddrEncoder is the encoder used when a pointer to the encoding value has an encoder. type condAddrEncoder struct { - canAddrEnc valueEncoder - elseEnc valueEncoder + canAddrEnc ValueEncoder + elseEnc ValueEncoder } -var _ valueEncoder = (*condAddrEncoder)(nil) - // EncodeValue is the ValueEncoderFunc for a value that may be addressable. -func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (cae *condAddrEncoder) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { if val.CanAddr() { - return cae.canAddrEnc.EncodeValue(ec, vw, val) + return cae.canAddrEnc.EncodeValue(reg, vw, val) } if cae.elseEnc != nil { - return cae.elseEnc.EncodeValue(ec, vw, val) + return cae.elseEnc.EncodeValue(reg, vw, val) } return ErrNoEncoder{Type: val.Type()} } // condAddrDecoder is the decoder used when a pointer to the value has a decoder. type condAddrDecoder struct { - canAddrDec valueDecoder - elseDec valueDecoder + canAddrDec ValueDecoder + elseDec ValueDecoder } -var _ valueDecoder = (*condAddrDecoder)(nil) +var _ ValueDecoder = (*condAddrDecoder)(nil) // newCondAddrDecoder returns an CondAddrDecoder. -func newCondAddrDecoder(canAddrDec, elseDec valueDecoder) *condAddrDecoder { +func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder { decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec} return &decoder } diff --git a/bson/cond_addr_codec_test.go b/bson/cond_addr_codec_test.go index 55d73e8204..26cd9c5534 100644 --- a/bson/cond_addr_codec_test.go +++ b/bson/cond_addr_codec_test.go @@ -22,11 +22,11 @@ func TestCondAddrCodec(t *testing.T) { t.Run("addressEncode", func(t *testing.T) { invoked := 0 - encode1 := ValueEncoderFunc(func(EncodeContext, ValueWriter, reflect.Value) error { + encode1 := ValueEncoderFunc(func(*Registry, ValueWriter, reflect.Value) error { invoked = 1 return nil }) - encode2 := ValueEncoderFunc(func(EncodeContext, ValueWriter, reflect.Value) error { + encode2 := ValueEncoderFunc(func(*Registry, ValueWriter, reflect.Value) error { invoked = 2 return nil }) @@ -42,7 +42,7 @@ func TestCondAddrCodec(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := condEncoder.EncodeValue(EncodeContext{}, rw, tc.val) + err := condEncoder.EncodeValue(nil, rw, tc.val) assert.Nil(t, err, "CondAddrEncoder error: %v", err) assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked) @@ -51,7 +51,7 @@ func TestCondAddrCodec(t *testing.T) { t.Run("error", func(t *testing.T) { errEncoder := &condAddrEncoder{canAddrEnc: encode1, elseEnc: nil} - err := errEncoder.EncodeValue(EncodeContext{}, rw, unaddressable) + err := errEncoder.EncodeValue(nil, rw, unaddressable) want := ErrNoEncoder{Type: unaddressable.Type()} assert.Equal(t, err, want, "expected error %v, got %v", want, err) }) diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index cfa09bdd10..b105ab0715 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -56,10 +56,10 @@ func registerDefaultDecoders(reg *Registry) { reg.RegisterTypeDecoder(tMaxKey, decodeAdapter{maxKeyDecodeValue, maxKeyDecodeType}) reg.RegisterTypeDecoder(tJavaScript, decodeAdapter{javaScriptDecodeValue, javaScriptDecodeType}) reg.RegisterTypeDecoder(tSymbol, decodeAdapter{symbolDecodeValue, symbolDecodeType}) - reg.RegisterTypeDecoder(tByteSlice, &byteSliceCodec{}) - reg.RegisterTypeDecoder(tTime, &timeCodec{}) - reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{}) - reg.RegisterTypeDecoder(tCoreArray, &arrayCodec{}) + reg.RegisterTypeDecoder(tByteSlice, defaultByteSliceCodec) + reg.RegisterTypeDecoder(tTime, defaultTimeCodec) + reg.RegisterTypeDecoder(tEmpty, defaultEmptyInterfaceCodec) + reg.RegisterTypeDecoder(tCoreArray, defaultArrayCodec) reg.RegisterTypeDecoder(tOID, decodeAdapter{objectIDDecodeValue, objectIDDecodeType}) reg.RegisterTypeDecoder(tDecimal, decodeAdapter{decimal128DecodeValue, decimal128DecodeType}) reg.RegisterTypeDecoder(tJSONNumber, decodeAdapter{jsonNumberDecodeValue, jsonNumberDecodeType}) @@ -72,18 +72,18 @@ func registerDefaultDecoders(reg *Registry) { reg.RegisterKindDecoder(reflect.Int16, intDecoder) reg.RegisterKindDecoder(reflect.Int32, intDecoder) reg.RegisterKindDecoder(reflect.Int64, intDecoder) - reg.RegisterKindDecoder(reflect.Uint, &uintCodec{}) - reg.RegisterKindDecoder(reflect.Uint8, &uintCodec{}) - reg.RegisterKindDecoder(reflect.Uint16, &uintCodec{}) - reg.RegisterKindDecoder(reflect.Uint32, &uintCodec{}) - reg.RegisterKindDecoder(reflect.Uint64, &uintCodec{}) + reg.RegisterKindDecoder(reflect.Uint, defaultUIntCodec) + reg.RegisterKindDecoder(reflect.Uint8, defaultUIntCodec) + reg.RegisterKindDecoder(reflect.Uint16, defaultUIntCodec) + reg.RegisterKindDecoder(reflect.Uint32, defaultUIntCodec) + reg.RegisterKindDecoder(reflect.Uint64, defaultUIntCodec) reg.RegisterKindDecoder(reflect.Float32, floatDecoder) reg.RegisterKindDecoder(reflect.Float64, floatDecoder) reg.RegisterKindDecoder(reflect.Array, ValueDecoderFunc(arrayDecodeValue)) - reg.RegisterKindDecoder(reflect.Map, &mapCodec{}) - reg.RegisterKindDecoder(reflect.Slice, &sliceCodec{}) - reg.RegisterKindDecoder(reflect.String, &stringCodec{}) - reg.RegisterKindDecoder(reflect.Struct, newStructCodec(DefaultStructTagParser)) + reg.RegisterKindDecoder(reflect.Map, defaultMapCodec) + reg.RegisterKindDecoder(reflect.Slice, defaultSliceCodec) + reg.RegisterKindDecoder(reflect.String, defaultStringCodec) + reg.RegisterKindDecoder(reflect.Struct, defaultStructCodec) reg.RegisterKindDecoder(reflect.Ptr, &pointerCodec{}) reg.RegisterTypeMapEntry(TypeDouble, tFloat64) reg.RegisterTypeMapEntry(TypeString, tString) diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index fc86145bfd..d2434a19e1 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -65,7 +65,7 @@ func TestDefaultValueDecoders(t *testing.T) { testCases := []struct { name string - vd valueDecoder + vd ValueDecoder subtests []subtest }{ { @@ -3566,7 +3566,7 @@ func TestDefaultValueDecoders(t *testing.T) { val interface{} vr ValueReader registry *Registry // buildDefaultRegistry will be used if this is nil - decoder valueDecoder + decoder ValueDecoder err error }{ { diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 31c4bf314c..df80ef0080 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -28,7 +28,7 @@ var sliceWriterPool = sync.Pool{ }, } -func encodeElement(ec EncodeContext, dw DocumentWriter, e E) error { +func encodeElement(reg *Registry, dw DocumentWriter, e E) error { vw, err := dw.WriteDocumentElement(e.Key) if err != nil { return err @@ -37,12 +37,12 @@ func encodeElement(ec EncodeContext, dw DocumentWriter, e E) error { if e.Value == nil { return vw.WriteNull() } - encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value)) + encoder, err := reg.LookupEncoder(reflect.TypeOf(e.Value)) if err != nil { return err } - err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value)) + err = encoder.EncodeValue(reg, vw, reflect.ValueOf(e.Value)) if err != nil { return err } @@ -54,6 +54,8 @@ func registerDefaultEncoders(reg *Registry) { if reg == nil { panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) } + intEncoder := &intCodec{} + uintEncoder := &uintCodec{} reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{}) reg.RegisterTypeEncoder(tTime, &timeCodec{}) reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{}) @@ -76,16 +78,16 @@ func registerDefaultEncoders(reg *Registry) { reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue)) reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue)) reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue)) - reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue)) - reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue)) - reg.RegisterKindEncoder(reflect.Int16, ValueEncoderFunc(intEncodeValue)) - reg.RegisterKindEncoder(reflect.Int32, ValueEncoderFunc(intEncodeValue)) - reg.RegisterKindEncoder(reflect.Int64, ValueEncoderFunc(intEncodeValue)) - reg.RegisterKindEncoder(reflect.Uint, &uintCodec{}) - reg.RegisterKindEncoder(reflect.Uint8, &uintCodec{}) - reg.RegisterKindEncoder(reflect.Uint16, &uintCodec{}) - reg.RegisterKindEncoder(reflect.Uint32, &uintCodec{}) - reg.RegisterKindEncoder(reflect.Uint64, &uintCodec{}) + reg.RegisterKindEncoder(reflect.Int, intEncoder) + reg.RegisterKindEncoder(reflect.Int8, intEncoder) + reg.RegisterKindEncoder(reflect.Int16, intEncoder) + reg.RegisterKindEncoder(reflect.Int32, intEncoder) + reg.RegisterKindEncoder(reflect.Int64, intEncoder) + reg.RegisterKindEncoder(reflect.Uint, uintEncoder) + reg.RegisterKindEncoder(reflect.Uint8, uintEncoder) + reg.RegisterKindEncoder(reflect.Uint16, uintEncoder) + reg.RegisterKindEncoder(reflect.Uint32, uintEncoder) + reg.RegisterKindEncoder(reflect.Uint64, uintEncoder) reg.RegisterKindEncoder(reflect.Float32, ValueEncoderFunc(floatEncodeValue)) reg.RegisterKindEncoder(reflect.Float64, ValueEncoderFunc(floatEncodeValue)) reg.RegisterKindEncoder(reflect.Array, ValueEncoderFunc(arrayEncodeValue)) @@ -100,7 +102,7 @@ func registerDefaultEncoders(reg *Registry) { } // booleanEncodeValue is the ValueEncoderFunc for bool types. -func booleanEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func booleanEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Bool { return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } @@ -111,34 +113,8 @@ func fitsIn32Bits(i int64) bool { return math.MinInt32 <= i && i <= math.MaxInt32 } -// intEncodeValue is the ValueEncoderFunc for int types. -func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { - switch val.Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32: - return vw.WriteInt32(int32(val.Int())) - case reflect.Int: - i64 := val.Int() - if fitsIn32Bits(i64) { - return vw.WriteInt32(int32(i64)) - } - return vw.WriteInt64(i64) - case reflect.Int64: - i64 := val.Int() - if ec.minSize && fitsIn32Bits(i64) { - return vw.WriteInt32(int32(i64)) - } - return vw.WriteInt64(i64) - } - - return ValueEncoderError{ - Name: "IntEncodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, - Received: val, - } -} - // floatEncodeValue is the ValueEncoderFunc for float types. -func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func floatEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Float32, reflect.Float64: return vw.WriteDouble(val.Float()) @@ -148,7 +124,7 @@ func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error } // objectIDEncodeValue is the ValueEncoderFunc for ObjectID. -func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func objectIDEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tOID { return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} } @@ -156,7 +132,7 @@ func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) err } // decimal128EncodeValue is the ValueEncoderFunc for Decimal128. -func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func decimal128EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDecimal { return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} } @@ -164,7 +140,7 @@ func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) e } // jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. -func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func jsonNumberEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJSONNumber { return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} } @@ -172,7 +148,11 @@ func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) // Attempt int first, then float64 if i64, err := jsnum.Int64(); err == nil { - return intEncodeValue(ec, vw, reflect.ValueOf(i64)) + encoder, err := reg.LookupEncoder(tInt64) + if err != nil { + return err + } + return encoder.EncodeValue(reg, vw, reflect.ValueOf(i64)) } f64, err := jsnum.Float64() @@ -180,11 +160,11 @@ func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) return err } - return floatEncodeValue(ec, vw, reflect.ValueOf(f64)) + return floatEncodeValue(reg, vw, reflect.ValueOf(f64)) } // urlEncodeValue is the ValueEncoderFunc for url.URL. -func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func urlEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tURL { return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} } @@ -193,7 +173,7 @@ func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { } // arrayEncodeValue is the ValueEncoderFunc for array types. -func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func arrayEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} } @@ -207,7 +187,7 @@ func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error for idx := 0; idx < val.Len(); idx++ { e := val.Index(idx).Interface().(E) - err = encodeElement(ec, dw, e) + err = encodeElement(reg, dw, e) if err != nil { return err } @@ -231,13 +211,13 @@ func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error } elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) + encoder, err := reg.LookupEncoder(elemType) if err != nil && elemType.Kind() != reflect.Interface { return err } for idx := 0; idx < val.Len(); idx++ { - currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx)) + currEncoder, currVal, lookupErr := lookupElementEncoder(reg, encoder, val.Index(idx)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -255,7 +235,7 @@ func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error continue } - err = currEncoder.EncodeValue(ec, vw, currVal) + err = currEncoder.EncodeValue(reg, vw, currVal) if err != nil { return err } @@ -263,7 +243,7 @@ func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error return aw.WriteArrayEnd() } -func lookupElementEncoder(ec EncodeContext, origEncoder valueEncoder, currVal reflect.Value) (valueEncoder, reflect.Value, error) { +func lookupElementEncoder(reg *Registry, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { if origEncoder != nil || (currVal.Kind() != reflect.Interface) { return origEncoder, currVal, nil } @@ -271,13 +251,13 @@ func lookupElementEncoder(ec EncodeContext, origEncoder valueEncoder, currVal re if !currVal.IsValid() { return nil, currVal, errInvalidValue } - currEncoder, err := ec.LookupEncoder(currVal.Type()) + currEncoder, err := reg.LookupEncoder(currVal.Type()) return currEncoder, currVal, err } // valueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. -func valueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func valueMarshalerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement ValueMarshaler switch { case !val.IsValid(): @@ -305,7 +285,7 @@ func valueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Valu } // marshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. -func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func marshalerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Marshaler switch { case !val.IsValid(): @@ -333,7 +313,7 @@ func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) er } // proxyEncodeValue is the ValueEncoderFunc for Proxy implementations. -func proxyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func proxyEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Proxy switch { case !val.IsValid(): @@ -358,26 +338,26 @@ func proxyEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error return err } if v == nil { - encoder, err := ec.LookupEncoder(nil) + encoder, err := reg.LookupEncoder(nil) if err != nil { return err } - return encoder.EncodeValue(ec, vw, reflect.ValueOf(nil)) + return encoder.EncodeValue(reg, vw, reflect.ValueOf(nil)) } vv := reflect.ValueOf(v) switch vv.Kind() { case reflect.Ptr, reflect.Interface: vv = vv.Elem() } - encoder, err := ec.LookupEncoder(vv.Type()) + encoder, err := reg.LookupEncoder(vv.Type()) if err != nil { return err } - return encoder.EncodeValue(ec, vw, vv) + return encoder.EncodeValue(reg, vw, vv) } // javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. -func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func javaScriptEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJavaScript { return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} } @@ -386,7 +366,7 @@ func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) e } // symbolEncodeValue is the ValueEncoderFunc for the Symbol type. -func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func symbolEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tSymbol { return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} } @@ -395,7 +375,7 @@ func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error } // binaryEncodeValue is the ValueEncoderFunc for Binary. -func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func binaryEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tBinary { return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} } @@ -405,7 +385,7 @@ func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error } // undefinedEncodeValue is the ValueEncoderFunc for Undefined. -func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func undefinedEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tUndefined { return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} } @@ -414,7 +394,7 @@ func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) er } // dateTimeEncodeValue is the ValueEncoderFunc for DateTime. -func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func dateTimeEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDateTime { return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} } @@ -423,7 +403,7 @@ func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) err } // nullEncodeValue is the ValueEncoderFunc for Null. -func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func nullEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tNull { return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} } @@ -432,7 +412,7 @@ func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { } // regexEncodeValue is the ValueEncoderFunc for Regex. -func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func regexEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRegex { return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} } @@ -443,7 +423,7 @@ func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error } // dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. -func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func dbPointerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDBPointer { return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} } @@ -454,7 +434,7 @@ func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) er } // timestampEncodeValue is the ValueEncoderFunc for Timestamp. -func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func timestampEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTimestamp { return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } @@ -465,7 +445,7 @@ func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) er } // minKeyEncodeValue is the ValueEncoderFunc for MinKey. -func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func minKeyEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMinKey { return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} } @@ -474,7 +454,7 @@ func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error } // maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. -func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func maxKeyEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMaxKey { return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} } @@ -483,7 +463,7 @@ func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error } // coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. -func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func coreDocumentEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreDocument { return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -494,7 +474,7 @@ func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) } // codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. -func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func codeWithScopeEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCodeWithScope { return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} } @@ -513,12 +493,12 @@ func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Valu scopeVW := bvwPool.Get(sw) defer bvwPool.Put(scopeVW) - encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope)) + encoder, err := reg.LookupEncoder(reflect.TypeOf(cws.Scope)) if err != nil { return err } - err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope)) + err = encoder.EncodeValue(reg, scopeVW, reflect.ValueOf(cws.Scope)) if err != nil { return err } diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 9c869448f9..0cc12ce597 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -10,7 +10,6 @@ import ( "encoding/json" "errors" "fmt" - "math" "net/url" "reflect" "strings" @@ -38,16 +37,16 @@ func TestDefaultValueEncoders(t *testing.T) { var wrong = func(string, string) string { return "wrong" } type mybool bool - type myint8 int8 - type myint16 int16 - type myint32 int32 - type myint64 int64 - type myint int + // type myint8 int8 + // type myint16 int16 + // type myint32 int32 + // type myint64 int64 + // type myint int type myuint8 uint8 type myuint16 uint16 type myuint32 uint32 type myuint64 uint64 - type myuint uint + // type myuint uint type myfloat32 float32 type myfloat64 float64 @@ -66,7 +65,7 @@ func TestDefaultValueEncoders(t *testing.T) { type subtest struct { name string val interface{} - ectx *EncodeContext + reg *Registry llvrw *valueReaderWriter invoke invoked err error @@ -74,7 +73,7 @@ func TestDefaultValueEncoders(t *testing.T) { testCases := []struct { name string - ve valueEncoder + ve ValueEncoder subtests []subtest }{ { @@ -93,46 +92,48 @@ func TestDefaultValueEncoders(t *testing.T) { {"reflection path", mybool(true), nil, nil, writeBoolean, nil}, }, }, - { - "IntEncodeValue", - ValueEncoderFunc(intEncodeValue), - []subtest{ - { - "wrong type", - wrong, - nil, - nil, - nothing, - ValueEncoderError{ - Name: "IntEncodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, - Received: reflect.ValueOf(wrong), + /* + { + "IntEncodeValue", + ValueEncoderFunc(intEncodeValue), + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + nothing, + ValueEncoderError{ + Name: "IntEncodeValue", + Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Received: reflect.ValueOf(wrong), + }, }, + {"int8/fast path", int8(127), nil, nil, writeInt32, nil}, + {"int16/fast path", int16(32767), nil, nil, writeInt32, nil}, + {"int32/fast path", int32(2147483647), nil, nil, writeInt32, nil}, + {"int64/fast path", int64(1234567890987), nil, nil, writeInt64, nil}, + {"int64/fast path - minsize", int64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"int/fast path - positive int32", int(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, + {"int/fast path - negative int32", int(math.MinInt32 + 1), nil, nil, writeInt32, nil}, + {"int/fast path - MaxInt32", int(math.MaxInt32), nil, nil, writeInt32, nil}, + {"int/fast path - MinInt32", int(math.MinInt32), nil, nil, writeInt32, nil}, + {"int8/reflection path", myint8(127), nil, nil, writeInt32, nil}, + {"int16/reflection path", myint16(32767), nil, nil, writeInt32, nil}, + {"int32/reflection path", myint32(2147483647), nil, nil, writeInt32, nil}, + {"int64/reflection path", myint64(1234567890987), nil, nil, writeInt64, nil}, + {"int64/reflection path - minsize", myint64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"int/reflection path - positive int32", myint(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, + {"int/reflection path - negative int32", myint(math.MinInt32 + 1), nil, nil, writeInt32, nil}, + {"int/reflection path - MaxInt32", myint(math.MaxInt32), nil, nil, writeInt32, nil}, + {"int/reflection path - MinInt32", myint(math.MinInt32), nil, nil, writeInt32, nil}, }, - {"int8/fast path", int8(127), nil, nil, writeInt32, nil}, - {"int16/fast path", int16(32767), nil, nil, writeInt32, nil}, - {"int32/fast path", int32(2147483647), nil, nil, writeInt32, nil}, - {"int64/fast path", int64(1234567890987), nil, nil, writeInt64, nil}, - {"int64/fast path - minsize", int64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int/fast path - positive int32", int(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, - {"int/fast path - negative int32", int(math.MinInt32 + 1), nil, nil, writeInt32, nil}, - {"int/fast path - MaxInt32", int(math.MaxInt32), nil, nil, writeInt32, nil}, - {"int/fast path - MinInt32", int(math.MinInt32), nil, nil, writeInt32, nil}, - {"int8/reflection path", myint8(127), nil, nil, writeInt32, nil}, - {"int16/reflection path", myint16(32767), nil, nil, writeInt32, nil}, - {"int32/reflection path", myint32(2147483647), nil, nil, writeInt32, nil}, - {"int64/reflection path", myint64(1234567890987), nil, nil, writeInt64, nil}, - {"int64/reflection path - minsize", myint64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int/reflection path - positive int32", myint(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, - {"int/reflection path - negative int32", myint(math.MinInt32 + 1), nil, nil, writeInt32, nil}, - {"int/reflection path - MaxInt32", myint(math.MaxInt32), nil, nil, writeInt32, nil}, - {"int/reflection path - MinInt32", myint(math.MinInt32), nil, nil, writeInt32, nil}, }, - }, + */ { "UintEncodeValue", &uintCodec{}, @@ -154,23 +155,23 @@ func TestDefaultValueEncoders(t *testing.T) { {"uint32/fast path", uint32(2147483647), nil, nil, writeInt64, nil}, {"uint64/fast path", uint64(1234567890987), nil, nil, writeInt64, nil}, {"uint/fast path", uint(1234567), nil, nil, writeInt64, nil}, - {"uint32/fast path - minsize", uint32(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"uint64/fast path - minsize", uint64(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"uint/fast path - minsize", uint(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"uint32/fast path - minsize too large", uint32(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"uint64/fast path - minsize too large", uint64(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"uint/fast path - minsize too large", uint(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + // {"uint32/fast path - minsize", uint32(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + // {"uint64/fast path - minsize", uint64(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + // {"uint/fast path - minsize", uint(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + // {"uint32/fast path - minsize too large", uint32(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + // {"uint64/fast path - minsize too large", uint64(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + // {"uint/fast path - minsize too large", uint(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, {"uint64/fast path - overflow", uint64(1 << 63), nil, nil, nothing, fmt.Errorf("%d overflows int64", uint64(1<<63))}, {"uint8/reflection path", myuint8(127), nil, nil, writeInt32, nil}, {"uint16/reflection path", myuint16(32767), nil, nil, writeInt32, nil}, {"uint32/reflection path", myuint32(2147483647), nil, nil, writeInt64, nil}, {"uint64/reflection path", myuint64(1234567890987), nil, nil, writeInt64, nil}, - {"uint32/reflection path - minsize", myuint32(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"uint64/reflection path - minsize", myuint64(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"uint/reflection path - minsize", myuint(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"uint32/reflection path - minsize too large", myuint(1 << 31), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"uint64/reflection path - minsize too large", myuint64(1 << 31), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"uint/reflection path - minsize too large", myuint(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + // {"uint32/reflection path - minsize", myuint32(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + // {"uint64/reflection path - minsize", myuint64(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + // {"uint/reflection path - minsize", myuint(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + // {"uint32/reflection path - minsize too large", myuint(1 << 31), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + // {"uint64/reflection path - minsize too large", myuint64(1 << 31), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + // {"uint/reflection path - minsize too large", myuint(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, {"uint64/reflection path - overflow", myuint64(1 << 63), nil, nil, nothing, fmt.Errorf("%d overflows int64", uint64(1<<63))}, }, }, @@ -234,7 +235,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", map[string]int{"foo": 1}, - &EncodeContext{Registry: newTestRegistry()}, + newTestRegistry(), &valueReaderWriter{}, writeDocument, fmt.Errorf("no encoder found for int"), @@ -242,7 +243,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "WriteDocumentElement Error", map[string]interface{}{"foo": "bar"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("wde error"), ErrAfter: writeDocumentElement}, writeDocumentElement, errors.New("wde error"), @@ -250,7 +251,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "EncodeValue Error", map[string]interface{}{"foo": "bar"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeString}, writeString, errors.New("ev error"), @@ -258,7 +259,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty map/success", map[string]interface{}{}, - &EncodeContext{Registry: newTestRegistry()}, + newTestRegistry(), &valueReaderWriter{}, writeDocumentEnd, nil, @@ -266,7 +267,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "with interface/success", map[string]myInterface{"foo": myStruct{1}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -274,7 +275,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "with interface/nil/success", map[string]myInterface{"foo": nil}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -284,7 +285,7 @@ func TestDefaultValueEncoders(t *testing.T) { map[int]interface{}{ 1: "foobar", }, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{}, writeDocumentEnd, nil, @@ -314,7 +315,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", [1]int{1}, - &EncodeContext{Registry: newTestRegistry()}, + newTestRegistry(), &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -322,7 +323,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "WriteArrayElement Error", [1]string{"foo"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("wae error"), ErrAfter: writeArrayElement}, writeArrayElement, errors.New("wae error"), @@ -330,7 +331,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "EncodeValue Error", [1]string{"foo"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeString}, writeString, errors.New("ev error"), @@ -338,7 +339,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "[1]E/success", [1]E{{"hello", "world"}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -346,7 +347,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "[1]E/success", [1]E{{"hello", nil}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -354,7 +355,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "[1]interface/success", [1]myInterface{myStruct{1}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeArrayEnd, nil, @@ -362,7 +363,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "[1]interface/nil/success", [1]myInterface{nil}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeArrayEnd, nil, @@ -392,7 +393,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", []int{1}, - &EncodeContext{Registry: newTestRegistry()}, + newTestRegistry(), &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -400,7 +401,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "WriteArrayElement Error", []string{"foo"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("wae error"), ErrAfter: writeArrayElement}, writeArrayElement, errors.New("wae error"), @@ -408,7 +409,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "EncodeValue Error", []string{"foo"}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeString}, writeString, errors.New("ev error"), @@ -416,7 +417,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "D/success", D{{"hello", "world"}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -424,7 +425,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "D/success", D{{"hello", nil}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -432,7 +433,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty slice/success", []interface{}{}, - &EncodeContext{Registry: newTestRegistry()}, + newTestRegistry(), &valueReaderWriter{}, writeArrayEnd, nil, @@ -440,7 +441,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "interface/success", []myInterface{myStruct{1}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeArrayEnd, nil, @@ -448,7 +449,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "interface/success", []myInterface{nil}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeArrayEnd, nil, @@ -726,7 +727,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup error", testProxy{ret: nil}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, nothing, ErrNoEncoder{Type: nil}, @@ -734,7 +735,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "success struct implementation", testProxy{ret: int64(1234567890)}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeInt64, nil, @@ -742,7 +743,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "success ptr to struct implementation", &testProxy{ret: int64(1234567890)}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeInt64, nil, @@ -758,7 +759,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "success ptr to ptr implementation", &testProxyPtr{ret: int64(1234567890)}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeInt64, nil, @@ -804,7 +805,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "no encoder", &wrong, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, nothing, ErrNoEncoder{Type: reflect.TypeOf(wrong)}, @@ -818,7 +819,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "ValueMarshaler", &vmStruct, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -826,7 +827,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Marshaler", &mStruct, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -834,7 +835,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Proxy", &pStruct, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -1073,12 +1074,12 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "StructEncodeValue", - newStructCodec(DefaultStructTagParser), + defaultStructCodec, []subtest{ { "interface value", struct{ Foo myInterface }{Foo: myStruct{1}}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -1086,7 +1087,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "nil interface value", struct{ Foo myInterface }{Foo: nil}, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, @@ -1123,7 +1124,7 @@ func TestDefaultValueEncoders(t *testing.T) { Code: "var hello = 'world';", Scope: D{}, }, - &EncodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), nil, writeDocumentEnd, nil, }, }, @@ -1181,16 +1182,12 @@ func TestDefaultValueEncoders(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, subtest := range tc.subtests { t.Run(subtest.name, func(t *testing.T) { - var ec EncodeContext - if subtest.ectx != nil { - ec = *subtest.ectx - } llvrw := new(valueReaderWriter) if subtest.llvrw != nil { llvrw = subtest.llvrw } llvrw.T = t - err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val)) + err := tc.ve.EncodeValue(subtest.reg, llvrw, reflect.ValueOf(subtest.val)) if !assert.CompareErrors(err, subtest.err) { t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) } @@ -1770,7 +1767,7 @@ func TestDefaultValueEncoders(t *testing.T) { reg := buildDefaultRegistry() enc, err := reg.LookupEncoder(reflect.TypeOf(tc.value)) noerr(t, err) - err = enc.EncodeValue(EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.value)) + err = enc.EncodeValue(reg, vw, reflect.ValueOf(tc.value)) if !errors.Is(err, tc.err) { t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) } @@ -1820,7 +1817,7 @@ func TestDefaultValueEncoders(t *testing.T) { reg := buildDefaultRegistry() enc, err := reg.LookupEncoder(reflect.TypeOf(tc.value)) noerr(t, err) - err = enc.EncodeValue(EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.value)) + err = enc.EncodeValue(reg, vw, reflect.ValueOf(tc.value)) if err == nil || !strings.Contains(err.Error(), tc.err.Error()) { t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) } diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index fce09c0ac1..0a68c77a40 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -17,13 +17,12 @@ type emptyInterfaceCodec struct { decodeBinaryAsSlice bool } -// Assert that defaultEmptyInterfaceCodec satisfies the typeDecoder interface, which allows it -// to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a -// collection. -var _ typeDecoder = (*emptyInterfaceCodec)(nil) +var ( + defaultEmptyInterfaceCodec = &emptyInterfaceCodec{} +) // EncodeValue is the ValueEncoderFunc for interface{}. -func (eic emptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (eic emptyInterfaceCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tEmpty { return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} } @@ -31,12 +30,12 @@ func (eic emptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val if val.IsNil() { return vw.WriteNull() } - encoder, err := ec.LookupEncoder(val.Elem().Type()) + encoder, err := reg.LookupEncoder(val.Elem().Type()) if err != nil { return err } - return encoder.EncodeValue(ec, vw, val.Elem()) + return encoder.EncodeValue(reg, vw, val.Elem()) } func (eic emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) { diff --git a/bson/encoder.go b/bson/encoder.go index 1b348f9488..33cb16cc13 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -23,15 +23,15 @@ var encPool = sync.Pool{ // An Encoder writes a serialization format to an output stream. It writes to a ValueWriter // as the destination of BSON data. type Encoder struct { - ec EncodeContext - vw ValueWriter + reg *Registry + vw ValueWriter } // NewEncoder returns a new encoder that uses the DefaultRegistry to write to vw. func NewEncoder(vw ValueWriter) *Encoder { return &Encoder{ - ec: EncodeContext{Registry: DefaultRegistry}, - vw: vw, + reg: DefaultRegistry, + vw: vw, } } @@ -48,12 +48,12 @@ func (e *Encoder) Encode(val interface{}) error { return copyDocumentFromBytes(e.vw, buf) } - encoder, err := e.ec.LookupEncoder(reflect.TypeOf(val)) + encoder, err := e.reg.LookupEncoder(reflect.TypeOf(val)) if err != nil { return err } - return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val)) + return encoder.EncodeValue(e.reg, e.vw, reflect.ValueOf(val)) } // Reset will reset the state of the Encoder, using the same *EncodeContext used in @@ -64,44 +64,73 @@ func (e *Encoder) Reset(vw ValueWriter) { // SetRegistry replaces the current registry of the Encoder with r. func (e *Encoder) SetRegistry(r *Registry) { - e.ec.Registry = r + e.reg = r } // ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in // the marshaled BSON when the "inline" struct tag option is set. func (e *Encoder) ErrorOnInlineDuplicates() { - e.ec.errorOnInlineDuplicates = true + if v, ok := e.reg.kindEncoders.Load(reflect.Struct); ok { + if enc, ok := v.(*structCodec); ok { + enc.overwriteDuplicatedInlinedFields = false + } + } } // IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint, // uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can // represent the integer value. func (e *Encoder) IntMinSize() { - e.ec.minSize = true + if v, ok := e.reg.kindEncoders.Load(reflect.Int); ok { + if enc, ok := v.(*intCodec); ok { + enc.encodeToMinSize = true + } + } + if v, ok := e.reg.kindEncoders.Load(reflect.Uint); ok { + if enc, ok := v.(*uintCodec); ok { + enc.encodeToMinSize = true + } + } } // StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name // strings using fmt.Sprint instead of the default string conversion logic. func (e *Encoder) StringifyMapKeysWithFmt() { - e.ec.stringifyMapKeysWithFmt = true + if v, ok := e.reg.kindEncoders.Load(reflect.Map); ok { + if enc, ok := v.(*mapCodec); ok { + enc.encodeKeysWithStringer = true + } + } } // NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON // null. func (e *Encoder) NilMapAsEmpty() { - e.ec.nilMapAsEmpty = true + if v, ok := e.reg.kindEncoders.Load(reflect.Map); ok { + if enc, ok := v.(*mapCodec); ok { + enc.encodeNilAsEmpty = true + } + } } // NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON // null. func (e *Encoder) NilSliceAsEmpty() { - e.ec.nilSliceAsEmpty = true + if v, ok := e.reg.kindEncoders.Load(reflect.Slice); ok { + if enc, ok := v.(*sliceCodec); ok { + enc.encodeNilAsEmpty = true + } + } } // NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values // instead of BSON null. func (e *Encoder) NilByteSliceAsEmpty() { - e.ec.nilByteSliceAsEmpty = true + if v, ok := e.reg.typeEncoders.Load(tByteSlice); ok { + if enc, ok := v.(*byteSliceCodec); ok { + enc.encodeNilAsEmpty = true + } + } } // TODO(GODRIVER-2820): Update the description to remove the note about only examining exported @@ -113,11 +142,19 @@ func (e *Encoder) NilByteSliceAsEmpty() { // Note that the Encoder only examines exported struct fields when determining if a struct is the // zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. func (e *Encoder) OmitZeroStruct() { - e.ec.omitZeroStruct = true + if v, ok := e.reg.kindEncoders.Load(reflect.Struct); ok { + if enc, ok := v.(*structCodec); ok { + enc.encodeOmitDefaultStruct = true + } + } } // UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson" // struct tag is not specified. func (e *Encoder) UseJSONStructTags() { - e.ec.useJSONStructTags = true + if v, ok := e.reg.kindEncoders.Load(reflect.Struct); ok { + if enc, ok := v.(*structCodec); ok { + enc.useJSONStructTags = true + } + } } diff --git a/bson/encoder_test.go b/bson/encoder_test.go index 999b9962ef..15cce55700 100644 --- a/bson/encoder_test.go +++ b/bson/encoder_test.go @@ -25,7 +25,7 @@ func TestBasicEncode(t *testing.T) { reg := DefaultRegistry encoder, err := reg.LookupEncoder(reflect.TypeOf(tc.val)) noerr(t, err) - err = encoder.EncodeValue(EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.val)) + err = encoder.EncodeValue(reg, vw, reflect.ValueOf(tc.val)) noerr(t, err) if !bytes.Equal(got, tc.want) { diff --git a/bson/int_codec.go b/bson/int_codec.go new file mode 100644 index 0000000000..d0791ad70b --- /dev/null +++ b/bson/int_codec.go @@ -0,0 +1,44 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "reflect" +) + +// intCodec is the Codec used for uint values. +type intCodec struct { + // encodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the + // minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value. + encodeToMinSize bool +} + +// EncodeValue is the ValueEncoder for uint types. +func (ic *intCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { + switch val.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32: + return vw.WriteInt32(int32(val.Int())) + case reflect.Int: + i64 := val.Int() + if fitsIn32Bits(i64) { + return vw.WriteInt32(int32(i64)) + } + return vw.WriteInt64(i64) + case reflect.Int64: + i64 := val.Int() + if ic.encodeToMinSize && fitsIn32Bits(i64) { + return vw.WriteInt32(int32(i64)) + } + return vw.WriteInt64(i64) + } + + return ValueEncoderError{ + Name: "IntEncodeValue", + Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Received: val, + } +} diff --git a/bson/map_codec.go b/bson/map_codec.go index 4da32ccae4..ce2ac9ff9e 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -14,6 +14,10 @@ import ( "strconv" ) +var ( + defaultMapCodec = &mapCodec{} +) + // mapCodec is the Codec used for map values. type mapCodec struct { // decodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination @@ -46,12 +50,12 @@ type KeyUnmarshaler interface { } // EncodeValue is the ValueEncoder for map[*]* types. -func (mc *mapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (mc *mapCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Map { return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} } - if val.IsNil() && !mc.encodeNilAsEmpty && !ec.nilMapAsEmpty { + if val.IsNil() && !mc.encodeNilAsEmpty { // If we have a nil map but we can't WriteNull, that means we're probably trying to encode // to a TopLevel document. We can't currently tell if this is what actually happened, but if // there's a deeper underlying problem, the error will also be returned from WriteDocument, @@ -68,23 +72,23 @@ func (mc *mapCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Va return err } - return mc.mapEncodeValue(ec, dw, val, nil) + return mc.mapEncodeValue(reg, dw, val, nil) } // mapEncodeValue handles encoding of the values of a map. The collisionFn returns // true if the provided key exists, this is mainly used for inline maps in the // struct codec. -func (mc *mapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { +func (mc *mapCodec) mapEncodeValue(reg *Registry, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) + encoder, err := reg.LookupEncoder(elemType) if err != nil && elemType.Kind() != reflect.Interface { return err } keys := val.MapKeys() for _, key := range keys { - keyStr, err := mc.encodeKey(key, ec.stringifyMapKeysWithFmt) + keyStr, err := mc.encodeKey(key, mc.encodeKeysWithStringer) if err != nil { return err } @@ -93,7 +97,7 @@ func (mc *mapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val refl return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) } - currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.MapIndex(key)) + currEncoder, currVal, lookupErr := lookupElementEncoder(reg, encoder, val.MapIndex(key)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -111,7 +115,7 @@ func (mc *mapCodec) mapEncodeValue(ec EncodeContext, dw DocumentWriter, val refl continue } - err = currEncoder.EncodeValue(ec, vw, currVal) + err = currEncoder.EncodeValue(reg, vw, currVal) if err != nil { return err } diff --git a/bson/marshal.go b/bson/marshal.go index 573de16398..32151cc465 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -103,7 +103,7 @@ func MarshalValueWithRegistry(r *Registry, val interface{}) (Type, []byte, error enc := encPool.Get().(*Encoder) defer encPool.Put(enc) enc.Reset(vwFlusher) - enc.ec = EncodeContext{Registry: r} + enc.SetRegistry(r) if err := enc.Encode(val); err != nil { return 0, nil, err } @@ -127,7 +127,7 @@ func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error) defer encPool.Put(enc) enc.Reset(ejvw) - enc.ec = EncodeContext{Registry: DefaultRegistry} + enc.SetRegistry(DefaultRegistry) err := enc.Encode(val) if err != nil { diff --git a/bson/marshal_test.go b/bson/marshal_test.go index ecf67d8493..6013d7b911 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -149,7 +149,7 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { // different Registry is used. // Create a custom Registry that negates int32 values when encoding. - var encodeInt32 ValueEncoderFunc = func(_ EncodeContext, vw ValueWriter, val reflect.Value) error { + var encodeInt32 ValueEncoderFunc = func(_ *Registry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.Int32 { return fmt.Errorf("expected kind to be int32, got %v", val.Kind()) } diff --git a/bson/pointer_codec.go b/bson/pointer_codec.go index 0fd7fdb81c..425d371d0e 100644 --- a/bson/pointer_codec.go +++ b/bson/pointer_codec.go @@ -10,9 +10,6 @@ import ( "reflect" ) -var _ valueEncoder = &pointerCodec{} -var _ valueDecoder = &pointerCodec{} - // pointerCodec is the Codec used for pointers. type pointerCodec struct { ecache typeEncoderCache @@ -21,7 +18,7 @@ type pointerCodec struct { // EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil // or looking up an encoder for the type of value the pointer points to. -func (pc *pointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (pc *pointerCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.Ptr { if !val.IsValid() { return vw.WriteNull() @@ -38,15 +35,15 @@ func (pc *pointerCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflec if v == nil { return ErrNoEncoder{Type: typ} } - return v.EncodeValue(ec, vw, val.Elem()) + return v.EncodeValue(reg, vw, val.Elem()) } // TODO(charlie): handle concurrent requests for the same type - enc, err := ec.LookupEncoder(typ.Elem()) + enc, err := reg.LookupEncoder(typ.Elem()) enc = pc.ecache.LoadOrStore(typ, enc) if err != nil { return err } - return enc.EncodeValue(ec, vw, val.Elem()) + return enc.EncodeValue(reg, vw, val.Elem()) } // DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index df6e059c4a..adbb28d601 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -31,7 +31,7 @@ func registerPrimitiveCodecs(reg *Registry) { // // If the RawValue's Type is "invalid" and the RawValue's Value is not empty or // nil, then this method will return an error. -func rawValueEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func rawValueEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRawValue { return ValueEncoderError{ Name: "RawValueEncodeValue", @@ -65,7 +65,7 @@ func rawValueDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) err } // rawEncodeValue is the ValueEncoderFunc for Reader. -func rawEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func rawEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRaw { return ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: val} } diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index a2e1e61e89..a38113b72d 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -37,7 +37,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { type subtest struct { name string val interface{} - ectx *EncodeContext llvrw *valueReaderWriter invoke invoked err error @@ -45,7 +44,7 @@ func TestPrimitiveValueEncoders(t *testing.T) { testCases := []struct { name string - ve valueEncoder + ve ValueEncoder subtests []subtest }{ { @@ -56,7 +55,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { "wrong type", wrong, nil, - nil, nothing, ValueEncoderError{ Name: "RawValueEncodeValue", @@ -68,7 +66,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { "RawValue/success", RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, 3.14159)}, nil, - nil, writeDouble, nil, }, @@ -79,7 +76,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { Value: bsoncore.AppendDouble(nil, 3.14159), }, nil, - nil, nothing, fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x0"), }, @@ -90,7 +86,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { Value: bsoncore.AppendDouble(nil, 3.14159), }, nil, - nil, nothing, fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x8f"), }, @@ -104,14 +99,12 @@ func TestPrimitiveValueEncoders(t *testing.T) { "wrong type", wrong, nil, - nil, nothing, ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: reflect.ValueOf(wrong)}, }, { "WriteDocument Error", Raw{}, - nil, &valueReaderWriter{Err: errors.New("wd error"), ErrAfter: writeDocument}, writeDocument, errors.New("wd error"), @@ -119,7 +112,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { { "Raw.Elements Error", Raw{0xFF, 0x00, 0x00, 0x00, 0x00}, - nil, &valueReaderWriter{}, writeDocument, errors.New("length read exceeds number of bytes available. length=5 bytes=255"), @@ -127,7 +119,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { { "WriteDocumentElement Error", Raw(bytesFromDoc(D{{"foo", nil}})), - nil, &valueReaderWriter{Err: errors.New("wde error"), ErrAfter: writeDocumentElement}, writeDocumentElement, errors.New("wde error"), @@ -135,7 +126,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { { "encodeValue error", Raw(bytesFromDoc(D{{"foo", nil}})), - nil, &valueReaderWriter{Err: errors.New("ev error"), ErrAfter: writeNull}, writeNull, errors.New("ev error"), @@ -143,7 +133,6 @@ func TestPrimitiveValueEncoders(t *testing.T) { { "iterator error", Raw{0x0C, 0x00, 0x00, 0x00, 0x01, 'f', 'o', 'o', 0x00, 0x01, 0x02, 0x03}, - nil, &valueReaderWriter{}, writeDocumentElement, errors.New("not enough bytes available to read type. bytes=3 type=double"), @@ -164,16 +153,12 @@ func TestPrimitiveValueEncoders(t *testing.T) { t.Run(subtest.name, func(t *testing.T) { t.Parallel() - var ec EncodeContext - if subtest.ectx != nil { - ec = *subtest.ectx - } llvrw := new(valueReaderWriter) if subtest.llvrw != nil { llvrw = subtest.llvrw } llvrw.T = t - err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val)) + err := tc.ve.EncodeValue(nil, llvrw, reflect.ValueOf(subtest.val)) if !assert.CompareErrors(err, subtest.err) { t.Errorf("Errors do not match. got %v; want %v", err, subtest.err) } @@ -491,7 +476,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { testCases := []struct { name string - vd valueDecoder + vd ValueDecoder subtests []subtest }{ { diff --git a/bson/registry.go b/bson/registry.go index 9c680e4e66..71d65259d6 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -125,7 +125,7 @@ func NewRegistry() *Registry { // interface. To get the latter behavior, call RegisterHookEncoder instead. // // RegisterTypeEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc valueEncoder) { +func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) { r.typeEncoders.Store(valueType, enc) } @@ -139,7 +139,7 @@ func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc valueEncoder) // implements the interface. To get the latter behavior, call RegisterHookDecoder instead. // // RegisterTypeDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec valueDecoder) { +func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) { r.typeDecoders.Store(valueType, dec) } @@ -155,7 +155,7 @@ func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec valueDecoder) // reg.RegisterKindEncoder(reflect.Int32, myEncoder) // // RegisterKindEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc valueEncoder) { +func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { r.kindEncoders.Store(kind, enc) } @@ -171,7 +171,7 @@ func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc valueEncoder) { // reg.RegisterKindDecoder(reflect.Int32, myDecoder) // // RegisterKindDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec valueDecoder) { +func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) { r.kindDecoders.Store(kind, dec) } @@ -181,7 +181,7 @@ func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec valueDecoder) { // (i.e. iface.Kind() != reflect.Interface), this method will panic. // // RegisterInterfaceEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc valueEncoder) { +func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc ValueEncoder) { if iface.Kind() != reflect.Interface { panicStr := fmt.Errorf("RegisterInterfaceEncoder expects a type with kind reflect.Interface, "+ "got type %s with kind %s", iface, iface.Kind()) @@ -204,7 +204,7 @@ func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc valueEncoder // this method will panic. // // RegisterInterfaceDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec valueDecoder) { +func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder) { if iface.Kind() != reflect.Interface { panicStr := fmt.Errorf("RegisterInterfaceDecoder expects a type with kind reflect.Interface, "+ "got type %s with kind %s", iface, iface.Kind()) @@ -251,7 +251,7 @@ func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) { // // If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for // concurrent use by multiple goroutines after all codecs and encoders are registered. -func (r *Registry) LookupEncoder(valueType reflect.Type) (valueEncoder, error) { +func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { if valueType == nil { return nil, ErrNoEncoder{Type: valueType} } @@ -274,15 +274,15 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (valueEncoder, error) { return nil, ErrNoEncoder{Type: valueType} } -func (r *Registry) storeTypeEncoder(rt reflect.Type, enc valueEncoder) valueEncoder { +func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEncoder { return r.typeEncoders.LoadOrStore(rt, enc) } -func (r *Registry) lookupTypeEncoder(rt reflect.Type) (valueEncoder, bool) { +func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) { return r.typeEncoders.Load(rt) } -func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (valueEncoder, bool) { +func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (ValueEncoder, bool) { if valueType == nil { return nil, false } @@ -320,7 +320,7 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool // // If no decoder is found, an error of type ErrNoDecoder is returned. LookupDecoder is safe for // concurrent use by multiple goroutines after all codecs and decoders are registered. -func (r *Registry) LookupDecoder(valueType reflect.Type) (valueDecoder, error) { +func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { if valueType == nil { return nil, ErrNilType } @@ -343,15 +343,15 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (valueDecoder, error) { return nil, ErrNoDecoder{Type: valueType} } -func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (valueDecoder, bool) { +func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (ValueDecoder, bool) { return r.typeDecoders.Load(valueType) } -func (r *Registry) storeTypeDecoder(typ reflect.Type, dec valueDecoder) valueDecoder { +func (r *Registry) storeTypeDecoder(typ reflect.Type, dec ValueDecoder) ValueDecoder { return r.typeDecoders.LoadOrStore(typ, dec) } -func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (valueDecoder, bool) { +func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) { for _, idec := range r.interfaceDecoders { if valueType.Implements(idec.i) { return idec.vd, true @@ -383,10 +383,10 @@ func (r *Registry) LookupTypeMapEntry(bt Type) (reflect.Type, error) { type interfaceValueEncoder struct { i reflect.Type - ve valueEncoder + ve ValueEncoder } type interfaceValueDecoder struct { i reflect.Type - vd valueDecoder + vd ValueDecoder } diff --git a/bson/registry_examples_test.go b/bson/registry_examples_test.go index 39214f1b65..b866df8cdb 100644 --- a/bson/registry_examples_test.go +++ b/bson/registry_examples_test.go @@ -23,7 +23,7 @@ func ExampleRegistry_customEncoder() { negatedIntType := reflect.TypeOf(negatedInt(0)) negatedIntEncoder := func( - ec bson.EncodeContext, + _ *bson.Registry, vw bson.ValueWriter, val reflect.Value, ) error { @@ -162,7 +162,7 @@ func ExampleRegistry_RegisterKindEncoder() { // encoder for kind reflect.Int32. That way, even user-defined types with // underlying type int32 will be encoded as a BSON int64. int32To64Encoder := func( - ec bson.EncodeContext, + _ *bson.Registry, vw bson.ValueWriter, val reflect.Value, ) error { diff --git a/bson/registry_test.go b/bson/registry_test.go index 2a5150ad24..5375b6f444 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -62,7 +62,7 @@ func TestRegistryBuilder(t *testing.T) { reg.RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) want := []struct { t reflect.Type - c valueEncoder + c ValueEncoder }{ {reflect.TypeOf(ft1), fc3}, {reflect.TypeOf(ft2), fc2}, @@ -90,7 +90,7 @@ func TestRegistryBuilder(t *testing.T) { reg.RegisterKindEncoder(k4, fc4) want := []struct { k reflect.Kind - c valueEncoder + c ValueEncoder }{ {k1, fc3}, {k2, fc2}, @@ -173,8 +173,8 @@ func TestRegistryBuilder(t *testing.T) { }) t.Run("Lookup", func(t *testing.T) { type Codec interface { - valueEncoder - valueDecoder + ValueEncoder + ValueDecoder } var ( @@ -472,7 +472,7 @@ func TestRegistry(t *testing.T) { want := []struct { t reflect.Type - c valueEncoder + c ValueEncoder }{ {reflect.TypeOf(ft1), fc3}, {reflect.TypeOf(ft2), fc2}, @@ -502,7 +502,7 @@ func TestRegistry(t *testing.T) { want := []struct { k reflect.Kind - c valueEncoder + c ValueEncoder }{ {k1, fc3}, {k2, fc2}, @@ -588,8 +588,8 @@ func TestRegistry(t *testing.T) { t.Parallel() type Codec interface { - valueEncoder - valueDecoder + ValueEncoder + ValueDecoder } var ( @@ -887,7 +887,7 @@ func TestRegistry(t *testing.T) { } // get is only for testing as it does return if the value was found -func (c *kindEncoderCache) get(rt reflect.Kind) valueEncoder { +func (c *kindEncoderCache) get(rt reflect.Kind) ValueEncoder { e, _ := c.Load(rt) return e } @@ -948,7 +948,7 @@ type fakeCodec struct { num int } -func (*fakeCodec) EncodeValue(EncodeContext, ValueWriter, reflect.Value) error { +func (*fakeCodec) EncodeValue(*Registry, ValueWriter, reflect.Value) error { return nil } func (*fakeCodec) DecodeValue(DecodeContext, ValueReader, reflect.Value) error { diff --git a/bson/setter_getter.go b/bson/setter_getter.go index 3616d25603..069408c9ab 100644 --- a/bson/setter_getter.go +++ b/bson/setter_getter.go @@ -84,7 +84,7 @@ func SetterDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error } // GetterEncodeValue is the ValueEncoderFunc for Getter types. -func GetterEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func GetterEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Getter switch { case !val.IsValid(): @@ -112,9 +112,9 @@ func GetterEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) erro return vw.WriteNull() } vv := reflect.ValueOf(x) - encoder, err := ec.Registry.LookupEncoder(vv.Type()) + encoder, err := reg.LookupEncoder(vv.Type()) if err != nil { return err } - return encoder.EncodeValue(ec, vw, vv) + return encoder.EncodeValue(reg, vw, vv) } diff --git a/bson/slice_codec.go b/bson/slice_codec.go index b25efc6bff..d7db3cf9da 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -12,6 +12,10 @@ import ( "reflect" ) +var ( + defaultSliceCodec = &sliceCodec{} +) + // sliceCodec is the Codec used for slice values. type sliceCodec struct { // encodeNilAsEmpty causes EncodeValue to marshal nil Go slices as empty BSON arrays instead of @@ -20,12 +24,12 @@ type sliceCodec struct { } // EncodeValue is the ValueEncoder for slice types. -func (sc sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc sliceCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Slice { return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } - if val.IsNil() && !sc.encodeNilAsEmpty && !ec.nilSliceAsEmpty { + if val.IsNil() && !sc.encodeNilAsEmpty { return vw.WriteNull() } @@ -46,7 +50,7 @@ func (sc sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.V } for _, e := range d { - err = encodeElement(ec, dw, e) + err = encodeElement(reg, dw, e) if err != nil { return err } @@ -61,13 +65,13 @@ func (sc sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.V } elemType := val.Type().Elem() - encoder, err := ec.LookupEncoder(elemType) + encoder, err := reg.LookupEncoder(elemType) if err != nil && elemType.Kind() != reflect.Interface { return err } for idx := 0; idx < val.Len(); idx++ { - currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx)) + currEncoder, currVal, lookupErr := lookupElementEncoder(reg, encoder, val.Index(idx)) if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -85,7 +89,7 @@ func (sc sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.V continue } - err = currEncoder.EncodeValue(ec, vw, currVal) + err = currEncoder.EncodeValue(reg, vw, currVal) if err != nil { return err } diff --git a/bson/string_codec.go b/bson/string_codec.go index 7d7205f34d..de73fc6f0d 100644 --- a/bson/string_codec.go +++ b/bson/string_codec.go @@ -19,13 +19,12 @@ type stringCodec struct { decodeObjectIDAsHex bool } -// Assert that defaultStringCodec satisfies the typeDecoder interface, which allows it to be -// used by collection type decoders (e.g. map, slice, etc) to set individual values in a -// collection. -var _ typeDecoder = (*stringCodec)(nil) +var ( + defaultStringCodec = &stringCodec{} +) // EncodeValue is the ValueEncoder for string types. -func (sc *stringCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc *stringCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.String { return ValueEncoderError{ Name: "StringEncodeValue", diff --git a/bson/struct_codec.go b/bson/struct_codec.go index c8f783a00f..b489b1dc56 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -16,6 +16,10 @@ import ( "time" ) +var ( + defaultStructCodec = newStructCodec(DefaultStructTagParser) +) + // DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type. type DecodeError struct { keys []string @@ -71,10 +75,9 @@ type structCodec struct { // a duplicate field in the marshaled BSON when the "inline" struct tag option is set. The // default value is true. overwriteDuplicatedInlinedFields bool -} -var _ valueEncoder = &structCodec{} -var _ valueDecoder = &structCodec{} + useJSONStructTags bool +} // newStructCodec returns a StructCodec that uses p for struct tag parsing. func newStructCodec(p StructTagParser) *structCodec { @@ -85,12 +88,12 @@ func newStructCodec(p StructTagParser) *structCodec { } // EncodeValue handles encoding generic struct types. -func (sc *structCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (sc *structCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Struct { return ValueEncoderError{Name: "structCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } - sd, err := sc.describeStruct(ec.Registry, val.Type(), ec.useJSONStructTags, ec.errorOnInlineDuplicates) + sd, err := sc.describeStruct(reg, val.Type(), sc.useJSONStructTags, !sc.overwriteDuplicatedInlinedFields) if err != nil { return err } @@ -110,7 +113,7 @@ func (sc *structCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect } } - desc.encoder, rv, err = lookupElementEncoder(ec, desc.encoder, rv) + desc.encoder, rv, err = lookupElementEncoder(reg, desc.encoder, rv) if err != nil && !errors.Is(err, errInvalidValue) { return err @@ -145,7 +148,7 @@ func (sc *structCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect // nil interface separately. empty = rv.IsNil() } else { - empty = isEmpty(rv, sc.encodeOmitDefaultStruct || ec.omitZeroStruct) + empty = isEmpty(rv, sc.encodeOmitDefaultStruct) } if desc.omitEmpty && empty { continue @@ -156,18 +159,19 @@ func (sc *structCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect return err } - ectx := EncodeContext{ - Registry: ec.Registry, - minSize: desc.minSize || ec.minSize, - errorOnInlineDuplicates: ec.errorOnInlineDuplicates, - stringifyMapKeysWithFmt: ec.stringifyMapKeysWithFmt, - nilMapAsEmpty: ec.nilMapAsEmpty, - nilSliceAsEmpty: ec.nilSliceAsEmpty, - nilByteSliceAsEmpty: ec.nilByteSliceAsEmpty, - omitZeroStruct: ec.omitZeroStruct, - useJSONStructTags: ec.useJSONStructTags, + // defaultUIntCodec.encodeToMinSize = desc.minSize + switch v := encoder.(type) { + case *uintCodec: + encoder = &uintCodec{ + encodeToMinSize: v.encodeToMinSize || desc.minSize, + } + case *intCodec: + encoder = &intCodec{ + encodeToMinSize: v.encodeToMinSize || desc.minSize, + } } - err = encoder.EncodeValue(ectx, vw2, rv) + + err = encoder.EncodeValue(reg, vw2, rv) if err != nil { return err } @@ -180,7 +184,7 @@ func (sc *structCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect return exists } - return (&mapCodec{}).mapEncodeValue(ec, dw, rv, collisionFn) + return (&mapCodec{}).mapEncodeValue(reg, dw, rv, collisionFn) } return dw.WriteDocumentEnd() @@ -239,7 +243,7 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect val.Set(deepZero(val.Type())) } - var decoder valueDecoder + var decoder ValueDecoder var inlineMap reflect.Value if sd.inlineMap >= 0 { inlineMap = val.Field(sd.inlineMap) @@ -385,8 +389,8 @@ type fieldDescription struct { minSize bool truncate bool inline []int - encoder valueEncoder - decoder valueDecoder + encoder ValueEncoder + decoder ValueDecoder } type byIndex []fieldDescription diff --git a/bson/time_codec.go b/bson/time_codec.go index 6bbe300e4a..535861ed71 100644 --- a/bson/time_codec.go +++ b/bson/time_codec.go @@ -22,9 +22,9 @@ type timeCodec struct { useLocalTimeZone bool } -// Assert that defaultTimeCodec satisfies the typeDecoder interface, which allows it to be used -// by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. -var _ typeDecoder = (*timeCodec)(nil) +var ( + defaultTimeCodec = &timeCodec{} +) func (tc *timeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tTime { @@ -99,7 +99,7 @@ func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.V } // EncodeValue is the ValueEncoderFunc for time.TIme. -func (tc *timeCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error { +func (tc *timeCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTime { return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} } diff --git a/bson/uint_codec.go b/bson/uint_codec.go index 27a297d043..f63404e934 100644 --- a/bson/uint_codec.go +++ b/bson/uint_codec.go @@ -19,20 +19,20 @@ type uintCodec struct { encodeToMinSize bool } -// Assert that defaultUIntCodec satisfies the typeDecoder interface, which allows it to be used -// by collection type decoders (e.g. map, slice, etc) to set individual values in a collection. -var _ typeDecoder = (*uintCodec)(nil) +var ( + defaultUIntCodec = &uintCodec{} +) // EncodeValue is the ValueEncoder for uint types. -func (uic *uintCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { +func (uic *uintCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Uint8, reflect.Uint16: return vw.WriteInt32(int32(val.Uint())) case reflect.Uint, reflect.Uint32, reflect.Uint64: u64 := val.Uint() - // If ec.MinSize or if encodeToMinSize is true for a non-uint64 value we should write val as an int32 - useMinSize := ec.minSize || (uic.encodeToMinSize && val.Kind() != reflect.Uint64) + // If encodeToMinSize is true for a non-uint64 value we should write val as an int32 + useMinSize := uic.encodeToMinSize && val.Kind() != reflect.Uint64 if u64 <= math.MaxInt32 && useMinSize { return vw.WriteInt32(int32(u64)) diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 677bb44868..a098880b4c 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -39,7 +39,7 @@ type negateCodec struct { ID int64 `bson:"_id"` } -func (e *negateCodec) EncodeValue(_ bson.EncodeContext, vw bson.ValueWriter, val reflect.Value) error { +func (e *negateCodec) EncodeValue(_ *bson.Registry, vw bson.ValueWriter, val reflect.Value) error { return vw.WriteInt64(val.Int()) } From 394b263a606e1b3f221cee02e78b1f9567a8ac1f Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 10 May 2024 17:30:21 -0400 Subject: [PATCH 05/15] WIP --- bson/array_codec.go | 2 +- bson/bson_test.go | 6 +- bson/bsoncodec.go | 11 +- bson/byte_slice_codec.go | 2 +- bson/cond_addr_codec.go | 2 +- bson/cond_addr_codec_test.go | 4 +- bson/decoder_test.go | 4 +- bson/default_value_decoders.go | 2 +- bson/default_value_decoders_test.go | 89 +-- bson/default_value_encoders.go | 144 ++-- bson/default_value_encoders_test.go | 109 +-- bson/empty_interface_codec.go | 2 +- bson/encoder.go | 76 ++- bson/int_codec.go | 156 ++++- bson/map_codec.go | 4 +- bson/marshal_test.go | 7 +- bson/mgoregistry.go | 67 +- bson/pointer_codec.go | 2 +- bson/primitive_codecs.go | 16 +- bson/raw_value_test.go | 14 +- bson/registry.go | 310 +++++---- bson/registry_examples_test.go | 40 +- bson/registry_test.go | 779 ++++++++-------------- bson/setter_getter.go | 2 +- bson/slice_codec.go | 2 +- bson/string_codec.go | 2 +- bson/struct_codec.go | 50 +- bson/time_codec.go | 2 +- bson/unmarshal_test.go | 5 +- bson/unmarshal_value_test.go | 10 +- internal/integration/client_test.go | 9 +- internal/integration/crud_spec_test.go | 8 +- internal/integration/database_test.go | 8 +- internal/integration/unified_spec_test.go | 10 +- mongo/database_test.go | 4 +- mongo/options/clientoptions_test.go | 2 +- mongo/read_write_concern_spec_test.go | 8 +- x/mongo/driver/topology/server_options.go | 2 +- 38 files changed, 993 insertions(+), 979 deletions(-) diff --git a/bson/array_codec.go b/bson/array_codec.go index 9ea43d4028..757fd60004 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -20,7 +20,7 @@ var ( ) // EncodeValue is the ValueEncoder for bsoncore.Array values. -func (ac *arrayCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func (ac *arrayCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreArray { return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } diff --git a/bson/bson_test.go b/bson/bson_test.go index 5d99e066a8..246b0e913a 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -358,12 +358,12 @@ func TestMapCodec(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - mapRegistry := NewRegistry() - mapRegistry.RegisterKindEncoder(reflect.Map, tc.codec) + mapRegistry := NewRegistryBuilder() + mapRegistry.RegisterKindEncoder(reflect.Map, func() ValueEncoder { return tc.codec }) buf := new(bytes.Buffer) vw := NewValueWriter(buf) enc := NewEncoder(vw) - enc.SetRegistry(mapRegistry) + enc.SetRegistry(mapRegistry.Build()) err := enc.Encode(mapObj) assert.Nil(t, err, "Encode error: %v", err) str := buf.String() diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 68c108e104..db176ad906 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -103,21 +103,26 @@ type DecodeContext struct { zeroStructs bool } +// EncoderRegistry is an interface provides a ValueEncoder based on the given reflect.Type. +type EncoderRegistry interface { + LookupEncoder(reflect.Type) (ValueEncoder, error) +} + // ValueEncoder is the interface implemented by types that can encode a provided Go type to BSON. // The value to encode is provided as a reflect.Value and a bson.ValueWriter is used within the // EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc // is provided to allow use of a function with the correct signature as a ValueEncoder. A pointer // to a Registry instance is provided to allow implementations to lookup further ValueEncoders. type ValueEncoder interface { - EncodeValue(*Registry, ValueWriter, reflect.Value) error + EncodeValue(EncoderRegistry, ValueWriter, reflect.Value) error } // ValueEncoderFunc is an adapter function that allows a function with the correct signature to be // used as a ValueEncoder. -type ValueEncoderFunc func(*Registry, ValueWriter, reflect.Value) error +type ValueEncoderFunc func(EncoderRegistry, ValueWriter, reflect.Value) error // EncodeValue implements the ValueEncoder interface. -func (fn ValueEncoderFunc) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (fn ValueEncoderFunc) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { return fn(reg, vw, val) } diff --git a/bson/byte_slice_codec.go b/bson/byte_slice_codec.go index 83dba12ecb..e012c3d913 100644 --- a/bson/byte_slice_codec.go +++ b/bson/byte_slice_codec.go @@ -23,7 +23,7 @@ var ( ) // EncodeValue is the ValueEncoder for []byte. -func (bsc *byteSliceCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func (bsc *byteSliceCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tByteSlice { return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } diff --git a/bson/cond_addr_codec.go b/bson/cond_addr_codec.go index d1baf96ef4..ef87da6250 100644 --- a/bson/cond_addr_codec.go +++ b/bson/cond_addr_codec.go @@ -17,7 +17,7 @@ type condAddrEncoder struct { } // EncodeValue is the ValueEncoderFunc for a value that may be addressable. -func (cae *condAddrEncoder) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (cae *condAddrEncoder) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.CanAddr() { return cae.canAddrEnc.EncodeValue(reg, vw, val) } diff --git a/bson/cond_addr_codec_test.go b/bson/cond_addr_codec_test.go index 26cd9c5534..ee4f61f3a7 100644 --- a/bson/cond_addr_codec_test.go +++ b/bson/cond_addr_codec_test.go @@ -22,11 +22,11 @@ func TestCondAddrCodec(t *testing.T) { t.Run("addressEncode", func(t *testing.T) { invoked := 0 - encode1 := ValueEncoderFunc(func(*Registry, ValueWriter, reflect.Value) error { + encode1 := ValueEncoderFunc(func(EncoderRegistry, ValueWriter, reflect.Value) error { invoked = 1 return nil }) - encode2 := ValueEncoderFunc(func(*Registry, ValueWriter, reflect.Value) error { + encode2 := ValueEncoderFunc(func(EncoderRegistry, ValueWriter, reflect.Value) error { invoked = 2 return nil }) diff --git a/bson/decoder_test.go b/bson/decoder_test.go index 8fe8d07480..3b96f63559 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -29,7 +29,7 @@ func TestBasicDecode(t *testing.T) { got := reflect.New(tc.sType).Elem() vr := NewValueReader(tc.data) - reg := DefaultRegistry + reg := NewRegistryBuilder().Build() decoder, err := reg.LookupDecoder(reflect.TypeOf(got)) noerr(t, err) err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got) @@ -199,7 +199,7 @@ func TestDecoderv2(t *testing.T) { t.Run("SetRegistry", func(t *testing.T) { t.Parallel() - r1, r2 := DefaultRegistry, NewRegistry() + r1, r2 := DefaultRegistry, NewRegistryBuilder().Build() dc1 := DecodeContext{Registry: r1} dc2 := DecodeContext{Registry: r2} dec := NewDecoder(NewValueReader([]byte{})) diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index b105ab0715..56331da9a8 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -36,7 +36,7 @@ func (d decodeBinaryError) Error() string { // There is no support for decoding map[string]interface{} because there is no decoder for // interface{}, so users must either register this decoder themselves or use the // EmptyInterfaceDecoder available in the bson package. -func registerDefaultDecoders(reg *Registry) { +func registerDefaultDecoders(reg *RegistryBuilder) { if reg == nil { panic(errors.New("argument to RegisterDefaultDecoders must not be nil")) } diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index d2434a19e1..0e32e64ba7 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -819,7 +819,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", map[string]string{}, - &DecodeContext{Registry: newTestRegistry()}, + &DecodeContext{Registry: newTestRegistryBuilder().Build()}, &valueReaderWriter{}, readDocument, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -905,7 +905,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", [1]string{}, - &DecodeContext{Registry: newTestRegistry()}, + &DecodeContext{Registry: newTestRegistryBuilder().Build()}, &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -999,7 +999,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", []string{}, - &DecodeContext{Registry: newTestRegistry()}, + &DecodeContext{Registry: newTestRegistryBuilder().Build()}, &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -3310,7 +3310,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() - dc := DecodeContext{Registry: newTestRegistry()} + dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} want := ErrNoTypeMapEntry{Type: tc.bsontype} got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) if !assert.CompareErrors(got, want) { @@ -3323,8 +3323,9 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() - reg := newTestRegistry() - reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) + reg := newTestRegistryBuilder(). + RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). + Build() dc := DecodeContext{ Registry: reg, } @@ -3341,9 +3342,10 @@ func TestDefaultValueDecoders(t *testing.T) { } want := errors.New("DecodeValue failure error") llc := &llCodec{t: t, err: want} - reg := newTestRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf(tc.val), llc) - reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) + reg := newTestRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). + RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). + Build() dc := DecodeContext{ Registry: reg, } @@ -3356,9 +3358,10 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("Success", func(t *testing.T) { want := tc.val llc := &llCodec{t: t, decodeval: tc.val} - reg := newTestRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf(tc.val), llc) - reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) + reg := newTestRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). + RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). + Build() dc := DecodeContext{ Registry: reg, } @@ -3395,7 +3398,7 @@ func TestDefaultValueDecoders(t *testing.T) { llvr := &valueReaderWriter{BSONType: TypeDouble} want := ErrNoTypeMapEntry{Type: TypeDouble} val := reflect.New(tEmpty).Elem() - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: newTestRegistry()}, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: newTestRegistryBuilder().Build()}, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3416,15 +3419,15 @@ func TestDefaultValueDecoders(t *testing.T) { // registering a custom type map entry for both Type(0) anad TypeEmbeddedDocument should cause // both top-level and embedded documents to decode to registered type when unmarshalling to interface{} - topLevelReg := newTestRegistry() - registerDefaultEncoders(topLevelReg) - registerDefaultDecoders(topLevelReg) - topLevelReg.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) + topLevelRb := newTestRegistryBuilder() + registerDefaultEncoders(topLevelRb) + registerDefaultDecoders(topLevelRb) + topLevelRb.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) - embeddedReg := newTestRegistry() - registerDefaultEncoders(embeddedReg) - registerDefaultDecoders(embeddedReg) - embeddedReg.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) + embeddedRb := newTestRegistryBuilder() + registerDefaultEncoders(embeddedRb) + registerDefaultDecoders(embeddedRb) + embeddedRb.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) // create doc {"nested": {"foo": 1}} innerDoc := bsoncore.BuildDocument( @@ -3445,8 +3448,8 @@ func TestDefaultValueDecoders(t *testing.T) { name string registry *Registry }{ - {"top level", topLevelReg}, - {"embedded", embeddedReg}, + {"top level", topLevelRb.Build()}, + {"embedded", embeddedRb.Build()}, } for _, tc := range testCases { var got interface{} @@ -3464,10 +3467,11 @@ func TestDefaultValueDecoders(t *testing.T) { // If a type map entry is registered for TypeEmbeddedDocument, the decoder should use ancestor // information if available instead of the registered entry. - reg := newTestRegistry() - registerDefaultEncoders(reg) - registerDefaultDecoders(reg) - reg.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(M{})) + rb := newTestRegistryBuilder() + registerDefaultEncoders(rb) + registerDefaultDecoders(rb) + rb.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(M{})) + reg := rb.Build() // build document {"nested": {"foo": 10}} inner := bsoncore.BuildDocument( @@ -3500,8 +3504,9 @@ func TestDefaultValueDecoders(t *testing.T) { emptyInterfaceErrorDecode := func(DecodeContext, ValueReader, reflect.Value) error { return decodeValueError } - emptyInterfaceErrorRegistry := newTestRegistry() - emptyInterfaceErrorRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) + emptyInterfaceErrorRegistry := newTestRegistryBuilder(). + RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)). + Build() // Set up a document {foo: 10} and an error that would happen if the value were decoded into interface{} // using the registry defined above. @@ -3553,9 +3558,9 @@ func TestDefaultValueDecoders(t *testing.T) { outerDoc := buildDocument(bsoncore.AppendDocumentElement(nil, "first", inner1Doc)) // Use a registry that has all default decoders with the custom interface{} decoder that always errors. - nestedRegistry := newTestRegistry() - registerDefaultDecoders(nestedRegistry) - nestedRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) + nestedRegistryBuilder := newTestRegistryBuilder() + registerDefaultDecoders(nestedRegistryBuilder) + nestedRegistryBuilder.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) nestedErr := &DecodeError{ keys: []string{"fourth", "1", "third", "randomKey", "second", "first"}, wrapped: decodeValueError, @@ -3640,7 +3645,7 @@ func TestDefaultValueDecoders(t *testing.T) { "struct - no decoder found", stringStruct{}, NewValueReader(docBytes), - newTestRegistry(), + newTestRegistryBuilder().Build(), defaultTestStructCodec, stringStructErr, }, @@ -3648,7 +3653,7 @@ func TestDefaultValueDecoders(t *testing.T) { "deeply nested struct", outer{}, NewValueReader(outerDoc), - nestedRegistry, + nestedRegistryBuilder.Build(), defaultTestStructCodec, nestedErr, }, @@ -3705,11 +3710,11 @@ func TestDefaultValueDecoders(t *testing.T) { bsoncore.BuildArrayElement(nil, "boolArray", trueValue), ) - reg := newTestRegistry() - registerDefaultDecoders(reg) - reg.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) + rb := newTestRegistryBuilder() + registerDefaultDecoders(rb) + rb.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) - dc := DecodeContext{Registry: reg} + dc := DecodeContext{Registry: rb.Build()} vr := NewValueReader(docBytes) val := reflect.New(tD).Elem() err := dDecodeValue(dc, vr, val) @@ -3774,8 +3779,8 @@ func buildDocument(elems []byte) []byte { } func buildDefaultRegistry() *Registry { - reg := newTestRegistry() - registerDefaultEncoders(reg) - registerDefaultDecoders(reg) - return reg + rb := newTestRegistryBuilder() + registerDefaultEncoders(rb) + registerDefaultDecoders(rb) + return rb.Build() } diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index df80ef0080..ca6a4a9cad 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -28,7 +28,7 @@ var sliceWriterPool = sync.Pool{ }, } -func encodeElement(reg *Registry, dw DocumentWriter, e E) error { +func encodeElement(reg EncoderRegistry, dw DocumentWriter, e E) error { vw, err := dw.WriteDocumentElement(e.Key) if err != nil { return err @@ -50,59 +50,59 @@ func encodeElement(reg *Registry, dw DocumentWriter, e E) error { } // registerDefaultEncoders will register the default encoder methods with the provided Registry. -func registerDefaultEncoders(reg *Registry) { - if reg == nil { +func registerDefaultEncoders(rb *RegistryBuilder) { + if rb == nil { panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) } - intEncoder := &intCodec{} - uintEncoder := &uintCodec{} - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{}) - reg.RegisterTypeEncoder(tTime, &timeCodec{}) - reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{}) - reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{}) - reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue)) - reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue)) - reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue)) - reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue)) - reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue)) - reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue)) - reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue)) - reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue)) - reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue)) - reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue)) - reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue)) - reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue)) - reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue)) - reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue)) - reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue)) - reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue)) - reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue)) - reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue)) - reg.RegisterKindEncoder(reflect.Int, intEncoder) - reg.RegisterKindEncoder(reflect.Int8, intEncoder) - reg.RegisterKindEncoder(reflect.Int16, intEncoder) - reg.RegisterKindEncoder(reflect.Int32, intEncoder) - reg.RegisterKindEncoder(reflect.Int64, intEncoder) - reg.RegisterKindEncoder(reflect.Uint, uintEncoder) - reg.RegisterKindEncoder(reflect.Uint8, uintEncoder) - reg.RegisterKindEncoder(reflect.Uint16, uintEncoder) - reg.RegisterKindEncoder(reflect.Uint32, uintEncoder) - reg.RegisterKindEncoder(reflect.Uint64, uintEncoder) - reg.RegisterKindEncoder(reflect.Float32, ValueEncoderFunc(floatEncodeValue)) - reg.RegisterKindEncoder(reflect.Float64, ValueEncoderFunc(floatEncodeValue)) - reg.RegisterKindEncoder(reflect.Array, ValueEncoderFunc(arrayEncodeValue)) - reg.RegisterKindEncoder(reflect.Map, &mapCodec{}) - reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) - reg.RegisterKindEncoder(reflect.String, &stringCodec{}) - reg.RegisterKindEncoder(reflect.Struct, newStructCodec(DefaultStructTagParser)) - reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{}) - reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue)) - reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue)) - reg.RegisterInterfaceEncoder(tProxy, ValueEncoderFunc(proxyEncodeValue)) + intEncoder := func() ValueEncoder { return &intCodec{} } + floatEncoder := func() ValueEncoder { return ValueEncoderFunc(floatEncodeValue) } + rb.RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{} }). + RegisterTypeEncoder(tTime, func() ValueEncoder { return &timeCodec{} }). + RegisterTypeEncoder(tEmpty, func() ValueEncoder { return &emptyInterfaceCodec{} }). + RegisterTypeEncoder(tCoreArray, func() ValueEncoder { return &arrayCodec{} }). + RegisterTypeEncoder(tOID, func() ValueEncoder { return ValueEncoderFunc(objectIDEncodeValue) }). + RegisterTypeEncoder(tDecimal, func() ValueEncoder { return ValueEncoderFunc(decimal128EncodeValue) }). + RegisterTypeEncoder(tJSONNumber, func() ValueEncoder { return ValueEncoderFunc(jsonNumberEncodeValue) }). + RegisterTypeEncoder(tURL, func() ValueEncoder { return ValueEncoderFunc(urlEncodeValue) }). + RegisterTypeEncoder(tJavaScript, func() ValueEncoder { return ValueEncoderFunc(javaScriptEncodeValue) }). + RegisterTypeEncoder(tSymbol, func() ValueEncoder { return ValueEncoderFunc(symbolEncodeValue) }). + RegisterTypeEncoder(tBinary, func() ValueEncoder { return ValueEncoderFunc(binaryEncodeValue) }). + RegisterTypeEncoder(tUndefined, func() ValueEncoder { return ValueEncoderFunc(undefinedEncodeValue) }). + RegisterTypeEncoder(tDateTime, func() ValueEncoder { return ValueEncoderFunc(dateTimeEncodeValue) }). + RegisterTypeEncoder(tNull, func() ValueEncoder { return ValueEncoderFunc(nullEncodeValue) }). + RegisterTypeEncoder(tRegex, func() ValueEncoder { return ValueEncoderFunc(regexEncodeValue) }). + RegisterTypeEncoder(tDBPointer, func() ValueEncoder { return ValueEncoderFunc(dbPointerEncodeValue) }). + RegisterTypeEncoder(tTimestamp, func() ValueEncoder { return ValueEncoderFunc(timestampEncodeValue) }). + RegisterTypeEncoder(tMinKey, func() ValueEncoder { return ValueEncoderFunc(minKeyEncodeValue) }). + RegisterTypeEncoder(tMaxKey, func() ValueEncoder { return ValueEncoderFunc(maxKeyEncodeValue) }). + RegisterTypeEncoder(tCoreDocument, func() ValueEncoder { return ValueEncoderFunc(coreDocumentEncodeValue) }). + RegisterTypeEncoder(tCodeWithScope, func() ValueEncoder { return ValueEncoderFunc(codeWithScopeEncodeValue) }). + RegisterKindEncoder(reflect.Bool, func() ValueEncoder { return ValueEncoderFunc(booleanEncodeValue) }). + RegisterKindEncoder(reflect.Int, intEncoder). + RegisterKindEncoder(reflect.Int8, intEncoder). + RegisterKindEncoder(reflect.Int16, intEncoder). + RegisterKindEncoder(reflect.Int32, intEncoder). + RegisterKindEncoder(reflect.Int64, intEncoder). + RegisterKindEncoder(reflect.Uint, intEncoder). + RegisterKindEncoder(reflect.Uint8, intEncoder). + RegisterKindEncoder(reflect.Uint16, intEncoder). + RegisterKindEncoder(reflect.Uint32, intEncoder). + RegisterKindEncoder(reflect.Uint64, intEncoder). + RegisterKindEncoder(reflect.Float32, floatEncoder). + RegisterKindEncoder(reflect.Float64, floatEncoder). + RegisterKindEncoder(reflect.Array, func() ValueEncoder { return ValueEncoderFunc(arrayEncodeValue) }). + RegisterKindEncoder(reflect.Map, func() ValueEncoder { return &mapCodec{} }). + RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{} }). + RegisterKindEncoder(reflect.String, func() ValueEncoder { return &stringCodec{} }). + RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return newStructCodec(DefaultStructTagParser) }). + RegisterKindEncoder(reflect.Ptr, func() ValueEncoder { return &pointerCodec{} }). + RegisterInterfaceEncoder(tValueMarshaler, func() ValueEncoder { return ValueEncoderFunc(valueMarshalerEncodeValue) }). + RegisterInterfaceEncoder(tMarshaler, func() ValueEncoder { return ValueEncoderFunc(marshalerEncodeValue) }). + RegisterInterfaceEncoder(tProxy, func() ValueEncoder { return ValueEncoderFunc(proxyEncodeValue) }) } // booleanEncodeValue is the ValueEncoderFunc for bool types. -func booleanEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func booleanEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Bool { return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } @@ -114,7 +114,7 @@ func fitsIn32Bits(i int64) bool { } // floatEncodeValue is the ValueEncoderFunc for float types. -func floatEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func floatEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Float32, reflect.Float64: return vw.WriteDouble(val.Float()) @@ -124,7 +124,7 @@ func floatEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // objectIDEncodeValue is the ValueEncoderFunc for ObjectID. -func objectIDEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func objectIDEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tOID { return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} } @@ -132,7 +132,7 @@ func objectIDEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // decimal128EncodeValue is the ValueEncoderFunc for Decimal128. -func decimal128EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func decimal128EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDecimal { return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} } @@ -140,7 +140,7 @@ func decimal128EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. -func jsonNumberEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func jsonNumberEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJSONNumber { return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} } @@ -164,7 +164,7 @@ func jsonNumberEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) err } // urlEncodeValue is the ValueEncoderFunc for url.URL. -func urlEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func urlEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tURL { return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} } @@ -173,7 +173,7 @@ func urlEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // arrayEncodeValue is the ValueEncoderFunc for array types. -func arrayEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func arrayEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} } @@ -243,7 +243,7 @@ func arrayEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { return aw.WriteArrayEnd() } -func lookupElementEncoder(reg *Registry, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { +func lookupElementEncoder(reg EncoderRegistry, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { if origEncoder != nil || (currVal.Kind() != reflect.Interface) { return origEncoder, currVal, nil } @@ -257,7 +257,7 @@ func lookupElementEncoder(reg *Registry, origEncoder ValueEncoder, currVal refle } // valueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. -func valueMarshalerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func valueMarshalerEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement ValueMarshaler switch { case !val.IsValid(): @@ -285,7 +285,7 @@ func valueMarshalerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) e } // marshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. -func marshalerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func marshalerEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Marshaler switch { case !val.IsValid(): @@ -313,7 +313,7 @@ func marshalerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // proxyEncodeValue is the ValueEncoderFunc for Proxy implementations. -func proxyEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func proxyEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Proxy switch { case !val.IsValid(): @@ -357,7 +357,7 @@ func proxyEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { } // javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. -func javaScriptEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func javaScriptEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJavaScript { return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} } @@ -366,7 +366,7 @@ func javaScriptEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // symbolEncodeValue is the ValueEncoderFunc for the Symbol type. -func symbolEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func symbolEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tSymbol { return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} } @@ -375,7 +375,7 @@ func symbolEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // binaryEncodeValue is the ValueEncoderFunc for Binary. -func binaryEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func binaryEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tBinary { return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} } @@ -385,7 +385,7 @@ func binaryEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // undefinedEncodeValue is the ValueEncoderFunc for Undefined. -func undefinedEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func undefinedEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tUndefined { return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} } @@ -394,7 +394,7 @@ func undefinedEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // dateTimeEncodeValue is the ValueEncoderFunc for DateTime. -func dateTimeEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func dateTimeEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDateTime { return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} } @@ -403,7 +403,7 @@ func dateTimeEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // nullEncodeValue is the ValueEncoderFunc for Null. -func nullEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func nullEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tNull { return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} } @@ -412,7 +412,7 @@ func nullEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // regexEncodeValue is the ValueEncoderFunc for Regex. -func regexEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func regexEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRegex { return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} } @@ -423,7 +423,7 @@ func regexEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. -func dbPointerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func dbPointerEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDBPointer { return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} } @@ -434,7 +434,7 @@ func dbPointerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // timestampEncodeValue is the ValueEncoderFunc for Timestamp. -func timestampEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func timestampEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTimestamp { return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } @@ -445,7 +445,7 @@ func timestampEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // minKeyEncodeValue is the ValueEncoderFunc for MinKey. -func minKeyEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func minKeyEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMinKey { return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} } @@ -454,7 +454,7 @@ func minKeyEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. -func maxKeyEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func maxKeyEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMaxKey { return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} } @@ -463,7 +463,7 @@ func maxKeyEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. -func coreDocumentEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func coreDocumentEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreDocument { return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -474,7 +474,7 @@ func coreDocumentEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) err } // codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. -func codeWithScopeEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func codeWithScopeEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCodeWithScope { return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} } diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 0cc12ce597..cd8efe72db 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -10,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "net/url" "reflect" "strings" @@ -37,11 +38,11 @@ func TestDefaultValueEncoders(t *testing.T) { var wrong = func(string, string) string { return "wrong" } type mybool bool - // type myint8 int8 - // type myint16 int16 - // type myint32 int32 - // type myint64 int64 - // type myint int + type myint8 int8 + type myint16 int16 + type myint32 int32 + type myint64 int64 + type myint int type myuint8 uint8 type myuint16 uint16 type myuint32 uint32 @@ -92,51 +93,52 @@ func TestDefaultValueEncoders(t *testing.T) { {"reflection path", mybool(true), nil, nil, writeBoolean, nil}, }, }, - /* - { - "IntEncodeValue", - ValueEncoderFunc(intEncodeValue), - []subtest{ - { - "wrong type", - wrong, - nil, - nil, - nothing, - ValueEncoderError{ - Name: "IntEncodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, - Received: reflect.ValueOf(wrong), + { + "IntEncodeValue", + &intCodec{}, + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + nothing, + ValueEncoderError{ + Name: "IntEncodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, + Received: reflect.ValueOf(wrong), }, - {"int8/fast path", int8(127), nil, nil, writeInt32, nil}, - {"int16/fast path", int16(32767), nil, nil, writeInt32, nil}, - {"int32/fast path", int32(2147483647), nil, nil, writeInt32, nil}, - {"int64/fast path", int64(1234567890987), nil, nil, writeInt64, nil}, - {"int64/fast path - minsize", int64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int/fast path - positive int32", int(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, - {"int/fast path - negative int32", int(math.MinInt32 + 1), nil, nil, writeInt32, nil}, - {"int/fast path - MaxInt32", int(math.MaxInt32), nil, nil, writeInt32, nil}, - {"int/fast path - MinInt32", int(math.MinInt32), nil, nil, writeInt32, nil}, - {"int8/reflection path", myint8(127), nil, nil, writeInt32, nil}, - {"int16/reflection path", myint16(32767), nil, nil, writeInt32, nil}, - {"int32/reflection path", myint32(2147483647), nil, nil, writeInt32, nil}, - {"int64/reflection path", myint64(1234567890987), nil, nil, writeInt64, nil}, - {"int64/reflection path - minsize", myint64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int/reflection path - positive int32", myint(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, - {"int/reflection path - negative int32", myint(math.MinInt32 + 1), nil, nil, writeInt32, nil}, - {"int/reflection path - MaxInt32", myint(math.MaxInt32), nil, nil, writeInt32, nil}, - {"int/reflection path - MinInt32", myint(math.MinInt32), nil, nil, writeInt32, nil}, }, + {"int8/fast path", int8(127), nil, nil, writeInt32, nil}, + {"int16/fast path", int16(32767), nil, nil, writeInt32, nil}, + {"int32/fast path", int32(2147483647), nil, nil, writeInt32, nil}, + {"int64/fast path", int64(1234567890987), nil, nil, writeInt64, nil}, + // {"int64/fast path - minsize", int64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + // {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + // {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"int/fast path - positive int32", int(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, + {"int/fast path - negative int32", int(math.MinInt32 + 1), nil, nil, writeInt32, nil}, + {"int/fast path - MaxInt32", int(math.MaxInt32), nil, nil, writeInt32, nil}, + {"int/fast path - MinInt32", int(math.MinInt32), nil, nil, writeInt32, nil}, + {"int8/reflection path", myint8(127), nil, nil, writeInt32, nil}, + {"int16/reflection path", myint16(32767), nil, nil, writeInt32, nil}, + {"int32/reflection path", myint32(2147483647), nil, nil, writeInt32, nil}, + {"int64/reflection path", myint64(1234567890987), nil, nil, writeInt64, nil}, + // {"int64/reflection path - minsize", myint64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + // {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + // {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"int/reflection path - positive int32", myint(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, + {"int/reflection path - negative int32", myint(math.MinInt32 + 1), nil, nil, writeInt32, nil}, + {"int/reflection path - MaxInt32", myint(math.MaxInt32), nil, nil, writeInt32, nil}, + {"int/reflection path - MinInt32", myint(math.MinInt32), nil, nil, writeInt32, nil}, }, - */ + }, { "UintEncodeValue", - &uintCodec{}, + &intCodec{}, []subtest{ { "wrong type", @@ -145,8 +147,11 @@ func TestDefaultValueEncoders(t *testing.T) { nil, nothing, ValueEncoderError{ - Name: "UintEncodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "IntEncodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, @@ -235,7 +240,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", map[string]int{"foo": 1}, - newTestRegistry(), + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeDocument, fmt.Errorf("no encoder found for int"), @@ -259,7 +264,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty map/success", map[string]interface{}{}, - newTestRegistry(), + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeDocumentEnd, nil, @@ -315,7 +320,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", [1]int{1}, - newTestRegistry(), + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -393,7 +398,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", []int{1}, - newTestRegistry(), + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -433,7 +438,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty slice/success", []interface{}{}, - newTestRegistry(), + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeArrayEnd, nil, @@ -510,7 +515,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "json.Number/int64/success", json.Number("1234567890"), - nil, nil, writeInt64, nil, + buildDefaultRegistry(), nil, writeInt64, nil, }, { "json.Number/float64/success", diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index 0a68c77a40..cea7dfd348 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -22,7 +22,7 @@ var ( ) // EncodeValue is the ValueEncoderFunc for interface{}. -func (eic emptyInterfaceCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (eic emptyInterfaceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tEmpty { return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} } diff --git a/bson/encoder.go b/bson/encoder.go index 33cb16cc13..42c1900006 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -70,9 +70,10 @@ func (e *Encoder) SetRegistry(r *Registry) { // ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in // the marshaled BSON when the "inline" struct tag option is set. func (e *Encoder) ErrorOnInlineDuplicates() { - if v, ok := e.reg.kindEncoders.Load(reflect.Struct); ok { - if enc, ok := v.(*structCodec); ok { - enc.overwriteDuplicatedInlinedFields = false + t := reflect.TypeOf((*structCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*structCodec).overwriteDuplicatedInlinedFields = false } } } @@ -81,14 +82,20 @@ func (e *Encoder) ErrorOnInlineDuplicates() { // uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can // represent the integer value. func (e *Encoder) IntMinSize() { - if v, ok := e.reg.kindEncoders.Load(reflect.Int); ok { - if enc, ok := v.(*intCodec); ok { - enc.encodeToMinSize = true - } - } - if v, ok := e.reg.kindEncoders.Load(reflect.Uint); ok { - if enc, ok := v.(*uintCodec); ok { - enc.encodeToMinSize = true + // if v, ok := e.reg.kindEncoders.Load(reflect.Int); ok { + // if enc, ok := v.(*intCodec); ok { + // enc.encodeToMinSize = true + // } + // } + // if v, ok := e.reg.kindEncoders.Load(reflect.Uint); ok { + // if enc, ok := v.(*uintCodec); ok { + // enc.encodeToMinSize = true + // } + // } + t := reflect.TypeOf((*intCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*intCodec).encodeToMinSize = true } } } @@ -96,9 +103,10 @@ func (e *Encoder) IntMinSize() { // StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name // strings using fmt.Sprint instead of the default string conversion logic. func (e *Encoder) StringifyMapKeysWithFmt() { - if v, ok := e.reg.kindEncoders.Load(reflect.Map); ok { - if enc, ok := v.(*mapCodec); ok { - enc.encodeKeysWithStringer = true + t := reflect.TypeOf((*mapCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*mapCodec).encodeKeysWithStringer = true } } } @@ -106,9 +114,10 @@ func (e *Encoder) StringifyMapKeysWithFmt() { // NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON // null. func (e *Encoder) NilMapAsEmpty() { - if v, ok := e.reg.kindEncoders.Load(reflect.Map); ok { - if enc, ok := v.(*mapCodec); ok { - enc.encodeNilAsEmpty = true + t := reflect.TypeOf((*mapCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*mapCodec).encodeNilAsEmpty = true } } } @@ -116,9 +125,10 @@ func (e *Encoder) NilMapAsEmpty() { // NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON // null. func (e *Encoder) NilSliceAsEmpty() { - if v, ok := e.reg.kindEncoders.Load(reflect.Slice); ok { - if enc, ok := v.(*sliceCodec); ok { - enc.encodeNilAsEmpty = true + t := reflect.TypeOf((*sliceCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*sliceCodec).encodeNilAsEmpty = true } } } @@ -126,9 +136,15 @@ func (e *Encoder) NilSliceAsEmpty() { // NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values // instead of BSON null. func (e *Encoder) NilByteSliceAsEmpty() { - if v, ok := e.reg.typeEncoders.Load(tByteSlice); ok { - if enc, ok := v.(*byteSliceCodec); ok { - enc.encodeNilAsEmpty = true + // if v, ok := e.reg.typeEncoders.Load(tByteSlice); ok { + // if enc, ok := v.(*byteSliceCodec); ok { + // enc.encodeNilAsEmpty = true + // } + // } + t := reflect.TypeOf((*byteSliceCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*byteSliceCodec).encodeNilAsEmpty = true } } } @@ -142,9 +158,10 @@ func (e *Encoder) NilByteSliceAsEmpty() { // Note that the Encoder only examines exported struct fields when determining if a struct is the // zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. func (e *Encoder) OmitZeroStruct() { - if v, ok := e.reg.kindEncoders.Load(reflect.Struct); ok { - if enc, ok := v.(*structCodec); ok { - enc.encodeOmitDefaultStruct = true + t := reflect.TypeOf((*structCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*structCodec).encodeOmitDefaultStruct = true } } } @@ -152,9 +169,10 @@ func (e *Encoder) OmitZeroStruct() { // UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson" // struct tag is not specified. func (e *Encoder) UseJSONStructTags() { - if v, ok := e.reg.kindEncoders.Load(reflect.Struct); ok { - if enc, ok := v.(*structCodec); ok { - enc.useJSONStructTags = true + t := reflect.TypeOf((*structCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*structCodec).useJSONStructTags = true } } } diff --git a/bson/int_codec.go b/bson/int_codec.go index d0791ad70b..4d82092309 100644 --- a/bson/int_codec.go +++ b/bson/int_codec.go @@ -7,6 +7,8 @@ package bson import ( + "fmt" + "math" "reflect" ) @@ -15,10 +17,16 @@ type intCodec struct { // encodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the // minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value. encodeToMinSize bool + + // truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" + // values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64, + // uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to + // BSON "decimal128" values. + truncate bool } // EncodeValue is the ValueEncoder for uint types. -func (ic *intCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func (ic *intCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Int8, reflect.Int16, reflect.Int32: return vw.WriteInt32(int32(val.Int())) @@ -34,11 +42,153 @@ func (ic *intCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) return vw.WriteInt32(int32(i64)) } return vw.WriteInt64(i64) + + case reflect.Uint8, reflect.Uint16: + return vw.WriteInt32(int32(val.Uint())) + case reflect.Uint, reflect.Uint32, reflect.Uint64: + u64 := val.Uint() + + // If encodeToMinSize is true for a non-uint64 value we should write val as an int32 + useMinSize := ic.encodeToMinSize && val.Kind() != reflect.Uint64 + + if u64 <= math.MaxInt32 && useMinSize { + return vw.WriteInt32(int32(u64)) + } + if u64 > math.MaxInt64 { + return fmt.Errorf("%d overflows int64", u64) + } + return vw.WriteInt64(int64(u64)) } return ValueEncoderError{ - Name: "IntEncodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "IntEncodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: val, } } + +// DecodeValue is the ValueDecoder for uint types. +func (ic *intCodec) DecodeValue(_ *Registry, vr ValueReader, val reflect.Value) error { + if !val.CanSet() { + return ValueDecoderError{ + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, + Received: val, + } + } + + var i64 int64 + switch vrType := vr.Type(); vrType { + case TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return err + } + i64 = int64(i32) + case TypeInt64: + var err error + i64, err = vr.ReadInt64() + if err != nil { + return err + } + case TypeDouble: + f64, err := vr.ReadDouble() + if err != nil { + return err + } + if !ic.truncate && math.Floor(f64) != f64 { + return errCannotTruncate + } + if f64 > float64(math.MaxInt64) { + return fmt.Errorf("%g overflows int64", f64) + } + i64 = int64(f64) + case TypeBoolean: + b, err := vr.ReadBoolean() + if err != nil { + return err + } + if b { + i64 = 1 + } + case TypeNull: + if err := vr.ReadNull(); err != nil { + return err + } + case TypeUndefined: + if err := vr.ReadUndefined(); err != nil { + return err + } + default: + return fmt.Errorf("cannot decode %v into an integer type", vrType) + } + + switch t := val.Type(); t.Kind() { + case reflect.Int8: + if i64 < math.MinInt8 || i64 > math.MaxInt8 { + return fmt.Errorf("%d overflows int8", i64) + } + val.SetInt(i64) + case reflect.Int16: + if i64 < math.MinInt16 || i64 > math.MaxInt16 { + return fmt.Errorf("%d overflows int16", i64) + } + val.SetInt(i64) + case reflect.Int32: + if i64 < math.MinInt32 || i64 > math.MaxInt32 { + return fmt.Errorf("%d overflows int32", i64) + } + val.SetInt(i64) + case reflect.Int64: + val.SetInt(i64) + case reflect.Int: + if int64(int(i64)) != i64 { // Can we fit this inside of an int + return fmt.Errorf("%d overflows int", i64) + } + val.SetInt(i64) + + case reflect.Uint8: + if i64 < 0 || i64 > math.MaxUint8 { + return fmt.Errorf("%d overflows uint8", i64) + } + val.SetUint(uint64(i64)) + case reflect.Uint16: + if i64 < 0 || i64 > math.MaxUint16 { + return fmt.Errorf("%d overflows uint16", i64) + } + val.SetUint(uint64(i64)) + case reflect.Uint32: + if i64 < 0 || i64 > math.MaxUint32 { + return fmt.Errorf("%d overflows uint32", i64) + } + val.SetUint(uint64(i64)) + case reflect.Uint64: + if i64 < 0 { + return fmt.Errorf("%d overflows uint64", i64) + } + val.SetUint(uint64(i64)) + case reflect.Uint: + if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint + return fmt.Errorf("%d overflows uint", i64) + } + val.SetUint(uint64(i64)) + + default: + return ValueDecoderError{ + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, + Received: reflect.Zero(t), + } + } + + return nil +} diff --git a/bson/map_codec.go b/bson/map_codec.go index ce2ac9ff9e..bfa77ca0d8 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -50,7 +50,7 @@ type KeyUnmarshaler interface { } // EncodeValue is the ValueEncoder for map[*]* types. -func (mc *mapCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (mc *mapCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Map { return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} } @@ -78,7 +78,7 @@ func (mc *mapCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value // mapEncodeValue handles encoding of the values of a map. The collisionFn returns // true if the provided key exists, this is mainly used for inline maps in the // struct codec. -func (mc *mapCodec) mapEncodeValue(reg *Registry, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { +func (mc *mapCodec) mapEncodeValue(reg EncoderRegistry, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { elemType := val.Type().Elem() encoder, err := reg.LookupEncoder(elemType) diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 6013d7b911..5eeada562c 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -149,15 +149,16 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { // different Registry is used. // Create a custom Registry that negates int32 values when encoding. - var encodeInt32 ValueEncoderFunc = func(_ *Registry, vw ValueWriter, val reflect.Value) error { + var encodeInt32 ValueEncoderFunc = func(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.Int32 { return fmt.Errorf("expected kind to be int32, got %v", val.Kind()) } return vw.WriteInt32(int32(val.Int()) * -1) } - customReg := NewRegistry() - customReg.RegisterTypeEncoder(tInt32, encodeInt32) + customReg := NewRegistryBuilder(). + RegisterTypeEncoder(tInt32, func() ValueEncoder { return encodeInt32 }). + Build() // Helper function to run the test and make assertions. The provided original value should result in the document // {"x": {$numberInt: 1}} when marshalled with the default registry. diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index 398de2afb9..f0b77f4efb 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -22,10 +22,7 @@ var ( tSetter = reflect.TypeOf((*Setter)(nil)).Elem() ) -// NewMgoRegistry creates a new bson.Registry configured with the default encoders and decoders. -func NewMgoRegistry() *Registry { - reg := NewRegistry() - +func newMgoRegistryBuilder() *RegistryBuilder { structcodec := &structCodec{ parser: DefaultStructTagParser, decodeZeroStruct: true, @@ -37,45 +34,47 @@ func NewMgoRegistry() *Registry { encodeNilAsEmpty: true, encodeKeysWithStringer: true, } - uintcodec := &uintCodec{encodeToMinSize: true} + intcodec := func() ValueEncoder { return &intCodec{encodeToMinSize: true} } - reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true}) - reg.RegisterKindDecoder(reflect.String, &stringCodec{}) - reg.RegisterKindDecoder(reflect.Struct, structcodec) - reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: true}) - reg.RegisterKindEncoder(reflect.Struct, structcodec) - reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{encodeNilAsEmpty: true}) - reg.RegisterKindEncoder(reflect.Map, mapCodec) - reg.RegisterKindEncoder(reflect.Uint, uintcodec) - reg.RegisterKindEncoder(reflect.Uint8, uintcodec) - reg.RegisterKindEncoder(reflect.Uint16, uintcodec) - reg.RegisterKindEncoder(reflect.Uint32, uintcodec) - reg.RegisterKindEncoder(reflect.Uint64, uintcodec) - reg.RegisterTypeMapEntry(TypeInt32, tInt) - reg.RegisterTypeMapEntry(TypeDateTime, tTime) - reg.RegisterTypeMapEntry(TypeArray, tInterfaceSlice) - reg.RegisterTypeMapEntry(Type(0), tM) - reg.RegisterTypeMapEntry(TypeEmbeddedDocument, tM) - reg.RegisterInterfaceEncoder(tGetter, ValueEncoderFunc(GetterEncodeValue)) - reg.RegisterInterfaceDecoder(tSetter, ValueDecoderFunc(SetterDecodeValue)) + return NewRegistryBuilder(). + RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true}). + RegisterKindDecoder(reflect.String, &stringCodec{}). + RegisterKindDecoder(reflect.Struct, structcodec). + RegisterKindDecoder(reflect.Map, mapCodec). + RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: true} }). + RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return structcodec }). + RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{encodeNilAsEmpty: true} }). + RegisterKindEncoder(reflect.Map, func() ValueEncoder { return mapCodec }). + RegisterKindEncoder(reflect.Uint, intcodec). + RegisterKindEncoder(reflect.Uint8, intcodec). + RegisterKindEncoder(reflect.Uint16, intcodec). + RegisterKindEncoder(reflect.Uint32, intcodec). + RegisterKindEncoder(reflect.Uint64, intcodec). + RegisterTypeMapEntry(TypeInt32, tInt). + RegisterTypeMapEntry(TypeDateTime, tTime). + RegisterTypeMapEntry(TypeArray, tInterfaceSlice). + RegisterTypeMapEntry(Type(0), tM). + RegisterTypeMapEntry(TypeEmbeddedDocument, tM). + RegisterInterfaceEncoder(tGetter, func() ValueEncoder { return ValueEncoderFunc(GetterEncodeValue) }). + RegisterInterfaceDecoder(tSetter, ValueDecoderFunc(SetterDecodeValue)) +} - return reg +// NewMgoRegistry creates a new bson.Registry configured with the default encoders and decoders. +func NewMgoRegistry() *Registry { + return newMgoRegistryBuilder().Build() } // NewRespectNilValuesMgoRegistry creates a new bson.Registry configured to behave like mgo/bson // with RespectNilValues set to true. func NewRespectNilValuesMgoRegistry() *Registry { - reg := NewMgoRegistry() - mapCodec := &mapCodec{ decodeZerosMap: true, } - reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: false}) - reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) - reg.RegisterKindEncoder(reflect.Map, mapCodec) - - return reg + return newMgoRegistryBuilder(). + RegisterKindDecoder(reflect.Map, mapCodec). + RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: false} }). + RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{} }). + RegisterKindEncoder(reflect.Map, func() ValueEncoder { return mapCodec }). + Build() } diff --git a/bson/pointer_codec.go b/bson/pointer_codec.go index 425d371d0e..af35da68b2 100644 --- a/bson/pointer_codec.go +++ b/bson/pointer_codec.go @@ -18,7 +18,7 @@ type pointerCodec struct { // EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil // or looking up an encoder for the type of value the pointer points to. -func (pc *pointerCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (pc *pointerCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.Ptr { if !val.IsValid() { return vw.WriteNull() diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index adbb28d601..082cd15357 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -16,22 +16,22 @@ var tRawValue = reflect.TypeOf(RawValue{}) var tRaw = reflect.TypeOf(Raw(nil)) // registerPrimitiveCodecs will register the encode and decode methods with the provided Registry. -func registerPrimitiveCodecs(reg *Registry) { - if reg == nil { +func registerPrimitiveCodecs(rb *RegistryBuilder) { + if rb == nil { panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil")) } - reg.RegisterTypeEncoder(tRawValue, ValueEncoderFunc(rawValueEncodeValue)) - reg.RegisterTypeEncoder(tRaw, ValueEncoderFunc(rawEncodeValue)) - reg.RegisterTypeDecoder(tRawValue, ValueDecoderFunc(rawValueDecodeValue)) - reg.RegisterTypeDecoder(tRaw, ValueDecoderFunc(rawDecodeValue)) + rb.RegisterTypeEncoder(tRawValue, func() ValueEncoder { return ValueEncoderFunc(rawValueEncodeValue) }). + RegisterTypeEncoder(tRaw, func() ValueEncoder { return ValueEncoderFunc(rawEncodeValue) }). + RegisterTypeDecoder(tRawValue, ValueDecoderFunc(rawValueDecodeValue)). + RegisterTypeDecoder(tRaw, ValueDecoderFunc(rawDecodeValue)) } // rawValueEncodeValue is the ValueEncoderFunc for RawValue. // // If the RawValue's Type is "invalid" and the RawValue's Value is not empty or // nil, then this method will return an error. -func rawValueEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func rawValueEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRawValue { return ValueEncoderError{ Name: "RawValueEncodeValue", @@ -65,7 +65,7 @@ func rawValueDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) err } // rawEncodeValue is the ValueEncoderFunc for Reader. -func rawEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func rawEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRaw { return ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: val} } diff --git a/bson/raw_value_test.go b/bson/raw_value_test.go index 67444faa61..18598ebe8f 100644 --- a/bson/raw_value_test.go +++ b/bson/raw_value_test.go @@ -25,7 +25,7 @@ func TestRawValue(t *testing.T) { t.Run("Uses registry attached to value", func(t *testing.T) { t.Parallel() - reg := newTestRegistry() + reg := newTestRegistryBuilder().Build() val := RawValue{Type: TypeString, Value: bsoncore.AppendString(nil, "foobar"), r: reg} var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -63,7 +63,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - reg := newTestRegistry() + reg := newTestRegistryBuilder().Build() var val RawValue var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -75,7 +75,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns DecodeValue error", func(t *testing.T) { t.Parallel() - reg := NewRegistry() + reg := NewRegistryBuilder().Build() val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string want := fmt.Errorf("cannot decode %v into a string type", TypeDouble) @@ -87,7 +87,7 @@ func TestRawValue(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - reg := NewRegistry() + reg := NewRegistryBuilder().Build() want := float64(3.14159) val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, want)} var got float64 @@ -114,7 +114,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: newTestRegistry()} + dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} var val RawValue var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -126,7 +126,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns DecodeValue error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: NewRegistry()} + dc := DecodeContext{Registry: NewRegistryBuilder().Build()} val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string want := fmt.Errorf("cannot decode %v into a string type", TypeDouble) @@ -138,7 +138,7 @@ func TestRawValue(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: NewRegistry()} + dc := DecodeContext{Registry: NewRegistryBuilder().Build()} want := float64(3.14159) val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, want)} var got float64 diff --git a/bson/registry.go b/bson/registry.go index 71d65259d6..179b61de91 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -15,7 +15,7 @@ import ( // DefaultRegistry is the default Registry. It contains the default codecs and the // primitive codecs. -var DefaultRegistry = NewRegistry() +var DefaultRegistry = NewRegistryBuilder().Build() // ErrNilType is returned when nil is passed to either LookupEncoder or LookupDecoder. // @@ -58,75 +58,54 @@ func (entme ErrNoTypeMapEntry) Error() string { return "no type map entry found for " + entme.Type.String() } -// A Registry is a store for ValueEncoders, ValueDecoders, and a type map. See the Registry type -// documentation for examples of registering various custom encoders and decoders. A Registry can -// have four main types of codecs: -// -// 1. Type encoders/decoders - These can be registered using the RegisterTypeEncoder and -// RegisterTypeDecoder methods. The registered codec will be invoked when encoding/decoding a value -// whose type matches the registered type exactly. -// If the registered type is an interface, the codec will be invoked when encoding or decoding -// values whose type is the interface, but not for values with concrete types that implement the -// interface. -// -// 2. Interface encoders/decoders - These can be registered using the RegisterInterfaceEncoder and -// RegisterInterfaceDecoder methods. These methods only accept interface types and the registered codecs -// will be invoked when encoding or decoding values whose types implement the interface. An example -// of an interface defined by the driver is bson.Marshaler. The driver will call the MarshalBSON method -// for any value whose type implements bson.Marshaler, regardless of the value's concrete type. -// -// 3. Type map entries - This can be used to associate a BSON type with a Go type. These type -// associations are used when decoding into a bson.D/bson.M or a struct field of type interface{}. -// For example, by default, BSON int32 and int64 values decode as Go int32 and int64 instances, -// respectively, when decoding into a bson.D. The following code would change the behavior so these -// values decode as Go int instances instead: -// -// intType := reflect.TypeOf(int(0)) -// registry.RegisterTypeMapEntry(bson.TypeInt32, intType).RegisterTypeMapEntry(bson.TypeInt64, intType) -// -// 4. Kind encoder/decoders - These can be registered using the RegisterDefaultEncoder and -// RegisterDefaultDecoder methods. The registered codec will be invoked when encoding or decoding -// values whose reflect.Kind matches the registered reflect.Kind as long as the value's type doesn't -// match a registered type or interface encoder/decoder first. These methods should be used to change the -// behavior for all values for a specific kind. -// -// Read [Registry.LookupDecoder] and [Registry.LookupEncoder] for Registry lookup procedure. -type Registry struct { - interfaceEncoders []interfaceValueEncoder - interfaceDecoders []interfaceValueDecoder - typeEncoders *typeEncoderCache +// A RegistryBuilder is used to build a Registry. This type is not goroutine +// safe. +type RegistryBuilder struct { + typeEncoders map[reflect.Type]EncoderFactory typeDecoders *typeDecoderCache - kindEncoders *kindEncoderCache + interfaceEncoders map[reflect.Type]EncoderFactory + interfaceDecoders []interfaceValueDecoder + kindEncoders [reflect.UnsafePointer + 1]EncoderFactory kindDecoders *kindDecoderCache - typeMap sync.Map // map[Type]reflect.Type + typeMap map[Type]reflect.Type } -// NewRegistry creates a new empty Registry. -func NewRegistry() *Registry { - reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), +// NewRegistryBuilder creates a new empty RegistryBuilder. +func NewRegistryBuilder() *RegistryBuilder { + rb := &RegistryBuilder{ + typeEncoders: make(map[reflect.Type]EncoderFactory), + typeDecoders: new(typeDecoderCache), + interfaceEncoders: make(map[reflect.Type]EncoderFactory), + kindDecoders: new(kindDecoderCache), + typeMap: make(map[Type]reflect.Type), } - registerDefaultEncoders(reg) - registerDefaultDecoders(reg) - registerPrimitiveCodecs(reg) - return reg + registerDefaultEncoders(rb) + registerDefaultDecoders(rb) + registerPrimitiveCodecs(rb) + return rb } -// RegisterTypeEncoder registers the provided ValueEncoder for the provided type. +// EncoderFactory is a factory function that generates a new ValueEncoder. +type EncoderFactory func() ValueEncoder + +// DecoderFactory is a factory function that generates a new ValueDecoder. +type DecoderFactory func() ValueDecoder + +// RegisterTypeEncoder registers a ValueEncoder factory for the provided type. // -// The type will be used as provided, so an encoder can be registered for a type and a different -// encoder can be registered for a pointer to that type. +// The type will be used as provided, so an encoder factory can be registered for a type and a +// different one can be registered for a pointer to that type. // // If the given type is an interface, the encoder will be called when marshaling a type that is // that interface. It will not be called when marshaling a non-interface type that implements the -// interface. To get the latter behavior, call RegisterHookEncoder instead. +// interface. To get the latter behavior, call RegisterInterfaceEncoder instead. // // RegisterTypeEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) { - r.typeEncoders.Store(valueType, enc) +func (rb *RegistryBuilder) RegisterTypeEncoder(valueType reflect.Type, encFac EncoderFactory) *RegistryBuilder { + if encFac != nil { + rb.typeEncoders[valueType] = encFac + } + return rb } // RegisterTypeDecoder registers the provided ValueDecoder for the provided type. @@ -139,24 +118,28 @@ func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) // implements the interface. To get the latter behavior, call RegisterHookDecoder instead. // // RegisterTypeDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) { - r.typeDecoders.Store(valueType, dec) +func (rb *RegistryBuilder) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) *RegistryBuilder { + rb.typeDecoders.Store(valueType, dec) + return rb } -// RegisterKindEncoder registers the provided ValueEncoder for the provided kind. +// RegisterKindEncoder registers a ValueEncoder factory for the provided kind. // -// Use RegisterKindEncoder to register an encoder for any type with the same underlying kind. For -// example, consider the type MyInt defined as +// Use RegisterKindEncoder to register an encoder factory for any type with the same underlying kind. +// For example, consider the type MyInt defined as // // type MyInt int32 // -// To define an encoder for MyInt and int32, use RegisterKindEncoder like +// To define an encoder factory for MyInt and int32, use RegisterKindEncoder like // // reg.RegisterKindEncoder(reflect.Int32, myEncoder) // // RegisterKindEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { - r.kindEncoders.Store(kind, enc) +func (rb *RegistryBuilder) RegisterKindEncoder(kind reflect.Kind, encFac EncoderFactory) *RegistryBuilder { + if encFac != nil && kind < reflect.Kind(len(rb.kindEncoders)) { + rb.kindEncoders[kind] = encFac + } + return rb } // RegisterKindDecoder registers the provided ValueDecoder for the provided kind. @@ -171,31 +154,29 @@ func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { // reg.RegisterKindDecoder(reflect.Int32, myDecoder) // // RegisterKindDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) { - r.kindDecoders.Store(kind, dec) +func (rb *RegistryBuilder) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder { + rb.kindDecoders.Store(kind, dec) + return rb } -// RegisterInterfaceEncoder registers an encoder for the provided interface type iface. This encoder will -// be called when marshaling a type if the type implements iface or a pointer to the type +// RegisterInterfaceEncoder registers an encoder factory for the provided interface type iface. This +// encoder will be called when marshaling a type if the type implements iface or a pointer to the type // implements iface. If the provided type is not an interface // (i.e. iface.Kind() != reflect.Interface), this method will panic. // // RegisterInterfaceEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc ValueEncoder) { +func (rb *RegistryBuilder) RegisterInterfaceEncoder(iface reflect.Type, encFac EncoderFactory) *RegistryBuilder { if iface.Kind() != reflect.Interface { panicStr := fmt.Errorf("RegisterInterfaceEncoder expects a type with kind reflect.Interface, "+ "got type %s with kind %s", iface, iface.Kind()) panic(panicStr) } - for idx, encoder := range r.interfaceEncoders { - if encoder.i == iface { - r.interfaceEncoders[idx].ve = enc - return - } + if encFac != nil { + rb.interfaceEncoders[iface] = encFac } - r.interfaceEncoders = append(r.interfaceEncoders, interfaceValueEncoder{i: iface, ve: enc}) + return rb } // RegisterInterfaceDecoder registers an decoder for the provided interface type iface. This decoder will @@ -204,21 +185,23 @@ func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc ValueEncoder // this method will panic. // // RegisterInterfaceDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder) { +func (rb *RegistryBuilder) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder) *RegistryBuilder { if iface.Kind() != reflect.Interface { panicStr := fmt.Errorf("RegisterInterfaceDecoder expects a type with kind reflect.Interface, "+ "got type %s with kind %s", iface, iface.Kind()) panic(panicStr) } - for idx, decoder := range r.interfaceDecoders { + for idx, decoder := range rb.interfaceDecoders { if decoder.i == iface { - r.interfaceDecoders[idx].vd = dec - return + rb.interfaceDecoders[idx].vd = dec + return rb } } - r.interfaceDecoders = append(r.interfaceDecoders, interfaceValueDecoder{i: iface, vd: dec}) + rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: iface, vd: dec}) + + return rb } // RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this @@ -230,8 +213,113 @@ func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder // to decode to bson.Raw, use the following code: // // reg.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) -func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) { - r.typeMap.Store(bt, rt) +// +// RegisterTypeMapEntry should not be called concurrently with any other Registry method. +func (rb *RegistryBuilder) RegisterTypeMapEntry(bt Type, rt reflect.Type) *RegistryBuilder { + rb.typeMap[bt] = rt + return rb +} + +// Build creates a Registry from the current state of this RegistryBuilder. +func (rb *RegistryBuilder) Build() *Registry { + r := &Registry{ + typeEncoders: new(sync.Map), + typeDecoders: rb.typeDecoders.Clone(), + interfaceEncoders: make([]interfaceValueEncoder, 0, len(rb.interfaceEncoders)), + interfaceDecoders: append([]interfaceValueDecoder(nil), rb.interfaceDecoders...), + kindDecoders: rb.kindDecoders.Clone(), + encoderTypeMap: make(map[reflect.Type][]ValueEncoder), + typeMap: make(map[Type]reflect.Type), + } + encoderCache := make(map[reflect.Value]ValueEncoder) + for k, v := range rb.typeEncoders { + var encoder ValueEncoder + if enc, ok := encoderCache[reflect.ValueOf(v)]; ok { + encoder = enc + } else { + encoder = v() + encoderCache[reflect.ValueOf(v)] = encoder + et := reflect.ValueOf(encoder).Type() + r.encoderTypeMap[et] = append(r.encoderTypeMap[et], encoder) + } + r.typeEncoders.Store(k, encoder) + } + for k, v := range rb.interfaceEncoders { + var encoder ValueEncoder + if enc, ok := encoderCache[reflect.ValueOf(v)]; ok { + encoder = enc + } else { + encoder = v() + encoderCache[reflect.ValueOf(v)] = encoder + et := reflect.ValueOf(encoder).Type() + r.encoderTypeMap[et] = append(r.encoderTypeMap[et], encoder) + } + r.interfaceEncoders = append(r.interfaceEncoders, interfaceValueEncoder{k, encoder}) + } + for i, v := range rb.kindEncoders { + if v == nil { + continue + } + var encoder ValueEncoder + if enc, ok := encoderCache[reflect.ValueOf(v)]; ok { + encoder = enc + } else { + encoder = v() + encoderCache[reflect.ValueOf(v)] = encoder + et := reflect.ValueOf(encoder).Type() + r.encoderTypeMap[et] = append(r.encoderTypeMap[et], encoder) + } + r.kindEncoders[i] = encoder + } + for k, v := range rb.typeMap { + r.typeMap[k] = v + } + return r +} + +// A Registry is a store for ValueEncoders, ValueDecoders, and a type map. See the Registry type +// documentation for examples of registering various custom encoders and decoders. A Registry can +// have four main types of codecs: +// +// 1. Type encoders/decoders - These can be registered using the RegisterTypeEncoder and +// RegisterTypeDecoder methods. The registered codec will be invoked when encoding/decoding a value +// whose type matches the registered type exactly. +// If the registered type is an interface, the codec will be invoked when encoding or decoding +// values whose type is the interface, but not for values with concrete types that implement the +// interface. +// +// 2. Interface encoders/decoders - These can be registered using the RegisterInterfaceEncoder and +// RegisterInterfaceDecoder methods. These methods only accept interface types and the registered codecs +// will be invoked when encoding or decoding values whose types implement the interface. An example +// of an interface defined by the driver is bson.Marshaler. The driver will call the MarshalBSON method +// for any value whose type implements bson.Marshaler, regardless of the value's concrete type. +// +// 3. Type map entries - This can be used to associate a BSON type with a Go type. These type +// associations are used when decoding into a bson.D/bson.M or a struct field of type interface{}. +// For example, by default, BSON int32 and int64 values decode as Go int32 and int64 instances, +// respectively, when decoding into a bson.D. The following code would change the behavior so these +// values decode as Go int instances instead: +// +// intType := reflect.TypeOf(int(0)) +// registry.RegisterTypeMapEntry(bson.TypeInt32, intType).RegisterTypeMapEntry(bson.TypeInt64, intType) +// +// 4. Kind encoder/decoders - These can be registered using the RegisterDefaultEncoder and +// RegisterDefaultDecoder methods. The registered codec will be invoked when encoding or decoding +// values whose reflect.Kind matches the registered reflect.Kind as long as the value's type doesn't +// match a registered type or interface encoder/decoder first. These methods should be used to change the +// behavior for all values for a specific kind. +// +// Read [Registry.LookupDecoder] and [Registry.LookupEncoder] for Registry lookup procedure. +type Registry struct { + typeEncoders *sync.Map // map[reflect.Type]ValueEncoder + typeDecoders *typeDecoderCache + interfaceEncoders []interfaceValueEncoder + interfaceDecoders []interfaceValueDecoder + kindEncoders [reflect.UnsafePointer + 1]ValueEncoder + kindDecoders *kindDecoderCache + typeMap map[Type]reflect.Type + + encoderTypeMap map[reflect.Type][]ValueEncoder } // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup @@ -250,36 +338,38 @@ func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) { // 3. An encoder registered using RegisterKindEncoder for the kind of value. // // If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for -// concurrent use by multiple goroutines after all codecs and encoders are registered. +// concurrent use by multiple goroutines. func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { if valueType == nil { return nil, ErrNoEncoder{Type: valueType} } - enc, found := r.lookupTypeEncoder(valueType) - if found { + + if enc, found := r.typeEncoders.Load(valueType); found { if enc == nil { return nil, ErrNoEncoder{Type: valueType} } - return enc, nil + return enc.(ValueEncoder), nil } - enc, found = r.lookupInterfaceEncoder(valueType, true) - if found { - return r.typeEncoders.LoadOrStore(valueType, enc), nil + if enc, found := r.lookupInterfaceEncoder(valueType, true); found { + r.typeEncoders.Store(valueType, enc) + return enc, nil } - if v, ok := r.kindEncoders.Load(valueType.Kind()); ok { - return r.storeTypeEncoder(valueType, v), nil + if enc, found := r.lookupKindEncoder(valueType.Kind()); found { + r.typeEncoders.Store(valueType, enc) + return enc, nil } return nil, ErrNoEncoder{Type: valueType} } -func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEncoder { - return r.typeEncoders.LoadOrStore(rt, enc) -} - -func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) { - return r.typeEncoders.Load(rt) +func (r *Registry) lookupKindEncoder(valueKind reflect.Kind) (ValueEncoder, bool) { + if valueKind < reflect.Kind(len(r.kindEncoders)) { + if enc := r.kindEncoders[valueKind]; enc != nil { + return enc, true + } + } + return nil, false } func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (ValueEncoder, bool) { @@ -295,7 +385,7 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool // ahead in interfaceEncoders defaultEnc, found := r.lookupInterfaceEncoder(valueType, false) if !found { - defaultEnc, _ = r.kindEncoders.Load(valueType.Kind()) + defaultEnc, _ = r.lookupKindEncoder(valueType.Kind()) } return &condAddrEncoder{canAddrEnc: ienc.ve, elseEnc: defaultEnc}, true } @@ -319,12 +409,12 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool // 3. A decoder registered using RegisterKindDecoder for the kind of value. // // If no decoder is found, an error of type ErrNoDecoder is returned. LookupDecoder is safe for -// concurrent use by multiple goroutines after all codecs and decoders are registered. +// concurrent use by multiple goroutines. func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { if valueType == nil { return nil, ErrNilType } - dec, found := r.lookupTypeDecoder(valueType) + dec, found := r.typeDecoders.Load(valueType) if found { if dec == nil { return nil, ErrNoDecoder{Type: valueType} @@ -334,23 +424,15 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { dec, found = r.lookupInterfaceDecoder(valueType, true) if found { - return r.storeTypeDecoder(valueType, dec), nil + return r.typeDecoders.LoadOrStore(valueType, dec), nil } if v, ok := r.kindDecoders.Load(valueType.Kind()); ok { - return r.storeTypeDecoder(valueType, v), nil + return r.typeDecoders.LoadOrStore(valueType, v), nil } return nil, ErrNoDecoder{Type: valueType} } -func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (ValueDecoder, bool) { - return r.typeDecoders.Load(valueType) -} - -func (r *Registry) storeTypeDecoder(typ reflect.Type, dec ValueDecoder) ValueDecoder { - return r.typeDecoders.LoadOrStore(typ, dec) -} - func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) { for _, idec := range r.interfaceDecoders { if valueType.Implements(idec.i) { @@ -371,14 +453,12 @@ func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool // LookupTypeMapEntry inspects the registry's type map for a Go type for the corresponding BSON // type. If no type is found, ErrNoTypeMapEntry is returned. -// -// LookupTypeMapEntry should not be called concurrently with any other Registry method. func (r *Registry) LookupTypeMapEntry(bt Type) (reflect.Type, error) { - v, ok := r.typeMap.Load(bt) + v, ok := r.typeMap[bt] if v == nil || !ok { return nil, ErrNoTypeMapEntry{Type: bt} } - return v.(reflect.Type), nil + return v, nil } type interfaceValueEncoder struct { diff --git a/bson/registry_examples_test.go b/bson/registry_examples_test.go index b866df8cdb..4b15dde3d5 100644 --- a/bson/registry_examples_test.go +++ b/bson/registry_examples_test.go @@ -23,7 +23,7 @@ func ExampleRegistry_customEncoder() { negatedIntType := reflect.TypeOf(negatedInt(0)) negatedIntEncoder := func( - _ *bson.Registry, + _ bson.EncoderRegistry, vw bson.ValueWriter, val reflect.Value, ) error { @@ -46,10 +46,13 @@ func ExampleRegistry_customEncoder() { return vw.WriteInt64(negatedVal) } - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder() reg.RegisterTypeEncoder( negatedIntType, - bson.ValueEncoderFunc(negatedIntEncoder)) + func() bson.ValueEncoder { + return bson.ValueEncoderFunc(negatedIntEncoder) + }, + ) // Define a document that includes both int and negatedInt fields with the // same value. @@ -67,7 +70,7 @@ func ExampleRegistry_customEncoder() { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) enc := bson.NewEncoder(vw) - enc.SetRegistry(reg) + enc.SetRegistry(reg.Build()) err := enc.Encode(doc) if err != nil { panic(err) @@ -129,10 +132,11 @@ func ExampleRegistry_customDecoder() { return nil } - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder() reg.RegisterTypeDecoder( lenientBoolType, - bson.ValueDecoderFunc(lenientBoolDecoder)) + bson.ValueDecoderFunc(lenientBoolDecoder), + ) // Marshal a BSON document with a single field "isOK" that is a non-zero // integer value. @@ -148,7 +152,7 @@ func ExampleRegistry_customDecoder() { IsOK lenientBool `bson:"isOK"` } var doc MyDocument - err = bson.UnmarshalWithRegistry(reg, b, &doc) + err = bson.UnmarshalWithRegistry(reg.Build(), b, &doc) if err != nil { panic(err) } @@ -156,13 +160,13 @@ func ExampleRegistry_customDecoder() { // Output: {IsOK:true} } -func ExampleRegistry_RegisterKindEncoder() { +func ExampleRegistryBuilder_RegisterKindEncoder() { // Create a custom encoder that writes any Go type that has underlying type // int32 as an a BSON int64. To do that, we register the encoder as a "kind" // encoder for kind reflect.Int32. That way, even user-defined types with // underlying type int32 will be encoded as a BSON int64. int32To64Encoder := func( - _ *bson.Registry, + _ bson.EncoderRegistry, vw bson.ValueWriter, val reflect.Value, ) error { @@ -181,10 +185,13 @@ func ExampleRegistry_RegisterKindEncoder() { // Create a default registry and register our int32-to-int64 encoder for // kind reflect.Int32. - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder() reg.RegisterKindEncoder( reflect.Int32, - bson.ValueEncoderFunc(int32To64Encoder)) + func() bson.ValueEncoder { + return bson.ValueEncoderFunc(int32To64Encoder) + }, + ) // Define a document that includes an int32, an int64, and a user-defined // type "myInt" that has underlying type int32. @@ -205,7 +212,7 @@ func ExampleRegistry_RegisterKindEncoder() { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) enc := bson.NewEncoder(vw) - enc.SetRegistry(reg) + enc.SetRegistry(reg.Build()) err := enc.Encode(doc) if err != nil { panic(err) @@ -214,7 +221,7 @@ func ExampleRegistry_RegisterKindEncoder() { // Output: {"myint": {"$numberLong":"1"},"int32": {"$numberLong":"1"},"int64": {"$numberLong":"1"}} } -func ExampleRegistry_RegisterKindDecoder() { +func ExampleRegistryBuilder_RegisterKindDecoder() { // Create a custom decoder that can decode any integer value, including // integer values encoded as floating point numbers, to any Go type // with underlying type int64. To do that, we register the decoder as a @@ -270,10 +277,11 @@ func ExampleRegistry_RegisterKindDecoder() { return nil } - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder() reg.RegisterKindDecoder( reflect.Int64, - bson.ValueDecoderFunc(flexibleInt64KindDecoder)) + bson.ValueDecoderFunc(flexibleInt64KindDecoder), + ) // Marshal a BSON document with fields that are mixed numeric types but all // hold integer values (i.e. values with no fractional part). @@ -290,7 +298,7 @@ func ExampleRegistry_RegisterKindDecoder() { Int64 int64 } var doc myDocument - err = bson.UnmarshalWithRegistry(reg, b, &doc) + err = bson.UnmarshalWithRegistry(reg.Build(), b, &doc) if err != nil { panic(err) } diff --git a/bson/registry_test.go b/bson/registry_test.go index 5375b6f444..003bb69d6b 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -15,508 +15,247 @@ import ( "go.mongodb.org/mongo-driver/internal/assert" ) -// newTestRegistry creates a new empty Registry. -func newTestRegistry() *Registry { - return &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), +// newTestRegistryBuilder creates a new empty RegistryBuilder. +func newTestRegistryBuilder() *RegistryBuilder { + return &RegistryBuilder{ + typeEncoders: make(map[reflect.Type]EncoderFactory), + typeDecoders: new(typeDecoderCache), + interfaceEncoders: make(map[reflect.Type]EncoderFactory), + kindDecoders: new(kindDecoderCache), + typeMap: make(map[Type]reflect.Type), } } func TestRegistryBuilder(t *testing.T) { + t.Parallel() + t.Run("Register", func(t *testing.T) { + t.Parallel() + fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec) t.Run("interface", func(t *testing.T) { - var t1f *testInterface1 - var t2f *testInterface2 - var t4f *testInterface4 - ips := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc1}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + t.Parallel() + + t1f, t2f, t3f, t4f := + reflect.TypeOf((*testInterface1)(nil)).Elem(), + reflect.TypeOf((*testInterface2)(nil)).Elem(), + reflect.TypeOf((*testInterface3)(nil)).Elem(), + reflect.TypeOf((*testInterface4)(nil)).Elem() + + var c1, c2, c3, c4 int + ef1 := func() ValueEncoder { + c1++ + return fc1 + } + ef2 := func() ValueEncoder { + c2++ + return fc2 + } + ef3 := func() ValueEncoder { + c3++ + return fc3 + } + ef4 := func() ValueEncoder { + c4++ + return fc4 + } + + ips := []struct { + i reflect.Type + ef EncoderFactory + }{ + {i: t1f, ef: ef1}, + {i: t2f, ef: ef2}, + {i: t1f, ef: ef3}, + {i: t3f, ef: ef2}, + {i: t4f, ef: ef4}, } want := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + {i: t1f, ve: fc3}, {i: t2f, ve: fc2}, + {i: t3f, ve: fc2}, {i: t4f, ve: fc4}, } - reg := newTestRegistry() + + rb := newTestRegistryBuilder() for _, ip := range ips { - reg.RegisterInterfaceEncoder(ip.i, ip.ve) + rb.RegisterInterfaceEncoder(ip.i, ip.ef) } + reg := rb.Build() - got := reg.interfaceEncoders - if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) { - t.Errorf("the registered interfaces are not correct: got %#v, want %#v", got, want) + if !cmp.Equal(c1, 0) { + t.Errorf("ef1 is called %d time(s); expected 0", c1) } - }) - t.Run("type", func(t *testing.T) { - ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} - reg := newTestRegistry() - reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc1) - reg.RegisterTypeEncoder(reflect.TypeOf(ft2), fc2) - reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc3) - reg.RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) - want := []struct { - t reflect.Type - c ValueEncoder - }{ - {reflect.TypeOf(ft1), fc3}, - {reflect.TypeOf(ft2), fc2}, - {reflect.TypeOf(ft4), fc4}, + if !cmp.Equal(c2, 1) { + t.Errorf("ef2 is called %d time(s); expected 1", c2) } - - got := reg.typeEncoders - for _, s := range want { - wantT, wantC := s.t, s.c - gotC, exists := got.Load(wantT) - if !exists { - t.Errorf("Did not find type in the type registry: %v", wantT) - } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC) - } + if !cmp.Equal(c3, 1) { + t.Errorf("ef3 is called %d time(s); expected 1", c3) } - }) - t.Run("kind", func(t *testing.T) { - k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map - reg := newTestRegistry() - reg.RegisterKindEncoder(k1, fc1) - reg.RegisterKindEncoder(k2, fc2) - reg.RegisterKindEncoder(k1, fc3) - reg.RegisterKindEncoder(k4, fc4) - want := []struct { - k reflect.Kind - c ValueEncoder - }{ - {k1, fc3}, - {k2, fc2}, - {k4, fc4}, + if !cmp.Equal(c4, 1) { + t.Errorf("ef4 is called %d time(s); expected 1", c4) + } + codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + t.Errorf("codecs were not cached correctly") + } + got := make(map[reflect.Type]ValueEncoder) + for _, e := range reg.interfaceEncoders { + got[e.i] = e.ve } - - got := reg.kindEncoders for _, s := range want { - wantK, wantC := s.k, s.c - gotC, exists := got.Load(wantK) + wantI, wantVe := s.i, s.ve + gotVe, exists := got[wantI] if !exists { - t.Errorf("Did not find kind in the kind registry: %v", wantK) + t.Errorf("Did not find type in the type registry: %v", wantI) } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC) + if !cmp.Equal(gotVe, wantVe, cmp.AllowUnexported(fakeCodec{})) { + t.Errorf("codecs did not match: got %#v; want %#v", gotVe, wantVe) } } }) - t.Run("RegisterDefault", func(t *testing.T) { - t.Run("MapCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - - reg.RegisterKindEncoder(reflect.Map, codec) - if reg.kindEncoders.get(reflect.Map) != codec { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) - } - - reg.RegisterKindEncoder(reflect.Map, codec2) - if reg.kindEncoders.get(reflect.Map) != codec2 { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) - } - }) - t.Run("StructCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - - reg.RegisterKindEncoder(reflect.Struct, codec) - if reg.kindEncoders.get(reflect.Struct) != codec { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) - } - - reg.RegisterKindEncoder(reflect.Struct, codec2) - if reg.kindEncoders.get(reflect.Struct) != codec2 { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) - } - }) - t.Run("SliceCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - - reg.RegisterKindEncoder(reflect.Slice, codec) - if reg.kindEncoders.get(reflect.Slice) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) - } - - reg.RegisterKindEncoder(reflect.Slice, codec2) - if reg.kindEncoders.get(reflect.Slice) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) - } - }) - t.Run("ArrayCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() + t.Run("type", func(t *testing.T) { + t.Parallel() - reg.RegisterKindEncoder(reflect.Array, codec) - if reg.kindEncoders.get(reflect.Array) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) - } + ft1, ft2, ft3, ft4 := + reflect.TypeOf(fakeType1{}), + reflect.TypeOf(fakeType2{}), + reflect.TypeOf(fakeType3{}), + reflect.TypeOf(fakeType4{}) - reg.RegisterKindEncoder(reflect.Array, codec2) - if reg.kindEncoders.get(reflect.Array) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) - } - }) - }) - t.Run("Lookup", func(t *testing.T) { - type Codec interface { - ValueEncoder - ValueDecoder + var c1, c2, c3, c4 int + ef1 := func() ValueEncoder { + c1++ + return fc1 + } + ef2 := func() ValueEncoder { + c2++ + return fc2 + } + ef3 := func() ValueEncoder { + c3++ + return fc3 + } + ef4 := func() ValueEncoder { + c4++ + return fc4 } - var ( - arrinstance [12]int - arr = reflect.TypeOf(arrinstance) - slc = reflect.TypeOf(make([]int, 12)) - m = reflect.TypeOf(make(map[string]int)) - strct = reflect.TypeOf(struct{ Foo string }{}) - ft1 = reflect.PtrTo(reflect.TypeOf(fakeType1{})) - ft2 = reflect.TypeOf(fakeType2{}) - ft3 = reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" })) - ti1 = reflect.TypeOf((*testInterface1)(nil)).Elem() - ti2 = reflect.TypeOf((*testInterface2)(nil)).Elem() - ti1Impl = reflect.TypeOf(testInterface1Impl{}) - ti2Impl = reflect.TypeOf(testInterface2Impl{}) - ti3 = reflect.TypeOf((*testInterface3)(nil)).Elem() - ti3Impl = reflect.TypeOf(testInterface3Impl{}) - ti3ImplPtr = reflect.TypeOf((*testInterface3Impl)(nil)) - fc1, fc2 = &fakeCodec{num: 1}, &fakeCodec{num: 2} - fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec) - pc = &pointerCodec{} - ) - - reg := newTestRegistry() - reg.RegisterTypeEncoder(ft1, fc1) - reg.RegisterTypeEncoder(ft2, fc2) - reg.RegisterTypeEncoder(ti1, fc1) - reg.RegisterKindEncoder(reflect.Struct, fsc) - reg.RegisterKindEncoder(reflect.Slice, fslcc) - reg.RegisterKindEncoder(reflect.Array, fslcc) - reg.RegisterKindEncoder(reflect.Map, fmc) - reg.RegisterKindEncoder(reflect.Ptr, pc) - reg.RegisterTypeDecoder(ft1, fc1) - reg.RegisterTypeDecoder(ft2, fc2) - reg.RegisterTypeDecoder(ti1, fc1) // values whose exact type is testInterface1 will use fc1 encoder - reg.RegisterKindDecoder(reflect.Struct, fsc) - reg.RegisterKindDecoder(reflect.Slice, fslcc) - reg.RegisterKindDecoder(reflect.Array, fslcc) - reg.RegisterKindDecoder(reflect.Map, fmc) - reg.RegisterKindDecoder(reflect.Ptr, pc) - reg.RegisterInterfaceEncoder(ti2, fc2) - reg.RegisterInterfaceEncoder(ti3, fc3) - reg.RegisterInterfaceDecoder(ti2, fc2) - reg.RegisterInterfaceDecoder(ti3, fc3) - - testCases := []struct { - name string - t reflect.Type - wantcodec Codec - wanterr error - testcache bool + ips := []struct { + i reflect.Type + ef EncoderFactory }{ - { - "type registry (pointer)", - ft1, - fc1, - nil, - false, - }, - { - "type registry (non-pointer)", - ft2, - fc2, - nil, - false, - }, - { - // lookup an interface type and expect that the registered encoder is returned - "interface with type encoder", - ti1, - fc1, - nil, - true, - }, - { - // lookup a type that implements an interface and expect that the default struct codec is returned - "interface implementation with type encoder", - ti1Impl, - fsc, - nil, - false, - }, - { - // lookup an interface type and expect that the registered hook is returned - "interface with hook", - ti2, - fc2, - nil, - false, - }, - { - // lookup a type that implements an interface and expect that the registered hook is returned - "interface implementation with hook", - ti2Impl, - fc2, - nil, - false, - }, - { - // lookup a pointer to a type where the pointer implements an interface and expect that the - // registered hook is returned - "interface pointer to implementation with hook (pointer)", - ti3ImplPtr, - fc3, - nil, - false, - }, - { - "default struct codec (pointer)", - reflect.PtrTo(strct), - pc, - nil, - false, - }, - { - "default struct codec (non-pointer)", - strct, - fsc, - nil, - false, - }, - { - "default array codec", - arr, - fslcc, - nil, - false, - }, - { - "default slice codec", - slc, - fslcc, - nil, - false, - }, - { - "default map", - m, - fmc, - nil, - false, - }, - { - "map non-string key", - reflect.TypeOf(map[int]int{}), - fmc, - nil, - false, - }, - { - "No Codec Registered", - ft3, - nil, - ErrNoEncoder{Type: ft3}, - false, - }, + {i: ft1, ef: ef1}, + {i: ft2, ef: ef2}, + {i: ft1, ef: ef3}, + {i: ft3, ef: ef2}, + {i: ft4, ef: ef4}, } - - allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{}) - comparepc := func(pc1, pc2 *pointerCodec) bool { return true } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Run("Encoder", func(t *testing.T) { - gotcodec, goterr := reg.LookupEncoder(tc.t) - if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(assert.CompareErrors)) { - t.Errorf("errors did not match: got %#v, want %#v", goterr, tc.wanterr) - } - if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec) - } - }) - t.Run("Decoder", func(t *testing.T) { - wanterr := tc.wanterr - if ene, ok := tc.wanterr.(ErrNoEncoder); ok { - wanterr = ErrNoDecoder(ene) - } - - gotcodec, goterr := reg.LookupDecoder(tc.t) - if !cmp.Equal(goterr, wanterr, cmp.Comparer(assert.CompareErrors)) { - t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr) - } - if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec) - } - }) - }) + want := []interfaceValueEncoder{ + {i: ft1, ve: fc3}, {i: ft2, ve: fc2}, + {i: ft3, ve: fc2}, {i: ft4, ve: fc4}, } - // lookup a type whose pointer implements an interface and expect that the registered hook is - // returned - t.Run("interface implementation with hook (pointer)", func(t *testing.T) { - t.Run("Encoder", func(t *testing.T) { - gotEnc, err := reg.LookupEncoder(ti3Impl) - assert.Nil(t, err, "LookupEncoder error: %v", err) - cae, ok := gotEnc.(*condAddrEncoder) - assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc) - if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected canAddrEnc %#v, got %#v", cae.canAddrEnc, fc3) - } - if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected elseEnc %#v, got %#v", cae.elseEnc, fsc) - } - }) - t.Run("Decoder", func(t *testing.T) { - gotDec, err := reg.LookupDecoder(ti3Impl) - assert.Nil(t, err, "LookupDecoder error: %v", err) - - cad, ok := gotDec.(*condAddrDecoder) - assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec) - if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected canAddrDec %#v, got %#v", cad.canAddrDec, fc3) - } - if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected elseDec %#v, got %#v", cad.elseDec, fsc) - } - }) - }) - }) - }) - t.Run("Type Map", func(t *testing.T) { - reg := newTestRegistry() - reg.RegisterTypeMapEntry(TypeString, reflect.TypeOf("")) - reg.RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))) - - var got, want reflect.Type - - want = reflect.TypeOf("") - got, err := reg.LookupTypeMapEntry(TypeString) - noerr(t, err) - if got != want { - t.Errorf("unexpected type: got %#v, want %#v", got, want) - } - - want = reflect.TypeOf(int(0)) - got, err = reg.LookupTypeMapEntry(TypeInt32) - noerr(t, err) - if got != want { - t.Errorf("unexpected type: got %#v, want %#v", got, want) - } - - want = nil - wanterr := ErrNoTypeMapEntry{Type: TypeObjectID} - got, err = reg.LookupTypeMapEntry(TypeObjectID) - if !errors.Is(err, wanterr) { - t.Errorf("did not get expected error: got %#v, want %#v", err, wanterr) - } - if got != want { - t.Errorf("unexpected type: got %#v, want %#v", got, want) - } - }) -} - -func TestRegistry(t *testing.T) { - t.Parallel() - - t.Run("Register", func(t *testing.T) { - t.Parallel() - - fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec) - t.Run("interface", func(t *testing.T) { - t.Parallel() + rb := newTestRegistryBuilder() + for _, ip := range ips { + rb.RegisterTypeEncoder(ip.i, ip.ef) + } + reg := rb.Build() - var t1f *testInterface1 - var t2f *testInterface2 - var t4f *testInterface4 - ips := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc1}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + if !cmp.Equal(c1, 0) { + t.Errorf("ef1 is called %d time(s); expected 0", c1) } - want := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + if !cmp.Equal(c2, 1) { + t.Errorf("ef2 is called %d time(s); expected 1", c2) } - reg := newTestRegistry() - for _, ip := range ips { - reg.RegisterInterfaceEncoder(ip.i, ip.ve) + if !cmp.Equal(c3, 1) { + t.Errorf("ef3 is called %d time(s); expected 1", c3) } - got := reg.interfaceEncoders - if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) { - t.Errorf("registered interfaces are not correct: got %#v, want %#v", got, want) + if !cmp.Equal(c4, 1) { + t.Errorf("ef4 is called %d time(s); expected 1", c4) } - }) - t.Run("type", func(t *testing.T) { - t.Parallel() - - ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} - reg := newTestRegistry() - reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc1) - reg.RegisterTypeEncoder(reflect.TypeOf(ft2), fc2) - reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc3) - reg.RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) - - want := []struct { - t reflect.Type - c ValueEncoder - }{ - {reflect.TypeOf(ft1), fc3}, - {reflect.TypeOf(ft2), fc2}, - {reflect.TypeOf(ft4), fc4}, + codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + t.Errorf("codecs were not cached correctly") } got := reg.typeEncoders for _, s := range want { - wantT, wantC := s.t, s.c - gotC, exists := got.Load(wantT) + wantI, wantVe := s.i, s.ve + gotVe, exists := got.Load(wantI) if !exists { - t.Errorf("type missing in registry: %v", wantT) + t.Errorf("type missing in registry: %v", wantI) } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC) + if !cmp.Equal(gotVe, wantVe, cmp.AllowUnexported(fakeCodec{})) { + t.Errorf("codecs did not match: got %#v; want %#v", gotVe, wantVe) } } }) t.Run("kind", func(t *testing.T) { t.Parallel() - k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map - reg := newTestRegistry() - reg.RegisterKindEncoder(k1, fc1) - reg.RegisterKindEncoder(k2, fc2) - reg.RegisterKindEncoder(k1, fc3) - reg.RegisterKindEncoder(k4, fc4) + k1, k2, k3, k4 := reflect.Struct, reflect.Slice, reflect.Int, reflect.Map + var c1, c2, c3, c4 int + ef1 := func() ValueEncoder { + c1++ + return fc1 + } + ef2 := func() ValueEncoder { + c2++ + return fc2 + } + ef3 := func() ValueEncoder { + c3++ + return fc3 + } + ef4 := func() ValueEncoder { + c4++ + return fc4 + } + + ips := []struct { + k reflect.Kind + ef EncoderFactory + }{ + {k: k1, ef: ef1}, + {k: k2, ef: ef2}, + {k: k1, ef: ef3}, + {k: k3, ef: ef2}, + {k: k4, ef: ef4}, + } want := []struct { k reflect.Kind c ValueEncoder }{ - {k1, fc3}, - {k2, fc2}, - {k4, fc4}, + {k1, fc3}, {k2, fc2}, {k4, fc4}, + } + + rb := newTestRegistryBuilder() + for _, ip := range ips { + rb.RegisterKindEncoder(ip.k, ip.ef) + } + reg := rb.Build() + + if !cmp.Equal(c1, 0) { + t.Errorf("ef1 is called %d time(s); expected 0", c1) + } + if !cmp.Equal(c2, 1) { + t.Errorf("ef2 is called %d time(s); expected 1", c2) + } + if !cmp.Equal(c3, 1) { + t.Errorf("ef3 is called %d time(s); expected 1", c3) + } + if !cmp.Equal(c4, 1) { + t.Errorf("ef4 is called %d time(s); expected 1", c4) + } + codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + t.Errorf("codecs were not cached correctly") } got := reg.kindEncoders for _, s := range want { - wantK, wantC := s.k, s.c - gotC, exists := got.Load(wantK) - if !exists { - t.Errorf("type missing in registry: %v", wantK) - } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v, want %#v", gotC, wantC) + wantI, wantVe := s.k, s.c + gotC := got[wantI] + if !cmp.Equal(gotC, wantVe, cmp.AllowUnexported(fakeCodec{})) { + t.Errorf("codecs did not match: got %#v, want %#v", gotC, wantVe) } } }) @@ -528,14 +267,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - reg.RegisterKindEncoder(reflect.Map, codec) - if reg.kindEncoders.get(reflect.Map) != codec { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Map, func() ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Map]; got != codec { + t.Errorf("map codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Map, codec2) - if reg.kindEncoders.get(reflect.Map) != codec2 { - t.Errorf("map codec properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) + + rb.RegisterKindEncoder(reflect.Map, func() ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Map]; got != codec2 { + t.Errorf("map codec not properly set: got %#v, want %#v", got, codec2) } }) t.Run("StructCodec", func(t *testing.T) { @@ -543,14 +286,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - reg.RegisterKindEncoder(reflect.Struct, codec) - if reg.kindEncoders.get(reflect.Struct) != codec { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Struct]; got != codec { + t.Errorf("struct codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Struct, codec2) - if reg.kindEncoders.get(reflect.Struct) != codec2 { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) + + rb.RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Struct]; got != codec2 { + t.Errorf("struct codec not properly set: got %#v, want %#v", got, codec2) } }) t.Run("SliceCodec", func(t *testing.T) { @@ -558,14 +305,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - reg.RegisterKindEncoder(reflect.Slice, codec) - if reg.kindEncoders.get(reflect.Slice) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Slice]; got != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Slice, codec2) - if reg.kindEncoders.get(reflect.Slice) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) + + rb.RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Slice]; got != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec2) } }) t.Run("ArrayCodec", func(t *testing.T) { @@ -573,14 +324,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - reg.RegisterKindEncoder(reflect.Array, codec) - if reg.kindEncoders.get(reflect.Array) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Array, func() ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Array]; got != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Array, codec2) - if reg.kindEncoders.get(reflect.Array) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) + + rb.RegisterKindEncoder(reflect.Array, func() ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Array]; got != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec2) } }) }) @@ -613,27 +368,36 @@ func TestRegistry(t *testing.T) { pc = &pointerCodec{} ) - reg := newTestRegistry() - reg.RegisterTypeEncoder(ft1, fc1) - reg.RegisterTypeEncoder(ft2, fc2) - reg.RegisterTypeEncoder(ti1, fc1) - reg.RegisterKindEncoder(reflect.Struct, fsc) - reg.RegisterKindEncoder(reflect.Slice, fslcc) - reg.RegisterKindEncoder(reflect.Array, fslcc) - reg.RegisterKindEncoder(reflect.Map, fmc) - reg.RegisterKindEncoder(reflect.Ptr, pc) - reg.RegisterTypeDecoder(ft1, fc1) - reg.RegisterTypeDecoder(ft2, fc2) - reg.RegisterTypeDecoder(ti1, fc1) // values whose exact type is testInterface1 will use fc1 encoder - reg.RegisterKindDecoder(reflect.Struct, fsc) - reg.RegisterKindDecoder(reflect.Slice, fslcc) - reg.RegisterKindDecoder(reflect.Array, fslcc) - reg.RegisterKindDecoder(reflect.Map, fmc) - reg.RegisterKindDecoder(reflect.Ptr, pc) - reg.RegisterInterfaceEncoder(ti2, fc2) - reg.RegisterInterfaceDecoder(ti2, fc2) - reg.RegisterInterfaceEncoder(ti3, fc3) - reg.RegisterInterfaceDecoder(ti3, fc3) + fc1EncFac := func() ValueEncoder { return fc1 } + fc2EncFac := func() ValueEncoder { return fc2 } + fc3EncFac := func() ValueEncoder { return fc3 } + fscEncFac := func() ValueEncoder { return fsc } + fslccEncFac := func() ValueEncoder { return fslcc } + fmcEncFac := func() ValueEncoder { return fmc } + pcEncFac := func() ValueEncoder { return pc } + + reg := newTestRegistryBuilder(). + RegisterTypeEncoder(ft1, fc1EncFac). + RegisterTypeEncoder(ft2, fc2EncFac). + RegisterTypeEncoder(ti1, fc1EncFac). + RegisterKindEncoder(reflect.Struct, fscEncFac). + RegisterKindEncoder(reflect.Slice, fslccEncFac). + RegisterKindEncoder(reflect.Array, fslccEncFac). + RegisterKindEncoder(reflect.Map, fmcEncFac). + RegisterKindEncoder(reflect.Ptr, pcEncFac). + RegisterTypeDecoder(ft1, fc1). + RegisterTypeDecoder(ft2, fc2). + RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder + RegisterKindDecoder(reflect.Struct, fsc). + RegisterKindDecoder(reflect.Slice, fslcc). + RegisterKindDecoder(reflect.Array, fslcc). + RegisterKindDecoder(reflect.Map, fmc). + RegisterKindDecoder(reflect.Ptr, pc). + RegisterInterfaceEncoder(ti2, fc2EncFac). + RegisterInterfaceEncoder(ti3, fc3EncFac). + RegisterInterfaceDecoder(ti2, fc2). + RegisterInterfaceDecoder(ti3, fc3). + Build() testCases := []struct { name string @@ -854,9 +618,10 @@ func TestRegistry(t *testing.T) { }) t.Run("Type Map", func(t *testing.T) { t.Parallel() - reg := newTestRegistry() - reg.RegisterTypeMapEntry(TypeString, reflect.TypeOf("")) - reg.RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))) + reg := newTestRegistryBuilder(). + RegisterTypeMapEntry(TypeString, reflect.TypeOf("")). + RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))). + Build() var got, want reflect.Type @@ -886,12 +651,6 @@ func TestRegistry(t *testing.T) { }) } -// get is only for testing as it does return if the value was found -func (c *kindEncoderCache) get(rt reflect.Kind) ValueEncoder { - e, _ := c.Load(rt) - return e -} - func BenchmarkLookupEncoder(b *testing.B) { type childStruct struct { V1, V2, V3, V4 int @@ -908,10 +667,11 @@ func BenchmarkLookupEncoder(b *testing.B) { reflect.TypeOf(&testInterface1Impl{}), reflect.TypeOf(&nestedStruct{}), } - r := NewRegistry() + rb := NewRegistryBuilder() for _, typ := range types { - r.RegisterTypeEncoder(typ, &fakeCodec{}) + rb.RegisterTypeEncoder(typ, func() ValueEncoder { return &fakeCodec{} }) } + r := rb.Build() b.Run("Serial", func(b *testing.B) { for i := 0; i < b.N; i++ { _, err := r.LookupEncoder(types[i%len(types)]) @@ -934,6 +694,7 @@ func BenchmarkLookupEncoder(b *testing.B) { type fakeType1 struct{} type fakeType2 struct{} +type fakeType3 struct{} type fakeType4 struct{} type fakeType5 func(string, string) string type fakeStructCodec struct{ *fakeCodec } @@ -948,7 +709,7 @@ type fakeCodec struct { num int } -func (*fakeCodec) EncodeValue(*Registry, ValueWriter, reflect.Value) error { +func (*fakeCodec) EncodeValue(EncoderRegistry, ValueWriter, reflect.Value) error { return nil } func (*fakeCodec) DecodeValue(DecodeContext, ValueReader, reflect.Value) error { @@ -977,5 +738,3 @@ type testInterface3Impl struct{} var _ testInterface3 = (*testInterface3Impl)(nil) func (*testInterface3Impl) test3() {} - -func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 } diff --git a/bson/setter_getter.go b/bson/setter_getter.go index 069408c9ab..46706241be 100644 --- a/bson/setter_getter.go +++ b/bson/setter_getter.go @@ -84,7 +84,7 @@ func SetterDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error } // GetterEncodeValue is the ValueEncoderFunc for Getter types. -func GetterEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func GetterEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Getter switch { case !val.IsValid(): diff --git a/bson/slice_codec.go b/bson/slice_codec.go index d7db3cf9da..3640cdd124 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -24,7 +24,7 @@ type sliceCodec struct { } // EncodeValue is the ValueEncoder for slice types. -func (sc sliceCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (sc sliceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Slice { return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } diff --git a/bson/string_codec.go b/bson/string_codec.go index de73fc6f0d..9f1ee76136 100644 --- a/bson/string_codec.go +++ b/bson/string_codec.go @@ -24,7 +24,7 @@ var ( ) // EncodeValue is the ValueEncoder for string types. -func (sc *stringCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func (sc *stringCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.String { return ValueEncoderError{ Name: "StringEncodeValue", diff --git a/bson/struct_codec.go b/bson/struct_codec.go index b489b1dc56..c3ddd4f2c6 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -88,12 +88,12 @@ func newStructCodec(p StructTagParser) *structCodec { } // EncodeValue handles encoding generic struct types. -func (sc *structCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (sc *structCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Struct { return ValueEncoderError{Name: "structCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } - sd, err := sc.describeStruct(reg, val.Type(), sc.useJSONStructTags, !sc.overwriteDuplicatedInlinedFields) + sd, err := sc.describeStruct(val.Type(), sc.useJSONStructTags, !sc.overwriteDuplicatedInlinedFields) if err != nil { return err } @@ -113,7 +113,12 @@ func (sc *structCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Va } } - desc.encoder, rv, err = lookupElementEncoder(reg, desc.encoder, rv) + var encoder ValueEncoder + if encoder, err = reg.LookupEncoder(desc.fieldType); err != nil { + encoder = nil + } + + encoder, rv, err = lookupElementEncoder(reg, encoder, rv) if err != nil && !errors.Is(err, errInvalidValue) { return err @@ -134,12 +139,10 @@ func (sc *structCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Va continue } - if desc.encoder == nil { + if encoder == nil { return ErrNoEncoder{Type: rv.Type()} } - encoder := desc.encoder - var empty bool if cz, ok := encoder.(CodecZeroer); ok { empty = cz.IsTypeZero(rv.Interface()) @@ -160,12 +163,7 @@ func (sc *structCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Va } // defaultUIntCodec.encodeToMinSize = desc.minSize - switch v := encoder.(type) { - case *uintCodec: - encoder = &uintCodec{ - encodeToMinSize: v.encodeToMinSize || desc.minSize, - } - case *intCodec: + if v, ok := encoder.(*intCodec); ok { encoder = &intCodec{ encodeToMinSize: v.encodeToMinSize || desc.minSize, } @@ -231,7 +229,7 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) } - sd, err := sc.describeStruct(dc.Registry, val.Type(), dc.useJSONStructTags, false) + sd, err := sc.describeStruct(val.Type(), dc.useJSONStructTags, false) if err != nil { return err } @@ -330,11 +328,12 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect zeroStructs: dc.zeroStructs, } - if fd.decoder == nil { + decoder, err := dc.Registry.LookupDecoder(fd.fieldType) + if err != nil { return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()}) } - err = fd.decoder.DecodeValue(dctx, vr, field.Elem()) + err = decoder.DecodeValue(dctx, vr, field.Elem()) if err != nil { return newDecodeError(fd.name, err) } @@ -389,8 +388,7 @@ type fieldDescription struct { minSize bool truncate bool inline []int - encoder ValueEncoder - decoder ValueDecoder + fieldType reflect.Type } type byIndex []fieldDescription @@ -423,7 +421,6 @@ func (bi byIndex) Less(i, j int) bool { } func (sc *structCodec) describeStruct( - r *Registry, t reflect.Type, useJSONStructTags bool, errorOnDuplicates bool, @@ -435,7 +432,7 @@ func (sc *structCodec) describeStruct( } // TODO(charlie): Only describe the struct once when called // concurrently with the same type. - ds, err := sc.describeStructSlow(r, t, useJSONStructTags, errorOnDuplicates) + ds, err := sc.describeStructSlow(t, useJSONStructTags, errorOnDuplicates) if err != nil { return nil, err } @@ -446,7 +443,6 @@ func (sc *structCodec) describeStruct( } func (sc *structCodec) describeStructSlow( - r *Registry, t reflect.Type, useJSONStructTags bool, errorOnDuplicates bool, @@ -467,23 +463,15 @@ func (sc *structCodec) describeStructSlow( } sfType := sf.Type - encoder, err := r.LookupEncoder(sfType) - if err != nil { - encoder = nil - } - decoder, err := r.LookupDecoder(sfType) - if err != nil { - decoder = nil - } description := fieldDescription{ fieldName: sf.Name, idx: i, - encoder: encoder, - decoder: decoder, + fieldType: sfType, } var stags StructTags + var err error // If the caller requested that we use JSON struct tags, use the JSONFallbackStructTagParser // instead of the parser defined on the codec. if useJSONStructTags { @@ -520,7 +508,7 @@ func (sc *structCodec) describeStructSlow( } fallthrough case reflect.Struct: - inlinesf, err := sc.describeStruct(r, sfType, useJSONStructTags, errorOnDuplicates) + inlinesf, err := sc.describeStruct(sfType, useJSONStructTags, errorOnDuplicates) if err != nil { return nil, err } diff --git a/bson/time_codec.go b/bson/time_codec.go index 535861ed71..d9bb57404b 100644 --- a/bson/time_codec.go +++ b/bson/time_codec.go @@ -99,7 +99,7 @@ func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.V } // EncodeValue is the ValueEncoderFunc for time.TIme. -func (tc *timeCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func (tc *timeCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTime { return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} } diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index 0871237386..d8ef9b69ba 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -228,8 +228,9 @@ func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) { val.SetInt(int64(-1 * i32)) return nil } - customReg := NewRegistry() - customReg.RegisterTypeDecoder(tInt32, decodeInt32) + customReg := NewRegistryBuilder(). + RegisterTypeDecoder(tInt32, decodeInt32). + Build() docBytes := bsoncore.BuildDocumentFromElements( nil, diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index fd379b5daa..3af7578d12 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -75,8 +75,9 @@ func TestUnmarshalValue(t *testing.T) { bytes: bsoncore.AppendString(nil, "hello world"), }, } - reg := NewRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}) + reg := NewRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}). + Build() for _, tc := range testCases { tc := tc @@ -110,8 +111,9 @@ func BenchmarkSliceCodecUnmarshal(b *testing.B) { bytes: bsoncore.AppendString(nil, strings.Repeat("t", 4096)), }, } - reg := NewRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}) + reg := NewRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}). + Build() for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { b.RunParallel(func(pb *testing.PB) { diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index a098880b4c..d3ebb1421b 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -39,7 +39,7 @@ type negateCodec struct { ID int64 `bson:"_id"` } -func (e *negateCodec) EncodeValue(_ *bson.Registry, vw bson.ValueWriter, val reflect.Value) error { +func (e *negateCodec) EncodeValue(_ bson.EncoderRegistry, vw bson.ValueWriter, val reflect.Value) error { return vw.WriteInt64(val.Int()) } @@ -100,9 +100,10 @@ func (sc *slowConn) Read(b []byte) (n int, err error) { func TestClient(t *testing.T) { mt := mtest.New(t, noClientOpts) - reg := bson.NewRegistry() - reg.RegisterTypeEncoder(reflect.TypeOf(int64(0)), &negateCodec{}) - reg.RegisterTypeDecoder(reflect.TypeOf(int64(0)), &negateCodec{}) + reg := bson.NewRegistryBuilder(). + RegisterTypeEncoder(reflect.TypeOf(int64(0)), func() bson.ValueEncoder { return &negateCodec{} }). + RegisterTypeDecoder(reflect.TypeOf(int64(0)), &negateCodec{}). + Build() registryOpts := options.Client(). SetRegistry(reg) mt.RunOpts("registry passed to cursors", mtest.NewOptions().ClientOptions(registryOpts), func(mt *mtest.T) { diff --git a/internal/integration/crud_spec_test.go b/internal/integration/crud_spec_test.go index e6583f8ade..996cdd27f4 100644 --- a/internal/integration/crud_spec_test.go +++ b/internal/integration/crud_spec_test.go @@ -55,11 +55,9 @@ type crudOutcome struct { Collection *outcomeCollection `bson:"collection"` } -var crudRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) - return reg -}() +var crudRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). + Build() func TestCrudSpec(t *testing.T) { for _, dir := range []string{crudReadDir, crudWriteDir} { diff --git a/internal/integration/database_test.go b/internal/integration/database_test.go index 12c2e0cd53..da043a6636 100644 --- a/internal/integration/database_test.go +++ b/internal/integration/database_test.go @@ -29,11 +29,9 @@ const ( ) var ( - interfaceAsMapRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.M{})) - return reg - }() + interfaceAsMapRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.M{})). + Build() ) func TestDatabase(t *testing.T) { diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index c9199f6135..487714d834 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -181,12 +181,10 @@ var directories = []string{ } var checkOutcomeOpts = options.Collection().SetReadPreference(readpref.Primary()).SetReadConcern(readconcern.Local()) -var specTestRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) - reg.RegisterTypeDecoder(reflect.TypeOf(testData{}), bson.ValueDecoderFunc(decodeTestData)) - return reg -}() +var specTestRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). + RegisterTypeDecoder(reflect.TypeOf(testData{}), bson.ValueDecoderFunc(decodeTestData)). + Build() func TestUnifiedSpecs(t *testing.T) { for _, specDir := range directories { diff --git a/mongo/database_test.go b/mongo/database_test.go index 31bd900439..1142b6df9c 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -53,7 +53,7 @@ func TestDatabase(t *testing.T) { wc2 := &writeconcern.WriteConcern{W: 10} rcLocal := readconcern.Local() rcMajority := readconcern.Majority() - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder().Build() opts := options.Database().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetWriteConcern(wc1). SetReadPreference(rpSecondary).SetReadConcern(rcMajority).SetWriteConcern(wc2).SetRegistry(reg) @@ -70,7 +70,7 @@ func TestDatabase(t *testing.T) { rpPrimary := readpref.Primary() rcLocal := readconcern.Local() wc1 := &writeconcern.WriteConcern{W: 10} - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder().Build() client := setupClient(options.Client().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetRegistry(reg)) got := client.Database("foo", options.Database().SetWriteConcern(wc1)) diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index beba45514f..078c029308 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -80,7 +80,7 @@ func TestClientOptions(t *testing.T) { {"Monitor", (*ClientOptions).SetMonitor, &event.CommandMonitor{}, "Monitor", false}, {"ReadConcern", (*ClientOptions).SetReadConcern, readconcern.Majority(), "ReadConcern", false}, {"ReadPreference", (*ClientOptions).SetReadPreference, readpref.SecondaryPreferred(), "ReadPreference", false}, - {"Registry", (*ClientOptions).SetRegistry, bson.NewRegistry(), "Registry", false}, + {"Registry", (*ClientOptions).SetRegistry, bson.NewRegistryBuilder().Build(), "Registry", false}, {"ReplicaSet", (*ClientOptions).SetReplicaSet, "example-replicaset", "ReplicaSet", true}, {"RetryWrites", (*ClientOptions).SetRetryWrites, true, "RetryWrites", true}, {"ServerSelectionTimeout", (*ClientOptions).SetServerSelectionTimeout, 5 * time.Second, "ServerSelectionTimeout", true}, diff --git a/mongo/read_write_concern_spec_test.go b/mongo/read_write_concern_spec_test.go index ec49bb91db..c737f76a9b 100644 --- a/mongo/read_write_concern_spec_test.go +++ b/mongo/read_write_concern_spec_test.go @@ -31,11 +31,9 @@ const ( var ( serverDefaultConcern = []byte{5, 0, 0, 0, 0} // server default read concern and write concern is empty document - specTestRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) - return reg - }() + specTestRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). + Build() ) type connectionStringTestFile struct { diff --git a/x/mongo/driver/topology/server_options.go b/x/mongo/driver/topology/server_options.go index c02600e232..dca9c0581b 100644 --- a/x/mongo/driver/topology/server_options.go +++ b/x/mongo/driver/topology/server_options.go @@ -17,7 +17,7 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) -var defaultRegistry = bson.NewRegistry() +var defaultRegistry = bson.NewRegistryBuilder().Build() type serverConfig struct { clock *session.ClusterClock From 51517bcd3f909b92f0d87bead310e4c753d97ebf Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 14 May 2024 17:01:18 -0400 Subject: [PATCH 06/15] WIP --- bson/bsoncodec.go | 13 ++++++------- bson/cond_addr_codec.go | 8 ++------ bson/cond_addr_codec_test.go | 4 ++-- bson/empty_interface_codec.go | 2 +- bson/pointer_codec.go | 17 +++++++++++------ bson/registry.go | 25 +++++++++++-------------- bson/registry_test.go | 2 +- 7 files changed, 34 insertions(+), 37 deletions(-) diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index db176ad906..e0369a1111 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -126,6 +126,12 @@ func (fn ValueEncoderFunc) EncodeValue(reg EncoderRegistry, vw ValueWriter, val return fn(reg, vw, val) } +// DecoderRegistry is an interface provides a ValueDecoder based on the given reflect.Type. +type DecoderRegistry interface { + LookupDecoder(reflect.Type) (ValueDecoder, error) + LookupTypeMapEntry(Type) (reflect.Type, error) +} + // ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type. // Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc, // ValueDecoderFunc is provided to allow the use of a function with the correct signature as a @@ -165,13 +171,6 @@ type decodeAdapter struct { var _ ValueDecoder = decodeAdapter{} var _ typeDecoder = decodeAdapter{} -// decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type -// t and calls decoder.DecodeValue on it. -func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - td, _ := decoder.(typeDecoder) - return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true) -} - func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { if td != nil { val, err := td.decodeType(dc, vr, t) diff --git a/bson/cond_addr_codec.go b/bson/cond_addr_codec.go index ef87da6250..897d311a79 100644 --- a/bson/cond_addr_codec.go +++ b/bson/cond_addr_codec.go @@ -16,6 +16,8 @@ type condAddrEncoder struct { elseEnc ValueEncoder } +var _ ValueEncoder = (*condAddrEncoder)(nil) + // EncodeValue is the ValueEncoderFunc for a value that may be addressable. func (cae *condAddrEncoder) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.CanAddr() { @@ -35,12 +37,6 @@ type condAddrDecoder struct { var _ ValueDecoder = (*condAddrDecoder)(nil) -// newCondAddrDecoder returns an CondAddrDecoder. -func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder { - decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec} - return &decoder -} - // DecodeValue is the ValueDecoderFunc for a value that may be addressable. func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if val.CanAddr() { diff --git a/bson/cond_addr_codec_test.go b/bson/cond_addr_codec_test.go index ee4f61f3a7..b0bf63c1cd 100644 --- a/bson/cond_addr_codec_test.go +++ b/bson/cond_addr_codec_test.go @@ -66,7 +66,7 @@ func TestCondAddrCodec(t *testing.T) { invoked = 2 return nil }) - condDecoder := newCondAddrDecoder(decode1, decode2) + condDecoder := &condAddrDecoder{canAddrDec: decode1, elseDec: decode2} testCases := []struct { name string @@ -86,7 +86,7 @@ func TestCondAddrCodec(t *testing.T) { } t.Run("error", func(t *testing.T) { - errDecoder := newCondAddrDecoder(decode1, nil) + errDecoder := &condAddrDecoder{canAddrDec: decode1, elseDec: nil} err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable) want := ErrNoDecoder{Type: unaddressable.Type()} assert.Equal(t, err, want, "expected error %v, got %v", want, err) diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index cea7dfd348..59754ec6b5 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -99,7 +99,7 @@ func (eic emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t re return emptyValue, err } - elem, err := decodeTypeOrValue(decoder, dc, vr, rtype) + elem, err := decodeTypeOrValueWithInfo(decoder, decoder.(typeDecoder), dc, vr, rtype, true) if err != nil { return emptyValue, err } diff --git a/bson/pointer_codec.go b/bson/pointer_codec.go index af35da68b2..d0aec9c7d5 100644 --- a/bson/pointer_codec.go +++ b/bson/pointer_codec.go @@ -8,12 +8,13 @@ package bson import ( "reflect" + "sync" ) // pointerCodec is the Codec used for pointers. type pointerCodec struct { - ecache typeEncoderCache - dcache typeDecoderCache + ecache sync.Map // map[reflect.Type]ValueEncoder + dcache sync.Map // map[reflect.Type]ValueDecoder } // EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil @@ -35,14 +36,16 @@ func (pc *pointerCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val ref if v == nil { return ErrNoEncoder{Type: typ} } - return v.EncodeValue(reg, vw, val.Elem()) + return v.(ValueEncoder).EncodeValue(reg, vw, val.Elem()) } // TODO(charlie): handle concurrent requests for the same type enc, err := reg.LookupEncoder(typ.Elem()) - enc = pc.ecache.LoadOrStore(typ, enc) if err != nil { return err } + if v, ok := pc.ecache.LoadOrStore(typ, enc); ok { + enc = v.(ValueEncoder) + } return enc.EncodeValue(reg, vw, val.Elem()) } @@ -71,13 +74,15 @@ func (pc *pointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflec if v == nil { return ErrNoDecoder{Type: typ} } - return v.DecodeValue(dc, vr, val.Elem()) + return v.(ValueDecoder).DecodeValue(dc, vr, val.Elem()) } // TODO(charlie): handle concurrent requests for the same type dec, err := dc.LookupDecoder(typ.Elem()) - dec = pc.dcache.LoadOrStore(typ, dec) if err != nil { return err } + if v, ok := pc.dcache.LoadOrStore(typ, dec); ok { + dec = v.(ValueDecoder) + } return dec.DecodeValue(dc, vr, val.Elem()) } diff --git a/bson/registry.go b/bson/registry.go index 179b61de91..8f1cc421b9 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -7,7 +7,6 @@ package bson import ( - "errors" "fmt" "reflect" "sync" @@ -17,11 +16,6 @@ import ( // primitive codecs. var DefaultRegistry = NewRegistryBuilder().Build() -// ErrNilType is returned when nil is passed to either LookupEncoder or LookupDecoder. -// -// Deprecated: ErrNilType will not be supported in Go Driver 2.0. -var ErrNilType = errors.New("cannot perform a decoder lookup on ") - // ErrNoEncoder is returned when there wasn't an encoder available for a type. // // Deprecated: ErrNoEncoder will not be supported in Go Driver 2.0. @@ -58,6 +52,12 @@ func (entme ErrNoTypeMapEntry) Error() string { return "no type map entry found for " + entme.Type.String() } +// EncoderFactory is a factory function that generates a new ValueEncoder. +type EncoderFactory func() ValueEncoder + +// DecoderFactory is a factory function that generates a new ValueDecoder. +type DecoderFactory func() ValueDecoder + // A RegistryBuilder is used to build a Registry. This type is not goroutine // safe. type RegistryBuilder struct { @@ -85,12 +85,6 @@ func NewRegistryBuilder() *RegistryBuilder { return rb } -// EncoderFactory is a factory function that generates a new ValueEncoder. -type EncoderFactory func() ValueEncoder - -// DecoderFactory is a factory function that generates a new ValueDecoder. -type DecoderFactory func() ValueDecoder - // RegisterTypeEncoder registers a ValueEncoder factory for the provided type. // // The type will be used as provided, so an encoder factory can be registered for a type and a @@ -412,7 +406,7 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool // concurrent use by multiple goroutines. func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { if valueType == nil { - return nil, ErrNilType + return nil, ErrNoDecoder{Type: valueType} } dec, found := r.typeDecoders.Load(valueType) if found { @@ -434,6 +428,9 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { } func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) { + if valueType == nil { + return nil, false + } for _, idec := range r.interfaceDecoders { if valueType.Implements(idec.i) { return idec.vd, true @@ -445,7 +442,7 @@ func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool if !found { defaultDec, _ = r.kindDecoders.Load(valueType.Kind()) } - return newCondAddrDecoder(idec.vd, defaultDec), true + return &condAddrDecoder{canAddrDec: idec.vd, elseDec: defaultDec}, true } } return nil, false diff --git a/bson/registry_test.go b/bson/registry_test.go index 003bb69d6b..f60542b74d 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -568,7 +568,7 @@ func TestRegistryBuilder(t *testing.T) { t.Run("Decoder", func(t *testing.T) { t.Parallel() - wanterr := ErrNilType + wanterr := ErrNoDecoder{Type: nil} gotcodec, goterr := reg.LookupDecoder(nil) if !cmp.Equal(goterr, wanterr, cmp.Comparer(assert.CompareErrors)) { From 7938e909ac39d1ccb61b1388d195a891dd384800 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 14 May 2024 18:05:36 -0400 Subject: [PATCH 07/15] WIP --- bson/bsoncodec.go | 4 ++-- bson/default_value_decoders.go | 6 ++---- bson/empty_interface_codec.go | 2 +- bson/map_codec.go | 3 +-- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index e0369a1111..08e3d4b862 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -171,8 +171,8 @@ type decodeAdapter struct { var _ ValueDecoder = decodeAdapter{} var _ typeDecoder = decodeAdapter{} -func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { - if td != nil { +func decodeTypeOrValueWithInfo(vd ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { + if td := vd.(typeDecoder); td != nil { val, err := td.decodeType(dc, vr, t) if err == nil && convert && val.Type() != t { // This conversion step is necessary for slices and maps. If a user declares variables like: diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 56331da9a8..50f7fc7cb9 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -135,7 +135,6 @@ func dDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { if err != nil { return err } - tEmptyTypeDecoder, _ := decoder.(typeDecoder) // Use the elements in the provided value if it's non nil. Otherwise, allocate a new D instance. var elems D @@ -155,7 +154,7 @@ func dDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { } // Pass false for convert because we don't need to call reflect.Value.Convert for tEmpty. - elem, err := decodeTypeOrValueWithInfo(decoder, tEmptyTypeDecoder, dc, elemVr, tEmpty, false) + elem, err := decodeTypeOrValueWithInfo(decoder, dc, elemVr, tEmpty, false) if err != nil { return err } @@ -1274,7 +1273,6 @@ func decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]refle if err != nil { return nil, err } - eTypeDecoder, _ := decoder.(typeDecoder) idx := 0 for { @@ -1286,7 +1284,7 @@ func decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]refle return nil, err } - elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true) + elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, eType, true) if err != nil { return nil, newDecodeError(strconv.Itoa(idx), err) } diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index 59754ec6b5..b8314a873f 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -99,7 +99,7 @@ func (eic emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t re return emptyValue, err } - elem, err := decodeTypeOrValueWithInfo(decoder, decoder.(typeDecoder), dc, vr, rtype, true) + elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, rtype, true) if err != nil { return emptyValue, err } diff --git a/bson/map_codec.go b/bson/map_codec.go index bfa77ca0d8..5f6963ed26 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -160,7 +160,6 @@ func (mc *mapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va if err != nil { return err } - eTypeDecoder, _ := decoder.(typeDecoder) if eType == tEmpty { dc.ancestor = val.Type() @@ -182,7 +181,7 @@ func (mc *mapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va return err } - elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true) + elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, eType, true) if err != nil { return newDecodeError(key, err) } From b2a1c150928471d7b07d7b7fc366cc998150d4a2 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Mon, 20 May 2024 11:44:27 -0400 Subject: [PATCH 08/15] WIP --- bson/array_codec.go | 6 +- bson/bsoncodec.go | 42 +- bson/bsoncodec_test.go | 4 +- bson/byte_slice_codec.go | 10 +- bson/cond_addr_codec.go | 6 +- bson/cond_addr_codec_test.go | 4 +- bson/default_value_decoders.go | 461 +++++++--------------- bson/default_value_decoders_test.go | 16 +- bson/default_value_encoders.go | 15 +- bson/default_value_encoders_test.go | 24 +- bson/empty_interface_codec.go | 45 +-- bson/float_codec.go | 107 +++++ bson/int_codec.go | 91 +++-- bson/map_codec.go | 17 +- bson/mgoregistry.go | 12 +- bson/pointer_codec.go | 8 +- bson/primitive_codecs.go | 8 +- bson/registry.go | 153 ++++--- bson/registry_examples_test.go | 8 +- bson/registry_test.go | 34 +- bson/setter_getter.go | 2 +- bson/slice_codec.go | 11 +- bson/string_codec.go | 12 +- bson/struct_codec.go | 30 +- bson/time_codec.go | 12 +- bson/uint_codec.go | 157 -------- bson/unmarshal_test.go | 4 +- bson/unmarshal_value_test.go | 4 +- internal/integration/client_test.go | 4 +- internal/integration/unified_spec_test.go | 12 +- 30 files changed, 521 insertions(+), 798 deletions(-) create mode 100644 bson/float_codec.go delete mode 100644 bson/uint_codec.go diff --git a/bson/array_codec.go b/bson/array_codec.go index 757fd60004..76b9a059f7 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -15,10 +15,6 @@ import ( // arrayCodec is the Codec used for bsoncore.Array values. type arrayCodec struct{} -var ( - defaultArrayCodec = &arrayCodec{} -) - // EncodeValue is the ValueEncoder for bsoncore.Array values. func (ac *arrayCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreArray { @@ -30,7 +26,7 @@ func (ac *arrayCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect } // DecodeValue is the ValueDecoder for bsoncore.Array values. -func (ac *arrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +func (ac *arrayCodec) DecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreArray { return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 08e3d4b862..6d70a99ca1 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -77,12 +77,6 @@ func (vde ValueDecoderError) Error() string { type DecodeContext struct { *Registry - // ancestor is the type of a containing document. This is mainly used to determine what type - // should be used when decoding an embedded document into an empty interface. For example, if - // Ancestor is a bson.M, BSON embedded document values being decoded into an empty interface - // will be decoded into a bson.M. - ancestor reflect.Type - // defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the // usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is // set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an @@ -138,28 +132,28 @@ type DecoderRegistry interface { // ValueDecoder. A DecodeContext instance is provided and serves similar functionality to the // EncodeContext. type ValueDecoder interface { - DecodeValue(DecodeContext, ValueReader, reflect.Value) error + DecodeValue(DecoderRegistry, ValueReader, reflect.Value) error } // ValueDecoderFunc is an adapter function that allows a function with the correct signature to be // used as a ValueDecoder. -type ValueDecoderFunc func(DecodeContext, ValueReader, reflect.Value) error +type ValueDecoderFunc func(DecoderRegistry, ValueReader, reflect.Value) error // DecodeValue implements the ValueDecoder interface. -func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - return fn(dc, vr, val) +func (fn ValueDecoderFunc) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { + return fn(reg, vr, val) } // typeDecoder is the interface implemented by types that can handle the decoding of a value given its type. type typeDecoder interface { - decodeType(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error) + decodeType(DecoderRegistry, ValueReader, reflect.Type) (reflect.Value, error) } // typeDecoderFunc is an adapter function that allows a function with the correct signature to be used as a typeDecoder. -type typeDecoderFunc func(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error) +type typeDecoderFunc func(DecoderRegistry, ValueReader, reflect.Type) (reflect.Value, error) -func (fn typeDecoderFunc) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - return fn(dc, vr, t) +func (fn typeDecoderFunc) decodeType(reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { + return fn(reg, vr, t) } // decodeAdapter allows two functions with the correct signatures to be used as both a ValueDecoder and typeDecoder. @@ -171,24 +165,14 @@ type decodeAdapter struct { var _ ValueDecoder = decodeAdapter{} var _ typeDecoder = decodeAdapter{} -func decodeTypeOrValueWithInfo(vd ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { - if td := vd.(typeDecoder); td != nil { - val, err := td.decodeType(dc, vr, t) - if err == nil && convert && val.Type() != t { - // This conversion step is necessary for slices and maps. If a user declares variables like: - // - // type myBool bool - // var m map[string]myBool - // - // and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present - // because we'll try to assign a value of type bool to one of type myBool. - val = val.Convert(t) - } - return val, err +func decodeTypeOrValueWithInfo(vd ValueDecoder, reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { + td, ok := vd.(typeDecoder) + if ok && td != nil { + return td.decodeType(reg, vr, t) } val := reflect.New(t).Elem() - err := vd.DecodeValue(dc, vr, val) + err := vd.DecodeValue(reg, vr, val) return val, err } diff --git a/bson/bsoncodec_test.go b/bson/bsoncodec_test.go index e4ba05d5e1..02db6ca003 100644 --- a/bson/bsoncodec_test.go +++ b/bson/bsoncodec_test.go @@ -18,7 +18,7 @@ type llCodec struct { err error } -func (llc *llCodec) EncodeValue(_ *Registry, _ ValueWriter, i interface{}) error { +func (llc *llCodec) EncodeValue(_ EncoderRegistry, _ ValueWriter, i interface{}) error { if llc.err != nil { return llc.err } @@ -27,7 +27,7 @@ func (llc *llCodec) EncodeValue(_ *Registry, _ ValueWriter, i interface{}) error return nil } -func (llc *llCodec) DecodeValue(_ DecodeContext, _ ValueReader, val reflect.Value) error { +func (llc *llCodec) DecodeValue(_ DecoderRegistry, _ ValueReader, val reflect.Value) error { if llc.err != nil { return llc.err } diff --git a/bson/byte_slice_codec.go b/bson/byte_slice_codec.go index e012c3d913..779ae9ed71 100644 --- a/bson/byte_slice_codec.go +++ b/bson/byte_slice_codec.go @@ -18,10 +18,6 @@ type byteSliceCodec struct { encodeNilAsEmpty bool } -var ( - defaultByteSliceCodec = &byteSliceCodec{} -) - // EncodeValue is the ValueEncoder for []byte. func (bsc *byteSliceCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tByteSlice { @@ -33,7 +29,7 @@ func (bsc *byteSliceCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val re return vw.WriteBinary(val.Interface().([]byte)) } -func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (bsc *byteSliceCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tByteSlice { return emptyValue, ValueDecoderError{ Name: "ByteSliceDecodeValue", @@ -81,12 +77,12 @@ func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect } // DecodeValue is the ValueDecoder for []byte. -func (bsc *byteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (bsc *byteSliceCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tByteSlice { return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } - elem, err := bsc.decodeType(dc, vr, tByteSlice) + elem, err := bsc.decodeType(reg, vr, tByteSlice) if err != nil { return err } diff --git a/bson/cond_addr_codec.go b/bson/cond_addr_codec.go index 897d311a79..cd2727e2cc 100644 --- a/bson/cond_addr_codec.go +++ b/bson/cond_addr_codec.go @@ -38,12 +38,12 @@ type condAddrDecoder struct { var _ ValueDecoder = (*condAddrDecoder)(nil) // DecodeValue is the ValueDecoderFunc for a value that may be addressable. -func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (cad *condAddrDecoder) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if val.CanAddr() { - return cad.canAddrDec.DecodeValue(dc, vr, val) + return cad.canAddrDec.DecodeValue(reg, vr, val) } if cad.elseDec != nil { - return cad.elseDec.DecodeValue(dc, vr, val) + return cad.elseDec.DecodeValue(reg, vr, val) } return ErrNoDecoder{Type: val.Type()} } diff --git a/bson/cond_addr_codec_test.go b/bson/cond_addr_codec_test.go index b0bf63c1cd..15b4a8a333 100644 --- a/bson/cond_addr_codec_test.go +++ b/bson/cond_addr_codec_test.go @@ -58,11 +58,11 @@ func TestCondAddrCodec(t *testing.T) { }) t.Run("addressDecode", func(t *testing.T) { invoked := 0 - decode1 := ValueDecoderFunc(func(DecodeContext, ValueReader, reflect.Value) error { + decode1 := ValueDecoderFunc(func(DecoderRegistry, ValueReader, reflect.Value) error { invoked = 1 return nil }) - decode2 := ValueDecoderFunc(func(DecodeContext, ValueReader, reflect.Value) error { + decode2 := ValueDecoderFunc(func(DecoderRegistry, ValueReader, reflect.Value) error { invoked = 2 return nil }) diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 50f7fc7cb9..ec8b3c8730 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -10,7 +10,6 @@ import ( "encoding/json" "errors" "fmt" - "math" "net/url" "reflect" "strconv" @@ -36,89 +35,88 @@ func (d decodeBinaryError) Error() string { // There is no support for decoding map[string]interface{} because there is no decoder for // interface{}, so users must either register this decoder themselves or use the // EmptyInterfaceDecoder available in the bson package. -func registerDefaultDecoders(reg *RegistryBuilder) { - if reg == nil { +func registerDefaultDecoders(rb *RegistryBuilder) { + if rb == nil { panic(errors.New("argument to RegisterDefaultDecoders must not be nil")) } - intDecoder := decodeAdapter{intDecodeValue, intDecodeType} - floatDecoder := decodeAdapter{floatDecodeValue, floatDecodeType} - - reg.RegisterTypeDecoder(tD, ValueDecoderFunc(dDecodeValue)) - reg.RegisterTypeDecoder(tBinary, decodeAdapter{binaryDecodeValue, binaryDecodeType}) - reg.RegisterTypeDecoder(tUndefined, decodeAdapter{undefinedDecodeValue, undefinedDecodeType}) - reg.RegisterTypeDecoder(tDateTime, decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType}) - reg.RegisterTypeDecoder(tNull, decodeAdapter{nullDecodeValue, nullDecodeType}) - reg.RegisterTypeDecoder(tRegex, decodeAdapter{regexDecodeValue, regexDecodeType}) - reg.RegisterTypeDecoder(tDBPointer, decodeAdapter{dbPointerDecodeValue, dbPointerDecodeType}) - reg.RegisterTypeDecoder(tTimestamp, decodeAdapter{timestampDecodeValue, timestampDecodeType}) - reg.RegisterTypeDecoder(tMinKey, decodeAdapter{minKeyDecodeValue, minKeyDecodeType}) - reg.RegisterTypeDecoder(tMaxKey, decodeAdapter{maxKeyDecodeValue, maxKeyDecodeType}) - reg.RegisterTypeDecoder(tJavaScript, decodeAdapter{javaScriptDecodeValue, javaScriptDecodeType}) - reg.RegisterTypeDecoder(tSymbol, decodeAdapter{symbolDecodeValue, symbolDecodeType}) - reg.RegisterTypeDecoder(tByteSlice, defaultByteSliceCodec) - reg.RegisterTypeDecoder(tTime, defaultTimeCodec) - reg.RegisterTypeDecoder(tEmpty, defaultEmptyInterfaceCodec) - reg.RegisterTypeDecoder(tCoreArray, defaultArrayCodec) - reg.RegisterTypeDecoder(tOID, decodeAdapter{objectIDDecodeValue, objectIDDecodeType}) - reg.RegisterTypeDecoder(tDecimal, decodeAdapter{decimal128DecodeValue, decimal128DecodeType}) - reg.RegisterTypeDecoder(tJSONNumber, decodeAdapter{jsonNumberDecodeValue, jsonNumberDecodeType}) - reg.RegisterTypeDecoder(tURL, decodeAdapter{urlDecodeValue, urlDecodeType}) - reg.RegisterTypeDecoder(tCoreDocument, ValueDecoderFunc(coreDocumentDecodeValue)) - reg.RegisterTypeDecoder(tCodeWithScope, decodeAdapter{codeWithScopeDecodeValue, codeWithScopeDecodeType}) - reg.RegisterKindDecoder(reflect.Bool, decodeAdapter{booleanDecodeValue, booleanDecodeType}) - reg.RegisterKindDecoder(reflect.Int, intDecoder) - reg.RegisterKindDecoder(reflect.Int8, intDecoder) - reg.RegisterKindDecoder(reflect.Int16, intDecoder) - reg.RegisterKindDecoder(reflect.Int32, intDecoder) - reg.RegisterKindDecoder(reflect.Int64, intDecoder) - reg.RegisterKindDecoder(reflect.Uint, defaultUIntCodec) - reg.RegisterKindDecoder(reflect.Uint8, defaultUIntCodec) - reg.RegisterKindDecoder(reflect.Uint16, defaultUIntCodec) - reg.RegisterKindDecoder(reflect.Uint32, defaultUIntCodec) - reg.RegisterKindDecoder(reflect.Uint64, defaultUIntCodec) - reg.RegisterKindDecoder(reflect.Float32, floatDecoder) - reg.RegisterKindDecoder(reflect.Float64, floatDecoder) - reg.RegisterKindDecoder(reflect.Array, ValueDecoderFunc(arrayDecodeValue)) - reg.RegisterKindDecoder(reflect.Map, defaultMapCodec) - reg.RegisterKindDecoder(reflect.Slice, defaultSliceCodec) - reg.RegisterKindDecoder(reflect.String, defaultStringCodec) - reg.RegisterKindDecoder(reflect.Struct, defaultStructCodec) - reg.RegisterKindDecoder(reflect.Ptr, &pointerCodec{}) - reg.RegisterTypeMapEntry(TypeDouble, tFloat64) - reg.RegisterTypeMapEntry(TypeString, tString) - reg.RegisterTypeMapEntry(TypeArray, tA) - reg.RegisterTypeMapEntry(TypeBinary, tBinary) - reg.RegisterTypeMapEntry(TypeUndefined, tUndefined) - reg.RegisterTypeMapEntry(TypeObjectID, tOID) - reg.RegisterTypeMapEntry(TypeBoolean, tBool) - reg.RegisterTypeMapEntry(TypeDateTime, tDateTime) - reg.RegisterTypeMapEntry(TypeRegex, tRegex) - reg.RegisterTypeMapEntry(TypeDBPointer, tDBPointer) - reg.RegisterTypeMapEntry(TypeJavaScript, tJavaScript) - reg.RegisterTypeMapEntry(TypeSymbol, tSymbol) - reg.RegisterTypeMapEntry(TypeCodeWithScope, tCodeWithScope) - reg.RegisterTypeMapEntry(TypeInt32, tInt32) - reg.RegisterTypeMapEntry(TypeInt64, tInt64) - reg.RegisterTypeMapEntry(TypeTimestamp, tTimestamp) - reg.RegisterTypeMapEntry(TypeDecimal128, tDecimal) - reg.RegisterTypeMapEntry(TypeMinKey, tMinKey) - reg.RegisterTypeMapEntry(TypeMaxKey, tMaxKey) - reg.RegisterTypeMapEntry(Type(0), tD) - reg.RegisterTypeMapEntry(TypeEmbeddedDocument, tD) - reg.RegisterInterfaceDecoder(tValueUnmarshaler, ValueDecoderFunc(valueUnmarshalerDecodeValue)) - reg.RegisterInterfaceDecoder(tUnmarshaler, ValueDecoderFunc(unmarshalerDecodeValue)) + intDecoder := func() ValueDecoder { return &intCodec{} } + floatDecoder := func() ValueDecoder { return &floatCodec{} } + rb.RegisterTypeDecoder(tD, func() ValueDecoder { return ValueDecoderFunc(dDecodeValue) }). + RegisterTypeDecoder(tBinary, func() ValueDecoder { return &decodeAdapter{binaryDecodeValue, binaryDecodeType} }). + RegisterTypeDecoder(tUndefined, func() ValueDecoder { return &decodeAdapter{undefinedDecodeValue, undefinedDecodeType} }). + RegisterTypeDecoder(tDateTime, func() ValueDecoder { return &decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType} }). + RegisterTypeDecoder(tNull, func() ValueDecoder { return &decodeAdapter{nullDecodeValue, nullDecodeType} }). + RegisterTypeDecoder(tRegex, func() ValueDecoder { return &decodeAdapter{regexDecodeValue, regexDecodeType} }). + RegisterTypeDecoder(tDBPointer, func() ValueDecoder { return &decodeAdapter{dbPointerDecodeValue, dbPointerDecodeType} }). + RegisterTypeDecoder(tTimestamp, func() ValueDecoder { return &decodeAdapter{timestampDecodeValue, timestampDecodeType} }). + RegisterTypeDecoder(tMinKey, func() ValueDecoder { return &decodeAdapter{minKeyDecodeValue, minKeyDecodeType} }). + RegisterTypeDecoder(tMaxKey, func() ValueDecoder { return &decodeAdapter{maxKeyDecodeValue, maxKeyDecodeType} }). + RegisterTypeDecoder(tJavaScript, func() ValueDecoder { return &decodeAdapter{javaScriptDecodeValue, javaScriptDecodeType} }). + RegisterTypeDecoder(tSymbol, func() ValueDecoder { return &decodeAdapter{symbolDecodeValue, symbolDecodeType} }). + RegisterTypeDecoder(tByteSlice, func() ValueDecoder { return &byteSliceCodec{} }). + RegisterTypeDecoder(tTime, func() ValueDecoder { return &timeCodec{} }). + RegisterTypeDecoder(tEmpty, func() ValueDecoder { return &emptyInterfaceCodec{} }). + RegisterTypeDecoder(tCoreArray, func() ValueDecoder { return &arrayCodec{} }). + RegisterTypeDecoder(tOID, func() ValueDecoder { return &decodeAdapter{objectIDDecodeValue, objectIDDecodeType} }). + RegisterTypeDecoder(tDecimal, func() ValueDecoder { return &decodeAdapter{decimal128DecodeValue, decimal128DecodeType} }). + RegisterTypeDecoder(tJSONNumber, func() ValueDecoder { return &decodeAdapter{jsonNumberDecodeValue, jsonNumberDecodeType} }). + RegisterTypeDecoder(tURL, func() ValueDecoder { return &decodeAdapter{urlDecodeValue, urlDecodeType} }). + RegisterTypeDecoder(tCoreDocument, func() ValueDecoder { return ValueDecoderFunc(coreDocumentDecodeValue) }). + RegisterTypeDecoder(tCodeWithScope, func() ValueDecoder { return &decodeAdapter{codeWithScopeDecodeValue, codeWithScopeDecodeType} }). + RegisterKindDecoder(reflect.Bool, func() ValueDecoder { return &decodeAdapter{booleanDecodeValue, booleanDecodeType} }). + RegisterKindDecoder(reflect.Int, intDecoder). + RegisterKindDecoder(reflect.Int8, intDecoder). + RegisterKindDecoder(reflect.Int16, intDecoder). + RegisterKindDecoder(reflect.Int32, intDecoder). + RegisterKindDecoder(reflect.Int64, intDecoder). + RegisterKindDecoder(reflect.Uint, intDecoder). + RegisterKindDecoder(reflect.Uint8, intDecoder). + RegisterKindDecoder(reflect.Uint16, intDecoder). + RegisterKindDecoder(reflect.Uint32, intDecoder). + RegisterKindDecoder(reflect.Uint64, intDecoder). + RegisterKindDecoder(reflect.Float32, floatDecoder). + RegisterKindDecoder(reflect.Float64, floatDecoder). + RegisterKindDecoder(reflect.Array, func() ValueDecoder { return ValueDecoderFunc(arrayDecodeValue) }). + RegisterKindDecoder(reflect.Map, func() ValueDecoder { return &mapCodec{} }). + RegisterKindDecoder(reflect.Slice, func() ValueDecoder { return &sliceCodec{} }). + RegisterKindDecoder(reflect.String, func() ValueDecoder { return &stringCodec{} }). + RegisterKindDecoder(reflect.Struct, func() ValueDecoder { return newStructCodec(DefaultStructTagParser) }). + RegisterKindDecoder(reflect.Ptr, func() ValueDecoder { return &pointerCodec{} }). + RegisterTypeMapEntry(TypeDouble, tFloat64). + RegisterTypeMapEntry(TypeString, tString). + RegisterTypeMapEntry(TypeArray, tA). + RegisterTypeMapEntry(TypeBinary, tBinary). + RegisterTypeMapEntry(TypeUndefined, tUndefined). + RegisterTypeMapEntry(TypeObjectID, tOID). + RegisterTypeMapEntry(TypeBoolean, tBool). + RegisterTypeMapEntry(TypeDateTime, tDateTime). + RegisterTypeMapEntry(TypeRegex, tRegex). + RegisterTypeMapEntry(TypeDBPointer, tDBPointer). + RegisterTypeMapEntry(TypeJavaScript, tJavaScript). + RegisterTypeMapEntry(TypeSymbol, tSymbol). + RegisterTypeMapEntry(TypeCodeWithScope, tCodeWithScope). + RegisterTypeMapEntry(TypeInt32, tInt32). + RegisterTypeMapEntry(TypeInt64, tInt64). + RegisterTypeMapEntry(TypeTimestamp, tTimestamp). + RegisterTypeMapEntry(TypeDecimal128, tDecimal). + RegisterTypeMapEntry(TypeMinKey, tMinKey). + RegisterTypeMapEntry(TypeMaxKey, tMaxKey). + RegisterTypeMapEntry(Type(0), tD). + RegisterTypeMapEntry(TypeEmbeddedDocument, tD). + RegisterInterfaceDecoder(tValueUnmarshaler, func() ValueDecoder { return ValueDecoderFunc(valueUnmarshalerDecodeValue) }). + RegisterInterfaceDecoder(tUnmarshaler, func() ValueDecoder { return ValueDecoderFunc(unmarshalerDecodeValue) }) } // dDecodeValue is the ValueDecoderFunc for D instances. -func dDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func dDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || !val.CanSet() || val.Type() != tD { return ValueDecoderError{Name: "DDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } switch vrType := vr.Type(); vrType { case Type(0), TypeEmbeddedDocument: - dc.ancestor = tD + break case TypeNull: val.Set(reflect.Zero(val.Type())) return vr.ReadNull() @@ -131,7 +129,7 @@ func dDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { return err } - decoder, err := dc.LookupDecoder(tEmpty) + decoder, err := reg.LookupDecoder(tEmpty) if err != nil { return err } @@ -153,8 +151,7 @@ func dDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { return err } - // Pass false for convert because we don't need to call reflect.Value.Convert for tEmpty. - elem, err := decodeTypeOrValueWithInfo(decoder, dc, elemVr, tEmpty, false) + elem, err := decodeTypeOrValueWithInfo(decoder, reg, elemVr, tEmpty) if err != nil { return err } @@ -166,7 +163,7 @@ func dDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { return nil } -func booleanDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func booleanDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t.Kind() != reflect.Bool { return emptyValue, ValueDecoderError{ Name: "BooleanDecodeValue", @@ -213,12 +210,12 @@ func booleanDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect } // booleanDecodeValue is the ValueDecoderFunc for bool types. -func booleanDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +func booleanDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool { return ValueDecoderError{Name: "BooleanDecodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } - elem, err := booleanDecodeType(dctx, vr, val.Type()) + elem, err := booleanDecodeType(reg, vr, val.Type()) if err != nil { return err } @@ -227,187 +224,7 @@ func booleanDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) e return nil } -func intDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - var i64 int64 - var err error - switch vrType := vr.Type(); vrType { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return emptyValue, err - } - i64 = int64(i32) - case TypeInt64: - i64, err = vr.ReadInt64() - if err != nil { - return emptyValue, err - } - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return emptyValue, err - } - if !dc.truncate && math.Floor(f64) != f64 { - return emptyValue, errCannotTruncate - } - if f64 > float64(math.MaxInt64) { - return emptyValue, fmt.Errorf("%g overflows int64", f64) - } - i64 = int64(f64) - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return emptyValue, err - } - if b { - i64 = 1 - } - case TypeNull: - if err = vr.ReadNull(); err != nil { - return emptyValue, err - } - case TypeUndefined: - if err = vr.ReadUndefined(); err != nil { - return emptyValue, err - } - default: - return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType) - } - - switch t.Kind() { - case reflect.Int8: - if i64 < math.MinInt8 || i64 > math.MaxInt8 { - return emptyValue, fmt.Errorf("%d overflows int8", i64) - } - - return reflect.ValueOf(int8(i64)), nil - case reflect.Int16: - if i64 < math.MinInt16 || i64 > math.MaxInt16 { - return emptyValue, fmt.Errorf("%d overflows int16", i64) - } - - return reflect.ValueOf(int16(i64)), nil - case reflect.Int32: - if i64 < math.MinInt32 || i64 > math.MaxInt32 { - return emptyValue, fmt.Errorf("%d overflows int32", i64) - } - - return reflect.ValueOf(int32(i64)), nil - case reflect.Int64: - return reflect.ValueOf(i64), nil - case reflect.Int: - if int64(int(i64)) != i64 { // Can we fit this inside of an int - return emptyValue, fmt.Errorf("%d overflows int", i64) - } - - return reflect.ValueOf(int(i64)), nil - default: - return emptyValue, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, - Received: reflect.Zero(t), - } - } -} - -// intDecodeValue is the ValueDecoderFunc for int types. -func intDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() { - return ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, - Received: val, - } - } - - elem, err := intDecodeType(dc, vr, val.Type()) - if err != nil { - return err - } - - val.SetInt(elem.Int()) - return nil -} - -func floatDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - var f float64 - var err error - switch vrType := vr.Type(); vrType { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return emptyValue, err - } - f = float64(i32) - case TypeInt64: - i64, err := vr.ReadInt64() - if err != nil { - return emptyValue, err - } - f = float64(i64) - case TypeDouble: - f, err = vr.ReadDouble() - if err != nil { - return emptyValue, err - } - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return emptyValue, err - } - if b { - f = 1 - } - case TypeNull: - if err = vr.ReadNull(); err != nil { - return emptyValue, err - } - case TypeUndefined: - if err = vr.ReadUndefined(); err != nil { - return emptyValue, err - } - default: - return emptyValue, fmt.Errorf("cannot decode %v into a float32 or float64 type", vrType) - } - - switch t.Kind() { - case reflect.Float32: - if !dc.truncate && float64(float32(f)) != f { - return emptyValue, errCannotTruncate - } - - return reflect.ValueOf(float32(f)), nil - case reflect.Float64: - return reflect.ValueOf(f), nil - default: - return emptyValue, ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, - Received: reflect.Zero(t), - } - } -} - -// floatDecodeValue is the ValueDecoderFunc for float types. -func floatDecodeValue(ec DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() { - return ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, - Received: val, - } - } - - elem, err := floatDecodeType(ec, vr, val.Type()) - if err != nil { - return err - } - - val.SetFloat(elem.Float()) - return nil -} - -func javaScriptDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func javaScriptDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tJavaScript { return emptyValue, ValueDecoderError{ Name: "JavaScriptDecodeValue", @@ -436,12 +253,12 @@ func javaScriptDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (refl } // javaScriptDecodeValue is the ValueDecoderFunc for the JavaScript type. -func javaScriptDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +func javaScriptDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tJavaScript { return ValueDecoderError{Name: "JavaScriptDecodeValue", Types: []reflect.Type{tJavaScript}, Received: val} } - elem, err := javaScriptDecodeType(dctx, vr, tJavaScript) + elem, err := javaScriptDecodeType(reg, vr, tJavaScript) if err != nil { return err } @@ -450,7 +267,7 @@ func javaScriptDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value return nil } -func symbolDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func symbolDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tSymbol { return emptyValue, ValueDecoderError{ Name: "SymbolDecodeValue", @@ -491,12 +308,12 @@ func symbolDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect. } // symbolDecodeValue is the ValueDecoderFunc for the Symbol type. -func symbolDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +func symbolDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tSymbol { return ValueDecoderError{Name: "SymbolDecodeValue", Types: []reflect.Type{tSymbol}, Received: val} } - elem, err := symbolDecodeType(dctx, vr, tSymbol) + elem, err := symbolDecodeType(reg, vr, tSymbol) if err != nil { return err } @@ -505,7 +322,7 @@ func symbolDecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) er return nil } -func binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func binaryDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tBinary { return emptyValue, ValueDecoderError{ Name: "BinaryDecodeValue", @@ -535,12 +352,12 @@ func binaryDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect. } // binaryDecodeValue is the ValueDecoderFunc for Binary. -func binaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func binaryDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tBinary { return ValueDecoderError{Name: "BinaryDecodeValue", Types: []reflect.Type{tBinary}, Received: val} } - elem, err := binaryDecodeType(dc, vr, tBinary) + elem, err := binaryDecodeType(reg, vr, tBinary) if err != nil { return err } @@ -549,7 +366,7 @@ func binaryDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) erro return nil } -func undefinedDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func undefinedDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tUndefined { return emptyValue, ValueDecoderError{ Name: "UndefinedDecodeValue", @@ -575,12 +392,12 @@ func undefinedDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (refle } // undefinedDecodeValue is the ValueDecoderFunc for Undefined. -func undefinedDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func undefinedDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tUndefined { return ValueDecoderError{Name: "UndefinedDecodeValue", Types: []reflect.Type{tUndefined}, Received: val} } - elem, err := undefinedDecodeType(dc, vr, tUndefined) + elem, err := undefinedDecodeType(reg, vr, tUndefined) if err != nil { return err } @@ -590,7 +407,7 @@ func undefinedDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) e } // Accept both 12-byte string and pretty-printed 24-byte hex string formats. -func objectIDDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func objectIDDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tOID { return emptyValue, ValueDecoderError{ Name: "ObjectIDDecodeValue", @@ -636,12 +453,12 @@ func objectIDDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflec } // objectIDDecodeValue is the ValueDecoderFunc for ObjectID. -func objectIDDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func objectIDDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tOID { return ValueDecoderError{Name: "ObjectIDDecodeValue", Types: []reflect.Type{tOID}, Received: val} } - elem, err := objectIDDecodeType(dc, vr, tOID) + elem, err := objectIDDecodeType(reg, vr, tOID) if err != nil { return err } @@ -650,7 +467,7 @@ func objectIDDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) er return nil } -func dateTimeDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func dateTimeDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDateTime { return emptyValue, ValueDecoderError{ Name: "DateTimeDecodeValue", @@ -679,12 +496,12 @@ func dateTimeDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflec } // dateTimeDecodeValue is the ValueDecoderFunc for DateTime. -func dateTimeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func dateTimeDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDateTime { return ValueDecoderError{Name: "DateTimeDecodeValue", Types: []reflect.Type{tDateTime}, Received: val} } - elem, err := dateTimeDecodeType(dc, vr, tDateTime) + elem, err := dateTimeDecodeType(reg, vr, tDateTime) if err != nil { return err } @@ -693,7 +510,7 @@ func dateTimeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) er return nil } -func nullDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func nullDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tNull { return emptyValue, ValueDecoderError{ Name: "NullDecodeValue", @@ -719,12 +536,12 @@ func nullDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Va } // nullDecodeValue is the ValueDecoderFunc for Null. -func nullDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func nullDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tNull { return ValueDecoderError{Name: "NullDecodeValue", Types: []reflect.Type{tNull}, Received: val} } - elem, err := nullDecodeType(dc, vr, tNull) + elem, err := nullDecodeType(reg, vr, tNull) if err != nil { return err } @@ -733,7 +550,7 @@ func nullDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error return nil } -func regexDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func regexDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tRegex { return emptyValue, ValueDecoderError{ Name: "RegexDecodeValue", @@ -762,12 +579,12 @@ func regexDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.V } // regexDecodeValue is the ValueDecoderFunc for Regex. -func regexDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func regexDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRegex { return ValueDecoderError{Name: "RegexDecodeValue", Types: []reflect.Type{tRegex}, Received: val} } - elem, err := regexDecodeType(dc, vr, tRegex) + elem, err := regexDecodeType(reg, vr, tRegex) if err != nil { return err } @@ -776,7 +593,7 @@ func regexDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error return nil } -func dbPointerDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func dbPointerDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDBPointer { return emptyValue, ValueDecoderError{ Name: "DBPointerDecodeValue", @@ -806,12 +623,12 @@ func dbPointerDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (refle } // dbPointerDecodeValue is the ValueDecoderFunc for DBPointer. -func dbPointerDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func dbPointerDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDBPointer { return ValueDecoderError{Name: "DBPointerDecodeValue", Types: []reflect.Type{tDBPointer}, Received: val} } - elem, err := dbPointerDecodeType(dc, vr, tDBPointer) + elem, err := dbPointerDecodeType(reg, vr, tDBPointer) if err != nil { return err } @@ -820,7 +637,7 @@ func dbPointerDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) e return nil } -func timestampDecodeType(_ DecodeContext, vr ValueReader, reflectType reflect.Type) (reflect.Value, error) { +func timestampDecodeType(_ DecoderRegistry, vr ValueReader, reflectType reflect.Type) (reflect.Value, error) { if reflectType != tTimestamp { return emptyValue, ValueDecoderError{ Name: "TimestampDecodeValue", @@ -849,12 +666,12 @@ func timestampDecodeType(_ DecodeContext, vr ValueReader, reflectType reflect.Ty } // timestampDecodeValue is the ValueDecoderFunc for Timestamp. -func timestampDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func timestampDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tTimestamp { return ValueDecoderError{Name: "TimestampDecodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } - elem, err := timestampDecodeType(dc, vr, tTimestamp) + elem, err := timestampDecodeType(reg, vr, tTimestamp) if err != nil { return err } @@ -863,7 +680,7 @@ func timestampDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) e return nil } -func minKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func minKeyDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tMinKey { return emptyValue, ValueDecoderError{ Name: "MinKeyDecodeValue", @@ -891,12 +708,12 @@ func minKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect. } // minKeyDecodeValue is the ValueDecoderFunc for MinKey. -func minKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func minKeyDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tMinKey { return ValueDecoderError{Name: "MinKeyDecodeValue", Types: []reflect.Type{tMinKey}, Received: val} } - elem, err := minKeyDecodeType(dc, vr, tMinKey) + elem, err := minKeyDecodeType(reg, vr, tMinKey) if err != nil { return err } @@ -905,7 +722,7 @@ func minKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) erro return nil } -func maxKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func maxKeyDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tMaxKey { return emptyValue, ValueDecoderError{ Name: "MaxKeyDecodeValue", @@ -933,12 +750,12 @@ func maxKeyDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect. } // maxKeyDecodeValue is the ValueDecoderFunc for MaxKey. -func maxKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func maxKeyDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tMaxKey { return ValueDecoderError{Name: "MaxKeyDecodeValue", Types: []reflect.Type{tMaxKey}, Received: val} } - elem, err := maxKeyDecodeType(dc, vr, tMaxKey) + elem, err := maxKeyDecodeType(reg, vr, tMaxKey) if err != nil { return err } @@ -947,7 +764,7 @@ func maxKeyDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) erro return nil } -func decimal128DecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func decimal128DecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tDecimal { return emptyValue, ValueDecoderError{ Name: "Decimal128DecodeValue", @@ -976,12 +793,12 @@ func decimal128DecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (refl } // decimal128DecodeValue is the ValueDecoderFunc for Decimal128. -func decimal128DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +func decimal128DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tDecimal { return ValueDecoderError{Name: "Decimal128DecodeValue", Types: []reflect.Type{tDecimal}, Received: val} } - elem, err := decimal128DecodeType(dctx, vr, tDecimal) + elem, err := decimal128DecodeType(reg, vr, tDecimal) if err != nil { return err } @@ -990,7 +807,7 @@ func decimal128DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value return nil } -func jsonNumberDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func jsonNumberDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tJSONNumber { return emptyValue, ValueDecoderError{ Name: "JSONNumberDecodeValue", @@ -1035,12 +852,12 @@ func jsonNumberDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (refl } // jsonNumberDecodeValue is the ValueDecoderFunc for json.Number. -func jsonNumberDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func jsonNumberDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tJSONNumber { return ValueDecoderError{Name: "JSONNumberDecodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} } - elem, err := jsonNumberDecodeType(dc, vr, tJSONNumber) + elem, err := jsonNumberDecodeType(reg, vr, tJSONNumber) if err != nil { return err } @@ -1049,7 +866,7 @@ func jsonNumberDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) return nil } -func urlDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func urlDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tURL { return emptyValue, ValueDecoderError{ Name: "URLDecodeValue", @@ -1084,12 +901,12 @@ func urlDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Val } // urlDecodeValue is the ValueDecoderFunc for url.URL. -func urlDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func urlDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tURL { return ValueDecoderError{Name: "URLDecodeValue", Types: []reflect.Type{tURL}, Received: val} } - elem, err := urlDecodeType(dc, vr, tURL) + elem, err := urlDecodeType(reg, vr, tURL) if err != nil { return err } @@ -1099,7 +916,7 @@ func urlDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { } // arrayDecodeValue is the ValueDecoderFunc for array types. -func arrayDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func arrayDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueDecoderError{Name: "ArrayDecodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} } @@ -1140,7 +957,7 @@ func arrayDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error return fmt.Errorf("cannot decode %v into an array", vrType) } - var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error) + var elemsFunc func(DecoderRegistry, ValueReader, reflect.Value) ([]reflect.Value, error) switch val.Type().Elem() { case tE: elemsFunc = decodeD @@ -1148,7 +965,7 @@ func arrayDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error elemsFunc = decodeDefault } - elems, err := elemsFunc(dc, vr, val) + elems, err := elemsFunc(reg, vr, val) if err != nil { return err } @@ -1165,7 +982,7 @@ func arrayDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error } // valueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations. -func valueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +func valueUnmarshalerDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tValueUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tValueUnmarshaler)) { return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } @@ -1198,7 +1015,7 @@ func valueUnmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Va } // unmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations. -func unmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +func unmarshalerDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tUnmarshaler)) { return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} } @@ -1243,7 +1060,7 @@ func unmarshalerDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) } // coreDocumentDecodeValue is the ValueDecoderFunc for bsoncore.Document. -func coreDocumentDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +func coreDocumentDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCoreDocument { return ValueDecoderError{Name: "CoreDocumentDecodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -1259,7 +1076,7 @@ func coreDocumentDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) return err } -func decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]reflect.Value, error) { +func decodeDefault(reg DecoderRegistry, vr ValueReader, val reflect.Value) ([]reflect.Value, error) { elems := make([]reflect.Value, 0) ar, err := vr.ReadArray() @@ -1269,7 +1086,7 @@ func decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]refle eType := val.Type().Elem() - decoder, err := dc.LookupDecoder(eType) + decoder, err := reg.LookupDecoder(eType) if err != nil { return nil, err } @@ -1284,7 +1101,7 @@ func decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]refle return nil, err } - elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, eType, true) + elem, err := decodeTypeOrValueWithInfo(decoder, reg, vr, eType) if err != nil { return nil, newDecodeError(strconv.Itoa(idx), err) } @@ -1295,7 +1112,7 @@ func decodeDefault(dc DecodeContext, vr ValueReader, val reflect.Value) ([]refle return elems, nil } -func codeWithScopeDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func codeWithScopeDecodeType(reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tCodeWithScope { return emptyValue, ValueDecoderError{ Name: "CodeWithScopeDecodeValue", @@ -1314,7 +1131,7 @@ func codeWithScopeDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) ( } scope := reflect.New(tD).Elem() - elems, err := decodeElemsFromDocumentReader(dc, dr) + elems, err := decodeElemsFromDocumentReader(reg, dr, tEmpty) if err != nil { return emptyValue, err } @@ -1341,12 +1158,12 @@ func codeWithScopeDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) ( } // codeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope. -func codeWithScopeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func codeWithScopeDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tCodeWithScope { return ValueDecoderError{Name: "CodeWithScopeDecodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} } - elem, err := codeWithScopeDecodeType(dc, vr, tCodeWithScope) + elem, err := codeWithScopeDecodeType(reg, vr, tCodeWithScope) if err != nil { return err } @@ -1355,7 +1172,7 @@ func codeWithScopeDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Valu return nil } -func decodeD(dc DecodeContext, vr ValueReader, _ reflect.Value) ([]reflect.Value, error) { +func decodeD(reg DecoderRegistry, vr ValueReader, val reflect.Value) ([]reflect.Value, error) { switch vr.Type() { case Type(0), TypeEmbeddedDocument: default: @@ -1367,11 +1184,11 @@ func decodeD(dc DecodeContext, vr ValueReader, _ reflect.Value) ([]reflect.Value return nil, err } - return decodeElemsFromDocumentReader(dc, dr) + return decodeElemsFromDocumentReader(reg, dr, val.Type()) } -func decodeElemsFromDocumentReader(dc DecodeContext, dr DocumentReader) ([]reflect.Value, error) { - decoder, err := dc.LookupDecoder(tEmpty) +func decodeElemsFromDocumentReader(reg DecoderRegistry, dr DocumentReader, t reflect.Type) ([]reflect.Value, error) { + decoder, err := reg.LookupDecoder(tEmpty) if err != nil { return nil, err } @@ -1386,8 +1203,8 @@ func decodeElemsFromDocumentReader(dc DecodeContext, dr DocumentReader) ([]refle return nil, err } - val := reflect.New(tEmpty).Elem() - err = decoder.DecodeValue(dc, vr, val) + var val reflect.Value + val, err = decodeTypeOrValueWithInfo(decoder, reg, vr, t) if err != nil { return nil, newDecodeError(key, err) } diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 0e32e64ba7..24892adeae 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -139,7 +139,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "IntDecodeValue", - ValueDecoderFunc(intDecodeValue), + &intCodec{}, []subtest{ { "wrong type", @@ -371,7 +371,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "defaultUIntCodec.DecodeValue", - &uintCodec{}, + &intCodec{}, []subtest{ { "wrong type", @@ -607,7 +607,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "FloatDecodeValue", - ValueDecoderFunc(floatDecodeValue), + &floatCodec{}, []subtest{ { "wrong type", @@ -3341,7 +3341,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } want := errors.New("DecodeValue failure error") - llc := &llCodec{t: t, err: want} + llc := func() ValueDecoder { return &llCodec{t: t, err: want} } reg := newTestRegistryBuilder(). RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). @@ -3357,7 +3357,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("Success", func(t *testing.T) { want := tc.val - llc := &llCodec{t: t, decodeval: tc.val} + llc := func() ValueDecoder { return &llCodec{t: t, decodeval: tc.val} } reg := newTestRegistryBuilder(). RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). @@ -3501,11 +3501,11 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("decode errors contain key information", func(t *testing.T) { decodeValueError := errors.New("decode value error") - emptyInterfaceErrorDecode := func(DecodeContext, ValueReader, reflect.Value) error { + emptyInterfaceErrorDecode := func(DecoderRegistry, ValueReader, reflect.Value) error { return decodeValueError } emptyInterfaceErrorRegistry := newTestRegistryBuilder(). - RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)). + RegisterTypeDecoder(tEmpty, func() ValueDecoder { return ValueDecoderFunc(emptyInterfaceErrorDecode) }). Build() // Set up a document {foo: 10} and an error that would happen if the value were decoded into interface{} @@ -3560,7 +3560,7 @@ func TestDefaultValueDecoders(t *testing.T) { // Use a registry that has all default decoders with the custom interface{} decoder that always errors. nestedRegistryBuilder := newTestRegistryBuilder() registerDefaultDecoders(nestedRegistryBuilder) - nestedRegistryBuilder.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) + nestedRegistryBuilder.RegisterTypeDecoder(tEmpty, func() ValueDecoder { return ValueDecoderFunc(emptyInterfaceErrorDecode) }) nestedErr := &DecodeError{ keys: []string{"fourth", "1", "third", "randomKey", "second", "first"}, wrapped: decodeValueError, diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index ca6a4a9cad..56d48e3722 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -54,8 +54,9 @@ func registerDefaultEncoders(rb *RegistryBuilder) { if rb == nil { panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) } + intEncoder := func() ValueEncoder { return &intCodec{} } - floatEncoder := func() ValueEncoder { return ValueEncoderFunc(floatEncodeValue) } + floatEncoder := func() ValueEncoder { return &floatCodec{} } rb.RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{} }). RegisterTypeEncoder(tTime, func() ValueEncoder { return &timeCodec{} }). RegisterTypeEncoder(tEmpty, func() ValueEncoder { return &emptyInterfaceCodec{} }). @@ -113,16 +114,6 @@ func fitsIn32Bits(i int64) bool { return math.MinInt32 <= i && i <= math.MaxInt32 } -// floatEncodeValue is the ValueEncoderFunc for float types. -func floatEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { - switch val.Kind() { - case reflect.Float32, reflect.Float64: - return vw.WriteDouble(val.Float()) - } - - return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} -} - // objectIDEncodeValue is the ValueEncoderFunc for ObjectID. func objectIDEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tOID { @@ -160,7 +151,7 @@ func jsonNumberEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Valu return err } - return floatEncodeValue(reg, vw, reflect.ValueOf(f64)) + return (&floatCodec{}).EncodeValue(reg, vw, reflect.ValueOf(f64)) } // urlEncodeValue is the ValueEncoderFunc for url.URL. diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index cd8efe72db..47cb21e6c1 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -182,7 +182,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "FloatEncodeValue", - ValueEncoderFunc(floatEncodeValue), + &floatCodec{}, []subtest{ { "wrong type", @@ -1077,28 +1077,6 @@ func TestDefaultValueEncoders(t *testing.T) { }, }, }, - { - "StructEncodeValue", - defaultStructCodec, - []subtest{ - { - "interface value", - struct{ Foo myInterface }{Foo: myStruct{1}}, - buildDefaultRegistry(), - nil, - writeDocumentEnd, - nil, - }, - { - "nil interface value", - struct{ Foo myInterface }{Foo: nil}, - buildDefaultRegistry(), - nil, - writeDocumentEnd, - nil, - }, - }, - }, { "CodeWithScopeEncodeValue", ValueEncoderFunc(codeWithScopeEncodeValue), diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index b8314a873f..e428176d2d 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -12,15 +12,17 @@ import ( // emptyInterfaceCodec is the Codec used for interface{} values. type emptyInterfaceCodec struct { + // defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the + // usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is + // set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an + // error. DocumentType overrides the Ancestor field. + defaultDocumentType reflect.Type + // decodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the // "Generic" or "Old" BSON binary subtype as a Go byte slice instead of a Binary. decodeBinaryAsSlice bool } -var ( - defaultEmptyInterfaceCodec = &emptyInterfaceCodec{} -) - // EncodeValue is the ValueEncoderFunc for interface{}. func (eic emptyInterfaceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tEmpty { @@ -38,23 +40,23 @@ func (eic emptyInterfaceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, return encoder.EncodeValue(reg, vw, val.Elem()) } -func (eic emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) { +func (eic emptyInterfaceCodec) getEmptyInterfaceDecodeType(reg DecoderRegistry, valueType Type, ancestorType reflect.Type) (reflect.Type, error) { isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument if isDocument { - if dc.defaultDocumentType != nil { + if eic.defaultDocumentType != nil { // If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return // that type. - return dc.defaultDocumentType, nil + return eic.defaultDocumentType, nil } - if dc.ancestor != nil { + if ancestorType != nil && ancestorType != tEmpty { // Using ancestor information rather than looking up the type map entry forces consistent decoding. // If we're decoding into a bson.D, subdocuments should also be decoded as bson.D, even if a type map entry // has been registered. - return dc.ancestor, nil + return ancestorType, nil } } - rtype, err := dc.LookupTypeMapEntry(valueType) + rtype, err := reg.LookupTypeMapEntry(valueType) if err == nil { return rtype, nil } @@ -70,7 +72,7 @@ func (eic emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, val lookupType = Type(0) } - rtype, err = dc.LookupTypeMapEntry(lookupType) + rtype, err = reg.LookupTypeMapEntry(lookupType) if err == nil { return rtype, nil } @@ -79,12 +81,8 @@ func (eic emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, val return nil, err } -func (eic emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - if t != tEmpty { - return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)} - } - - rtype, err := eic.getEmptyInterfaceDecodeType(dc, vr.Type()) +func (eic emptyInterfaceCodec) decodeType(reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { + rtype, err := eic.getEmptyInterfaceDecodeType(reg, vr.Type(), t) if err != nil { switch vr.Type() { case TypeNull: @@ -94,17 +92,20 @@ func (eic emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t re } } - decoder, err := dc.LookupDecoder(rtype) + decoder, err := reg.LookupDecoder(rtype) if err != nil { return emptyValue, err } - elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, rtype, true) + elem, err := decodeTypeOrValueWithInfo(decoder, reg, vr, rtype) if err != nil { return emptyValue, err } + if elem.Type() != rtype { + elem = elem.Convert(rtype) + } - if (eic.decodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary { + if eic.decodeBinaryAsSlice && rtype == tBinary { binElem := elem.Interface().(Binary) if binElem.Subtype == TypeBinaryGeneric || binElem.Subtype == TypeBinaryBinaryOld { elem = reflect.ValueOf(binElem.Data) @@ -115,12 +116,12 @@ func (eic emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t re } // DecodeValue is the ValueDecoderFunc for interface{}. -func (eic emptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (eic emptyInterfaceCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tEmpty { return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} } - elem, err := eic.decodeType(dc, vr, val.Type()) + elem, err := eic.decodeType(reg, vr, val.Type()) if err != nil { return err } diff --git a/bson/float_codec.go b/bson/float_codec.go new file mode 100644 index 0000000000..aa99857877 --- /dev/null +++ b/bson/float_codec.go @@ -0,0 +1,107 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "fmt" + "reflect" +) + +type floatCodec struct { + // truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" + // values when attempting to unmarshal them into a Go float struct field. The truncation logic + // does not apply to BSON "decimal128" values. + truncate bool +} + +// floatEncodeValue is the ValueEncoderFunc for float types. +func (fc *floatCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { + switch val.Kind() { + case reflect.Float32, reflect.Float64: + return vw.WriteDouble(val.Float()) + } + + return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} +} + +func (fc *floatCodec) floatDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { + var f float64 + var err error + switch vrType := vr.Type(); vrType { + case TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return emptyValue, err + } + f = float64(i32) + case TypeInt64: + i64, err := vr.ReadInt64() + if err != nil { + return emptyValue, err + } + f = float64(i64) + case TypeDouble: + f, err = vr.ReadDouble() + if err != nil { + return emptyValue, err + } + case TypeBoolean: + b, err := vr.ReadBoolean() + if err != nil { + return emptyValue, err + } + if b { + f = 1 + } + case TypeNull: + if err = vr.ReadNull(); err != nil { + return emptyValue, err + } + case TypeUndefined: + if err = vr.ReadUndefined(); err != nil { + return emptyValue, err + } + default: + return emptyValue, fmt.Errorf("cannot decode %v into a float32 or float64 type", vrType) + } + + switch t.Kind() { + case reflect.Float32: + if !fc.truncate && float64(float32(f)) != f { + return emptyValue, errCannotTruncate + } + + return reflect.ValueOf(float32(f)), nil + case reflect.Float64: + return reflect.ValueOf(f), nil + default: + return emptyValue, ValueDecoderError{ + Name: "FloatDecodeValue", + Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Received: reflect.Zero(t), + } + } +} + +// DecodeValue is the ValueDecoder for float types. +func (fc *floatCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { + if !val.CanSet() { + return ValueDecoderError{ + Name: "FloatDecodeValue", + Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Received: val, + } + } + + elem, err := fc.floatDecodeType(reg, vr, val.Type()) + if err != nil { + return err + } + + val.SetFloat(elem.Float()) + return nil +} diff --git a/bson/int_codec.go b/bson/int_codec.go index 4d82092309..4caff3aa5a 100644 --- a/bson/int_codec.go +++ b/bson/int_codec.go @@ -70,117 +70,105 @@ func (ic *intCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.V } } -// DecodeValue is the ValueDecoder for uint types. -func (ic *intCodec) DecodeValue(_ *Registry, vr ValueReader, val reflect.Value) error { - if !val.CanSet() { - return ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{ - reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, - reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, - }, - Received: val, - } - } - +func (ic *intCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { var i64 int64 switch vrType := vr.Type(); vrType { case TypeInt32: i32, err := vr.ReadInt32() if err != nil { - return err + return emptyValue, err } i64 = int64(i32) case TypeInt64: var err error i64, err = vr.ReadInt64() if err != nil { - return err + return emptyValue, err } case TypeDouble: f64, err := vr.ReadDouble() if err != nil { - return err + return emptyValue, err } if !ic.truncate && math.Floor(f64) != f64 { - return errCannotTruncate + return emptyValue, errCannotTruncate } if f64 > float64(math.MaxInt64) { - return fmt.Errorf("%g overflows int64", f64) + return emptyValue, fmt.Errorf("%g overflows int64", f64) } i64 = int64(f64) case TypeBoolean: b, err := vr.ReadBoolean() if err != nil { - return err + return emptyValue, err } if b { i64 = 1 } case TypeNull: if err := vr.ReadNull(); err != nil { - return err + return emptyValue, err } case TypeUndefined: if err := vr.ReadUndefined(); err != nil { - return err + return emptyValue, err } default: - return fmt.Errorf("cannot decode %v into an integer type", vrType) + return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType) } - switch t := val.Type(); t.Kind() { + switch t.Kind() { case reflect.Int8: if i64 < math.MinInt8 || i64 > math.MaxInt8 { - return fmt.Errorf("%d overflows int8", i64) + return emptyValue, fmt.Errorf("%d overflows int8", i64) } - val.SetInt(i64) + return reflect.ValueOf(int8(i64)), nil case reflect.Int16: if i64 < math.MinInt16 || i64 > math.MaxInt16 { - return fmt.Errorf("%d overflows int16", i64) + return emptyValue, fmt.Errorf("%d overflows int16", i64) } - val.SetInt(i64) + return reflect.ValueOf(int16(i64)), nil case reflect.Int32: if i64 < math.MinInt32 || i64 > math.MaxInt32 { - return fmt.Errorf("%d overflows int32", i64) + return emptyValue, fmt.Errorf("%d overflows int32", i64) } - val.SetInt(i64) + return reflect.ValueOf(int32(i64)), nil case reflect.Int64: - val.SetInt(i64) + return reflect.ValueOf(i64), nil case reflect.Int: if int64(int(i64)) != i64 { // Can we fit this inside of an int - return fmt.Errorf("%d overflows int", i64) + return emptyValue, fmt.Errorf("%d overflows int", i64) } - val.SetInt(i64) + return reflect.ValueOf(int(i64)), nil case reflect.Uint8: if i64 < 0 || i64 > math.MaxUint8 { - return fmt.Errorf("%d overflows uint8", i64) + return emptyValue, fmt.Errorf("%d overflows uint8", i64) } - val.SetUint(uint64(i64)) + return reflect.ValueOf(uint8(i64)), nil case reflect.Uint16: if i64 < 0 || i64 > math.MaxUint16 { - return fmt.Errorf("%d overflows uint16", i64) + return emptyValue, fmt.Errorf("%d overflows uint16", i64) } - val.SetUint(uint64(i64)) + return reflect.ValueOf(uint16(i64)), nil case reflect.Uint32: if i64 < 0 || i64 > math.MaxUint32 { - return fmt.Errorf("%d overflows uint32", i64) + return emptyValue, fmt.Errorf("%d overflows uint32", i64) } - val.SetUint(uint64(i64)) + return reflect.ValueOf(uint32(i64)), nil case reflect.Uint64: if i64 < 0 { - return fmt.Errorf("%d overflows uint64", i64) + return emptyValue, fmt.Errorf("%d overflows uint64", i64) } - val.SetUint(uint64(i64)) + return reflect.ValueOf(uint64(i64)), nil case reflect.Uint: if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint - return fmt.Errorf("%d overflows uint", i64) + return emptyValue, fmt.Errorf("%d overflows uint", i64) } - val.SetUint(uint64(i64)) + return reflect.ValueOf(uint(i64)), nil default: - return ValueDecoderError{ + return emptyValue, ValueDecoderError{ Name: "IntDecodeValue", Kinds: []reflect.Kind{ reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, @@ -189,6 +177,23 @@ func (ic *intCodec) DecodeValue(_ *Registry, vr ValueReader, val reflect.Value) Received: reflect.Zero(t), } } +} + +// DecodeValue is the ValueDecoder for uint types. +func (ic *intCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { + if !val.CanSet() { + return ValueDecoderError{ + Name: "IntDecodeValue", + Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Received: val, + } + } + + elem, err := ic.decodeType(reg, vr, val.Type()) + if err != nil { + return err + } + val.Set(elem) return nil } diff --git a/bson/map_codec.go b/bson/map_codec.go index 5f6963ed26..a9640d34c6 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -14,10 +14,6 @@ import ( "strconv" ) -var ( - defaultMapCodec = &mapCodec{} -) - // mapCodec is the Codec used for map values. type mapCodec struct { // decodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination @@ -125,7 +121,7 @@ func (mc *mapCodec) mapEncodeValue(reg EncoderRegistry, dw DocumentWriter, val r } // DecodeValue is the ValueDecoder for map[string/decimal]* types. -func (mc *mapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (mc *mapCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) { return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} } @@ -151,18 +147,18 @@ func (mc *mapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va val.Set(reflect.MakeMap(val.Type())) } - if val.Len() > 0 && (mc.decodeZerosMap || dc.zeroMaps) { + if val.Len() > 0 && mc.decodeZerosMap { clearMap(val) } eType := val.Type().Elem() - decoder, err := dc.LookupDecoder(eType) + decoder, err := reg.LookupDecoder(eType) if err != nil { return err } if eType == tEmpty { - dc.ancestor = val.Type() + eType = val.Type() } keyType := val.Type().Key() @@ -181,10 +177,13 @@ func (mc *mapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va return err } - elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, eType, true) + elem, err := decodeTypeOrValueWithInfo(decoder, reg, vr, eType) if err != nil { return newDecodeError(key, err) } + if t := val.Type().Elem(); elem.Type() != t { + elem = elem.Convert(t) + } val.SetMapIndex(k, elem) } diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index f0b77f4efb..9d9255f3bd 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -37,10 +37,10 @@ func newMgoRegistryBuilder() *RegistryBuilder { intcodec := func() ValueEncoder { return &intCodec{encodeToMinSize: true} } return NewRegistryBuilder(). - RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true}). - RegisterKindDecoder(reflect.String, &stringCodec{}). - RegisterKindDecoder(reflect.Struct, structcodec). - RegisterKindDecoder(reflect.Map, mapCodec). + RegisterTypeDecoder(tEmpty, func() ValueDecoder { return &emptyInterfaceCodec{decodeBinaryAsSlice: true} }). + RegisterKindDecoder(reflect.String, func() ValueDecoder { return &stringCodec{} }). + RegisterKindDecoder(reflect.Struct, func() ValueDecoder { return structcodec }). + RegisterKindDecoder(reflect.Map, func() ValueDecoder { return mapCodec }). RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: true} }). RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return structcodec }). RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{encodeNilAsEmpty: true} }). @@ -56,7 +56,7 @@ func newMgoRegistryBuilder() *RegistryBuilder { RegisterTypeMapEntry(Type(0), tM). RegisterTypeMapEntry(TypeEmbeddedDocument, tM). RegisterInterfaceEncoder(tGetter, func() ValueEncoder { return ValueEncoderFunc(GetterEncodeValue) }). - RegisterInterfaceDecoder(tSetter, ValueDecoderFunc(SetterDecodeValue)) + RegisterInterfaceDecoder(tSetter, func() ValueDecoder { return ValueDecoderFunc(SetterDecodeValue) }) } // NewMgoRegistry creates a new bson.Registry configured with the default encoders and decoders. @@ -72,7 +72,7 @@ func NewRespectNilValuesMgoRegistry() *Registry { } return newMgoRegistryBuilder(). - RegisterKindDecoder(reflect.Map, mapCodec). + RegisterKindDecoder(reflect.Map, func() ValueDecoder { return mapCodec }). RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: false} }). RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{} }). RegisterKindEncoder(reflect.Map, func() ValueEncoder { return mapCodec }). diff --git a/bson/pointer_codec.go b/bson/pointer_codec.go index d0aec9c7d5..bca19742bc 100644 --- a/bson/pointer_codec.go +++ b/bson/pointer_codec.go @@ -51,7 +51,7 @@ func (pc *pointerCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val ref // DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and // using that to decode. If the BSON value is Null, this method will set the pointer to nil. -func (pc *pointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (pc *pointerCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Ptr { return ValueDecoderError{Name: "pointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} } @@ -74,15 +74,15 @@ func (pc *pointerCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflec if v == nil { return ErrNoDecoder{Type: typ} } - return v.(ValueDecoder).DecodeValue(dc, vr, val.Elem()) + return v.(ValueDecoder).DecodeValue(reg, vr, val.Elem()) } // TODO(charlie): handle concurrent requests for the same type - dec, err := dc.LookupDecoder(typ.Elem()) + dec, err := reg.LookupDecoder(typ.Elem()) if err != nil { return err } if v, ok := pc.dcache.LoadOrStore(typ, dec); ok { dec = v.(ValueDecoder) } - return dec.DecodeValue(dc, vr, val.Elem()) + return dec.DecodeValue(reg, vr, val.Elem()) } diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index 082cd15357..f5a67165e4 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -23,8 +23,8 @@ func registerPrimitiveCodecs(rb *RegistryBuilder) { rb.RegisterTypeEncoder(tRawValue, func() ValueEncoder { return ValueEncoderFunc(rawValueEncodeValue) }). RegisterTypeEncoder(tRaw, func() ValueEncoder { return ValueEncoderFunc(rawEncodeValue) }). - RegisterTypeDecoder(tRawValue, ValueDecoderFunc(rawValueDecodeValue)). - RegisterTypeDecoder(tRaw, ValueDecoderFunc(rawDecodeValue)) + RegisterTypeDecoder(tRawValue, func() ValueDecoder { return ValueDecoderFunc(rawValueDecodeValue) }). + RegisterTypeDecoder(tRaw, func() ValueDecoder { return ValueDecoderFunc(rawDecodeValue) }) } // rawValueEncodeValue is the ValueEncoderFunc for RawValue. @@ -50,7 +50,7 @@ func rawValueEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) e } // rawValueDecodeValue is the ValueDecoderFunc for RawValue. -func rawValueDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +func rawValueDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRawValue { return ValueDecoderError{Name: "RawValueDecodeValue", Types: []reflect.Type{tRawValue}, Received: val} } @@ -76,7 +76,7 @@ func rawEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error } // rawDecodeValue is the ValueDecoderFunc for Reader. -func rawDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +func rawDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tRaw { return ValueDecoderError{Name: "RawDecodeValue", Types: []reflect.Type{tRaw}, Received: val} } diff --git a/bson/registry.go b/bson/registry.go index 8f1cc421b9..41a652fa31 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -62,11 +62,11 @@ type DecoderFactory func() ValueDecoder // safe. type RegistryBuilder struct { typeEncoders map[reflect.Type]EncoderFactory - typeDecoders *typeDecoderCache + typeDecoders map[reflect.Type]DecoderFactory interfaceEncoders map[reflect.Type]EncoderFactory - interfaceDecoders []interfaceValueDecoder + interfaceDecoders map[reflect.Type]DecoderFactory kindEncoders [reflect.UnsafePointer + 1]EncoderFactory - kindDecoders *kindDecoderCache + kindDecoders [reflect.UnsafePointer + 1]DecoderFactory typeMap map[Type]reflect.Type } @@ -74,9 +74,9 @@ type RegistryBuilder struct { func NewRegistryBuilder() *RegistryBuilder { rb := &RegistryBuilder{ typeEncoders: make(map[reflect.Type]EncoderFactory), - typeDecoders: new(typeDecoderCache), + typeDecoders: make(map[reflect.Type]DecoderFactory), interfaceEncoders: make(map[reflect.Type]EncoderFactory), - kindDecoders: new(kindDecoderCache), + interfaceDecoders: make(map[reflect.Type]DecoderFactory), typeMap: make(map[Type]reflect.Type), } registerDefaultEncoders(rb) @@ -102,7 +102,7 @@ func (rb *RegistryBuilder) RegisterTypeEncoder(valueType reflect.Type, encFac En return rb } -// RegisterTypeDecoder registers the provided ValueDecoder for the provided type. +// RegisterTypeDecoder registers a ValueDecoder factory for the provided type. // // The type will be used as provided, so a decoder can be registered for a type and a different // decoder can be registered for a pointer to that type. @@ -112,8 +112,10 @@ func (rb *RegistryBuilder) RegisterTypeEncoder(valueType reflect.Type, encFac En // implements the interface. To get the latter behavior, call RegisterHookDecoder instead. // // RegisterTypeDecoder should not be called concurrently with any other Registry method. -func (rb *RegistryBuilder) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) *RegistryBuilder { - rb.typeDecoders.Store(valueType, dec) +func (rb *RegistryBuilder) RegisterTypeDecoder(valueType reflect.Type, decFac DecoderFactory) *RegistryBuilder { + if decFac != nil { + rb.typeDecoders[valueType] = decFac + } return rb } @@ -136,7 +138,7 @@ func (rb *RegistryBuilder) RegisterKindEncoder(kind reflect.Kind, encFac Encoder return rb } -// RegisterKindDecoder registers the provided ValueDecoder for the provided kind. +// RegisterKindDecoder registers a ValueDecoder factory for the provided kind. // // Use RegisterKindDecoder to register a decoder for any type with the same underlying kind. For // example, consider the type MyInt defined as @@ -148,8 +150,10 @@ func (rb *RegistryBuilder) RegisterKindEncoder(kind reflect.Kind, encFac Encoder // reg.RegisterKindDecoder(reflect.Int32, myDecoder) // // RegisterKindDecoder should not be called concurrently with any other Registry method. -func (rb *RegistryBuilder) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder { - rb.kindDecoders.Store(kind, dec) +func (rb *RegistryBuilder) RegisterKindDecoder(kind reflect.Kind, decFac DecoderFactory) *RegistryBuilder { + if decFac != nil && kind < reflect.Kind(len(rb.kindDecoders)) { + rb.kindDecoders[kind] = decFac + } return rb } @@ -173,28 +177,23 @@ func (rb *RegistryBuilder) RegisterInterfaceEncoder(iface reflect.Type, encFac E return rb } -// RegisterInterfaceDecoder registers an decoder for the provided interface type iface. This decoder will -// be called when unmarshaling into a type if the type implements iface or a pointer to the type +// RegisterInterfaceDecoder registers a decoder factory for the provided interface type iface. This decoder +// will be called when unmarshaling into a type if the type implements iface or a pointer to the type // implements iface. If the provided type is not an interface (i.e. iface.Kind() != reflect.Interface), // this method will panic. // // RegisterInterfaceDecoder should not be called concurrently with any other Registry method. -func (rb *RegistryBuilder) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder) *RegistryBuilder { +func (rb *RegistryBuilder) RegisterInterfaceDecoder(iface reflect.Type, decFac DecoderFactory) *RegistryBuilder { if iface.Kind() != reflect.Interface { panicStr := fmt.Errorf("RegisterInterfaceDecoder expects a type with kind reflect.Interface, "+ "got type %s with kind %s", iface, iface.Kind()) panic(panicStr) } - for idx, decoder := range rb.interfaceDecoders { - if decoder.i == iface { - rb.interfaceDecoders[idx].vd = dec - return rb - } + if decFac != nil { + rb.interfaceDecoders[iface] = decFac } - rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: iface, vd: dec}) - return rb } @@ -218,56 +217,73 @@ func (rb *RegistryBuilder) RegisterTypeMapEntry(bt Type, rt reflect.Type) *Regis func (rb *RegistryBuilder) Build() *Registry { r := &Registry{ typeEncoders: new(sync.Map), - typeDecoders: rb.typeDecoders.Clone(), + typeDecoders: new(sync.Map), interfaceEncoders: make([]interfaceValueEncoder, 0, len(rb.interfaceEncoders)), - interfaceDecoders: append([]interfaceValueDecoder(nil), rb.interfaceDecoders...), - kindDecoders: rb.kindDecoders.Clone(), - encoderTypeMap: make(map[reflect.Type][]ValueEncoder), + interfaceDecoders: make([]interfaceValueDecoder, 0, len(rb.interfaceDecoders)), typeMap: make(map[Type]reflect.Type), + + encoderTypeMap: make(map[reflect.Type][]ValueEncoder), + decoderTypeMap: make(map[reflect.Type][]ValueDecoder), } + encoderCache := make(map[reflect.Value]ValueEncoder) - for k, v := range rb.typeEncoders { - var encoder ValueEncoder - if enc, ok := encoderCache[reflect.ValueOf(v)]; ok { - encoder = enc - } else { - encoder = v() - encoderCache[reflect.ValueOf(v)] = encoder - et := reflect.ValueOf(encoder).Type() - r.encoderTypeMap[et] = append(r.encoderTypeMap[et], encoder) + getEncoder := func(encFac EncoderFactory) ValueEncoder { + if enc, ok := encoderCache[reflect.ValueOf(encFac)]; ok { + return enc } + encoder := encFac() + encoderCache[reflect.ValueOf(encFac)] = encoder + t := reflect.ValueOf(encoder).Type() + r.encoderTypeMap[t] = append(r.encoderTypeMap[t], encoder) + return encoder + } + for k, v := range rb.typeEncoders { + encoder := getEncoder(v) r.typeEncoders.Store(k, encoder) } for k, v := range rb.interfaceEncoders { - var encoder ValueEncoder - if enc, ok := encoderCache[reflect.ValueOf(v)]; ok { - encoder = enc - } else { - encoder = v() - encoderCache[reflect.ValueOf(v)] = encoder - et := reflect.ValueOf(encoder).Type() - r.encoderTypeMap[et] = append(r.encoderTypeMap[et], encoder) - } + encoder := getEncoder(v) r.interfaceEncoders = append(r.interfaceEncoders, interfaceValueEncoder{k, encoder}) } for i, v := range rb.kindEncoders { if v == nil { continue } - var encoder ValueEncoder - if enc, ok := encoderCache[reflect.ValueOf(v)]; ok { - encoder = enc - } else { - encoder = v() - encoderCache[reflect.ValueOf(v)] = encoder - et := reflect.ValueOf(encoder).Type() - r.encoderTypeMap[et] = append(r.encoderTypeMap[et], encoder) - } + encoder := getEncoder(v) r.kindEncoders[i] = encoder } + + decoderCache := make(map[reflect.Value]ValueDecoder) + getDecoder := func(decFac DecoderFactory) ValueDecoder { + if dec, ok := decoderCache[reflect.ValueOf(decFac)]; ok { + return dec + } + decoder := decFac() + decoderCache[reflect.ValueOf(decFac)] = decoder + t := reflect.ValueOf(decoder).Type() + r.decoderTypeMap[t] = append(r.decoderTypeMap[t], decoder) + return decoder + } + for k, v := range rb.typeDecoders { + decoder := getDecoder(v) + r.typeDecoders.Store(k, decoder) + } + for k, v := range rb.interfaceDecoders { + decoder := getDecoder(v) + r.interfaceDecoders = append(r.interfaceDecoders, interfaceValueDecoder{k, decoder}) + } + for i, v := range rb.kindDecoders { + if v == nil { + continue + } + decoder := getDecoder(v) + r.kindDecoders[i] = decoder + } + for k, v := range rb.typeMap { r.typeMap[k] = v } + return r } @@ -306,14 +322,15 @@ func (rb *RegistryBuilder) Build() *Registry { // Read [Registry.LookupDecoder] and [Registry.LookupEncoder] for Registry lookup procedure. type Registry struct { typeEncoders *sync.Map // map[reflect.Type]ValueEncoder - typeDecoders *typeDecoderCache + typeDecoders *sync.Map // map[reflect.Type]ValueDecoder interfaceEncoders []interfaceValueEncoder interfaceDecoders []interfaceValueDecoder kindEncoders [reflect.UnsafePointer + 1]ValueEncoder - kindDecoders *kindDecoderCache + kindDecoders [reflect.UnsafePointer + 1]ValueDecoder typeMap map[Type]reflect.Type encoderTypeMap map[reflect.Type][]ValueEncoder + decoderTypeMap map[reflect.Type][]ValueDecoder } // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup @@ -408,25 +425,35 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { if valueType == nil { return nil, ErrNoDecoder{Type: valueType} } - dec, found := r.typeDecoders.Load(valueType) - if found { + + if dec, found := r.typeDecoders.Load(valueType); found { if dec == nil { return nil, ErrNoDecoder{Type: valueType} } - return dec, nil + return dec.(ValueDecoder), nil } - dec, found = r.lookupInterfaceDecoder(valueType, true) - if found { - return r.typeDecoders.LoadOrStore(valueType, dec), nil + if dec, found := r.lookupInterfaceDecoder(valueType, true); found { + r.typeDecoders.Store(valueType, dec) + return dec, nil } - if v, ok := r.kindDecoders.Load(valueType.Kind()); ok { - return r.typeDecoders.LoadOrStore(valueType, v), nil + if dec, found := r.lookupKindDecoder(valueType.Kind()); found { + r.typeDecoders.Store(valueType, dec) + return dec, nil } return nil, ErrNoDecoder{Type: valueType} } +func (r *Registry) lookupKindDecoder(valueKind reflect.Kind) (ValueDecoder, bool) { + if valueKind < reflect.Kind(len(r.kindDecoders)) { + if dec := r.kindDecoders[valueKind]; dec != nil { + return dec, true + } + } + return nil, false +} + func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) { if valueType == nil { return nil, false @@ -440,7 +467,7 @@ func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool // ahead in interfaceDecoders defaultDec, found := r.lookupInterfaceDecoder(valueType, false) if !found { - defaultDec, _ = r.kindDecoders.Load(valueType.Kind()) + defaultDec, _ = r.lookupKindDecoder(valueType.Kind()) } return &condAddrDecoder{canAddrDec: idec.vd, elseDec: defaultDec}, true } diff --git a/bson/registry_examples_test.go b/bson/registry_examples_test.go index 4b15dde3d5..35b5016eba 100644 --- a/bson/registry_examples_test.go +++ b/bson/registry_examples_test.go @@ -88,7 +88,7 @@ func ExampleRegistry_customDecoder() { lenientBoolType := reflect.TypeOf(lenientBool(true)) lenientBoolDecoder := func( - dc bson.DecodeContext, + _ bson.DecoderRegistry, vr bson.ValueReader, val reflect.Value, ) error { @@ -135,7 +135,7 @@ func ExampleRegistry_customDecoder() { reg := bson.NewRegistryBuilder() reg.RegisterTypeDecoder( lenientBoolType, - bson.ValueDecoderFunc(lenientBoolDecoder), + func() bson.ValueDecoder { return bson.ValueDecoderFunc(lenientBoolDecoder) }, ) // Marshal a BSON document with a single field "isOK" that is a non-zero @@ -228,7 +228,7 @@ func ExampleRegistryBuilder_RegisterKindDecoder() { // "kind" decoder for kind reflect.Int64. That way, we can even decode to // user-defined types with underlying type int64. flexibleInt64KindDecoder := func( - dc bson.DecodeContext, + _ bson.DecoderRegistry, vr bson.ValueReader, val reflect.Value, ) error { @@ -280,7 +280,7 @@ func ExampleRegistryBuilder_RegisterKindDecoder() { reg := bson.NewRegistryBuilder() reg.RegisterKindDecoder( reflect.Int64, - bson.ValueDecoderFunc(flexibleInt64KindDecoder), + func() bson.ValueDecoder { return bson.ValueDecoderFunc(flexibleInt64KindDecoder) }, ) // Marshal a BSON document with fields that are mixed numeric types but all diff --git a/bson/registry_test.go b/bson/registry_test.go index f60542b74d..c321f48d88 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -19,9 +19,9 @@ import ( func newTestRegistryBuilder() *RegistryBuilder { return &RegistryBuilder{ typeEncoders: make(map[reflect.Type]EncoderFactory), - typeDecoders: new(typeDecoderCache), + typeDecoders: make(map[reflect.Type]DecoderFactory), interfaceEncoders: make(map[reflect.Type]EncoderFactory), - kindDecoders: new(kindDecoderCache), + interfaceDecoders: make(map[reflect.Type]DecoderFactory), typeMap: make(map[Type]reflect.Type), } } @@ -376,6 +376,14 @@ func TestRegistryBuilder(t *testing.T) { fmcEncFac := func() ValueEncoder { return fmc } pcEncFac := func() ValueEncoder { return pc } + fc1DecFac := func() ValueDecoder { return fc1 } + fc2DecFac := func() ValueDecoder { return fc2 } + fc3DecFac := func() ValueDecoder { return fc3 } + fscDecFac := func() ValueDecoder { return fsc } + fslccDecFac := func() ValueDecoder { return fslcc } + fmcDecFac := func() ValueDecoder { return fmc } + pcDecFac := func() ValueDecoder { return pc } + reg := newTestRegistryBuilder(). RegisterTypeEncoder(ft1, fc1EncFac). RegisterTypeEncoder(ft2, fc2EncFac). @@ -385,18 +393,18 @@ func TestRegistryBuilder(t *testing.T) { RegisterKindEncoder(reflect.Array, fslccEncFac). RegisterKindEncoder(reflect.Map, fmcEncFac). RegisterKindEncoder(reflect.Ptr, pcEncFac). - RegisterTypeDecoder(ft1, fc1). - RegisterTypeDecoder(ft2, fc2). - RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder - RegisterKindDecoder(reflect.Struct, fsc). - RegisterKindDecoder(reflect.Slice, fslcc). - RegisterKindDecoder(reflect.Array, fslcc). - RegisterKindDecoder(reflect.Map, fmc). - RegisterKindDecoder(reflect.Ptr, pc). + RegisterTypeDecoder(ft1, fc1DecFac). + RegisterTypeDecoder(ft2, fc2DecFac). + RegisterTypeDecoder(ti1, fc1DecFac). // values whose exact type is testInterface1 will use fc1 encoder + RegisterKindDecoder(reflect.Struct, fscDecFac). + RegisterKindDecoder(reflect.Slice, fslccDecFac). + RegisterKindDecoder(reflect.Array, fslccDecFac). + RegisterKindDecoder(reflect.Map, fmcDecFac). + RegisterKindDecoder(reflect.Ptr, pcDecFac). RegisterInterfaceEncoder(ti2, fc2EncFac). RegisterInterfaceEncoder(ti3, fc3EncFac). - RegisterInterfaceDecoder(ti2, fc2). - RegisterInterfaceDecoder(ti3, fc3). + RegisterInterfaceDecoder(ti2, fc2DecFac). + RegisterInterfaceDecoder(ti3, fc3DecFac). Build() testCases := []struct { @@ -712,7 +720,7 @@ type fakeCodec struct { func (*fakeCodec) EncodeValue(EncoderRegistry, ValueWriter, reflect.Value) error { return nil } -func (*fakeCodec) DecodeValue(DecodeContext, ValueReader, reflect.Value) error { +func (*fakeCodec) DecodeValue(DecoderRegistry, ValueReader, reflect.Value) error { return nil } diff --git a/bson/setter_getter.go b/bson/setter_getter.go index 46706241be..5d08b40c42 100644 --- a/bson/setter_getter.go +++ b/bson/setter_getter.go @@ -46,7 +46,7 @@ type Getter interface { } // SetterDecodeValue is the ValueDecoderFunc for Setter types. -func SetterDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { +func SetterDecodeValue(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.IsValid() || (!val.Type().Implements(tSetter) && !reflect.PtrTo(val.Type()).Implements(tSetter)) { return ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} } diff --git a/bson/slice_codec.go b/bson/slice_codec.go index 3640cdd124..71aaf32b93 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -12,10 +12,6 @@ import ( "reflect" ) -var ( - defaultSliceCodec = &sliceCodec{} -) - // sliceCodec is the Codec used for slice values. type sliceCodec struct { // encodeNilAsEmpty causes EncodeValue to marshal nil Go slices as empty BSON arrays instead of @@ -98,7 +94,7 @@ func (sc sliceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflec } // DecodeValue is the ValueDecoder for slice types. -func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *sliceCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Slice { return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } @@ -153,16 +149,15 @@ func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect. return fmt.Errorf("cannot decode %v into a slice", vrType) } - var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error) + var elemsFunc func(DecoderRegistry, ValueReader, reflect.Value) ([]reflect.Value, error) switch val.Type().Elem() { case tE: - dc.ancestor = val.Type() elemsFunc = decodeD default: elemsFunc = decodeDefault } - elems, err := elemsFunc(dc, vr, val) + elems, err := elemsFunc(reg, vr, val) if err != nil { return err } diff --git a/bson/string_codec.go b/bson/string_codec.go index 9f1ee76136..55b0fd9c62 100644 --- a/bson/string_codec.go +++ b/bson/string_codec.go @@ -19,10 +19,6 @@ type stringCodec struct { decodeObjectIDAsHex bool } -var ( - defaultStringCodec = &stringCodec{} -) - // EncodeValue is the ValueEncoder for string types. func (sc *stringCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.String { @@ -36,7 +32,7 @@ func (sc *stringCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflec return vw.WriteString(val.String()) } -func (sc *stringCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (sc *stringCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t.Kind() != reflect.String { return emptyValue, ValueDecoderError{ Name: "StringDecodeValue", @@ -58,7 +54,7 @@ func (sc *stringCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Ty if err != nil { return emptyValue, err } - if !sc.decodeObjectIDAsHex && !dc.decodeObjectIDAsHex { + if !sc.decodeObjectIDAsHex { return emptyValue, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set") } str = oid.Hex() @@ -92,12 +88,12 @@ func (sc *stringCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Ty } // DecodeValue is the ValueDecoder for string types. -func (sc *stringCodec) DecodeValue(dctx DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *stringCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.String { return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} } - elem, err := sc.decodeType(dctx, vr, val.Type()) + elem, err := sc.decodeType(reg, vr, val.Type()) if err != nil { return err } diff --git a/bson/struct_codec.go b/bson/struct_codec.go index c3ddd4f2c6..ac917a5b17 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -16,10 +16,6 @@ import ( "time" ) -var ( - defaultStructCodec = newStructCodec(DefaultStructTagParser) -) - // DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type. type DecodeError struct { keys []string @@ -204,7 +200,7 @@ func newDecodeError(key string, original error) error { // DecodeValue implements the Codec interface. // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr. // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared. -func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (sc *structCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Kind() != reflect.Struct { return ValueDecoderError{Name: "structCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } @@ -229,12 +225,12 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) } - sd, err := sc.describeStruct(val.Type(), dc.useJSONStructTags, false) + sd, err := sc.describeStruct(val.Type(), sc.useJSONStructTags, false) if err != nil { return err } - if sc.decodeZeroStruct || dc.zeroStructs { + if sc.decodeZeroStruct { val.Set(reflect.Zero(val.Type())) } if sc.decodeDeepZeroInline && sd.inline { @@ -245,7 +241,7 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect var inlineMap reflect.Value if sd.inlineMap >= 0 { inlineMap = val.Field(sd.inlineMap) - decoder, err = dc.LookupDecoder(inlineMap.Type().Elem()) + decoder, err = reg.LookupDecoder(inlineMap.Type().Elem()) if err != nil { return err } @@ -289,8 +285,7 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect } elem := reflect.New(inlineMap.Type().Elem()).Elem() - dc.ancestor = inlineMap.Type() - err = decoder.DecodeValue(dc, vr, elem) + err = decoder.DecodeValue(reg, vr, elem) if err != nil { return err } @@ -317,23 +312,12 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect } field = field.Addr() - dctx := DecodeContext{ - Registry: dc.Registry, - truncate: fd.truncate || dc.truncate, - defaultDocumentType: dc.defaultDocumentType, - binaryAsSlice: dc.binaryAsSlice, - useJSONStructTags: dc.useJSONStructTags, - useLocalTimeZone: dc.useLocalTimeZone, - zeroMaps: dc.zeroMaps, - zeroStructs: dc.zeroStructs, - } - - decoder, err := dc.Registry.LookupDecoder(fd.fieldType) + decoder, err := reg.LookupDecoder(fd.fieldType) if err != nil { return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()}) } - err = decoder.DecodeValue(dctx, vr, field.Elem()) + err = decoder.DecodeValue(reg, vr, field.Elem()) if err != nil { return newDecodeError(fd.name, err) } diff --git a/bson/time_codec.go b/bson/time_codec.go index d9bb57404b..1e62117d47 100644 --- a/bson/time_codec.go +++ b/bson/time_codec.go @@ -22,11 +22,7 @@ type timeCodec struct { useLocalTimeZone bool } -var ( - defaultTimeCodec = &timeCodec{} -) - -func (tc *timeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (tc *timeCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tTime { return emptyValue, ValueDecoderError{ Name: "TimeDecodeValue", @@ -77,19 +73,19 @@ func (tc *timeCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type return emptyValue, fmt.Errorf("cannot decode %v into a time.Time", vrType) } - if !tc.useLocalTimeZone && !dc.useLocalTimeZone { + if !tc.useLocalTimeZone { timeVal = timeVal.UTC() } return reflect.ValueOf(timeVal), nil } // DecodeValue is the ValueDecoderFunc for time.Time. -func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { +func (tc *timeCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tTime { return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val} } - elem, err := tc.decodeType(dc, vr, tTime) + elem, err := tc.decodeType(reg, vr, tTime) if err != nil { return err } diff --git a/bson/uint_codec.go b/bson/uint_codec.go deleted file mode 100644 index f63404e934..0000000000 --- a/bson/uint_codec.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bson - -import ( - "fmt" - "math" - "reflect" -) - -// uintCodec is the Codec used for uint values. -type uintCodec struct { - // encodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the - // minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value. - encodeToMinSize bool -} - -var ( - defaultUIntCodec = &uintCodec{} -) - -// EncodeValue is the ValueEncoder for uint types. -func (uic *uintCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { - switch val.Kind() { - case reflect.Uint8, reflect.Uint16: - return vw.WriteInt32(int32(val.Uint())) - case reflect.Uint, reflect.Uint32, reflect.Uint64: - u64 := val.Uint() - - // If encodeToMinSize is true for a non-uint64 value we should write val as an int32 - useMinSize := uic.encodeToMinSize && val.Kind() != reflect.Uint64 - - if u64 <= math.MaxInt32 && useMinSize { - return vw.WriteInt32(int32(u64)) - } - if u64 > math.MaxInt64 { - return fmt.Errorf("%d overflows int64", u64) - } - return vw.WriteInt64(int64(u64)) - } - - return ValueEncoderError{ - Name: "UintEncodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } -} - -func (uic *uintCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - var i64 int64 - var err error - switch vrType := vr.Type(); vrType { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return emptyValue, err - } - i64 = int64(i32) - case TypeInt64: - i64, err = vr.ReadInt64() - if err != nil { - return emptyValue, err - } - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return emptyValue, err - } - if !dc.truncate && math.Floor(f64) != f64 { - return emptyValue, errCannotTruncate - } - if f64 > float64(math.MaxInt64) { - return emptyValue, fmt.Errorf("%g overflows int64", f64) - } - i64 = int64(f64) - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return emptyValue, err - } - if b { - i64 = 1 - } - case TypeNull: - if err = vr.ReadNull(); err != nil { - return emptyValue, err - } - case TypeUndefined: - if err = vr.ReadUndefined(); err != nil { - return emptyValue, err - } - default: - return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType) - } - - switch t.Kind() { - case reflect.Uint8: - if i64 < 0 || i64 > math.MaxUint8 { - return emptyValue, fmt.Errorf("%d overflows uint8", i64) - } - - return reflect.ValueOf(uint8(i64)), nil - case reflect.Uint16: - if i64 < 0 || i64 > math.MaxUint16 { - return emptyValue, fmt.Errorf("%d overflows uint16", i64) - } - - return reflect.ValueOf(uint16(i64)), nil - case reflect.Uint32: - if i64 < 0 || i64 > math.MaxUint32 { - return emptyValue, fmt.Errorf("%d overflows uint32", i64) - } - - return reflect.ValueOf(uint32(i64)), nil - case reflect.Uint64: - if i64 < 0 { - return emptyValue, fmt.Errorf("%d overflows uint64", i64) - } - - return reflect.ValueOf(uint64(i64)), nil - case reflect.Uint: - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint - return emptyValue, fmt.Errorf("%d overflows uint", i64) - } - - return reflect.ValueOf(uint(i64)), nil - default: - return emptyValue, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: reflect.Zero(t), - } - } -} - -// DecodeValue is the ValueDecoder for uint types. -func (uic *uintCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() { - return ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } - } - - elem, err := uic.decodeType(dc, vr, val.Type()) - if err != nil { - return err - } - - val.SetUint(elem.Uint()) - return nil -} diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index d8ef9b69ba..dc8102b02c 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -219,7 +219,7 @@ func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) { // different Registry is used. // Create a custom Registry that negates BSON int32 values when decoding. - var decodeInt32 ValueDecoderFunc = func(_ DecodeContext, vr ValueReader, val reflect.Value) error { + var decodeInt32 ValueDecoderFunc = func(_ DecoderRegistry, vr ValueReader, val reflect.Value) error { i32, err := vr.ReadInt32() if err != nil { return err @@ -229,7 +229,7 @@ func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) { return nil } customReg := NewRegistryBuilder(). - RegisterTypeDecoder(tInt32, decodeInt32). + RegisterTypeDecoder(tInt32, func() ValueDecoder { return decodeInt32 }). Build() docBytes := bsoncore.BuildDocumentFromElements( diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index 3af7578d12..05524f658a 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -76,7 +76,7 @@ func TestUnmarshalValue(t *testing.T) { }, } reg := NewRegistryBuilder(). - RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}). + RegisterTypeDecoder(reflect.TypeOf([]byte{}), func() ValueDecoder { return &sliceCodec{} }). Build() for _, tc := range testCases { tc := tc @@ -112,7 +112,7 @@ func BenchmarkSliceCodecUnmarshal(b *testing.B) { }, } reg := NewRegistryBuilder(). - RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}). + RegisterTypeDecoder(reflect.TypeOf([]byte{}), func() ValueDecoder { return &sliceCodec{} }). Build() for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index d3ebb1421b..49607d6689 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -44,7 +44,7 @@ func (e *negateCodec) EncodeValue(_ bson.EncoderRegistry, vw bson.ValueWriter, v } // DecodeValue negates the value of ID when reading -func (e *negateCodec) DecodeValue(_ bson.DecodeContext, vr bson.ValueReader, val reflect.Value) error { +func (e *negateCodec) DecodeValue(_ bson.DecoderRegistry, vr bson.ValueReader, val reflect.Value) error { i, err := vr.ReadInt64() if err != nil { return err @@ -102,7 +102,7 @@ func TestClient(t *testing.T) { reg := bson.NewRegistryBuilder(). RegisterTypeEncoder(reflect.TypeOf(int64(0)), func() bson.ValueEncoder { return &negateCodec{} }). - RegisterTypeDecoder(reflect.TypeOf(int64(0)), &negateCodec{}). + RegisterTypeDecoder(reflect.TypeOf(int64(0)), func() bson.ValueDecoder { return &negateCodec{} }). Build() registryOpts := options.Client(). SetRegistry(reg) diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index 487714d834..f8cb489933 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -77,24 +77,24 @@ type testData struct { } // custom decoder for testData type -func decodeTestData(dc bson.DecodeContext, vr bson.ValueReader, val reflect.Value) error { +func decodeTestData(reg bson.DecoderRegistry, vr bson.ValueReader, val reflect.Value) error { switch vr.Type() { case bson.TypeArray: docsVal := val.FieldByName("Documents") - decoder, err := dc.Registry.LookupDecoder(docsVal.Type()) + decoder, err := reg.LookupDecoder(docsVal.Type()) if err != nil { return err } - return decoder.DecodeValue(dc, vr, docsVal) + return decoder.DecodeValue(reg, vr, docsVal) case bson.TypeEmbeddedDocument: gridfsDataVal := val.FieldByName("GridFSData") - decoder, err := dc.Registry.LookupDecoder(gridfsDataVal.Type()) + decoder, err := reg.LookupDecoder(gridfsDataVal.Type()) if err != nil { return err } - return decoder.DecodeValue(dc, vr, gridfsDataVal) + return decoder.DecodeValue(reg, vr, gridfsDataVal) } return nil } @@ -183,7 +183,7 @@ var directories = []string{ var checkOutcomeOpts = options.Collection().SetReadPreference(readpref.Primary()).SetReadConcern(readconcern.Local()) var specTestRegistry = bson.NewRegistryBuilder(). RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). - RegisterTypeDecoder(reflect.TypeOf(testData{}), bson.ValueDecoderFunc(decodeTestData)). + RegisterTypeDecoder(reflect.TypeOf(testData{}), func() bson.ValueDecoder { return bson.ValueDecoderFunc(decodeTestData) }). Build() func TestUnifiedSpecs(t *testing.T) { From 7a0b8bf799d3f02827f316d4209206ca90da0388 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Mon, 20 May 2024 14:13:43 -0400 Subject: [PATCH 09/15] WIP --- bson/bsoncodec.go | 28 +----- bson/cond_addr_codec_test.go | 4 +- bson/decoder.go | 78 +++++++++++--- bson/decoder_test.go | 12 +-- bson/default_value_decoders_test.go | 151 +++++++++++++--------------- bson/encoder.go | 16 +-- bson/mgocompat/bson_test.go | 12 +-- bson/primitive_codecs_test.go | 12 +-- bson/raw_value.go | 10 +- bson/raw_value_test.go | 12 +-- bson/registry.go | 10 +- bson/registry_test.go | 12 +-- bson/string_codec_test.go | 12 +-- bson/time_codec_test.go | 4 +- bson/truncation_test.go | 20 ++-- bson/unmarshal.go | 18 ++-- bson/unmarshal_test.go | 6 +- 17 files changed, 205 insertions(+), 212 deletions(-) diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 6d70a99ca1..5e910fca88 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -72,31 +72,6 @@ func (vde ValueDecoderError) Error() string { return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received) } -// DecodeContext is the contextual information required for a Codec to decode a -// value. -type DecodeContext struct { - *Registry - - // defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the - // usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is - // set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an - // error. DocumentType overrides the Ancestor field. - defaultDocumentType reflect.Type - - // truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" - // values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64, - // uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to - // BSON "decimal128" values. - truncate bool - - binaryAsSlice bool - decodeObjectIDAsHex bool - useJSONStructTags bool - useLocalTimeZone bool - zeroMaps bool - zeroStructs bool -} - // EncoderRegistry is an interface provides a ValueEncoder based on the given reflect.Type. type EncoderRegistry interface { LookupEncoder(reflect.Type) (ValueEncoder, error) @@ -166,8 +141,7 @@ var _ ValueDecoder = decodeAdapter{} var _ typeDecoder = decodeAdapter{} func decodeTypeOrValueWithInfo(vd ValueDecoder, reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { - td, ok := vd.(typeDecoder) - if ok && td != nil { + if td, _ := vd.(typeDecoder); td != nil { return td.decodeType(reg, vr, t) } diff --git a/bson/cond_addr_codec_test.go b/bson/cond_addr_codec_test.go index 15b4a8a333..6fd777ae77 100644 --- a/bson/cond_addr_codec_test.go +++ b/bson/cond_addr_codec_test.go @@ -78,7 +78,7 @@ func TestCondAddrCodec(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := condDecoder.DecodeValue(DecodeContext{}, rw, tc.val) + err := condDecoder.DecodeValue(nil, rw, tc.val) assert.Nil(t, err, "CondAddrDecoder error: %v", err) assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked) @@ -87,7 +87,7 @@ func TestCondAddrCodec(t *testing.T) { t.Run("error", func(t *testing.T) { errDecoder := &condAddrDecoder{canAddrDec: decode1, elseDec: nil} - err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable) + err := errDecoder.DecodeValue(nil, rw, unaddressable) want := ErrNoDecoder{Type: unaddressable.Type()} assert.Equal(t, err, want, "expected error %v, got %v", want, err) }) diff --git a/bson/decoder.go b/bson/decoder.go index cc53f422f9..14335c5bb5 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -28,15 +28,15 @@ var decPool = sync.Pool{ // A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as // the source of BSON data. type Decoder struct { - dc DecodeContext - vr ValueReader + reg *Registry + vr ValueReader } // NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr. func NewDecoder(vr ValueReader) *Decoder { return &Decoder{ - dc: DecodeContext{Registry: DefaultRegistry}, - vr: vr, + reg: DefaultRegistry, + vr: vr, } } @@ -68,12 +68,12 @@ func (d *Decoder) Decode(val interface{}) error { default: return fmt.Errorf("argument to Decode must be a pointer or a map, but got %v", rval) } - decoder, err := d.dc.LookupDecoder(rval.Type()) + decoder, err := d.reg.LookupDecoder(rval.Type()) if err != nil { return err } - return decoder.DecodeValue(d.dc, d.vr, rval) + return decoder.DecodeValue(d.reg, d.vr, rval) } // Reset will reset the state of the decoder, using the same *DecodeContext used in @@ -84,59 +84,105 @@ func (d *Decoder) Reset(vr ValueReader) { // SetRegistry replaces the current registry of the decoder with r. func (d *Decoder) SetRegistry(r *Registry) { - d.dc.Registry = r + d.reg = r } // DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This // behavior is restricted to data typed as "interface{}" or "map[string]interface{}". func (d *Decoder) DefaultDocumentM() { - d.dc.defaultDocumentType = reflect.TypeOf(M{}) + t := reflect.TypeOf((*emptyInterfaceCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*emptyInterfaceCodec).defaultDocumentType = reflect.TypeOf(M{}) + } + } } // DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This // behavior is restricted to data typed as "interface{}" or "map[string]interface{}". func (d *Decoder) DefaultDocumentD() { - d.dc.defaultDocumentType = reflect.TypeOf(D{}) + t := reflect.TypeOf((*emptyInterfaceCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*emptyInterfaceCodec).defaultDocumentType = reflect.TypeOf(D{}) + } + } } // AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values // when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct // field. The truncation logic does not apply to BSON "decimal128" values. func (d *Decoder) AllowTruncatingDoubles() { - d.dc.truncate = true + t := reflect.TypeOf((*intCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*intCodec).truncate = true + } + } + // TODO floatCodec } // BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or // "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary. func (d *Decoder) BinaryAsSlice() { - d.dc.binaryAsSlice = true + t := reflect.TypeOf((*emptyInterfaceCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*emptyInterfaceCodec).decodeBinaryAsSlice = true + } + } } // DecodeObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. func (d *Decoder) DecodeObjectIDAsHex() { - d.dc.decodeObjectIDAsHex = true + t := reflect.TypeOf((*stringCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*stringCodec).decodeObjectIDAsHex = true + } + } } // UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson" // struct tag is not specified. func (d *Decoder) UseJSONStructTags() { - d.dc.useJSONStructTags = true + t := reflect.TypeOf((*structCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*structCodec).useJSONStructTags = true + } + } } // UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead // of the UTC timezone. func (d *Decoder) UseLocalTimeZone() { - d.dc.useLocalTimeZone = true + t := reflect.TypeOf((*timeCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*timeCodec).useLocalTimeZone = true + } + } } // ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value // passed to Decode before unmarshaling BSON documents into them. func (d *Decoder) ZeroMaps() { - d.dc.zeroMaps = true + t := reflect.TypeOf((*mapCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*mapCodec).decodeZerosMap = true + } + } } // ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination // value passed to Decode before unmarshaling BSON documents into them. func (d *Decoder) ZeroStructs() { - d.dc.zeroStructs = true + t := reflect.TypeOf((*structCodec)(nil)) + if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*structCodec).decodeZeroStruct = true + } + } } diff --git a/bson/decoder_test.go b/bson/decoder_test.go index 3b96f63559..b101b38d65 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -32,7 +32,7 @@ func TestBasicDecode(t *testing.T) { reg := NewRegistryBuilder().Build() decoder, err := reg.LookupDecoder(reflect.TypeOf(got)) noerr(t, err) - err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got) + err = decoder.DecodeValue(reg, vr, got) noerr(t, err) assert.Equal(t, tc.want, got.Addr().Interface(), "Results do not match.") }) @@ -200,15 +200,13 @@ func TestDecoderv2(t *testing.T) { t.Parallel() r1, r2 := DefaultRegistry, NewRegistryBuilder().Build() - dc1 := DecodeContext{Registry: r1} - dc2 := DecodeContext{Registry: r2} dec := NewDecoder(NewValueReader([]byte{})) - if !reflect.DeepEqual(dec.dc, dc1) { - t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1) + if !reflect.DeepEqual(dec.reg, r1) { + t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r1) } dec.SetRegistry(r2) - if !reflect.DeepEqual(dec.dc, dc2) { - t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2) + if !reflect.DeepEqual(dec.reg, r2) { + t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r2) } }) t.Run("DecodeToNil", func(t *testing.T) { diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 24892adeae..42602e3cbe 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -57,7 +57,7 @@ func TestDefaultValueDecoders(t *testing.T) { type subtest struct { name string val interface{} - dctx *DecodeContext + reg *Registry llvrw *valueReaderWriter invoke invoked err error @@ -186,15 +186,15 @@ func TestDefaultValueDecoders(t *testing.T) { errors.New("ReadDouble error"), }, { - "ReadDouble", int64(3), &DecodeContext{}, + "ReadDouble", int64(3), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.00)}, readDouble, nil, }, - { - "ReadDouble (truncate)", int64(3), &DecodeContext{truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, + // { + // "ReadDouble (truncate)", int64(3), &DecodeContext{truncate: true}, + // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + // nil, + // }, { "ReadDouble (no truncate)", int64(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -418,15 +418,15 @@ func TestDefaultValueDecoders(t *testing.T) { errors.New("ReadDouble error"), }, { - "ReadDouble", uint64(3), &DecodeContext{}, + "ReadDouble", uint64(3), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.00)}, readDouble, nil, }, - { - "ReadDouble (truncate)", uint64(3), &DecodeContext{truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, + // { + // "ReadDouble (truncate)", uint64(3), &DecodeContext{truncate: true}, + // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + // nil, + // }, { "ReadDouble (no truncate)", uint64(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -673,11 +673,11 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14159)}, readDouble, nil, }, - { - "float32/fast path (truncate)", float32(3.14), &DecodeContext{truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, + // { + // "float32/fast path (truncate)", float32(3.14), &DecodeContext{truncate: true}, + // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + // nil, + // }, { "float32/fast path (no truncate)", float32(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -711,11 +711,11 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14159)}, readDouble, nil, }, - { - "float32/reflection path (truncate)", myfloat32(3.14), &DecodeContext{truncate: true}, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - nil, - }, + // { + // "float32/reflection path (truncate)", myfloat32(3.14), &DecodeContext{truncate: true}, + // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + // nil, + // }, { "float32/reflection path (no truncate)", myfloat32(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -803,7 +803,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "wrong kind (non-string key)", map[bool]interface{}{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{}, readElement, fmt.Errorf("unsupported key type: %T", false), @@ -819,7 +819,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", map[string]string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{}, readDocument, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -827,7 +827,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "ReadElement Error", make(map[string]interface{}), - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("re error"), ErrAfter: readElement}, readElement, errors.New("re error"), @@ -905,7 +905,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", [1]string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -913,7 +913,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "ReadValue Error", [1]string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("rv error"), ErrAfter: readValue, BSONType: TypeArray}, readValue, errors.New("rv error"), @@ -921,7 +921,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "DecodeValue Error", [1]string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{BSONType: TypeArray}, readValue, &DecodeError{keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type")}, @@ -999,7 +999,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", []string{}, - &DecodeContext{Registry: newTestRegistryBuilder().Build()}, + newTestRegistryBuilder().Build(), &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -1007,7 +1007,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "ReadValue Error", []string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{Err: errors.New("rv error"), ErrAfter: readValue, BSONType: TypeArray}, readValue, errors.New("rv error"), @@ -1015,7 +1015,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "DecodeValue Error", []string{}, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{BSONType: TypeArray}, readValue, &DecodeError{keys: []string{"0"}, wrapped: errors.New("cannot decode array into a string type")}, @@ -1561,7 +1561,7 @@ func TestDefaultValueDecoders(t *testing.T) { ValueDecoderError{Name: "pointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}}, }, { - "No Decoder", &wrong, &DecodeContext{Registry: buildDefaultRegistry()}, nil, nothing, + "No Decoder", &wrong, buildDefaultRegistry(), nil, nothing, ErrNoDecoder{Type: reflect.TypeOf(wrong)}, }, { @@ -2287,7 +2287,7 @@ func TestDefaultValueDecoders(t *testing.T) { Code: "var hello = 'world';", Scope: D{{"foo", nil}}, }, - &DecodeContext{Registry: buildDefaultRegistry()}, + buildDefaultRegistry(), &valueReaderWriter{BSONType: TypeCodeWithScope, Err: errors.New("dd error"), ErrAfter: readElement}, readElement, errors.New("dd error"), @@ -2346,10 +2346,6 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, rc := range tc.subtests { t.Run(rc.name, func(t *testing.T) { - var dc DecodeContext - if rc.dctx != nil { - dc = *rc.dctx - } llvrw := new(valueReaderWriter) if rc.llvrw != nil { llvrw = rc.llvrw @@ -2357,13 +2353,13 @@ func TestDefaultValueDecoders(t *testing.T) { llvrw.T = t // var got interface{} if rc.val == cansetreflectiontest { // We're doing a CanSet reflection test - err := tc.vd.DecodeValue(dc, llvrw, reflect.Value{}) + err := tc.vd.DecodeValue(rc.reg, llvrw, reflect.Value{}) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } val := reflect.New(reflect.TypeOf(rc.val)).Elem() - err = tc.vd.DecodeValue(dc, llvrw, val) + err = tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } @@ -2375,13 +2371,13 @@ func TestDefaultValueDecoders(t *testing.T) { t.Fatalf("Error must be a DecodeValueError, but got a %T", rc.err) } - err := tc.vd.DecodeValue(dc, llvrw, reflect.Value{}) + err := tc.vd.DecodeValue(rc.reg, llvrw, reflect.Value{}) wanterr.Received = reflect.ValueOf(nil) if !assert.CompareErrors(err, wanterr) { t.Errorf("Errors do not match. got %v; want %v", err, wanterr) } - err = tc.vd.DecodeValue(dc, llvrw, reflect.ValueOf(int(12345))) + err = tc.vd.DecodeValue(rc.reg, llvrw, reflect.ValueOf(int(12345))) wanterr.Received = reflect.ValueOf(int(12345)) if !assert.CompareErrors(err, wanterr) { t.Errorf("Errors do not match. got %v; want %v", err, wanterr) @@ -2399,7 +2395,7 @@ func TestDefaultValueDecoders(t *testing.T) { panic(err) } }() - err := tc.vd.DecodeValue(dc, llvrw, val) + err := tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } @@ -2420,7 +2416,7 @@ func TestDefaultValueDecoders(t *testing.T) { } t.Run("CodeWithScopeCodec/DecodeValue/success", func(t *testing.T) { - dc := DecodeContext{Registry: buildDefaultRegistry()} + reg := buildDefaultRegistry() b := bsoncore.BuildDocument(nil, bsoncore.AppendCodeWithScopeElement( nil, "foo", "var hello = 'world';", @@ -2438,7 +2434,7 @@ func TestDefaultValueDecoders(t *testing.T) { Scope: D{{"bar", nil}}, } val := reflect.New(tCodeWithScope).Elem() - err = codeWithScopeDecodeValue(dc, vr, val) + err = codeWithScopeDecodeValue(reg, vr, val) noerr(t, err) got := val.Interface().(CodeWithScope) @@ -2447,25 +2443,23 @@ func TestDefaultValueDecoders(t *testing.T) { } }) t.Run("ValueUnmarshalerDecodeValue/UnmarshalBSONValue error", func(t *testing.T) { - var dc DecodeContext llvrw := &valueReaderWriter{BSONType: TypeString, Return: string("hello, world!")} llvrw.T = t want := errors.New("ubsonv error") valUnmarshaler := &testValueUnmarshaler{err: want} - got := valueUnmarshalerDecodeValue(dc, llvrw, reflect.ValueOf(valUnmarshaler)) + got := valueUnmarshalerDecodeValue(nil, llvrw, reflect.ValueOf(valUnmarshaler)) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } }) t.Run("ValueUnmarshalerDecodeValue/Unaddressable value", func(t *testing.T) { - var dc DecodeContext llvrw := &valueReaderWriter{BSONType: TypeString, Return: string("hello, world!")} llvrw.T = t val := reflect.ValueOf(testValueUnmarshaler{}) want := ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} - got := valueUnmarshalerDecodeValue(dc, llvrw, val) + got := valueUnmarshalerDecodeValue(nil, llvrw, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -2488,8 +2482,8 @@ func TestDefaultValueDecoders(t *testing.T) { var val [1]string want := fmt.Errorf("more elements returned in array than can fit inside %T, got 2 elements", val) - dc := DecodeContext{Registry: buildDefaultRegistry()} - got := arrayDecodeValue(dc, vr, reflect.ValueOf(val)) + reg := buildDefaultRegistry() + got := arrayDecodeValue(reg, vr, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors do not match. got %v; want %v", got, want) } @@ -3137,7 +3131,7 @@ func TestDefaultValueDecoders(t *testing.T) { noerr(t, err) gotVal := reflect.New(reflect.TypeOf(tc.value)).Elem() - err = dec.DecodeValue(DecodeContext{Registry: reg}, vr, gotVal) + err = dec.DecodeValue(reg, vr, gotVal) noerr(t, err) got := gotVal.Interface() @@ -3186,7 +3180,7 @@ func TestDefaultValueDecoders(t *testing.T) { noerr(t, err) gotVal := reflect.New(reflect.TypeOf(tc.value)).Elem() - err = dec.DecodeValue(DecodeContext{Registry: reg}, vr, gotVal) + err = dec.DecodeValue(reg, vr, gotVal) if err == nil || !strings.Contains(err.Error(), tc.err.Error()) { t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) } @@ -3310,9 +3304,9 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() - dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} + reg := newTestRegistryBuilder().Build() want := ErrNoTypeMapEntry{Type: tc.bsontype} - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3326,11 +3320,8 @@ func TestDefaultValueDecoders(t *testing.T) { reg := newTestRegistryBuilder(). RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). Build() - dc := DecodeContext{ - Registry: reg, - } want := ErrNoDecoder{Type: reflect.TypeOf(tc.val)} - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3346,10 +3337,7 @@ func TestDefaultValueDecoders(t *testing.T) { RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). Build() - dc := DecodeContext{ - Registry: reg, - } - got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, reflect.New(tEmpty).Elem()) + got := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, reflect.New(tEmpty).Elem()) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3362,11 +3350,8 @@ func TestDefaultValueDecoders(t *testing.T) { RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). Build() - dc := DecodeContext{ - Registry: reg, - } got := reflect.New(tEmpty).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, got) + err := defaultEmptyInterfaceCodec.DecodeValue(reg, llvr, got) noerr(t, err) if !cmp.Equal(got.Interface(), want, cmp.Comparer(compareDecimal128)) { t.Errorf("Did not receive expected value. got %v; want %v", got.Interface(), want) @@ -3379,7 +3364,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("non-interface{}", func(t *testing.T) { val := uint64(1234567890) want := ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(val)} - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) + got := defaultEmptyInterfaceCodec.DecodeValue(nil, nil, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3388,7 +3373,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("nil *interface{}", func(t *testing.T) { var val interface{} want := ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.ValueOf(val)} - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) + got := defaultEmptyInterfaceCodec.DecodeValue(nil, nil, reflect.ValueOf(val)) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3398,7 +3383,7 @@ func TestDefaultValueDecoders(t *testing.T) { llvr := &valueReaderWriter{BSONType: TypeDouble} want := ErrNoTypeMapEntry{Type: TypeDouble} val := reflect.New(tEmpty).Elem() - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: newTestRegistryBuilder().Build()}, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(newTestRegistryBuilder().Build(), llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3409,7 +3394,7 @@ func TestDefaultValueDecoders(t *testing.T) { want := D{{"pi", 3.14159}} var got interface{} val := reflect.ValueOf(&got).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: buildDefaultRegistry()}, vr, val) + err := defaultEmptyInterfaceCodec.DecodeValue(buildDefaultRegistry(), vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Errorf("Did not get correct result. got %v; want %v", got, want) @@ -3456,7 +3441,7 @@ func TestDefaultValueDecoders(t *testing.T) { vr := NewValueReader(doc) val := reflect.ValueOf(&got).Elem() - err := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: tc.registry}, vr, val) + err := defaultEmptyInterfaceCodec.DecodeValue(tc.registry, vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Fatalf("got %v, want %v", got, want) @@ -3491,7 +3476,7 @@ func TestDefaultValueDecoders(t *testing.T) { var got D vr := NewValueReader(doc) val := reflect.ValueOf(&got).Elem() - err := (&sliceCodec{}).DecodeValue(DecodeContext{Registry: reg}, vr, val) + err := (&sliceCodec{}).DecodeValue(reg, vr, val) noerr(t, err) if !cmp.Equal(got, want) { t.Fatalf("got %v, want %v", got, want) @@ -3660,16 +3645,16 @@ func TestDefaultValueDecoders(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - dc := DecodeContext{Registry: tc.registry} - if dc.Registry == nil { - dc.Registry = buildDefaultRegistry() + reg := tc.registry + if reg == nil { + reg = buildDefaultRegistry() } var val reflect.Value if rtype := reflect.TypeOf(tc.val); rtype != nil { val = reflect.New(rtype).Elem() } - err := tc.decoder.DecodeValue(dc, tc.vr, val) + err := tc.decoder.DecodeValue(reg, tc.vr, val) assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err) }) } @@ -3681,10 +3666,10 @@ func TestDefaultValueDecoders(t *testing.T) { type inner struct{ Bar string } type outer struct{ Foo inner } - dc := DecodeContext{Registry: buildDefaultRegistry()} + reg := buildDefaultRegistry() vr := NewValueReader(outerBytes) val := reflect.New(reflect.TypeOf(outer{})).Elem() - err := defaultTestStructCodec.DecodeValue(dc, vr, val) + err := defaultTestStructCodec.DecodeValue(reg, vr, val) var decodeErr *DecodeError assert.True(t, errors.As(err, &decodeErr), "expected DecodeError, got %v of type %T", err, err) @@ -3714,10 +3699,10 @@ func TestDefaultValueDecoders(t *testing.T) { registerDefaultDecoders(rb) rb.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) - dc := DecodeContext{Registry: rb.Build()} + reg := rb.Build() vr := NewValueReader(docBytes) val := reflect.New(tD).Elem() - err := dDecodeValue(dc, vr, val) + err := dDecodeValue(reg, vr, val) assert.Nil(t, err, "DDecodeValue error: %v", err) want := D{ @@ -3733,10 +3718,10 @@ func TestDefaultValueDecoders(t *testing.T) { ) type myMap map[string]mybool - dc := DecodeContext{Registry: buildDefaultRegistry()} + reg := buildDefaultRegistry() vr := NewValueReader(docBytes) val := reflect.New(reflect.TypeOf(myMap{})).Elem() - err := (&mapCodec{}).DecodeValue(dc, vr, val) + err := (&mapCodec{}).DecodeValue(reg, vr, val) assert.Nil(t, err, "DecodeValue error: %v", err) want := myMap{ diff --git a/bson/encoder.go b/bson/encoder.go index 42c1900006..dba91b7424 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -71,7 +71,7 @@ func (e *Encoder) SetRegistry(r *Registry) { // the marshaled BSON when the "inline" struct tag option is set. func (e *Encoder) ErrorOnInlineDuplicates() { t := reflect.TypeOf((*structCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*structCodec).overwriteDuplicatedInlinedFields = false } @@ -93,7 +93,7 @@ func (e *Encoder) IntMinSize() { // } // } t := reflect.TypeOf((*intCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*intCodec).encodeToMinSize = true } @@ -104,7 +104,7 @@ func (e *Encoder) IntMinSize() { // strings using fmt.Sprint instead of the default string conversion logic. func (e *Encoder) StringifyMapKeysWithFmt() { t := reflect.TypeOf((*mapCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*mapCodec).encodeKeysWithStringer = true } @@ -115,7 +115,7 @@ func (e *Encoder) StringifyMapKeysWithFmt() { // null. func (e *Encoder) NilMapAsEmpty() { t := reflect.TypeOf((*mapCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*mapCodec).encodeNilAsEmpty = true } @@ -126,7 +126,7 @@ func (e *Encoder) NilMapAsEmpty() { // null. func (e *Encoder) NilSliceAsEmpty() { t := reflect.TypeOf((*sliceCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*sliceCodec).encodeNilAsEmpty = true } @@ -142,7 +142,7 @@ func (e *Encoder) NilByteSliceAsEmpty() { // } // } t := reflect.TypeOf((*byteSliceCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*byteSliceCodec).encodeNilAsEmpty = true } @@ -159,7 +159,7 @@ func (e *Encoder) NilByteSliceAsEmpty() { // zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. func (e *Encoder) OmitZeroStruct() { t := reflect.TypeOf((*structCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*structCodec).encodeOmitDefaultStruct = true } @@ -170,7 +170,7 @@ func (e *Encoder) OmitZeroStruct() { // struct tag is not specified. func (e *Encoder) UseJSONStructTags() { t := reflect.TypeOf((*structCodec)(nil)) - if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { v[i].(*structCodec).useJSONStructTags = true } diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 7571abb19f..8a1ff811fd 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -479,7 +479,7 @@ func (t *prefixPtr) SetBSON(raw bson.RawValue) error { return err } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -506,7 +506,7 @@ func (t *prefixVal) SetBSON(raw bson.RawValue) error { return err } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -930,7 +930,7 @@ func (o *setterType) SetBSON(raw bson.RawValue) error { raw.Type = bson.TypeEmbeddedDocument } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -1289,7 +1289,7 @@ func (s *getterSetterD) SetBSON(raw bson.RawValue) error { raw.Type = bson.TypeEmbeddedDocument } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -1315,7 +1315,7 @@ func (i *getterSetterInt) SetBSON(raw bson.RawValue) error { raw.Type = bson.TypeEmbeddedDocument } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } @@ -1337,7 +1337,7 @@ func (s *ifaceSlice) SetBSON(raw bson.RawValue) error { return err } vr := bson.NewBSONValueReader(raw.Type, raw.Value) - err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) + err = decoder.DecodeValue(Registry, vr, rval) if err != nil { return err } diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index a38113b72d..18dcfb71b3 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -468,7 +468,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { type subtest struct { name string val interface{} - dctx *DecodeContext + reg *Registry llvrw *valueReaderWriter invoke invoked err error @@ -563,23 +563,19 @@ func TestPrimitiveValueDecoders(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for _, rc := range tc.subtests { t.Run(rc.name, func(t *testing.T) { - var dc DecodeContext - if rc.dctx != nil { - dc = *rc.dctx - } llvrw := new(valueReaderWriter) if rc.llvrw != nil { llvrw = rc.llvrw } llvrw.T = t if rc.val == cansetreflectiontest { // We're doing a CanSet reflection test - err := tc.vd.DecodeValue(dc, llvrw, reflect.Value{}) + err := tc.vd.DecodeValue(rc.reg, llvrw, reflect.Value{}) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } val := reflect.New(reflect.TypeOf(rc.val)).Elem() - err = tc.vd.DecodeValue(dc, llvrw, val) + err = tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } @@ -596,7 +592,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { panic(err) } }() - err := tc.vd.DecodeValue(dc, llvrw, val) + err := tc.vd.DecodeValue(rc.reg, llvrw, val) if !assert.CompareErrors(err, rc.err) { t.Errorf("Errors do not match. got %v; want %v", err, rc.err) } diff --git a/bson/raw_value.go b/bson/raw_value.go index a32b82e41d..f119cbd9fe 100644 --- a/bson/raw_value.go +++ b/bson/raw_value.go @@ -81,13 +81,13 @@ func (rv RawValue) UnmarshalWithRegistry(r *Registry, val interface{}) error { if err != nil { return err } - return dec.DecodeValue(DecodeContext{Registry: r}, vr, rval) + return dec.DecodeValue(r, vr, rval) } // UnmarshalWithContext performs the same unmarshalling as Unmarshal but uses the provided DecodeContext // instead of the one attached or the default registry. -func (rv RawValue) UnmarshalWithContext(dc *DecodeContext, val interface{}) error { - if dc == nil { +func (rv RawValue) UnmarshalWithContext(reg *Registry, val interface{}) error { + if reg == nil { return ErrNilContext } @@ -97,11 +97,11 @@ func (rv RawValue) UnmarshalWithContext(dc *DecodeContext, val interface{}) erro return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval) } rval = rval.Elem() - dec, err := dc.LookupDecoder(rval.Type()) + dec, err := reg.LookupDecoder(rval.Type()) if err != nil { return err } - return dec.DecodeValue(*dc, vr, rval) + return dec.DecodeValue(reg, vr, rval) } func convertFromCoreValue(v bsoncore.Value) RawValue { diff --git a/bson/raw_value_test.go b/bson/raw_value_test.go index 18598ebe8f..aa2b9a0eb6 100644 --- a/bson/raw_value_test.go +++ b/bson/raw_value_test.go @@ -114,11 +114,11 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} + reg := newTestRegistryBuilder().Build() var val RawValue var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} - got := val.UnmarshalWithContext(&dc, &s) + got := val.UnmarshalWithContext(reg, &s) if !assert.CompareErrors(got, want) { t.Errorf("Expected errors to match. got %v; want %v", got, want) } @@ -126,11 +126,11 @@ func TestRawValue(t *testing.T) { t.Run("Returns DecodeValue error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: NewRegistryBuilder().Build()} + reg := NewRegistryBuilder().Build() val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string want := fmt.Errorf("cannot decode %v into a string type", TypeDouble) - got := val.UnmarshalWithContext(&dc, &s) + got := val.UnmarshalWithContext(reg, &s) if !assert.CompareErrors(got, want) { t.Errorf("Expected errors to match. got %v; want %v", got, want) } @@ -138,11 +138,11 @@ func TestRawValue(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: NewRegistryBuilder().Build()} + reg := NewRegistryBuilder().Build() want := float64(3.14159) val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, want)} var got float64 - err := val.UnmarshalWithContext(&dc, &got) + err := val.UnmarshalWithContext(reg, &got) noerr(t, err) if got != want { t.Errorf("Expected results to match. got %g; want %g", got, want) diff --git a/bson/registry.go b/bson/registry.go index 41a652fa31..fa63a4c7eb 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -222,8 +222,7 @@ func (rb *RegistryBuilder) Build() *Registry { interfaceDecoders: make([]interfaceValueDecoder, 0, len(rb.interfaceDecoders)), typeMap: make(map[Type]reflect.Type), - encoderTypeMap: make(map[reflect.Type][]ValueEncoder), - decoderTypeMap: make(map[reflect.Type][]ValueDecoder), + codecTypeMap: make(map[reflect.Type][]interface{}), } encoderCache := make(map[reflect.Value]ValueEncoder) @@ -234,7 +233,7 @@ func (rb *RegistryBuilder) Build() *Registry { encoder := encFac() encoderCache[reflect.ValueOf(encFac)] = encoder t := reflect.ValueOf(encoder).Type() - r.encoderTypeMap[t] = append(r.encoderTypeMap[t], encoder) + r.codecTypeMap[t] = append(r.codecTypeMap[t], encoder) return encoder } for k, v := range rb.typeEncoders { @@ -261,7 +260,7 @@ func (rb *RegistryBuilder) Build() *Registry { decoder := decFac() decoderCache[reflect.ValueOf(decFac)] = decoder t := reflect.ValueOf(decoder).Type() - r.decoderTypeMap[t] = append(r.decoderTypeMap[t], decoder) + r.codecTypeMap[t] = append(r.codecTypeMap[t], decoder) return decoder } for k, v := range rb.typeDecoders { @@ -329,8 +328,7 @@ type Registry struct { kindDecoders [reflect.UnsafePointer + 1]ValueDecoder typeMap map[Type]reflect.Type - encoderTypeMap map[reflect.Type][]ValueEncoder - decoderTypeMap map[reflect.Type][]ValueDecoder + codecTypeMap map[reflect.Type][]interface{} } // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup diff --git a/bson/registry_test.go b/bson/registry_test.go index c321f48d88..fd66e0cc84 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -93,8 +93,8 @@ func TestRegistryBuilder(t *testing.T) { if !cmp.Equal(c4, 1) { t.Errorf("ef4 is called %d time(s); expected 1", c4) } - codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] - if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + codecs, ok := reg.codecTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.codecTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { t.Errorf("codecs were not cached correctly") } got := make(map[reflect.Type]ValueEncoder) @@ -172,8 +172,8 @@ func TestRegistryBuilder(t *testing.T) { if !cmp.Equal(c4, 1) { t.Errorf("ef4 is called %d time(s); expected 1", c4) } - codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] - if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + codecs, ok := reg.codecTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.codecTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { t.Errorf("codecs were not cached correctly") } got := reg.typeEncoders @@ -246,8 +246,8 @@ func TestRegistryBuilder(t *testing.T) { if !cmp.Equal(c4, 1) { t.Errorf("ef4 is called %d time(s); expected 1", c4) } - codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] - if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + codecs, ok := reg.codecTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.codecTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { t.Errorf("codecs were not cached correctly") } got := reg.kindEncoders diff --git a/bson/string_codec_test.go b/bson/string_codec_test.go index 16d1727d4f..c764af97dc 100644 --- a/bson/string_codec_test.go +++ b/bson/string_codec_test.go @@ -20,20 +20,18 @@ func TestStringCodec(t *testing.T) { reader := &valueReaderWriter{BSONType: TypeObjectID, Return: oid} testCases := []struct { name string - dctx DecodeContext + codec *stringCodec err error result string }{ - {"default", DecodeContext{}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, - {"true", DecodeContext{decodeObjectIDAsHex: true}, nil, oid.Hex()}, - {"false", DecodeContext{decodeObjectIDAsHex: false}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, + {"default", &stringCodec{}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, + {"true", &stringCodec{decodeObjectIDAsHex: true}, nil, oid.Hex()}, + {"false", &stringCodec{decodeObjectIDAsHex: false}, errors.New("cannot decode ObjectID as string if DecodeObjectIDAsHex is not set"), ""}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - stringCodec := &stringCodec{} - actual := reflect.New(reflect.TypeOf("")).Elem() - err := stringCodec.DecodeValue(tc.dctx, reader, actual) + err := tc.codec.DecodeValue(nil, reader, actual) if tc.err == nil { assert.NoError(t, err) } else { diff --git a/bson/time_codec_test.go b/bson/time_codec_test.go index 70f52906b2..fc32339602 100644 --- a/bson/time_codec_test.go +++ b/bson/time_codec_test.go @@ -31,7 +31,7 @@ func TestTimeCodec(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { actual := reflect.New(reflect.TypeOf(now)).Elem() - err := tc.codec.DecodeValue(DecodeContext{}, reader, actual) + err := tc.codec.DecodeValue(nil, reader, actual) assert.Nil(t, err, "TimeCodec.DecodeValue error: %v", err) actualTime := actual.Interface().(time.Time) @@ -65,7 +65,7 @@ func TestTimeCodec(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { actual := reflect.New(reflect.TypeOf(now)).Elem() - err := (&timeCodec{}).DecodeValue(DecodeContext{}, tc.reader, actual) + err := (&timeCodec{}).DecodeValue(nil, tc.reader, actual) assert.Nil(t, err, "DecodeValue error: %v", err) actualTime := actual.Interface().(time.Time) diff --git a/bson/truncation_test.go b/bson/truncation_test.go index a9aeea278b..8f3301e85d 100644 --- a/bson/truncation_test.go +++ b/bson/truncation_test.go @@ -39,12 +39,12 @@ func TestTruncation(t *testing.T) { assert.Nil(t, err) var output outputArgs - dc := DecodeContext{ - Registry: DefaultRegistry, - truncate: true, - } + // dc := DecodeContext{ + // Registry: DefaultRegistry, + // truncate: true, + // } - err = UnmarshalWithContext(dc, buf.Bytes(), &output) + err = UnmarshalWithContext(DefaultRegistry, buf.Bytes(), &output) assert.Nil(t, err) assert.Equal(t, inputName, output.Name) @@ -65,13 +65,13 @@ func TestTruncation(t *testing.T) { assert.Nil(t, err) var output outputArgs - dc := DecodeContext{ - Registry: DefaultRegistry, - truncate: false, - } + // dc := DecodeContext{ + // Registry: DefaultRegistry, + // truncate: false, + // } // case throws an error when truncation is disabled - err = UnmarshalWithContext(dc, buf.Bytes(), &output) + err = UnmarshalWithContext(DefaultRegistry, buf.Bytes(), &output) assert.NotNil(t, err) }) } diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 7caadc5dbc..a02582577e 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -56,7 +56,7 @@ func Unmarshal(data []byte, val interface{}) error { // See [Decoder] for more examples. func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { vr := NewValueReader(data) - return unmarshalFromReader(DecodeContext{Registry: r}, vr, val) + return unmarshalFromReader(r, vr, val) } // UnmarshalWithContext parses the BSON-encoded data using DecodeContext dc and @@ -73,9 +73,9 @@ func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { // dec.DefaultDocumentM() // // See [Decoder] for more examples. -func UnmarshalWithContext(dc DecodeContext, data []byte, val interface{}) error { +func UnmarshalWithContext(reg *Registry, data []byte, val interface{}) error { vr := NewValueReader(data) - return unmarshalFromReader(dc, vr, val) + return unmarshalFromReader(reg, vr, val) } // UnmarshalValue parses the BSON value of type t with bson.DefaultRegistry and @@ -93,7 +93,7 @@ func UnmarshalValue(t Type, data []byte, val interface{}) error { // Go Driver 2.0. func UnmarshalValueWithRegistry(r *Registry, t Type, data []byte, val interface{}) error { vr := NewBSONValueReader(t, data) - return unmarshalFromReader(DecodeContext{Registry: r}, vr, val) + return unmarshalFromReader(r, vr, val) } // UnmarshalExtJSON parses the extended JSON-encoded data and stores the result @@ -126,7 +126,7 @@ func UnmarshalExtJSONWithRegistry(r *Registry, data []byte, canonical bool, val return err } - return unmarshalFromReader(DecodeContext{Registry: r}, ejvr, val) + return unmarshalFromReader(r, ejvr, val) } // UnmarshalExtJSONWithContext parses the extended JSON-encoded data using @@ -147,21 +147,21 @@ func UnmarshalExtJSONWithRegistry(r *Registry, data []byte, canonical bool, val // dec.DefaultDocumentM() // // See [Decoder] for more examples. -func UnmarshalExtJSONWithContext(dc DecodeContext, data []byte, canonical bool, val interface{}) error { +func UnmarshalExtJSONWithContext(reg *Registry, data []byte, canonical bool, val interface{}) error { ejvr, err := NewExtJSONValueReader(bytes.NewReader(data), canonical) if err != nil { return err } - return unmarshalFromReader(dc, ejvr, val) + return unmarshalFromReader(reg, ejvr, val) } -func unmarshalFromReader(dc DecodeContext, vr ValueReader, val interface{}) error { +func unmarshalFromReader(reg *Registry, vr ValueReader, val interface{}) error { dec := decPool.Get().(*Decoder) defer decPool.Put(dec) dec.Reset(vr) - dec.dc = dc + dec.reg = reg return dec.Decode(val) } diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index dc8102b02c..9748d8a6db 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -69,9 +69,8 @@ func TestUnmarshalWithContext(t *testing.T) { copy(data, tc.data) // Assert that unmarshaling the input data results in the expected value. - dc := DecodeContext{Registry: DefaultRegistry} got := reflect.New(tc.sType).Interface() - err := UnmarshalWithContext(dc, data, got) + err := UnmarshalWithContext(DefaultRegistry, data, got) noerr(t, err) assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") @@ -199,8 +198,7 @@ func TestUnmarshalExtJSONWithContext(t *testing.T) { // Assert that unmarshaling the input data results in the expected value. got := reflect.New(tc.sType).Interface() - dc := DecodeContext{Registry: DefaultRegistry} - err := UnmarshalExtJSONWithContext(dc, data, true, got) + err := UnmarshalExtJSONWithContext(DefaultRegistry, data, true, got) noerr(t, err) assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") From 5ce26704aac115fe5f54d24f98c6b0eae0840d44 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 21 May 2024 12:32:51 -0400 Subject: [PATCH 10/15] WIP --- bson/decoder.go | 4 +- bson/decoder_test.go | 26 ++++---- bson/default_value_decoders.go | 4 ++ bson/default_value_decoders_test.go | 98 ++++++++++++++++++++--------- bson/empty_interface_codec.go | 10 +-- bson/encoder.go | 6 +- bson/encoder_test.go | 2 +- bson/int_codec.go | 26 +++++--- bson/marshal.go | 8 +-- bson/marshal_test.go | 4 +- bson/mgoregistry.go | 2 +- bson/raw_value.go | 2 +- bson/registry.go | 12 ++-- bson/registry_examples_test.go | 8 ++- bson/slice_codec.go | 2 +- bson/struct_codec.go | 34 +++++++--- bson/truncation_test.go | 8 +-- bson/unmarshal.go | 8 +-- bson/unmarshal_test.go | 10 +-- bson/unmarshal_value_test.go | 4 +- mongo/client.go | 2 +- mongo/cursor.go | 6 +- mongo/gridfs_bucket.go | 2 +- mongo/mongo.go | 4 +- mongo/options/gridfsoptions.go | 2 +- mongo/options/mongooptions.go | 4 +- mongo/single_result.go | 2 +- mongo/single_result_test.go | 8 +-- 28 files changed, 196 insertions(+), 112 deletions(-) diff --git a/bson/decoder.go b/bson/decoder.go index 14335c5bb5..ae335fc8ff 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -32,10 +32,10 @@ type Decoder struct { vr ValueReader } -// NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr. +// NewDecoder returns a new decoder that uses the default registry to read from vr. func NewDecoder(vr ValueReader) *Decoder { return &Decoder{ - reg: DefaultRegistry, + reg: NewRegistryBuilder().Build(), vr: vr, } } diff --git a/bson/decoder_test.go b/bson/decoder_test.go index b101b38d65..973e48b869 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -196,19 +196,19 @@ func TestDecoderv2(t *testing.T) { t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr2) } }) - t.Run("SetRegistry", func(t *testing.T) { - t.Parallel() - - r1, r2 := DefaultRegistry, NewRegistryBuilder().Build() - dec := NewDecoder(NewValueReader([]byte{})) - if !reflect.DeepEqual(dec.reg, r1) { - t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r1) - } - dec.SetRegistry(r2) - if !reflect.DeepEqual(dec.reg, r2) { - t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r2) - } - }) + // t.Run("SetRegistry", func(t *testing.T) { + // t.Parallel() + + // r1, r2 := DefaultRegistry, NewRegistryBuilder().Build() + // dec := NewDecoder(NewValueReader([]byte{})) + // if !reflect.DeepEqual(dec.reg, r1) { + // t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r1) + // } + // dec.SetRegistry(r2) + // if !reflect.DeepEqual(dec.reg, r2) { + // t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r2) + // } + // }) t.Run("DecodeToNil", func(t *testing.T) { t.Parallel() diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index ec8b3c8730..5bb49cc60c 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -1105,6 +1105,9 @@ func decodeDefault(reg DecoderRegistry, vr ValueReader, val reflect.Value) ([]re if err != nil { return nil, newDecodeError(strconv.Itoa(idx), err) } + if elem.Type() != eType { + elem = elem.Convert(eType) + } elems = append(elems, elem) idx++ } @@ -1175,6 +1178,7 @@ func codeWithScopeDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.V func decodeD(reg DecoderRegistry, vr ValueReader, val reflect.Value) ([]reflect.Value, error) { switch vr.Type() { case Type(0), TypeEmbeddedDocument: + break default: return nil, fmt.Errorf("cannot decode %v into a D", vr.Type()) } diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 42602e3cbe..f931c16f6e 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -148,8 +148,11 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, @@ -214,8 +217,11 @@ func TestDefaultValueDecoders(t *testing.T) { "int8/fast path - nil", (*int8)(nil), nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*int8)(nil)), }, }, @@ -223,8 +229,11 @@ func TestDefaultValueDecoders(t *testing.T) { "int16/fast path - nil", (*int16)(nil), nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*int16)(nil)), }, }, @@ -232,8 +241,11 @@ func TestDefaultValueDecoders(t *testing.T) { "int32/fast path - nil", (*int32)(nil), nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*int32)(nil)), }, }, @@ -241,8 +253,11 @@ func TestDefaultValueDecoders(t *testing.T) { "int64/fast path - nil", (*int64)(nil), nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*int64)(nil)), }, }, @@ -250,8 +265,11 @@ func TestDefaultValueDecoders(t *testing.T) { "int/fast path - nil", (*int)(nil), nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*int)(nil)), }, }, @@ -347,8 +365,11 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, }, }, { @@ -380,8 +401,11 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, @@ -446,8 +470,11 @@ func TestDefaultValueDecoders(t *testing.T) { "uint8/fast path - nil", (*uint8)(nil), nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*uint8)(nil)), }, }, @@ -455,8 +482,11 @@ func TestDefaultValueDecoders(t *testing.T) { "uint16/fast path - nil", (*uint16)(nil), nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*uint16)(nil)), }, }, @@ -464,8 +494,11 @@ func TestDefaultValueDecoders(t *testing.T) { "uint32/fast path - nil", (*uint32)(nil), nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*uint32)(nil)), }, }, @@ -473,8 +506,11 @@ func TestDefaultValueDecoders(t *testing.T) { "uint64/fast path - nil", (*uint64)(nil), nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*uint64)(nil)), }, }, @@ -482,8 +518,11 @@ func TestDefaultValueDecoders(t *testing.T) { "uint/fast path - nil", (*uint)(nil), nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*uint)(nil)), }, }, @@ -599,8 +638,11 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, }, }, }, diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index e428176d2d..cf30014859 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -24,7 +24,7 @@ type emptyInterfaceCodec struct { } // EncodeValue is the ValueEncoderFunc for interface{}. -func (eic emptyInterfaceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { +func (eic *emptyInterfaceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tEmpty { return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} } @@ -40,7 +40,7 @@ func (eic emptyInterfaceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, return encoder.EncodeValue(reg, vw, val.Elem()) } -func (eic emptyInterfaceCodec) getEmptyInterfaceDecodeType(reg DecoderRegistry, valueType Type, ancestorType reflect.Type) (reflect.Type, error) { +func (eic *emptyInterfaceCodec) getEmptyInterfaceDecodeType(reg DecoderRegistry, valueType Type, ancestorType reflect.Type) (reflect.Type, error) { isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument if isDocument { if eic.defaultDocumentType != nil { @@ -81,12 +81,12 @@ func (eic emptyInterfaceCodec) getEmptyInterfaceDecodeType(reg DecoderRegistry, return nil, err } -func (eic emptyInterfaceCodec) decodeType(reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (eic *emptyInterfaceCodec) decodeType(reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { rtype, err := eic.getEmptyInterfaceDecodeType(reg, vr.Type(), t) if err != nil { switch vr.Type() { case TypeNull: - return reflect.Zero(t), vr.ReadNull() + return reflect.Zero(tEmpty), vr.ReadNull() default: return emptyValue, err } @@ -116,7 +116,7 @@ func (eic emptyInterfaceCodec) decodeType(reg DecoderRegistry, vr ValueReader, t } // DecodeValue is the ValueDecoderFunc for interface{}. -func (eic emptyInterfaceCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { +func (eic *emptyInterfaceCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() || val.Type() != tEmpty { return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} } diff --git a/bson/encoder.go b/bson/encoder.go index dba91b7424..1b90e9e948 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -27,10 +27,10 @@ type Encoder struct { vw ValueWriter } -// NewEncoder returns a new encoder that uses the DefaultRegistry to write to vw. +// NewEncoder returns a new encoder that uses the default registry to write to vw. func NewEncoder(vw ValueWriter) *Encoder { return &Encoder{ - reg: DefaultRegistry, + reg: NewRegistryBuilder().Build(), vw: vw, } } @@ -95,7 +95,7 @@ func (e *Encoder) IntMinSize() { t := reflect.TypeOf((*intCodec)(nil)) if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { - v[i].(*intCodec).encodeToMinSize = true + v[i].(*intCodec).minSize = true } } } diff --git a/bson/encoder_test.go b/bson/encoder_test.go index 15cce55700..2dff4fbfdc 100644 --- a/bson/encoder_test.go +++ b/bson/encoder_test.go @@ -22,7 +22,7 @@ func TestBasicEncode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { got := make(SliceWriter, 0, 1024) vw := NewValueWriter(&got) - reg := DefaultRegistry + reg := NewRegistryBuilder().Build() encoder, err := reg.LookupEncoder(reflect.TypeOf(tc.val)) noerr(t, err) err = encoder.EncodeValue(reg, vw, reflect.ValueOf(tc.val)) diff --git a/bson/int_codec.go b/bson/int_codec.go index 4caff3aa5a..a7edab4e58 100644 --- a/bson/int_codec.go +++ b/bson/int_codec.go @@ -14,9 +14,14 @@ import ( // intCodec is the Codec used for uint values. type intCodec struct { - // encodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the + // minSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, + // uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) + // that can represent the integer value. + minSize bool + + // encodeUintToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the // minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value. - encodeToMinSize bool + encodeUintToMinSize bool // truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" // values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64, @@ -38,7 +43,7 @@ func (ic *intCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.V return vw.WriteInt64(i64) case reflect.Int64: i64 := val.Int() - if ic.encodeToMinSize && fitsIn32Bits(i64) { + if ic.minSize && fitsIn32Bits(i64) { return vw.WriteInt32(int32(i64)) } return vw.WriteInt64(i64) @@ -48,8 +53,8 @@ func (ic *intCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.V case reflect.Uint, reflect.Uint32, reflect.Uint64: u64 := val.Uint() - // If encodeToMinSize is true for a non-uint64 value we should write val as an int32 - useMinSize := ic.encodeToMinSize && val.Kind() != reflect.Uint64 + // If minSize or encodeToMinSize is true for a non-uint64 value we should write val as an int32 + useMinSize := ic.minSize || (ic.encodeUintToMinSize && val.Kind() != reflect.Uint64) if u64 <= math.MaxInt32 && useMinSize { return vw.WriteInt32(int32(u64)) @@ -183,8 +188,11 @@ func (ic *intCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type func (ic *intCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() { return ValueDecoderError{ - Name: "IntDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: val, } } @@ -194,6 +202,10 @@ func (ic *intCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect return err } + if t := val.Type(); elem.Type() != t { + elem = elem.Convert(t) + } + val.Set(elem) return nil } diff --git a/bson/marshal.go b/bson/marshal.go index 32151cc465..a26270a158 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -74,7 +74,7 @@ func Marshal(val interface{}) ([]byte, error) { enc := encPool.Get().(*Encoder) defer encPool.Put(enc) enc.Reset(vw) - enc.SetRegistry(DefaultRegistry) + enc.SetRegistry(NewRegistryBuilder().Build()) err := enc.Encode(val) if err != nil { return nil, err @@ -85,10 +85,10 @@ func Marshal(val interface{}) ([]byte, error) { // MarshalValue returns the BSON encoding of val. // -// MarshalValue will use bson.DefaultRegistry to transform val into a BSON value. If val is a struct, this function will +// MarshalValue will use default registry to transform val into a BSON value. If val is a struct, this function will // inspect struct tags and alter the marshalling process accordingly. func MarshalValue(val interface{}) (Type, []byte, error) { - return MarshalValueWithRegistry(DefaultRegistry, val) + return MarshalValueWithRegistry(NewRegistryBuilder().Build(), val) } // MarshalValueWithRegistry returns the BSON encoding of val using Registry r. @@ -127,7 +127,7 @@ func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error) defer encPool.Put(enc) enc.Reset(ejvw) - enc.SetRegistry(DefaultRegistry) + enc.SetRegistry(NewRegistryBuilder().Build()) err := enc.Encode(val) if err != nil { diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 5eeada562c..82d07d99f2 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -28,7 +28,7 @@ func TestMarshalWithRegistry(t *testing.T) { if tc.reg != nil { reg = tc.reg } else { - reg = DefaultRegistry + reg = NewRegistryBuilder().Build() } buf := new(bytes.Buffer) vw := NewValueWriter(buf) @@ -52,7 +52,7 @@ func TestMarshalWithContext(t *testing.T) { if tc.reg != nil { reg = tc.reg } else { - reg = DefaultRegistry + reg = NewRegistryBuilder().Build() } buf := new(bytes.Buffer) vw := NewValueWriter(buf) diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index 9d9255f3bd..1efac62e92 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -34,7 +34,7 @@ func newMgoRegistryBuilder() *RegistryBuilder { encodeNilAsEmpty: true, encodeKeysWithStringer: true, } - intcodec := func() ValueEncoder { return &intCodec{encodeToMinSize: true} } + intcodec := func() ValueEncoder { return &intCodec{encodeUintToMinSize: true} } return NewRegistryBuilder(). RegisterTypeDecoder(tEmpty, func() ValueDecoder { return &emptyInterfaceCodec{decodeBinaryAsSlice: true} }). diff --git a/bson/raw_value.go b/bson/raw_value.go index f119cbd9fe..732379e118 100644 --- a/bson/raw_value.go +++ b/bson/raw_value.go @@ -46,7 +46,7 @@ func (rv RawValue) IsZero() bool { func (rv RawValue) Unmarshal(val interface{}) error { reg := rv.r if reg == nil { - reg = DefaultRegistry + reg = NewRegistryBuilder().Build() } return rv.UnmarshalWithRegistry(reg, val) } diff --git a/bson/registry.go b/bson/registry.go index fa63a4c7eb..5eaa2fecee 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -12,10 +12,6 @@ import ( "sync" ) -// DefaultRegistry is the default Registry. It contains the default codecs and the -// primitive codecs. -var DefaultRegistry = NewRegistryBuilder().Build() - // ErrNoEncoder is returned when there wasn't an encoder available for a type. // // Deprecated: ErrNoEncoder will not be supported in Go Driver 2.0. @@ -38,7 +34,13 @@ type ErrNoDecoder struct { } func (end ErrNoDecoder) Error() string { - return "no decoder found for " + end.Type.String() + var typeStr string + if end.Type != nil { + typeStr = end.Type.String() + } else { + typeStr = "nil type" + } + return "no decoder found for " + typeStr } // ErrNoTypeMapEntry is returned when there wasn't a type available for the provided BSON type. diff --git a/bson/registry_examples_test.go b/bson/registry_examples_test.go index 35b5016eba..b8b1010c9f 100644 --- a/bson/registry_examples_test.go +++ b/bson/registry_examples_test.go @@ -135,7 +135,9 @@ func ExampleRegistry_customDecoder() { reg := bson.NewRegistryBuilder() reg.RegisterTypeDecoder( lenientBoolType, - func() bson.ValueDecoder { return bson.ValueDecoderFunc(lenientBoolDecoder) }, + func() bson.ValueDecoder { + return bson.ValueDecoderFunc(lenientBoolDecoder) + }, ) // Marshal a BSON document with a single field "isOK" that is a non-zero @@ -280,7 +282,9 @@ func ExampleRegistryBuilder_RegisterKindDecoder() { reg := bson.NewRegistryBuilder() reg.RegisterKindDecoder( reflect.Int64, - func() bson.ValueDecoder { return bson.ValueDecoderFunc(flexibleInt64KindDecoder) }, + func() bson.ValueDecoder { + return bson.ValueDecoderFunc(flexibleInt64KindDecoder) + }, ) // Marshal a BSON document with fields that are mixed numeric types but all diff --git a/bson/slice_codec.go b/bson/slice_codec.go index 71aaf32b93..f08f8100d6 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -20,7 +20,7 @@ type sliceCodec struct { } // EncodeValue is the ValueEncoder for slice types. -func (sc sliceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { +func (sc *sliceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Slice { return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } diff --git a/bson/struct_codec.go b/bson/struct_codec.go index ac917a5b17..fb4d36f258 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -83,6 +83,28 @@ func newStructCodec(p StructTagParser) *structCodec { } } +type localEncoderRegistry struct { + registry EncoderRegistry + + minSize bool +} + +func (r *localEncoderRegistry) LookupEncoder(t reflect.Type) (ValueEncoder, error) { + ve, err := r.registry.LookupEncoder(t) + if err != nil { + return ve, err + } + if r.minSize { + if ic, ok := ve.(*intCodec); ok { + ve = &intCodec{ + minSize: true, + truncate: ic.truncate, + } + } + } + return ve, nil +} + // EncodeValue handles encoding generic struct types. func (sc *structCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Struct { @@ -109,6 +131,11 @@ func (sc *structCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val refl } } + reg = &localEncoderRegistry{ + registry: reg, + minSize: desc.minSize, + } + var encoder ValueEncoder if encoder, err = reg.LookupEncoder(desc.fieldType); err != nil { encoder = nil @@ -158,13 +185,6 @@ func (sc *structCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val refl return err } - // defaultUIntCodec.encodeToMinSize = desc.minSize - if v, ok := encoder.(*intCodec); ok { - encoder = &intCodec{ - encodeToMinSize: v.encodeToMinSize || desc.minSize, - } - } - err = encoder.EncodeValue(reg, vw2, rv) if err != nil { return err diff --git a/bson/truncation_test.go b/bson/truncation_test.go index 8f3301e85d..36d66e9b38 100644 --- a/bson/truncation_test.go +++ b/bson/truncation_test.go @@ -34,7 +34,7 @@ func TestTruncation(t *testing.T) { vw := NewValueWriter(buf) enc := NewEncoder(vw) enc.IntMinSize() - enc.SetRegistry(DefaultRegistry) + enc.SetRegistry(NewRegistryBuilder().Build()) err := enc.Encode(&input) assert.Nil(t, err) @@ -44,7 +44,7 @@ func TestTruncation(t *testing.T) { // truncate: true, // } - err = UnmarshalWithContext(DefaultRegistry, buf.Bytes(), &output) + err = UnmarshalWithContext(NewRegistryBuilder().Build(), buf.Bytes(), &output) assert.Nil(t, err) assert.Equal(t, inputName, output.Name) @@ -60,7 +60,7 @@ func TestTruncation(t *testing.T) { vw := NewValueWriter(buf) enc := NewEncoder(vw) enc.IntMinSize() - enc.SetRegistry(DefaultRegistry) + enc.SetRegistry(NewRegistryBuilder().Build()) err := enc.Encode(&input) assert.Nil(t, err) @@ -71,7 +71,7 @@ func TestTruncation(t *testing.T) { // } // case throws an error when truncation is disabled - err = UnmarshalWithContext(DefaultRegistry, buf.Bytes(), &output) + err = UnmarshalWithContext(NewRegistryBuilder().Build(), buf.Bytes(), &output) assert.NotNil(t, err) }) } diff --git a/bson/unmarshal.go b/bson/unmarshal.go index a02582577e..48bac97643 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -38,7 +38,7 @@ type ValueUnmarshaler interface { // pointed to by val. If val is nil or not a pointer, Unmarshal returns // InvalidUnmarshalError. func Unmarshal(data []byte, val interface{}) error { - return UnmarshalWithRegistry(DefaultRegistry, data, val) + return UnmarshalWithRegistry(NewRegistryBuilder().Build(), data, val) } // UnmarshalWithRegistry parses the BSON-encoded data using Registry r and @@ -78,11 +78,11 @@ func UnmarshalWithContext(reg *Registry, data []byte, val interface{}) error { return unmarshalFromReader(reg, vr, val) } -// UnmarshalValue parses the BSON value of type t with bson.DefaultRegistry and +// UnmarshalValue parses the BSON value of type t with default registry and // stores the result in the value pointed to by val. If val is nil or not a pointer, // UnmarshalValue returns an error. func UnmarshalValue(t Type, data []byte, val interface{}) error { - return UnmarshalValueWithRegistry(DefaultRegistry, t, data, val) + return UnmarshalValueWithRegistry(NewRegistryBuilder().Build(), t, data, val) } // UnmarshalValueWithRegistry parses the BSON value of type t with registry r and @@ -100,7 +100,7 @@ func UnmarshalValueWithRegistry(r *Registry, t Type, data []byte, val interface{ // in the value pointed to by val. If val is nil or not a pointer, Unmarshal // returns InvalidUnmarshalError. func UnmarshalExtJSON(data []byte, canonical bool, val interface{}) error { - return UnmarshalExtJSONWithRegistry(DefaultRegistry, data, canonical, val) + return UnmarshalExtJSONWithRegistry(NewRegistryBuilder().Build(), data, canonical, val) } // UnmarshalExtJSONWithRegistry parses the extended JSON-encoded data using diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index 9748d8a6db..1abf59fd48 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -48,7 +48,7 @@ func TestUnmarshalWithRegistry(t *testing.T) { // Assert that unmarshaling the input data results in the expected value. got := reflect.New(tc.sType).Interface() - err := UnmarshalWithRegistry(DefaultRegistry, data, got) + err := UnmarshalWithRegistry(NewRegistryBuilder().Build(), data, got) noerr(t, err) assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") @@ -70,7 +70,7 @@ func TestUnmarshalWithContext(t *testing.T) { // Assert that unmarshaling the input data results in the expected value. got := reflect.New(tc.sType).Interface() - err := UnmarshalWithContext(DefaultRegistry, data, got) + err := UnmarshalWithContext(NewRegistryBuilder().Build(), data, got) noerr(t, err) assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") @@ -88,7 +88,7 @@ func TestUnmarshalExtJSONWithRegistry(t *testing.T) { type teststruct struct{ Foo int } var got teststruct data := []byte("{\"foo\":1}") - err := UnmarshalExtJSONWithRegistry(DefaultRegistry, data, true, &got) + err := UnmarshalExtJSONWithRegistry(NewRegistryBuilder().Build(), data, true, &got) noerr(t, err) want := teststruct{1} assert.Equal(t, want, got, "Did not unmarshal as expected.") @@ -96,7 +96,7 @@ func TestUnmarshalExtJSONWithRegistry(t *testing.T) { t.Run("UnmarshalExtJSONInvalidInput", func(t *testing.T) { data := []byte("invalid") - err := UnmarshalExtJSONWithRegistry(DefaultRegistry, data, true, &M{}) + err := UnmarshalExtJSONWithRegistry(NewRegistryBuilder().Build(), data, true, &M{}) if !errors.Is(err, ErrInvalidJSON) { t.Fatalf("wanted ErrInvalidJSON, got %v", err) } @@ -198,7 +198,7 @@ func TestUnmarshalExtJSONWithContext(t *testing.T) { // Assert that unmarshaling the input data results in the expected value. got := reflect.New(tc.sType).Interface() - err := UnmarshalExtJSONWithContext(DefaultRegistry, data, true, got) + err := UnmarshalExtJSONWithContext(NewRegistryBuilder().Build(), data, true, got) noerr(t, err) assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index 05524f658a..a4d27eca01 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -36,7 +36,7 @@ func TestUnmarshalValue(t *testing.T) { }) } }) - t.Run("UnmarshalValueWithRegistry with DefaultRegistry", func(t *testing.T) { + t.Run("UnmarshalValueWithRegistry with default registry", func(t *testing.T) { t.Parallel() for _, tc := range unmarshalValueTestCases { @@ -46,7 +46,7 @@ func TestUnmarshalValue(t *testing.T) { t.Parallel() gotValue := reflect.New(reflect.TypeOf(tc.val)) - err := UnmarshalValueWithRegistry(DefaultRegistry, tc.bsontype, tc.bytes, gotValue.Interface()) + err := UnmarshalValueWithRegistry(NewRegistryBuilder().Build(), tc.bsontype, tc.bytes, gotValue.Interface()) assert.Nil(t, err, "UnmarshalValueWithRegistry error: %v", err) assert.Equal(t, tc.val, gotValue.Elem().Interface(), "value mismatch; expected %s, got %s", tc.val, gotValue.Elem()) }) diff --git a/mongo/client.go b/mongo/client.go index 6cba70ce8d..f87eefd133 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -166,7 +166,7 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { client.bsonOpts = clientOpt.BSONOptions } // Registry - client.registry = bson.DefaultRegistry + client.registry = bson.NewRegistryBuilder().Build() if clientOpt.Registry != nil { client.registry = clientOpt.Registry } diff --git a/mongo/cursor.go b/mongo/cursor.go index 8f07b1ee9b..8d1e58f6f9 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -55,7 +55,7 @@ func newCursorWithSession( clientSession *session.Client, ) (*Cursor, error) { if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } if bc == nil { return nil, errors.New("batch cursor must not be nil") @@ -82,12 +82,12 @@ func newEmptyCursor() *Cursor { } // NewCursorFromDocuments creates a new Cursor pre-loaded with the provided documents, error and registry. If no registry is provided, -// bson.DefaultRegistry will be used. +// a default registry will be used. // // The documents parameter must be a slice of documents. The slice may be nil or empty, but all elements must be non-nil. func NewCursorFromDocuments(documents []interface{}, preloadedErr error, registry *bson.Registry) (*Cursor, error) { if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } buf := new(bytes.Buffer) diff --git a/mongo/gridfs_bucket.go b/mongo/gridfs_bucket.go index 7c2bbac64e..48c02f8716 100644 --- a/mongo/gridfs_bucket.go +++ b/mongo/gridfs_bucket.go @@ -613,7 +613,7 @@ func (b *GridFSBucket) parseUploadOptions(opts ...*options.UploadOptions) (*uplo upload.chunkSize = *uo.ChunkSizeBytes } if uo.Registry == nil { - uo.Registry = bson.DefaultRegistry + uo.Registry = bson.NewRegistryBuilder().Build() } if uo.Metadata != nil { // TODO(GODRIVER-2726): Replace with marshal() and unmarshal() once the diff --git a/mongo/mongo.go b/mongo/mongo.go index 318c765000..ff499556dc 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -118,7 +118,7 @@ func marshal( registry *bson.Registry, ) (bsoncore.Document, error) { if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } if val == nil { return nil, ErrNilDocument @@ -156,7 +156,7 @@ func ensureID( reg *bson.Registry, ) (bsoncore.Document, interface{}, error) { if reg == nil { - reg = bson.DefaultRegistry + reg = bson.NewRegistryBuilder().Build() } // Try to find the "_id" element. If it exists, try to unmarshal just the diff --git a/mongo/options/gridfsoptions.go b/mongo/options/gridfsoptions.go index 10d454c89d..47f97a5a51 100644 --- a/mongo/options/gridfsoptions.go +++ b/mongo/options/gridfsoptions.go @@ -99,7 +99,7 @@ type UploadOptions struct { // GridFSUpload creates a new UploadOptions instance. func GridFSUpload() *UploadOptions { - return &UploadOptions{Registry: bson.DefaultRegistry} + return &UploadOptions{Registry: bson.NewRegistryBuilder().Build()} } // SetChunkSizeBytes sets the value for the ChunkSize field. diff --git a/mongo/options/mongooptions.go b/mongo/options/mongooptions.go index 2279f66d0d..bfe9ad523b 100644 --- a/mongo/options/mongooptions.go +++ b/mongo/options/mongooptions.go @@ -128,7 +128,7 @@ type ArrayFilters struct { func (af *ArrayFilters) ToArray() ([]bson.Raw, error) { registry := af.Registry if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } filters := make([]bson.Raw, 0, len(af.Filters)) buf := new(bytes.Buffer) @@ -154,7 +154,7 @@ func (af *ArrayFilters) ToArray() ([]bson.Raw, error) { func (af *ArrayFilters) ToArrayDocument() (bson.Raw, error) { registry := af.Registry if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } idx, arr := bsoncore.AppendArrayStart(nil) diff --git a/mongo/single_result.go b/mongo/single_result.go index e0639e4069..f467666167 100644 --- a/mongo/single_result.go +++ b/mongo/single_result.go @@ -40,7 +40,7 @@ func NewSingleResultFromDocument(document interface{}, err error, registry *bson return &SingleResult{err: ErrNilDocument} } if registry == nil { - registry = bson.DefaultRegistry + registry = bson.NewRegistryBuilder().Build() } cur, createErr := NewCursorFromDocuments([]interface{}{document}, err, registry) diff --git a/mongo/single_result_test.go b/mongo/single_result_test.go index 1338fe90c6..ae1e49a7a0 100644 --- a/mongo/single_result_test.go +++ b/mongo/single_result_test.go @@ -22,10 +22,10 @@ func TestSingleResult(t *testing.T) { t.Run("Decode", func(t *testing.T) { t.Run("decode twice", func(t *testing.T) { // Test that Decode and Raw can be called more than once - c, err := newCursor(newTestBatchCursor(1, 1), nil, bson.DefaultRegistry) + c, err := newCursor(newTestBatchCursor(1, 1), nil, bson.NewRegistryBuilder().Build()) assert.Nil(t, err, "newCursor error: %v", err) - sr := &SingleResult{cur: c, reg: bson.DefaultRegistry} + sr := &SingleResult{cur: c, reg: c.registry} var firstDecode, secondDecode bson.Raw err = sr.Decode(&firstDecode) assert.Nil(t, err, "Decode error: %v", err) @@ -47,7 +47,7 @@ func TestSingleResult(t *testing.T) { assert.Equal(t, sr.err, err, "expected error %v, got %v", sr.err, err) }) t.Run("with BSONOptions", func(t *testing.T) { - c, err := newCursor(newTestBatchCursor(1, 1), nil, bson.DefaultRegistry) + c, err := newCursor(newTestBatchCursor(1, 1), nil, bson.NewRegistryBuilder().Build()) require.NoError(t, err, "newCursor error") sr := &SingleResult{ @@ -55,7 +55,7 @@ func TestSingleResult(t *testing.T) { bsonOpts: &options.BSONOptions{ UseJSONStructTags: true, }, - reg: bson.DefaultRegistry, + reg: c.registry, } type myDocument struct { From 6ca8e114d57160ca7ae014d492cf144d5654a5bb Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 22 May 2024 17:18:58 -0400 Subject: [PATCH 11/15] WIP --- bson/bson_test.go | 3 +- bson/decoder.go | 32 +++------- bson/decoder_test.go | 26 -------- bson/default_value_decoders.go | 2 +- bson/encoder.go | 33 ++++------ bson/marshal.go | 18 +----- bson/marshal_test.go | 31 +--------- bson/mgocompat/bson_test.go | 107 +++++++-------------------------- bson/registry_examples_test.go | 38 ++++++------ bson/struct_codec.go | 13 +++- bson/truncation_test.go | 6 +- bson/unmarshal.go | 8 +-- mongo/cursor.go | 17 +++--- mongo/gridfs_bucket.go | 3 +- mongo/mongo.go | 11 ++-- mongo/options/mongooptions.go | 8 +-- 16 files changed, 99 insertions(+), 257 deletions(-) diff --git a/bson/bson_test.go b/bson/bson_test.go index 246b0e913a..4f41c28b7f 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -362,8 +362,7 @@ func TestMapCodec(t *testing.T) { mapRegistry.RegisterKindEncoder(reflect.Map, func() ValueEncoder { return tc.codec }) buf := new(bytes.Buffer) vw := NewValueWriter(buf) - enc := NewEncoder(vw) - enc.SetRegistry(mapRegistry.Build()) + enc := NewEncoderWithRegistry(mapRegistry.Build(), vw) err := enc.Encode(mapObj) assert.Nil(t, err, "Encode error: %v", err) str := buf.String() diff --git a/bson/decoder.go b/bson/decoder.go index ae335fc8ff..62714d1f4b 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -10,21 +10,11 @@ import ( "errors" "fmt" "reflect" - "sync" ) // ErrDecodeToNil is the error returned when trying to decode to a nil value var ErrDecodeToNil = errors.New("cannot Decode to nil value") -// This pool is used to keep the allocations of Decoders down. This is only used for the Marshal* -// methods and is not consumable from outside of this package. The Decoders retrieved from this pool -// must have both Reset and SetRegistry called on them. -var decPool = sync.Pool{ - New: func() interface{} { - return new(Decoder) - }, -} - // A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as // the source of BSON data. type Decoder struct { @@ -34,8 +24,17 @@ type Decoder struct { // NewDecoder returns a new decoder that uses the default registry to read from vr. func NewDecoder(vr ValueReader) *Decoder { + r := NewRegistryBuilder().Build() return &Decoder{ - reg: NewRegistryBuilder().Build(), + reg: r, + vr: vr, + } +} + +// NewDecoderWithRegistry returns a new decoder that uses the given registry to read from vr. +func NewDecoderWithRegistry(r *Registry, vr ValueReader) *Decoder { + return &Decoder{ + reg: r, vr: vr, } } @@ -76,17 +75,6 @@ func (d *Decoder) Decode(val interface{}) error { return decoder.DecodeValue(d.reg, d.vr, rval) } -// Reset will reset the state of the decoder, using the same *DecodeContext used in -// the original construction but using vr for reading. -func (d *Decoder) Reset(vr ValueReader) { - d.vr = vr -} - -// SetRegistry replaces the current registry of the decoder with r. -func (d *Decoder) SetRegistry(r *Registry) { - d.reg = r -} - // DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This // behavior is restricted to data typed as "interface{}" or "map[string]interface{}". func (d *Decoder) DefaultDocumentM() { diff --git a/bson/decoder_test.go b/bson/decoder_test.go index 973e48b869..ccae4f2bb3 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -183,32 +183,6 @@ func TestDecoderv2(t *testing.T) { want := foo{Item: "canvas", Qty: 4, Bonus: 2} assert.Equal(t, want, got, "Results do not match.") }) - t.Run("Reset", func(t *testing.T) { - t.Parallel() - - vr1, vr2 := NewValueReader([]byte{}), NewValueReader([]byte{}) - dec := NewDecoder(vr1) - if dec.vr != vr1 { - t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr1) - } - dec.Reset(vr2) - if dec.vr != vr2 { - t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr2) - } - }) - // t.Run("SetRegistry", func(t *testing.T) { - // t.Parallel() - - // r1, r2 := DefaultRegistry, NewRegistryBuilder().Build() - // dec := NewDecoder(NewValueReader([]byte{})) - // if !reflect.DeepEqual(dec.reg, r1) { - // t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r1) - // } - // dec.SetRegistry(r2) - // if !reflect.DeepEqual(dec.reg, r2) { - // t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r2) - // } - // }) t.Run("DecodeToNil", func(t *testing.T) { t.Parallel() diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 5bb49cc60c..08fec2dd61 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -151,7 +151,7 @@ func dDecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error return err } - elem, err := decodeTypeOrValueWithInfo(decoder, reg, elemVr, tEmpty) + elem, err := decodeTypeOrValueWithInfo(decoder, reg, elemVr, tD) if err != nil { return err } diff --git a/bson/encoder.go b/bson/encoder.go index 1b90e9e948..1e1965cd91 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -8,18 +8,8 @@ package bson import ( "reflect" - "sync" ) -// This pool is used to keep the allocations of Encoders down. This is only used for the Marshal* -// methods and is not consumable from outside of this package. The Encoders retrieved from this pool -// must have both Reset and SetRegistry called on them. -var encPool = sync.Pool{ - New: func() interface{} { - return new(Encoder) - }, -} - // An Encoder writes a serialization format to an output stream. It writes to a ValueWriter // as the destination of BSON data. type Encoder struct { @@ -35,6 +25,14 @@ func NewEncoder(vw ValueWriter) *Encoder { } } +// NewEncoderWithRegistry returns a new encoder that uses the given registry to write to vw. +func NewEncoderWithRegistry(r *Registry, vw ValueWriter) *Encoder { + return &Encoder{ + reg: r, + vw: vw, + } +} + // Encode writes the BSON encoding of val to the stream. // // See [Marshal] for details about BSON marshaling behavior. @@ -56,17 +54,6 @@ func (e *Encoder) Encode(val interface{}) error { return encoder.EncodeValue(e.reg, e.vw, reflect.ValueOf(val)) } -// Reset will reset the state of the Encoder, using the same *EncodeContext used in -// the original construction but using vw. -func (e *Encoder) Reset(vw ValueWriter) { - e.vw = vw -} - -// SetRegistry replaces the current registry of the Encoder with r. -func (e *Encoder) SetRegistry(r *Registry) { - e.reg = r -} - // ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in // the marshaled BSON when the "inline" struct tag option is set. func (e *Encoder) ErrorOnInlineDuplicates() { @@ -172,7 +159,9 @@ func (e *Encoder) UseJSONStructTags() { t := reflect.TypeOf((*structCodec)(nil)) if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { for i := range v { - v[i].(*structCodec).useJSONStructTags = true + if enc, ok := v[i].(*structCodec); ok { + enc.useJSONStructTags = true + } } } } diff --git a/bson/marshal.go b/bson/marshal.go index a26270a158..01dea11fb6 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -70,11 +70,7 @@ func Marshal(val interface{}) ([]byte, error) { } }() sw.Reset() - vw := NewValueWriter(sw) - enc := encPool.Get().(*Encoder) - defer encPool.Put(enc) - enc.Reset(vw) - enc.SetRegistry(NewRegistryBuilder().Build()) + enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), NewValueWriter(sw)) err := enc.Encode(val) if err != nil { return nil, err @@ -100,10 +96,7 @@ func MarshalValueWithRegistry(r *Registry, val interface{}) (Type, []byte, error vwFlusher := bvwPool.GetAtModeElement(&sw) // get an Encoder and encode the value - enc := encPool.Get().(*Encoder) - defer encPool.Put(enc) - enc.Reset(vwFlusher) - enc.SetRegistry(r) + enc := NewEncoderWithRegistry(r, vwFlusher) if err := enc.Encode(val); err != nil { return 0, nil, err } @@ -123,12 +116,7 @@ func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error) ejvw := extjPool.Get(&sw, canonical, escapeHTML) defer extjPool.Put(ejvw) - enc := encPool.Get().(*Encoder) - defer encPool.Put(enc) - - enc.Reset(ejvw) - enc.SetRegistry(NewRegistryBuilder().Build()) - + enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), ejvw) err := enc.Encode(val) if err != nil { return nil, err diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 82d07d99f2..0bb650b668 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -32,33 +32,7 @@ func TestMarshalWithRegistry(t *testing.T) { } buf := new(bytes.Buffer) vw := NewValueWriter(buf) - enc := NewEncoder(vw) - enc.SetRegistry(reg) - err := enc.Encode(tc.val) - noerr(t, err) - - if got := buf.Bytes(); !bytes.Equal(got, tc.want) { - t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) - t.Errorf("Bytes:\n%v\n%v", got, tc.want) - } - }) - } -} - -func TestMarshalWithContext(t *testing.T) { - for _, tc := range marshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - var reg *Registry - if tc.reg != nil { - reg = tc.reg - } else { - reg = NewRegistryBuilder().Build() - } - buf := new(bytes.Buffer) - vw := NewValueWriter(buf) - enc := NewEncoder(vw) - enc.IntMinSize() - enc.SetRegistry(reg) + enc := NewEncoderWithRegistry(reg, vw) err := enc.Encode(tc.val) noerr(t, err) @@ -175,8 +149,7 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { buf := new(bytes.Buffer) vw := NewValueWriter(buf) - enc := NewEncoder(vw) - enc.SetRegistry(customReg) + enc := NewEncoderWithRegistry(customReg, vw) err = enc.Encode(original) assert.Nil(t, err, "Encode error: %v", err) second := buf.Bytes() diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 8a1ff811fd..302c9fc6da 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -79,23 +79,6 @@ var sampleItems = []testItemType{ "\x13\x00\x00\x00\x05slice\x00\x02\x00\x00\x00\x00\x01\x02\x00"}, } -func TestMarshalSampleItems(t *testing.T) { - buf := new(bytes.Buffer) - enc := new(bson.Encoder) - for i, item := range sampleItems { - t.Run(strconv.Itoa(i), func(t *testing.T) { - buf.Reset() - vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) - err := enc.Encode(item.obj) - assert.Nil(t, err, "expected nil error, got: %v", err) - str := buf.String() - assert.Equal(t, str, item.data, "expected: %v, got: %v", item.data, str) - }) - } -} - func TestUnmarshalSampleItems(t *testing.T) { for i, item := range sampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { @@ -164,23 +147,6 @@ var allItems = []testItemType{ "\xFF_\x00"}, } -func TestMarshalAllItems(t *testing.T) { - buf := new(bytes.Buffer) - enc := new(bson.Encoder) - for i, item := range allItems { - t.Run(strconv.Itoa(i), func(t *testing.T) { - buf.Reset() - vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) - err := enc.Encode(item.obj) - assert.Nil(t, err, "expected nil error, got: %v", err) - str := buf.String() - assert.Equal(t, str, wrapInDoc(item.data), "expected: %v, got: %v", wrapInDoc(item.data), str) - }) - } -} - func TestUnmarshalAllItems(t *testing.T) { for i, item := range allItems { t.Run(strconv.Itoa(i), func(t *testing.T) { @@ -220,8 +186,7 @@ func TestUnmarshalRawIncompatible(t *testing.T) { func TestUnmarshalZeroesStruct(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(bson.M{"b": 2}) assert.Nil(t, err, "expected nil error, got: %v", err) type T struct{ A, B int } @@ -235,8 +200,7 @@ func TestUnmarshalZeroesStruct(t *testing.T) { func TestUnmarshalZeroesMap(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(bson.M{"b": 2}) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{"a": 1} @@ -250,8 +214,7 @@ func TestUnmarshalZeroesMap(t *testing.T) { func TestUnmarshalNonNilInterface(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(bson.M{"b": 2}) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{"a": 1} @@ -288,13 +251,11 @@ func TestPtrInline(t *testing.T) { } buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, cs := range cases { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(cs.In) assert.Nil(t, err, "expected nil error, got: %v", err) var dataBSON bson.M @@ -377,13 +338,11 @@ var oneWayMarshalItems = []testItemType{ func TestOneWayMarshalItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range oneWayMarshalItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -414,13 +373,11 @@ var structSampleItems = []testItemType{ func TestMarshalStructSampleItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range structSampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, item.data, buf.String(), "expected: %v, got: %v", item.data, buf.String()) @@ -441,8 +398,7 @@ func Test64bitInt(t *testing.T) { if int(i) > 0 { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(bson.M{"i": int(i)}) assert.Nil(t, err, "expected nil error, got: %v", err) want := wrapInDoc("\x12i\x00\x00\x00\x00\x80\x00\x00\x00\x00") @@ -580,13 +536,11 @@ var structItems = []testItemType{ func TestMarshalStructItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range structItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, wrapInDoc(item.data), buf.String(), "expected: %v, got: %v", wrapInDoc(item.data), buf.String()) @@ -656,13 +610,11 @@ var marshalItems = []testItemType{ func TestMarshalOneWayItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range marshalItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, wrapInDoc(item.data), buf.String(), "expected: %v, got: %v", wrapInDoc(item.data), buf.String()) @@ -765,13 +717,11 @@ var marshalErrorItems = []testItemType{ func TestMarshalErrorItems(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range marshalErrorItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(item.obj) assert.NotNil(t, err, "expected error") @@ -1031,8 +981,7 @@ func TestUnmarshalSetterErrSetZero(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(bson.M{"field": "foo"}) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -1066,7 +1015,6 @@ type docWithGetterField struct { func TestMarshalAllItemsWithGetter(t *testing.T) { buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, item := range allItems { if item.data == "" { continue @@ -1076,8 +1024,7 @@ func TestMarshalAllItemsWithGetter(t *testing.T) { obj := &docWithGetterField{} obj.Field = &typeWithGetter{result: item.obj.(bson.M)["_"]} vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, wrapInDoc(item.data), buf.String(), @@ -1090,8 +1037,7 @@ func TestMarshalWholeDocumentWithGetter(t *testing.T) { obj := &typeWithGetter{result: sampleItems[0].obj} buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, sampleItems[0].data, buf.String(), @@ -1105,8 +1051,7 @@ func TestGetterErrors(t *testing.T) { obj1.Field = &typeWithGetter{sampleItems[0].obj, e} buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(obj1) assert.Equal(t, e, err, "expected error: %v, got: %v", e, err) assert.Nil(t, buf.Bytes(), "expected nil data, got: %v", buf.Bytes()) @@ -1114,8 +1059,7 @@ func TestGetterErrors(t *testing.T) { obj2 := &typeWithGetter{sampleItems[0].obj, e} buf.Reset() vw = bson.NewValueWriter(buf) - enc = bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc = bson.NewEncoderWithRegistry(Registry, vw) err = enc.Encode(obj2) assert.Equal(t, e, err, "expected error: %v, got: %v", e, err) assert.Nil(t, buf.Bytes(), "expected nil data, got: %v", buf.Bytes()) @@ -1135,8 +1079,7 @@ func TestMarshalShortWithGetter(t *testing.T) { obj := typeWithIntGetter{42} buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{} @@ -1149,8 +1092,7 @@ func TestMarshalWithGetterNil(t *testing.T) { obj := docWithGetterField{} buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{} @@ -1594,8 +1536,7 @@ func testCrossPair(t *testing.T, dump interface{}, load interface{}) { zero := makeZeroDoc(load) buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(dump) assert.Nil(t, err, "expected nil error, got: %v", err) err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), zero) @@ -1708,8 +1649,7 @@ func TestMarshalNotRespectNil(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(testStruct1) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -1741,8 +1681,7 @@ func TestMarshalRespectNil(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(testStruct1) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -1770,8 +1709,7 @@ func TestMarshalRespectNil(t *testing.T) { buf.Reset() vw = bson.NewValueWriter(buf) - enc = bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc = bson.NewEncoderWithRegistry(Registry, vw) err = enc.Encode(testStruct1) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -1806,8 +1744,7 @@ func TestInlineWithPointerToSelf(t *testing.T) { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(Registry) + enc := bson.NewEncoderWithRegistry(Registry, vw) err := enc.Encode(x1) assert.Nil(t, err, "expected nil error, got: %v", err) diff --git a/bson/registry_examples_test.go b/bson/registry_examples_test.go index b8b1010c9f..20fab280a4 100644 --- a/bson/registry_examples_test.go +++ b/bson/registry_examples_test.go @@ -46,13 +46,14 @@ func ExampleRegistry_customEncoder() { return vw.WriteInt64(negatedVal) } - reg := bson.NewRegistryBuilder() - reg.RegisterTypeEncoder( - negatedIntType, - func() bson.ValueEncoder { - return bson.ValueEncoderFunc(negatedIntEncoder) - }, - ) + reg := bson.NewRegistryBuilder(). + RegisterTypeEncoder( + negatedIntType, + func() bson.ValueEncoder { + return bson.ValueEncoderFunc(negatedIntEncoder) + }, + ). + Build() // Define a document that includes both int and negatedInt fields with the // same value. @@ -69,8 +70,7 @@ func ExampleRegistry_customEncoder() { // same value and that the negatedInt field is encoded as the negated value. buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(reg.Build()) + enc := bson.NewEncoderWithRegistry(reg, vw) err := enc.Encode(doc) if err != nil { panic(err) @@ -185,15 +185,16 @@ func ExampleRegistryBuilder_RegisterKindEncoder() { return vw.WriteInt64(val.Int()) } - // Create a default registry and register our int32-to-int64 encoder for + // Create a registry with our int32-to-int64 register encoder for // kind reflect.Int32. - reg := bson.NewRegistryBuilder() - reg.RegisterKindEncoder( - reflect.Int32, - func() bson.ValueEncoder { - return bson.ValueEncoderFunc(int32To64Encoder) - }, - ) + reg := bson.NewRegistryBuilder(). + RegisterKindEncoder( + reflect.Int32, + func() bson.ValueEncoder { + return bson.ValueEncoderFunc(int32To64Encoder) + }, + ). + Build() // Define a document that includes an int32, an int64, and a user-defined // type "myInt" that has underlying type int32. @@ -213,8 +214,7 @@ func ExampleRegistryBuilder_RegisterKindEncoder() { // int64 (represented as "$numberLong" when encoded as Extended JSON). buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(reg.Build()) + enc := bson.NewEncoderWithRegistry(reg, vw) err := enc.Encode(doc) if err != nil { panic(err) diff --git a/bson/struct_codec.go b/bson/struct_codec.go index fb4d36f258..8db0f06dfa 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -300,12 +300,19 @@ func (sc *structCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val refl continue } + inlineT := inlineMap.Type() + if inlineMap.IsNil() { - inlineMap.Set(reflect.MakeMap(inlineMap.Type())) + inlineMap.Set(reflect.MakeMap(inlineT)) } - elem := reflect.New(inlineMap.Type().Elem()).Elem() - err = decoder.DecodeValue(reg, vr, elem) + var elem reflect.Value + if elemT := inlineT.Elem(); elemT == tEmpty { + elem, err = decodeTypeOrValueWithInfo(decoder, reg, vr, inlineT) + } else { + elem = reflect.New(elemT).Elem() + err = decoder.DecodeValue(reg, vr, elem) + } if err != nil { return err } diff --git a/bson/truncation_test.go b/bson/truncation_test.go index 36d66e9b38..d25621ea92 100644 --- a/bson/truncation_test.go +++ b/bson/truncation_test.go @@ -32,9 +32,8 @@ func TestTruncation(t *testing.T) { buf := new(bytes.Buffer) vw := NewValueWriter(buf) - enc := NewEncoder(vw) + enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), vw) enc.IntMinSize() - enc.SetRegistry(NewRegistryBuilder().Build()) err := enc.Encode(&input) assert.Nil(t, err) @@ -58,9 +57,8 @@ func TestTruncation(t *testing.T) { buf := new(bytes.Buffer) vw := NewValueWriter(buf) - enc := NewEncoder(vw) + enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), vw) enc.IntMinSize() - enc.SetRegistry(NewRegistryBuilder().Build()) err := enc.Encode(&input) assert.Nil(t, err) diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 48bac97643..696ab8c727 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -157,11 +157,5 @@ func UnmarshalExtJSONWithContext(reg *Registry, data []byte, canonical bool, val } func unmarshalFromReader(reg *Registry, vr ValueReader, val interface{}) error { - dec := decPool.Get().(*Decoder) - defer decPool.Put(dec) - - dec.Reset(vr) - dec.reg = reg - - return dec.Decode(val) + return NewDecoderWithRegistry(reg, vr).Decode(val) } diff --git a/mongo/cursor.go b/mongo/cursor.go index 8d1e58f6f9..577927dfac 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -91,7 +91,6 @@ func NewCursorFromDocuments(documents []interface{}, preloadedErr error, registr } buf := new(bytes.Buffer) - enc := new(bson.Encoder) values := make([]bsoncore.Value, len(documents)) for i, doc := range documents { @@ -104,9 +103,7 @@ func NewCursorFromDocuments(documents []interface{}, preloadedErr error, registr } vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(registry) - + enc := bson.NewEncoderWithRegistry(registry, vw) if err := enc.Encode(doc); err != nil { return nil, err } @@ -238,7 +235,13 @@ func getDecoder( opts *options.BSONOptions, reg *bson.Registry, ) *bson.Decoder { - dec := bson.NewDecoder(bson.NewValueReader(data)) + vr := bson.NewValueReader(data) + var dec *bson.Decoder + if reg != nil { + dec = bson.NewDecoderWithRegistry(reg, vr) + } else { + dec = bson.NewDecoder(vr) + } if opts != nil { if opts.AllowTruncatingDoubles { @@ -267,10 +270,6 @@ func getDecoder( } } - if reg != nil { - dec.SetRegistry(reg) - } - return dec } diff --git a/mongo/gridfs_bucket.go b/mongo/gridfs_bucket.go index 48c02f8716..55212eb334 100644 --- a/mongo/gridfs_bucket.go +++ b/mongo/gridfs_bucket.go @@ -620,8 +620,7 @@ func (b *GridFSBucket) parseUploadOptions(opts ...*options.UploadOptions) (*uplo // TODO gridfs package is merged into the mongo package. buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) - enc := bson.NewEncoder(vw) - enc.SetRegistry(uo.Registry) + enc := bson.NewEncoderWithRegistry(uo.Registry, vw) err := enc.Encode(uo.Metadata) if err != nil { return nil, err diff --git a/mongo/mongo.go b/mongo/mongo.go index ff499556dc..d102c05b66 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -63,7 +63,12 @@ func getEncoder( reg *bson.Registry, ) (*bson.Encoder, error) { vw := bvwPool.Get(w) - enc := bson.NewEncoder(vw) + var enc *bson.Encoder + if reg != nil { + enc = bson.NewEncoderWithRegistry(reg, vw) + } else { + enc = bson.NewEncoder(vw) + } if opts != nil { if opts.ErrorOnInlineDuplicates { @@ -92,10 +97,6 @@ func getEncoder( } } - if reg != nil { - enc.SetRegistry(reg) - } - return enc, nil } diff --git a/mongo/options/mongooptions.go b/mongo/options/mongooptions.go index bfe9ad523b..7124eb81ae 100644 --- a/mongo/options/mongooptions.go +++ b/mongo/options/mongooptions.go @@ -132,12 +132,10 @@ func (af *ArrayFilters) ToArray() ([]bson.Raw, error) { } filters := make([]bson.Raw, 0, len(af.Filters)) buf := new(bytes.Buffer) - enc := new(bson.Encoder) for _, f := range af.Filters { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(registry) + enc := bson.NewEncoderWithRegistry(registry, vw) err := enc.Encode(f) if err != nil { return nil, err @@ -159,12 +157,10 @@ func (af *ArrayFilters) ToArrayDocument() (bson.Raw, error) { idx, arr := bsoncore.AppendArrayStart(nil) buf := new(bytes.Buffer) - enc := new(bson.Encoder) for i, f := range af.Filters { buf.Reset() vw := bson.NewValueWriter(buf) - enc.Reset(vw) - enc.SetRegistry(registry) + enc := bson.NewEncoderWithRegistry(registry, vw) err := enc.Encode(f) if err != nil { return nil, err From 9c32dc781fd0e6e9b0adc52002a807888df26b0b Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Thu, 23 May 2024 17:02:22 -0400 Subject: [PATCH 12/15] WIP --- bson/decoder.go | 109 ++-------------------- bson/decoder_example_test.go | 8 +- bson/decoder_test.go | 20 ++-- bson/default_value_decoders.go | 27 +++--- bson/default_value_decoders_test.go | 128 +++++++++++++++---------- bson/default_value_encoders.go | 29 +++--- bson/default_value_encoders_test.go | 20 ++-- bson/encoder.go | 123 +++--------------------- bson/encoder_example_test.go | 54 +++++------ bson/encoder_test.go | 16 ++-- bson/float_codec.go | 107 --------------------- bson/mgoregistry.go | 12 +-- bson/{int_codec.go => num_codec.go} | 118 ++++++++++++++++++++--- bson/registry.go | 28 ++++-- bson/registry_option.go | 139 ++++++++++++++++++++++++++++ bson/struct_codec.go | 7 +- bson/truncation_test.go | 26 +++--- bson/unmarshal.go | 27 +++--- mongo/cursor.go | 16 ++-- mongo/mongo.go | 24 ++--- 20 files changed, 508 insertions(+), 530 deletions(-) delete mode 100644 bson/float_codec.go rename bson/{int_codec.go => num_codec.go} (64%) create mode 100644 bson/registry_option.go diff --git a/bson/decoder.go b/bson/decoder.go index 62714d1f4b..dd37150f60 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -15,10 +15,16 @@ import ( // ErrDecodeToNil is the error returned when trying to decode to a nil value var ErrDecodeToNil = errors.New("cannot Decode to nil value") +// ConfigurableDecoderRegistry refers a DecoderRegistry that is configurable with *RegistryOpt. +type ConfigurableDecoderRegistry interface { + DecoderRegistry + SetCodecOptions(opts ...*RegistryOpt) +} + // A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as // the source of BSON data. type Decoder struct { - reg *Registry + reg ConfigurableDecoderRegistry vr ValueReader } @@ -75,102 +81,7 @@ func (d *Decoder) Decode(val interface{}) error { return decoder.DecodeValue(d.reg, d.vr, rval) } -// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This -// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -func (d *Decoder) DefaultDocumentM() { - t := reflect.TypeOf((*emptyInterfaceCodec)(nil)) - if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*emptyInterfaceCodec).defaultDocumentType = reflect.TypeOf(M{}) - } - } -} - -// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This -// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -func (d *Decoder) DefaultDocumentD() { - t := reflect.TypeOf((*emptyInterfaceCodec)(nil)) - if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*emptyInterfaceCodec).defaultDocumentType = reflect.TypeOf(D{}) - } - } -} - -// AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values -// when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct -// field. The truncation logic does not apply to BSON "decimal128" values. -func (d *Decoder) AllowTruncatingDoubles() { - t := reflect.TypeOf((*intCodec)(nil)) - if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*intCodec).truncate = true - } - } - // TODO floatCodec -} - -// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or -// "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary. -func (d *Decoder) BinaryAsSlice() { - t := reflect.TypeOf((*emptyInterfaceCodec)(nil)) - if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*emptyInterfaceCodec).decodeBinaryAsSlice = true - } - } -} - -// DecodeObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. -func (d *Decoder) DecodeObjectIDAsHex() { - t := reflect.TypeOf((*stringCodec)(nil)) - if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*stringCodec).decodeObjectIDAsHex = true - } - } -} - -// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson" -// struct tag is not specified. -func (d *Decoder) UseJSONStructTags() { - t := reflect.TypeOf((*structCodec)(nil)) - if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*structCodec).useJSONStructTags = true - } - } -} - -// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead -// of the UTC timezone. -func (d *Decoder) UseLocalTimeZone() { - t := reflect.TypeOf((*timeCodec)(nil)) - if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*timeCodec).useLocalTimeZone = true - } - } -} - -// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value -// passed to Decode before unmarshaling BSON documents into them. -func (d *Decoder) ZeroMaps() { - t := reflect.TypeOf((*mapCodec)(nil)) - if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*mapCodec).decodeZerosMap = true - } - } -} - -// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination -// value passed to Decode before unmarshaling BSON documents into them. -func (d *Decoder) ZeroStructs() { - t := reflect.TypeOf((*structCodec)(nil)) - if v, ok := d.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*structCodec).decodeZeroStruct = true - } - } +// SetBehavior set the decoder behavior with *RegistryOpt. +func (d *Decoder) SetBehavior(opts ...*RegistryOpt) { + d.reg.SetCodecOptions(opts...) } diff --git a/bson/decoder_example_test.go b/bson/decoder_example_test.go index 3e17e98927..60fb360710 100644 --- a/bson/decoder_example_test.go +++ b/bson/decoder_example_test.go @@ -48,7 +48,7 @@ func ExampleDecoder() { // Output: {Name:Cereal Rounds SKU:AB12345 Price:399} } -func ExampleDecoder_DefaultDocumentM() { +func ExampleDecoder_SetBehavior_defaultDocumentM() { // Marshal a BSON document that contains a city name and a nested document // with various city properties. doc := bson.D{ @@ -77,7 +77,7 @@ func ExampleDecoder_DefaultDocumentM() { // type if the decode destination has no type information. The Properties // field in the City struct will be decoded as a "M" (i.e. map) instead // of the default "D". - decoder.DefaultDocumentM() + decoder.SetBehavior(bson.DefaultDocumentM) var res City err = decoder.Decode(&res) @@ -89,7 +89,7 @@ func ExampleDecoder_DefaultDocumentM() { // Output: {Name:New York Properties:map[elevation:10 population:8804190 state:NY]} } -func ExampleDecoder_UseJSONStructTags() { +func ExampleDecoder_SetBehavior_useJSONStructTags() { // Marshal a BSON document that contains the name, SKU, and price (in cents) // of a product. doc := bson.D{ @@ -114,7 +114,7 @@ func ExampleDecoder_UseJSONStructTags() { // Configure the Decoder to use "json" struct tags when decoding if "bson" // struct tags are not present. - decoder.UseJSONStructTags() + decoder.SetBehavior(bson.UseJSONStructTags) var res Product err = decoder.Decode(&res) diff --git a/bson/decoder_test.go b/bson/decoder_test.go index ccae4f2bb3..1c884c6d56 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -253,7 +253,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "AllowTruncatingDoubles", configure: func(dec *Decoder) { - dec.AllowTruncatingDoubles() + dec.SetBehavior(AllowTruncatingDoubles) }, input: bsoncore.NewDocumentBuilder(). AppendDouble("myInt", 1.999). @@ -286,7 +286,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "BinaryAsSlice", configure: func(dec *Decoder) { - dec.BinaryAsSlice() + dec.SetBehavior(BinaryAsSlice) }, input: bsoncore.NewDocumentBuilder(). AppendBinary("myBinary", TypeBinaryGeneric, []byte{}). @@ -299,7 +299,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "DefaultDocumentD nested", configure: func(dec *Decoder) { - dec.DefaultDocumentD() + dec.SetBehavior(DefaultDocumentD) }, input: bsoncore.NewDocumentBuilder(). AppendDocument("myDocument", bsoncore.NewDocumentBuilder(). @@ -316,7 +316,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "DefaultDocumentM nested", configure: func(dec *Decoder) { - dec.DefaultDocumentM() + dec.SetBehavior(DefaultDocumentM) }, input: bsoncore.NewDocumentBuilder(). AppendDocument("myDocument", bsoncore.NewDocumentBuilder(). @@ -333,7 +333,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "UseJSONStructTags", configure: func(dec *Decoder) { - dec.UseJSONStructTags() + dec.SetBehavior(UseJSONStructTags) }, input: bsoncore.NewDocumentBuilder(). AppendString("jsonFieldName", "test value"). @@ -346,7 +346,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "UseLocalTimeZone", configure: func(dec *Decoder) { - dec.UseLocalTimeZone() + dec.SetBehavior(UseLocalTimeZone) }, input: bsoncore.NewDocumentBuilder(). AppendDateTime("myTime", 1684349179939). @@ -359,7 +359,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "ZeroMaps", configure: func(dec *Decoder) { - dec.ZeroMaps() + dec.SetBehavior(ZeroMaps) }, input: bsoncore.NewDocumentBuilder(). AppendDocument("myMap", bsoncore.NewDocumentBuilder(). @@ -376,7 +376,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "ZeroStructs", configure: func(dec *Decoder) { - dec.ZeroStructs() + dec.SetBehavior(ZeroStructs) }, input: bsoncore.NewDocumentBuilder(). AppendString("myString", "test value"). @@ -417,7 +417,7 @@ func TestDecoderConfiguration(t *testing.T) { dec := NewDecoder(NewValueReader(input)) - dec.DefaultDocumentM() + dec.SetBehavior(DefaultDocumentM) var got interface{} err := dec.Decode(&got) @@ -441,7 +441,7 @@ func TestDecoderConfiguration(t *testing.T) { dec := NewDecoder(NewValueReader(input)) - dec.DefaultDocumentD() + dec.SetBehavior(DefaultDocumentD) var got interface{} err := dec.Decode(&got) diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 08fec2dd61..5c28ec9a86 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -40,8 +40,7 @@ func registerDefaultDecoders(rb *RegistryBuilder) { panic(errors.New("argument to RegisterDefaultDecoders must not be nil")) } - intDecoder := func() ValueDecoder { return &intCodec{} } - floatDecoder := func() ValueDecoder { return &floatCodec{} } + numDecoder := func() ValueDecoder { return &numCodec{} } rb.RegisterTypeDecoder(tD, func() ValueDecoder { return ValueDecoderFunc(dDecodeValue) }). RegisterTypeDecoder(tBinary, func() ValueDecoder { return &decodeAdapter{binaryDecodeValue, binaryDecodeType} }). RegisterTypeDecoder(tUndefined, func() ValueDecoder { return &decodeAdapter{undefinedDecodeValue, undefinedDecodeType} }). @@ -65,18 +64,18 @@ func registerDefaultDecoders(rb *RegistryBuilder) { RegisterTypeDecoder(tCoreDocument, func() ValueDecoder { return ValueDecoderFunc(coreDocumentDecodeValue) }). RegisterTypeDecoder(tCodeWithScope, func() ValueDecoder { return &decodeAdapter{codeWithScopeDecodeValue, codeWithScopeDecodeType} }). RegisterKindDecoder(reflect.Bool, func() ValueDecoder { return &decodeAdapter{booleanDecodeValue, booleanDecodeType} }). - RegisterKindDecoder(reflect.Int, intDecoder). - RegisterKindDecoder(reflect.Int8, intDecoder). - RegisterKindDecoder(reflect.Int16, intDecoder). - RegisterKindDecoder(reflect.Int32, intDecoder). - RegisterKindDecoder(reflect.Int64, intDecoder). - RegisterKindDecoder(reflect.Uint, intDecoder). - RegisterKindDecoder(reflect.Uint8, intDecoder). - RegisterKindDecoder(reflect.Uint16, intDecoder). - RegisterKindDecoder(reflect.Uint32, intDecoder). - RegisterKindDecoder(reflect.Uint64, intDecoder). - RegisterKindDecoder(reflect.Float32, floatDecoder). - RegisterKindDecoder(reflect.Float64, floatDecoder). + RegisterKindDecoder(reflect.Int, numDecoder). + RegisterKindDecoder(reflect.Int8, numDecoder). + RegisterKindDecoder(reflect.Int16, numDecoder). + RegisterKindDecoder(reflect.Int32, numDecoder). + RegisterKindDecoder(reflect.Int64, numDecoder). + RegisterKindDecoder(reflect.Uint, numDecoder). + RegisterKindDecoder(reflect.Uint8, numDecoder). + RegisterKindDecoder(reflect.Uint16, numDecoder). + RegisterKindDecoder(reflect.Uint32, numDecoder). + RegisterKindDecoder(reflect.Uint64, numDecoder). + RegisterKindDecoder(reflect.Float32, numDecoder). + RegisterKindDecoder(reflect.Float64, numDecoder). RegisterKindDecoder(reflect.Array, func() ValueDecoder { return ValueDecoderFunc(arrayDecodeValue) }). RegisterKindDecoder(reflect.Map, func() ValueDecoder { return &mapCodec{} }). RegisterKindDecoder(reflect.Slice, func() ValueDecoder { return &sliceCodec{} }). diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index f931c16f6e..019057ea00 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -139,17 +139,18 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "IntDecodeValue", - &intCodec{}, + &numCodec{}, []subtest{ { "wrong type", wrong, nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, - readInt32, + nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -215,10 +216,11 @@ func TestDefaultValueDecoders(t *testing.T) { {"int/fast path", int(1234), nil, &valueReaderWriter{BSONType: TypeInt64, Return: int64(1234)}, readInt64, nil}, { "int8/fast path - nil", (*int8)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -227,10 +229,11 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "int16/fast path - nil", (*int16)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -239,10 +242,11 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "int32/fast path - nil", (*int32)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -251,10 +255,11 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "int64/fast path - nil", (*int64)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -263,10 +268,11 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "int/fast path - nil", (*int)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -365,8 +371,9 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -391,18 +398,19 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "defaultUIntCodec.DecodeValue", - &intCodec{}, + "UintDecodeValue", + &numCodec{}, []subtest{ { "wrong type", wrong, nil, &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, - readInt32, + nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -468,10 +476,11 @@ func TestDefaultValueDecoders(t *testing.T) { {"uint/fast path", uint(1234), nil, &valueReaderWriter{BSONType: TypeInt64, Return: int64(1234)}, readInt64, nil}, { "uint8/fast path - nil", (*uint8)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -480,10 +489,11 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "uint16/fast path - nil", (*uint16)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -492,10 +502,11 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "uint32/fast path - nil", (*uint32)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -504,10 +515,11 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "uint64/fast path - nil", (*uint64)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -516,10 +528,11 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "uint/fast path - nil", (*uint)(nil), nil, - &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, readInt32, + &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -638,8 +651,9 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeInt32, Return: int32(0)}, nothing, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -649,23 +663,27 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "FloatDecodeValue", - &floatCodec{}, + &numCodec{}, []subtest{ { "wrong type", wrong, nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, - readDouble, + nothing, ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, { "type not double", - 0, + float64(0), nil, &valueReaderWriter{BSONType: TypeString}, nothing, @@ -727,19 +745,27 @@ func TestDefaultValueDecoders(t *testing.T) { }, { "float32/fast path - nil", (*float32)(nil), nil, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, readDouble, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, nothing, ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*float32)(nil)), }, }, { "float64/fast path - nil", (*float64)(nil), nil, - &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, readDouble, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, nothing, ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf((*float64)(nil)), }, }, @@ -770,14 +796,18 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(0)}, nothing, ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, }, }, }, }, { - "defaultTimeCodec.DecodeValue", + "TimeDecodeValue", &timeCodec{}, []subtest{ { @@ -831,7 +861,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "defaultMapCodec.DecodeValue", + "MapDecodeValue", &mapCodec{}, []subtest{ { @@ -1003,7 +1033,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "defaultSliceCodec.DecodeValue", + "SliceDecodeValue", &sliceCodec{}, []subtest{ { @@ -1414,7 +1444,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "defaultByteSliceCodec.DecodeValue", + "ByteSliceDecodeValue", &byteSliceCodec{}, []subtest{ { @@ -1482,7 +1512,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "defaultStringCodec.DecodeValue", + "StringDecodeValue", &stringCodec{}, []subtest{ { @@ -1591,7 +1621,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "PointerCodec.DecodeValue", + "PointerDecodeValue", &pointerCodec{}, []subtest{ { @@ -2262,7 +2292,7 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, { - "StructCodec.DecodeValue", + "StructDecodeValue", defaultTestStructCodec, []subtest{ { diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 56d48e3722..be57563530 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -55,8 +55,7 @@ func registerDefaultEncoders(rb *RegistryBuilder) { panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) } - intEncoder := func() ValueEncoder { return &intCodec{} } - floatEncoder := func() ValueEncoder { return &floatCodec{} } + numEncoder := func() ValueEncoder { return &numCodec{} } rb.RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{} }). RegisterTypeEncoder(tTime, func() ValueEncoder { return &timeCodec{} }). RegisterTypeEncoder(tEmpty, func() ValueEncoder { return &emptyInterfaceCodec{} }). @@ -79,18 +78,18 @@ func registerDefaultEncoders(rb *RegistryBuilder) { RegisterTypeEncoder(tCoreDocument, func() ValueEncoder { return ValueEncoderFunc(coreDocumentEncodeValue) }). RegisterTypeEncoder(tCodeWithScope, func() ValueEncoder { return ValueEncoderFunc(codeWithScopeEncodeValue) }). RegisterKindEncoder(reflect.Bool, func() ValueEncoder { return ValueEncoderFunc(booleanEncodeValue) }). - RegisterKindEncoder(reflect.Int, intEncoder). - RegisterKindEncoder(reflect.Int8, intEncoder). - RegisterKindEncoder(reflect.Int16, intEncoder). - RegisterKindEncoder(reflect.Int32, intEncoder). - RegisterKindEncoder(reflect.Int64, intEncoder). - RegisterKindEncoder(reflect.Uint, intEncoder). - RegisterKindEncoder(reflect.Uint8, intEncoder). - RegisterKindEncoder(reflect.Uint16, intEncoder). - RegisterKindEncoder(reflect.Uint32, intEncoder). - RegisterKindEncoder(reflect.Uint64, intEncoder). - RegisterKindEncoder(reflect.Float32, floatEncoder). - RegisterKindEncoder(reflect.Float64, floatEncoder). + RegisterKindEncoder(reflect.Int, numEncoder). + RegisterKindEncoder(reflect.Int8, numEncoder). + RegisterKindEncoder(reflect.Int16, numEncoder). + RegisterKindEncoder(reflect.Int32, numEncoder). + RegisterKindEncoder(reflect.Int64, numEncoder). + RegisterKindEncoder(reflect.Uint, numEncoder). + RegisterKindEncoder(reflect.Uint8, numEncoder). + RegisterKindEncoder(reflect.Uint16, numEncoder). + RegisterKindEncoder(reflect.Uint32, numEncoder). + RegisterKindEncoder(reflect.Uint64, numEncoder). + RegisterKindEncoder(reflect.Float32, numEncoder). + RegisterKindEncoder(reflect.Float64, numEncoder). RegisterKindEncoder(reflect.Array, func() ValueEncoder { return ValueEncoderFunc(arrayEncodeValue) }). RegisterKindEncoder(reflect.Map, func() ValueEncoder { return &mapCodec{} }). RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{} }). @@ -151,7 +150,7 @@ func jsonNumberEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Valu return err } - return (&floatCodec{}).EncodeValue(reg, vw, reflect.ValueOf(f64)) + return (&numCodec{}).EncodeValue(reg, vw, reflect.ValueOf(f64)) } // urlEncodeValue is the ValueEncoderFunc for url.URL. diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 47cb21e6c1..9a5d51cb04 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -95,7 +95,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "IntEncodeValue", - &intCodec{}, + &numCodec{}, []subtest{ { "wrong type", @@ -104,8 +104,9 @@ func TestDefaultValueEncoders(t *testing.T) { nil, nothing, ValueEncoderError{ - Name: "IntEncodeValue", + Name: "NumEncodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -138,7 +139,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "UintEncodeValue", - &intCodec{}, + &numCodec{}, []subtest{ { "wrong type", @@ -147,8 +148,9 @@ func TestDefaultValueEncoders(t *testing.T) { nil, nothing, ValueEncoderError{ - Name: "IntEncodeValue", + Name: "NumEncodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -182,7 +184,7 @@ func TestDefaultValueEncoders(t *testing.T) { }, { "FloatEncodeValue", - &floatCodec{}, + &numCodec{}, []subtest{ { "wrong type", @@ -191,8 +193,12 @@ func TestDefaultValueEncoders(t *testing.T) { nil, nothing, ValueEncoderError{ - Name: "FloatEncodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Name: "NumEncodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, diff --git a/bson/encoder.go b/bson/encoder.go index 1e1965cd91..53db25e5c1 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -10,10 +10,16 @@ import ( "reflect" ) +// ConfigurableEncoderRegistry refers a EncoderRegistry that is configurable with *RegistryOpt. +type ConfigurableEncoderRegistry interface { + EncoderRegistry + SetCodecOptions(opts ...*RegistryOpt) +} + // An Encoder writes a serialization format to an output stream. It writes to a ValueWriter // as the destination of BSON data. type Encoder struct { - reg *Registry + reg ConfigurableEncoderRegistry vw ValueWriter } @@ -26,7 +32,7 @@ func NewEncoder(vw ValueWriter) *Encoder { } // NewEncoderWithRegistry returns a new encoder that uses the given registry to write to vw. -func NewEncoderWithRegistry(r *Registry, vw ValueWriter) *Encoder { +func NewEncoderWithRegistry(r ConfigurableEncoderRegistry, vw ValueWriter) *Encoder { return &Encoder{ reg: r, vw: vw, @@ -54,114 +60,7 @@ func (e *Encoder) Encode(val interface{}) error { return encoder.EncodeValue(e.reg, e.vw, reflect.ValueOf(val)) } -// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in -// the marshaled BSON when the "inline" struct tag option is set. -func (e *Encoder) ErrorOnInlineDuplicates() { - t := reflect.TypeOf((*structCodec)(nil)) - if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*structCodec).overwriteDuplicatedInlinedFields = false - } - } -} - -// IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint, -// uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can -// represent the integer value. -func (e *Encoder) IntMinSize() { - // if v, ok := e.reg.kindEncoders.Load(reflect.Int); ok { - // if enc, ok := v.(*intCodec); ok { - // enc.encodeToMinSize = true - // } - // } - // if v, ok := e.reg.kindEncoders.Load(reflect.Uint); ok { - // if enc, ok := v.(*uintCodec); ok { - // enc.encodeToMinSize = true - // } - // } - t := reflect.TypeOf((*intCodec)(nil)) - if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*intCodec).minSize = true - } - } -} - -// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name -// strings using fmt.Sprint instead of the default string conversion logic. -func (e *Encoder) StringifyMapKeysWithFmt() { - t := reflect.TypeOf((*mapCodec)(nil)) - if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*mapCodec).encodeKeysWithStringer = true - } - } -} - -// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON -// null. -func (e *Encoder) NilMapAsEmpty() { - t := reflect.TypeOf((*mapCodec)(nil)) - if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*mapCodec).encodeNilAsEmpty = true - } - } -} - -// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON -// null. -func (e *Encoder) NilSliceAsEmpty() { - t := reflect.TypeOf((*sliceCodec)(nil)) - if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*sliceCodec).encodeNilAsEmpty = true - } - } -} - -// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values -// instead of BSON null. -func (e *Encoder) NilByteSliceAsEmpty() { - // if v, ok := e.reg.typeEncoders.Load(tByteSlice); ok { - // if enc, ok := v.(*byteSliceCodec); ok { - // enc.encodeNilAsEmpty = true - // } - // } - t := reflect.TypeOf((*byteSliceCodec)(nil)) - if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*byteSliceCodec).encodeNilAsEmpty = true - } - } -} - -// TODO(GODRIVER-2820): Update the description to remove the note about only examining exported -// TODO struct fields once the logic is updated to also inspect private struct fields. - -// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{}) -// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set. -// -// Note that the Encoder only examines exported struct fields when determining if a struct is the -// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. -func (e *Encoder) OmitZeroStruct() { - t := reflect.TypeOf((*structCodec)(nil)) - if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - v[i].(*structCodec).encodeOmitDefaultStruct = true - } - } -} - -// UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson" -// struct tag is not specified. -func (e *Encoder) UseJSONStructTags() { - t := reflect.TypeOf((*structCodec)(nil)) - if v, ok := e.reg.codecTypeMap[t]; ok && v != nil { - for i := range v { - if enc, ok := v[i].(*structCodec); ok { - enc.useJSONStructTags = true - } - } - } +// SetBehavior set the encoder behavior with *RegistryOpt. +func (e *Encoder) SetBehavior(opts ...*RegistryOpt) { + e.reg.SetCodecOptions(opts...) } diff --git a/bson/encoder_example_test.go b/bson/encoder_example_test.go index 5c34192db4..f56249b7ef 100644 --- a/bson/encoder_example_test.go +++ b/bson/encoder_example_test.go @@ -53,7 +53,30 @@ func (k CityState) String() string { return fmt.Sprintf("%s, %s", k.City, k.State) } -func ExampleEncoder_StringifyMapKeysWithFmt() { +func ExampleEncoder_SetBehavior_intMinSize() { + // Create an encoder that will marshal integers as the minimum BSON int size + // (either 32 or 64 bits) that can represent the integer value. + type foo struct { + Bar uint32 + } + + buf := new(bytes.Buffer) + vw := bson.NewValueWriter(buf) + + enc := bson.NewEncoder(vw) + enc.SetBehavior(bson.IntMinSize) + + err := enc.Encode(foo{2}) + if err != nil { + panic(err) + } + + fmt.Println(bson.Raw(buf.Bytes()).String()) + // Output: + // {"bar": {"$numberInt":"2"}} +} + +func ExampleEncoder_SetBehavior_stringifyMapKeysWithFmt() { // Create an Encoder that writes BSON values to a bytes.Buffer. buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) @@ -61,7 +84,7 @@ func ExampleEncoder_StringifyMapKeysWithFmt() { // Configure the Encoder to convert Go map keys to BSON document field names // using fmt.Sprintf instead of the default string conversion logic. - encoder.StringifyMapKeysWithFmt() + encoder.SetBehavior(bson.StringifyMapKeysWithFmt) // Use the Encoder to marshal a BSON document that contains is a map of // city and state to a list of zip codes in that city. @@ -78,7 +101,7 @@ func ExampleEncoder_StringifyMapKeysWithFmt() { // Output: {"New York, NY": [{"$numberInt":"10001"},{"$numberInt":"10301"},{"$numberInt":"10451"}]} } -func ExampleEncoder_UseJSONStructTags() { +func ExampleEncoder_SetBehavior_useJSONStructTags() { // Create an Encoder that writes BSON values to a bytes.Buffer. buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) @@ -92,7 +115,7 @@ func ExampleEncoder_UseJSONStructTags() { // Configure the Encoder to use "json" struct tags when decoding if "bson" // struct tags are not present. - encoder.UseJSONStructTags() + encoder.SetBehavior(bson.UseJSONStructTags) // Use the Encoder to marshal a BSON document that contains the name, SKU, // and price (in cents) of a product. @@ -215,26 +238,3 @@ func ExampleEncoder_multipleExtendedJSONDocuments() { // {"x":{"$numberInt":"3"},"y":{"$numberInt":"4"}} // {"x":{"$numberInt":"4"},"y":{"$numberInt":"5"}} } - -func ExampleEncoder_IntMinSize() { - // Create an encoder that will marshal integers as the minimum BSON int size - // (either 32 or 64 bits) that can represent the integer value. - type foo struct { - Bar uint32 - } - - buf := new(bytes.Buffer) - vw := bson.NewValueWriter(buf) - - enc := bson.NewEncoder(vw) - enc.IntMinSize() - - err := enc.Encode(foo{2}) - if err != nil { - panic(err) - } - - fmt.Println(bson.Raw(buf.Bytes()).String()) - // Output: - // {"bar": {"$numberInt":"2"}} -} diff --git a/bson/encoder_test.go b/bson/encoder_test.go index 2dff4fbfdc..1110fc7d8d 100644 --- a/bson/encoder_test.go +++ b/bson/encoder_test.go @@ -160,7 +160,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "ErrorOnInlineDuplicates", configure: func(enc *Encoder) { - enc.ErrorOnInlineDuplicates() + enc.SetBehavior(ErrorOnInlineDuplicates) }, input: inlineDuplicateOuter{ Inline: inlineDuplicateInner{Duplicate: "inner"}, @@ -173,7 +173,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "IntMinSize", configure: func(enc *Encoder) { - enc.IntMinSize() + enc.SetBehavior(IntMinSize) }, input: D{ {Key: "myInt", Value: int(1)}, @@ -194,7 +194,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "StringifyMapKeysWithFmt", configure: func(enc *Encoder) { - enc.StringifyMapKeysWithFmt() + enc.SetBehavior(StringifyMapKeysWithFmt) }, input: map[stringerTest]string{ {}: "test value", @@ -207,7 +207,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "NilMapAsEmpty", configure: func(enc *Encoder) { - enc.NilMapAsEmpty() + enc.SetBehavior(NilMapAsEmpty) }, input: D{{Key: "myMap", Value: map[string]string(nil)}}, want: bsoncore.NewDocumentBuilder(). @@ -218,7 +218,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "NilSliceAsEmpty", configure: func(enc *Encoder) { - enc.NilSliceAsEmpty() + enc.SetBehavior(NilSliceAsEmpty) }, input: D{{Key: "mySlice", Value: []string(nil)}}, want: bsoncore.NewDocumentBuilder(). @@ -229,7 +229,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "NilByteSliceAsEmpty", configure: func(enc *Encoder) { - enc.NilByteSliceAsEmpty() + enc.SetBehavior(NilByteSliceAsEmpty) }, input: D{{Key: "myBytes", Value: []byte(nil)}}, want: bsoncore.NewDocumentBuilder(). @@ -241,7 +241,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "OmitZeroStruct", configure: func(enc *Encoder) { - enc.OmitZeroStruct() + enc.SetBehavior(OmitZeroStruct) }, input: struct { Zero zeroStruct `bson:",omitempty"` @@ -253,7 +253,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "UseJSONStructTags", configure: func(enc *Encoder) { - enc.UseJSONStructTags() + enc.SetBehavior(UseJSONStructTags) }, input: struct { StructFieldName string `json:"jsonFieldName"` diff --git a/bson/float_codec.go b/bson/float_codec.go deleted file mode 100644 index aa99857877..0000000000 --- a/bson/float_codec.go +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2024-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package bson - -import ( - "fmt" - "reflect" -) - -type floatCodec struct { - // truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" - // values when attempting to unmarshal them into a Go float struct field. The truncation logic - // does not apply to BSON "decimal128" values. - truncate bool -} - -// floatEncodeValue is the ValueEncoderFunc for float types. -func (fc *floatCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { - switch val.Kind() { - case reflect.Float32, reflect.Float64: - return vw.WriteDouble(val.Float()) - } - - return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} -} - -func (fc *floatCodec) floatDecodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { - var f float64 - var err error - switch vrType := vr.Type(); vrType { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return emptyValue, err - } - f = float64(i32) - case TypeInt64: - i64, err := vr.ReadInt64() - if err != nil { - return emptyValue, err - } - f = float64(i64) - case TypeDouble: - f, err = vr.ReadDouble() - if err != nil { - return emptyValue, err - } - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return emptyValue, err - } - if b { - f = 1 - } - case TypeNull: - if err = vr.ReadNull(); err != nil { - return emptyValue, err - } - case TypeUndefined: - if err = vr.ReadUndefined(); err != nil { - return emptyValue, err - } - default: - return emptyValue, fmt.Errorf("cannot decode %v into a float32 or float64 type", vrType) - } - - switch t.Kind() { - case reflect.Float32: - if !fc.truncate && float64(float32(f)) != f { - return emptyValue, errCannotTruncate - } - - return reflect.ValueOf(float32(f)), nil - case reflect.Float64: - return reflect.ValueOf(f), nil - default: - return emptyValue, ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, - Received: reflect.Zero(t), - } - } -} - -// DecodeValue is the ValueDecoder for float types. -func (fc *floatCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { - if !val.CanSet() { - return ValueDecoderError{ - Name: "FloatDecodeValue", - Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, - Received: val, - } - } - - elem, err := fc.floatDecodeType(reg, vr, val.Type()) - if err != nil { - return err - } - - val.SetFloat(elem.Float()) - return nil -} diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index 1efac62e92..b8d9ca7a10 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -34,7 +34,7 @@ func newMgoRegistryBuilder() *RegistryBuilder { encodeNilAsEmpty: true, encodeKeysWithStringer: true, } - intcodec := func() ValueEncoder { return &intCodec{encodeUintToMinSize: true} } + numcodec := func() ValueEncoder { return &numCodec{encodeUintToMinSize: true} } return NewRegistryBuilder(). RegisterTypeDecoder(tEmpty, func() ValueDecoder { return &emptyInterfaceCodec{decodeBinaryAsSlice: true} }). @@ -45,11 +45,11 @@ func newMgoRegistryBuilder() *RegistryBuilder { RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return structcodec }). RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{encodeNilAsEmpty: true} }). RegisterKindEncoder(reflect.Map, func() ValueEncoder { return mapCodec }). - RegisterKindEncoder(reflect.Uint, intcodec). - RegisterKindEncoder(reflect.Uint8, intcodec). - RegisterKindEncoder(reflect.Uint16, intcodec). - RegisterKindEncoder(reflect.Uint32, intcodec). - RegisterKindEncoder(reflect.Uint64, intcodec). + RegisterKindEncoder(reflect.Uint, numcodec). + RegisterKindEncoder(reflect.Uint8, numcodec). + RegisterKindEncoder(reflect.Uint16, numcodec). + RegisterKindEncoder(reflect.Uint32, numcodec). + RegisterKindEncoder(reflect.Uint64, numcodec). RegisterTypeMapEntry(TypeInt32, tInt). RegisterTypeMapEntry(TypeDateTime, tTime). RegisterTypeMapEntry(TypeArray, tInterfaceSlice). diff --git a/bson/int_codec.go b/bson/num_codec.go similarity index 64% rename from bson/int_codec.go rename to bson/num_codec.go index a7edab4e58..33c16fd15e 100644 --- a/bson/int_codec.go +++ b/bson/num_codec.go @@ -12,8 +12,8 @@ import ( "reflect" ) -// intCodec is the Codec used for uint values. -type intCodec struct { +// numCodec is the Codec used for numeric values. +type numCodec struct { // minSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, // uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) // that can represent the integer value. @@ -30,9 +30,12 @@ type intCodec struct { truncate bool } -// EncodeValue is the ValueEncoder for uint types. -func (ic *intCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { +// EncodeValue is the ValueEncoder for numeric types. +func (nc *numCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { switch val.Kind() { + case reflect.Float32, reflect.Float64: + return vw.WriteDouble(val.Float()) + case reflect.Int8, reflect.Int16, reflect.Int32: return vw.WriteInt32(int32(val.Int())) case reflect.Int: @@ -43,7 +46,7 @@ func (ic *intCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.V return vw.WriteInt64(i64) case reflect.Int64: i64 := val.Int() - if ic.minSize && fitsIn32Bits(i64) { + if nc.minSize && fitsIn32Bits(i64) { return vw.WriteInt32(int32(i64)) } return vw.WriteInt64(i64) @@ -54,7 +57,7 @@ func (ic *intCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.V u64 := val.Uint() // If minSize or encodeToMinSize is true for a non-uint64 value we should write val as an int32 - useMinSize := ic.minSize || (ic.encodeUintToMinSize && val.Kind() != reflect.Uint64) + useMinSize := nc.minSize || (nc.encodeUintToMinSize && val.Kind() != reflect.Uint64) if u64 <= math.MaxInt32 && useMinSize { return vw.WriteInt32(int32(u64)) @@ -66,8 +69,9 @@ func (ic *intCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.V } return ValueEncoderError{ - Name: "IntEncodeValue", + Name: "NumEncodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -75,7 +79,7 @@ func (ic *intCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.V } } -func (ic *intCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { +func (nc *numCodec) decodeTypeInt(vr ValueReader, t reflect.Type) (reflect.Value, error) { var i64 int64 switch vrType := vr.Type(); vrType { case TypeInt32: @@ -95,7 +99,7 @@ func (ic *intCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type if err != nil { return emptyValue, err } - if !ic.truncate && math.Floor(f64) != f64 { + if !nc.truncate && math.Floor(f64) != f64 { return emptyValue, errCannotTruncate } if f64 > float64(math.MaxInt64) { @@ -174,8 +178,93 @@ func (ic *intCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type default: return emptyValue, ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, + Received: reflect.Zero(t), + } + } +} + +func (nc *numCodec) decodeTypeFloat(vr ValueReader, t reflect.Type) (reflect.Value, error) { + var f float64 + var err error + switch vrType := vr.Type(); vrType { + case TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return emptyValue, err + } + f = float64(i32) + case TypeInt64: + i64, err := vr.ReadInt64() + if err != nil { + return emptyValue, err + } + f = float64(i64) + case TypeDouble: + f, err = vr.ReadDouble() + if err != nil { + return emptyValue, err + } + case TypeBoolean: + b, err := vr.ReadBoolean() + if err != nil { + return emptyValue, err + } + if b { + f = 1 + } + case TypeNull: + if err = vr.ReadNull(); err != nil { + return emptyValue, err + } + case TypeUndefined: + if err = vr.ReadUndefined(); err != nil { + return emptyValue, err + } + default: + return emptyValue, fmt.Errorf("cannot decode %v into a float32 or float64 type", vrType) + } + + switch t.Kind() { + case reflect.Float32: + if !nc.truncate && float64(float32(f)) != f { + return emptyValue, errCannotTruncate + } + + return reflect.ValueOf(float32(f)), nil + case reflect.Float64: + return reflect.ValueOf(f), nil + + default: + return emptyValue, ValueDecoderError{ + Name: "NumDecodeValue", + Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, + Received: reflect.Zero(t), + } + } +} + +func (nc *numCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) { + switch t.Kind() { + case reflect.Float32, reflect.Float64: + return nc.decodeTypeFloat(vr, t) + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + return nc.decodeTypeInt(vr, t) + default: + return emptyValue, ValueDecoderError{ + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -184,12 +273,13 @@ func (ic *intCodec) decodeType(_ DecoderRegistry, vr ValueReader, t reflect.Type } } -// DecodeValue is the ValueDecoder for uint types. -func (ic *intCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { +// DecodeValue is the ValueDecoder for numeric types. +func (nc *numCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect.Value) error { if !val.CanSet() { return ValueDecoderError{ - Name: "IntDecodeValue", + Name: "NumDecodeValue", Kinds: []reflect.Kind{ + reflect.Float32, reflect.Float64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, @@ -197,7 +287,7 @@ func (ic *intCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val reflect } } - elem, err := ic.decodeType(reg, vr, val.Type()) + elem, err := nc.decodeType(reg, vr, val.Type()) if err != nil { return err } diff --git a/bson/registry.go b/bson/registry.go index 5eaa2fecee..6deb66c6a3 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -227,13 +227,14 @@ func (rb *RegistryBuilder) Build() *Registry { codecTypeMap: make(map[reflect.Type][]interface{}), } - encoderCache := make(map[reflect.Value]ValueEncoder) + codecCache := make(map[reflect.Value]interface{}) + getEncoder := func(encFac EncoderFactory) ValueEncoder { - if enc, ok := encoderCache[reflect.ValueOf(encFac)]; ok { - return enc + if enc, ok := codecCache[reflect.ValueOf(encFac)]; ok { + return enc.(ValueEncoder) } encoder := encFac() - encoderCache[reflect.ValueOf(encFac)] = encoder + codecCache[reflect.ValueOf(encFac)] = encoder t := reflect.ValueOf(encoder).Type() r.codecTypeMap[t] = append(r.codecTypeMap[t], encoder) return encoder @@ -254,13 +255,12 @@ func (rb *RegistryBuilder) Build() *Registry { r.kindEncoders[i] = encoder } - decoderCache := make(map[reflect.Value]ValueDecoder) getDecoder := func(decFac DecoderFactory) ValueDecoder { - if dec, ok := decoderCache[reflect.ValueOf(decFac)]; ok { - return dec + if dec, ok := codecCache[reflect.ValueOf(decFac)]; ok { + return dec.(ValueDecoder) } decoder := decFac() - decoderCache[reflect.ValueOf(decFac)] = decoder + codecCache[reflect.ValueOf(decFac)] = decoder t := reflect.ValueOf(decoder).Type() r.codecTypeMap[t] = append(r.codecTypeMap[t], decoder) return decoder @@ -333,6 +333,18 @@ type Registry struct { codecTypeMap map[reflect.Type][]interface{} } +// SetCodecOptions configures Registry using a *RegistryOpt. +func (r *Registry) SetCodecOptions(opts ...*RegistryOpt) { + for _, opt := range opts { + v, ok := r.codecTypeMap[opt.typ] + if ok && v != nil { + for i := range v { + _ = opt.fn.Call([]reflect.Value{reflect.ValueOf(v[i])}) + } + } + } +} + // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup // order: // diff --git a/bson/registry_option.go b/bson/registry_option.go new file mode 100644 index 0000000000..b655b1af47 --- /dev/null +++ b/bson/registry_option.go @@ -0,0 +1,139 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "reflect" +) + +// RegistryOpt is used to configure a Registry. +type RegistryOpt struct { + typ reflect.Type + fn reflect.Value +} + +// NewRegistryOpt creates a *RegistryOpt from a setter function. +// For example: +// +// opt := NewRegistryOpt(func(c *Codec) { +// c.attr = value +// }) +// +// reg := NewRegistryBuilder().Build() +// reg.SetCodecOptions(opt) +// +// The "attr" field in the registered Codec can be set to "value". +func NewRegistryOpt[T any](fn func(T)) *RegistryOpt { + var zero [0]T + return &RegistryOpt{ + typ: reflect.TypeOf(zero).Elem(), + fn: reflect.ValueOf(fn), + } +} + +// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values +// instead of BSON null. +var NilByteSliceAsEmpty = NewRegistryOpt(func(c *byteSliceCodec) { + c.encodeNilAsEmpty = true +}) + +// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or +// "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary. +var BinaryAsSlice = NewRegistryOpt(func(c *emptyInterfaceCodec) { + c.decodeBinaryAsSlice = true +}) + +// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". +var DefaultDocumentM = NewRegistryOpt(func(c *emptyInterfaceCodec) { + c.defaultDocumentType = reflect.TypeOf(M{}) +}) + +// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". +var DefaultDocumentD = NewRegistryOpt(func(c *emptyInterfaceCodec) { + c.defaultDocumentType = reflect.TypeOf(D{}) +}) + +// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON +// null. +var NilMapAsEmpty = NewRegistryOpt(func(c *mapCodec) { + c.encodeNilAsEmpty = true +}) + +// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name +// strings using fmt.Sprint instead of the default string conversion logic. +var StringifyMapKeysWithFmt = NewRegistryOpt(func(c *mapCodec) { + c.encodeKeysWithStringer = true +}) + +// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value +// passed to Decode before unmarshaling BSON documents into them. +var ZeroMaps = NewRegistryOpt(func(c *mapCodec) { + c.decodeZerosMap = true +}) + +// AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values +// when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct +// field. The truncation logic does not apply to BSON "decimal128" values. +var AllowTruncatingDoubles = NewRegistryOpt(func(c *numCodec) { + c.truncate = true +}) + +// IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint, +// uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can +// represent the integer value. +var IntMinSize = NewRegistryOpt(func(c *numCodec) { + c.minSize = true +}) + +// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON +// null. +var NilSliceAsEmpty = NewRegistryOpt(func(c *sliceCodec) { + c.encodeNilAsEmpty = true +}) + +// DecodeObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. +var DecodeObjectIDAsHex = NewRegistryOpt(func(c *stringCodec) { + c.decodeObjectIDAsHex = true +}) + +// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in +// the marshaled BSON when the "inline" struct tag option is set. +var ErrorOnInlineDuplicates = NewRegistryOpt(func(c *structCodec) { + c.overwriteDuplicatedInlinedFields = false +}) + +// TODO(GODRIVER-2820): Update the description to remove the note about only examining exported +// TODO struct fields once the logic is updated to also inspect private struct fields. + +// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{}) +// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set. +// +// Note that the Encoder only examines exported struct fields when determining if a struct is the +// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. +var OmitZeroStruct = NewRegistryOpt(func(c *structCodec) { + c.encodeOmitDefaultStruct = true +}) + +// UseJSONStructTags causes the Encoder and Decoder to fall back to using the "json" struct tag if +// a "bson" struct tag is not specified. +var UseJSONStructTags = NewRegistryOpt(func(c *structCodec) { + c.useJSONStructTags = true +}) + +// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination +// value passed to Decode before unmarshaling BSON documents into them. +var ZeroStructs = NewRegistryOpt(func(c *structCodec) { + c.decodeZeroStruct = true +}) + +// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead +// of the UTC timezone. +var UseLocalTimeZone = NewRegistryOpt(func(c *timeCodec) { + c.useLocalTimeZone = true +}) diff --git a/bson/struct_codec.go b/bson/struct_codec.go index 8db0f06dfa..dc42b31150 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -95,8 +95,8 @@ func (r *localEncoderRegistry) LookupEncoder(t reflect.Type) (ValueEncoder, erro return ve, err } if r.minSize { - if ic, ok := ve.(*intCodec); ok { - ve = &intCodec{ + if ic, ok := ve.(*numCodec); ok { + ve = &numCodec{ minSize: true, truncate: ic.truncate, } @@ -309,6 +309,9 @@ func (sc *structCodec) DecodeValue(reg DecoderRegistry, vr ValueReader, val refl var elem reflect.Value if elemT := inlineT.Elem(); elemT == tEmpty { elem, err = decodeTypeOrValueWithInfo(decoder, reg, vr, inlineT) + if elem.Type() != elemT { + elem = elem.Convert(elemT) + } } else { elem = reflect.New(elemT).Elem() err = decoder.DecodeValue(reg, vr, elem) diff --git a/bson/truncation_test.go b/bson/truncation_test.go index d25621ea92..311b2942d4 100644 --- a/bson/truncation_test.go +++ b/bson/truncation_test.go @@ -33,17 +33,18 @@ func TestTruncation(t *testing.T) { buf := new(bytes.Buffer) vw := NewValueWriter(buf) enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), vw) - enc.IntMinSize() + enc.SetBehavior(IntMinSize) err := enc.Encode(&input) assert.Nil(t, err) var output outputArgs - // dc := DecodeContext{ - // Registry: DefaultRegistry, - // truncate: true, - // } + opt := NewRegistryOpt(func(c *numCodec) { + c.truncate = true + }) + reg := NewRegistryBuilder().Build() + reg.SetCodecOptions(opt) - err = UnmarshalWithContext(NewRegistryBuilder().Build(), buf.Bytes(), &output) + err = UnmarshalWithContext(reg, buf.Bytes(), &output) assert.Nil(t, err) assert.Equal(t, inputName, output.Name) @@ -58,18 +59,19 @@ func TestTruncation(t *testing.T) { buf := new(bytes.Buffer) vw := NewValueWriter(buf) enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), vw) - enc.IntMinSize() + enc.SetBehavior(IntMinSize) err := enc.Encode(&input) assert.Nil(t, err) var output outputArgs - // dc := DecodeContext{ - // Registry: DefaultRegistry, - // truncate: false, - // } + opt := NewRegistryOpt(func(c *numCodec) { + c.truncate = false + }) + reg := NewRegistryBuilder().Build() + reg.SetCodecOptions(opt) // case throws an error when truncation is disabled - err = UnmarshalWithContext(NewRegistryBuilder().Build(), buf.Bytes(), &output) + err = UnmarshalWithContext(reg, buf.Bytes(), &output) assert.NotNil(t, err) }) } diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 696ab8c727..371d2dfc3d 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -45,18 +45,17 @@ func Unmarshal(data []byte, val interface{}) error { // stores the result in the value pointed to by val. If val is nil or not // a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError. // -// Deprecated: Use [NewDecoder] and specify the Registry by calling [Decoder.SetRegistry] instead: +// Deprecated: Use [NewDecoderWithRegistry] instead: // -// dec, err := bson.NewDecoder(NewBSONDocumentReader(data)) +// dec, err := bson.NewDecoderWithRegistry(reg, NewBSONDocumentReader(data)) // if err != nil { // panic(err) // } -// dec.SetRegistry(reg) // // See [Decoder] for more examples. -func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { +func UnmarshalWithRegistry(reg *Registry, data []byte, val interface{}) error { vr := NewValueReader(data) - return unmarshalFromReader(r, vr, val) + return NewDecoderWithRegistry(reg, vr).Decode(val) } // UnmarshalWithContext parses the BSON-encoded data using DecodeContext dc and @@ -75,7 +74,7 @@ func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { // See [Decoder] for more examples. func UnmarshalWithContext(reg *Registry, data []byte, val interface{}) error { vr := NewValueReader(data) - return unmarshalFromReader(reg, vr, val) + return NewDecoderWithRegistry(reg, vr).Decode(val) } // UnmarshalValue parses the BSON value of type t with default registry and @@ -91,9 +90,9 @@ func UnmarshalValue(t Type, data []byte, val interface{}) error { // // Deprecated: Using a custom registry to unmarshal individual BSON values will not be supported in // Go Driver 2.0. -func UnmarshalValueWithRegistry(r *Registry, t Type, data []byte, val interface{}) error { +func UnmarshalValueWithRegistry(reg *Registry, t Type, data []byte, val interface{}) error { vr := NewBSONValueReader(t, data) - return unmarshalFromReader(r, vr, val) + return NewDecoderWithRegistry(reg, vr).Decode(val) } // UnmarshalExtJSON parses the extended JSON-encoded data and stores the result @@ -120,13 +119,13 @@ func UnmarshalExtJSON(data []byte, canonical bool, val interface{}) error { // dec.SetRegistry(reg) // // See [Decoder] for more examples. -func UnmarshalExtJSONWithRegistry(r *Registry, data []byte, canonical bool, val interface{}) error { - ejvr, err := NewExtJSONValueReader(bytes.NewReader(data), canonical) +func UnmarshalExtJSONWithRegistry(reg *Registry, data []byte, canonical bool, val interface{}) error { + vr, err := NewExtJSONValueReader(bytes.NewReader(data), canonical) if err != nil { return err } - return unmarshalFromReader(r, ejvr, val) + return NewDecoderWithRegistry(reg, vr).Decode(val) } // UnmarshalExtJSONWithContext parses the extended JSON-encoded data using @@ -148,14 +147,10 @@ func UnmarshalExtJSONWithRegistry(r *Registry, data []byte, canonical bool, val // // See [Decoder] for more examples. func UnmarshalExtJSONWithContext(reg *Registry, data []byte, canonical bool, val interface{}) error { - ejvr, err := NewExtJSONValueReader(bytes.NewReader(data), canonical) + vr, err := NewExtJSONValueReader(bytes.NewReader(data), canonical) if err != nil { return err } - return unmarshalFromReader(reg, ejvr, val) -} - -func unmarshalFromReader(reg *Registry, vr ValueReader, val interface{}) error { return NewDecoderWithRegistry(reg, vr).Decode(val) } diff --git a/mongo/cursor.go b/mongo/cursor.go index 577927dfac..703aebf90b 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -245,28 +245,28 @@ func getDecoder( if opts != nil { if opts.AllowTruncatingDoubles { - dec.AllowTruncatingDoubles() + dec.SetBehavior(bson.AllowTruncatingDoubles) } if opts.BinaryAsSlice { - dec.BinaryAsSlice() + dec.SetBehavior(bson.BinaryAsSlice) } if opts.DefaultDocumentD { - dec.DefaultDocumentD() + dec.SetBehavior(bson.DefaultDocumentD) } if opts.DefaultDocumentM { - dec.DefaultDocumentM() + dec.SetBehavior(bson.DefaultDocumentM) } if opts.UseJSONStructTags { - dec.UseJSONStructTags() + dec.SetBehavior(bson.UseJSONStructTags) } if opts.UseLocalTimeZone { - dec.UseLocalTimeZone() + dec.SetBehavior(bson.UseLocalTimeZone) } if opts.ZeroMaps { - dec.ZeroMaps() + dec.SetBehavior(bson.ZeroMaps) } if opts.ZeroStructs { - dec.ZeroStructs() + dec.SetBehavior(bson.ZeroStructs) } } diff --git a/mongo/mongo.go b/mongo/mongo.go index d102c05b66..1112a297db 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -72,28 +72,28 @@ func getEncoder( if opts != nil { if opts.ErrorOnInlineDuplicates { - enc.ErrorOnInlineDuplicates() + enc.SetBehavior(bson.ErrorOnInlineDuplicates) } if opts.IntMinSize { - enc.IntMinSize() + enc.SetBehavior(bson.IntMinSize) } if opts.NilByteSliceAsEmpty { - enc.NilByteSliceAsEmpty() + enc.SetBehavior(bson.NilByteSliceAsEmpty) } if opts.NilMapAsEmpty { - enc.NilMapAsEmpty() + enc.SetBehavior(bson.NilMapAsEmpty) } if opts.NilSliceAsEmpty { - enc.NilSliceAsEmpty() + enc.SetBehavior(bson.NilSliceAsEmpty) } if opts.OmitZeroStruct { - enc.OmitZeroStruct() + enc.SetBehavior(bson.OmitZeroStruct) } if opts.StringifyMapKeysWithFmt { - enc.StringifyMapKeysWithFmt() + enc.SetBehavior(bson.StringifyMapKeysWithFmt) } if opts.UseJSONStructTags { - enc.UseJSONStructTags() + enc.SetBehavior(bson.UseJSONStructTags) } } @@ -154,10 +154,10 @@ func ensureID( doc bsoncore.Document, oid bson.ObjectID, bsonOpts *options.BSONOptions, - reg *bson.Registry, + registry *bson.Registry, ) (bsoncore.Document, interface{}, error) { - if reg == nil { - reg = bson.NewRegistryBuilder().Build() + if registry == nil { + registry = bson.NewRegistryBuilder().Build() } // Try to find the "_id" element. If it exists, try to unmarshal just the @@ -167,7 +167,7 @@ func ensureID( var id struct { ID interface{} `bson:"_id"` } - dec := getDecoder(doc, bsonOpts, reg) + dec := getDecoder(doc, bsonOpts, registry) err = dec.Decode(&id) if err != nil { return nil, nil, fmt.Errorf("error unmarshaling BSON document: %w", err) From 7f7a86b922cfdbb9b31d5f3b0a9aa4a0494a3455 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 24 May 2024 11:50:03 -0400 Subject: [PATCH 13/15] WIP --- bson/decoder.go | 6 +- bson/decoder_example_test.go | 10 ++- bson/decoder_test.go | 26 ++++--- bson/encoder.go | 6 +- bson/encoder_example_test.go | 21 ++++-- bson/encoder_test.go | 16 ++-- bson/registry.go | 19 +++-- bson/registry_option.go | 50 ++++++++----- bson/struct_codec.go | 65 ++++++++-------- bson/struct_tag_parser.go | 133 ++++++++++++++++++++------------- bson/struct_tag_parser_test.go | 128 ++++++++++++++++++++----------- bson/truncation_test.go | 22 ++++-- mongo/change_stream.go | 5 +- mongo/cursor.go | 38 +++++++--- mongo/mongo.go | 29 ++++--- mongo/single_result.go | 5 +- 16 files changed, 361 insertions(+), 218 deletions(-) diff --git a/bson/decoder.go b/bson/decoder.go index dd37150f60..7b0d0d68bd 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -18,7 +18,7 @@ var ErrDecodeToNil = errors.New("cannot Decode to nil value") // ConfigurableDecoderRegistry refers a DecoderRegistry that is configurable with *RegistryOpt. type ConfigurableDecoderRegistry interface { DecoderRegistry - SetCodecOptions(opts ...*RegistryOpt) + SetCodecOption(opt *RegistryOpt) error } // A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as @@ -82,6 +82,6 @@ func (d *Decoder) Decode(val interface{}) error { } // SetBehavior set the decoder behavior with *RegistryOpt. -func (d *Decoder) SetBehavior(opts ...*RegistryOpt) { - d.reg.SetCodecOptions(opts...) +func (d *Decoder) SetBehavior(opt *RegistryOpt) error { + return d.reg.SetCodecOption(opt) } diff --git a/bson/decoder_example_test.go b/bson/decoder_example_test.go index 60fb360710..590756090d 100644 --- a/bson/decoder_example_test.go +++ b/bson/decoder_example_test.go @@ -77,7 +77,10 @@ func ExampleDecoder_SetBehavior_defaultDocumentM() { // type if the decode destination has no type information. The Properties // field in the City struct will be decoded as a "M" (i.e. map) instead // of the default "D". - decoder.SetBehavior(bson.DefaultDocumentM) + err = decoder.SetBehavior(bson.DefaultDocumentM) + if err != nil { + panic(err) + } var res City err = decoder.Decode(&res) @@ -114,7 +117,10 @@ func ExampleDecoder_SetBehavior_useJSONStructTags() { // Configure the Decoder to use "json" struct tags when decoding if "bson" // struct tags are not present. - decoder.SetBehavior(bson.UseJSONStructTags) + err = decoder.SetBehavior(bson.UseJSONStructTags) + if err != nil { + panic(err) + } var res Product err = decoder.Decode(&res) diff --git a/bson/decoder_test.go b/bson/decoder_test.go index 1c884c6d56..6ff2ad7545 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -253,7 +253,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "AllowTruncatingDoubles", configure: func(dec *Decoder) { - dec.SetBehavior(AllowTruncatingDoubles) + _ = dec.SetBehavior(AllowTruncatingDoubles) }, input: bsoncore.NewDocumentBuilder(). AppendDouble("myInt", 1.999). @@ -286,7 +286,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "BinaryAsSlice", configure: func(dec *Decoder) { - dec.SetBehavior(BinaryAsSlice) + _ = dec.SetBehavior(BinaryAsSlice) }, input: bsoncore.NewDocumentBuilder(). AppendBinary("myBinary", TypeBinaryGeneric, []byte{}). @@ -299,7 +299,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "DefaultDocumentD nested", configure: func(dec *Decoder) { - dec.SetBehavior(DefaultDocumentD) + _ = dec.SetBehavior(DefaultDocumentD) }, input: bsoncore.NewDocumentBuilder(). AppendDocument("myDocument", bsoncore.NewDocumentBuilder(). @@ -316,7 +316,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "DefaultDocumentM nested", configure: func(dec *Decoder) { - dec.SetBehavior(DefaultDocumentM) + _ = dec.SetBehavior(DefaultDocumentM) }, input: bsoncore.NewDocumentBuilder(). AppendDocument("myDocument", bsoncore.NewDocumentBuilder(). @@ -333,7 +333,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "UseJSONStructTags", configure: func(dec *Decoder) { - dec.SetBehavior(UseJSONStructTags) + _ = dec.SetBehavior(UseJSONStructTags) }, input: bsoncore.NewDocumentBuilder(). AppendString("jsonFieldName", "test value"). @@ -346,7 +346,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "UseLocalTimeZone", configure: func(dec *Decoder) { - dec.SetBehavior(UseLocalTimeZone) + _ = dec.SetBehavior(UseLocalTimeZone) }, input: bsoncore.NewDocumentBuilder(). AppendDateTime("myTime", 1684349179939). @@ -359,7 +359,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "ZeroMaps", configure: func(dec *Decoder) { - dec.SetBehavior(ZeroMaps) + _ = dec.SetBehavior(ZeroMaps) }, input: bsoncore.NewDocumentBuilder(). AppendDocument("myMap", bsoncore.NewDocumentBuilder(). @@ -376,7 +376,7 @@ func TestDecoderConfiguration(t *testing.T) { { description: "ZeroStructs", configure: func(dec *Decoder) { - dec.SetBehavior(ZeroStructs) + _ = dec.SetBehavior(ZeroStructs) }, input: bsoncore.NewDocumentBuilder(). AppendString("myString", "test value"). @@ -417,10 +417,11 @@ func TestDecoderConfiguration(t *testing.T) { dec := NewDecoder(NewValueReader(input)) - dec.SetBehavior(DefaultDocumentM) + err := dec.SetBehavior(DefaultDocumentM) + require.NoError(t, err, "SetBehavior error") var got interface{} - err := dec.Decode(&got) + err = dec.Decode(&got) require.NoError(t, err, "Decode error") want := M{ @@ -441,10 +442,11 @@ func TestDecoderConfiguration(t *testing.T) { dec := NewDecoder(NewValueReader(input)) - dec.SetBehavior(DefaultDocumentD) + err := dec.SetBehavior(DefaultDocumentD) + require.NoError(t, err, "SetBehavior error") var got interface{} - err := dec.Decode(&got) + err = dec.Decode(&got) require.NoError(t, err, "Decode error") want := D{ diff --git a/bson/encoder.go b/bson/encoder.go index 53db25e5c1..1317bee79e 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -13,7 +13,7 @@ import ( // ConfigurableEncoderRegistry refers a EncoderRegistry that is configurable with *RegistryOpt. type ConfigurableEncoderRegistry interface { EncoderRegistry - SetCodecOptions(opts ...*RegistryOpt) + SetCodecOption(opt *RegistryOpt) error } // An Encoder writes a serialization format to an output stream. It writes to a ValueWriter @@ -61,6 +61,6 @@ func (e *Encoder) Encode(val interface{}) error { } // SetBehavior set the encoder behavior with *RegistryOpt. -func (e *Encoder) SetBehavior(opts ...*RegistryOpt) { - e.reg.SetCodecOptions(opts...) +func (e *Encoder) SetBehavior(opt *RegistryOpt) error { + return e.reg.SetCodecOption(opt) } diff --git a/bson/encoder_example_test.go b/bson/encoder_example_test.go index f56249b7ef..60e85fabd3 100644 --- a/bson/encoder_example_test.go +++ b/bson/encoder_example_test.go @@ -64,9 +64,12 @@ func ExampleEncoder_SetBehavior_intMinSize() { vw := bson.NewValueWriter(buf) enc := bson.NewEncoder(vw) - enc.SetBehavior(bson.IntMinSize) + err := enc.SetBehavior(bson.IntMinSize) + if err != nil { + panic(err) + } - err := enc.Encode(foo{2}) + err = enc.Encode(foo{2}) if err != nil { panic(err) } @@ -84,14 +87,17 @@ func ExampleEncoder_SetBehavior_stringifyMapKeysWithFmt() { // Configure the Encoder to convert Go map keys to BSON document field names // using fmt.Sprintf instead of the default string conversion logic. - encoder.SetBehavior(bson.StringifyMapKeysWithFmt) + err := encoder.SetBehavior(bson.StringifyMapKeysWithFmt) + if err != nil { + panic(err) + } // Use the Encoder to marshal a BSON document that contains is a map of // city and state to a list of zip codes in that city. zipCodes := map[CityState][]int{ {City: "New York", State: "NY"}: {10001, 10301, 10451}, } - err := encoder.Encode(zipCodes) + err = encoder.Encode(zipCodes) if err != nil { panic(err) } @@ -115,7 +121,10 @@ func ExampleEncoder_SetBehavior_useJSONStructTags() { // Configure the Encoder to use "json" struct tags when decoding if "bson" // struct tags are not present. - encoder.SetBehavior(bson.UseJSONStructTags) + err := encoder.SetBehavior(bson.UseJSONStructTags) + if err != nil { + panic(err) + } // Use the Encoder to marshal a BSON document that contains the name, SKU, // and price (in cents) of a product. @@ -124,7 +133,7 @@ func ExampleEncoder_SetBehavior_useJSONStructTags() { SKU: "AB12345", Price: 399, } - err := encoder.Encode(product) + err = encoder.Encode(product) if err != nil { panic(err) } diff --git a/bson/encoder_test.go b/bson/encoder_test.go index 1110fc7d8d..a9f7376b5a 100644 --- a/bson/encoder_test.go +++ b/bson/encoder_test.go @@ -160,7 +160,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "ErrorOnInlineDuplicates", configure: func(enc *Encoder) { - enc.SetBehavior(ErrorOnInlineDuplicates) + _ = enc.SetBehavior(ErrorOnInlineDuplicates) }, input: inlineDuplicateOuter{ Inline: inlineDuplicateInner{Duplicate: "inner"}, @@ -173,7 +173,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "IntMinSize", configure: func(enc *Encoder) { - enc.SetBehavior(IntMinSize) + _ = enc.SetBehavior(IntMinSize) }, input: D{ {Key: "myInt", Value: int(1)}, @@ -194,7 +194,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "StringifyMapKeysWithFmt", configure: func(enc *Encoder) { - enc.SetBehavior(StringifyMapKeysWithFmt) + _ = enc.SetBehavior(StringifyMapKeysWithFmt) }, input: map[stringerTest]string{ {}: "test value", @@ -207,7 +207,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "NilMapAsEmpty", configure: func(enc *Encoder) { - enc.SetBehavior(NilMapAsEmpty) + _ = enc.SetBehavior(NilMapAsEmpty) }, input: D{{Key: "myMap", Value: map[string]string(nil)}}, want: bsoncore.NewDocumentBuilder(). @@ -218,7 +218,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "NilSliceAsEmpty", configure: func(enc *Encoder) { - enc.SetBehavior(NilSliceAsEmpty) + _ = enc.SetBehavior(NilSliceAsEmpty) }, input: D{{Key: "mySlice", Value: []string(nil)}}, want: bsoncore.NewDocumentBuilder(). @@ -229,7 +229,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "NilByteSliceAsEmpty", configure: func(enc *Encoder) { - enc.SetBehavior(NilByteSliceAsEmpty) + _ = enc.SetBehavior(NilByteSliceAsEmpty) }, input: D{{Key: "myBytes", Value: []byte(nil)}}, want: bsoncore.NewDocumentBuilder(). @@ -241,7 +241,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "OmitZeroStruct", configure: func(enc *Encoder) { - enc.SetBehavior(OmitZeroStruct) + _ = enc.SetBehavior(OmitZeroStruct) }, input: struct { Zero zeroStruct `bson:",omitempty"` @@ -253,7 +253,7 @@ func TestEncoderConfiguration(t *testing.T) { { description: "UseJSONStructTags", configure: func(enc *Encoder) { - enc.SetBehavior(UseJSONStructTags) + _ = enc.SetBehavior(UseJSONStructTags) }, input: struct { StructFieldName string `json:"jsonFieldName"` diff --git a/bson/registry.go b/bson/registry.go index 6deb66c6a3..cf88c703ba 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -333,16 +333,21 @@ type Registry struct { codecTypeMap map[reflect.Type][]interface{} } -// SetCodecOptions configures Registry using a *RegistryOpt. -func (r *Registry) SetCodecOptions(opts ...*RegistryOpt) { - for _, opt := range opts { - v, ok := r.codecTypeMap[opt.typ] - if ok && v != nil { - for i := range v { - _ = opt.fn.Call([]reflect.Value{reflect.ValueOf(v[i])}) +// SetCodecOption configures Registry using a *RegistryOpt. +func (r *Registry) SetCodecOption(opt *RegistryOpt) error { + v, ok := r.codecTypeMap[opt.typ] + if !ok || len(v) == 0 { + return fmt.Errorf("could not find codec %s", opt.typ.String()) + } + for i := range v { + rtns := opt.fn.Call([]reflect.Value{reflect.ValueOf(v[i])}) + for _, r := range rtns { + if !r.IsNil() { + return r.Interface().(error) } } } + return nil } // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup diff --git a/bson/registry_option.go b/bson/registry_option.go index b655b1af47..5f07052f5c 100644 --- a/bson/registry_option.go +++ b/bson/registry_option.go @@ -27,7 +27,7 @@ type RegistryOpt struct { // reg.SetCodecOptions(opt) // // The "attr" field in the registered Codec can be set to "value". -func NewRegistryOpt[T any](fn func(T)) *RegistryOpt { +func NewRegistryOpt[T any](fn func(T) error) *RegistryOpt { var zero [0]T return &RegistryOpt{ typ: reflect.TypeOf(zero).Elem(), @@ -37,75 +37,87 @@ func NewRegistryOpt[T any](fn func(T)) *RegistryOpt { // NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values // instead of BSON null. -var NilByteSliceAsEmpty = NewRegistryOpt(func(c *byteSliceCodec) { +var NilByteSliceAsEmpty = NewRegistryOpt(func(c *byteSliceCodec) error { c.encodeNilAsEmpty = true + return nil }) // BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or // "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary. -var BinaryAsSlice = NewRegistryOpt(func(c *emptyInterfaceCodec) { +var BinaryAsSlice = NewRegistryOpt(func(c *emptyInterfaceCodec) error { c.decodeBinaryAsSlice = true + return nil }) // DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This // behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -var DefaultDocumentM = NewRegistryOpt(func(c *emptyInterfaceCodec) { +var DefaultDocumentM = NewRegistryOpt(func(c *emptyInterfaceCodec) error { c.defaultDocumentType = reflect.TypeOf(M{}) + return nil }) // DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This // behavior is restricted to data typed as "interface{}" or "map[string]interface{}". -var DefaultDocumentD = NewRegistryOpt(func(c *emptyInterfaceCodec) { +var DefaultDocumentD = NewRegistryOpt(func(c *emptyInterfaceCodec) error { c.defaultDocumentType = reflect.TypeOf(D{}) + return nil }) // NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON // null. -var NilMapAsEmpty = NewRegistryOpt(func(c *mapCodec) { +var NilMapAsEmpty = NewRegistryOpt(func(c *mapCodec) error { c.encodeNilAsEmpty = true + return nil }) // StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name // strings using fmt.Sprint instead of the default string conversion logic. -var StringifyMapKeysWithFmt = NewRegistryOpt(func(c *mapCodec) { +var StringifyMapKeysWithFmt = NewRegistryOpt(func(c *mapCodec) error { c.encodeKeysWithStringer = true + return nil }) // ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value // passed to Decode before unmarshaling BSON documents into them. -var ZeroMaps = NewRegistryOpt(func(c *mapCodec) { +var ZeroMaps = NewRegistryOpt(func(c *mapCodec) error { c.decodeZerosMap = true + return nil }) // AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values // when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct // field. The truncation logic does not apply to BSON "decimal128" values. -var AllowTruncatingDoubles = NewRegistryOpt(func(c *numCodec) { +var AllowTruncatingDoubles = NewRegistryOpt(func(c *numCodec) error { c.truncate = true + return nil }) // IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint, // uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can // represent the integer value. -var IntMinSize = NewRegistryOpt(func(c *numCodec) { +var IntMinSize = NewRegistryOpt(func(c *numCodec) error { c.minSize = true + return nil }) // NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON // null. -var NilSliceAsEmpty = NewRegistryOpt(func(c *sliceCodec) { +var NilSliceAsEmpty = NewRegistryOpt(func(c *sliceCodec) error { c.encodeNilAsEmpty = true + return nil }) // DecodeObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. -var DecodeObjectIDAsHex = NewRegistryOpt(func(c *stringCodec) { +var DecodeObjectIDAsHex = NewRegistryOpt(func(c *stringCodec) error { c.decodeObjectIDAsHex = true + return nil }) // ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in // the marshaled BSON when the "inline" struct tag option is set. -var ErrorOnInlineDuplicates = NewRegistryOpt(func(c *structCodec) { +var ErrorOnInlineDuplicates = NewRegistryOpt(func(c *structCodec) error { c.overwriteDuplicatedInlinedFields = false + return nil }) // TODO(GODRIVER-2820): Update the description to remove the note about only examining exported @@ -116,24 +128,28 @@ var ErrorOnInlineDuplicates = NewRegistryOpt(func(c *structCodec) { // // Note that the Encoder only examines exported struct fields when determining if a struct is the // zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. -var OmitZeroStruct = NewRegistryOpt(func(c *structCodec) { +var OmitZeroStruct = NewRegistryOpt(func(c *structCodec) error { c.encodeOmitDefaultStruct = true + return nil }) // UseJSONStructTags causes the Encoder and Decoder to fall back to using the "json" struct tag if // a "bson" struct tag is not specified. -var UseJSONStructTags = NewRegistryOpt(func(c *structCodec) { +var UseJSONStructTags = NewRegistryOpt(func(c *structCodec) error { c.useJSONStructTags = true + return nil }) // ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination // value passed to Decode before unmarshaling BSON documents into them. -var ZeroStructs = NewRegistryOpt(func(c *structCodec) { +var ZeroStructs = NewRegistryOpt(func(c *structCodec) error { c.decodeZeroStruct = true + return nil }) // UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead // of the UTC timezone. -var UseLocalTimeZone = NewRegistryOpt(func(c *timeCodec) { +var UseLocalTimeZone = NewRegistryOpt(func(c *timeCodec) error { c.useLocalTimeZone = true + return nil }) diff --git a/bson/struct_codec.go b/bson/struct_codec.go index dc42b31150..24ea1a2018 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -50,7 +50,7 @@ func (de *DecodeError) Keys() []string { // structCodec is the Codec used for struct values. type structCodec struct { cache sync.Map // map[reflect.Type]*structDescription - parser StructTagParser + parser structTagParser // decodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the decodeZeroStruct bool @@ -76,7 +76,7 @@ type structCodec struct { } // newStructCodec returns a StructCodec that uses p for struct tag parsing. -func newStructCodec(p StructTagParser) *structCodec { +func newStructCodec(p structTagParser) *structCodec { return &structCodec{ parser: p, overwriteDuplicatedInlinedFields: true, @@ -84,25 +84,12 @@ func newStructCodec(p StructTagParser) *structCodec { } type localEncoderRegistry struct { - registry EncoderRegistry - - minSize bool + registry EncoderRegistry + encoderLookup func(EncoderRegistry, reflect.Type) (ValueEncoder, error) } func (r *localEncoderRegistry) LookupEncoder(t reflect.Type) (ValueEncoder, error) { - ve, err := r.registry.LookupEncoder(t) - if err != nil { - return ve, err - } - if r.minSize { - if ic, ok := ve.(*numCodec); ok { - ve = &numCodec{ - minSize: true, - truncate: ic.truncate, - } - } - } - return ve, nil + return r.encoderLookup(r.registry, t) } // EncodeValue handles encoding generic struct types. @@ -132,8 +119,8 @@ func (sc *structCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val refl } reg = &localEncoderRegistry{ - registry: reg, - minSize: desc.minSize, + registry: reg, + encoderLookup: desc.encoderLookup, } var encoder ValueEncoder @@ -395,14 +382,13 @@ type structDescription struct { } type fieldDescription struct { - name string // BSON key name - fieldName string // struct field name - idx int - omitEmpty bool - minSize bool - truncate bool - inline []int - fieldType reflect.Type + name string // BSON key name + fieldName string // struct field name + idx int + inline []int + omitEmpty bool + fieldType reflect.Type + encoderLookup func(EncoderRegistry, reflect.Type) (ValueEncoder, error) } type byIndex []fieldDescription @@ -484,14 +470,14 @@ func (sc *structCodec) describeStructSlow( fieldType: sfType, } - var stags StructTags + var stags *structTags var err error // If the caller requested that we use JSON struct tags, use the JSONFallbackStructTagParser // instead of the parser defined on the codec. if useJSONStructTags { - stags, err = JSONFallbackStructTagParser.ParseStructTags(sf) + stags, err = sc.parser.parseJSONStructTags(sf) } else { - stags, err = sc.parser.ParseStructTags(sf) + stags, err = sc.parser.parseStructTags(sf) } if err != nil { return nil, err @@ -501,8 +487,21 @@ func (sc *structCodec) describeStructSlow( } description.name = stags.Name description.omitEmpty = stags.OmitEmpty - description.minSize = stags.MinSize - description.truncate = stags.Truncate + description.encoderLookup = func(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { + if stags.LookupEncoderOnMinSize != nil { + reg = &localEncoderRegistry{ + registry: reg, + encoderLookup: stags.LookupEncoderOnMinSize.LookupEncoder, + } + } + if stags.LookupEncoderOnTruncate != nil { + reg = &localEncoderRegistry{ + registry: reg, + encoderLookup: stags.LookupEncoderOnTruncate.LookupEncoder, + } + } + return reg.LookupEncoder(t) + } if stags.Inline { sd.inline = true diff --git a/bson/struct_tag_parser.go b/bson/struct_tag_parser.go index d116c14040..30b0e9815d 100644 --- a/bson/struct_tag_parser.go +++ b/bson/struct_tag_parser.go @@ -11,25 +11,57 @@ import ( "strings" ) -// StructTagParser returns the struct tags for a given struct field. -// -// Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. -type StructTagParser interface { - ParseStructTags(reflect.StructField) (StructTags, error) +// structTagParser returns the struct tags for a given reflect.StructField. +type structTagParser interface { + parseStructTags(reflect.StructField) (*structTags, error) + parseJSONStructTags(reflect.StructField) (*structTags, error) } -// StructTagParserFunc is an adapter that allows a generic function to be used -// as a StructTagParser. -// -// Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. -type StructTagParserFunc func(reflect.StructField) (StructTags, error) +// DefaultStructTagParser is the StructTagParser used by the StructCodec by default. +var DefaultStructTagParser = &StructTagParser{ + LookupEncoderOnMinSize: retrieverOnMinSize{}, + LookupEncoderOnTruncate: retrieverOnTruncate{}, +} + +type retrieverOnMinSize struct{} + +func (retrieverOnMinSize) LookupEncoder(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { + enc, err := reg.LookupEncoder(t) + if err != nil { + return enc, err + } + switch t.Kind() { + case reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + if codec, ok := enc.(*numCodec); ok { + c := *codec + c.minSize = true + return &c, nil + } + } + return enc, nil +} + +type retrieverOnTruncate struct{} -// ParseStructTags implements the StructTagParser interface. -func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructTags, error) { - return stpf(sf) +func (retrieverOnTruncate) LookupEncoder(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { + enc, err := reg.LookupEncoder(t) + if err != nil { + return enc, err + } + switch t.Kind() { + case reflect.Float32, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + if codec, ok := enc.(*numCodec); ok { + c := *codec + c.truncate = true + return &c, nil + } + } + return enc, nil } -// StructTags represents the struct tag fields that the StructCodec uses during +// structTags represents the struct tag fields that the StructCodec uses during // the encoding and decoding process. // // In the case of a struct, the lowercased field name is used as the key for each exported @@ -38,34 +70,37 @@ func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructT // // The properties are defined below: // -// OmitEmpty Only include the field if it's not set to the zero value for the type or to -// empty slices or maps. -// -// MinSize Marshal an integer of a type larger than 32 bits value as an int32, if that's -// feasible while preserving the numeric value. -// -// Truncate When unmarshaling a BSON double, it is permitted to lose precision to fit within -// a float32. -// -// Inline Inline the field, which must be a struct or a map, causing all of its fields +// inline Inline the field, which must be a struct or a map, causing all of its fields // or keys to be processed as if they were part of the outer struct. For maps, // keys must not conflict with the bson keys of other struct fields. // -// Skip This struct field should be skipped. This is usually denoted by parsing a "-" -// for the name. +// omitEmpty Only include the field if it's not set to the zero value for the type or to +// empty slices or maps. // -// Deprecated: Defining custom BSON struct tag parsers will not be supported in Go Driver 2.0. -type StructTags struct { +// skip This struct field should be skipped. This is usually denoted by parsing a "-" +// for the name. +type structTags struct { Name string - OmitEmpty bool - MinSize bool - Truncate bool Inline bool + OmitEmpty bool Skip bool + + LookupEncoderOnMinSize EncoderRetriever + LookupEncoderOnTruncate EncoderRetriever } -// DefaultStructTagParser is the StructTagParser used by the StructCodec by default. -// It will handle the bson struct tag. See the documentation for StructTags to see +// EncoderRetriever is used to look up ValueEncoder with given EncoderRegistry and reflect.Type. +type EncoderRetriever interface { + LookupEncoder(EncoderRegistry, reflect.Type) (ValueEncoder, error) +} + +// StructTagParser defines the encoder lookup bahavior when minSize and truncate tags are set. +type StructTagParser struct { + LookupEncoderOnMinSize EncoderRetriever + LookupEncoderOnTruncate EncoderRetriever +} + +// parseStructTags handles the bson struct tag. See the documentation for StructTags to see // what each of the returned fields means. // // If there is no name in the struct tag fields, the struct field name is lowercased. @@ -89,22 +124,20 @@ type StructTags struct { // A struct tag either consisting entirely of '-' or with a bson key with a // value consisting entirely of '-' will return a StructTags with Skip true and // the remaining fields will be their default values. -// -// Deprecated: DefaultStructTagParser will be removed in Go Driver 2.0. -var DefaultStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) { +func (p *StructTagParser) parseStructTags(sf reflect.StructField) (*structTags, error) { key := strings.ToLower(sf.Name) tag, ok := sf.Tag.Lookup("bson") if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { tag = string(sf.Tag) } - return parseTags(key, tag) + return p.parseTags(key, tag) } -func parseTags(key string, tag string) (StructTags, error) { - var st StructTags +func (p *StructTagParser) parseTags(key string, tag string) (*structTags, error) { + var st structTags if tag == "-" { st.Skip = true - return st, nil + return &st, nil } for idx, str := range strings.Split(tag, ",") { @@ -112,29 +145,25 @@ func parseTags(key string, tag string) (StructTags, error) { key = str } switch str { + case "inline": + st.Inline = true case "omitempty": st.OmitEmpty = true case "minsize": - st.MinSize = true + st.LookupEncoderOnMinSize = p.LookupEncoderOnMinSize case "truncate": - st.Truncate = true - case "inline": - st.Inline = true + st.LookupEncoderOnTruncate = p.LookupEncoderOnTruncate } } st.Name = key - return st, nil + return &st, nil } -// JSONFallbackStructTagParser has the same behavior as DefaultStructTagParser -// but will also fallback to parsing the json tag instead on a field where the +// parseJSONStructTags parses the json tag instead on a field where the // bson tag isn't available. -// -// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Encoder.UseJSONStructTags] and -// [go.mongodb.org/mongo-driver/bson.Decoder.UseJSONStructTags] instead. -var JSONFallbackStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) { +func (p *StructTagParser) parseJSONStructTags(sf reflect.StructField) (*structTags, error) { key := strings.ToLower(sf.Name) tag, ok := sf.Tag.Lookup("bson") if !ok { @@ -144,5 +173,5 @@ var JSONFallbackStructTagParser StructTagParserFunc = func(sf reflect.StructFiel tag = string(sf.Tag) } - return parseTags(key, tag) + return p.parseTags(key, tag) } diff --git a/bson/struct_tag_parser_test.go b/bson/struct_tag_parser_test.go index b03815488a..312d1ab7f3 100644 --- a/bson/struct_tag_parser_test.go +++ b/bson/struct_tag_parser_test.go @@ -17,134 +17,174 @@ func TestStructTagParsers(t *testing.T) { testCases := []struct { name string sf reflect.StructField - want StructTags - parser StructTagParserFunc + want *structTags + parser func(reflect.StructField) (*structTags, error) }{ { "default no bson tag", reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")}, - StructTags{Name: "bar"}, - DefaultStructTagParser, + &structTags{Name: "bar"}, + DefaultStructTagParser.parseStructTags, }, { "default empty", reflect.StructField{Name: "foo", Tag: reflect.StructTag("")}, - StructTags{Name: "foo"}, - DefaultStructTagParser, + &structTags{Name: "foo"}, + DefaultStructTagParser.parseStructTags, }, { "default tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")}, - StructTags{Skip: true}, - DefaultStructTagParser, + &structTags{Skip: true}, + DefaultStructTagParser.parseStructTags, }, { "default bson tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)}, - StructTags{Skip: true}, - DefaultStructTagParser, + &structTags{Skip: true}, + DefaultStructTagParser.parseStructTags, }, { "default all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{ + Name: "bar", Inline: true, OmitEmpty: true, + LookupEncoderOnMinSize: retrieverOnMinSize{}, + LookupEncoderOnTruncate: retrieverOnTruncate{}, + }, + DefaultStructTagParser.parseStructTags, }, { "default all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{ + Name: "foo", Inline: true, OmitEmpty: true, + LookupEncoderOnMinSize: retrieverOnMinSize{}, + LookupEncoderOnTruncate: retrieverOnTruncate{}, + }, + DefaultStructTagParser.parseStructTags, }, { "default bson tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{ + Name: "bar", Inline: true, OmitEmpty: true, + LookupEncoderOnMinSize: retrieverOnMinSize{}, + LookupEncoderOnTruncate: retrieverOnTruncate{}, + }, + DefaultStructTagParser.parseStructTags, }, { "default bson tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - DefaultStructTagParser, + &structTags{ + Name: "foo", Inline: true, OmitEmpty: true, + LookupEncoderOnMinSize: retrieverOnMinSize{}, + LookupEncoderOnTruncate: retrieverOnTruncate{}, + }, + DefaultStructTagParser.parseStructTags, }, { "default ignore xml", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)}, - StructTags{Name: "foo"}, - DefaultStructTagParser, + &structTags{Name: "foo"}, + DefaultStructTagParser.parseStructTags, }, { "JSONFallback no bson tag", reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")}, - StructTags{Name: "bar"}, - JSONFallbackStructTagParser, + &structTags{Name: "bar"}, + DefaultStructTagParser.parseJSONStructTags, }, { "JSONFallback empty", reflect.StructField{Name: "foo", Tag: reflect.StructTag("")}, - StructTags{Name: "foo"}, - JSONFallbackStructTagParser, + &structTags{Name: "foo"}, + DefaultStructTagParser.parseJSONStructTags, }, { "JSONFallback tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")}, - StructTags{Skip: true}, - JSONFallbackStructTagParser, + &structTags{Skip: true}, + DefaultStructTagParser.parseJSONStructTags, }, { "JSONFallback bson tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)}, - StructTags{Skip: true}, - JSONFallbackStructTagParser, + &structTags{Skip: true}, + DefaultStructTagParser.parseJSONStructTags, }, { "JSONFallback all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{ + Name: "bar", Inline: true, OmitEmpty: true, + LookupEncoderOnMinSize: retrieverOnMinSize{}, + LookupEncoderOnTruncate: retrieverOnTruncate{}, + }, + DefaultStructTagParser.parseJSONStructTags, }, { "JSONFallback all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{ + Name: "foo", Inline: true, OmitEmpty: true, + LookupEncoderOnMinSize: retrieverOnMinSize{}, + LookupEncoderOnTruncate: retrieverOnTruncate{}, + }, + DefaultStructTagParser.parseJSONStructTags, }, { "JSONFallback bson tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{ + Name: "bar", Inline: true, OmitEmpty: true, + LookupEncoderOnMinSize: retrieverOnMinSize{}, + LookupEncoderOnTruncate: retrieverOnTruncate{}, + }, + DefaultStructTagParser.parseJSONStructTags, }, { "JSONFallback bson tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{ + Name: "foo", Inline: true, OmitEmpty: true, + LookupEncoderOnMinSize: retrieverOnMinSize{}, + LookupEncoderOnTruncate: retrieverOnTruncate{}, + }, + DefaultStructTagParser.parseJSONStructTags, }, { "JSONFallback json tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:"bar,omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{ + Name: "bar", Inline: true, OmitEmpty: true, + LookupEncoderOnMinSize: retrieverOnMinSize{}, + LookupEncoderOnTruncate: retrieverOnTruncate{}, + }, + DefaultStructTagParser.parseJSONStructTags, }, { "JSONFallback json tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:",omitempty,minsize,truncate,inline"`)}, - StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, - JSONFallbackStructTagParser, + &structTags{ + Name: "foo", Inline: true, OmitEmpty: true, + LookupEncoderOnMinSize: retrieverOnMinSize{}, + LookupEncoderOnTruncate: retrieverOnTruncate{}, + }, + DefaultStructTagParser.parseJSONStructTags, }, { "JSONFallback bson tag overrides other tags", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar" json:"qux,truncate"`)}, - StructTags{Name: "bar"}, - JSONFallbackStructTagParser, + &structTags{Name: "bar"}, + DefaultStructTagParser.parseJSONStructTags, }, { "JSONFallback ignore xml", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)}, - StructTags{Name: "foo"}, - JSONFallbackStructTagParser, + &structTags{Name: "foo"}, + DefaultStructTagParser.parseJSONStructTags, }, } diff --git a/bson/truncation_test.go b/bson/truncation_test.go index 311b2942d4..e0a1579494 100644 --- a/bson/truncation_test.go +++ b/bson/truncation_test.go @@ -33,16 +33,19 @@ func TestTruncation(t *testing.T) { buf := new(bytes.Buffer) vw := NewValueWriter(buf) enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), vw) - enc.SetBehavior(IntMinSize) - err := enc.Encode(&input) + err := enc.SetBehavior(IntMinSize) + assert.Nil(t, err) + err = enc.Encode(&input) assert.Nil(t, err) var output outputArgs - opt := NewRegistryOpt(func(c *numCodec) { + opt := NewRegistryOpt(func(c *numCodec) error { c.truncate = true + return nil }) reg := NewRegistryBuilder().Build() - reg.SetCodecOptions(opt) + err = reg.SetCodecOption(opt) + assert.Nil(t, err) err = UnmarshalWithContext(reg, buf.Bytes(), &output) assert.Nil(t, err) @@ -59,16 +62,19 @@ func TestTruncation(t *testing.T) { buf := new(bytes.Buffer) vw := NewValueWriter(buf) enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), vw) - enc.SetBehavior(IntMinSize) - err := enc.Encode(&input) + err := enc.SetBehavior(IntMinSize) + assert.Nil(t, err) + err = enc.Encode(&input) assert.Nil(t, err) var output outputArgs - opt := NewRegistryOpt(func(c *numCodec) { + opt := NewRegistryOpt(func(c *numCodec) error { c.truncate = false + return nil }) reg := NewRegistryBuilder().Build() - reg.SetCodecOptions(opt) + err = reg.SetCodecOption(opt) + assert.Nil(t, err) // case throws an error when truncation is disabled err = UnmarshalWithContext(reg, buf.Bytes(), &output) diff --git a/mongo/change_stream.go b/mongo/change_stream.go index d578abce63..a63dcd8bb7 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -596,7 +596,10 @@ func (cs *ChangeStream) Decode(val interface{}) error { return ErrNilCursor } - dec := getDecoder(cs.Current, cs.bsonOpts, cs.registry) + dec, err := getDecoder(cs.Current, cs.bsonOpts, cs.registry) + if err != nil { + return err + } return dec.Decode(val) } diff --git a/mongo/cursor.go b/mongo/cursor.go index 703aebf90b..251d190e89 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -234,7 +234,7 @@ func getDecoder( data []byte, opts *options.BSONOptions, reg *bson.Registry, -) *bson.Decoder { +) (*bson.Decoder, error) { vr := bson.NewValueReader(data) var dec *bson.Decoder if reg != nil { @@ -244,39 +244,49 @@ func getDecoder( } if opts != nil { + regOpts := []*bson.RegistryOpt{} if opts.AllowTruncatingDoubles { - dec.SetBehavior(bson.AllowTruncatingDoubles) + regOpts = append(regOpts, bson.AllowTruncatingDoubles) } if opts.BinaryAsSlice { - dec.SetBehavior(bson.BinaryAsSlice) + regOpts = append(regOpts, bson.BinaryAsSlice) } if opts.DefaultDocumentD { - dec.SetBehavior(bson.DefaultDocumentD) + regOpts = append(regOpts, bson.DefaultDocumentD) } if opts.DefaultDocumentM { - dec.SetBehavior(bson.DefaultDocumentM) + regOpts = append(regOpts, bson.DefaultDocumentM) } if opts.UseJSONStructTags { - dec.SetBehavior(bson.UseJSONStructTags) + regOpts = append(regOpts, bson.UseJSONStructTags) } if opts.UseLocalTimeZone { - dec.SetBehavior(bson.UseLocalTimeZone) + regOpts = append(regOpts, bson.UseLocalTimeZone) } if opts.ZeroMaps { - dec.SetBehavior(bson.ZeroMaps) + regOpts = append(regOpts, bson.ZeroMaps) } if opts.ZeroStructs { - dec.SetBehavior(bson.ZeroStructs) + regOpts = append(regOpts, bson.ZeroStructs) + } + for _, opt := range regOpts { + err := dec.SetBehavior(opt) + if err != nil { + return nil, err + } } } - return dec + return dec, nil } // Decode will unmarshal the current document into val and return any errors from the unmarshalling process without any // modification. If val is nil or is a typed nil, an error will be returned. func (c *Cursor) Decode(val interface{}) error { - dec := getDecoder(c.Current, c.bsonOpts, c.registry) + dec, err := getDecoder(c.Current, c.bsonOpts, c.registry) + if err != nil { + return err + } return dec.Decode(val) } @@ -367,7 +377,11 @@ func (c *Cursor) addFromBatch(sliceVal reflect.Value, elemType reflect.Type, bat } currElem := sliceVal.Index(index).Addr().Interface() - dec := getDecoder(doc, c.bsonOpts, c.registry) + var dec *bson.Decoder + dec, err = getDecoder(doc, c.bsonOpts, c.registry) + if err != nil { + return sliceVal, index, err + } err = dec.Decode(currElem) if err != nil { return sliceVal, index, err diff --git a/mongo/mongo.go b/mongo/mongo.go index 1112a297db..8cd6258e38 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -71,29 +71,36 @@ func getEncoder( } if opts != nil { + regOpts := []*bson.RegistryOpt{} if opts.ErrorOnInlineDuplicates { - enc.SetBehavior(bson.ErrorOnInlineDuplicates) + regOpts = append(regOpts, bson.ErrorOnInlineDuplicates) } if opts.IntMinSize { - enc.SetBehavior(bson.IntMinSize) + regOpts = append(regOpts, bson.IntMinSize) } if opts.NilByteSliceAsEmpty { - enc.SetBehavior(bson.NilByteSliceAsEmpty) + regOpts = append(regOpts, bson.NilByteSliceAsEmpty) } if opts.NilMapAsEmpty { - enc.SetBehavior(bson.NilMapAsEmpty) + regOpts = append(regOpts, bson.NilMapAsEmpty) } if opts.NilSliceAsEmpty { - enc.SetBehavior(bson.NilSliceAsEmpty) + regOpts = append(regOpts, bson.NilSliceAsEmpty) } if opts.OmitZeroStruct { - enc.SetBehavior(bson.OmitZeroStruct) + regOpts = append(regOpts, bson.OmitZeroStruct) } if opts.StringifyMapKeysWithFmt { - enc.SetBehavior(bson.StringifyMapKeysWithFmt) + regOpts = append(regOpts, bson.StringifyMapKeysWithFmt) } if opts.UseJSONStructTags { - enc.SetBehavior(bson.UseJSONStructTags) + regOpts = append(regOpts, bson.UseJSONStructTags) + } + for _, opt := range regOpts { + err := enc.SetBehavior(opt) + if err != nil { + return nil, err + } } } @@ -167,7 +174,11 @@ func ensureID( var id struct { ID interface{} `bson:"_id"` } - dec := getDecoder(doc, bsonOpts, registry) + var dec *bson.Decoder + dec, err = getDecoder(doc, bsonOpts, registry) + if err != nil { + return nil, nil, fmt.Errorf("error unmarshaling BSON document: %w", err) + } err = dec.Decode(&id) if err != nil { return nil, nil, fmt.Errorf("error unmarshaling BSON document: %w", err) diff --git a/mongo/single_result.go b/mongo/single_result.go index f467666167..d6223d31ee 100644 --- a/mongo/single_result.go +++ b/mongo/single_result.go @@ -73,7 +73,10 @@ func (sr *SingleResult) Decode(v interface{}) error { return sr.err } - dec := getDecoder(sr.rdr, sr.bsonOpts, sr.reg) + dec, err := getDecoder(sr.rdr, sr.bsonOpts, sr.reg) + if err != nil { + return err + } return dec.Decode(v) } From e9b3bfaaeae2b5ed6fc21890d63fb6b163f4d5f1 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 24 May 2024 22:38:01 -0400 Subject: [PATCH 14/15] WIP --- bson/default_value_decoders.go | 5 +- bson/default_value_decoders_test.go | 57 +++++++++------ bson/default_value_encoders.go | 13 +++- bson/default_value_encoders_test.go | 89 +++++++++++++++++------ bson/map_codec.go | 8 ++- bson/marshal.go | 6 +- bson/marshal_test.go | 2 +- bson/mgocompat/bson_test.go | 42 +++++++++-- bson/mgoregistry.go | 2 +- bson/primitive_codecs.go | 3 +- bson/raw_value.go | 2 +- bson/registry.go | 62 +++++++++++++++- bson/registry_examples_test.go | 12 ++-- bson/registry_option.go | 4 +- bson/registry_test.go | 1 + bson/struct_codec.go | 35 ++++++--- bson/struct_tag_parser.go | 108 ++++++---------------------- bson/struct_tag_parser_test.go | 102 ++++++++------------------ bson/truncation_test.go | 4 +- bson/unmarshal.go | 25 +------ bson/unmarshal_test.go | 30 ++------ bson/unmarshal_value_test.go | 2 +- 22 files changed, 327 insertions(+), 287 deletions(-) diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index 5c28ec9a86..e56b3d2faa 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -41,7 +41,8 @@ func registerDefaultDecoders(rb *RegistryBuilder) { } numDecoder := func() ValueDecoder { return &numCodec{} } - rb.RegisterTypeDecoder(tD, func() ValueDecoder { return ValueDecoderFunc(dDecodeValue) }). + rb. + RegisterTypeDecoder(tD, func() ValueDecoder { return ValueDecoderFunc(dDecodeValue) }). RegisterTypeDecoder(tBinary, func() ValueDecoder { return &decodeAdapter{binaryDecodeValue, binaryDecodeType} }). RegisterTypeDecoder(tUndefined, func() ValueDecoder { return &decodeAdapter{undefinedDecodeValue, undefinedDecodeType} }). RegisterTypeDecoder(tDateTime, func() ValueDecoder { return &decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType} }). @@ -80,7 +81,7 @@ func registerDefaultDecoders(rb *RegistryBuilder) { RegisterKindDecoder(reflect.Map, func() ValueDecoder { return &mapCodec{} }). RegisterKindDecoder(reflect.Slice, func() ValueDecoder { return &sliceCodec{} }). RegisterKindDecoder(reflect.String, func() ValueDecoder { return &stringCodec{} }). - RegisterKindDecoder(reflect.Struct, func() ValueDecoder { return newStructCodec(DefaultStructTagParser) }). + RegisterKindDecoder(reflect.Struct, func() ValueDecoder { return newStructCodec(rb.StructTagHandler()) }). RegisterKindDecoder(reflect.Ptr, func() ValueDecoder { return &pointerCodec{} }). RegisterTypeMapEntry(TypeDouble, tFloat64). RegisterTypeMapEntry(TypeString, tString). diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 019057ea00..d07df571ec 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -23,7 +23,7 @@ import ( ) var ( - defaultTestStructCodec = newStructCodec(DefaultStructTagParser) + defaultTestStructCodec = newStructCodec(DefaultStructTagHandler()) ) func TestDefaultValueDecoders(t *testing.T) { @@ -194,11 +194,6 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.00)}, readDouble, nil, }, - // { - // "ReadDouble (truncate)", int64(3), &DecodeContext{truncate: true}, - // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - // nil, - // }, { "ReadDouble (no truncate)", int64(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -454,11 +449,6 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.00)}, readDouble, nil, }, - // { - // "ReadDouble (truncate)", uint64(3), &DecodeContext{truncate: true}, - // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - // nil, - // }, { "ReadDouble (no truncate)", uint64(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -733,11 +723,6 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14159)}, readDouble, nil, }, - // { - // "float32/fast path (truncate)", float32(3.14), &DecodeContext{truncate: true}, - // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - // nil, - // }, { "float32/fast path (no truncate)", float32(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -779,11 +764,6 @@ func TestDefaultValueDecoders(t *testing.T) { &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14159)}, readDouble, nil, }, - // { - // "float32/reflection path (truncate)", myfloat32(3.14), &DecodeContext{truncate: true}, - // &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, - // nil, - // }, { "float32/reflection path (no truncate)", myfloat32(0), nil, &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, @@ -806,6 +786,32 @@ func TestDefaultValueDecoders(t *testing.T) { }, }, }, + { + "NumDecodeValue (truncate)", + &numCodec{truncate: true}, + []subtest{ + { + "int ReadDouble (truncate)", int64(3), nil, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + nil, + }, + { + "uint ReadDouble (truncate)", uint64(3), nil, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + nil, + }, + { + "float32/fast path (truncate)", float32(3.14), nil, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + nil, + }, + { + "float32/reflection path (truncate)", myfloat32(3.14), nil, + &valueReaderWriter{BSONType: TypeDouble, Return: float64(3.14)}, readDouble, + nil, + }, + }, + }, { "TimeDecodeValue", &timeCodec{}, @@ -2463,7 +2469,6 @@ func TestDefaultValueDecoders(t *testing.T) { want := rc.val defer func() { if err := recover(); err != nil { - fmt.Println(t.Name()) panic(err) } }() @@ -2536,6 +2541,14 @@ func TestDefaultValueDecoders(t *testing.T) { t.Errorf("Errors do not match. got %v; want %v", got, want) } }) + t.Run("SliceCodec/DecodeValue/can't set slice", func(t *testing.T) { + var val []string + want := ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: reflect.ValueOf(val)} + got := (&sliceCodec{}).DecodeValue(nil, nil, reflect.ValueOf(val)) + if !assert.CompareErrors(got, want) { + t.Errorf("Errors do not match. got %v; want %v", got, want) + } + }) t.Run("SliceCodec/DecodeValue/too many elements", func(t *testing.T) { idx, doc := bsoncore.AppendDocumentStart(nil) aidx, doc := bsoncore.AppendArrayElementStart(doc, "foo") diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index be57563530..12a1eb1412 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -56,7 +56,8 @@ func registerDefaultEncoders(rb *RegistryBuilder) { } numEncoder := func() ValueEncoder { return &numCodec{} } - rb.RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{} }). + rb. + RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{} }). RegisterTypeEncoder(tTime, func() ValueEncoder { return &timeCodec{} }). RegisterTypeEncoder(tEmpty, func() ValueEncoder { return &emptyInterfaceCodec{} }). RegisterTypeEncoder(tCoreArray, func() ValueEncoder { return &arrayCodec{} }). @@ -94,7 +95,7 @@ func registerDefaultEncoders(rb *RegistryBuilder) { RegisterKindEncoder(reflect.Map, func() ValueEncoder { return &mapCodec{} }). RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{} }). RegisterKindEncoder(reflect.String, func() ValueEncoder { return &stringCodec{} }). - RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return newStructCodec(DefaultStructTagParser) }). + RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return newStructCodec(rb.StructTagHandler()) }). RegisterKindEncoder(reflect.Ptr, func() ValueEncoder { return &pointerCodec{} }). RegisterInterfaceEncoder(tValueMarshaler, func() ValueEncoder { return ValueEncoderFunc(valueMarshalerEncodeValue) }). RegisterInterfaceEncoder(tMarshaler, func() ValueEncoder { return ValueEncoderFunc(marshalerEncodeValue) }). @@ -150,7 +151,13 @@ func jsonNumberEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Valu return err } - return (&numCodec{}).EncodeValue(reg, vw, reflect.ValueOf(f64)) + var encoder ValueEncoder + encoder, err = reg.LookupEncoder(reflect.TypeOf(f64)) + if err != nil { + return err + } + + return encoder.EncodeValue(reg, vw, reflect.ValueOf(f64)) } // urlEncodeValue is the ValueEncoderFunc for url.URL. diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 9a5d51cb04..5faf76b25f 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -47,7 +47,7 @@ func TestDefaultValueEncoders(t *testing.T) { type myuint16 uint16 type myuint32 uint32 type myuint64 uint64 - // type myuint uint + type myuint uint type myfloat32 float32 type myfloat64 float64 @@ -117,9 +117,6 @@ func TestDefaultValueEncoders(t *testing.T) { {"int16/fast path", int16(32767), nil, nil, writeInt32, nil}, {"int32/fast path", int32(2147483647), nil, nil, writeInt32, nil}, {"int64/fast path", int64(1234567890987), nil, nil, writeInt64, nil}, - // {"int64/fast path - minsize", int64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - // {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - // {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, {"int/fast path - positive int32", int(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, {"int/fast path - negative int32", int(math.MinInt32 + 1), nil, nil, writeInt32, nil}, {"int/fast path - MaxInt32", int(math.MaxInt32), nil, nil, writeInt32, nil}, @@ -128,9 +125,6 @@ func TestDefaultValueEncoders(t *testing.T) { {"int16/reflection path", myint16(32767), nil, nil, writeInt32, nil}, {"int32/reflection path", myint32(2147483647), nil, nil, writeInt32, nil}, {"int64/reflection path", myint64(1234567890987), nil, nil, writeInt64, nil}, - // {"int64/reflection path - minsize", myint64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - // {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - // {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, {"int/reflection path - positive int32", myint(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, {"int/reflection path - negative int32", myint(math.MinInt32 + 1), nil, nil, writeInt32, nil}, {"int/reflection path - MaxInt32", myint(math.MaxInt32), nil, nil, writeInt32, nil}, @@ -162,26 +156,38 @@ func TestDefaultValueEncoders(t *testing.T) { {"uint32/fast path", uint32(2147483647), nil, nil, writeInt64, nil}, {"uint64/fast path", uint64(1234567890987), nil, nil, writeInt64, nil}, {"uint/fast path", uint(1234567), nil, nil, writeInt64, nil}, - // {"uint32/fast path - minsize", uint32(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - // {"uint64/fast path - minsize", uint64(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - // {"uint/fast path - minsize", uint(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - // {"uint32/fast path - minsize too large", uint32(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - // {"uint64/fast path - minsize too large", uint64(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - // {"uint/fast path - minsize too large", uint(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, {"uint64/fast path - overflow", uint64(1 << 63), nil, nil, nothing, fmt.Errorf("%d overflows int64", uint64(1<<63))}, {"uint8/reflection path", myuint8(127), nil, nil, writeInt32, nil}, {"uint16/reflection path", myuint16(32767), nil, nil, writeInt32, nil}, {"uint32/reflection path", myuint32(2147483647), nil, nil, writeInt64, nil}, {"uint64/reflection path", myuint64(1234567890987), nil, nil, writeInt64, nil}, - // {"uint32/reflection path - minsize", myuint32(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - // {"uint64/reflection path - minsize", myuint64(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - // {"uint/reflection path - minsize", myuint(2147483647), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - // {"uint32/reflection path - minsize too large", myuint(1 << 31), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - // {"uint64/reflection path - minsize too large", myuint64(1 << 31), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - // {"uint/reflection path - minsize too large", myuint(2147483648), &EncodeContext{minSize: true}, nil, writeInt64, nil}, {"uint64/reflection path - overflow", myuint64(1 << 63), nil, nil, nothing, fmt.Errorf("%d overflows int64", uint64(1<<63))}, }, }, + { + "NumEncodeValue (minSize)", + &numCodec{minSize: true}, + []subtest{ + {"int64/fast path - minsize", int64(math.MaxInt32), nil, nil, writeInt32, nil}, + {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), nil, nil, writeInt64, nil}, + {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), nil, nil, writeInt64, nil}, + {"int64/reflection path - minsize", myint64(math.MaxInt32), nil, nil, writeInt32, nil}, + {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), nil, nil, writeInt64, nil}, + {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), nil, nil, writeInt64, nil}, + {"uint32/fast path - minsize", uint32(2147483647), nil, nil, writeInt32, nil}, + {"uint64/fast path - minsize", uint64(2147483647), nil, nil, writeInt32, nil}, + {"uint/fast path - minsize", uint(2147483647), nil, nil, writeInt32, nil}, + {"uint32/fast path - minsize too large", uint32(2147483648), nil, nil, writeInt64, nil}, + {"uint64/fast path - minsize too large", uint64(2147483648), nil, nil, writeInt64, nil}, + {"uint/fast path - minsize too large", uint(2147483648), nil, nil, writeInt64, nil}, + {"uint32/reflection path - minsize", myuint32(2147483647), nil, nil, writeInt32, nil}, + {"uint64/reflection path - minsize", myuint64(2147483647), nil, nil, writeInt32, nil}, + {"uint/reflection path - minsize", myuint(2147483647), nil, nil, writeInt32, nil}, + {"uint32/reflection path - minsize too large", myuint(1 << 31), nil, nil, writeInt64, nil}, + {"uint64/reflection path - minsize too large", myuint64(1 << 31), nil, nil, writeInt64, nil}, + {"uint/reflection path - minsize too large", myuint(2147483648), nil, nil, writeInt64, nil}, + }, + }, { "FloatEncodeValue", &numCodec{}, @@ -526,7 +532,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "json.Number/float64/success", json.Number("3.14159"), - nil, nil, writeDouble, nil, + buildDefaultRegistry(), nil, writeDouble, nil, }, }, }, @@ -1083,6 +1089,28 @@ func TestDefaultValueEncoders(t *testing.T) { }, }, }, + { + "StructEncodeValue", + defaultTestStructCodec, + []subtest{ + { + "interface value", + struct{ Foo myInterface }{Foo: myStruct{1}}, + buildDefaultRegistry(), + nil, + writeDocumentEnd, + nil, + }, + { + "nil interface value", + struct{ Foo myInterface }{Foo: nil}, + buildDefaultRegistry(), + nil, + writeDocumentEnd, + nil, + }, + }, + }, { "CodeWithScopeEncodeValue", ValueEncoderFunc(codeWithScopeEncodeValue), @@ -1813,6 +1841,27 @@ func TestDefaultValueEncoders(t *testing.T) { }) } }) + + t.Run("EmptyInterfaceEncodeValue/nil", func(t *testing.T) { + val := reflect.New(tEmpty).Elem() + llvrw := new(valueReaderWriter) + err := (&emptyInterfaceCodec{}).EncodeValue(newTestRegistryBuilder().Build(), llvrw, val) + noerr(t, err) + if llvrw.invoked != writeNull { + t.Errorf("Incorrect method called. got %v; want %v", llvrw.invoked, writeNull) + } + }) + + t.Run("EmptyInterfaceEncodeValue/LookupEncoder error", func(t *testing.T) { + val := reflect.New(tEmpty).Elem() + val.Set(reflect.ValueOf(int64(1234567890))) + llvrw := new(valueReaderWriter) + got := (&emptyInterfaceCodec{}).EncodeValue(newTestRegistryBuilder().Build(), llvrw, val) + want := ErrNoEncoder{Type: tInt64} + if !assert.CompareErrors(got, want) { + t.Errorf("Did not receive expected error. got %v; want %v", got, want) + } + }) } type testValueMarshalPtr struct { diff --git a/bson/map_codec.go b/bson/map_codec.go index a9640d34c6..db43347722 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -68,7 +68,11 @@ func (mc *mapCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect return err } - return mc.mapEncodeValue(reg, dw, val, nil) + err = mc.mapEncodeValue(reg, dw, val, nil) + if err != nil { + return err + } + return dw.WriteDocumentEnd() } // mapEncodeValue handles encoding of the values of a map. The collisionFn returns @@ -117,7 +121,7 @@ func (mc *mapCodec) mapEncodeValue(reg EncoderRegistry, dw DocumentWriter, val r } } - return dw.WriteDocumentEnd() + return nil } // DecodeValue is the ValueDecoder for map[string/decimal]* types. diff --git a/bson/marshal.go b/bson/marshal.go index 01dea11fb6..db0bf47fae 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -70,7 +70,7 @@ func Marshal(val interface{}) ([]byte, error) { } }() sw.Reset() - enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), NewValueWriter(sw)) + enc := NewEncoderWithRegistry(defaultRegistry, NewValueWriter(sw)) err := enc.Encode(val) if err != nil { return nil, err @@ -84,7 +84,7 @@ func Marshal(val interface{}) ([]byte, error) { // MarshalValue will use default registry to transform val into a BSON value. If val is a struct, this function will // inspect struct tags and alter the marshalling process accordingly. func MarshalValue(val interface{}) (Type, []byte, error) { - return MarshalValueWithRegistry(NewRegistryBuilder().Build(), val) + return MarshalValueWithRegistry(defaultRegistry, val) } // MarshalValueWithRegistry returns the BSON encoding of val using Registry r. @@ -116,7 +116,7 @@ func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error) ejvw := extjPool.Get(&sw, canonical, escapeHTML) defer extjPool.Put(ejvw) - enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), ejvw) + enc := NewEncoderWithRegistry(defaultRegistry, ejvw) err := enc.Encode(val) if err != nil { return nil, err diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 0bb650b668..1edf66e33b 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -28,7 +28,7 @@ func TestMarshalWithRegistry(t *testing.T) { if tc.reg != nil { reg = tc.reg } else { - reg = NewRegistryBuilder().Build() + reg = defaultRegistry } buf := new(bytes.Buffer) vw := NewValueWriter(buf) diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 302c9fc6da..c31605d105 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -79,11 +79,28 @@ var sampleItems = []testItemType{ "\x13\x00\x00\x00\x05slice\x00\x02\x00\x00\x00\x00\x01\x02\x00"}, } -func TestUnmarshalSampleItems(t *testing.T) { +func TestEncodeSampleItems(t *testing.T) { + buf := new(bytes.Buffer) + for i, item := range sampleItems { + t.Run(strconv.Itoa(i), func(t *testing.T) { + buf.Reset() + vw := bson.NewValueWriter(buf) + enc := bson.NewEncoderWithRegistry(Registry, vw) + err := enc.Encode(item.obj) + assert.Nil(t, err, "expected nil error, got: %v", err) + str := buf.String() + assert.Equal(t, str, item.data, "expected: %v, got: %v", item.data, str) + }) + } +} + +func TestDecodeSampleItems(t *testing.T) { for i, item := range sampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { value := bson.M{} - err := bson.UnmarshalWithRegistry(Registry, []byte(item.data), &value) + vr := bson.NewValueReader([]byte(item.data)) + dec := bson.NewDecoderWithRegistry(Registry, vr) + err := dec.Decode(&value) assert.Nil(t, err, "expected nil error, got: %v", err) assert.True(t, reflect.DeepEqual(value, item.obj), "expected: %v, got: %v", item.obj, value) }) @@ -147,11 +164,28 @@ var allItems = []testItemType{ "\xFF_\x00"}, } -func TestUnmarshalAllItems(t *testing.T) { +func TestEncodeAllItems(t *testing.T) { + buf := new(bytes.Buffer) + for i, item := range allItems { + t.Run(strconv.Itoa(i), func(t *testing.T) { + buf.Reset() + vw := bson.NewValueWriter(buf) + enc := bson.NewEncoderWithRegistry(Registry, vw) + err := enc.Encode(item.obj) + assert.Nil(t, err, "expected nil error, got: %v", err) + str := buf.String() + assert.Equal(t, str, wrapInDoc(item.data), "expected: %v, got: %v", wrapInDoc(item.data), str) + }) + } +} + +func TestDecodeAllItems(t *testing.T) { for i, item := range allItems { t.Run(strconv.Itoa(i), func(t *testing.T) { value := bson.M{} - err := bson.UnmarshalWithRegistry(Registry, []byte(wrapInDoc(item.data)), &value) + vr := bson.NewValueReader([]byte(wrapInDoc(item.data))) + dec := bson.NewDecoderWithRegistry(Registry, vr) + err := dec.Decode(&value) assert.Nil(t, err, "expected nil error, got: %v", err) assert.True(t, reflect.DeepEqual(value, item.obj), "expected: %v, got: %v", item.obj, value) }) diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index b8d9ca7a10..0b7af8dda8 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -24,7 +24,7 @@ var ( func newMgoRegistryBuilder() *RegistryBuilder { structcodec := &structCodec{ - parser: DefaultStructTagParser, + tagHndl: DefaultStructTagHandler(), decodeZeroStruct: true, encodeOmitDefaultStruct: true, allowUnexportedFields: true, diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index f5a67165e4..b8cdc8a94f 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -21,7 +21,8 @@ func registerPrimitiveCodecs(rb *RegistryBuilder) { panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil")) } - rb.RegisterTypeEncoder(tRawValue, func() ValueEncoder { return ValueEncoderFunc(rawValueEncodeValue) }). + rb. + RegisterTypeEncoder(tRawValue, func() ValueEncoder { return ValueEncoderFunc(rawValueEncodeValue) }). RegisterTypeEncoder(tRaw, func() ValueEncoder { return ValueEncoderFunc(rawEncodeValue) }). RegisterTypeDecoder(tRawValue, func() ValueDecoder { return ValueDecoderFunc(rawValueDecodeValue) }). RegisterTypeDecoder(tRaw, func() ValueDecoder { return ValueDecoderFunc(rawDecodeValue) }) diff --git a/bson/raw_value.go b/bson/raw_value.go index 732379e118..0e32361ee4 100644 --- a/bson/raw_value.go +++ b/bson/raw_value.go @@ -46,7 +46,7 @@ func (rv RawValue) IsZero() bool { func (rv RawValue) Unmarshal(val interface{}) error { reg := rv.r if reg == nil { - reg = NewRegistryBuilder().Build() + reg = defaultRegistry } return rv.UnmarshalWithRegistry(reg, val) } diff --git a/bson/registry.go b/bson/registry.go index cf88c703ba..9d7b1c52dc 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -12,6 +12,8 @@ import ( "sync" ) +var defaultRegistry = NewRegistryBuilder().Build() + // ErrNoEncoder is returned when there wasn't an encoder available for a type. // // Deprecated: ErrNoEncoder will not be supported in Go Driver 2.0. @@ -60,9 +62,57 @@ type EncoderFactory func() ValueEncoder // DecoderFactory is a factory function that generates a new ValueDecoder. type DecoderFactory func() ValueDecoder +func inlineEncoder(reg EncoderRegistry, w DocumentWriter, v reflect.Value, collisionFn func(string) bool) error { + enc, err := reg.LookupEncoder(v.Type()) + if err != nil { + return err + } + codec, ok := enc.(*mapCodec) + if !ok { + return fmt.Errorf("failed to find an encoder for inline map") + } + return codec.mapEncodeValue(reg, w, v, collisionFn) +} + +func retrieverOnMinSize(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { + enc, err := reg.LookupEncoder(t) + if err != nil { + return enc, err + } + switch t.Kind() { + case reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + if codec, ok := enc.(*numCodec); ok { + c := *codec + c.minSize = true + return &c, nil + } + } + return enc, nil +} + +func retrieverOnTruncate(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { + enc, err := reg.LookupEncoder(t) + if err != nil { + return enc, err + } + switch t.Kind() { + case reflect.Float32, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + if codec, ok := enc.(*numCodec); ok { + c := *codec + c.truncate = true + return &c, nil + } + } + return enc, nil +} + // A RegistryBuilder is used to build a Registry. This type is not goroutine // safe. type RegistryBuilder struct { + StructTagHandler func() StructTagHandler + typeEncoders map[reflect.Type]EncoderFactory typeDecoders map[reflect.Type]DecoderFactory interfaceEncoders map[reflect.Type]EncoderFactory @@ -72,9 +122,19 @@ type RegistryBuilder struct { typeMap map[Type]reflect.Type } +// DefaultStructTagHandler generates a new *StructTagHandler to initialize the struct codec. +func DefaultStructTagHandler() StructTagHandler { + return StructTagHandler{ + InlineEncoder: inlineEncoder, + LookupEncoderOnMinSize: retrieverOnMinSize, + LookupEncoderOnTruncate: retrieverOnTruncate, + } +} + // NewRegistryBuilder creates a new empty RegistryBuilder. func NewRegistryBuilder() *RegistryBuilder { rb := &RegistryBuilder{ + StructTagHandler: DefaultStructTagHandler, typeEncoders: make(map[reflect.Type]EncoderFactory), typeDecoders: make(map[reflect.Type]DecoderFactory), interfaceEncoders: make(map[reflect.Type]EncoderFactory), @@ -207,7 +267,7 @@ func (rb *RegistryBuilder) RegisterInterfaceDecoder(iface reflect.Type, decFac D // documents, a type map entry for TypeEmbeddedDocument should be registered. For example, to force BSON documents // to decode to bson.Raw, use the following code: // -// reg.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) +// rb.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) // // RegisterTypeMapEntry should not be called concurrently with any other Registry method. func (rb *RegistryBuilder) RegisterTypeMapEntry(bt Type, rt reflect.Type) *RegistryBuilder { diff --git a/bson/registry_examples_test.go b/bson/registry_examples_test.go index 20fab280a4..a807a0d80a 100644 --- a/bson/registry_examples_test.go +++ b/bson/registry_examples_test.go @@ -132,8 +132,8 @@ func ExampleRegistry_customDecoder() { return nil } - reg := bson.NewRegistryBuilder() - reg.RegisterTypeDecoder( + rb := bson.NewRegistryBuilder() + rb.RegisterTypeDecoder( lenientBoolType, func() bson.ValueDecoder { return bson.ValueDecoderFunc(lenientBoolDecoder) @@ -154,7 +154,7 @@ func ExampleRegistry_customDecoder() { IsOK lenientBool `bson:"isOK"` } var doc MyDocument - err = bson.UnmarshalWithRegistry(reg.Build(), b, &doc) + err = bson.UnmarshalWithRegistry(rb.Build(), b, &doc) if err != nil { panic(err) } @@ -279,8 +279,8 @@ func ExampleRegistryBuilder_RegisterKindDecoder() { return nil } - reg := bson.NewRegistryBuilder() - reg.RegisterKindDecoder( + rb := bson.NewRegistryBuilder() + rb.RegisterKindDecoder( reflect.Int64, func() bson.ValueDecoder { return bson.ValueDecoderFunc(flexibleInt64KindDecoder) @@ -302,7 +302,7 @@ func ExampleRegistryBuilder_RegisterKindDecoder() { Int64 int64 } var doc myDocument - err = bson.UnmarshalWithRegistry(reg.Build(), b, &doc) + err = bson.UnmarshalWithRegistry(rb.Build(), b, &doc) if err != nil { panic(err) } diff --git a/bson/registry_option.go b/bson/registry_option.go index 5f07052f5c..cee6c9f7a0 100644 --- a/bson/registry_option.go +++ b/bson/registry_option.go @@ -107,8 +107,8 @@ var NilSliceAsEmpty = NewRegistryOpt(func(c *sliceCodec) error { return nil }) -// DecodeObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. -var DecodeObjectIDAsHex = NewRegistryOpt(func(c *stringCodec) error { +// ObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string. +var ObjectIDAsHex = NewRegistryOpt(func(c *stringCodec) error { c.decodeObjectIDAsHex = true return nil }) diff --git a/bson/registry_test.go b/bson/registry_test.go index fd66e0cc84..3711cb770c 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -18,6 +18,7 @@ import ( // newTestRegistryBuilder creates a new empty RegistryBuilder. func newTestRegistryBuilder() *RegistryBuilder { return &RegistryBuilder{ + StructTagHandler: DefaultStructTagHandler, typeEncoders: make(map[reflect.Type]EncoderFactory), typeDecoders: make(map[reflect.Type]DecoderFactory), interfaceEncoders: make(map[reflect.Type]EncoderFactory), diff --git a/bson/struct_codec.go b/bson/struct_codec.go index 24ea1a2018..5155da984e 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -47,10 +47,17 @@ func (de *DecodeError) Keys() []string { return reversedKeys } +// StructTagHandler defines the struct encoder bahavior when inline, minSize and truncate tags are set. +type StructTagHandler struct { + InlineEncoder func(EncoderRegistry, DocumentWriter, reflect.Value, func(string) bool) error + LookupEncoderOnMinSize func(EncoderRegistry, reflect.Type) (ValueEncoder, error) + LookupEncoderOnTruncate func(EncoderRegistry, reflect.Type) (ValueEncoder, error) +} + // structCodec is the Codec used for struct values. type structCodec struct { - cache sync.Map // map[reflect.Type]*structDescription - parser structTagParser + cache sync.Map // map[reflect.Type]*structDescription + tagHndl StructTagHandler // decodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the decodeZeroStruct bool @@ -76,9 +83,9 @@ type structCodec struct { } // newStructCodec returns a StructCodec that uses p for struct tag parsing. -func newStructCodec(p structTagParser) *structCodec { +func newStructCodec(hndl StructTagHandler) *structCodec { return &structCodec{ - parser: p, + tagHndl: hndl, overwriteDuplicatedInlinedFields: true, } } @@ -185,7 +192,13 @@ func (sc *structCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val refl return exists } - return (&mapCodec{}).mapEncodeValue(reg, dw, rv, collisionFn) + if sc.tagHndl.InlineEncoder == nil { + return errors.New("inline map encoder is not defined") + } + err = sc.tagHndl.InlineEncoder(reg, dw, rv, collisionFn) + if err != nil { + return err + } } return dw.WriteDocumentEnd() @@ -475,9 +488,9 @@ func (sc *structCodec) describeStructSlow( // If the caller requested that we use JSON struct tags, use the JSONFallbackStructTagParser // instead of the parser defined on the codec. if useJSONStructTags { - stags, err = sc.parser.parseJSONStructTags(sf) + stags, err = parseJSONStructTags(sf) } else { - stags, err = sc.parser.parseStructTags(sf) + stags, err = parseStructTags(sf) } if err != nil { return nil, err @@ -488,16 +501,16 @@ func (sc *structCodec) describeStructSlow( description.name = stags.Name description.omitEmpty = stags.OmitEmpty description.encoderLookup = func(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { - if stags.LookupEncoderOnMinSize != nil { + if stags.MinSize && sc.tagHndl.LookupEncoderOnMinSize != nil { reg = &localEncoderRegistry{ registry: reg, - encoderLookup: stags.LookupEncoderOnMinSize.LookupEncoder, + encoderLookup: sc.tagHndl.LookupEncoderOnMinSize, } } - if stags.LookupEncoderOnTruncate != nil { + if stags.Truncate && sc.tagHndl.LookupEncoderOnTruncate != nil { reg = &localEncoderRegistry{ registry: reg, - encoderLookup: stags.LookupEncoderOnTruncate.LookupEncoder, + encoderLookup: sc.tagHndl.LookupEncoderOnTruncate, } } return reg.LookupEncoder(t) diff --git a/bson/struct_tag_parser.go b/bson/struct_tag_parser.go index 30b0e9815d..7cf8aecffe 100644 --- a/bson/struct_tag_parser.go +++ b/bson/struct_tag_parser.go @@ -11,56 +11,6 @@ import ( "strings" ) -// structTagParser returns the struct tags for a given reflect.StructField. -type structTagParser interface { - parseStructTags(reflect.StructField) (*structTags, error) - parseJSONStructTags(reflect.StructField) (*structTags, error) -} - -// DefaultStructTagParser is the StructTagParser used by the StructCodec by default. -var DefaultStructTagParser = &StructTagParser{ - LookupEncoderOnMinSize: retrieverOnMinSize{}, - LookupEncoderOnTruncate: retrieverOnTruncate{}, -} - -type retrieverOnMinSize struct{} - -func (retrieverOnMinSize) LookupEncoder(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { - enc, err := reg.LookupEncoder(t) - if err != nil { - return enc, err - } - switch t.Kind() { - case reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - if codec, ok := enc.(*numCodec); ok { - c := *codec - c.minSize = true - return &c, nil - } - } - return enc, nil -} - -type retrieverOnTruncate struct{} - -func (retrieverOnTruncate) LookupEncoder(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { - enc, err := reg.LookupEncoder(t) - if err != nil { - return enc, err - } - switch t.Kind() { - case reflect.Float32, - reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, - reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - if codec, ok := enc.(*numCodec); ok { - c := *codec - c.truncate = true - return &c, nil - } - } - return enc, nil -} - // structTags represents the struct tag fields that the StructCodec uses during // the encoding and decoding process. // @@ -82,22 +32,10 @@ func (retrieverOnTruncate) LookupEncoder(reg EncoderRegistry, t reflect.Type) (V type structTags struct { Name string Inline bool + MinSize bool OmitEmpty bool Skip bool - - LookupEncoderOnMinSize EncoderRetriever - LookupEncoderOnTruncate EncoderRetriever -} - -// EncoderRetriever is used to look up ValueEncoder with given EncoderRegistry and reflect.Type. -type EncoderRetriever interface { - LookupEncoder(EncoderRegistry, reflect.Type) (ValueEncoder, error) -} - -// StructTagParser defines the encoder lookup bahavior when minSize and truncate tags are set. -type StructTagParser struct { - LookupEncoderOnMinSize EncoderRetriever - LookupEncoderOnTruncate EncoderRetriever + Truncate bool } // parseStructTags handles the bson struct tag. See the documentation for StructTags to see @@ -124,16 +62,31 @@ type StructTagParser struct { // A struct tag either consisting entirely of '-' or with a bson key with a // value consisting entirely of '-' will return a StructTags with Skip true and // the remaining fields will be their default values. -func (p *StructTagParser) parseStructTags(sf reflect.StructField) (*structTags, error) { +func parseStructTags(sf reflect.StructField) (*structTags, error) { key := strings.ToLower(sf.Name) tag, ok := sf.Tag.Lookup("bson") if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { tag = string(sf.Tag) } - return p.parseTags(key, tag) + return parseTags(key, tag) } -func (p *StructTagParser) parseTags(key string, tag string) (*structTags, error) { +// parseJSONStructTags parses the json tag instead on a field where the +// bson tag isn't available. +func parseJSONStructTags(sf reflect.StructField) (*structTags, error) { + key := strings.ToLower(sf.Name) + tag, ok := sf.Tag.Lookup("bson") + if !ok { + tag, ok = sf.Tag.Lookup("json") + } + if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { + tag = string(sf.Tag) + } + + return parseTags(key, tag) +} + +func parseTags(key string, tag string) (*structTags, error) { var st structTags if tag == "-" { st.Skip = true @@ -147,12 +100,12 @@ func (p *StructTagParser) parseTags(key string, tag string) (*structTags, error) switch str { case "inline": st.Inline = true + case "minsize": + st.MinSize = true case "omitempty": st.OmitEmpty = true - case "minsize": - st.LookupEncoderOnMinSize = p.LookupEncoderOnMinSize case "truncate": - st.LookupEncoderOnTruncate = p.LookupEncoderOnTruncate + st.Truncate = true } } @@ -160,18 +113,3 @@ func (p *StructTagParser) parseTags(key string, tag string) (*structTags, error) return &st, nil } - -// parseJSONStructTags parses the json tag instead on a field where the -// bson tag isn't available. -func (p *StructTagParser) parseJSONStructTags(sf reflect.StructField) (*structTags, error) { - key := strings.ToLower(sf.Name) - tag, ok := sf.Tag.Lookup("bson") - if !ok { - tag, ok = sf.Tag.Lookup("json") - } - if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { - tag = string(sf.Tag) - } - - return p.parseTags(key, tag) -} diff --git a/bson/struct_tag_parser_test.go b/bson/struct_tag_parser_test.go index 312d1ab7f3..c34faec0b5 100644 --- a/bson/struct_tag_parser_test.go +++ b/bson/struct_tag_parser_test.go @@ -24,167 +24,127 @@ func TestStructTagParsers(t *testing.T) { "default no bson tag", reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")}, &structTags{Name: "bar"}, - DefaultStructTagParser.parseStructTags, + parseStructTags, }, { "default empty", reflect.StructField{Name: "foo", Tag: reflect.StructTag("")}, &structTags{Name: "foo"}, - DefaultStructTagParser.parseStructTags, + parseStructTags, }, { "default tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")}, &structTags{Skip: true}, - DefaultStructTagParser.parseStructTags, + parseStructTags, }, { "default bson tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)}, &structTags{Skip: true}, - DefaultStructTagParser.parseStructTags, + parseStructTags, }, { "default all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)}, - &structTags{ - Name: "bar", Inline: true, OmitEmpty: true, - LookupEncoderOnMinSize: retrieverOnMinSize{}, - LookupEncoderOnTruncate: retrieverOnTruncate{}, - }, - DefaultStructTagParser.parseStructTags, + &structTags{Name: "bar", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseStructTags, }, { "default all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)}, - &structTags{ - Name: "foo", Inline: true, OmitEmpty: true, - LookupEncoderOnMinSize: retrieverOnMinSize{}, - LookupEncoderOnTruncate: retrieverOnTruncate{}, - }, - DefaultStructTagParser.parseStructTags, + &structTags{Name: "foo", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseStructTags, }, { "default bson tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)}, - &structTags{ - Name: "bar", Inline: true, OmitEmpty: true, - LookupEncoderOnMinSize: retrieverOnMinSize{}, - LookupEncoderOnTruncate: retrieverOnTruncate{}, - }, - DefaultStructTagParser.parseStructTags, + &structTags{Name: "bar", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseStructTags, }, { "default bson tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)}, - &structTags{ - Name: "foo", Inline: true, OmitEmpty: true, - LookupEncoderOnMinSize: retrieverOnMinSize{}, - LookupEncoderOnTruncate: retrieverOnTruncate{}, - }, - DefaultStructTagParser.parseStructTags, + &structTags{Name: "foo", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseStructTags, }, { "default ignore xml", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)}, &structTags{Name: "foo"}, - DefaultStructTagParser.parseStructTags, + parseStructTags, }, { "JSONFallback no bson tag", reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")}, &structTags{Name: "bar"}, - DefaultStructTagParser.parseJSONStructTags, + parseJSONStructTags, }, { "JSONFallback empty", reflect.StructField{Name: "foo", Tag: reflect.StructTag("")}, &structTags{Name: "foo"}, - DefaultStructTagParser.parseJSONStructTags, + parseJSONStructTags, }, { "JSONFallback tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")}, &structTags{Skip: true}, - DefaultStructTagParser.parseJSONStructTags, + parseJSONStructTags, }, { "JSONFallback bson tag only dash", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)}, &structTags{Skip: true}, - DefaultStructTagParser.parseJSONStructTags, + parseJSONStructTags, }, { "JSONFallback all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)}, - &structTags{ - Name: "bar", Inline: true, OmitEmpty: true, - LookupEncoderOnMinSize: retrieverOnMinSize{}, - LookupEncoderOnTruncate: retrieverOnTruncate{}, - }, - DefaultStructTagParser.parseJSONStructTags, + &structTags{Name: "bar", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)}, - &structTags{ - Name: "foo", Inline: true, OmitEmpty: true, - LookupEncoderOnMinSize: retrieverOnMinSize{}, - LookupEncoderOnTruncate: retrieverOnTruncate{}, - }, - DefaultStructTagParser.parseJSONStructTags, + &structTags{Name: "foo", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback bson tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)}, - &structTags{ - Name: "bar", Inline: true, OmitEmpty: true, - LookupEncoderOnMinSize: retrieverOnMinSize{}, - LookupEncoderOnTruncate: retrieverOnTruncate{}, - }, - DefaultStructTagParser.parseJSONStructTags, + &structTags{Name: "bar", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback bson tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)}, - &structTags{ - Name: "foo", Inline: true, OmitEmpty: true, - LookupEncoderOnMinSize: retrieverOnMinSize{}, - LookupEncoderOnTruncate: retrieverOnTruncate{}, - }, - DefaultStructTagParser.parseJSONStructTags, + &structTags{Name: "foo", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback json tag all options", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:"bar,omitempty,minsize,truncate,inline"`)}, - &structTags{ - Name: "bar", Inline: true, OmitEmpty: true, - LookupEncoderOnMinSize: retrieverOnMinSize{}, - LookupEncoderOnTruncate: retrieverOnTruncate{}, - }, - DefaultStructTagParser.parseJSONStructTags, + &structTags{Name: "bar", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback json tag all options default name", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:",omitempty,minsize,truncate,inline"`)}, - &structTags{ - Name: "foo", Inline: true, OmitEmpty: true, - LookupEncoderOnMinSize: retrieverOnMinSize{}, - LookupEncoderOnTruncate: retrieverOnTruncate{}, - }, - DefaultStructTagParser.parseJSONStructTags, + &structTags{Name: "foo", Inline: true, OmitEmpty: true, MinSize: true, Truncate: true}, + parseJSONStructTags, }, { "JSONFallback bson tag overrides other tags", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar" json:"qux,truncate"`)}, &structTags{Name: "bar"}, - DefaultStructTagParser.parseJSONStructTags, + parseJSONStructTags, }, { "JSONFallback ignore xml", reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)}, &structTags{Name: "foo"}, - DefaultStructTagParser.parseJSONStructTags, + parseJSONStructTags, }, } diff --git a/bson/truncation_test.go b/bson/truncation_test.go index e0a1579494..04deb0efde 100644 --- a/bson/truncation_test.go +++ b/bson/truncation_test.go @@ -47,7 +47,7 @@ func TestTruncation(t *testing.T) { err = reg.SetCodecOption(opt) assert.Nil(t, err) - err = UnmarshalWithContext(reg, buf.Bytes(), &output) + err = UnmarshalWithRegistry(reg, buf.Bytes(), &output) assert.Nil(t, err) assert.Equal(t, inputName, output.Name) @@ -77,7 +77,7 @@ func TestTruncation(t *testing.T) { assert.Nil(t, err) // case throws an error when truncation is disabled - err = UnmarshalWithContext(reg, buf.Bytes(), &output) + err = UnmarshalWithRegistry(reg, buf.Bytes(), &output) assert.NotNil(t, err) }) } diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 371d2dfc3d..6cead8048b 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -38,7 +38,7 @@ type ValueUnmarshaler interface { // pointed to by val. If val is nil or not a pointer, Unmarshal returns // InvalidUnmarshalError. func Unmarshal(data []byte, val interface{}) error { - return UnmarshalWithRegistry(NewRegistryBuilder().Build(), data, val) + return UnmarshalWithRegistry(defaultRegistry, data, val) } // UnmarshalWithRegistry parses the BSON-encoded data using Registry r and @@ -58,30 +58,11 @@ func UnmarshalWithRegistry(reg *Registry, data []byte, val interface{}) error { return NewDecoderWithRegistry(reg, vr).Decode(val) } -// UnmarshalWithContext parses the BSON-encoded data using DecodeContext dc and -// stores the result in the value pointed to by val. If val is nil or not -// a pointer, UnmarshalWithRegistry returns InvalidUnmarshalError. -// -// Deprecated: Use [NewDecoder] and use the Decoder configuration methods to set the desired unmarshal -// behavior instead: -// -// dec, err := bson.NewDecoder(NewBSONDocumentReader(data)) -// if err != nil { -// panic(err) -// } -// dec.DefaultDocumentM() -// -// See [Decoder] for more examples. -func UnmarshalWithContext(reg *Registry, data []byte, val interface{}) error { - vr := NewValueReader(data) - return NewDecoderWithRegistry(reg, vr).Decode(val) -} - // UnmarshalValue parses the BSON value of type t with default registry and // stores the result in the value pointed to by val. If val is nil or not a pointer, // UnmarshalValue returns an error. func UnmarshalValue(t Type, data []byte, val interface{}) error { - return UnmarshalValueWithRegistry(NewRegistryBuilder().Build(), t, data, val) + return UnmarshalValueWithRegistry(defaultRegistry, t, data, val) } // UnmarshalValueWithRegistry parses the BSON value of type t with registry r and @@ -99,7 +80,7 @@ func UnmarshalValueWithRegistry(reg *Registry, t Type, data []byte, val interfac // in the value pointed to by val. If val is nil or not a pointer, Unmarshal // returns InvalidUnmarshalError. func UnmarshalExtJSON(data []byte, canonical bool, val interface{}) error { - return UnmarshalExtJSONWithRegistry(NewRegistryBuilder().Build(), data, canonical, val) + return UnmarshalExtJSONWithRegistry(defaultRegistry, data, canonical, val) } // UnmarshalExtJSONWithRegistry parses the extended JSON-encoded data using diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index 1abf59fd48..c6ba38d760 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -48,29 +48,7 @@ func TestUnmarshalWithRegistry(t *testing.T) { // Assert that unmarshaling the input data results in the expected value. got := reflect.New(tc.sType).Interface() - err := UnmarshalWithRegistry(NewRegistryBuilder().Build(), data, got) - noerr(t, err) - assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") - - // Fill the input data slice with random bytes and then assert that the result still - // matches the expected value. - _, err = rand.Read(data) - noerr(t, err) - assert.Equal(t, tc.want, got, "unmarshaled value does not match expected after modifying the input bytes") - }) - } -} - -func TestUnmarshalWithContext(t *testing.T) { - for _, tc := range unmarshalingTestCases() { - t.Run(tc.name, func(t *testing.T) { - // Make a copy of the test data so we can modify it later. - data := make([]byte, len(tc.data)) - copy(data, tc.data) - - // Assert that unmarshaling the input data results in the expected value. - got := reflect.New(tc.sType).Interface() - err := UnmarshalWithContext(NewRegistryBuilder().Build(), data, got) + err := UnmarshalWithRegistry(defaultRegistry, data, got) noerr(t, err) assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") @@ -88,7 +66,7 @@ func TestUnmarshalExtJSONWithRegistry(t *testing.T) { type teststruct struct{ Foo int } var got teststruct data := []byte("{\"foo\":1}") - err := UnmarshalExtJSONWithRegistry(NewRegistryBuilder().Build(), data, true, &got) + err := UnmarshalExtJSONWithRegistry(defaultRegistry, data, true, &got) noerr(t, err) want := teststruct{1} assert.Equal(t, want, got, "Did not unmarshal as expected.") @@ -96,7 +74,7 @@ func TestUnmarshalExtJSONWithRegistry(t *testing.T) { t.Run("UnmarshalExtJSONInvalidInput", func(t *testing.T) { data := []byte("invalid") - err := UnmarshalExtJSONWithRegistry(NewRegistryBuilder().Build(), data, true, &M{}) + err := UnmarshalExtJSONWithRegistry(defaultRegistry, data, true, &M{}) if !errors.Is(err, ErrInvalidJSON) { t.Fatalf("wanted ErrInvalidJSON, got %v", err) } @@ -198,7 +176,7 @@ func TestUnmarshalExtJSONWithContext(t *testing.T) { // Assert that unmarshaling the input data results in the expected value. got := reflect.New(tc.sType).Interface() - err := UnmarshalExtJSONWithContext(NewRegistryBuilder().Build(), data, true, got) + err := UnmarshalExtJSONWithContext(defaultRegistry, data, true, got) noerr(t, err) assert.Equal(t, tc.want, got, "Did not unmarshal as expected.") diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index a4d27eca01..c91eece0a9 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -46,7 +46,7 @@ func TestUnmarshalValue(t *testing.T) { t.Parallel() gotValue := reflect.New(reflect.TypeOf(tc.val)) - err := UnmarshalValueWithRegistry(NewRegistryBuilder().Build(), tc.bsontype, tc.bytes, gotValue.Interface()) + err := UnmarshalValueWithRegistry(defaultRegistry, tc.bsontype, tc.bytes, gotValue.Interface()) assert.Nil(t, err, "UnmarshalValueWithRegistry error: %v", err) assert.Equal(t, tc.val, gotValue.Elem().Interface(), "value mismatch; expected %s, got %s", tc.val, gotValue.Elem()) }) From c15afa8c7b0c45574674f32bf09551b960b68021 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Sun, 26 May 2024 12:04:00 -0400 Subject: [PATCH 15/15] update struct codec --- bson/bson_test.go | 2 +- bson/default_value_decoders.go | 64 ++++++++++---------- bson/default_value_decoders_test.go | 10 ++-- bson/default_value_encoders.go | 69 +++++++++++---------- bson/map_codec.go | 6 +- bson/marshal_test.go | 2 +- bson/mgoregistry.go | 56 +++++++++-------- bson/primitive_codecs.go | 8 +-- bson/registry.go | 70 ++-------------------- bson/registry_examples_test.go | 8 +-- bson/registry_test.go | 73 +++++++++++------------ bson/struct_codec.go | 63 +++++++++++++------ bson/unmarshal_test.go | 2 +- bson/unmarshal_value_test.go | 4 +- internal/integration/client_test.go | 4 +- internal/integration/unified_spec_test.go | 2 +- 16 files changed, 211 insertions(+), 232 deletions(-) diff --git a/bson/bson_test.go b/bson/bson_test.go index 54e38c6c0b..e54b4dc865 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -359,7 +359,7 @@ func TestMapCodec(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mapRegistry := NewRegistryBuilder() - mapRegistry.RegisterKindEncoder(reflect.Map, func() ValueEncoder { return tc.codec }) + mapRegistry.RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return tc.codec }) buf := new(bytes.Buffer) vw := NewValueWriter(buf) enc := NewEncoderWithRegistry(mapRegistry.Build(), vw) diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index e56b3d2faa..bd49385c7b 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -40,31 +40,31 @@ func registerDefaultDecoders(rb *RegistryBuilder) { panic(errors.New("argument to RegisterDefaultDecoders must not be nil")) } - numDecoder := func() ValueDecoder { return &numCodec{} } + numDecoder := func(*Registry) ValueDecoder { return &numCodec{} } rb. - RegisterTypeDecoder(tD, func() ValueDecoder { return ValueDecoderFunc(dDecodeValue) }). - RegisterTypeDecoder(tBinary, func() ValueDecoder { return &decodeAdapter{binaryDecodeValue, binaryDecodeType} }). - RegisterTypeDecoder(tUndefined, func() ValueDecoder { return &decodeAdapter{undefinedDecodeValue, undefinedDecodeType} }). - RegisterTypeDecoder(tDateTime, func() ValueDecoder { return &decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType} }). - RegisterTypeDecoder(tNull, func() ValueDecoder { return &decodeAdapter{nullDecodeValue, nullDecodeType} }). - RegisterTypeDecoder(tRegex, func() ValueDecoder { return &decodeAdapter{regexDecodeValue, regexDecodeType} }). - RegisterTypeDecoder(tDBPointer, func() ValueDecoder { return &decodeAdapter{dbPointerDecodeValue, dbPointerDecodeType} }). - RegisterTypeDecoder(tTimestamp, func() ValueDecoder { return &decodeAdapter{timestampDecodeValue, timestampDecodeType} }). - RegisterTypeDecoder(tMinKey, func() ValueDecoder { return &decodeAdapter{minKeyDecodeValue, minKeyDecodeType} }). - RegisterTypeDecoder(tMaxKey, func() ValueDecoder { return &decodeAdapter{maxKeyDecodeValue, maxKeyDecodeType} }). - RegisterTypeDecoder(tJavaScript, func() ValueDecoder { return &decodeAdapter{javaScriptDecodeValue, javaScriptDecodeType} }). - RegisterTypeDecoder(tSymbol, func() ValueDecoder { return &decodeAdapter{symbolDecodeValue, symbolDecodeType} }). - RegisterTypeDecoder(tByteSlice, func() ValueDecoder { return &byteSliceCodec{} }). - RegisterTypeDecoder(tTime, func() ValueDecoder { return &timeCodec{} }). - RegisterTypeDecoder(tEmpty, func() ValueDecoder { return &emptyInterfaceCodec{} }). - RegisterTypeDecoder(tCoreArray, func() ValueDecoder { return &arrayCodec{} }). - RegisterTypeDecoder(tOID, func() ValueDecoder { return &decodeAdapter{objectIDDecodeValue, objectIDDecodeType} }). - RegisterTypeDecoder(tDecimal, func() ValueDecoder { return &decodeAdapter{decimal128DecodeValue, decimal128DecodeType} }). - RegisterTypeDecoder(tJSONNumber, func() ValueDecoder { return &decodeAdapter{jsonNumberDecodeValue, jsonNumberDecodeType} }). - RegisterTypeDecoder(tURL, func() ValueDecoder { return &decodeAdapter{urlDecodeValue, urlDecodeType} }). - RegisterTypeDecoder(tCoreDocument, func() ValueDecoder { return ValueDecoderFunc(coreDocumentDecodeValue) }). - RegisterTypeDecoder(tCodeWithScope, func() ValueDecoder { return &decodeAdapter{codeWithScopeDecodeValue, codeWithScopeDecodeType} }). - RegisterKindDecoder(reflect.Bool, func() ValueDecoder { return &decodeAdapter{booleanDecodeValue, booleanDecodeType} }). + RegisterTypeDecoder(tD, func(*Registry) ValueDecoder { return ValueDecoderFunc(dDecodeValue) }). + RegisterTypeDecoder(tBinary, func(*Registry) ValueDecoder { return &decodeAdapter{binaryDecodeValue, binaryDecodeType} }). + RegisterTypeDecoder(tUndefined, func(*Registry) ValueDecoder { return &decodeAdapter{undefinedDecodeValue, undefinedDecodeType} }). + RegisterTypeDecoder(tDateTime, func(*Registry) ValueDecoder { return &decodeAdapter{dateTimeDecodeValue, dateTimeDecodeType} }). + RegisterTypeDecoder(tNull, func(*Registry) ValueDecoder { return &decodeAdapter{nullDecodeValue, nullDecodeType} }). + RegisterTypeDecoder(tRegex, func(*Registry) ValueDecoder { return &decodeAdapter{regexDecodeValue, regexDecodeType} }). + RegisterTypeDecoder(tDBPointer, func(*Registry) ValueDecoder { return &decodeAdapter{dbPointerDecodeValue, dbPointerDecodeType} }). + RegisterTypeDecoder(tTimestamp, func(*Registry) ValueDecoder { return &decodeAdapter{timestampDecodeValue, timestampDecodeType} }). + RegisterTypeDecoder(tMinKey, func(*Registry) ValueDecoder { return &decodeAdapter{minKeyDecodeValue, minKeyDecodeType} }). + RegisterTypeDecoder(tMaxKey, func(*Registry) ValueDecoder { return &decodeAdapter{maxKeyDecodeValue, maxKeyDecodeType} }). + RegisterTypeDecoder(tJavaScript, func(*Registry) ValueDecoder { return &decodeAdapter{javaScriptDecodeValue, javaScriptDecodeType} }). + RegisterTypeDecoder(tSymbol, func(*Registry) ValueDecoder { return &decodeAdapter{symbolDecodeValue, symbolDecodeType} }). + RegisterTypeDecoder(tByteSlice, func(*Registry) ValueDecoder { return &byteSliceCodec{} }). + RegisterTypeDecoder(tTime, func(*Registry) ValueDecoder { return &timeCodec{} }). + RegisterTypeDecoder(tEmpty, func(*Registry) ValueDecoder { return &emptyInterfaceCodec{} }). + RegisterTypeDecoder(tCoreArray, func(*Registry) ValueDecoder { return &arrayCodec{} }). + RegisterTypeDecoder(tOID, func(*Registry) ValueDecoder { return &decodeAdapter{objectIDDecodeValue, objectIDDecodeType} }). + RegisterTypeDecoder(tDecimal, func(*Registry) ValueDecoder { return &decodeAdapter{decimal128DecodeValue, decimal128DecodeType} }). + RegisterTypeDecoder(tJSONNumber, func(*Registry) ValueDecoder { return &decodeAdapter{jsonNumberDecodeValue, jsonNumberDecodeType} }). + RegisterTypeDecoder(tURL, func(*Registry) ValueDecoder { return &decodeAdapter{urlDecodeValue, urlDecodeType} }). + RegisterTypeDecoder(tCoreDocument, func(*Registry) ValueDecoder { return ValueDecoderFunc(coreDocumentDecodeValue) }). + RegisterTypeDecoder(tCodeWithScope, func(*Registry) ValueDecoder { return &decodeAdapter{codeWithScopeDecodeValue, codeWithScopeDecodeType} }). + RegisterKindDecoder(reflect.Bool, func(*Registry) ValueDecoder { return &decodeAdapter{booleanDecodeValue, booleanDecodeType} }). RegisterKindDecoder(reflect.Int, numDecoder). RegisterKindDecoder(reflect.Int8, numDecoder). RegisterKindDecoder(reflect.Int16, numDecoder). @@ -77,12 +77,12 @@ func registerDefaultDecoders(rb *RegistryBuilder) { RegisterKindDecoder(reflect.Uint64, numDecoder). RegisterKindDecoder(reflect.Float32, numDecoder). RegisterKindDecoder(reflect.Float64, numDecoder). - RegisterKindDecoder(reflect.Array, func() ValueDecoder { return ValueDecoderFunc(arrayDecodeValue) }). - RegisterKindDecoder(reflect.Map, func() ValueDecoder { return &mapCodec{} }). - RegisterKindDecoder(reflect.Slice, func() ValueDecoder { return &sliceCodec{} }). - RegisterKindDecoder(reflect.String, func() ValueDecoder { return &stringCodec{} }). - RegisterKindDecoder(reflect.Struct, func() ValueDecoder { return newStructCodec(rb.StructTagHandler()) }). - RegisterKindDecoder(reflect.Ptr, func() ValueDecoder { return &pointerCodec{} }). + RegisterKindDecoder(reflect.Array, func(*Registry) ValueDecoder { return ValueDecoderFunc(arrayDecodeValue) }). + RegisterKindDecoder(reflect.Map, func(*Registry) ValueDecoder { return &mapCodec{} }). + RegisterKindDecoder(reflect.Slice, func(*Registry) ValueDecoder { return &sliceCodec{} }). + RegisterKindDecoder(reflect.String, func(*Registry) ValueDecoder { return &stringCodec{} }). + RegisterKindDecoder(reflect.Struct, func(*Registry) ValueDecoder { return newStructCodec(nil) }). + RegisterKindDecoder(reflect.Ptr, func(*Registry) ValueDecoder { return &pointerCodec{} }). RegisterTypeMapEntry(TypeDouble, tFloat64). RegisterTypeMapEntry(TypeString, tString). RegisterTypeMapEntry(TypeArray, tA). @@ -104,8 +104,8 @@ func registerDefaultDecoders(rb *RegistryBuilder) { RegisterTypeMapEntry(TypeMaxKey, tMaxKey). RegisterTypeMapEntry(Type(0), tD). RegisterTypeMapEntry(TypeEmbeddedDocument, tD). - RegisterInterfaceDecoder(tValueUnmarshaler, func() ValueDecoder { return ValueDecoderFunc(valueUnmarshalerDecodeValue) }). - RegisterInterfaceDecoder(tUnmarshaler, func() ValueDecoder { return ValueDecoderFunc(unmarshalerDecodeValue) }) + RegisterInterfaceDecoder(tValueUnmarshaler, func(*Registry) ValueDecoder { return ValueDecoderFunc(valueUnmarshalerDecodeValue) }). + RegisterInterfaceDecoder(tUnmarshaler, func(*Registry) ValueDecoder { return ValueDecoderFunc(unmarshalerDecodeValue) }) } // dDecodeValue is the ValueDecoderFunc for D instances. diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index d07df571ec..50b1a668c2 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -23,7 +23,7 @@ import ( ) var ( - defaultTestStructCodec = newStructCodec(DefaultStructTagHandler()) + defaultTestStructCodec = newStructCodec(nil) ) func TestDefaultValueDecoders(t *testing.T) { @@ -3417,7 +3417,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } want := errors.New("DecodeValue failure error") - llc := func() ValueDecoder { return &llCodec{t: t, err: want} } + llc := func(*Registry) ValueDecoder { return &llCodec{t: t, err: want} } reg := newTestRegistryBuilder(). RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). @@ -3430,7 +3430,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("Success", func(t *testing.T) { want := tc.val - llc := func() ValueDecoder { return &llCodec{t: t, decodeval: tc.val} } + llc := func(*Registry) ValueDecoder { return &llCodec{t: t, decodeval: tc.val} } reg := newTestRegistryBuilder(). RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). @@ -3575,7 +3575,7 @@ func TestDefaultValueDecoders(t *testing.T) { return decodeValueError } emptyInterfaceErrorRegistry := newTestRegistryBuilder(). - RegisterTypeDecoder(tEmpty, func() ValueDecoder { return ValueDecoderFunc(emptyInterfaceErrorDecode) }). + RegisterTypeDecoder(tEmpty, func(*Registry) ValueDecoder { return ValueDecoderFunc(emptyInterfaceErrorDecode) }). Build() // Set up a document {foo: 10} and an error that would happen if the value were decoded into interface{} @@ -3630,7 +3630,7 @@ func TestDefaultValueDecoders(t *testing.T) { // Use a registry that has all default decoders with the custom interface{} decoder that always errors. nestedRegistryBuilder := newTestRegistryBuilder() registerDefaultDecoders(nestedRegistryBuilder) - nestedRegistryBuilder.RegisterTypeDecoder(tEmpty, func() ValueDecoder { return ValueDecoderFunc(emptyInterfaceErrorDecode) }) + nestedRegistryBuilder.RegisterTypeDecoder(tEmpty, func(*Registry) ValueDecoder { return ValueDecoderFunc(emptyInterfaceErrorDecode) }) nestedErr := &DecodeError{ keys: []string{"fourth", "1", "third", "randomKey", "second", "first"}, wrapped: decodeValueError, diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index 12a1eb1412..c9eb3fbe08 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -55,30 +55,30 @@ func registerDefaultEncoders(rb *RegistryBuilder) { panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) } - numEncoder := func() ValueEncoder { return &numCodec{} } + numEncoder := func(*Registry) ValueEncoder { return &numCodec{} } rb. - RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{} }). - RegisterTypeEncoder(tTime, func() ValueEncoder { return &timeCodec{} }). - RegisterTypeEncoder(tEmpty, func() ValueEncoder { return &emptyInterfaceCodec{} }). - RegisterTypeEncoder(tCoreArray, func() ValueEncoder { return &arrayCodec{} }). - RegisterTypeEncoder(tOID, func() ValueEncoder { return ValueEncoderFunc(objectIDEncodeValue) }). - RegisterTypeEncoder(tDecimal, func() ValueEncoder { return ValueEncoderFunc(decimal128EncodeValue) }). - RegisterTypeEncoder(tJSONNumber, func() ValueEncoder { return ValueEncoderFunc(jsonNumberEncodeValue) }). - RegisterTypeEncoder(tURL, func() ValueEncoder { return ValueEncoderFunc(urlEncodeValue) }). - RegisterTypeEncoder(tJavaScript, func() ValueEncoder { return ValueEncoderFunc(javaScriptEncodeValue) }). - RegisterTypeEncoder(tSymbol, func() ValueEncoder { return ValueEncoderFunc(symbolEncodeValue) }). - RegisterTypeEncoder(tBinary, func() ValueEncoder { return ValueEncoderFunc(binaryEncodeValue) }). - RegisterTypeEncoder(tUndefined, func() ValueEncoder { return ValueEncoderFunc(undefinedEncodeValue) }). - RegisterTypeEncoder(tDateTime, func() ValueEncoder { return ValueEncoderFunc(dateTimeEncodeValue) }). - RegisterTypeEncoder(tNull, func() ValueEncoder { return ValueEncoderFunc(nullEncodeValue) }). - RegisterTypeEncoder(tRegex, func() ValueEncoder { return ValueEncoderFunc(regexEncodeValue) }). - RegisterTypeEncoder(tDBPointer, func() ValueEncoder { return ValueEncoderFunc(dbPointerEncodeValue) }). - RegisterTypeEncoder(tTimestamp, func() ValueEncoder { return ValueEncoderFunc(timestampEncodeValue) }). - RegisterTypeEncoder(tMinKey, func() ValueEncoder { return ValueEncoderFunc(minKeyEncodeValue) }). - RegisterTypeEncoder(tMaxKey, func() ValueEncoder { return ValueEncoderFunc(maxKeyEncodeValue) }). - RegisterTypeEncoder(tCoreDocument, func() ValueEncoder { return ValueEncoderFunc(coreDocumentEncodeValue) }). - RegisterTypeEncoder(tCodeWithScope, func() ValueEncoder { return ValueEncoderFunc(codeWithScopeEncodeValue) }). - RegisterKindEncoder(reflect.Bool, func() ValueEncoder { return ValueEncoderFunc(booleanEncodeValue) }). + RegisterTypeEncoder(tByteSlice, func(*Registry) ValueEncoder { return &byteSliceCodec{} }). + RegisterTypeEncoder(tTime, func(*Registry) ValueEncoder { return &timeCodec{} }). + RegisterTypeEncoder(tEmpty, func(*Registry) ValueEncoder { return &emptyInterfaceCodec{} }). + RegisterTypeEncoder(tCoreArray, func(*Registry) ValueEncoder { return &arrayCodec{} }). + RegisterTypeEncoder(tOID, func(*Registry) ValueEncoder { return ValueEncoderFunc(objectIDEncodeValue) }). + RegisterTypeEncoder(tDecimal, func(*Registry) ValueEncoder { return ValueEncoderFunc(decimal128EncodeValue) }). + RegisterTypeEncoder(tJSONNumber, func(*Registry) ValueEncoder { return ValueEncoderFunc(jsonNumberEncodeValue) }). + RegisterTypeEncoder(tURL, func(*Registry) ValueEncoder { return ValueEncoderFunc(urlEncodeValue) }). + RegisterTypeEncoder(tJavaScript, func(*Registry) ValueEncoder { return ValueEncoderFunc(javaScriptEncodeValue) }). + RegisterTypeEncoder(tSymbol, func(*Registry) ValueEncoder { return ValueEncoderFunc(symbolEncodeValue) }). + RegisterTypeEncoder(tBinary, func(*Registry) ValueEncoder { return ValueEncoderFunc(binaryEncodeValue) }). + RegisterTypeEncoder(tUndefined, func(*Registry) ValueEncoder { return ValueEncoderFunc(undefinedEncodeValue) }). + RegisterTypeEncoder(tDateTime, func(*Registry) ValueEncoder { return ValueEncoderFunc(dateTimeEncodeValue) }). + RegisterTypeEncoder(tNull, func(*Registry) ValueEncoder { return ValueEncoderFunc(nullEncodeValue) }). + RegisterTypeEncoder(tRegex, func(*Registry) ValueEncoder { return ValueEncoderFunc(regexEncodeValue) }). + RegisterTypeEncoder(tDBPointer, func(*Registry) ValueEncoder { return ValueEncoderFunc(dbPointerEncodeValue) }). + RegisterTypeEncoder(tTimestamp, func(*Registry) ValueEncoder { return ValueEncoderFunc(timestampEncodeValue) }). + RegisterTypeEncoder(tMinKey, func(*Registry) ValueEncoder { return ValueEncoderFunc(minKeyEncodeValue) }). + RegisterTypeEncoder(tMaxKey, func(*Registry) ValueEncoder { return ValueEncoderFunc(maxKeyEncodeValue) }). + RegisterTypeEncoder(tCoreDocument, func(*Registry) ValueEncoder { return ValueEncoderFunc(coreDocumentEncodeValue) }). + RegisterTypeEncoder(tCodeWithScope, func(*Registry) ValueEncoder { return ValueEncoderFunc(codeWithScopeEncodeValue) }). + RegisterKindEncoder(reflect.Bool, func(*Registry) ValueEncoder { return ValueEncoderFunc(booleanEncodeValue) }). RegisterKindEncoder(reflect.Int, numEncoder). RegisterKindEncoder(reflect.Int8, numEncoder). RegisterKindEncoder(reflect.Int16, numEncoder). @@ -91,15 +91,20 @@ func registerDefaultEncoders(rb *RegistryBuilder) { RegisterKindEncoder(reflect.Uint64, numEncoder). RegisterKindEncoder(reflect.Float32, numEncoder). RegisterKindEncoder(reflect.Float64, numEncoder). - RegisterKindEncoder(reflect.Array, func() ValueEncoder { return ValueEncoderFunc(arrayEncodeValue) }). - RegisterKindEncoder(reflect.Map, func() ValueEncoder { return &mapCodec{} }). - RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{} }). - RegisterKindEncoder(reflect.String, func() ValueEncoder { return &stringCodec{} }). - RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return newStructCodec(rb.StructTagHandler()) }). - RegisterKindEncoder(reflect.Ptr, func() ValueEncoder { return &pointerCodec{} }). - RegisterInterfaceEncoder(tValueMarshaler, func() ValueEncoder { return ValueEncoderFunc(valueMarshalerEncodeValue) }). - RegisterInterfaceEncoder(tMarshaler, func() ValueEncoder { return ValueEncoderFunc(marshalerEncodeValue) }). - RegisterInterfaceEncoder(tProxy, func() ValueEncoder { return ValueEncoderFunc(proxyEncodeValue) }) + RegisterKindEncoder(reflect.Array, func(*Registry) ValueEncoder { return ValueEncoderFunc(arrayEncodeValue) }). + RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return &mapCodec{} }). + RegisterKindEncoder(reflect.Slice, func(*Registry) ValueEncoder { return &sliceCodec{} }). + RegisterKindEncoder(reflect.String, func(*Registry) ValueEncoder { return &stringCodec{} }). + RegisterKindEncoder(reflect.Struct, func(reg *Registry) ValueEncoder { + // reflect.Struct is 25 that is bigger than reflect.Map, 21, in the kind array, + // so Map will be registered earlier than Struct. + enc, _ := reg.lookupKindEncoder(reflect.Map) + return newStructCodec(enc.(mapElementsEncoder)) + }). + RegisterKindEncoder(reflect.Ptr, func(*Registry) ValueEncoder { return &pointerCodec{} }). + RegisterInterfaceEncoder(tValueMarshaler, func(*Registry) ValueEncoder { return ValueEncoderFunc(valueMarshalerEncodeValue) }). + RegisterInterfaceEncoder(tMarshaler, func(*Registry) ValueEncoder { return ValueEncoderFunc(marshalerEncodeValue) }). + RegisterInterfaceEncoder(tProxy, func(*Registry) ValueEncoder { return ValueEncoderFunc(proxyEncodeValue) }) } // booleanEncodeValue is the ValueEncoderFunc for bool types. diff --git a/bson/map_codec.go b/bson/map_codec.go index db43347722..0089c75717 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -68,17 +68,17 @@ func (mc *mapCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect return err } - err = mc.mapEncodeValue(reg, dw, val, nil) + err = mc.encodeMapElements(reg, dw, val, nil) if err != nil { return err } return dw.WriteDocumentEnd() } -// mapEncodeValue handles encoding of the values of a map. The collisionFn returns +// encodeMapElements handles encoding of the values of a map. The collisionFn returns // true if the provided key exists, this is mainly used for inline maps in the // struct codec. -func (mc *mapCodec) mapEncodeValue(reg EncoderRegistry, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { +func (mc *mapCodec) encodeMapElements(reg EncoderRegistry, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { elemType := val.Type().Elem() encoder, err := reg.LookupEncoder(elemType) diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 1edf66e33b..93787338a2 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -131,7 +131,7 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { return vw.WriteInt32(int32(val.Int()) * -1) } customReg := NewRegistryBuilder(). - RegisterTypeEncoder(tInt32, func() ValueEncoder { return encodeInt32 }). + RegisterTypeEncoder(tInt32, func(*Registry) ValueEncoder { return encodeInt32 }). Build() // Helper function to run the test and make assertions. The provided original value should result in the document diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index 0b7af8dda8..1aa96380cd 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -23,40 +23,44 @@ var ( ) func newMgoRegistryBuilder() *RegistryBuilder { - structcodec := &structCodec{ - tagHndl: DefaultStructTagHandler(), - decodeZeroStruct: true, - encodeOmitDefaultStruct: true, - allowUnexportedFields: true, - } mapCodec := &mapCodec{ decodeZerosMap: true, encodeNilAsEmpty: true, encodeKeysWithStringer: true, } - numcodec := func() ValueEncoder { return &numCodec{encodeUintToMinSize: true} } + newStructCodec := func(elemEncoder mapElementsEncoder) *structCodec { + return &structCodec{ + elemEncoder: elemEncoder, + decodeZeroStruct: true, + encodeOmitDefaultStruct: true, + allowUnexportedFields: true, + } + } + numcodecFac := func(*Registry) ValueEncoder { return &numCodec{encodeUintToMinSize: true} } return NewRegistryBuilder(). - RegisterTypeDecoder(tEmpty, func() ValueDecoder { return &emptyInterfaceCodec{decodeBinaryAsSlice: true} }). - RegisterKindDecoder(reflect.String, func() ValueDecoder { return &stringCodec{} }). - RegisterKindDecoder(reflect.Struct, func() ValueDecoder { return structcodec }). - RegisterKindDecoder(reflect.Map, func() ValueDecoder { return mapCodec }). - RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: true} }). - RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return structcodec }). - RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{encodeNilAsEmpty: true} }). - RegisterKindEncoder(reflect.Map, func() ValueEncoder { return mapCodec }). - RegisterKindEncoder(reflect.Uint, numcodec). - RegisterKindEncoder(reflect.Uint8, numcodec). - RegisterKindEncoder(reflect.Uint16, numcodec). - RegisterKindEncoder(reflect.Uint32, numcodec). - RegisterKindEncoder(reflect.Uint64, numcodec). + RegisterTypeDecoder(tEmpty, func(*Registry) ValueDecoder { return &emptyInterfaceCodec{decodeBinaryAsSlice: true} }). + RegisterKindDecoder(reflect.Struct, func(*Registry) ValueDecoder { return newStructCodec(nil) }). + RegisterKindDecoder(reflect.Map, func(*Registry) ValueDecoder { return mapCodec }). + RegisterTypeEncoder(tByteSlice, func(*Registry) ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: true} }). + RegisterKindEncoder(reflect.Struct, func(reg *Registry) ValueEncoder { + enc, _ := reg.lookupKindEncoder(reflect.Map) + return newStructCodec(enc.(mapElementsEncoder)) + }). + RegisterKindEncoder(reflect.Slice, func(*Registry) ValueEncoder { return &sliceCodec{encodeNilAsEmpty: true} }). + RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return mapCodec }). + RegisterKindEncoder(reflect.Uint, numcodecFac). + RegisterKindEncoder(reflect.Uint8, numcodecFac). + RegisterKindEncoder(reflect.Uint16, numcodecFac). + RegisterKindEncoder(reflect.Uint32, numcodecFac). + RegisterKindEncoder(reflect.Uint64, numcodecFac). RegisterTypeMapEntry(TypeInt32, tInt). RegisterTypeMapEntry(TypeDateTime, tTime). RegisterTypeMapEntry(TypeArray, tInterfaceSlice). RegisterTypeMapEntry(Type(0), tM). RegisterTypeMapEntry(TypeEmbeddedDocument, tM). - RegisterInterfaceEncoder(tGetter, func() ValueEncoder { return ValueEncoderFunc(GetterEncodeValue) }). - RegisterInterfaceDecoder(tSetter, func() ValueDecoder { return ValueDecoderFunc(SetterDecodeValue) }) + RegisterInterfaceEncoder(tGetter, func(*Registry) ValueEncoder { return ValueEncoderFunc(GetterEncodeValue) }). + RegisterInterfaceDecoder(tSetter, func(*Registry) ValueDecoder { return ValueDecoderFunc(SetterDecodeValue) }) } // NewMgoRegistry creates a new bson.Registry configured with the default encoders and decoders. @@ -72,9 +76,9 @@ func NewRespectNilValuesMgoRegistry() *Registry { } return newMgoRegistryBuilder(). - RegisterKindDecoder(reflect.Map, func() ValueDecoder { return mapCodec }). - RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: false} }). - RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{} }). - RegisterKindEncoder(reflect.Map, func() ValueEncoder { return mapCodec }). + RegisterKindDecoder(reflect.Map, func(*Registry) ValueDecoder { return mapCodec }). + RegisterTypeEncoder(tByteSlice, func(*Registry) ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: false} }). + RegisterKindEncoder(reflect.Slice, func(*Registry) ValueEncoder { return &sliceCodec{} }). + RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return mapCodec }). Build() } diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index b8cdc8a94f..2cf68ddf85 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -22,10 +22,10 @@ func registerPrimitiveCodecs(rb *RegistryBuilder) { } rb. - RegisterTypeEncoder(tRawValue, func() ValueEncoder { return ValueEncoderFunc(rawValueEncodeValue) }). - RegisterTypeEncoder(tRaw, func() ValueEncoder { return ValueEncoderFunc(rawEncodeValue) }). - RegisterTypeDecoder(tRawValue, func() ValueDecoder { return ValueDecoderFunc(rawValueDecodeValue) }). - RegisterTypeDecoder(tRaw, func() ValueDecoder { return ValueDecoderFunc(rawDecodeValue) }) + RegisterTypeEncoder(tRawValue, func(*Registry) ValueEncoder { return ValueEncoderFunc(rawValueEncodeValue) }). + RegisterTypeEncoder(tRaw, func(*Registry) ValueEncoder { return ValueEncoderFunc(rawEncodeValue) }). + RegisterTypeDecoder(tRawValue, func(*Registry) ValueDecoder { return ValueDecoderFunc(rawValueDecodeValue) }). + RegisterTypeDecoder(tRaw, func(*Registry) ValueDecoder { return ValueDecoderFunc(rawDecodeValue) }) } // rawValueEncodeValue is the ValueEncoderFunc for RawValue. diff --git a/bson/registry.go b/bson/registry.go index 9d7b1c52dc..740a2bb665 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -56,63 +56,15 @@ func (entme ErrNoTypeMapEntry) Error() string { return "no type map entry found for " + entme.Type.String() } -// EncoderFactory is a factory function that generates a new ValueEncoder. -type EncoderFactory func() ValueEncoder +// EncoderFactory is an idempotent factory function that generates a new ValueEncoder. +type EncoderFactory func(*Registry) ValueEncoder -// DecoderFactory is a factory function that generates a new ValueDecoder. -type DecoderFactory func() ValueDecoder - -func inlineEncoder(reg EncoderRegistry, w DocumentWriter, v reflect.Value, collisionFn func(string) bool) error { - enc, err := reg.LookupEncoder(v.Type()) - if err != nil { - return err - } - codec, ok := enc.(*mapCodec) - if !ok { - return fmt.Errorf("failed to find an encoder for inline map") - } - return codec.mapEncodeValue(reg, w, v, collisionFn) -} - -func retrieverOnMinSize(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { - enc, err := reg.LookupEncoder(t) - if err != nil { - return enc, err - } - switch t.Kind() { - case reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: - if codec, ok := enc.(*numCodec); ok { - c := *codec - c.minSize = true - return &c, nil - } - } - return enc, nil -} - -func retrieverOnTruncate(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { - enc, err := reg.LookupEncoder(t) - if err != nil { - return enc, err - } - switch t.Kind() { - case reflect.Float32, - reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, - reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - if codec, ok := enc.(*numCodec); ok { - c := *codec - c.truncate = true - return &c, nil - } - } - return enc, nil -} +// DecoderFactory is an idempotent factory function that generates a new ValueDecoder. +type DecoderFactory func(*Registry) ValueDecoder // A RegistryBuilder is used to build a Registry. This type is not goroutine // safe. type RegistryBuilder struct { - StructTagHandler func() StructTagHandler - typeEncoders map[reflect.Type]EncoderFactory typeDecoders map[reflect.Type]DecoderFactory interfaceEncoders map[reflect.Type]EncoderFactory @@ -122,19 +74,9 @@ type RegistryBuilder struct { typeMap map[Type]reflect.Type } -// DefaultStructTagHandler generates a new *StructTagHandler to initialize the struct codec. -func DefaultStructTagHandler() StructTagHandler { - return StructTagHandler{ - InlineEncoder: inlineEncoder, - LookupEncoderOnMinSize: retrieverOnMinSize, - LookupEncoderOnTruncate: retrieverOnTruncate, - } -} - // NewRegistryBuilder creates a new empty RegistryBuilder. func NewRegistryBuilder() *RegistryBuilder { rb := &RegistryBuilder{ - StructTagHandler: DefaultStructTagHandler, typeEncoders: make(map[reflect.Type]EncoderFactory), typeDecoders: make(map[reflect.Type]DecoderFactory), interfaceEncoders: make(map[reflect.Type]EncoderFactory), @@ -293,7 +235,7 @@ func (rb *RegistryBuilder) Build() *Registry { if enc, ok := codecCache[reflect.ValueOf(encFac)]; ok { return enc.(ValueEncoder) } - encoder := encFac() + encoder := encFac(r) codecCache[reflect.ValueOf(encFac)] = encoder t := reflect.ValueOf(encoder).Type() r.codecTypeMap[t] = append(r.codecTypeMap[t], encoder) @@ -319,7 +261,7 @@ func (rb *RegistryBuilder) Build() *Registry { if dec, ok := codecCache[reflect.ValueOf(decFac)]; ok { return dec.(ValueDecoder) } - decoder := decFac() + decoder := decFac(r) codecCache[reflect.ValueOf(decFac)] = decoder t := reflect.ValueOf(decoder).Type() r.codecTypeMap[t] = append(r.codecTypeMap[t], decoder) diff --git a/bson/registry_examples_test.go b/bson/registry_examples_test.go index a807a0d80a..92d82ba58e 100644 --- a/bson/registry_examples_test.go +++ b/bson/registry_examples_test.go @@ -49,7 +49,7 @@ func ExampleRegistry_customEncoder() { reg := bson.NewRegistryBuilder(). RegisterTypeEncoder( negatedIntType, - func() bson.ValueEncoder { + func(*bson.Registry) bson.ValueEncoder { return bson.ValueEncoderFunc(negatedIntEncoder) }, ). @@ -135,7 +135,7 @@ func ExampleRegistry_customDecoder() { rb := bson.NewRegistryBuilder() rb.RegisterTypeDecoder( lenientBoolType, - func() bson.ValueDecoder { + func(*bson.Registry) bson.ValueDecoder { return bson.ValueDecoderFunc(lenientBoolDecoder) }, ) @@ -190,7 +190,7 @@ func ExampleRegistryBuilder_RegisterKindEncoder() { reg := bson.NewRegistryBuilder(). RegisterKindEncoder( reflect.Int32, - func() bson.ValueEncoder { + func(*bson.Registry) bson.ValueEncoder { return bson.ValueEncoderFunc(int32To64Encoder) }, ). @@ -282,7 +282,7 @@ func ExampleRegistryBuilder_RegisterKindDecoder() { rb := bson.NewRegistryBuilder() rb.RegisterKindDecoder( reflect.Int64, - func() bson.ValueDecoder { + func(*bson.Registry) bson.ValueDecoder { return bson.ValueDecoderFunc(flexibleInt64KindDecoder) }, ) diff --git a/bson/registry_test.go b/bson/registry_test.go index 3711cb770c..a87fd3c6cb 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -18,7 +18,6 @@ import ( // newTestRegistryBuilder creates a new empty RegistryBuilder. func newTestRegistryBuilder() *RegistryBuilder { return &RegistryBuilder{ - StructTagHandler: DefaultStructTagHandler, typeEncoders: make(map[reflect.Type]EncoderFactory), typeDecoders: make(map[reflect.Type]DecoderFactory), interfaceEncoders: make(map[reflect.Type]EncoderFactory), @@ -44,19 +43,19 @@ func TestRegistryBuilder(t *testing.T) { reflect.TypeOf((*testInterface4)(nil)).Elem() var c1, c2, c3, c4 int - ef1 := func() ValueEncoder { + ef1 := func(*Registry) ValueEncoder { c1++ return fc1 } - ef2 := func() ValueEncoder { + ef2 := func(*Registry) ValueEncoder { c2++ return fc2 } - ef3 := func() ValueEncoder { + ef3 := func(*Registry) ValueEncoder { c3++ return fc3 } - ef4 := func() ValueEncoder { + ef4 := func(*Registry) ValueEncoder { c4++ return fc4 } @@ -123,19 +122,19 @@ func TestRegistryBuilder(t *testing.T) { reflect.TypeOf(fakeType4{}) var c1, c2, c3, c4 int - ef1 := func() ValueEncoder { + ef1 := func(*Registry) ValueEncoder { c1++ return fc1 } - ef2 := func() ValueEncoder { + ef2 := func(*Registry) ValueEncoder { c2++ return fc2 } - ef3 := func() ValueEncoder { + ef3 := func(*Registry) ValueEncoder { c3++ return fc3 } - ef4 := func() ValueEncoder { + ef4 := func(*Registry) ValueEncoder { c4++ return fc4 } @@ -195,19 +194,19 @@ func TestRegistryBuilder(t *testing.T) { k1, k2, k3, k4 := reflect.Struct, reflect.Slice, reflect.Int, reflect.Map var c1, c2, c3, c4 int - ef1 := func() ValueEncoder { + ef1 := func(*Registry) ValueEncoder { c1++ return fc1 } - ef2 := func() ValueEncoder { + ef2 := func(*Registry) ValueEncoder { c2++ return fc2 } - ef3 := func() ValueEncoder { + ef3 := func(*Registry) ValueEncoder { c3++ return fc3 } - ef4 := func() ValueEncoder { + ef4 := func(*Registry) ValueEncoder { c4++ return fc4 } @@ -270,13 +269,13 @@ func TestRegistryBuilder(t *testing.T) { codec2 := &fakeCodec{num: 2} rb := newTestRegistryBuilder() - rb.RegisterKindEncoder(reflect.Map, func() ValueEncoder { return codec }) + rb.RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return codec }) reg := rb.Build() if got := reg.kindEncoders[reflect.Map]; got != codec { t.Errorf("map codec not properly set: got %#v, want %#v", got, codec) } - rb.RegisterKindEncoder(reflect.Map, func() ValueEncoder { return codec2 }) + rb.RegisterKindEncoder(reflect.Map, func(*Registry) ValueEncoder { return codec2 }) reg = rb.Build() if got := reg.kindEncoders[reflect.Map]; got != codec2 { t.Errorf("map codec not properly set: got %#v, want %#v", got, codec2) @@ -289,13 +288,13 @@ func TestRegistryBuilder(t *testing.T) { codec2 := &fakeCodec{num: 2} rb := newTestRegistryBuilder() - rb.RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return codec }) + rb.RegisterKindEncoder(reflect.Struct, func(*Registry) ValueEncoder { return codec }) reg := rb.Build() if got := reg.kindEncoders[reflect.Struct]; got != codec { t.Errorf("struct codec not properly set: got %#v, want %#v", got, codec) } - rb.RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return codec2 }) + rb.RegisterKindEncoder(reflect.Struct, func(*Registry) ValueEncoder { return codec2 }) reg = rb.Build() if got := reg.kindEncoders[reflect.Struct]; got != codec2 { t.Errorf("struct codec not properly set: got %#v, want %#v", got, codec2) @@ -308,13 +307,13 @@ func TestRegistryBuilder(t *testing.T) { codec2 := &fakeCodec{num: 2} rb := newTestRegistryBuilder() - rb.RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return codec }) + rb.RegisterKindEncoder(reflect.Slice, func(*Registry) ValueEncoder { return codec }) reg := rb.Build() if got := reg.kindEncoders[reflect.Slice]; got != codec { t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec) } - rb.RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return codec2 }) + rb.RegisterKindEncoder(reflect.Slice, func(*Registry) ValueEncoder { return codec2 }) reg = rb.Build() if got := reg.kindEncoders[reflect.Slice]; got != codec2 { t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec2) @@ -327,13 +326,13 @@ func TestRegistryBuilder(t *testing.T) { codec2 := &fakeCodec{num: 2} rb := newTestRegistryBuilder() - rb.RegisterKindEncoder(reflect.Array, func() ValueEncoder { return codec }) + rb.RegisterKindEncoder(reflect.Array, func(*Registry) ValueEncoder { return codec }) reg := rb.Build() if got := reg.kindEncoders[reflect.Array]; got != codec { t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec) } - rb.RegisterKindEncoder(reflect.Array, func() ValueEncoder { return codec2 }) + rb.RegisterKindEncoder(reflect.Array, func(*Registry) ValueEncoder { return codec2 }) reg = rb.Build() if got := reg.kindEncoders[reflect.Array]; got != codec2 { t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec2) @@ -369,21 +368,21 @@ func TestRegistryBuilder(t *testing.T) { pc = &pointerCodec{} ) - fc1EncFac := func() ValueEncoder { return fc1 } - fc2EncFac := func() ValueEncoder { return fc2 } - fc3EncFac := func() ValueEncoder { return fc3 } - fscEncFac := func() ValueEncoder { return fsc } - fslccEncFac := func() ValueEncoder { return fslcc } - fmcEncFac := func() ValueEncoder { return fmc } - pcEncFac := func() ValueEncoder { return pc } - - fc1DecFac := func() ValueDecoder { return fc1 } - fc2DecFac := func() ValueDecoder { return fc2 } - fc3DecFac := func() ValueDecoder { return fc3 } - fscDecFac := func() ValueDecoder { return fsc } - fslccDecFac := func() ValueDecoder { return fslcc } - fmcDecFac := func() ValueDecoder { return fmc } - pcDecFac := func() ValueDecoder { return pc } + fc1EncFac := func(*Registry) ValueEncoder { return fc1 } + fc2EncFac := func(*Registry) ValueEncoder { return fc2 } + fc3EncFac := func(*Registry) ValueEncoder { return fc3 } + fscEncFac := func(*Registry) ValueEncoder { return fsc } + fslccEncFac := func(*Registry) ValueEncoder { return fslcc } + fmcEncFac := func(*Registry) ValueEncoder { return fmc } + pcEncFac := func(*Registry) ValueEncoder { return pc } + + fc1DecFac := func(*Registry) ValueDecoder { return fc1 } + fc2DecFac := func(*Registry) ValueDecoder { return fc2 } + fc3DecFac := func(*Registry) ValueDecoder { return fc3 } + fscDecFac := func(*Registry) ValueDecoder { return fsc } + fslccDecFac := func(*Registry) ValueDecoder { return fslcc } + fmcDecFac := func(*Registry) ValueDecoder { return fmc } + pcDecFac := func(*Registry) ValueDecoder { return pc } reg := newTestRegistryBuilder(). RegisterTypeEncoder(ft1, fc1EncFac). @@ -678,7 +677,7 @@ func BenchmarkLookupEncoder(b *testing.B) { } rb := NewRegistryBuilder() for _, typ := range types { - rb.RegisterTypeEncoder(typ, func() ValueEncoder { return &fakeCodec{} }) + rb.RegisterTypeEncoder(typ, func(*Registry) ValueEncoder { return &fakeCodec{} }) } r := rb.Build() b.Run("Serial", func(b *testing.B) { diff --git a/bson/struct_codec.go b/bson/struct_codec.go index 5155da984e..c9b5306c33 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -47,17 +47,15 @@ func (de *DecodeError) Keys() []string { return reversedKeys } -// StructTagHandler defines the struct encoder bahavior when inline, minSize and truncate tags are set. -type StructTagHandler struct { - InlineEncoder func(EncoderRegistry, DocumentWriter, reflect.Value, func(string) bool) error - LookupEncoderOnMinSize func(EncoderRegistry, reflect.Type) (ValueEncoder, error) - LookupEncoderOnTruncate func(EncoderRegistry, reflect.Type) (ValueEncoder, error) +// mapElementsEncoder handles encoding of the values of an inline map. +type mapElementsEncoder interface { + encodeMapElements(EncoderRegistry, DocumentWriter, reflect.Value, func(string) bool) error } // structCodec is the Codec used for struct values. type structCodec struct { - cache sync.Map // map[reflect.Type]*structDescription - tagHndl StructTagHandler + cache sync.Map // map[reflect.Type]*structDescription + elemEncoder mapElementsEncoder // decodeZeroStruct causes DecodeValue to delete any existing values from Go structs in the decodeZeroStruct bool @@ -83,9 +81,9 @@ type structCodec struct { } // newStructCodec returns a StructCodec that uses p for struct tag parsing. -func newStructCodec(hndl StructTagHandler) *structCodec { +func newStructCodec(elemEncoder mapElementsEncoder) *structCodec { return &structCodec{ - tagHndl: hndl, + elemEncoder: elemEncoder, overwriteDuplicatedInlinedFields: true, } } @@ -99,6 +97,40 @@ func (r *localEncoderRegistry) LookupEncoder(t reflect.Type) (ValueEncoder, erro return r.encoderLookup(r.registry, t) } +func onMinSize(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { + enc, err := reg.LookupEncoder(t) + if err != nil { + return enc, err + } + switch t.Kind() { + case reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64: + if codec, ok := enc.(*numCodec); ok { + c := *codec + c.minSize = true + return &c, nil + } + } + return enc, nil +} + +func onTruncate(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { + enc, err := reg.LookupEncoder(t) + if err != nil { + return enc, err + } + switch t.Kind() { + case reflect.Float32, + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + if codec, ok := enc.(*numCodec); ok { + c := *codec + c.truncate = true + return &c, nil + } + } + return enc, nil +} + // EncodeValue handles encoding generic struct types. func (sc *structCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Struct { @@ -192,10 +224,7 @@ func (sc *structCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val refl return exists } - if sc.tagHndl.InlineEncoder == nil { - return errors.New("inline map encoder is not defined") - } - err = sc.tagHndl.InlineEncoder(reg, dw, rv, collisionFn) + err = sc.elemEncoder.encodeMapElements(reg, dw, rv, collisionFn) if err != nil { return err } @@ -501,16 +530,16 @@ func (sc *structCodec) describeStructSlow( description.name = stags.Name description.omitEmpty = stags.OmitEmpty description.encoderLookup = func(reg EncoderRegistry, t reflect.Type) (ValueEncoder, error) { - if stags.MinSize && sc.tagHndl.LookupEncoderOnMinSize != nil { + if stags.MinSize { reg = &localEncoderRegistry{ registry: reg, - encoderLookup: sc.tagHndl.LookupEncoderOnMinSize, + encoderLookup: onMinSize, } } - if stags.Truncate && sc.tagHndl.LookupEncoderOnTruncate != nil { + if stags.Truncate { reg = &localEncoderRegistry{ registry: reg, - encoderLookup: sc.tagHndl.LookupEncoderOnTruncate, + encoderLookup: onTruncate, } } return reg.LookupEncoder(t) diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index c6ba38d760..d643a7db57 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -205,7 +205,7 @@ func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) { return nil } customReg := NewRegistryBuilder(). - RegisterTypeDecoder(tInt32, func() ValueDecoder { return decodeInt32 }). + RegisterTypeDecoder(tInt32, func(*Registry) ValueDecoder { return decodeInt32 }). Build() docBytes := bsoncore.BuildDocumentFromElements( diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index c91eece0a9..f9acc852ad 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -76,7 +76,7 @@ func TestUnmarshalValue(t *testing.T) { }, } reg := NewRegistryBuilder(). - RegisterTypeDecoder(reflect.TypeOf([]byte{}), func() ValueDecoder { return &sliceCodec{} }). + RegisterTypeDecoder(reflect.TypeOf([]byte{}), func(*Registry) ValueDecoder { return &sliceCodec{} }). Build() for _, tc := range testCases { tc := tc @@ -112,7 +112,7 @@ func BenchmarkSliceCodecUnmarshal(b *testing.B) { }, } reg := NewRegistryBuilder(). - RegisterTypeDecoder(reflect.TypeOf([]byte{}), func() ValueDecoder { return &sliceCodec{} }). + RegisterTypeDecoder(reflect.TypeOf([]byte{}), func(*Registry) ValueDecoder { return &sliceCodec{} }). Build() for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index cc3942db8f..e2f9fb36d1 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -101,8 +101,8 @@ func TestClient(t *testing.T) { mt := mtest.New(t, noClientOpts) reg := bson.NewRegistryBuilder(). - RegisterTypeEncoder(reflect.TypeOf(int64(0)), func() bson.ValueEncoder { return &negateCodec{} }). - RegisterTypeDecoder(reflect.TypeOf(int64(0)), func() bson.ValueDecoder { return &negateCodec{} }). + RegisterTypeEncoder(reflect.TypeOf(int64(0)), func(*bson.Registry) bson.ValueEncoder { return &negateCodec{} }). + RegisterTypeDecoder(reflect.TypeOf(int64(0)), func(*bson.Registry) bson.ValueDecoder { return &negateCodec{} }). Build() registryOpts := options.Client(). SetRegistry(reg) diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index 8cd987d2de..4ea80d0a54 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -183,7 +183,7 @@ var directories = []string{ var checkOutcomeOpts = options.Collection().SetReadPreference(readpref.Primary()).SetReadConcern(readconcern.Local()) var specTestRegistry = bson.NewRegistryBuilder(). RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). - RegisterTypeDecoder(reflect.TypeOf(testData{}), func() bson.ValueDecoder { return bson.ValueDecoderFunc(decodeTestData) }). + RegisterTypeDecoder(reflect.TypeOf(testData{}), func(*bson.Registry) bson.ValueDecoder { return bson.ValueDecoderFunc(decodeTestData) }). Build() func TestUnifiedSpecs(t *testing.T) {