From 6d0c365ae491417e071735660ff213f79b6e1a2e Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 4 Oct 2023 15:19:17 -0400 Subject: [PATCH 01/12] WIP --- benchmark/canary.go | 27 --------------------------- bson/decoder.go | 35 ++--------------------------------- bson/decoder_test.go | 38 ++++++-------------------------------- bson/encoder.go | 27 --------------------------- bson/marshal.go | 14 +++----------- bson/unmarshal.go | 10 ++-------- mongo/cursor.go | 5 +---- 7 files changed, 14 insertions(+), 142 deletions(-) delete mode 100644 benchmark/canary.go diff --git a/benchmark/canary.go b/benchmark/canary.go deleted file mode 100644 index 8742c79a9c..0000000000 --- a/benchmark/canary.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package benchmark - -import ( - "context" -) - -// CanaryIncCase is a no-op. -// -// Deprecated: CanaryIncCase has no observable effect, so recent versions of the Go compiler may -// bypass calls to it in the compiled binary. It should not be used in benchmarks. -func CanaryIncCase(context.Context, TimerManager, int) error { - return nil -} - -// GlobalCanaryIncCase is a no-op. -// -// Deprecated: GlobalCanaryIncCase has no observable effect, so recent versions of the Go compiler -// may bypass calls to it in the compiled binary. It should not be used in benchmarks. -func GlobalCanaryIncCase(context.Context, TimerManager, int) error { - return nil -} diff --git a/bson/decoder.go b/bson/decoder.go index eac74cd399..c455aba568 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -58,24 +58,6 @@ func NewDecoder(vr bsonrw.ValueReader) (*Decoder, error) { }, nil } -// NewDecoderWithContext returns a new decoder that uses DecodeContext dc to read from vr. -// -// Deprecated: Use [NewDecoder] and use the Decoder configuration methods set the desired unmarshal -// behavior instead. -func NewDecoderWithContext(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader) (*Decoder, error) { - if dc.Registry == nil { - dc.Registry = DefaultRegistry - } - if vr == nil { - return nil, errors.New("cannot create a new Decoder with a nil ValueReader") - } - - return &Decoder{ - dc: dc, - vr: vr, - }, nil -} - // Decode reads the next BSON document from the stream and decodes it into the // value pointed to by val. // @@ -136,26 +118,13 @@ func (d *Decoder) Decode(val interface{}) error { // Reset will reset the state of the decoder, using the same *DecodeContext used in // the original construction but using vr for reading. -func (d *Decoder) Reset(vr bsonrw.ValueReader) error { - // TODO:(GODRIVER-2719): Remove error return value. +func (d *Decoder) Reset(vr bsonrw.ValueReader) { d.vr = vr - return nil } // SetRegistry replaces the current registry of the decoder with r. -func (d *Decoder) SetRegistry(r *bsoncodec.Registry) error { - // TODO:(GODRIVER-2719): Remove error return value. +func (d *Decoder) SetRegistry(r *bsoncodec.Registry) { d.dc.Registry = r - return nil -} - -// SetContext replaces the current registry of the decoder with dc. -// -// Deprecated: Use the Decoder configuration methods to set the desired unmarshal behavior instead. -func (d *Decoder) SetContext(dc bsoncodec.DecodeContext) error { - // TODO:(GODRIVER-2719): Remove error return value. - d.dc = dc - return nil } // DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This diff --git a/bson/decoder_test.go b/bson/decoder_test.go index c91f4e0491..1295b2617a 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -58,7 +58,7 @@ func TestDecoderv2(t *testing.T) { got := reflect.New(tc.sType).Interface() vr := bsonrw.NewBSONDocumentReader(tc.data) - dec, err := NewDecoderWithContext(bsoncodec.DecodeContext{Registry: DefaultRegistry}, vr) + dec, err := NewDecoder(vr) noerr(t, err) err = dec.Decode(got) noerr(t, err) @@ -177,8 +177,7 @@ func TestDecoderv2(t *testing.T) { t.Run("errors", func(t *testing.T) { t.Parallel() - dc := bsoncodec.DecodeContext{Registry: DefaultRegistry} - _, got := NewDecoderWithContext(dc, nil) + _, got := NewDecoder(nil) want := errors.New("cannot create a new Decoder with a nil ValueReader") if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { t.Errorf("Was expecting error but got different error. got %v; want %v", got, want) @@ -187,13 +186,7 @@ func TestDecoderv2(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() - got, err := NewDecoderWithContext(bsoncodec.DecodeContext{}, bsonrw.NewBSONDocumentReader([]byte{})) - noerr(t, err) - if got == nil { - t.Errorf("Was expecting a non-nil Decoder, but got ") - } - dc := bsoncodec.DecodeContext{Registry: DefaultRegistry} - got, err = NewDecoderWithContext(dc, bsonrw.NewBSONDocumentReader([]byte{})) + got, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{})) noerr(t, err) if got == nil { t.Errorf("Was expecting a non-nil Decoder, but got ") @@ -224,34 +217,16 @@ func TestDecoderv2(t *testing.T) { t.Parallel() vr1, vr2 := bsonrw.NewBSONDocumentReader([]byte{}), bsonrw.NewBSONDocumentReader([]byte{}) - dc := bsoncodec.DecodeContext{Registry: DefaultRegistry} - dec, err := NewDecoderWithContext(dc, vr1) + dec, err := NewDecoder(vr1) noerr(t, err) if dec.vr != vr1 { t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr1) } - err = dec.Reset(vr2) - noerr(t, err) + dec.Reset(vr2) if dec.vr != vr2 { t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr2) } }) - t.Run("SetContext", func(t *testing.T) { - t.Parallel() - - dc1 := bsoncodec.DecodeContext{Registry: DefaultRegistry} - dc2 := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()} - dec, err := NewDecoderWithContext(dc1, bsonrw.NewBSONDocumentReader([]byte{})) - noerr(t, err) - if !reflect.DeepEqual(dec.dc, dc1) { - t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1) - } - err = dec.SetContext(dc2) - noerr(t, err) - if !reflect.DeepEqual(dec.dc, dc2) { - t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2) - } - }) t.Run("SetRegistry", func(t *testing.T) { t.Parallel() @@ -263,8 +238,7 @@ func TestDecoderv2(t *testing.T) { if !reflect.DeepEqual(dec.dc, dc1) { t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1) } - err = dec.SetRegistry(r2) - noerr(t, err) + dec.SetRegistry(r2) if !reflect.DeepEqual(dec.dc, dc2) { t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2) } diff --git a/bson/encoder.go b/bson/encoder.go index 0be2a97fbc..1cf1759b45 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -53,24 +53,6 @@ func NewEncoder(vw bsonrw.ValueWriter) (*Encoder, error) { }, nil } -// NewEncoderWithContext returns a new encoder that uses EncodeContext ec to write to vw. -// -// Deprecated: Use [NewEncoder] and use the Encoder configuration methods to set the desired marshal -// behavior instead. -func NewEncoderWithContext(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter) (*Encoder, error) { - if ec.Registry == nil { - ec = bsoncodec.EncodeContext{Registry: DefaultRegistry} - } - if vw == nil { - return nil, errors.New("cannot create a new Encoder with a nil ValueWriter") - } - - return &Encoder{ - ec: ec, - vw: vw, - }, nil -} - // Encode writes the BSON encoding of val to the stream. // // See [Marshal] for details about BSON marshaling behavior. @@ -134,15 +116,6 @@ func (e *Encoder) SetRegistry(r *bsoncodec.Registry) error { return nil } -// SetContext replaces the current EncodeContext of the encoder with ec. -// -// Deprecated: Use the Encoder configuration methods set the desired marshal behavior instead. -func (e *Encoder) SetContext(ec bsoncodec.EncodeContext) error { - // TODO:(GODRIVER-2719): Remove error return value. - e.ec = ec - return nil -} - // ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in // the marshaled BSON when the "inline" struct tag option is set. func (e *Encoder) ErrorOnInlineDuplicates() { diff --git a/bson/marshal.go b/bson/marshal.go index 17ce6697e0..4dea5668d7 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -200,10 +200,7 @@ func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interf if err != nil { return nil, err } - err = enc.SetContext(ec) - if err != nil { - return nil, err - } + enc.ec = ec err = enc.Encode(val) if err != nil { @@ -274,9 +271,7 @@ func MarshalValueAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val i if err := enc.Reset(vwFlusher); err != nil { return 0, nil, err } - if err := enc.SetContext(ec); err != nil { - return 0, nil, err - } + enc.ec = ec if err := enc.Encode(val); err != nil { return 0, nil, err } @@ -417,10 +412,7 @@ func MarshalExtJSONAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val if err != nil { return nil, err } - err = enc.SetContext(ec) - if err != nil { - return nil, err - } + enc.ec = ec err = enc.Encode(val) if err != nil { diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 66da17ee01..28d0233e05 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -164,14 +164,8 @@ func unmarshalFromReader(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val dec := decPool.Get().(*Decoder) defer decPool.Put(dec) - err := dec.Reset(vr) - if err != nil { - return err - } - err = dec.SetContext(dc) - if err != nil { - return err - } + dec.Reset(vr) + dec.dc = dc return dec.Decode(val) } diff --git a/mongo/cursor.go b/mongo/cursor.go index d2228ed9c4..8a39ae9bd8 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -255,10 +255,7 @@ func getDecoder( } if reg != nil { - // TODO:(GODRIVER-2719): Remove error handling. - if err := dec.SetRegistry(reg); err != nil { - return nil, err - } + dec.SetRegistry(reg) } return dec, nil From 904a1ea1a9f8a8e3afb12738254a50f54dd81f83 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 4 Oct 2023 15:48:11 -0400 Subject: [PATCH 02/12] WIP --- bson/encoder.go | 18 ++++------------- bson/encoder_example_test.go | 30 ++++++----------------------- bson/encoder_test.go | 9 +++------ bson/marshal.go | 18 +++++------------ bson/marshal_test.go | 18 ----------------- bson/primitive_codecs_test.go | 3 +-- internal/codecutil/encoding_test.go | 3 +-- mongo/mongo.go | 10 ++-------- 8 files changed, 22 insertions(+), 87 deletions(-) diff --git a/bson/encoder.go b/bson/encoder.go index 1cf1759b45..4d22c7fbd4 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -7,7 +7,6 @@ package bson import ( - "errors" "reflect" "sync" @@ -41,16 +40,11 @@ type Encoder struct { } // NewEncoder returns a new encoder that uses the DefaultRegistry to write to vw. -func NewEncoder(vw bsonrw.ValueWriter) (*Encoder, error) { - // TODO:(GODRIVER-2719): Remove error return value. - if vw == nil { - return nil, errors.New("cannot create a new Encoder with a nil ValueWriter") - } - +func NewEncoder(vw bsonrw.ValueWriter) *Encoder { return &Encoder{ ec: bsoncodec.EncodeContext{Registry: DefaultRegistry}, vw: vw, - }, nil + } } // Encode writes the BSON encoding of val to the stream. @@ -103,17 +97,13 @@ func (e *Encoder) Encode(val interface{}) error { // Reset will reset the state of the Encoder, using the same *EncodeContext used in // the original construction but using vw. -func (e *Encoder) Reset(vw bsonrw.ValueWriter) error { - // TODO:(GODRIVER-2719): Remove error return value. +func (e *Encoder) Reset(vw bsonrw.ValueWriter) { e.vw = vw - return nil } // SetRegistry replaces the current registry of the Encoder with r. -func (e *Encoder) SetRegistry(r *bsoncodec.Registry) error { - // TODO:(GODRIVER-2719): Remove error return value. +func (e *Encoder) SetRegistry(r *bsoncodec.Registry) { e.ec.Registry = r - return nil } // ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in diff --git a/bson/encoder_example_test.go b/bson/encoder_example_test.go index 054c6497ec..c9a9f76d3f 100644 --- a/bson/encoder_example_test.go +++ b/bson/encoder_example_test.go @@ -22,10 +22,7 @@ func ExampleEncoder() { if err != nil { panic(err) } - encoder, err := bson.NewEncoder(vw) - if err != nil { - panic(err) - } + encoder := bson.NewEncoder(vw) type Product struct { Name string `bson:"name"` @@ -66,10 +63,7 @@ func ExampleEncoder_StringifyMapKeysWithFmt() { if err != nil { panic(err) } - encoder, err := bson.NewEncoder(vw) - if err != nil { - panic(err) - } + encoder := bson.NewEncoder(vw) // Configure the Encoder to convert Go map keys to BSON document field names // using fmt.Sprintf instead of the default string conversion logic. @@ -97,10 +91,7 @@ func ExampleEncoder_UseJSONStructTags() { if err != nil { panic(err) } - encoder, err := bson.NewEncoder(vw) - if err != nil { - panic(err) - } + encoder := bson.NewEncoder(vw) type Product struct { Name string `json:"name"` @@ -136,10 +127,7 @@ func ExampleEncoder_multipleBSONDocuments() { if err != nil { panic(err) } - encoder, err := bson.NewEncoder(vw) - if err != nil { - panic(err) - } + encoder := bson.NewEncoder(vw) type Coordinate struct { X int @@ -186,10 +174,7 @@ func ExampleEncoder_extendedJSON() { if err != nil { panic(err) } - encoder, err := bson.NewEncoder(vw) - if err != nil { - panic(err) - } + encoder := bson.NewEncoder(vw) type Product struct { Name string `bson:"name"` @@ -221,10 +206,7 @@ func ExampleEncoder_multipleExtendedJSONDocuments() { if err != nil { panic(err) } - encoder, err := bson.NewEncoder(vw) - if err != nil { - panic(err) - } + encoder := bson.NewEncoder(vw) type Coordinate struct { X int diff --git a/bson/encoder_test.go b/bson/encoder_test.go index 9458b8d06e..597c74b7be 100644 --- a/bson/encoder_test.go +++ b/bson/encoder_test.go @@ -47,8 +47,7 @@ func TestEncoderEncode(t *testing.T) { got := make(bsonrw.SliceWriter, 0, 1024) vw, err := bsonrw.NewBSONValueWriter(&got) noerr(t, err) - enc, err := NewEncoder(vw) - noerr(t, err) + enc := NewEncoder(vw) err = enc.Encode(tc.val) noerr(t, err) @@ -105,8 +104,7 @@ func TestEncoderEncode(t *testing.T) { vw, err = bsonrw.NewBSONValueWriter(&b) noerr(t, err) } - enc, err := NewEncoder(vw) - noerr(t, err) + enc := NewEncoder(vw) got := enc.Encode(marshaler) want := tc.wanterr if !compareErrors(got, want) { @@ -285,8 +283,7 @@ func TestEncoderConfiguration(t *testing.T) { got := new(bytes.Buffer) vw, err := bsonrw.NewBSONValueWriter(got) require.NoError(t, err, "bsonrw.NewBSONValueWriter error") - enc, err := NewEncoder(vw) - require.NoError(t, err, "NewEncoder error") + enc := NewEncoder(vw) tc.configure(enc) diff --git a/bson/marshal.go b/bson/marshal.go index 4dea5668d7..d90c1f82da 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -196,13 +196,10 @@ func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interf enc := encPool.Get().(*Encoder) defer encPool.Put(enc) - err := enc.Reset(vw) - if err != nil { - return nil, err - } + enc.Reset(vw) enc.ec = ec - err = enc.Encode(val) + err := enc.Encode(val) if err != nil { return nil, err } @@ -268,9 +265,7 @@ func MarshalValueAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val i // get an Encoder and encode the value enc := encPool.Get().(*Encoder) defer encPool.Put(enc) - if err := enc.Reset(vwFlusher); err != nil { - return 0, nil, err - } + enc.Reset(vwFlusher) enc.ec = ec if err := enc.Encode(val); err != nil { return 0, nil, err @@ -408,13 +403,10 @@ func MarshalExtJSONAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val enc := encPool.Get().(*Encoder) defer encPool.Put(enc) - err := enc.Reset(ejvw) - if err != nil { - return nil, err - } + enc.Reset(ejvw) enc.ec = ec - err = enc.Encode(val) + err := enc.Encode(val) if err != nil { return nil, err } diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 54b27dfcf1..1df1db3db0 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -109,24 +109,6 @@ func TestMarshalWithContext(t *testing.T) { } } -func TestMarshalAppend(t *testing.T) { - for _, tc := range marshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - if tc.reg != nil { - t.Skip() // test requires custom registry - } - dst := make([]byte, 0, 1024) - got, err := MarshalAppend(dst, tc.val) - noerr(t, err) - - if !bytes.Equal(got, tc.want) { - t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) - t.Errorf("Bytes:\n%v\n%v", got, tc.want) - } - }) - } -} - func TestMarshalExtJSONAppendWithContext(t *testing.T) { t.Run("MarshalExtJSONAppendWithContext", func(t *testing.T) { dst := make([]byte, 0, 1024) diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index 466f135e83..1e6554cbee 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -468,8 +468,7 @@ func TestDefaultValueEncoders(t *testing.T) { b := make(bsonrw.SliceWriter, 0, 512) vw, err := bsonrw.NewBSONValueWriter(&b) noerr(t, err) - enc, err := NewEncoder(vw) - noerr(t, err) + enc := NewEncoder(vw) err = enc.Encode(tc.value) if err != tc.err { t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) diff --git a/internal/codecutil/encoding_test.go b/internal/codecutil/encoding_test.go index 9696048f71..707d961cf7 100644 --- a/internal/codecutil/encoding_test.go +++ b/internal/codecutil/encoding_test.go @@ -24,8 +24,7 @@ func testEncFn(t *testing.T) EncoderFn { rw, err := bsonrw.NewBSONValueWriter(w) require.NoError(t, err, "failed to construct BSONValue writer") - enc, err := bson.NewEncoder(rw) - require.NoError(t, err, "failed to construct encoder") + enc := bson.NewEncoder(rw) return enc, nil } diff --git a/mongo/mongo.go b/mongo/mongo.go index 393c5b7713..fc29c13bee 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -91,10 +91,7 @@ func getEncoder( reg *bsoncodec.Registry, ) (*bson.Encoder, error) { vw := bvwPool.Get(w) - enc, err := bson.NewEncoder(vw) - if err != nil { - return nil, err - } + enc := bson.NewEncoder(vw) if opts != nil { if opts.ErrorOnInlineDuplicates { @@ -124,10 +121,7 @@ func getEncoder( } if reg != nil { - // TODO:(GODRIVER-2719): Remove error handling. - if err := enc.SetRegistry(reg); err != nil { - return nil, err - } + enc.SetRegistry(reg) } return enc, nil From f10e808434abea27e2083f5c77000bc08e8391bc Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Thu, 5 Oct 2023 09:59:46 -0400 Subject: [PATCH 03/12] WIP --- benchmark/bson_map.go | 15 +++++-- benchmark/bson_struct.go | 13 ++++-- bson/bson_test.go | 14 +++++-- bson/bsoncodec/registry_examples_test.go | 23 +++++++++-- bson/marshal.go | 34 ++++++---------- bson/marshal_test.go | 20 ++++++++-- bson/mgocompat/bson_test.go | 50 +++++++++++++++++++----- mongo/gridfs/bucket.go | 12 +++++- mongo/options/mongooptions.go | 26 ++++++++++-- 9 files changed, 151 insertions(+), 56 deletions(-) diff --git a/benchmark/bson_map.go b/benchmark/bson_map.go index 8fd56ee81e..8760692053 100644 --- a/benchmark/bson_map.go +++ b/benchmark/bson_map.go @@ -7,11 +7,13 @@ package benchmark import ( + "bytes" "context" "errors" "fmt" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsonrw" ) func bsonMapDecoding(tm TimerManager, iters int, dataSet string) error { @@ -47,15 +49,20 @@ func bsonMapEncoding(tm TimerManager, iters int, dataSet string) error { return err } - var buf []byte tm.ResetTimer() + buf := new(bytes.Buffer) for i := 0; i < iters; i++ { - buf, err = bson.MarshalAppend(buf[:0], doc) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) if err != nil { - return nil + return err + } + err = bson.NewEncoder(vw).Encode(doc) + if err != nil { + return err } - if len(buf) == 0 { + if buf.Len() == 0 { return errors.New("encoding failed") } } diff --git a/benchmark/bson_struct.go b/benchmark/bson_struct.go index 3fec93cc2c..3abf97ff26 100644 --- a/benchmark/bson_struct.go +++ b/benchmark/bson_struct.go @@ -7,10 +7,12 @@ package benchmark import ( + "bytes" "context" "errors" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsonrw" ) func BSONFlatStructDecoding(_ context.Context, tm TimerManager, iters int) error { @@ -70,15 +72,20 @@ func BSONFlatStructTagsEncoding(_ context.Context, tm TimerManager, iters int) e return err } - var buf []byte + buf := new(bytes.Buffer) tm.ResetTimer() for i := 0; i < iters; i++ { - buf, err = bson.MarshalAppend(buf[:0], doc) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) if err != nil { return err } - if len(buf) == 0 { + err = bson.NewEncoder(vw).Encode(doc) + if err != nil { + return err + } + if buf.Len() == 0 { return errors.New("encoding failed") } } diff --git a/bson/bson_test.go b/bson/bson_test.go index e2c1bf9e8b..876928b33a 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -18,6 +18,7 @@ import ( "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsonoptions" + "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" @@ -177,13 +178,20 @@ func TestMapCodec(t *testing.T) { {"true", bsonoptions.MapCodec().SetEncodeKeysWithStringer(true), "bar"}, {"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"}, } + buf := new(bytes.Buffer) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mapCodec := bsoncodec.NewMapCodec(tc.opts) mapRegistry := NewRegistryBuilder().RegisterDefaultEncoder(reflect.Map, mapCodec).Build() - val, err := MarshalWithRegistry(mapRegistry, mapObj) - assert.Nil(t, err, "Marshal error: %v", err) - assert.True(t, strings.Contains(string(val), tc.key), "expected result to contain %v, got: %v", tc.key, string(val)) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err) + enc := NewEncoder(vw) + enc.SetRegistry(mapRegistry) + err = enc.Encode(mapObj) + assert.Nil(t, err, "Encode error: %v", err) + str := buf.String() + assert.True(t, strings.Contains(str, tc.key), "expected result to contain %v, got: %v", tc.key, str) }) } }) diff --git a/bson/bsoncodec/registry_examples_test.go b/bson/bsoncodec/registry_examples_test.go index 9dc72bd503..49d2ef77f7 100644 --- a/bson/bsoncodec/registry_examples_test.go +++ b/bson/bsoncodec/registry_examples_test.go @@ -7,6 +7,7 @@ package bsoncodec_test import ( + "bytes" "fmt" "math" "reflect" @@ -66,11 +67,18 @@ func ExampleRegistry_customEncoder() { // Marshal the document as BSON. Expect that the int field is encoded to the // same value and that the negatedInt field is encoded as the negated value. - b, err := bson.MarshalWithRegistry(reg, doc) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) if err != nil { panic(err) } - fmt.Println(bson.Raw(b).String()) + enc := bson.NewEncoder(vw) + enc.SetRegistry(reg) + err = enc.Encode(doc) + if err != nil { + panic(err) + } + fmt.Println(bson.Raw(buf.Bytes()).String()) // Output: {"int": {"$numberInt":"1"},"negatedint": {"$numberInt":"-1"}} } @@ -200,11 +208,18 @@ func ExampleRegistry_RegisterKindEncoder() { // Marshal the document as BSON. Expect that all fields are encoded as BSON // int64 (represented as "$numberLong" when encoded as Extended JSON). - b, err := bson.MarshalWithRegistry(reg, doc) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + if err != nil { + panic(err) + } + enc := bson.NewEncoder(vw) + enc.SetRegistry(reg) + err = enc.Encode(doc) if err != nil { panic(err) } - fmt.Println(bson.Raw(b).String()) + fmt.Println(bson.Raw(buf.Bytes()).String()) // Output: {"myint": {"$numberLong":"1"},"int32": {"$numberLong":"1"},"int64": {"$numberLong":"1"}} } diff --git a/bson/marshal.go b/bson/marshal.go index d90c1f82da..aa1d82bffe 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -49,29 +49,19 @@ type ValueMarshaler interface { // marshal val into a []byte. Marshal will inspect struct tags and alter the // marshaling process accordingly. func Marshal(val interface{}) ([]byte, error) { - return MarshalWithRegistry(DefaultRegistry, val) -} + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + if err != nil { + return nil, err + } + enc := NewEncoder(vw) + enc.SetRegistry(DefaultRegistry) + err = enc.Encode(val) + if err != nil { + return nil, err + } -// MarshalAppend will encode val as a BSON document and append the bytes to dst. If dst is not large enough to hold the -// bytes, it will be grown. If val is not a type that can be transformed into a document, MarshalValueAppend should be -// used instead. -// -// Deprecated: Use [NewEncoder] and pass the dst byte slice (wrapped by a bytes.Buffer) into -// [bsonrw.NewBSONValueWriter]: -// -// buf := bytes.NewBuffer(dst) -// vw, err := bsonrw.NewBSONValueWriter(buf) -// if err != nil { -// panic(err) -// } -// enc, err := bson.NewEncoder(vw) -// if err != nil { -// panic(err) -// } -// -// See [Encoder] for more examples. -func MarshalAppend(dst []byte, val interface{}) ([]byte, error) { - return MarshalAppendWithRegistry(DefaultRegistry, dst, val) + return buf.Bytes(), nil } // MarshalWithRegistry returns the BSON encoding of val as a BSON document. If val is not a type that can be transformed diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 1df1db3db0..334bba7b85 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -69,6 +69,7 @@ func TestMarshalAppendWithContext(t *testing.T) { } func TestMarshalWithRegistry(t *testing.T) { + buf := new(bytes.Buffer) for _, tc := range marshalingTestCases { t.Run(tc.name, func(t *testing.T) { var reg *bsoncodec.Registry @@ -77,10 +78,15 @@ func TestMarshalWithRegistry(t *testing.T) { } else { reg = DefaultRegistry } - got, err := MarshalWithRegistry(reg, tc.val) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) + noerr(t, err) + enc := NewEncoder(vw) + enc.SetRegistry(reg) + err = enc.Encode(tc.val) noerr(t, err) - if !bytes.Equal(got, tc.want) { + if got := buf.Bytes(); !bytes.Equal(got, tc.want) { t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) t.Errorf("Bytes:\n%v\n%v", got, tc.want) } @@ -229,8 +235,14 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { )) assert.Equal(t, expectedFirst, Raw(first), "expected document %v, got %v", expectedFirst, Raw(first)) - second, err := MarshalWithRegistry(customReg, original) - assert.Nil(t, err, "Marshal error: %v", err) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err) + enc := NewEncoder(vw) + enc.SetRegistry(customReg) + err = enc.Encode(original) + assert.Nil(t, err, "Encode error: %v", err) + second := buf.Bytes() expectedSecond := Raw(bsoncore.BuildDocumentFromElements( nil, bsoncore.AppendInt32Element(nil, "x", -1), diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 9b4495dcc3..5cb818cc29 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -10,6 +10,7 @@ package mgocompat import ( + "bytes" "encoding/binary" "encoding/json" "errors" @@ -83,11 +84,18 @@ var sampleItems = []testItemType{ } func TestMarshalSampleItems(t *testing.T) { + buf := new(bytes.Buffer) for i, item := range sampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { - data, err := bson.MarshalWithRegistry(Registry, item.obj) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) - assert.Equal(t, string(data), item.data, "expected: %v, got: %v", item.data, string(data)) + str := buf.String() + assert.Equal(t, str, item.data, "expected: %v, got: %v", item.data, str) }) } } @@ -161,11 +169,18 @@ var allItems = []testItemType{ } func TestMarshalAllItems(t *testing.T) { + buf := new(bytes.Buffer) for i, item := range allItems { t.Run(strconv.Itoa(i), func(t *testing.T) { - data, err := bson.MarshalWithRegistry(Registry, item.obj) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - assert.Equal(t, string(data), wrapInDoc(item.data), "expected: %v, got: %v", wrapInDoc(item.data), string(data)) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(item.obj) + assert.Nil(t, err, "expected nil error, got: %v", err) + str := buf.String() + assert.Equal(t, str, wrapInDoc(item.data), "expected: %v, got: %v", wrapInDoc(item.data), str) }) } } @@ -207,21 +222,31 @@ func TestUnmarshalRawIncompatible(t *testing.T) { } func TestUnmarshalZeroesStruct(t *testing.T) { - data, err := bson.MarshalWithRegistry(Registry, bson.M{"b": 2}) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(bson.M{"b": 2}) assert.Nil(t, err, "expected nil error, got: %v", err) type T struct{ A, B int } v := T{A: 1} - err = bson.UnmarshalWithRegistry(Registry, data, &v) + err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), &v) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, 0, v.A, "expected: 0, got: %v", v.A) assert.Equal(t, 2, v.B, "expected: 2, got: %v", v.B) } func TestUnmarshalZeroesMap(t *testing.T) { - data, err := bson.MarshalWithRegistry(Registry, bson.M{"b": 2}) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(bson.M{"b": 2}) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{"a": 1} - err = bson.UnmarshalWithRegistry(Registry, data, &m) + err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), &m) assert.Nil(t, err, "expected nil error, got: %v", err) want := bson.M{"b": 2} @@ -229,11 +254,16 @@ func TestUnmarshalZeroesMap(t *testing.T) { } func TestUnmarshalNonNilInterface(t *testing.T) { - data, err := bson.MarshalWithRegistry(Registry, bson.M{"b": 2}) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(bson.M{"b": 2}) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{"a": 1} var i interface{} = m - err = bson.UnmarshalWithRegistry(Registry, data, &i) + err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), &i) assert.Nil(t, err, "expected nil error, got: %v", err) assert.True(t, reflect.DeepEqual(bson.M{"b": 2}, i), "expected: %v, got: %v", bson.M{"b": 2}, i) assert.True(t, reflect.DeepEqual(bson.M{"a": 1}, m), "expected: %v, got: %v", bson.M{"a": 1}, m) diff --git a/mongo/gridfs/bucket.go b/mongo/gridfs/bucket.go index c9f40744f2..83d9218393 100644 --- a/mongo/gridfs/bucket.go +++ b/mongo/gridfs/bucket.go @@ -15,6 +15,7 @@ import ( "time" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/mongo" @@ -652,12 +653,19 @@ func (b *Bucket) parseUploadOptions(opts ...*options.UploadOptions) (*Upload, er if uo.Metadata != nil { // TODO(GODRIVER-2726): Replace with marshal() and unmarshal() once the // TODO gridfs package is merged into the mongo package. - raw, err := bson.MarshalWithRegistry(uo.Registry, uo.Metadata) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + if err != nil { + return nil, err + } + enc := bson.NewEncoder(vw) + enc.SetRegistry(uo.Registry) + err = enc.Encode(uo.Metadata) if err != nil { return nil, err } var doc bson.D - unMarErr := bson.UnmarshalWithRegistry(uo.Registry, raw, &doc) + unMarErr := bson.UnmarshalWithRegistry(uo.Registry, buf.Bytes(), &doc) if unMarErr != nil { return nil, unMarErr } diff --git a/mongo/options/mongooptions.go b/mongo/options/mongooptions.go index fd17ce44e1..1fa42576d0 100644 --- a/mongo/options/mongooptions.go +++ b/mongo/options/mongooptions.go @@ -7,12 +7,14 @@ package options import ( + "bytes" "fmt" "reflect" "strconv" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) @@ -133,12 +135,20 @@ func (af *ArrayFilters) ToArray() ([]bson.Raw, error) { registry = bson.DefaultRegistry } filters := make([]bson.Raw, 0, len(af.Filters)) + buf := new(bytes.Buffer) for _, f := range af.Filters { - filter, err := bson.MarshalWithRegistry(registry, f) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) if err != nil { return nil, err } - filters = append(filters, filter) + enc := bson.NewEncoder(vw) + enc.SetRegistry(registry) + err = enc.Encode(f) + if err != nil { + return nil, err + } + filters = append(filters, buf.Bytes()) } return filters, nil } @@ -154,13 +164,21 @@ func (af *ArrayFilters) ToArrayDocument() (bson.Raw, error) { } idx, arr := bsoncore.AppendArrayStart(nil) + buf := new(bytes.Buffer) for i, f := range af.Filters { - filter, err := bson.MarshalWithRegistry(registry, f) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) + if err != nil { + return nil, err + } + enc := bson.NewEncoder(vw) + enc.SetRegistry(registry) + err = enc.Encode(f) if err != nil { return nil, err } - arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(i), filter) + arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(i), buf.Bytes()) } arr, _ = bsoncore.AppendArrayEnd(arr, idx) return arr, nil From fba7c471ec25bbfdf4719566ab61a7f0334328a4 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 29 Nov 2023 17:05:42 -0500 Subject: [PATCH 04/12] WIP --- bson/raw.go | 8 -------- bson/raw_value.go | 14 -------------- mongo/description/server.go | 8 +++++--- .../unified/gridfs_bucket_operation_execution.go | 2 +- mongo/mongo.go | 6 +++--- mongo/mongo_test.go | 2 +- 6 files changed, 10 insertions(+), 30 deletions(-) diff --git a/bson/raw.go b/bson/raw.go index 130da61ba0..cd5dab20be 100644 --- a/bson/raw.go +++ b/bson/raw.go @@ -29,14 +29,6 @@ func ReadDocument(r io.Reader) (Raw, error) { return Raw(doc), err } -// NewFromIOReader reads a BSON document from the io.Reader and returns it as a bson.Raw. If the -// reader contains multiple BSON documents, only the first document is read. -// -// Deprecated: Use ReadDocument instead. -func NewFromIOReader(r io.Reader) (Raw, error) { - return ReadDocument(r) -} - // Validate validates the document. This method only validates the first document in // the slice, to validate other documents, the slice must be resliced. func (r Raw) Validate() (err error) { return bsoncore.Document(r).Validate() } diff --git a/bson/raw_value.go b/bson/raw_value.go index 4d1bfb3160..b5855f62e4 100644 --- a/bson/raw_value.go +++ b/bson/raw_value.go @@ -271,20 +271,6 @@ func (rv RawValue) Int32() int32 { return convertToCoreValue(rv).Int32() } // panicking. func (rv RawValue) Int32OK() (int32, bool) { return convertToCoreValue(rv).Int32OK() } -// AsInt32 returns a BSON number as an int32. If the BSON type is not a numeric one, this method -// will panic. -// -// Deprecated: Use AsInt64 instead. If an int32 is required, convert the returned value to an int32 -// and perform any required overflow/underflow checking. -func (rv RawValue) AsInt32() int32 { return convertToCoreValue(rv).AsInt32() } - -// AsInt32OK is the same as AsInt32, except that it returns a boolean instead of -// panicking. -// -// Deprecated: Use AsInt64OK instead. If an int32 is required, convert the returned value to an -// int32 and perform any required overflow/underflow checking. -func (rv RawValue) AsInt32OK() (int32, bool) { return convertToCoreValue(rv).AsInt32OK() } - // Timestamp returns the BSON timestamp value the Value represents. It panics if the value is a // BSON type other than timestamp. func (rv RawValue) Timestamp() (t, i uint32) { return convertToCoreValue(rv).Timestamp() } diff --git a/mongo/description/server.go b/mongo/description/server.go index cf39423839..7e953796e3 100644 --- a/mongo/description/server.go +++ b/mongo/description/server.go @@ -202,13 +202,15 @@ func NewServer(addr address.Address, response bson.Raw) Server { } desc.CanonicalAddr = address.Address(me).Canonicalize() case "maxWireVersion": - versionRange.Max, ok = element.Value().AsInt32OK() + verMax, ok := element.Value().AsInt64OK() + versionRange.Max = int32(verMax) if !ok { desc.LastError = fmt.Errorf("expected 'maxWireVersion' to be an integer but it's a BSON %s", element.Value().Type) return desc } case "minWireVersion": - versionRange.Min, ok = element.Value().AsInt32OK() + verMin, ok := element.Value().AsInt64OK() + versionRange.Min = int32(verMin) if !ok { desc.LastError = fmt.Errorf("expected 'minWireVersion' to be an integer but it's a BSON %s", element.Value().Type) return desc @@ -220,7 +222,7 @@ func NewServer(addr address.Address, response bson.Raw) Server { return desc } case "ok": - okay, ok := element.Value().AsInt32OK() + okay, ok := element.Value().AsInt64OK() if !ok { desc.LastError = fmt.Errorf("expected 'ok' to be a boolean but it's a BSON %s", element.Value().Type) return desc diff --git a/mongo/integration/unified/gridfs_bucket_operation_execution.go b/mongo/integration/unified/gridfs_bucket_operation_execution.go index 3be6fded0c..9beea6d88b 100644 --- a/mongo/integration/unified/gridfs_bucket_operation_execution.go +++ b/mongo/integration/unified/gridfs_bucket_operation_execution.go @@ -148,7 +148,7 @@ func executeBucketDownloadByName(ctx context.Context, operation *operation) (*op case "filename": filename = val.StringValue() case "revision": - opts.SetRevision(val.AsInt32()) + opts.SetRevision(int32(val.AsInt64())) default: return nil, fmt.Errorf("unrecognized bucket download option %q", key) } diff --git a/mongo/mongo.go b/mongo/mongo.go index fc29c13bee..af95edd544 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -249,7 +249,7 @@ func marshalAggregatePipeline( registry *bsoncodec.Registry, ) (bsoncore.Document, bool, error) { switch t := pipeline.(type) { - case bsoncodec.ValueMarshaler: + case bson.ValueMarshaler: btype, val, err := t.MarshalBSONValue() if err != nil { return nil, false, err @@ -367,7 +367,7 @@ func marshalUpdateValue( u.Type = bsontype.EmbeddedDocument u.Data = t return u, documentCheckerFunc(u.Data) - case bsoncodec.Marshaler: + case bson.Marshaler: u.Type = bsontype.EmbeddedDocument u.Data, err = t.MarshalBSON() if err != nil { @@ -375,7 +375,7 @@ func marshalUpdateValue( } return u, documentCheckerFunc(u.Data) - case bsoncodec.ValueMarshaler: + case bson.ValueMarshaler: u.Type, u.Data, err = t.MarshalBSONValue() if err != nil { return u, err diff --git a/mongo/mongo_test.go b/mongo/mongo_test.go index fd3f1ce869..3b56e18a20 100644 --- a/mongo/mongo_test.go +++ b/mongo/mongo_test.go @@ -588,7 +588,7 @@ func TestMarshalValue(t *testing.T) { } } -var _ bsoncodec.ValueMarshaler = bvMarsh{} +var _ bson.ValueMarshaler = bvMarsh{} type bvMarsh struct { t bsontype.Type From c7f3a17aa696c9776bd46c9f43525a69149e5ef0 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 29 Nov 2023 19:21:55 -0500 Subject: [PATCH 05/12] WIP --- .../client_side_encryption_prose_test.go | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/mongo/integration/client_side_encryption_prose_test.go b/mongo/integration/client_side_encryption_prose_test.go index 22dde7c896..a0f70ffc5e 100644 --- a/mongo/integration/client_side_encryption_prose_test.go +++ b/mongo/integration/client_side_encryption_prose_test.go @@ -10,6 +10,7 @@ package integration import ( + "bytes" "context" "crypto/tls" "encoding/base64" @@ -24,6 +25,7 @@ import ( "time" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" @@ -2147,12 +2149,14 @@ func TestClientSideEncryptionProse(t *testing.T) { }) mt.RunOpts("18. Azure IMDS Credentials", noClientOpts, func(mt *mtest.T) { - buf := make([]byte, 0, 256) + buf := new(bytes.Buffer) kmsProvidersMap := map[string]map[string]interface{}{ "azure": {}, } - p, err := bson.MarshalAppend(buf[:0], kmsProvidersMap) - assert.Nil(mt, err, "error in MarshalAppendWithRegistry: %v", err) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(mt, err, "error in NewBSONValueWriter: %v", err) + err = bson.NewEncoder(vw).Encode(kmsProvidersMap) + assert.Nil(mt, err, "error in Encode: %v", err) getClient := func(header http.Header) *http.Client { lt := &localTransport{ @@ -2167,7 +2171,7 @@ func TestClientSideEncryptionProse(t *testing.T) { mt.Run("Case 1: Success", func(mt *mtest.T) { opts := &mongocryptopts.MongoCryptOptions{ - KmsProviders: p, + KmsProviders: buf.Bytes(), HTTPClient: getClient(nil), } crypt, err := mongocrypt.NewMongoCrypt(opts) @@ -2182,7 +2186,7 @@ func TestClientSideEncryptionProse(t *testing.T) { header := make(http.Header) header.Set("X-MongoDB-HTTP-TestParams", "case=empty-json") opts := &mongocryptopts.MongoCryptOptions{ - KmsProviders: p, + KmsProviders: buf.Bytes(), HTTPClient: getClient(header), } crypt, err := mongocrypt.NewMongoCrypt(opts) @@ -2194,7 +2198,7 @@ func TestClientSideEncryptionProse(t *testing.T) { header := make(http.Header) header.Set("X-MongoDB-HTTP-TestParams", "case=bad-json") opts := &mongocryptopts.MongoCryptOptions{ - KmsProviders: p, + KmsProviders: buf.Bytes(), HTTPClient: getClient(header), } crypt, err := mongocrypt.NewMongoCrypt(opts) @@ -2206,7 +2210,7 @@ func TestClientSideEncryptionProse(t *testing.T) { header := make(http.Header) header.Set("X-MongoDB-HTTP-TestParams", "case=404") opts := &mongocryptopts.MongoCryptOptions{ - KmsProviders: p, + KmsProviders: buf.Bytes(), HTTPClient: getClient(header), } crypt, err := mongocrypt.NewMongoCrypt(opts) @@ -2218,7 +2222,7 @@ func TestClientSideEncryptionProse(t *testing.T) { header := make(http.Header) header.Set("X-MongoDB-HTTP-TestParams", "case=500") opts := &mongocryptopts.MongoCryptOptions{ - KmsProviders: p, + KmsProviders: buf.Bytes(), HTTPClient: getClient(header), } crypt, err := mongocrypt.NewMongoCrypt(opts) @@ -2230,7 +2234,7 @@ func TestClientSideEncryptionProse(t *testing.T) { header := make(http.Header) header.Set("X-MongoDB-HTTP-TestParams", "case=slow") opts := &mongocryptopts.MongoCryptOptions{ - KmsProviders: p, + KmsProviders: buf.Bytes(), HTTPClient: getClient(header), } crypt, err := mongocrypt.NewMongoCrypt(opts) From d9a1e55b65064bc0324e5b7fc3540ed48b16e441 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 1 Dec 2023 19:15:04 -0500 Subject: [PATCH 06/12] WIP --- bson/marshal.go | 22 ----- bson/mgocompat/bson_test.go | 191 ++++++++++++++++++++++++++++-------- 2 files changed, 148 insertions(+), 65 deletions(-) diff --git a/bson/marshal.go b/bson/marshal.go index aa1d82bffe..a4671da680 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -64,28 +64,6 @@ func Marshal(val interface{}) ([]byte, error) { return buf.Bytes(), nil } -// MarshalWithRegistry returns the BSON encoding of val as a BSON document. If val is not a type that can be transformed -// into a document, MarshalValueWithRegistry should be used instead. -// -// Deprecated: Use [NewEncoder] and specify the Registry by calling [Encoder.SetRegistry] instead: -// -// buf := new(bytes.Buffer) -// vw, err := bsonrw.NewBSONValueWriter(buf) -// if err != nil { -// panic(err) -// } -// enc, err := bson.NewEncoder(vw) -// if err != nil { -// panic(err) -// } -// enc.SetRegistry(reg) -// -// See [Encoder] for more examples. -func MarshalWithRegistry(r *bsoncodec.Registry, val interface{}) ([]byte, error) { - dst := make([]byte, 0) - return MarshalAppendWithRegistry(r, dst, val) -} - // MarshalWithContext returns the BSON encoding of val as a BSON document using EncodeContext ec. If val is not a type // that can be transformed into a document, MarshalValueWithContext should be used instead. // diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 949420bd6d..4ef95d4560 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -294,12 +294,18 @@ func TestPtrInline(t *testing.T) { }, } + buf := new(bytes.Buffer) for i, cs := range cases { t.Run(strconv.Itoa(i), func(t *testing.T) { - data, err := bson.MarshalWithRegistry(Registry, cs.In) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(cs.In) assert.Nil(t, err, "expected nil error, got: %v", err) var dataBSON bson.M - err = bson.UnmarshalWithRegistry(Registry, data, &dataBSON) + err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), &dataBSON) assert.Nil(t, err, "expected nil error, got: %v", err) assert.True(t, reflect.DeepEqual(cs.Out, dataBSON), "expected: %v, got: %v", cs.Out, dataBSON) @@ -377,12 +383,18 @@ var oneWayMarshalItems = []testItemType{ } func TestOneWayMarshalItems(t *testing.T) { + buf := new(bytes.Buffer) for i, item := range oneWayMarshalItems { t.Run(strconv.Itoa(i), func(t *testing.T) { - data, err := bson.MarshalWithRegistry(Registry, item.obj) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) - assert.Equal(t, wrapInDoc(item.data), string(data), "expected: %v, got: %v", bson.Raw(wrapInDoc(item.data)), bson.Raw(data)) + assert.Equal(t, wrapInDoc(item.data), buf.String(), "expected: %v, got: %v", bson.Raw(wrapInDoc(item.data)), bson.Raw(buf.Bytes())) }) } } @@ -408,11 +420,17 @@ var structSampleItems = []testItemType{ } func TestMarshalStructSampleItems(t *testing.T) { + buf := new(bytes.Buffer) for i, item := range structSampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { - data, err := bson.MarshalWithRegistry(Registry, item.obj) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) - assert.Equal(t, item.data, string(data), "expected: %v, got: %v", item.data, string(data)) + assert.Equal(t, item.data, buf.String(), "expected: %v, got: %v", item.data, buf.String()) }) } } @@ -428,13 +446,18 @@ func TestUnmarshalStructSampleItems(t *testing.T) { func Test64bitInt(t *testing.T) { var i int64 = (1 << 31) if int(i) > 0 { - data, err := bson.MarshalWithRegistry(Registry, bson.M{"i": int(i)}) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(bson.M{"i": int(i)}) assert.Nil(t, err, "expected nil error, got: %v", err) want := wrapInDoc("\x12i\x00\x00\x00\x00\x80\x00\x00\x00\x00") - assert.Equal(t, want, string(data), "expected: %v, got: %v", want, string(data)) + assert.Equal(t, want, buf.String(), "expected: %v, got: %v", want, buf.String()) var result struct{ I int } - err = bson.UnmarshalWithRegistry(Registry, data, &result) + err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), &result) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, i, int64(result.I), "expected: %v, got: %v", i, int64(result.I)) } @@ -564,11 +587,17 @@ var structItems = []testItemType{ } func TestMarshalStructItems(t *testing.T) { + buf := new(bytes.Buffer) for i, item := range structItems { t.Run(strconv.Itoa(i), func(t *testing.T) { - data, err := bson.MarshalWithRegistry(Registry, item.obj) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - assert.Equal(t, wrapInDoc(item.data), string(data), "expected: %v, got: %v", wrapInDoc(item.data), string(data)) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(item.obj) + assert.Nil(t, err, "expected nil error, got: %v", err) + assert.Equal(t, wrapInDoc(item.data), buf.String(), "expected: %v, got: %v", wrapInDoc(item.data), buf.String()) }) } } @@ -634,11 +663,17 @@ var marshalItems = []testItemType{ } func TestMarshalOneWayItems(t *testing.T) { + buf := new(bytes.Buffer) for i, item := range marshalItems { t.Run(strconv.Itoa(i), func(t *testing.T) { - data, err := bson.MarshalWithRegistry(Registry, item.obj) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - assert.Equal(t, wrapInDoc(item.data), string(data), "expected: %v, got: %v", wrapInDoc(item.data), string(data)) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(item.obj) + assert.Nil(t, err, "expected nil error, got: %v", err) + assert.Equal(t, wrapInDoc(item.data), buf.String(), "expected: %v, got: %v", wrapInDoc(item.data), buf.String()) }) } } @@ -741,12 +776,18 @@ var marshalErrorItems = []testItemType{ } func TestMarshalErrorItems(t *testing.T) { + buf := new(bytes.Buffer) for i, item := range marshalErrorItems { t.Run(strconv.Itoa(i), func(t *testing.T) { - data, err := bson.MarshalWithRegistry(Registry, item.obj) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(item.obj) assert.NotNil(t, err, "expected error") - assert.Nil(t, data, " expected nil data, got: %v", data) + assert.Nil(t, buf.Bytes(), " expected nil data, got: %v", buf.Bytes()) }) } } @@ -1000,11 +1041,16 @@ func TestUnmarshalSetterErrSetZero(t *testing.T) { setterResult["foo"] = ErrSetZero defer delete(setterResult, "field") - data, err := bson.MarshalWithRegistry(Registry, bson.M{"field": "foo"}) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(bson.M{"field": "foo"}) assert.Nil(t, err, "expected nil error, got: %v", err) m := map[string]*setterType{} - err = bson.UnmarshalWithRegistry(Registry, data, m) + err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), m) assert.Nil(t, err, "expected nil error, got: %v", err) value, ok := m["field"] @@ -1032,27 +1078,38 @@ type docWithGetterField struct { } func TestMarshalAllItemsWithGetter(t *testing.T) { + buf := new(bytes.Buffer) for i, item := range allItems { if item.data == "" { continue } t.Run(strconv.Itoa(i), func(t *testing.T) { + buf.Reset() obj := &docWithGetterField{} obj.Field = &typeWithGetter{result: item.obj.(bson.M)["_"]} - data, err := bson.MarshalWithRegistry(Registry, obj) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) - assert.Equal(t, wrapInDoc(item.data), string(data), - "expected value at %v to be: %v, got: %v", i, wrapInDoc(item.data), string(data)) + assert.Equal(t, wrapInDoc(item.data), buf.String(), + "expected value at %v to be: %v, got: %v", i, wrapInDoc(item.data), buf.String()) }) } } func TestMarshalWholeDocumentWithGetter(t *testing.T) { obj := &typeWithGetter{result: sampleItems[0].obj} - data, err := bson.MarshalWithRegistry(Registry, obj) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) - assert.Equal(t, sampleItems[0].data, string(data), - "expected: %v, got: %v", sampleItems[0].data, string(data)) + assert.Equal(t, sampleItems[0].data, buf.String(), + "expected: %v, got: %v", sampleItems[0].data, buf.String()) } func TestGetterErrors(t *testing.T) { @@ -1060,14 +1117,24 @@ func TestGetterErrors(t *testing.T) { obj1 := &docWithGetterField{} obj1.Field = &typeWithGetter{sampleItems[0].obj, e} - data, err := bson.MarshalWithRegistry(Registry, obj1) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(obj1) assert.Equal(t, e, err, "expected error: %v, got: %v", e, err) - assert.Nil(t, data, "expected nil data, got: %v", data) + assert.Nil(t, buf.Bytes(), "expected nil data, got: %v", buf.Bytes()) obj2 := &typeWithGetter{sampleItems[0].obj, e} - data, err = bson.MarshalWithRegistry(Registry, obj2) + buf.Reset() + vw, err = bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc = bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(obj2) assert.Equal(t, e, err, "expected error: %v, got: %v", e, err) - assert.Nil(t, data, "expected nil data, got: %v", data) + assert.Nil(t, buf.Bytes(), "expected nil data, got: %v", buf.Bytes()) } type intGetter int64 @@ -1082,20 +1149,30 @@ type typeWithIntGetter struct { func TestMarshalShortWithGetter(t *testing.T) { obj := typeWithIntGetter{42} - data, err := bson.MarshalWithRegistry(Registry, obj) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{} - err = bson.UnmarshalWithRegistry(Registry, data, &m) + err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), &m) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, 42, m["v"], "expected m[\"v\"] to be: %v, got: %v", 42, m["v"]) } func TestMarshalWithGetterNil(t *testing.T) { obj := docWithGetterField{} - data, err := bson.MarshalWithRegistry(Registry, obj) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) m := bson.M{} - err = bson.UnmarshalWithRegistry(Registry, data, &m) + err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), &m) assert.Nil(t, err, "expected nil error, got: %v", err) want := bson.M{"_": ""} assert.Equal(t, want, m, "expected m[\"v\"] to be: %v, got: %v", want, m) @@ -1539,9 +1616,14 @@ var oneWayCrossItems = []crossTypeItem{ func testCrossPair(t *testing.T, dump interface{}, load interface{}) { zero := makeZeroDoc(load) - data, err := bson.MarshalWithRegistry(Registry, dump) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - err = bson.UnmarshalWithRegistry(Registry, data, zero) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(dump) + assert.Nil(t, err, "expected nil error, got: %v", err) + err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), zero) assert.Nil(t, err, "expected nil error, got: %v", err) assert.True(t, reflect.DeepEqual(load, zero), "expected: %v, got: %v", load, zero) @@ -1649,11 +1731,17 @@ func TestMarshalNotRespectNil(t *testing.T) { assert.Nil(t, testStruct1.BSlice, "expected nil byte slice, got: %v", testStruct1.BSlice) assert.Nil(t, testStruct1.Map, "expected nil map, got: %v", testStruct1.Map) - b, _ := bson.MarshalWithRegistry(Registry, testStruct1) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(testStruct1) + assert.Nil(t, err, "expected nil error, got: %v", err) testStruct2 := T{} - _ = bson.UnmarshalWithRegistry(Registry, b, &testStruct2) + _ = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), &testStruct2) assert.NotNil(t, testStruct2.Slice, "expected non-nil slice") assert.NotNil(t, testStruct2.BSlice, "expected non-nil byte slice") @@ -1677,15 +1765,21 @@ func TestMarshalRespectNil(t *testing.T) { assert.Nil(t, testStruct1.MapPtr, "expected nil map ptr, got: %v", testStruct1.MapPtr) assert.Nil(t, testStruct1.Ptr, "expected nil ptr, got: %v", testStruct1.Ptr) - b, _ := bson.MarshalWithRegistry(RegistryRespectNilValues, testStruct1) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(testStruct1) + assert.Nil(t, err, "expected nil error, got: %v", err) testStruct2 := T{} - _ = bson.UnmarshalWithRegistry(RegistryRespectNilValues, b, &testStruct2) + _ = bson.UnmarshalWithRegistry(RegistryRespectNilValues, buf.Bytes(), &testStruct2) - assert.Nil(t, testStruct2.Slice, "expected nil slice, got: %v", testStruct2.Slice) + assert.Len(t, testStruct2.Slice, 0, "expected empty slice, got: %v", testStruct2.Slice) assert.Nil(t, testStruct2.SlicePtr, "expected nil slice ptr, got: %v", testStruct2.SlicePtr) - assert.Nil(t, testStruct2.Map, "expected nil map, got: %v", testStruct2.Map) + assert.Len(t, testStruct2.Map, 0, "expected empty map, got: %v", testStruct2.Map) assert.Nil(t, testStruct2.MapPtr, "expected nil map ptr, got: %v", testStruct2.MapPtr) assert.Nil(t, testStruct2.Ptr, "expected nil ptr, got: %v", testStruct2.Ptr) @@ -1701,11 +1795,17 @@ func TestMarshalRespectNil(t *testing.T) { assert.NotNil(t, testStruct1.Map, "expected non-nil map") assert.NotNil(t, testStruct1.MapPtr, "expected non-nil map ptr") - b, _ = bson.MarshalWithRegistry(RegistryRespectNilValues, testStruct1) + buf.Reset() + vw, err = bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc = bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(testStruct1) + assert.Nil(t, err, "expected nil error, got: %v", err) testStruct2 = T{} - _ = bson.UnmarshalWithRegistry(RegistryRespectNilValues, b, &testStruct2) + _ = bson.UnmarshalWithRegistry(RegistryRespectNilValues, buf.Bytes(), &testStruct2) assert.NotNil(t, testStruct2.Slice, "expected non-nil slice") assert.NotNil(t, testStruct2.SlicePtr, "expected non-nil slice ptr") @@ -1732,11 +1832,16 @@ func TestInlineWithPointerToSelf(t *testing.T) { Value: "", } - bytes, err := bson.MarshalWithRegistry(Registry, x1) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err, "expected nil error, got: %v", err) + enc := bson.NewEncoder(vw) + enc.SetRegistry(Registry) + err = enc.Encode(x1) assert.Nil(t, err, "expected nil error, got: %v", err) var x2 InlineLoop - err = bson.UnmarshalWithRegistry(Registry, bytes, &x2) + err = bson.UnmarshalWithRegistry(Registry, buf.Bytes(), &x2) assert.Nil(t, err, "expected nil error, got: %v", err) assert.Equal(t, x1, x2, "Expected %v, got %v", x1, x2) } From 60e575653b6be5b66b128f6963db4ef3f0e22503 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Mon, 4 Dec 2023 18:10:54 -0500 Subject: [PATCH 07/12] WIP --- bson/marshal.go | 112 ---------------------------------------- bson/marshal_test.go | 55 ++++---------------- bson/truncation_test.go | 24 ++++++--- mongo/cursor.go | 20 ++++--- 4 files changed, 41 insertions(+), 170 deletions(-) diff --git a/bson/marshal.go b/bson/marshal.go index a4671da680..6fd42792bb 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -9,7 +9,6 @@ package bson import ( "bytes" "encoding/json" - "sync" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsonrw" @@ -64,117 +63,6 @@ func Marshal(val interface{}) ([]byte, error) { return buf.Bytes(), nil } -// MarshalWithContext returns the BSON encoding of val as a BSON document using EncodeContext ec. If val is not a type -// that can be transformed into a document, MarshalValueWithContext should be used instead. -// -// Deprecated: Use [NewEncoder] and use the Encoder configuration methods to set the desired marshal -// behavior instead: -// -// buf := bytes.NewBuffer(dst) -// vw, err := bsonrw.NewBSONValueWriter(buf) -// if err != nil { -// panic(err) -// } -// enc, err := bson.NewEncoder(vw) -// if err != nil { -// panic(err) -// } -// enc.IntMinSize() -// -// See [Encoder] for more examples. -func MarshalWithContext(ec bsoncodec.EncodeContext, val interface{}) ([]byte, error) { - dst := make([]byte, 0) - return MarshalAppendWithContext(ec, dst, val) -} - -// MarshalAppendWithRegistry will encode val as a BSON document using Registry r and append the bytes to dst. If dst is -// not large enough to hold the bytes, it will be grown. If val is not a type that can be transformed into a document, -// MarshalValueAppendWithRegistry should be used instead. -// -// Deprecated: Use [NewEncoder], and pass the dst byte slice (wrapped by a bytes.Buffer) into -// [bsonrw.NewBSONValueWriter], and specify the Registry by calling [Encoder.SetRegistry] instead: -// -// buf := bytes.NewBuffer(dst) -// vw, err := bsonrw.NewBSONValueWriter(buf) -// if err != nil { -// panic(err) -// } -// enc, err := bson.NewEncoder(vw) -// if err != nil { -// panic(err) -// } -// enc.SetRegistry(reg) -// -// See [Encoder] for more examples. -func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) ([]byte, error) { - return MarshalAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val) -} - -// Pool of buffers for marshalling BSON. -var bufPool = sync.Pool{ - New: func() interface{} { - return new(bytes.Buffer) - }, -} - -// MarshalAppendWithContext will encode val as a BSON document using Registry r and EncodeContext ec and append the -// bytes to dst. If dst is not large enough to hold the bytes, it will be grown. If val is not a type that can be -// transformed into a document, MarshalValueAppendWithContext should be used instead. -// -// Deprecated: Use [NewEncoder], pass the dst byte slice (wrapped by a bytes.Buffer) into -// [bsonrw.NewBSONValueWriter], and use the Encoder configuration methods to set the desired marshal -// behavior instead: -// -// buf := bytes.NewBuffer(dst) -// vw, err := bsonrw.NewBSONValueWriter(buf) -// if err != nil { -// panic(err) -// } -// enc, err := bson.NewEncoder(vw) -// if err != nil { -// panic(err) -// } -// enc.IntMinSize() -// -// See [Encoder] for more examples. -func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) ([]byte, error) { - sw := bufPool.Get().(*bytes.Buffer) - defer func() { - // Proper usage of a sync.Pool requires each entry to have approximately - // the same memory cost. To obtain this property when the stored type - // contains a variably-sized buffer, we add a hard limit on the maximum - // buffer to place back in the pool. We limit the size to 16MiB because - // that's the maximum wire message size supported by any current MongoDB - // server. - // - // Comment based on - // https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147 - // - // Recycle byte slices that are smaller than 16MiB and at least half - // occupied. - if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() { - bufPool.Put(sw) - } - }() - - sw.Reset() - vw := bvwPool.Get(sw) - defer bvwPool.Put(vw) - - enc := encPool.Get().(*Encoder) - defer encPool.Put(enc) - - enc.Reset(vw) - enc.ec = ec - - err := enc.Encode(val) - if err != nil { - return nil, err - } - - return append(dst, sw.Bytes()...), nil -} - // MarshalValue returns the BSON encoding of val. // // MarshalValue will use bson.DefaultRegistry to transform val into a BSON value. If val is a struct, this function will diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 334bba7b85..1eaeda8169 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -25,49 +25,6 @@ import ( var tInt32 = reflect.TypeOf(int32(0)) -func TestMarshalAppendWithRegistry(t *testing.T) { - for _, tc := range marshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - dst := make([]byte, 0, 1024) - var reg *bsoncodec.Registry - if tc.reg != nil { - reg = tc.reg - } else { - reg = DefaultRegistry - } - got, err := MarshalAppendWithRegistry(reg, dst, tc.val) - noerr(t, err) - - if !bytes.Equal(got, tc.want) { - t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) - t.Errorf("Bytes:\n%v\n%v", got, tc.want) - } - }) - } -} - -func TestMarshalAppendWithContext(t *testing.T) { - for _, tc := range marshalingTestCases { - t.Run(tc.name, func(t *testing.T) { - dst := make([]byte, 0, 1024) - var reg *bsoncodec.Registry - if tc.reg != nil { - reg = tc.reg - } else { - reg = DefaultRegistry - } - ec := bsoncodec.EncodeContext{Registry: reg} - got, err := MarshalAppendWithContext(ec, dst, tc.val) - noerr(t, err) - - if !bytes.Equal(got, tc.want) { - t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) - t.Errorf("Bytes:\n%v\n%v", got, tc.want) - } - }) - } -} - func TestMarshalWithRegistry(t *testing.T) { buf := new(bytes.Buffer) for _, tc := range marshalingTestCases { @@ -95,6 +52,7 @@ func TestMarshalWithRegistry(t *testing.T) { } func TestMarshalWithContext(t *testing.T) { + buf := new(bytes.Buffer) for _, tc := range marshalingTestCases { t.Run(tc.name, func(t *testing.T) { var reg *bsoncodec.Registry @@ -103,11 +61,16 @@ func TestMarshalWithContext(t *testing.T) { } else { reg = DefaultRegistry } - ec := bsoncodec.EncodeContext{Registry: reg} - got, err := MarshalWithContext(ec, tc.val) + buf.Reset() + vw, err := bsonrw.NewBSONValueWriter(buf) + noerr(t, err) + enc := NewEncoder(vw) + enc.IntMinSize() + enc.SetRegistry(reg) + err = enc.Encode(tc.val) noerr(t, err) - if !bytes.Equal(got, tc.want) { + if got := buf.Bytes(); !bytes.Equal(got, tc.want) { t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) t.Errorf("Bytes:\n%v\n%v", got, tc.want) } diff --git a/bson/truncation_test.go b/bson/truncation_test.go index c8ba759a33..1db7562884 100644 --- a/bson/truncation_test.go +++ b/bson/truncation_test.go @@ -7,9 +7,11 @@ package bson import ( + "bytes" "testing" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/internal/assert" ) @@ -29,9 +31,14 @@ func TestTruncation(t *testing.T) { inputVal := 4.7892 input := inputArgs{Name: inputName, Val: &inputVal} - ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} - doc, err := MarshalWithContext(ec, &input) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err) + enc := NewEncoder(vw) + enc.IntMinSize() + enc.SetRegistry(DefaultRegistry) + err = enc.Encode(&input) assert.Nil(t, err) var output outputArgs @@ -40,7 +47,7 @@ func TestTruncation(t *testing.T) { Truncate: true, } - err = UnmarshalWithContext(dc, doc, &output) + err = UnmarshalWithContext(dc, buf.Bytes(), &output) assert.Nil(t, err) assert.Equal(t, inputName, output.Name) @@ -51,9 +58,14 @@ func TestTruncation(t *testing.T) { inputVal := 7.382 input := inputArgs{Name: inputName, Val: &inputVal} - ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} - doc, err := MarshalWithContext(ec, &input) + buf := new(bytes.Buffer) + vw, err := bsonrw.NewBSONValueWriter(buf) + assert.Nil(t, err) + enc := NewEncoder(vw) + enc.IntMinSize() + enc.SetRegistry(DefaultRegistry) + err = enc.Encode(&input) assert.Nil(t, err) var output outputArgs @@ -63,7 +75,7 @@ func TestTruncation(t *testing.T) { } // case throws an error when truncation is disabled - err = UnmarshalWithContext(dc, doc, &output) + err = UnmarshalWithContext(dc, buf.Bytes(), &output) assert.NotNil(t, err) }) } diff --git a/mongo/cursor.go b/mongo/cursor.go index 3958568854..c48814f42b 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -7,6 +7,7 @@ package mongo import ( + "bytes" "context" "errors" "fmt" @@ -91,7 +92,7 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco } // Convert documents slice to a sequence-style byte array. - var docsBytes []byte + buf := new(bytes.Buffer) for _, doc := range documents { switch t := doc.(type) { case nil: @@ -100,15 +101,22 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. doc = bson.Raw(t) } - var marshalErr error - docsBytes, marshalErr = bson.MarshalAppendWithRegistry(registry, docsBytes, doc) - if marshalErr != nil { - return nil, marshalErr + vw, err := bsonrw.NewBSONValueWriter(buf) + if err != nil { + return nil, err + } + enc := bson.NewEncoder(vw) + enc.SetRegistry(registry) + err = enc.Encode(doc) + // var marshalErr error + // docsBytes, marshalErr = bson.MarshalAppendWithRegistry(registry, docsBytes, doc) + if err != nil { + return nil, err } } c := &Cursor{ - bc: driver.NewBatchCursorFromDocuments(docsBytes), + bc: driver.NewBatchCursorFromDocuments(buf.Bytes()), registry: registry, err: err, } From 37bb727d2a394ff3604c2f46a29455219b701beb Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Tue, 5 Dec 2023 16:41:49 -0500 Subject: [PATCH 08/12] WIP --- bson/bson_test.go | 3 +- bson/marshal_test.go | 6 ++- bson/marshal_value_test.go | 78 ----------------------------------- bson/mgocompat/bson_test.go | 27 ++++++++---- mongo/cursor.go | 5 +-- mongo/options/mongooptions.go | 6 ++- 6 files changed, 30 insertions(+), 95 deletions(-) diff --git a/bson/bson_test.go b/bson/bson_test.go index 876928b33a..357308a7d1 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -179,6 +179,7 @@ func TestMapCodec(t *testing.T) { {"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"}, } buf := new(bytes.Buffer) + enc := new(Encoder) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mapCodec := bsoncodec.NewMapCodec(tc.opts) @@ -186,7 +187,7 @@ func TestMapCodec(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err) - enc := NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(mapRegistry) err = enc.Encode(mapObj) assert.Nil(t, err, "Encode error: %v", err) diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 1eaeda8169..4031bf38c7 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -27,6 +27,7 @@ var tInt32 = reflect.TypeOf(int32(0)) func TestMarshalWithRegistry(t *testing.T) { buf := new(bytes.Buffer) + enc := new(Encoder) for _, tc := range marshalingTestCases { t.Run(tc.name, func(t *testing.T) { var reg *bsoncodec.Registry @@ -38,7 +39,7 @@ func TestMarshalWithRegistry(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) noerr(t, err) - enc := NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(reg) err = enc.Encode(tc.val) noerr(t, err) @@ -53,6 +54,7 @@ func TestMarshalWithRegistry(t *testing.T) { func TestMarshalWithContext(t *testing.T) { buf := new(bytes.Buffer) + enc := new(Encoder) for _, tc := range marshalingTestCases { t.Run(tc.name, func(t *testing.T) { var reg *bsoncodec.Registry @@ -64,7 +66,7 @@ func TestMarshalWithContext(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) noerr(t, err) - enc := NewEncoder(vw) + enc.Reset(vw) enc.IntMinSize() enc.SetRegistry(reg) err = enc.Encode(tc.val) diff --git a/bson/marshal_value_test.go b/bson/marshal_value_test.go index cfc273f0de..66dec6c186 100644 --- a/bson/marshal_value_test.go +++ b/bson/marshal_value_test.go @@ -10,7 +10,6 @@ import ( "strings" "testing" - "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/internal/assert" ) @@ -35,83 +34,6 @@ func TestMarshalValue(t *testing.T) { }) } }) - t.Run("MarshalValueAppend", func(t *testing.T) { - t.Parallel() - - for _, tc := range marshalValueTestCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - valueType, valueBytes, err := MarshalValueAppend(nil, tc.val) - assert.Nil(t, err, "MarshalValueAppend error: %v", err) - compareMarshalValueResults(t, tc, valueType, valueBytes) - }) - } - }) - t.Run("MarshalValueWithRegistry", func(t *testing.T) { - t.Parallel() - - for _, tc := range marshalValueTestCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - valueType, valueBytes, err := MarshalValueWithRegistry(DefaultRegistry, tc.val) - assert.Nil(t, err, "MarshalValueWithRegistry error: %v", err) - compareMarshalValueResults(t, tc, valueType, valueBytes) - }) - } - }) - t.Run("MarshalValueWithContext", func(t *testing.T) { - t.Parallel() - - ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} - for _, tc := range marshalValueTestCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - valueType, valueBytes, err := MarshalValueWithContext(ec, tc.val) - assert.Nil(t, err, "MarshalValueWithContext error: %v", err) - compareMarshalValueResults(t, tc, valueType, valueBytes) - }) - } - }) - t.Run("MarshalValueAppendWithRegistry", func(t *testing.T) { - t.Parallel() - - for _, tc := range marshalValueTestCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - valueType, valueBytes, err := MarshalValueAppendWithRegistry(DefaultRegistry, nil, tc.val) - assert.Nil(t, err, "MarshalValueAppendWithRegistry error: %v", err) - compareMarshalValueResults(t, tc, valueType, valueBytes) - }) - } - }) - t.Run("MarshalValueAppendWithContext", func(t *testing.T) { - t.Parallel() - - ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} - for _, tc := range marshalValueTestCases { - tc := tc - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - valueType, valueBytes, err := MarshalValueAppendWithContext(ec, nil, tc.val) - assert.Nil(t, err, "MarshalValueWithContext error: %v", err) - compareMarshalValueResults(t, tc, valueType, valueBytes) - }) - } - }) } func compareMarshalValueResults(t *testing.T, tc marshalValueTestCase, gotType bsontype.Type, gotBytes []byte) { diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 4ef95d4560..4ff8fb57e8 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -85,12 +85,13 @@ var sampleItems = []testItemType{ func TestMarshalSampleItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range sampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -170,12 +171,13 @@ var allItems = []testItemType{ func TestMarshalAllItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range allItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -295,12 +297,13 @@ func TestPtrInline(t *testing.T) { } buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, cs := range cases { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(cs.In) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -384,12 +387,13 @@ var oneWayMarshalItems = []testItemType{ func TestOneWayMarshalItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range oneWayMarshalItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -421,12 +425,13 @@ var structSampleItems = []testItemType{ func TestMarshalStructSampleItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range structSampleItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -588,12 +593,13 @@ var structItems = []testItemType{ func TestMarshalStructItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range structItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -664,12 +670,13 @@ var marshalItems = []testItemType{ func TestMarshalOneWayItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range marshalItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) assert.Nil(t, err, "expected nil error, got: %v", err) @@ -777,12 +784,13 @@ var marshalErrorItems = []testItemType{ func TestMarshalErrorItems(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range marshalErrorItems { t.Run(strconv.Itoa(i), func(t *testing.T) { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(item.obj) @@ -1079,6 +1087,7 @@ type docWithGetterField struct { func TestMarshalAllItemsWithGetter(t *testing.T) { buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, item := range allItems { if item.data == "" { continue @@ -1089,7 +1098,7 @@ func TestMarshalAllItemsWithGetter(t *testing.T) { obj.Field = &typeWithGetter{result: item.obj.(bson.M)["_"]} vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err, "expected nil error, got: %v", err) - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(Registry) err = enc.Encode(obj) assert.Nil(t, err, "expected nil error, got: %v", err) diff --git a/mongo/cursor.go b/mongo/cursor.go index c48814f42b..67db1c2953 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -93,6 +93,7 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco // Convert documents slice to a sequence-style byte array. buf := new(bytes.Buffer) + enc := new(bson.Encoder) for _, doc := range documents { switch t := doc.(type) { case nil: @@ -105,11 +106,9 @@ func NewCursorFromDocuments(documents []interface{}, err error, registry *bsonco if err != nil { return nil, err } - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(registry) err = enc.Encode(doc) - // var marshalErr error - // docsBytes, marshalErr = bson.MarshalAppendWithRegistry(registry, docsBytes, doc) if err != nil { return nil, err } diff --git a/mongo/options/mongooptions.go b/mongo/options/mongooptions.go index 95d88f6e96..bf1bc69adb 100644 --- a/mongo/options/mongooptions.go +++ b/mongo/options/mongooptions.go @@ -134,13 +134,14 @@ func (af *ArrayFilters) ToArray() ([]bson.Raw, error) { } filters := make([]bson.Raw, 0, len(af.Filters)) buf := new(bytes.Buffer) + enc := new(bson.Encoder) for _, f := range af.Filters { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) if err != nil { return nil, err } - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(registry) err = enc.Encode(f) if err != nil { @@ -163,13 +164,14 @@ func (af *ArrayFilters) ToArrayDocument() (bson.Raw, error) { idx, arr := bsoncore.AppendArrayStart(nil) buf := new(bytes.Buffer) + enc := new(bson.Encoder) for i, f := range af.Filters { buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) if err != nil { return nil, err } - enc := bson.NewEncoder(vw) + enc.Reset(vw) enc.SetRegistry(registry) err = enc.Encode(f) if err != nil { From 02284e7561b257fe45b46fa330c7592f0a936a32 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 6 Dec 2023 17:40:14 -0500 Subject: [PATCH 09/12] WIP --- bson/marshal.go | 52 +++++++------------------------------------------ 1 file changed, 7 insertions(+), 45 deletions(-) diff --git a/bson/marshal.go b/bson/marshal.go index 6fd42792bb..15793bb2e6 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -53,7 +53,9 @@ func Marshal(val interface{}) ([]byte, error) { if err != nil { return nil, err } - enc := NewEncoder(vw) + enc := encPool.Get().(*Encoder) + defer encPool.Put(enc) + enc.Reset(vw) enc.SetRegistry(DefaultRegistry) err = enc.Encode(val) if err != nil { @@ -71,58 +73,19 @@ func MarshalValue(val interface{}) (bsontype.Type, []byte, error) { return MarshalValueWithRegistry(DefaultRegistry, val) } -// MarshalValueAppend will append the BSON encoding of val to dst. If dst is not large enough to hold the BSON encoding -// of val, dst will be grown. -// -// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go -// Driver 2.0. -func MarshalValueAppend(dst []byte, val interface{}) (bsontype.Type, []byte, error) { - return MarshalValueAppendWithRegistry(DefaultRegistry, dst, val) -} - // MarshalValueWithRegistry returns the BSON encoding of val using Registry r. // // Deprecated: Using a custom registry to marshal individual BSON values will not be supported in Go // Driver 2.0. func MarshalValueWithRegistry(r *bsoncodec.Registry, val interface{}) (bsontype.Type, []byte, error) { - dst := make([]byte, 0) - return MarshalValueAppendWithRegistry(r, dst, val) -} - -// MarshalValueWithContext returns the BSON encoding of val using EncodeContext ec. -// -// Deprecated: Using a custom EncodeContext to marshal individual BSON elements will not be -// supported in Go Driver 2.0. -func MarshalValueWithContext(ec bsoncodec.EncodeContext, val interface{}) (bsontype.Type, []byte, error) { - dst := make([]byte, 0) - return MarshalValueAppendWithContext(ec, dst, val) -} - -// MarshalValueAppendWithRegistry will append the BSON encoding of val to dst using Registry r. If dst is not large -// enough to hold the BSON encoding of val, dst will be grown. -// -// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go -// Driver 2.0. -func MarshalValueAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) (bsontype.Type, []byte, error) { - return MarshalValueAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val) -} - -// MarshalValueAppendWithContext will append the BSON encoding of val to dst using EncodeContext ec. If dst is not large -// enough to hold the BSON encoding of val, dst will be grown. -// -// Deprecated: Appending individual BSON elements to an existing slice will not be supported in Go -// Driver 2.0. -func MarshalValueAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) (bsontype.Type, []byte, error) { - // get a ValueWriter configured to write to dst - sw := new(bsonrw.SliceWriter) - *sw = dst - vwFlusher := bvwPool.GetAtModeElement(sw) + var sw bsonrw.SliceWriter + vwFlusher := bvwPool.GetAtModeElement(&sw) // get an Encoder and encode the value enc := encPool.Get().(*Encoder) defer encPool.Put(enc) enc.Reset(vwFlusher) - enc.ec = ec + enc.ec = bsoncodec.EncodeContext{Registry: r} if err := enc.Encode(val); err != nil { return 0, nil, err } @@ -133,8 +96,7 @@ func MarshalValueAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val i if err := vwFlusher.Flush(); err != nil { return 0, nil, err } - buffer := *sw - return bsontype.Type(buffer[0]), buffer[2:], nil + return bsontype.Type(sw[0]), sw[2:], nil } // MarshalExtJSON returns the extended JSON encoding of val. From c6914e1b39b8209eff1fbbca1d7fc4cb3411b1e2 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Wed, 6 Dec 2023 18:26:31 -0500 Subject: [PATCH 10/12] WIP --- bson/marshal.go | 123 ++----------------------------------------- bson/marshal_test.go | 23 ++------ 2 files changed, 8 insertions(+), 138 deletions(-) diff --git a/bson/marshal.go b/bson/marshal.go index 15793bb2e6..8006efdc1e 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -78,7 +78,7 @@ func MarshalValue(val interface{}) (bsontype.Type, []byte, error) { // Deprecated: Using a custom registry to marshal individual BSON values will not be supported in Go // Driver 2.0. func MarshalValueWithRegistry(r *bsoncodec.Registry, val interface{}) (bsontype.Type, []byte, error) { - var sw bsonrw.SliceWriter + sw := bsonrw.SliceWriter(make([]byte, 0)) vwFlusher := bvwPool.GetAtModeElement(&sw) // get an Encoder and encode the value @@ -101,135 +101,22 @@ func MarshalValueWithRegistry(r *bsoncodec.Registry, val interface{}) (bsontype. // MarshalExtJSON returns the extended JSON encoding of val. func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error) { - return MarshalExtJSONWithRegistry(DefaultRegistry, val, canonical, escapeHTML) -} - -// MarshalExtJSONAppend will append the extended JSON encoding of val to dst. -// If dst is not large enough to hold the extended JSON encoding of val, dst -// will be grown. -// -// Deprecated: Use [NewEncoder] and pass the dst byte slice (wrapped by a bytes.Buffer) into -// [bsonrw.NewExtJSONValueWriter] instead: -// -// buf := bytes.NewBuffer(dst) -// vw, err := bsonrw.NewExtJSONValueWriter(buf, true, false) -// if err != nil { -// panic(err) -// } -// enc, err := bson.NewEncoder(vw) -// if err != nil { -// panic(err) -// } -// -// See [Encoder] for more examples. -func MarshalExtJSONAppend(dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) { - return MarshalExtJSONAppendWithRegistry(DefaultRegistry, dst, val, canonical, escapeHTML) -} - -// MarshalExtJSONWithRegistry returns the extended JSON encoding of val using Registry r. -// -// Deprecated: Use [NewEncoder] and specify the Registry by calling [Encoder.SetRegistry] instead: -// -// buf := new(bytes.Buffer) -// vw, err := bsonrw.NewBSONValueWriter(buf) -// if err != nil { -// panic(err) -// } -// enc, err := bson.NewEncoder(vw) -// if err != nil { -// panic(err) -// } -// enc.SetRegistry(reg) -// -// See [Encoder] for more examples. -func MarshalExtJSONWithRegistry(r *bsoncodec.Registry, val interface{}, canonical, escapeHTML bool) ([]byte, error) { - dst := make([]byte, 0, defaultDstCap) - return MarshalExtJSONAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val, canonical, escapeHTML) -} - -// MarshalExtJSONWithContext returns the extended JSON encoding of val using Registry r. -// -// Deprecated: Use [NewEncoder] and use the Encoder configuration methods to set the desired marshal -// behavior instead: -// -// buf := new(bytes.Buffer) -// vw, err := bsonrw.NewBSONValueWriter(buf) -// if err != nil { -// panic(err) -// } -// enc, err := bson.NewEncoder(vw) -// if err != nil { -// panic(err) -// } -// enc.IntMinSize() -// -// See [Encoder] for more examples. -func MarshalExtJSONWithContext(ec bsoncodec.EncodeContext, val interface{}, canonical, escapeHTML bool) ([]byte, error) { - dst := make([]byte, 0, defaultDstCap) - return MarshalExtJSONAppendWithContext(ec, dst, val, canonical, escapeHTML) -} - -// MarshalExtJSONAppendWithRegistry will append the extended JSON encoding of -// val to dst using Registry r. If dst is not large enough to hold the BSON -// encoding of val, dst will be grown. -// -// Deprecated: Use [NewEncoder], pass the dst byte slice (wrapped by a bytes.Buffer) into -// [bsonrw.NewExtJSONValueWriter], and specify the Registry by calling [Encoder.SetRegistry] -// instead: -// -// buf := bytes.NewBuffer(dst) -// vw, err := bsonrw.NewExtJSONValueWriter(buf, true, false) -// if err != nil { -// panic(err) -// } -// enc, err := bson.NewEncoder(vw) -// if err != nil { -// panic(err) -// } -// -// See [Encoder] for more examples. -func MarshalExtJSONAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) { - return MarshalExtJSONAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val, canonical, escapeHTML) -} - -// MarshalExtJSONAppendWithContext will append the extended JSON encoding of -// val to dst using Registry r. If dst is not large enough to hold the BSON -// encoding of val, dst will be grown. -// -// Deprecated: Use [NewEncoder], pass the dst byte slice (wrapped by a bytes.Buffer) into -// [bsonrw.NewExtJSONValueWriter], and use the Encoder configuration methods to set the desired marshal -// behavior instead: -// -// buf := bytes.NewBuffer(dst) -// vw, err := bsonrw.NewExtJSONValueWriter(buf, true, false) -// if err != nil { -// panic(err) -// } -// enc, err := bson.NewEncoder(vw) -// if err != nil { -// panic(err) -// } -// enc.IntMinSize() -// -// See [Encoder] for more examples. -func MarshalExtJSONAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) { - sw := new(bsonrw.SliceWriter) - *sw = dst - ejvw := extjPool.Get(sw, canonical, escapeHTML) + sw := bsonrw.SliceWriter(make([]byte, 0, defaultDstCap)) + ejvw := extjPool.Get(&sw, canonical, escapeHTML) defer extjPool.Put(ejvw) enc := encPool.Get().(*Encoder) defer encPool.Put(enc) enc.Reset(ejvw) - enc.ec = ec + enc.ec = bsoncodec.EncodeContext{Registry: DefaultRegistry} err := enc.Encode(val) if err != nil { return nil, err } - return *sw, nil + return sw, nil } // IndentExtJSON will prefix and indent the provided extended JSON src and append it to dst. diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 4031bf38c7..461abaf8cb 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -80,28 +80,11 @@ func TestMarshalWithContext(t *testing.T) { } } -func TestMarshalExtJSONAppendWithContext(t *testing.T) { - t.Run("MarshalExtJSONAppendWithContext", func(t *testing.T) { - dst := make([]byte, 0, 1024) +func TestMarshalExtJSON(t *testing.T) { + t.Run("MarshalExtJSON", func(t *testing.T) { type teststruct struct{ Foo int } val := teststruct{1} - ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} - got, err := MarshalExtJSONAppendWithContext(ec, dst, val, true, false) - noerr(t, err) - want := []byte(`{"foo":{"$numberInt":"1"}}`) - if !bytes.Equal(got, want) { - t.Errorf("Bytes are not equal. got %v; want %v", got, want) - t.Errorf("Bytes:\n%s\n%s", got, want) - } - }) -} - -func TestMarshalExtJSONWithContext(t *testing.T) { - t.Run("MarshalExtJSONWithContext", func(t *testing.T) { - type teststruct struct{ Foo int } - val := teststruct{1} - ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} - got, err := MarshalExtJSONWithContext(ec, val, true, false) + got, err := MarshalExtJSON(val, true, false) noerr(t, err) want := []byte(`{"foo":{"$numberInt":"1"}}`) if !bytes.Equal(got, want) { From 99c10fd419d5d0ab30c5ab364e694d8cb6c00b10 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Thu, 7 Dec 2023 18:42:34 -0500 Subject: [PATCH 11/12] WIP --- bson/bson_test.go | 3 ++- bson/decoder_test.go | 2 +- bson/marshal_test.go | 5 ++--- bson/raw_value_test.go | 14 +++++++------- bson/unmarshal_test.go | 5 ++--- bson/unmarshal_value_test.go | 10 ++++++---- mongo/database_test.go | 4 ++-- mongo/integration/client_test.go | 5 ++++- mongo/integration/crud_spec_test.go | 8 ++++++-- mongo/integration/database_test.go | 9 ++++++--- mongo/integration/unified_spec_test.go | 9 ++++++--- mongo/options/clientoptions_test.go | 2 +- mongo/read_write_concern_spec_test.go | 8 ++++++-- x/mongo/driver/topology/server_options.go | 2 +- 14 files changed, 52 insertions(+), 34 deletions(-) diff --git a/bson/bson_test.go b/bson/bson_test.go index 357308a7d1..76b937a668 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -183,7 +183,8 @@ func TestMapCodec(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mapCodec := bsoncodec.NewMapCodec(tc.opts) - mapRegistry := NewRegistryBuilder().RegisterDefaultEncoder(reflect.Map, mapCodec).Build() + mapRegistry := NewRegistry() + mapRegistry.RegisterKindEncoder(reflect.Map, mapCodec) buf.Reset() vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err) diff --git a/bson/decoder_test.go b/bson/decoder_test.go index 1295b2617a..ac87e0950b 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -230,7 +230,7 @@ func TestDecoderv2(t *testing.T) { t.Run("SetRegistry", func(t *testing.T) { t.Parallel() - r1, r2 := DefaultRegistry, NewRegistryBuilder().Build() + r1, r2 := DefaultRegistry, NewRegistry() dc1 := bsoncodec.DecodeContext{Registry: r1} dc2 := bsoncodec.DecodeContext{Registry: r2} dec, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{})) diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 461abaf8cb..6a71190ea0 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -166,9 +166,8 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { return vw.WriteInt32(int32(val.Int()) * -1) } - customReg := NewRegistryBuilder(). - RegisterTypeEncoder(tInt32, encodeInt32). - Build() + customReg := NewRegistry() + customReg.RegisterTypeEncoder(tInt32, encodeInt32) // Helper function to run the test and make assertions. The provided original value should result in the document // {"x": {$numberInt: 1}} when marshalled with the default registry. diff --git a/bson/raw_value_test.go b/bson/raw_value_test.go index 87f08c4a55..f2d7918e21 100644 --- a/bson/raw_value_test.go +++ b/bson/raw_value_test.go @@ -26,7 +26,7 @@ func TestRawValue(t *testing.T) { t.Run("Uses registry attached to value", func(t *testing.T) { t.Parallel() - reg := bsoncodec.NewRegistryBuilder().Build() + reg := bsoncodec.NewRegistry() val := RawValue{Type: bsontype.String, Value: bsoncore.AppendString(nil, "foobar"), r: reg} var s string want := bsoncodec.ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -64,7 +64,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - reg := bsoncodec.NewRegistryBuilder().Build() + reg := bsoncodec.NewRegistry() var val RawValue var s string want := bsoncodec.ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -76,7 +76,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns DecodeValue error", func(t *testing.T) { t.Parallel() - reg := NewRegistryBuilder().Build() + reg := NewRegistry() val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string want := fmt.Errorf("cannot decode %v into a string type", bsontype.Double) @@ -88,7 +88,7 @@ func TestRawValue(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - reg := NewRegistryBuilder().Build() + reg := NewRegistry() want := float64(3.14159) val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, want)} var got float64 @@ -115,7 +115,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - dc := bsoncodec.DecodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()} + dc := bsoncodec.DecodeContext{Registry: bsoncodec.NewRegistry()} var val RawValue var s string want := bsoncodec.ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -127,7 +127,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns DecodeValue error", func(t *testing.T) { t.Parallel() - dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()} + dc := bsoncodec.DecodeContext{Registry: NewRegistry()} val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string want := fmt.Errorf("cannot decode %v into a string type", bsontype.Double) @@ -139,7 +139,7 @@ func TestRawValue(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()} + dc := bsoncodec.DecodeContext{Registry: NewRegistry()} want := float64(3.14159) val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, want)} var got float64 diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index 11452a895c..667f3f5094 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -229,9 +229,8 @@ func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) { val.SetInt(int64(-1 * i32)) return nil } - customReg := NewRegistryBuilder(). - RegisterTypeDecoder(tInt32, decodeInt32). - Build() + customReg := NewRegistry() + customReg.RegisterTypeDecoder(tInt32, decodeInt32) docBytes := bsoncore.BuildDocumentFromElements( nil, diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index ef91da1659..f25e25ba8b 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -77,7 +77,8 @@ func TestUnmarshalValue(t *testing.T) { bytes: bsoncore.AppendString(nil, "hello world"), }, } - rb := NewRegistryBuilder().RegisterTypeDecoder(reflect.TypeOf([]byte{}), bsoncodec.NewSliceCodec()).Build() + reg := NewRegistry() + reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), bsoncodec.NewSliceCodec()) for _, tc := range testCases { tc := tc @@ -85,7 +86,7 @@ func TestUnmarshalValue(t *testing.T) { t.Parallel() gotValue := reflect.New(reflect.TypeOf(tc.val)) - err := UnmarshalValueWithRegistry(rb, tc.bsontype, tc.bytes, gotValue.Interface()) + err := UnmarshalValueWithRegistry(reg, tc.bsontype, tc.bytes, gotValue.Interface()) assert.Nil(t, err, "UnmarshalValueWithRegistry error: %v", err) assert.Equal(t, tc.val, gotValue.Elem().Interface(), "value mismatch; expected %s, got %s", tc.val, gotValue.Elem()) }) @@ -111,12 +112,13 @@ func BenchmarkSliceCodecUnmarshal(b *testing.B) { bytes: bsoncore.AppendString(nil, strings.Repeat("t", 4096)), }, } - rb := NewRegistryBuilder().RegisterTypeDecoder(reflect.TypeOf([]byte{}), bsoncodec.NewSliceCodec()).Build() + reg := NewRegistry() + reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), bsoncodec.NewSliceCodec()) for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - err := UnmarshalValueWithRegistry(rb, bm.bsontype, bm.bytes, &[]byte{}) + err := UnmarshalValueWithRegistry(reg, bm.bsontype, bm.bytes, &[]byte{}) if err != nil { b.Fatal(err) } diff --git a/mongo/database_test.go b/mongo/database_test.go index 745f533b73..9b32af3548 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -54,7 +54,7 @@ func TestDatabase(t *testing.T) { wc2 := &writeconcern.WriteConcern{W: 10} rcLocal := readconcern.Local() rcMajority := readconcern.Majority() - reg := bsoncodec.NewRegistryBuilder().Build() + reg := bsoncodec.NewRegistry() opts := options.Database().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetWriteConcern(wc1). SetReadPreference(rpSecondary).SetReadConcern(rcMajority).SetWriteConcern(wc2).SetRegistry(reg) @@ -71,7 +71,7 @@ func TestDatabase(t *testing.T) { rpPrimary := readpref.Primary() rcLocal := readconcern.Local() wc1 := &writeconcern.WriteConcern{W: 10} - reg := bsoncodec.NewRegistryBuilder().Build() + reg := bsoncodec.NewRegistry() client := setupClient(options.Client().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetRegistry(reg)) got := client.Database("foo", options.Database().SetWriteConcern(wc1)) diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index af90b6b45e..3b50b8855c 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -103,8 +103,11 @@ func (sc *slowConn) Read(b []byte) (n int, err error) { func TestClient(t *testing.T) { mt := mtest.New(t, noClientOpts) + reg := bson.NewRegistry() + reg.RegisterTypeEncoder(reflect.TypeOf(int64(0)), &negateCodec{}) + reg.RegisterTypeDecoder(reflect.TypeOf(int64(0)), &negateCodec{}) registryOpts := options.Client(). - SetRegistry(bson.NewRegistryBuilder().RegisterCodec(reflect.TypeOf(int64(0)), &negateCodec{}).Build()) + SetRegistry(reg) mt.RunOpts("registry passed to cursors", mtest.NewOptions().ClientOptions(registryOpts), func(mt *mtest.T) { _, err := mt.Coll.InsertOne(context.Background(), negateCodec{ID: 10}) assert.Nil(mt, err, "InsertOne error: %v", err) diff --git a/mongo/integration/crud_spec_test.go b/mongo/integration/crud_spec_test.go index a80a9dd53c..6876ee8180 100644 --- a/mongo/integration/crud_spec_test.go +++ b/mongo/integration/crud_spec_test.go @@ -17,6 +17,7 @@ import ( "testing" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/bsonutil" "go.mongodb.org/mongo-driver/mongo" @@ -55,8 +56,11 @@ type crudOutcome struct { Collection *outcomeCollection `bson:"collection"` } -var crudRegistry = bson.NewRegistryBuilder(). - RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})).Build() +var crudRegistry = func() *bsoncodec.Registry { + reg := bson.NewRegistry() + reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) + return reg +}() func TestCrudSpec(t *testing.T) { for _, dir := range []string{crudReadDir, crudWriteDir} { diff --git a/mongo/integration/database_test.go b/mongo/integration/database_test.go index 31aba79719..4368be5bbc 100644 --- a/mongo/integration/database_test.go +++ b/mongo/integration/database_test.go @@ -14,6 +14,7 @@ import ( "testing" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/internal/assert" @@ -31,9 +32,11 @@ const ( ) var ( - interfaceAsMapRegistry = bson.NewRegistryBuilder(). - RegisterTypeMapEntry(bsontype.EmbeddedDocument, reflect.TypeOf(bson.M{})). - Build() + interfaceAsMapRegistry = func() *bsoncodec.Registry { + reg := bson.NewRegistry() + reg.RegisterTypeMapEntry(bsontype.EmbeddedDocument, reflect.TypeOf(bson.M{})) + return reg + }() ) func TestDatabase(t *testing.T) { diff --git a/mongo/integration/unified_spec_test.go b/mongo/integration/unified_spec_test.go index 4da42e6a68..d18728d69f 100644 --- a/mongo/integration/unified_spec_test.go +++ b/mongo/integration/unified_spec_test.go @@ -185,9 +185,12 @@ var directories = []string{ } var checkOutcomeOpts = options.Collection().SetReadPreference(readpref.Primary()).SetReadConcern(readconcern.Local()) -var specTestRegistry = bson.NewRegistryBuilder(). - RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). - RegisterTypeDecoder(reflect.TypeOf(testData{}), bsoncodec.ValueDecoderFunc(decodeTestData)).Build() +var specTestRegistry = func() *bsoncodec.Registry { + reg := bson.NewRegistry() + reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) + reg.RegisterTypeDecoder(reflect.TypeOf(testData{}), bsoncodec.ValueDecoderFunc(decodeTestData)) + return reg +}() func TestUnifiedSpecs(t *testing.T) { for _, specDir := range directories { diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index 29136557b2..d3f29ad774 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -81,7 +81,7 @@ func TestClientOptions(t *testing.T) { {"Monitor", (*ClientOptions).SetMonitor, &event.CommandMonitor{}, "Monitor", false}, {"ReadConcern", (*ClientOptions).SetReadConcern, readconcern.Majority(), "ReadConcern", false}, {"ReadPreference", (*ClientOptions).SetReadPreference, readpref.SecondaryPreferred(), "ReadPreference", false}, - {"Registry", (*ClientOptions).SetRegistry, bson.NewRegistryBuilder().Build(), "Registry", false}, + {"Registry", (*ClientOptions).SetRegistry, bson.NewRegistry(), "Registry", false}, {"ReplicaSet", (*ClientOptions).SetReplicaSet, "example-replicaset", "ReplicaSet", true}, {"RetryWrites", (*ClientOptions).SetRetryWrites, true, "RetryWrites", true}, {"ServerSelectionTimeout", (*ClientOptions).SetServerSelectionTimeout, 5 * time.Second, "ServerSelectionTimeout", true}, diff --git a/mongo/read_write_concern_spec_test.go b/mongo/read_write_concern_spec_test.go index dfbe14e3f4..7a784c54a3 100644 --- a/mongo/read_write_concern_spec_test.go +++ b/mongo/read_write_concern_spec_test.go @@ -15,6 +15,7 @@ import ( "time" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/mongo/readconcern" @@ -31,8 +32,11 @@ const ( var ( serverDefaultConcern = []byte{5, 0, 0, 0, 0} // server default read concern and write concern is empty document - specTestRegistry = bson.NewRegistryBuilder(). - RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})).Build() + specTestRegistry = func() *bsoncodec.Registry { + reg := bson.NewRegistry() + reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) + return reg + }() ) type connectionStringTestFile struct { diff --git a/x/mongo/driver/topology/server_options.go b/x/mongo/driver/topology/server_options.go index 4504a25355..84229c4401 100644 --- a/x/mongo/driver/topology/server_options.go +++ b/x/mongo/driver/topology/server_options.go @@ -18,7 +18,7 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) -var defaultRegistry = bson.NewRegistryBuilder().Build() +var defaultRegistry = bson.NewRegistry() type serverConfig struct { clock *session.ClusterClock From 7ed7a6cb4aefbb871050f1ebd36498a6223159dc Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 5 Jan 2024 14:05:02 -0500 Subject: [PATCH 12/12] minor fixes --- bson/bson_test.go | 6 ++---- bson/marshal_test.go | 12 ++++-------- .../unified/gridfs_bucket_operation_execution.go | 7 ++++++- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/bson/bson_test.go b/bson/bson_test.go index 76b937a668..9c312c9861 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -178,17 +178,15 @@ func TestMapCodec(t *testing.T) { {"true", bsonoptions.MapCodec().SetEncodeKeysWithStringer(true), "bar"}, {"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"}, } - buf := new(bytes.Buffer) - enc := new(Encoder) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mapCodec := bsoncodec.NewMapCodec(tc.opts) mapRegistry := NewRegistry() mapRegistry.RegisterKindEncoder(reflect.Map, mapCodec) - buf.Reset() + buf := new(bytes.Buffer) vw, err := bsonrw.NewBSONValueWriter(buf) assert.Nil(t, err) - enc.Reset(vw) + enc := NewEncoder(vw) enc.SetRegistry(mapRegistry) err = enc.Encode(mapObj) assert.Nil(t, err, "Encode error: %v", err) diff --git a/bson/marshal_test.go b/bson/marshal_test.go index bf59616815..5010541354 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -27,8 +27,6 @@ import ( var tInt32 = reflect.TypeOf(int32(0)) func TestMarshalWithRegistry(t *testing.T) { - buf := new(bytes.Buffer) - enc := new(Encoder) for _, tc := range marshalingTestCases { t.Run(tc.name, func(t *testing.T) { var reg *bsoncodec.Registry @@ -37,10 +35,10 @@ func TestMarshalWithRegistry(t *testing.T) { } else { reg = DefaultRegistry } - buf.Reset() + buf := new(bytes.Buffer) vw, err := bsonrw.NewBSONValueWriter(buf) noerr(t, err) - enc.Reset(vw) + enc := NewEncoder(vw) enc.SetRegistry(reg) err = enc.Encode(tc.val) noerr(t, err) @@ -54,8 +52,6 @@ func TestMarshalWithRegistry(t *testing.T) { } func TestMarshalWithContext(t *testing.T) { - buf := new(bytes.Buffer) - enc := new(Encoder) for _, tc := range marshalingTestCases { t.Run(tc.name, func(t *testing.T) { var reg *bsoncodec.Registry @@ -64,10 +60,10 @@ func TestMarshalWithContext(t *testing.T) { } else { reg = DefaultRegistry } - buf.Reset() + buf := new(bytes.Buffer) vw, err := bsonrw.NewBSONValueWriter(buf) noerr(t, err) - enc.Reset(vw) + enc := NewEncoder(vw) enc.IntMinSize() enc.SetRegistry(reg) err = enc.Encode(tc.val) diff --git a/internal/integration/unified/gridfs_bucket_operation_execution.go b/internal/integration/unified/gridfs_bucket_operation_execution.go index d53e714f1b..cfabd0025a 100644 --- a/internal/integration/unified/gridfs_bucket_operation_execution.go +++ b/internal/integration/unified/gridfs_bucket_operation_execution.go @@ -12,6 +12,7 @@ import ( "encoding/hex" "fmt" "io" + "math" "time" "go.mongodb.org/mongo-driver/bson" @@ -148,7 +149,11 @@ func executeBucketDownloadByName(ctx context.Context, operation *operation) (*op case "filename": filename = val.StringValue() case "revision": - opts.SetRevision(int32(val.AsInt64())) + revision := val.AsInt64() + if revision < math.MinInt32 || revision > math.MaxInt32 { + return nil, fmt.Errorf("revision overflows int32: %d", revision) + } + opts.SetRevision(int32(revision)) default: return nil, fmt.Errorf("unrecognized bucket download option %q", key) }