diff --git a/bson/copier.go b/bson/copier.go index abdd7162e4..3b04caedac 100644 --- a/bson/copier.go +++ b/bson/copier.go @@ -7,6 +7,7 @@ package bson import ( + "bytes" "errors" "fmt" "io" @@ -205,10 +206,7 @@ func copyValueFromBytes(dst ValueWriter, t Type, src []byte) error { return wvb.WriteValueBytes(t, src) } - vr := vrPool.Get().(*valueReader) - defer vrPool.Put(vr) - - vr.reset(src) + vr := newValueReader(bytes.NewReader(src)) vr.pushElement(t) return copyValue(dst, vr) diff --git a/bson/copier_test.go b/bson/copier_test.go index 23c13447a4..21527217ac 100644 --- a/bson/copier_test.go +++ b/bson/copier_test.go @@ -40,7 +40,7 @@ func TestCopier(t *testing.T) { doc = bsoncore.AppendStringElement(doc, "Hello", "world") doc, err := bsoncore.AppendDocumentEnd(doc, idx) noerr(t, err) - src := newValueReader(doc) + src := newValueReader(bytes.NewReader(doc)) dst := newValueWriterFromSlice(make([]byte, 0)) want := doc err = copyDocument(dst, src) @@ -77,7 +77,7 @@ func TestCopier(t *testing.T) { noerr(t, err) doc, err = bsoncore.AppendDocumentEnd(doc, idx) noerr(t, err) - src := newValueReader(doc) + src := newValueReader(bytes.NewReader(doc)) _, err = src.ReadDocument() noerr(t, err) @@ -450,7 +450,7 @@ func TestCopier(t *testing.T) { idx, ) noerr(t, err) - vr := newValueReader(b) + vr := newValueReader(bytes.NewReader(b)) _, err = vr.ReadDocument() noerr(t, err) _, _, err = vr.ReadElement() @@ -489,7 +489,7 @@ func TestCopier(t *testing.T) { idx, ) noerr(t, err) - vr := newValueReader(b) + vr := newValueReader(bytes.NewReader(b)) _, err = vr.ReadDocument() noerr(t, err) _, _, err = vr.ReadElement() diff --git a/bson/decoder_example_test.go b/bson/decoder_example_test.go index 3e17e98927..ffe8bd6a48 100644 --- a/bson/decoder_example_test.go +++ b/bson/decoder_example_test.go @@ -30,7 +30,7 @@ func ExampleDecoder() { // Create a Decoder that reads the marshaled BSON document and use it to // unmarshal the document into a Product struct. - decoder := bson.NewDecoder(bson.NewValueReader(data)) + decoder := bson.NewDecoder(bson.NewValueReader(bytes.NewReader(data))) type Product struct { Name string `bson:"name"` @@ -66,7 +66,7 @@ func ExampleDecoder_DefaultDocumentM() { // Create a Decoder that reads the marshaled BSON document and use it to unmarshal the document // into a City struct. - decoder := bson.NewDecoder(bson.NewValueReader(data)) + decoder := bson.NewDecoder(bson.NewValueReader(bytes.NewReader(data))) type City struct { Name string `bson:"name"` @@ -104,7 +104,7 @@ func ExampleDecoder_UseJSONStructTags() { // Create a Decoder that reads the marshaled BSON document and use it to // unmarshal the document into a Product struct. - decoder := bson.NewDecoder(bson.NewValueReader(data)) + decoder := bson.NewDecoder(bson.NewValueReader(bytes.NewReader(data))) type Product struct { Name string `json:"name"` diff --git a/bson/decoder_test.go b/bson/decoder_test.go index dbef3e7fb0..f589b3ad56 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -28,7 +28,7 @@ func TestBasicDecode(t *testing.T) { t.Parallel() got := reflect.New(tc.sType).Elem() - vr := NewValueReader(tc.data) + vr := NewValueReader(bytes.NewReader(tc.data)) reg := DefaultRegistry decoder, err := reg.LookupDecoder(reflect.TypeOf(got)) noerr(t, err) @@ -184,7 +184,7 @@ func TestDecodingInterfaces(t *testing.T) { data, receiver, check := tc.stub() got := reflect.ValueOf(receiver).Elem() - vr := NewValueReader(data) + vr := NewValueReader(bytes.NewReader(data)) reg := DefaultRegistry decoder, err := reg.LookupDecoder(got.Type()) noerr(t, err) @@ -208,7 +208,7 @@ func TestDecoderv2(t *testing.T) { t.Parallel() got := reflect.New(tc.sType).Interface() - vr := NewValueReader(tc.data) + vr := NewValueReader(bytes.NewReader(tc.data)) dec := NewDecoder(vr) err := dec.Decode(got) noerr(t, err) @@ -223,7 +223,7 @@ func TestDecoderv2(t *testing.T) { _ = certainlydoesntexistelsewhereihope(func(string, string) string { return "" }) cdeih := func(string, string) string { return "certainlydoesntexistelsewhereihope" } - dec := NewDecoder(NewValueReader([]byte{})) + dec := NewDecoder(NewValueReader(bytes.NewReader([]byte{}))) want := ErrNoDecoder{Type: reflect.TypeOf(cdeih)} got := dec.Decode(&cdeih) assert.Equal(t, want, got, "Received unexpected error.") @@ -285,7 +285,7 @@ func TestDecoderv2(t *testing.T) { want := bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)) unmarshaler := &testUnmarshaler{} - vr := NewValueReader(want) + vr := NewValueReader(bytes.NewReader(want)) dec := NewDecoder(vr) err := dec.Decode(unmarshaler) noerr(t, err) @@ -302,7 +302,7 @@ func TestDecoderv2(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() - got := NewDecoder(NewValueReader([]byte{})) + got := NewDecoder(NewValueReader(bytes.NewReader([]byte{}))) if got == nil { t.Errorf("Was expecting a non-nil Decoder, but got ") } @@ -314,7 +314,7 @@ func TestDecoderv2(t *testing.T) { t.Run("success", func(t *testing.T) { t.Parallel() - got := NewDecoder(NewValueReader([]byte{})) + got := NewDecoder(NewValueReader(bytes.NewReader([]byte{}))) if got == nil { t.Errorf("Was expecting a non-nil Decoder, but got ") } @@ -332,7 +332,7 @@ func TestDecoderv2(t *testing.T) { got.Item = "apple" got.Bonus = 2 data := docToBytes(D{{"item", "canvas"}, {"qty", 4}}) - vr := NewValueReader(data) + vr := NewValueReader(bytes.NewReader(data)) dec := NewDecoder(vr) err := dec.Decode(&got) noerr(t, err) @@ -342,7 +342,7 @@ func TestDecoderv2(t *testing.T) { t.Run("Reset", func(t *testing.T) { t.Parallel() - vr1, vr2 := NewValueReader([]byte{}), NewValueReader([]byte{}) + vr1, vr2 := NewValueReader(bytes.NewReader([]byte{})), NewValueReader(bytes.NewReader([]byte{})) dec := NewDecoder(vr1) if dec.vr != vr1 { t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr1) @@ -358,7 +358,7 @@ func TestDecoderv2(t *testing.T) { r1, r2 := DefaultRegistry, NewRegistry() dc1 := DecodeContext{Registry: r1} dc2 := DecodeContext{Registry: r2} - dec := NewDecoder(NewValueReader([]byte{})) + dec := NewDecoder(NewValueReader(bytes.NewReader([]byte{}))) if !reflect.DeepEqual(dec.dc, dc1) { t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1) } @@ -371,7 +371,7 @@ func TestDecoderv2(t *testing.T) { t.Parallel() data := docToBytes(D{{"item", "canvas"}, {"qty", 4}}) - vr := NewValueReader(data) + vr := NewValueReader(bytes.NewReader(data)) dec := NewDecoder(vr) var got *D @@ -578,7 +578,7 @@ func TestDecoderConfiguration(t *testing.T) { t.Run(tc.description, func(t *testing.T) { t.Parallel() - dec := NewDecoder(NewValueReader(tc.input)) + dec := NewDecoder(NewValueReader(bytes.NewReader(tc.input))) tc.configure(dec) @@ -599,7 +599,7 @@ func TestDecoderConfiguration(t *testing.T) { Build()). Build() - dec := NewDecoder(NewValueReader(input)) + dec := NewDecoder(NewValueReader(bytes.NewReader(input))) dec.DefaultDocumentM() @@ -623,7 +623,7 @@ func TestDecoderConfiguration(t *testing.T) { Build()). Build() - dec := NewDecoder(NewValueReader(input)) + dec := NewDecoder(NewValueReader(bytes.NewReader(input))) dec.DefaultDocumentD() diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 31148ab644..0ab2862ad0 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -7,6 +7,7 @@ package bson import ( + "bytes" "encoding/json" "errors" "fmt" @@ -2428,7 +2429,7 @@ func TestDefaultValueDecoders(t *testing.T) { buildDocument(bsoncore.AppendNullElement(nil, "bar")), ), ) - dvr := NewValueReader(b) + dvr := NewValueReader(bytes.NewReader(b)) dr, err := dvr.ReadDocument() noerr(t, err) _, vr, err := dr.ReadElement() @@ -2481,7 +2482,7 @@ func TestDefaultValueDecoders(t *testing.T) { noerr(t, err) doc, err = bsoncore.AppendDocumentEnd(doc, idx) noerr(t, err) - dvr := NewValueReader(doc) + dvr := NewValueReader(bytes.NewReader(doc)) noerr(t, err) dr, err := dvr.ReadDocument() noerr(t, err) @@ -3132,7 +3133,7 @@ func TestDefaultValueDecoders(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - vr := NewValueReader(tc.b) + vr := NewValueReader(bytes.NewReader(tc.b)) reg := buildDefaultRegistry() vtype := reflect.TypeOf(tc.value) dec, err := reg.LookupDecoder(vtype) @@ -3181,7 +3182,7 @@ func TestDefaultValueDecoders(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - vr := NewValueReader(tc.b) + vr := NewValueReader(bytes.NewReader(tc.b)) reg := buildDefaultRegistry() vtype := reflect.TypeOf(tc.value) dec, err := reg.LookupDecoder(vtype) @@ -3403,7 +3404,7 @@ func TestDefaultValueDecoders(t *testing.T) { }) t.Run("top level document", func(t *testing.T) { data := bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)) - vr := NewValueReader(data) + vr := NewValueReader(bytes.NewReader(data)) want := D{{"pi", 3.14159}} var got interface{} val := reflect.ValueOf(&got).Elem() @@ -3451,7 +3452,7 @@ func TestDefaultValueDecoders(t *testing.T) { } for _, tc := range testCases { var got interface{} - vr := NewValueReader(doc) + vr := NewValueReader(bytes.NewReader(doc)) val := reflect.ValueOf(&got).Elem() err := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: tc.registry}, vr, val) @@ -3487,7 +3488,7 @@ func TestDefaultValueDecoders(t *testing.T) { } var got D - vr := NewValueReader(doc) + vr := NewValueReader(bytes.NewReader(doc)) val := reflect.ValueOf(&got).Elem() err := defaultSliceCodec.DecodeValue(DecodeContext{Registry: reg}, vr, val) noerr(t, err) @@ -3577,7 +3578,7 @@ func TestDefaultValueDecoders(t *testing.T) { // DecodeValue error when decoding into a D. "D slice", D{}, - NewValueReader(docBytes), + NewValueReader(bytes.NewReader(docBytes)), emptyInterfaceErrorRegistry, defaultSliceCodec, docEmptyInterfaceErr, @@ -3600,7 +3601,7 @@ func TestDefaultValueDecoders(t *testing.T) { // the decodeD helper function. "D array", [1]E{}, - NewValueReader(docBytes), + NewValueReader(bytes.NewReader(docBytes)), emptyInterfaceErrorRegistry, ValueDecoderFunc(dvd.ArrayDecodeValue), docEmptyInterfaceErr, @@ -3623,7 +3624,7 @@ func TestDefaultValueDecoders(t *testing.T) { // DecodeValue error when decoding into a map. "map", map[string]interface{}{}, - NewValueReader(docBytes), + NewValueReader(bytes.NewReader(docBytes)), emptyInterfaceErrorRegistry, defaultMapCodec, docEmptyInterfaceErr, @@ -3632,7 +3633,7 @@ func TestDefaultValueDecoders(t *testing.T) { // DecodeValue error when decoding into a struct. "struct - DecodeValue error", emptyInterfaceStruct{}, - NewValueReader(docBytes), + NewValueReader(bytes.NewReader(docBytes)), emptyInterfaceErrorRegistry, defaultTestStructCodec, emptyInterfaceStructErr, @@ -3643,7 +3644,7 @@ func TestDefaultValueDecoders(t *testing.T) { // no decoder for strings. "struct - no decoder found", stringStruct{}, - NewValueReader(docBytes), + NewValueReader(bytes.NewReader(docBytes)), newTestRegistryBuilder().Build(), defaultTestStructCodec, stringStructErr, @@ -3651,7 +3652,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "deeply nested struct", outer{}, - NewValueReader(outerDoc), + NewValueReader(bytes.NewReader(outerDoc)), nestedRegistry, defaultTestStructCodec, nestedErr, @@ -3681,7 +3682,7 @@ func TestDefaultValueDecoders(t *testing.T) { type outer struct{ Foo inner } dc := DecodeContext{Registry: buildDefaultRegistry()} - vr := NewValueReader(outerBytes) + vr := NewValueReader(bytes.NewReader(outerBytes)) val := reflect.New(reflect.TypeOf(outer{})).Elem() err := defaultTestStructCodec.DecodeValue(dc, vr, val) @@ -3714,7 +3715,7 @@ func TestDefaultValueDecoders(t *testing.T) { reg := rb.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))).Build() dc := DecodeContext{Registry: reg} - vr := NewValueReader(docBytes) + vr := NewValueReader(bytes.NewReader(docBytes)) val := reflect.New(tD).Elem() err := defaultValueDecoders.DDecodeValue(dc, vr, val) assert.Nil(t, err, "DDecodeValue error: %v", err) @@ -3733,7 +3734,7 @@ func TestDefaultValueDecoders(t *testing.T) { type myMap map[string]mybool dc := DecodeContext{Registry: buildDefaultRegistry()} - vr := NewValueReader(docBytes) + vr := NewValueReader(bytes.NewReader(docBytes)) val := reflect.New(reflect.TypeOf(myMap{})).Elem() err := defaultMapCodec.DecodeValue(dc, vr, val) assert.Nil(t, err, "DecodeValue error: %v", err) diff --git a/bson/mgocompat/bson_test.go b/bson/mgocompat/bson_test.go index 6651509983..4d972059ac 100644 --- a/bson/mgocompat/bson_test.go +++ b/bson/mgocompat/bson_test.go @@ -478,7 +478,7 @@ func (t *prefixPtr) SetBSON(raw bson.RawValue) error { if err != nil { return err } - vr := bson.NewBSONValueReader(raw.Type, raw.Value) + vr := bson.NewBSONValueReader(raw.Type, bytes.NewReader(raw.Value)) err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) if err != nil { return err @@ -505,7 +505,7 @@ func (t *prefixVal) SetBSON(raw bson.RawValue) error { if err != nil { return err } - vr := bson.NewBSONValueReader(raw.Type, raw.Value) + vr := bson.NewBSONValueReader(raw.Type, bytes.NewReader(raw.Value)) err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) if err != nil { return err @@ -929,7 +929,7 @@ func (o *setterType) SetBSON(raw bson.RawValue) error { if raw.Type == 0x00 { raw.Type = bson.TypeEmbeddedDocument } - vr := bson.NewBSONValueReader(raw.Type, raw.Value) + vr := bson.NewBSONValueReader(raw.Type, bytes.NewReader(raw.Value)) err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) if err != nil { return err @@ -1288,7 +1288,7 @@ func (s *getterSetterD) SetBSON(raw bson.RawValue) error { if raw.Type == 0x00 { raw.Type = bson.TypeEmbeddedDocument } - vr := bson.NewBSONValueReader(raw.Type, raw.Value) + vr := bson.NewBSONValueReader(raw.Type, bytes.NewReader(raw.Value)) err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) if err != nil { return err @@ -1314,7 +1314,7 @@ func (i *getterSetterInt) SetBSON(raw bson.RawValue) error { if raw.Type == 0x00 { raw.Type = bson.TypeEmbeddedDocument } - vr := bson.NewBSONValueReader(raw.Type, raw.Value) + vr := bson.NewBSONValueReader(raw.Type, bytes.NewReader(raw.Value)) err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) if err != nil { return err @@ -1336,7 +1336,7 @@ func (s *ifaceSlice) SetBSON(raw bson.RawValue) error { if err != nil { return err } - vr := bson.NewBSONValueReader(raw.Type, raw.Value) + vr := bson.NewBSONValueReader(raw.Type, bytes.NewReader(raw.Value)) err = decoder.DecodeValue(bson.DecodeContext{Registry: Registry}, vr, rval) if err != nil { return err diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index be3aeab978..db0820f25d 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -1038,7 +1038,7 @@ func TestPrimitiveValueDecoders(t *testing.T) { t.Run("Decode", func(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - vr := NewValueReader(tc.b) + vr := NewValueReader(bytes.NewReader(tc.b)) dec := NewDecoder(vr) gotVal := reflect.New(reflect.TypeOf(tc.value)) err := dec.Decode(gotVal.Interface()) diff --git a/bson/raw_value.go b/bson/raw_value.go index a32b82e41d..2d6e367dc8 100644 --- a/bson/raw_value.go +++ b/bson/raw_value.go @@ -71,7 +71,7 @@ func (rv RawValue) UnmarshalWithRegistry(r *Registry, val interface{}) error { return ErrNilRegistry } - vr := NewBSONValueReader(rv.Type, rv.Value) + vr := NewBSONValueReader(rv.Type, bytes.NewReader(rv.Value)) rval := reflect.ValueOf(val) if rval.Kind() != reflect.Ptr { return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval) @@ -91,7 +91,7 @@ func (rv RawValue) UnmarshalWithContext(dc *DecodeContext, val interface{}) erro return ErrNilContext } - vr := NewBSONValueReader(rv.Type, rv.Value) + vr := NewBSONValueReader(rv.Type, bytes.NewReader(rv.Value)) rval := reflect.ValueOf(val) if rval.Kind() != reflect.Ptr { return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval) diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 7caadc5dbc..f5e4576e40 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -47,7 +47,7 @@ func Unmarshal(data []byte, val interface{}) error { // // Deprecated: Use [NewDecoder] and specify the Registry by calling [Decoder.SetRegistry] instead: // -// dec, err := bson.NewDecoder(NewBSONDocumentReader(data)) +// dec, err := bson.NewDecoder(NewValueReader(data)) // if err != nil { // panic(err) // } @@ -55,7 +55,7 @@ func Unmarshal(data []byte, val interface{}) error { // // See [Decoder] for more examples. func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { - vr := NewValueReader(data) + vr := NewValueReader(bytes.NewReader(data)) return unmarshalFromReader(DecodeContext{Registry: r}, vr, val) } @@ -66,7 +66,7 @@ func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { // Deprecated: Use [NewDecoder] and use the Decoder configuration methods to set the desired unmarshal // behavior instead: // -// dec, err := bson.NewDecoder(NewBSONDocumentReader(data)) +// dec, err := bson.NewDecoder(NewValueReader(data)) // if err != nil { // panic(err) // } @@ -74,7 +74,7 @@ func UnmarshalWithRegistry(r *Registry, data []byte, val interface{}) error { // // See [Decoder] for more examples. func UnmarshalWithContext(dc DecodeContext, data []byte, val interface{}) error { - vr := NewValueReader(data) + vr := NewValueReader(bytes.NewReader(data)) return unmarshalFromReader(dc, vr, val) } @@ -92,7 +92,7 @@ func UnmarshalValue(t Type, data []byte, val interface{}) error { // Deprecated: Using a custom registry to unmarshal individual BSON values will not be supported in // Go Driver 2.0. func UnmarshalValueWithRegistry(r *Registry, t Type, data []byte, val interface{}) error { - vr := NewBSONValueReader(t, data) + vr := NewBSONValueReader(t, bytes.NewReader(data)) return unmarshalFromReader(DecodeContext{Registry: r}, vr, val) } diff --git a/bson/unmarshaling_cases_test.go b/bson/unmarshaling_cases_test.go index 4b8210415e..63706cb4dc 100644 --- a/bson/unmarshaling_cases_test.go +++ b/bson/unmarshaling_cases_test.go @@ -7,6 +7,7 @@ package bson import ( + "bytes" "reflect" ) @@ -196,11 +197,11 @@ type unmarshalerNonPtrStruct struct { type myInt64 int64 -func (mi *myInt64) UnmarshalBSON(bytes []byte) error { - if len(bytes) == 0 { +func (mi *myInt64) UnmarshalBSON(b []byte) error { + if len(b) == 0 { return nil } - i, err := NewBSONValueReader(TypeInt64, bytes).ReadInt64() + i, err := NewBSONValueReader(TypeInt64, bytes.NewReader(b)).ReadInt64() if err != nil { return err } @@ -222,11 +223,11 @@ func (mm *myMap) UnmarshalBSON(bytes []byte) error { type myBytes []byte -func (mb *myBytes) UnmarshalBSON(bytes []byte) error { - if len(bytes) == 0 { +func (mb *myBytes) UnmarshalBSON(b []byte) error { + if len(b) == 0 { return nil } - b, _, err := NewBSONValueReader(TypeBinary, bytes).ReadBinary() + b, _, err := NewBSONValueReader(TypeBinary, bytes.NewReader(b)).ReadBinary() if err != nil { return err } @@ -236,11 +237,11 @@ func (mb *myBytes) UnmarshalBSON(bytes []byte) error { type myString string -func (ms *myString) UnmarshalBSON(bytes []byte) error { - if len(bytes) == 0 { +func (ms *myString) UnmarshalBSON(b []byte) error { + if len(b) == 0 { return nil } - s, err := NewBSONValueReader(TypeString, bytes).ReadString() + s, err := NewBSONValueReader(TypeString, bytes.NewReader(b)).ReadString() if err != nil { return err } diff --git a/bson/value_reader.go b/bson/value_reader.go index 2726541305..4db306c443 100644 --- a/bson/value_reader.go +++ b/bson/value_reader.go @@ -13,61 +13,10 @@ import ( "fmt" "io" "math" - "sync" ) var _ ValueReader = (*valueReader)(nil) -var vrPool = sync.Pool{ - New: func() interface{} { - return new(valueReader) - }, -} - -// ValueReaderPool is a pool for ValueReaders that read BSON. -// -// Deprecated: ValueReaderPool will not be supported in Go Driver 2.0. -type ValueReaderPool struct { - pool sync.Pool -} - -// NewValueReaderPool instantiates a new ValueReaderPool. -// -// Deprecated: ValueReaderPool will not be supported in Go Driver 2.0. -func NewValueReaderPool() *ValueReaderPool { - return &ValueReaderPool{ - pool: sync.Pool{ - New: func() interface{} { - return new(valueReader) - }, - }, - } -} - -// Get retrieves a ValueReader from the pool and uses src as the underlying BSON. -// -// Deprecated: ValueReaderPool will not be supported in Go Driver 2.0. -func (bvrp *ValueReaderPool) Get(src []byte) ValueReader { - vr := bvrp.pool.Get().(*valueReader) - vr.reset(src) - return vr -} - -// Put inserts a ValueReader into the pool. If the ValueReader is not a BSON ValueReader nothing -// is inserted into the pool and ok will be false. -// -// Deprecated: ValueReaderPool will not be supported in Go Driver 2.0. -func (bvrp *ValueReaderPool) Put(vr ValueReader) (ok bool) { - bvr, ok := vr.(*valueReader) - if !ok { - return false - } - - bvr.reset(nil) - bvrp.pool.Put(bvr) - return true -} - // ErrEOA is the error returned when the end of a BSON array has been reached. var ErrEOA = errors.New("end of array") @@ -82,59 +31,46 @@ type vrState struct { // valueReader is for reading BSON values. type valueReader struct { - offset int64 - d []byte + offset int64 + d []byte + readerErr error + r io.Reader stack []vrState frame int64 } // NewValueReader returns a ValueReader using b for the underlying BSON -// representation. Parameter b must be a BSON Document. -func NewValueReader(b []byte) ValueReader { - // TODO(skriptble): There's a lack of symmetry between the reader and writer, since the reader takes a []byte while the - // TODO writer takes an io.Writer. We should have two versions of each, one that takes a []byte and one that takes an - // TODO io.Reader or io.Writer. The []byte version will need to return a thing that can return the finished []byte since - // TODO it might be reallocated when appended to. - return newValueReader(b) +// representation. +func NewValueReader(r io.Reader) ValueReader { + return newValueReader(r) } // NewBSONValueReader returns a ValueReader that starts in the Value mode instead of in top // level document mode. This enables the creation of a ValueReader for a single BSON value. -func NewBSONValueReader(t Type, val []byte) ValueReader { +func NewBSONValueReader(t Type, r io.Reader) ValueReader { stack := make([]vrState, 1, 5) stack[0] = vrState{ mode: mValue, vType: t, } return &valueReader{ - d: val, + r: r, stack: stack, } } -func newValueReader(b []byte) *valueReader { +func newValueReader(r io.Reader) *valueReader { stack := make([]vrState, 1, 5) stack[0] = vrState{ mode: mTopLevel, } return &valueReader{ - d: b, + r: r, stack: stack, } } -func (vr *valueReader) reset(b []byte) { - if vr.stack == nil { - vr.stack = make([]vrState, 1, 5) - } - vr.stack = vr.stack[:1] - vr.stack[0] = vrState{mode: mTopLevel} - vr.d = b - vr.offset = 0 - vr.frame = 0 -} - func (vr *valueReader) advanceFrame() { if vr.frame+1 >= int64(len(vr.stack)) { // We need to grow the stack length := len(vr.stack) @@ -286,15 +222,49 @@ func (vr *valueReader) nextElementLength() (int32, error) { case TypeObjectID: length = 12 case TypeRegex: - regex := bytes.IndexByte(vr.d[vr.offset:], 0x00) - if regex < 0 { - err = io.EOF - break + var offset int + var buf []byte + regex := -1 + for regex < 0 { + regex = bytes.IndexByte(vr.d[vr.offset+int64(offset):], 0x00) + if regex < 0 { + if vr.readerErr != nil { + return 0, vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + offset += n + vr.d = append(vr.d, buf[0:n]...) + } else { + regex += offset + } } - pattern := bytes.IndexByte(vr.d[vr.offset+int64(regex)+1:], 0x00) - if pattern < 0 { - err = io.EOF - break + + offset = 0 + pattern := -1 + for pattern < 0 { + pattern = bytes.IndexByte(vr.d[vr.offset+int64(regex+offset)+1:], 0x00) + if pattern < 0 { + if vr.readerErr != nil { + return 0, vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + offset += n + vr.d = append(vr.d, buf[0:n]...) + } else { + pattern += offset + } } length = int32(int64(regex) + 1 + int64(pattern) + 1) default: @@ -423,8 +393,19 @@ func (vr *valueReader) ReadDocument() (DocumentReader, error) { if err != nil { return nil, err } - if int(size) != len(vr.d) { - return nil, fmt.Errorf("invalid document length") + var buf []byte + for int(size) > len(vr.d) { + if vr.readerErr != nil { + return nil, fmt.Errorf("invalid document length") + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + vr.d = append(vr.d, buf[0:n]...) } vr.stack[vr.frame].end = int64(size) + vr.offset - 4 return vr, nil @@ -752,8 +733,19 @@ func (vr *valueReader) readBytes(length int32) ([]byte, error) { return nil, fmt.Errorf("invalid length: %d", length) } - if vr.offset+int64(length) > int64(len(vr.d)) { - return nil, io.EOF + var buf []byte + for vr.offset+int64(length) > int64(len(vr.d)) { + if vr.readerErr != nil { + return nil, vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + vr.d = append(vr.d, buf[0:n]...) } start := vr.offset @@ -763,8 +755,19 @@ func (vr *valueReader) readBytes(length int32) ([]byte, error) { } func (vr *valueReader) appendBytes(dst []byte, length int32) ([]byte, error) { - if vr.offset+int64(length) > int64(len(vr.d)) { - return nil, io.EOF + var buf []byte + for vr.offset+int64(length) > int64(len(vr.d)) { + if vr.readerErr != nil { + return nil, vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + vr.d = append(vr.d, buf[0:n]...) } start := vr.offset @@ -773,8 +776,19 @@ func (vr *valueReader) appendBytes(dst []byte, length int32) ([]byte, error) { } func (vr *valueReader) skipBytes(length int32) error { - if vr.offset+int64(length) > int64(len(vr.d)) { - return io.EOF + var buf []byte + for vr.offset+int64(length) > int64(len(vr.d)) { + if vr.readerErr != nil { + return vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + vr.d = append(vr.d, buf[0:n]...) } vr.offset += int64(length) @@ -782,8 +796,19 @@ func (vr *valueReader) skipBytes(length int32) error { } func (vr *valueReader) readByte() (byte, error) { - if vr.offset+1 > int64(len(vr.d)) { - return 0x0, io.EOF + var buf []byte + for vr.offset+1 > int64(len(vr.d)) { + if vr.readerErr != nil { + return 0x0, vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + vr.d = append(vr.d, buf[0:n]...) } vr.offset++ @@ -791,19 +816,57 @@ func (vr *valueReader) readByte() (byte, error) { } func (vr *valueReader) skipCString() error { - idx := bytes.IndexByte(vr.d[vr.offset:], 0x00) - if idx < 0 { - return io.EOF + var offset int + var buf []byte + idx := -1 + for idx < 0 { + idx = bytes.IndexByte(vr.d[vr.offset+int64(offset):], 0x00) + if idx < 0 { + if vr.readerErr != nil { + return vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + offset += n + vr.d = append(vr.d, buf[0:n]...) + } else { + idx += offset + } } + vr.offset += int64(idx) + 1 return nil } func (vr *valueReader) readCString() (string, error) { - idx := bytes.IndexByte(vr.d[vr.offset:], 0x00) - if idx < 0 { - return "", io.EOF + var offset int + var buf []byte + idx := -1 + for idx < 0 { + idx = bytes.IndexByte(vr.d[vr.offset+int64(offset):], 0x00) + if idx < 0 { + if vr.readerErr != nil { + return "", vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + offset += n + vr.d = append(vr.d, buf[0:n]...) + } else { + idx += offset + } } + start := vr.offset // idx does not include the null byte vr.offset += int64(idx) + 1 @@ -816,8 +879,19 @@ func (vr *valueReader) readString() (string, error) { return "", err } - if int64(length)+vr.offset > int64(len(vr.d)) { - return "", io.EOF + var buf []byte + for vr.offset+int64(length) > int64(len(vr.d)) { + if vr.readerErr != nil { + return "", vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + vr.d = append(vr.d, buf[0:n]...) } if length <= 0 { @@ -834,8 +908,19 @@ func (vr *valueReader) readString() (string, error) { } func (vr *valueReader) peekLength() (int32, error) { - if vr.offset+4 > int64(len(vr.d)) { - return 0, io.EOF + var buf []byte + for vr.offset+4 > int64(len(vr.d)) { + if vr.readerErr != nil { + return 0, vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + vr.d = append(vr.d, buf[0:n]...) } idx := vr.offset @@ -845,8 +930,19 @@ func (vr *valueReader) peekLength() (int32, error) { func (vr *valueReader) readLength() (int32, error) { return vr.readi32() } func (vr *valueReader) readi32() (int32, error) { - if vr.offset+4 > int64(len(vr.d)) { - return 0, io.EOF + var buf []byte + for vr.offset+4 > int64(len(vr.d)) { + if vr.readerErr != nil { + return 0, vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + vr.d = append(vr.d, buf[0:n]...) } idx := vr.offset @@ -855,8 +951,19 @@ func (vr *valueReader) readi32() (int32, error) { } func (vr *valueReader) readu32() (uint32, error) { - if vr.offset+4 > int64(len(vr.d)) { - return 0, io.EOF + var buf []byte + for vr.offset+4 > int64(len(vr.d)) { + if vr.readerErr != nil { + return 0, vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + vr.d = append(vr.d, buf[0:n]...) } idx := vr.offset @@ -865,8 +972,19 @@ func (vr *valueReader) readu32() (uint32, error) { } func (vr *valueReader) readi64() (int64, error) { - if vr.offset+8 > int64(len(vr.d)) { - return 0, io.EOF + var buf []byte + for vr.offset+8 > int64(len(vr.d)) { + if vr.readerErr != nil { + return 0, vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + vr.d = append(vr.d, buf[0:n]...) } idx := vr.offset @@ -876,8 +994,19 @@ func (vr *valueReader) readi64() (int64, error) { } func (vr *valueReader) readu64() (uint64, error) { - if vr.offset+8 > int64(len(vr.d)) { - return 0, io.EOF + var buf []byte + for vr.offset+8 > int64(len(vr.d)) { + if vr.readerErr != nil { + return 0, vr.readerErr + } + if len(buf) == 0 { + buf = make([]byte, 512) + } + n, e := vr.r.Read(buf) + if e != nil { + vr.readerErr = e + } + vr.d = append(vr.d, buf[0:n]...) } idx := vr.offset diff --git a/mongo/client_encryption.go b/mongo/client_encryption.go index 97fe9b27b7..754378b76f 100644 --- a/mongo/client_encryption.go +++ b/mongo/client_encryption.go @@ -7,6 +7,7 @@ package mongo import ( + "bytes" "context" "errors" "fmt" @@ -106,7 +107,7 @@ func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context, if err != nil { return nil, nil, err } - r := bson.NewValueReader(efBSON) + r := bson.NewValueReader(bytes.NewReader(efBSON)) dec := bson.NewDecoder(r) var m bson.M err = dec.Decode(&m) diff --git a/mongo/cursor.go b/mongo/cursor.go index 8f07b1ee9b..7be32f93f9 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -238,7 +238,7 @@ func getDecoder( opts *options.BSONOptions, reg *bson.Registry, ) *bson.Decoder { - dec := bson.NewDecoder(bson.NewValueReader(data)) + dec := bson.NewDecoder(bson.NewValueReader(bytes.NewReader(data))) if opts != nil { if opts.AllowTruncatingDoubles {