From 999bdd8bf3d67fe6e7bbb2b8493121a29ac3fd16 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 5 Dec 2023 11:59:09 -0500 Subject: [PATCH] WIP --- benchmark/bson_map.go | 4 +- benchmark/bson_struct.go | 5 ++- bson/bson_test.go | 3 +- bson/marshal.go | 71 +++---------------------------- bson/marshal_test.go | 6 ++- bson/marshal_value_test.go | 78 ----------------------------------- bson/mgocompat/bson_test.go | 27 ++++++++---- mongo/change_stream.go | 22 +++++++++- mongo/collection.go | 16 ++++++- mongo/cursor.go | 5 +-- mongo/options/mongooptions.go | 6 ++- 11 files changed, 75 insertions(+), 168 deletions(-) diff --git a/benchmark/bson_map.go b/benchmark/bson_map.go index 8760692053..b49ac48da7 100644 --- a/benchmark/bson_map.go +++ b/benchmark/bson_map.go @@ -51,13 +51,15 @@ func bsonMapEncoding(tm TimerManager, iters int, dataSet string) error { tm.ResetTimer() buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i := 0; i < iters; i++ { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) if err != nil { return err } - err = bson.NewEncoder(vw).Encode(doc) + enc.Reset(vw) + err = enc.Encode(doc) if err != nil { return err } diff --git a/benchmark/bson_struct.go b/benchmark/bson_struct.go index 3abf97ff26..484f2bcb11 100644 --- a/benchmark/bson_struct.go +++ b/benchmark/bson_struct.go @@ -73,7 +73,7 @@ func BSONFlatStructTagsEncoding(_ context.Context, tm TimerManager, iters int) e } buf := new(bytes.Buffer) - + enc := new(bson.Encoder) tm.ResetTimer() for i := 0; i < iters; i++ { buf.Reset() @@ -81,7 +81,8 @@ func BSONFlatStructTagsEncoding(_ context.Context, tm TimerManager, iters int) e if err != nil { return err } - err = bson.NewEncoder(vw).Encode(doc) + enc.Reset(vw) + err = enc.Encode(doc) if err != nil { return err } diff --git a/bson/bson_test.go b/bson/bson_test.go index 876928b33a..357308a7d1 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -179,6 +179,7 @@ func TestMapCodec(t *testing.T) { {"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"}, } buf := new(bytes.Buffer) + enc := new(Encoder) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mapCodec := bsoncodec.NewMapCodec(tc.opts) @@ -186,7 +187,7 @@ func TestMapCodec(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err) - enc := NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(mapRegistry) err = enc.Encode(mapObj) assert.Nil(t, err, "Encode error: %v", err) diff --git a/bson/marshal.go b/bson/marshal.go index 6fd42792bb..0af337aca0 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -17,7 +17,6 @@ import ( const defaultDstCap = 256 -var bvwPool = bsonrw.NewBSONValueWriterPool() var extjPool = bsonrw.NewExtJSONValueWriterPool() // Marshaler is the interface implemented by types that can marshal themselves @@ -53,7 +52,9 @@ func Marshal(val interface{}) ([]byte, error) { if err != nil { return nil, err } - enc := NewEncoder(vw) + enc := encPool.Get().(*Encoder) + defer encPool.Put(enc) + enc.Reset(vw) enc.SetRegistry(DefaultRegistry) err = enc.Encode(val) if err != nil { @@ -68,72 +69,10 @@ func Marshal(val interface{}) ([]byte, error) { // MarshalValue will use bson.DefaultRegistry to transform val into a BSON value. If val is a struct, this function will // inspect struct tags and alter the marshalling process accordingly. func MarshalValue(val interface{}) (bsontype.Type, []byte, error) { - return MarshalValueWithRegistry(DefaultRegistry, val) -} - -// MarshalValueAppend will append the BSON encoding of val to dst. If dst is not large enough to hold the BSON encoding -// of val, dst will be grown. -// -// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go -// Driver 2.0. -func MarshalValueAppend(dst []byte, val interface{}) (bsontype.Type, []byte, error) { - return MarshalValueAppendWithRegistry(DefaultRegistry, dst, val) -} - -// MarshalValueWithRegistry returns the BSON encoding of val using Registry r. -// -// Deprecated: Using a custom registry to marshal individual BSON values will not be supported in Go -// Driver 2.0. -func MarshalValueWithRegistry(r *bsoncodec.Registry, val interface{}) (bsontype.Type, []byte, error) { - dst := make([]byte, 0) - return MarshalValueAppendWithRegistry(r, dst, val) -} - -// MarshalValueWithContext returns the BSON encoding of val using EncodeContext ec. -// -// Deprecated: Using a custom EncodeContext to marshal individual BSON elements will not be -// supported in Go Driver 2.0. -func MarshalValueWithContext(ec bsoncodec.EncodeContext, val interface{}) (bsontype.Type, []byte, error) { - dst := make([]byte, 0) - return MarshalValueAppendWithContext(ec, dst, val) -} - -// MarshalValueAppendWithRegistry will append the BSON encoding of val to dst using Registry r. If dst is not large -// enough to hold the BSON encoding of val, dst will be grown. -// -// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go -// Driver 2.0. -func MarshalValueAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) (bsontype.Type, []byte, error) { - return MarshalValueAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val) -} - -// MarshalValueAppendWithContext will append the BSON encoding of val to dst using EncodeContext ec. If dst is not large -// enough to hold the BSON encoding of val, dst will be grown. -// -// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go -// Driver 2.0. -func MarshalValueAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) (bsontype.Type, []byte, error) { - // get a ValueWriter configured to write to dst - sw := new(bsonrw.SliceWriter) - *sw = dst - vwFlusher := bvwPool.GetAtModeElement(sw) - - // get an Encoder and encode the value - enc := encPool.Get().(*Encoder) - defer encPool.Put(enc) - enc.Reset(vwFlusher) - enc.ec = ec - if err := enc.Encode(val); err != nil { - return 0, nil, err - } - - // flush the bytes written because we cannot guarantee that a full document has been written - // after the flush, *sw will be in the format - // [value type, 0 (null byte to indicate end of empty element name), value bytes..] - if err := vwFlusher.Flush(); err != nil { + buffer, err := Marshal(val) + if err != nil { return 0, nil, err } - buffer := *sw return bsontype.Type(buffer[0]), buffer[2:], nil } diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 1eaeda8169..4031bf38c7 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -27,6 +27,7 @@ var tInt32 = reflect.TypeOf(int32(0)) func TestMarshalWithRegistry(t *testing.T) { buf := new(bytes.Buffer) + enc := new(Encoder) for _, tc := range marshalingTestCases { t.Run(tc.name, func(t *testing.T) { var reg *bsoncodec.Registry @@ -38,7 +39,7 @@ func TestMarshalWithRegistry(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) noerr(t, err) - enc := NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(reg) err = enc.Encode(tc.val) noerr(t, err) @@ -53,6 +54,7 @@ func TestMarshalWithRegistry(t *testing.T) { func TestMarshalWithContext(t *testing.T) { buf := new(bytes.Buffer) + enc := new(Encoder) for _, tc := range marshalingTestCases { t.Run(tc.name, func(t *testing.T) { var reg *bsoncodec.Registry @@ -64,7 +66,7 @@ func TestMarshalWithContext(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) noerr(t, err) - enc := NewEncoder(vw) + enc.Reset(vw) enc.IntMinSize() enc.SetRegistry(reg) err = enc.Encode(tc.val) diff --git a/bson/marshal_value_test.go b/bson/marshal_value_test.go index cfc273f0de..66dec6c186 100644 --- a/bson/marshal_value_test.go +++ b/bson/marshal_value_test.go @@ -10,7 +10,6 @@ import ( "strings" "testing" - "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/internal/assert" ) @@ -35,83 +34,6 @@ func TestMarshalValue(t *testing.T) { }) } }) - t.Run("MarshalValueAppend", func(t *testing.T) { - t.Parallel() - - for _, tc := range marshalValueTestCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - valueType, valueBytes, err := MarshalValueAppend(nil, tc.val) - assert.Nil(t, err, "MarshalValueAppend error: %v", err) - compareMarshalValueResults(t, tc, valueType, valueBytes) - }) - } - }) - t.Run("MarshalValueWithRegistry", func(t *testing.T) { - t.Parallel() - - for _, tc := range marshalValueTestCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - valueType, valueBytes, err := MarshalValueWithRegistry(DefaultRegistry, tc.val) - assert.Nil(t, err, "MarshalValueWithRegistry error: %v", err) - compareMarshalValueResults(t, tc, valueType, valueBytes) - }) - } - }) - t.Run("MarshalValueWithContext", func(t *testing.T) { - t.Parallel() - - ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} - for _, tc := range marshalValueTestCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - valueType, valueBytes, err := MarshalValueWithContext(ec, tc.val) - assert.Nil(t, err, "MarshalValueWithContext error: %v", err) - compareMarshalValueResults(t, tc, valueType, valueBytes) - }) - } - }) - t.Run("MarshalValueAppendWithRegistry", func(t *testing.T) { - t.Parallel() - - for _, tc := range marshalValueTestCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - valueType, valueBytes, err := MarshalValueAppendWithRegistry(DefaultRegistry, nil, tc.val) - assert.Nil(t, err, "MarshalValueAppendWithRegistry error: %v", err) - compareMarshalValueResults(t, tc, valueType, valueBytes) - }) - } - }) - t.Run("MarshalValueAppendWithContext", func(t *testing.T) { - t.Parallel() - - ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} - for _, tc := range marshalValueTestCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - valueType, valueBytes, err := MarshalValueAppendWithContext(ec, nil, tc.val) - assert.Nil(t, err, "MarshalValueWithContext error: %v", err) - compareMarshalValueResults(t, tc, valueType, valueBytes) - }) - } - }) } func compareMarshalValueResults(t *testing.T, tc marshalValueTestCase, gotType bsontype.Type, gotBytes []byte) { diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 4ef95d4560..4ff8fb57e8 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -85,12 +85,13 @@ var sampleItems = []testItemType{ func TestMarshalSampleItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range sampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -170,12 +171,13 @@ var allItems = []testItemType{ func TestMarshalAllItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range allItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -295,12 +297,13 @@ func TestPtrInline(t *testing.T) { } buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, cs := range cases { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(cs.In) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -384,12 +387,13 @@ var oneWayMarshalItems = []testItemType{ func TestOneWayMarshalItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range oneWayMarshalItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -421,12 +425,13 @@ var structSampleItems = []testItemType{ func TestMarshalStructSampleItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range structSampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -588,12 +593,13 @@ var structItems = []testItemType{ func TestMarshalStructItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range structItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -664,12 +670,13 @@ var marshalItems = []testItemType{ func TestMarshalOneWayItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range marshalItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -777,12 +784,13 @@ var marshalErrorItems = []testItemType{ func TestMarshalErrorItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range marshalErrorItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) @@ -1079,6 +1087,7 @@ type docWithGetterField struct { func TestMarshalAllItemsWithGetter(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range allItems { if item.data == "" { continue @@ -1089,7 +1098,7 @@ func TestMarshalAllItemsWithGetter(t *testing.T) { obj.Field = &typeWithGetter{result: item.obj.(bson.M)["_"]} vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) diff --git a/mongo/change_stream.go b/mongo/change_stream.go index b7e16f99aa..a9a2011304 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -7,6 +7,7 @@ package mongo import ( + "bytes" "context" "errors" "fmt" @@ -16,6 +17,8 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" + "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/mongo/description" @@ -150,6 +153,21 @@ func mergeChangeStreamOptions(opts ...*options.ChangeStreamOptions) *options.Cha return csOpts } +func marshalValueWithRegistry(r *bsoncodec.Registry, val interface{}) (bsontype.Type, []byte, error) { + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + if err != nil { + return 0, nil, err + } + enc := bson.NewEncoder(vw) + enc.SetRegistry(r) + if err = enc.Encode(val); err != nil { + return 0, nil, err + } + hdr := buf.Next(2) + return bsontype.Type(hdr[0]), buf.Bytes(), nil +} + func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline interface{}, opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { if ctx == nil { @@ -212,7 +230,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in // any errors from Marshaling. customOptions := make(map[string]bsoncore.Value) for optionName, optionValue := range cs.options.Custom { - bsonType, bsonData, err := bson.MarshalValueWithRegistry(cs.registry, optionValue) + bsonType, bsonData, err := marshalValueWithRegistry(cs.registry, optionValue) if err != nil { cs.err = err closeImplicitSession(cs.sess) @@ -228,7 +246,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in // any errors from Marshaling. cs.pipelineOptions = make(map[string]bsoncore.Value) for optionName, optionValue := range cs.options.CustomPipeline { - bsonType, bsonData, err := bson.MarshalValueWithRegistry(cs.registry, optionValue) + bsonType, bsonData, err := marshalValueWithRegistry(cs.registry, optionValue) if err != nil { cs.err = err closeImplicitSession(cs.sess) diff --git a/mongo/collection.go b/mongo/collection.go index 55db6692df..0f6fad3684 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -7,6 +7,7 @@ package mongo import ( + "bytes" "context" "errors" "fmt" @@ -16,6 +17,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/internal/csfle" @@ -1075,12 +1077,22 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { // Marshal all custom options before passing to the aggregate operation. Return // any errors from Marshaling. customOptions := make(map[string]bsoncore.Value) + buf := new(bytes.Buffer) + enc := new(bson.Encoder) for optionName, optionValue := range ao.Custom { - bsonType, bsonData, err := bson.MarshalValueWithRegistry(a.registry, optionValue) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) if err != nil { return nil, err } - optionValueBSON := bsoncore.Value{Type: bsonType, Data: bsonData} + enc.Reset(vw) + enc.SetRegistry(a.registry) + err = enc.Encode(optionValue) + if err != nil { + return nil, err + } + hdr := buf.Next(2) + optionValueBSON := bsoncore.Value{Type: bsontype.Type(hdr[0]), Data: buf.Bytes()} customOptions[optionName] = optionValueBSON } op.CustomOptions(customOptions) diff --git a/mongo/cursor.go b/mongo/cursor.go index c48814f42b..67db1c2953 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -93,6 +93,7 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco // Convert documents slice to a sequence-style byte array. buf := new(bytes.Buffer) + enc := new(bson.Encoder) for _, doc := range documents { switch t := doc.(type) { case nil: @@ -105,11 +106,9 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco if err != nil { return nil, err } - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(registry) err = enc.Encode(doc) - // var marshalErr error - // docsBytes, marshalErr = bson.MarshalAppendWithRegistry(registry, docsBytes, doc) if err != nil { return nil, err } diff --git a/mongo/options/mongooptions.go b/mongo/options/mongooptions.go index 95d88f6e96..bf1bc69adb 100644 --- a/mongo/options/mongooptions.go +++ b/mongo/options/mongooptions.go @@ -134,13 +134,14 @@ func (af *ArrayFilters) ToArray() ([]bson.Raw, error) { } filters := make([]bson.Raw, 0, len(af.Filters)) buf := new(bytes.Buffer) + enc := new(bson.Encoder) for _, f := range af.Filters { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) if err != nil { return nil, err } - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(registry) err = enc.Encode(f) if err != nil { @@ -163,13 +164,14 @@ func (af *ArrayFilters) ToArrayDocument() (bson.Raw, error) { idx, arr := bsoncore.AppendArrayStart(nil) buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, f := range af.Filters { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) if err != nil { return nil, err } - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(registry) err = enc.Encode(f) if err != nil {