Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed May 20, 2024
1 parent b2a1c15 commit 7a0b8bf
Show file tree
Hide file tree
Showing 17 changed files with 205 additions and 212 deletions.
28 changes: 1 addition & 27 deletions bson/bsoncodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,31 +72,6 @@ func (vde ValueDecoderError) Error() string {
return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received)
}

// DecodeContext is the contextual information required for a Codec to decode a
// value.
type DecodeContext struct {
*Registry

// defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the
// usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is
// set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an
// error. DocumentType overrides the Ancestor field.
defaultDocumentType reflect.Type

// truncate, if true, instructs decoders to to truncate the fractional part of BSON "double"
// values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64,
// uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to
// BSON "decimal128" values.
truncate bool

binaryAsSlice bool
decodeObjectIDAsHex bool
useJSONStructTags bool
useLocalTimeZone bool
zeroMaps bool
zeroStructs bool
}

// EncoderRegistry is an interface provides a ValueEncoder based on the given reflect.Type.
type EncoderRegistry interface {
LookupEncoder(reflect.Type) (ValueEncoder, error)
Expand Down Expand Up @@ -166,8 +141,7 @@ var _ ValueDecoder = decodeAdapter{}
var _ typeDecoder = decodeAdapter{}

func decodeTypeOrValueWithInfo(vd ValueDecoder, reg DecoderRegistry, vr ValueReader, t reflect.Type) (reflect.Value, error) {
td, ok := vd.(typeDecoder)
if ok && td != nil {
if td, _ := vd.(typeDecoder); td != nil {
return td.decodeType(reg, vr, t)
}

Expand Down
4 changes: 2 additions & 2 deletions bson/cond_addr_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func TestCondAddrCodec(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := condDecoder.DecodeValue(DecodeContext{}, rw, tc.val)
err := condDecoder.DecodeValue(nil, rw, tc.val)
assert.Nil(t, err, "CondAddrDecoder error: %v", err)

assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
Expand All @@ -87,7 +87,7 @@ func TestCondAddrCodec(t *testing.T) {

t.Run("error", func(t *testing.T) {
errDecoder := &condAddrDecoder{canAddrDec: decode1, elseDec: nil}
err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable)
err := errDecoder.DecodeValue(nil, rw, unaddressable)
want := ErrNoDecoder{Type: unaddressable.Type()}
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
})
Expand Down
78 changes: 62 additions & 16 deletions bson/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ var decPool = sync.Pool{
// A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as
// the source of BSON data.
type Decoder struct {
dc DecodeContext
vr ValueReader
reg *Registry
vr ValueReader
}

// NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr.
func NewDecoder(vr ValueReader) *Decoder {
return &Decoder{
dc: DecodeContext{Registry: DefaultRegistry},
vr: vr,
reg: DefaultRegistry,
vr: vr,
}
}

Expand Down Expand Up @@ -68,12 +68,12 @@ func (d *Decoder) Decode(val interface{}) error {
default:
return fmt.Errorf("argument to Decode must be a pointer or a map, but got %v", rval)
}
decoder, err := d.dc.LookupDecoder(rval.Type())
decoder, err := d.reg.LookupDecoder(rval.Type())
if err != nil {
return err
}

return decoder.DecodeValue(d.dc, d.vr, rval)
return decoder.DecodeValue(d.reg, d.vr, rval)
}

// Reset will reset the state of the decoder, using the same *DecodeContext used in
Expand All @@ -84,59 +84,105 @@ func (d *Decoder) Reset(vr ValueReader) {

// SetRegistry replaces the current registry of the decoder with r.
func (d *Decoder) SetRegistry(r *Registry) {
d.dc.Registry = r
d.reg = r
}

// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This
// behavior is restricted to data typed as "interface{}" or "map[string]interface{}".
func (d *Decoder) DefaultDocumentM() {
d.dc.defaultDocumentType = reflect.TypeOf(M{})
t := reflect.TypeOf((*emptyInterfaceCodec)(nil))
if v, ok := d.reg.codecTypeMap[t]; ok && v != nil {
for i := range v {
v[i].(*emptyInterfaceCodec).defaultDocumentType = reflect.TypeOf(M{})
}
}
}

// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This
// behavior is restricted to data typed as "interface{}" or "map[string]interface{}".
func (d *Decoder) DefaultDocumentD() {
d.dc.defaultDocumentType = reflect.TypeOf(D{})
t := reflect.TypeOf((*emptyInterfaceCodec)(nil))
if v, ok := d.reg.codecTypeMap[t]; ok && v != nil {
for i := range v {
v[i].(*emptyInterfaceCodec).defaultDocumentType = reflect.TypeOf(D{})
}
}
}

// AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values
// when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct
// field. The truncation logic does not apply to BSON "decimal128" values.
func (d *Decoder) AllowTruncatingDoubles() {
d.dc.truncate = true
t := reflect.TypeOf((*intCodec)(nil))
if v, ok := d.reg.codecTypeMap[t]; ok && v != nil {
for i := range v {
v[i].(*intCodec).truncate = true
}
}
// TODO floatCodec
}

// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or
// "Old" BSON binary subtype as a Go byte slice instead of a primitive.Binary.
func (d *Decoder) BinaryAsSlice() {
d.dc.binaryAsSlice = true
t := reflect.TypeOf((*emptyInterfaceCodec)(nil))
if v, ok := d.reg.codecTypeMap[t]; ok && v != nil {
for i := range v {
v[i].(*emptyInterfaceCodec).decodeBinaryAsSlice = true
}
}
}

// DecodeObjectIDAsHex causes the Decoder to unmarshal BSON ObjectID as a hexadecimal string.
func (d *Decoder) DecodeObjectIDAsHex() {
d.dc.decodeObjectIDAsHex = true
t := reflect.TypeOf((*stringCodec)(nil))
if v, ok := d.reg.codecTypeMap[t]; ok && v != nil {
for i := range v {
v[i].(*stringCodec).decodeObjectIDAsHex = true
}
}
}

// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson"
// struct tag is not specified.
func (d *Decoder) UseJSONStructTags() {
d.dc.useJSONStructTags = true
t := reflect.TypeOf((*structCodec)(nil))
if v, ok := d.reg.codecTypeMap[t]; ok && v != nil {
for i := range v {
v[i].(*structCodec).useJSONStructTags = true
}
}
}

// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead
// of the UTC timezone.
func (d *Decoder) UseLocalTimeZone() {
d.dc.useLocalTimeZone = true
t := reflect.TypeOf((*timeCodec)(nil))
if v, ok := d.reg.codecTypeMap[t]; ok && v != nil {
for i := range v {
v[i].(*timeCodec).useLocalTimeZone = true
}
}
}

// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value
// passed to Decode before unmarshaling BSON documents into them.
func (d *Decoder) ZeroMaps() {
d.dc.zeroMaps = true
t := reflect.TypeOf((*mapCodec)(nil))
if v, ok := d.reg.codecTypeMap[t]; ok && v != nil {
for i := range v {
v[i].(*mapCodec).decodeZerosMap = true
}
}
}

// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination
// value passed to Decode before unmarshaling BSON documents into them.
func (d *Decoder) ZeroStructs() {
d.dc.zeroStructs = true
t := reflect.TypeOf((*structCodec)(nil))
if v, ok := d.reg.codecTypeMap[t]; ok && v != nil {
for i := range v {
v[i].(*structCodec).decodeZeroStruct = true
}
}
}
12 changes: 5 additions & 7 deletions bson/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestBasicDecode(t *testing.T) {
reg := NewRegistryBuilder().Build()
decoder, err := reg.LookupDecoder(reflect.TypeOf(got))
noerr(t, err)
err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got)
err = decoder.DecodeValue(reg, vr, got)
noerr(t, err)
assert.Equal(t, tc.want, got.Addr().Interface(), "Results do not match.")
})
Expand Down Expand Up @@ -200,15 +200,13 @@ func TestDecoderv2(t *testing.T) {
t.Parallel()

r1, r2 := DefaultRegistry, NewRegistryBuilder().Build()
dc1 := DecodeContext{Registry: r1}
dc2 := DecodeContext{Registry: r2}
dec := NewDecoder(NewValueReader([]byte{}))
if !reflect.DeepEqual(dec.dc, dc1) {
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1)
if !reflect.DeepEqual(dec.reg, r1) {
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r1)
}
dec.SetRegistry(r2)
if !reflect.DeepEqual(dec.dc, dc2) {
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2)
if !reflect.DeepEqual(dec.reg, r2) {
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r2)
}
})
t.Run("DecodeToNil", func(t *testing.T) {
Expand Down
Loading

0 comments on commit 7a0b8bf

Please sign in to comment.