Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Dec 5, 2023
1 parent 60e5756 commit 37bb727
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 95 deletions.
3 changes: 2 additions & 1 deletion bson/bson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,15 @@ func TestMapCodec(t *testing.T) {
{"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"},
}
buf := new(bytes.Buffer)
enc := new(Encoder)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mapCodec := bsoncodec.NewMapCodec(tc.opts)
mapRegistry := NewRegistryBuilder().RegisterDefaultEncoder(reflect.Map, mapCodec).Build()
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err)
enc := NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(mapRegistry)
err = enc.Encode(mapObj)
assert.Nil(t, err, "Encode error: %v", err)
Expand Down
6 changes: 4 additions & 2 deletions bson/marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ var tInt32 = reflect.TypeOf(int32(0))

func TestMarshalWithRegistry(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(Encoder)
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
var reg *bsoncodec.Registry
Expand All @@ -38,7 +39,7 @@ func TestMarshalWithRegistry(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
noerr(t, err)
enc := NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(reg)
err = enc.Encode(tc.val)
noerr(t, err)
Expand All @@ -53,6 +54,7 @@ func TestMarshalWithRegistry(t *testing.T) {

func TestMarshalWithContext(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(Encoder)
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
var reg *bsoncodec.Registry
Expand All @@ -64,7 +66,7 @@ func TestMarshalWithContext(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
noerr(t, err)
enc := NewEncoder(vw)
enc.Reset(vw)
enc.IntMinSize()
enc.SetRegistry(reg)
err = enc.Encode(tc.val)
Expand Down
78 changes: 0 additions & 78 deletions bson/marshal_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"strings"
"testing"

"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/assert"
)
Expand All @@ -35,83 +34,6 @@ func TestMarshalValue(t *testing.T) {
})
}
})
t.Run("MarshalValueAppend", func(t *testing.T) {
t.Parallel()

for _, tc := range marshalValueTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

valueType, valueBytes, err := MarshalValueAppend(nil, tc.val)
assert.Nil(t, err, "MarshalValueAppend error: %v", err)
compareMarshalValueResults(t, tc, valueType, valueBytes)
})
}
})
t.Run("MarshalValueWithRegistry", func(t *testing.T) {
t.Parallel()

for _, tc := range marshalValueTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

valueType, valueBytes, err := MarshalValueWithRegistry(DefaultRegistry, tc.val)
assert.Nil(t, err, "MarshalValueWithRegistry error: %v", err)
compareMarshalValueResults(t, tc, valueType, valueBytes)
})
}
})
t.Run("MarshalValueWithContext", func(t *testing.T) {
t.Parallel()

ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
for _, tc := range marshalValueTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

valueType, valueBytes, err := MarshalValueWithContext(ec, tc.val)
assert.Nil(t, err, "MarshalValueWithContext error: %v", err)
compareMarshalValueResults(t, tc, valueType, valueBytes)
})
}
})
t.Run("MarshalValueAppendWithRegistry", func(t *testing.T) {
t.Parallel()

for _, tc := range marshalValueTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

valueType, valueBytes, err := MarshalValueAppendWithRegistry(DefaultRegistry, nil, tc.val)
assert.Nil(t, err, "MarshalValueAppendWithRegistry error: %v", err)
compareMarshalValueResults(t, tc, valueType, valueBytes)
})
}
})
t.Run("MarshalValueAppendWithContext", func(t *testing.T) {
t.Parallel()

ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
for _, tc := range marshalValueTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

valueType, valueBytes, err := MarshalValueAppendWithContext(ec, nil, tc.val)
assert.Nil(t, err, "MarshalValueWithContext error: %v", err)
compareMarshalValueResults(t, tc, valueType, valueBytes)
})
}
})
}

func compareMarshalValueResults(t *testing.T, tc marshalValueTestCase, gotType bsontype.Type, gotBytes []byte) {
Expand Down
27 changes: 18 additions & 9 deletions bson/mgocompat/bson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,13 @@ var sampleItems = []testItemType{

func TestMarshalSampleItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range sampleItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -170,12 +171,13 @@ var allItems = []testItemType{

func TestMarshalAllItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range allItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -295,12 +297,13 @@ func TestPtrInline(t *testing.T) {
}

buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, cs := range cases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(cs.In)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -384,12 +387,13 @@ var oneWayMarshalItems = []testItemType{

func TestOneWayMarshalItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range oneWayMarshalItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -421,12 +425,13 @@ var structSampleItems = []testItemType{

func TestMarshalStructSampleItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range structSampleItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -588,12 +593,13 @@ var structItems = []testItemType{

func TestMarshalStructItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range structItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -664,12 +670,13 @@ var marshalItems = []testItemType{

func TestMarshalOneWayItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range marshalItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down Expand Up @@ -777,12 +784,13 @@ var marshalErrorItems = []testItemType{

func TestMarshalErrorItems(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range marshalErrorItems {
t.Run(strconv.Itoa(i), func(t *testing.T) {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(item.obj)

Expand Down Expand Up @@ -1079,6 +1087,7 @@ type docWithGetterField struct {

func TestMarshalAllItemsWithGetter(t *testing.T) {
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, item := range allItems {
if item.data == "" {
continue
Expand All @@ -1089,7 +1098,7 @@ func TestMarshalAllItemsWithGetter(t *testing.T) {
obj.Field = &typeWithGetter{result: item.obj.(bson.M)["_"]}
vw, err := bsonrw.NewBSONValueWriter(buf)
assert.Nil(t, err, "expected nil error, got: %v", err)
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(Registry)
err = enc.Encode(obj)
assert.Nil(t, err, "expected nil error, got: %v", err)
Expand Down
5 changes: 2 additions & 3 deletions mongo/cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco

// Convert documents slice to a sequence-style byte array.
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for _, doc := range documents {
switch t := doc.(type) {
case nil:
Expand All @@ -105,11 +106,9 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco
if err != nil {
return nil, err
}
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(registry)
err = enc.Encode(doc)
// var marshalErr error
// docsBytes, marshalErr = bson.MarshalAppendWithRegistry(registry, docsBytes, doc)
if err != nil {
return nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions mongo/options/mongooptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,14 @@ func (af *ArrayFilters) ToArray() ([]bson.Raw, error) {
}
filters := make([]bson.Raw, 0, len(af.Filters))
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for _, f := range af.Filters {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
if err != nil {
return nil, err
}
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(registry)
err = enc.Encode(f)
if err != nil {
Expand All @@ -163,13 +164,14 @@ func (af *ArrayFilters) ToArrayDocument() (bson.Raw, error) {

idx, arr := bsoncore.AppendArrayStart(nil)
buf := new(bytes.Buffer)
enc := new(bson.Encoder)
for i, f := range af.Filters {
buf.Reset()
vw, err := bsonrw.NewBSONValueWriter(buf)
if err != nil {
return nil, err
}
enc := bson.NewEncoder(vw)
enc.Reset(vw)
enc.SetRegistry(registry)
err = enc.Encode(f)
if err != nil {
Expand Down

0 comments on commit 37bb727

Please sign in to comment.