Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Dec 5, 2023
1 parent 60e5756 commit 6409678
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 165 deletions.
3 changes: 2 additions & 1 deletion bson/bson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,15 @@ func TestMapCodec(t *testing.T) {
{"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"},
}
buf := new(bytes.Buffer)
enc := new(Encoder)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mapCodec := bsoncodec.NewMapCodec(tc.opts)
mapRegistry := NewRegistryBuilder().RegisterDefaultEncoder(reflect.Map, mapCodec).Build()
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err)
enc := NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(mapRegistry)
err = enc.Encode(mapObj)
assert.Nil(t, err, "Encode error: %v", err)
Expand Down
71 changes: 5 additions & 66 deletions bson/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (

const defaultDstCap = 256

var bvwPool = bsonrw.NewBSONValueWriterPool()
var extjPool = bsonrw.NewExtJSONValueWriterPool()

// Marshaler is the interface implemented by types that can marshal themselves
Expand Down Expand Up @@ -53,7 +52,9 @@ func Marshal(val interface{}) ([]byte, error) {
if err != nil {
return nil, err
}
enc := NewEncoder(vw)
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
enc.Reset(vw)
enc.SetRegistry(DefaultRegistry)
err = enc.Encode(val)
if err != nil {
Expand All @@ -68,72 +69,10 @@ func Marshal(val interface{}) ([]byte, error) {
// MarshalValue will use bson.DefaultRegistry to transform val into a BSON value. If val is a struct, this function will
// inspect struct tags and alter the marshalling process accordingly.
func MarshalValue(val interface{}) (bsontype.Type, []byte, error) {
return MarshalValueWithRegistry(DefaultRegistry, val)
}

// MarshalValueAppend will append the BSON encoding of val to dst. If dst is not large enough to hold the BSON encoding
// of val, dst will be grown.
//
// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go
// Driver 2.0.
func MarshalValueAppend(dst []byte, val interface{}) (bsontype.Type, []byte, error) {
return MarshalValueAppendWithRegistry(DefaultRegistry, dst, val)
}

// MarshalValueWithRegistry returns the BSON encoding of val using Registry r.
//
// Deprecated: Using a custom registry to marshal individual BSON values will not be supported in Go
// Driver 2.0.
func MarshalValueWithRegistry(r *bsoncodec.Registry, val interface{}) (bsontype.Type, []byte, error) {
dst := make([]byte, 0)
return MarshalValueAppendWithRegistry(r, dst, val)
}

// MarshalValueWithContext returns the BSON encoding of val using EncodeContext ec.
//
// Deprecated: Using a custom EncodeContext to marshal individual BSON elements will not be
// supported in Go Driver 2.0.
func MarshalValueWithContext(ec bsoncodec.EncodeContext, val interface{}) (bsontype.Type, []byte, error) {
dst := make([]byte, 0)
return MarshalValueAppendWithContext(ec, dst, val)
}

// MarshalValueAppendWithRegistry will append the BSON encoding of val to dst using Registry r. If dst is not large
// enough to hold the BSON encoding of val, dst will be grown.
//
// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go
// Driver 2.0.
func MarshalValueAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) (bsontype.Type, []byte, error) {
return MarshalValueAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val)
}

// MarshalValueAppendWithContext will append the BSON encoding of val to dst using EncodeContext ec. If dst is not large
// enough to hold the BSON encoding of val, dst will be grown.
//
// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go
// Driver 2.0.
func MarshalValueAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) (bsontype.Type, []byte, error) {
// get a ValueWriter configured to write to dst
sw := new(bsonrw.SliceWriter)
*sw = dst
vwFlusher := bvwPool.GetAtModeElement(sw)

// get an Encoder and encode the value
enc := encPool.Get().(*Encoder)
defer encPool.Put(enc)
enc.Reset(vwFlusher)
enc.ec = ec
if err := enc.Encode(val); err != nil {
return 0, nil, err
}

// flush the bytes written because we cannot guarantee that a full document has been written
// after the flush, *sw will be in the format
// [value type, 0 (null byte to indicate end of empty element name), value bytes..]
if err := vwFlusher.Flush(); err != nil {
buffer, err := Marshal(val)
if err != nil {
return 0, nil, err
}
buffer := *sw
return bsontype.Type(buffer[0]), buffer[2:], nil
}

Expand Down
6 changes: 4 additions & 2 deletions bson/marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var tInt32 = reflect.TypeOf(int32(0))

func TestMarshalWithRegistry(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(Encoder)
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
var reg *bsoncodec.Registry
Expand All @@ -38,7 +39,7 @@ func TestMarshalWithRegistry(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
noerr(t, err)
enc := NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(reg)
err = enc.Encode(tc.val)
noerr(t, err)
Expand All @@ -53,6 +54,7 @@ func TestMarshalWithRegistry(t *testing.T) {

func TestMarshalWithContext(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(Encoder)
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
var reg *bsoncodec.Registry
Expand All @@ -64,7 +66,7 @@ func TestMarshalWithContext(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
noerr(t, err)
enc := NewEncoder(vw)
enc.Reset(vw)
enc.IntMinSize()
enc.SetRegistry(reg)
err = enc.Encode(tc.val)
Expand Down
78 changes: 0 additions & 78 deletions bson/marshal_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"strings"
"testing"

"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/assert"
)
Expand All @@ -35,83 +34,6 @@ func TestMarshalValue(t *testing.T) {
})
}
})
t.Run("MarshalValueAppend", func(t *testing.T) {
t.Parallel()

for _, tc := range marshalValueTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

valueType, valueBytes, err := MarshalValueAppend(nil, tc.val)
assert.Nil(t, err, "MarshalValueAppend error: %v", err)
compareMarshalValueResults(t, tc, valueType, valueBytes)
})
}
})
t.Run("MarshalValueWithRegistry", func(t *testing.T) {
t.Parallel()

for _, tc := range marshalValueTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

valueType, valueBytes, err := MarshalValueWithRegistry(DefaultRegistry, tc.val)
assert.Nil(t, err, "MarshalValueWithRegistry error: %v", err)
compareMarshalValueResults(t, tc, valueType, valueBytes)
})
}
})
t.Run("MarshalValueWithContext", func(t *testing.T) {
t.Parallel()

ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
for _, tc := range marshalValueTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

valueType, valueBytes, err := MarshalValueWithContext(ec, tc.val)
assert.Nil(t, err, "MarshalValueWithContext error: %v", err)
compareMarshalValueResults(t, tc, valueType, valueBytes)
})
}
})
t.Run("MarshalValueAppendWithRegistry", func(t *testing.T) {
t.Parallel()

for _, tc := range marshalValueTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

valueType, valueBytes, err := MarshalValueAppendWithRegistry(DefaultRegistry, nil, tc.val)
assert.Nil(t, err, "MarshalValueAppendWithRegistry error: %v", err)
compareMarshalValueResults(t, tc, valueType, valueBytes)
})
}
})
t.Run("MarshalValueAppendWithContext", func(t *testing.T) {
t.Parallel()

ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
for _, tc := range marshalValueTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

valueType, valueBytes, err := MarshalValueAppendWithContext(ec, nil, tc.val)
assert.Nil(t, err, "MarshalValueWithContext error: %v", err)
compareMarshalValueResults(t, tc, valueType, valueBytes)
})
}
})
}

func compareMarshalValueResults(t *testing.T, tc marshalValueTestCase, gotType bsontype.Type, gotBytes []byte) {
Expand Down
27 changes: 18 additions & 9 deletions bson/mgocompat/bson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,13 @@ var sampleItems = []testItemType{

func TestMarshalSampleItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range sampleItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -170,12 +171,13 @@ var allItems = []testItemType{

func TestMarshalAllItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range allItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -295,12 +297,13 @@ func TestPtrInline(t *testing.T) {
}

buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, cs := range cases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(cs.In)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -384,12 +387,13 @@ var oneWayMarshalItems = []testItemType{

func TestOneWayMarshalItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range oneWayMarshalItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -421,12 +425,13 @@ var structSampleItems = []testItemType{

func TestMarshalStructSampleItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range structSampleItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -588,12 +593,13 @@ var structItems = []testItemType{

func TestMarshalStructItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range structItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -664,12 +670,13 @@ var marshalItems = []testItemType{

func TestMarshalOneWayItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range marshalItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -777,12 +784,13 @@ var marshalErrorItems = []testItemType{

func TestMarshalErrorItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range marshalErrorItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)

Expand Down Expand Up @@ -1079,6 +1087,7 @@ type docWithGetterField struct {

func TestMarshalAllItemsWithGetter(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range allItems {
if item.data == "" {
continue
Expand All @@ -1089,7 +1098,7 @@ func TestMarshalAllItemsWithGetter(t *testing.T) {
obj.Field = &typeWithGetter{result: item.obj.(bson.M)["_"]}
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down
Loading

0 comments on commit 6409678

Please sign in to comment.