Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed May 22, 2024
1 parent 5ce2670 commit 08ec80e
Show file tree
Hide file tree
Showing 14 changed files with 88 additions and 253 deletions.
3 changes: 1 addition & 2 deletions bson/bson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,7 @@ func TestMapCodec(t *testing.T) {
mapRegistry.RegisterKindEncoder(reflect.Map, func() ValueEncoder { return tc.codec })
buf := new(bytes.Buffer)
vw := NewValueWriter(buf)
enc := NewEncoder(vw)
enc.SetRegistry(mapRegistry.Build())
enc := NewEncoderWithRegistry(mapRegistry.Build(), vw)
err := enc.Encode(mapObj)
assert.Nil(t, err, "Encode error: %v", err)
str := buf.String()
Expand Down
32 changes: 10 additions & 22 deletions bson/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,11 @@ import (
"errors"
"fmt"
"reflect"
"sync"
)

// ErrDecodeToNil is the error returned when trying to decode to a nil value
var ErrDecodeToNil = errors.New("cannot Decode to nil value")

// This pool is used to keep the allocations of Decoders down. This is only used for the Marshal*
// methods and is not consumable from outside of this package. The Decoders retrieved from this pool
// must have both Reset and SetRegistry called on them.
var decPool = sync.Pool{
New: func() interface{} {
return new(Decoder)
},
}

// A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as
// the source of BSON data.
type Decoder struct {
Expand All @@ -34,8 +24,17 @@ type Decoder struct {

// NewDecoder returns a new decoder that uses the default registry to read from vr.
func NewDecoder(vr ValueReader) *Decoder {
r := NewRegistryBuilder().Build()
return &Decoder{
reg: NewRegistryBuilder().Build(),
reg: r,
vr: vr,
}
}

// NewDecoderWithRegistry returns a new decoder that uses the given registry to read from vr.
func NewDecoderWithRegistry(r *Registry, vr ValueReader) *Decoder {
return &Decoder{
reg: r,
vr: vr,
}
}
Expand Down Expand Up @@ -76,17 +75,6 @@ func (d *Decoder) Decode(val interface{}) error {
return decoder.DecodeValue(d.reg, d.vr, rval)
}

// Reset will reset the state of the decoder, using the same *DecodeContext used in
// the original construction but using vr for reading.
func (d *Decoder) Reset(vr ValueReader) {
d.vr = vr
}

// SetRegistry replaces the current registry of the decoder with r.
func (d *Decoder) SetRegistry(r *Registry) {
d.reg = r
}

// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This
// behavior is restricted to data typed as "interface{}" or "map[string]interface{}".
func (d *Decoder) DefaultDocumentM() {
Expand Down
26 changes: 0 additions & 26 deletions bson/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,32 +183,6 @@ func TestDecoderv2(t *testing.T) {
want := foo{Item: "canvas", Qty: 4, Bonus: 2}
assert.Equal(t, want, got, "Results do not match.")
})
t.Run("Reset", func(t *testing.T) {
t.Parallel()

vr1, vr2 := NewValueReader([]byte{}), NewValueReader([]byte{})
dec := NewDecoder(vr1)
if dec.vr != vr1 {
t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr1)
}
dec.Reset(vr2)
if dec.vr != vr2 {
t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr2)
}
})
// t.Run("SetRegistry", func(t *testing.T) {
// t.Parallel()

// r1, r2 := DefaultRegistry, NewRegistryBuilder().Build()
// dec := NewDecoder(NewValueReader([]byte{}))
// if !reflect.DeepEqual(dec.reg, r1) {
// t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r1)
// }
// dec.SetRegistry(r2)
// if !reflect.DeepEqual(dec.reg, r2) {
// t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.reg, r2)
// }
// })
t.Run("DecodeToNil", func(t *testing.T) {
t.Parallel()

Expand Down
33 changes: 11 additions & 22 deletions bson/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,8 @@ package bson

import (
"reflect"
"sync"
)

// This pool is used to keep the allocations of Encoders down. This is only used for the Marshal*
// methods and is not consumable from outside of this package. The Encoders retrieved from this pool
// must have both Reset and SetRegistry called on them.
var encPool = sync.Pool{
New: func() interface{} {
return new(Encoder)
},
}

// An Encoder writes a serialization format to an output stream. It writes to a ValueWriter
// as the destination of BSON data.
type Encoder struct {
Expand All @@ -35,6 +25,14 @@ func NewEncoder(vw ValueWriter) *Encoder {
}
}

// NewEncoderWithRegistry returns a new encoder that uses the given registry to write to vw.
func NewEncoderWithRegistry(r *Registry, vw ValueWriter) *Encoder {
return &Encoder{
reg: r,
vw: vw,
}
}

// Encode writes the BSON encoding of val to the stream.
//
// See [Marshal] for details about BSON marshaling behavior.
Expand All @@ -56,17 +54,6 @@ func (e *Encoder) Encode(val interface{}) error {
return encoder.EncodeValue(e.reg, e.vw, reflect.ValueOf(val))
}

// Reset will reset the state of the Encoder, using the same *EncodeContext used in
// the original construction but using vw.
func (e *Encoder) Reset(vw ValueWriter) {
e.vw = vw
}

// SetRegistry replaces the current registry of the Encoder with r.
func (e *Encoder) SetRegistry(r *Registry) {
e.reg = r
}

// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in
// the marshaled BSON when the "inline" struct tag option is set.
func (e *Encoder) ErrorOnInlineDuplicates() {
Expand Down Expand Up @@ -172,7 +159,9 @@ func (e *Encoder) UseJSONStructTags() {
t := reflect.TypeOf((*structCodec)(nil))
if v, ok := e.reg.codecTypeMap[t]; ok && v != nil {
for i := range v {
v[i].(*structCodec).useJSONStructTags = true
if enc, ok := v[i].(*structCodec); ok {
enc.useJSONStructTags = true
}
}
}
}
18 changes: 3 additions & 15 deletions bson/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,7 @@ func Marshal(val interface{}) ([]byte, error) {
}
}()
sw.Reset()
vw := NewValueWriter(sw)
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
enc.Reset(vw)
enc.SetRegistry(NewRegistryBuilder().Build())
enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), NewValueWriter(sw))
err := enc.Encode(val)
if err != nil {
return nil, err
Expand All @@ -100,10 +96,7 @@ func MarshalValueWithRegistry(r *Registry, val interface{}) (Type, []byte, error
vwFlusher := bvwPool.GetAtModeElement(&sw)

// get an Encoder and encode the value
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
enc.Reset(vwFlusher)
enc.SetRegistry(r)
enc := NewEncoderWithRegistry(r, vwFlusher)
if err := enc.Encode(val); err != nil {
return 0, nil, err
}
Expand All @@ -123,12 +116,7 @@ func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error)
ejvw := extjPool.Get(&sw, canonical, escapeHTML)
defer extjPool.Put(ejvw)

enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)

enc.Reset(ejvw)
enc.SetRegistry(NewRegistryBuilder().Build())

enc := NewEncoderWithRegistry(NewRegistryBuilder().Build(), ejvw)
err := enc.Encode(val)
if err != nil {
return nil, err
Expand Down
31 changes: 2 additions & 29 deletions bson/marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,7 @@ func TestMarshalWithRegistry(t *testing.T) {
}
buf := new(bytes.Buffer)
vw := NewValueWriter(buf)
enc := NewEncoder(vw)
enc.SetRegistry(reg)
err := enc.Encode(tc.val)
noerr(t, err)

if got := buf.Bytes(); !bytes.Equal(got, tc.want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
}
})
}
}

func TestMarshalWithContext(t *testing.T) {
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
var reg *Registry
if tc.reg != nil {
reg = tc.reg
} else {
reg = NewRegistryBuilder().Build()
}
buf := new(bytes.Buffer)
vw := NewValueWriter(buf)
enc := NewEncoder(vw)
enc.IntMinSize()
enc.SetRegistry(reg)
enc := NewEncoderWithRegistry(reg, vw)
err := enc.Encode(tc.val)
noerr(t, err)

Expand Down Expand Up @@ -175,8 +149,7 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) {

buf := new(bytes.Buffer)
vw := NewValueWriter(buf)
enc := NewEncoder(vw)
enc.SetRegistry(customReg)
enc := NewEncoderWithRegistry(customReg, vw)
err = enc.Encode(original)
assert.Nil(t, err, "Encode error: %v", err)
second := buf.Bytes()
Expand Down
Loading

0 comments on commit 08ec80e

Please sign in to comment.