diff --git a/bson/bsoncodec/slice_codec.go b/bson/bsoncodec/slice_codec.go index 20c3e7549c..a43daf005f 100644 --- a/bson/bsoncodec/slice_codec.go +++ b/bson/bsoncodec/slice_codec.go @@ -62,7 +62,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val re } // If we have a []primitive.E we want to treat it as a document instead of as an array. - if val.Type().ConvertibleTo(tD) { + if val.Type() == tD || val.Type().ConvertibleTo(tD) { d := val.Convert(tD).Interface().(primitive.D) dw, err := vw.WriteDocument() diff --git a/bson/bsoncodec/struct_codec.go b/bson/bsoncodec/struct_codec.go index 29ea76d19c..4cde0a4d6b 100644 --- a/bson/bsoncodec/struct_codec.go +++ b/bson/bsoncodec/struct_codec.go @@ -190,15 +190,14 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val encoder := desc.encoder var zero bool - rvInterface := rv.Interface() if cz, ok := encoder.(CodecZeroer); ok { - zero = cz.IsTypeZero(rvInterface) + zero = cz.IsTypeZero(rv.Interface()) } else if rv.Kind() == reflect.Interface { // isZero will not treat an interface rv as an interface, so we need to check for the // zero interface separately. zero = rv.IsNil() } else { - zero = isZero(rvInterface, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct) + zero = isZero(rv, sc.EncodeOmitDefaultStruct || ec.omitZeroStruct) } if desc.omitEmpty && zero { continue @@ -392,56 +391,32 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val return nil } -func isZero(i interface{}, omitZeroStruct bool) bool { - v := reflect.ValueOf(i) - - // check the value validity - if !v.IsValid() { - return true +func isZero(v reflect.Value, omitZeroStruct bool) bool { + kind := v.Kind() + if (kind != reflect.Ptr || !v.IsNil()) && v.Type().Implements(tZeroer) { + return v.Interface().(Zeroer).IsZero() } - - if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) { - return z.IsZero() - } - - switch v.Kind() { - case reflect.Array, reflect.Map, reflect.Slice, reflect.String: - return v.Len() == 0 - case reflect.Bool: - return !v.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return v.Uint() == 0 - case reflect.Float32, reflect.Float64: - return v.Float() == 0 - case reflect.Interface, reflect.Ptr: - return v.IsNil() - case reflect.Struct: + if kind == reflect.Struct { if !omitZeroStruct { return false } - - // TODO(GODRIVER-2820): Update the logic to be able to handle private struct fields. - // TODO Use condition "reflect.Zero(v.Type()).Equal(v)" instead. - vt := v.Type() if vt == tTime { return v.Interface().(time.Time).IsZero() } - for i := 0; i < v.NumField(); i++ { - if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous { + numField := vt.NumField() + for i := 0; i < numField; i++ { + ff := vt.Field(i) + if ff.PkgPath != "" && !ff.Anonymous { continue // Private field } - fld := v.Field(i) - if !isZero(fld.Interface(), omitZeroStruct) { + if !isZero(v.Field(i), omitZeroStruct) { return false } } return true } - - return false + return !v.IsValid() || v.IsZero() } type structDescription struct { @@ -708,21 +683,21 @@ func getInlineField(val reflect.Value, index []int) (reflect.Value, error) { // DeepZero returns recursive zero object func deepZero(st reflect.Type) (result reflect.Value) { - result = reflect.Indirect(reflect.New(st)) - - if result.Kind() == reflect.Struct { - for i := 0; i < result.NumField(); i++ { - if f := result.Field(i); f.Kind() == reflect.Ptr { - if f.CanInterface() { - if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct { - result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem()))) - } + if st.Kind() == reflect.Struct { + numField := st.NumField() + for i := 0; i < numField; i++ { + if result == emptyValue { + result = reflect.Indirect(reflect.New(st)) + } + f := result.Field(i) + if f.CanInterface() { + if f.Type().Kind() == reflect.Struct { + result.Field(i).Set(recursivePointerTo(deepZero(f.Type().Elem()))) } } } } - - return + return result } // recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside diff --git a/bson/bsoncodec/struct_codec_test.go b/bson/bsoncodec/struct_codec_test.go index 008fc11528..573b374b14 100644 --- a/bson/bsoncodec/struct_codec_test.go +++ b/bson/bsoncodec/struct_codec_test.go @@ -7,6 +7,7 @@ package bsoncodec import ( + "reflect" "testing" "time" @@ -147,7 +148,7 @@ func TestIsZero(t *testing.T) { t.Run(tc.description, func(t *testing.T) { t.Parallel() - got := isZero(tc.value, tc.omitZeroStruct) + got := isZero(reflect.ValueOf(tc.value), tc.omitZeroStruct) assert.Equal(t, tc.want, got, "expected and actual isZero return are different") }) } diff --git a/bson/bsoncodec/types.go b/bson/bsoncodec/types.go index 07f4b70e6d..6ade17b7d3 100644 --- a/bson/bsoncodec/types.go +++ b/bson/bsoncodec/types.go @@ -34,6 +34,7 @@ var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem() var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem() var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem() var tProxy = reflect.TypeOf((*Proxy)(nil)).Elem() +var tZeroer = reflect.TypeOf((*Zeroer)(nil)).Elem() var tBinary = reflect.TypeOf(primitive.Binary{}) var tUndefined = reflect.TypeOf(primitive.Undefined{}) diff --git a/bson/bsonrw/copier.go b/bson/bsonrw/copier.go index 33d59bd258..c146d02e58 100644 --- a/bson/bsonrw/copier.go +++ b/bson/bsonrw/copier.go @@ -193,7 +193,7 @@ func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) } vw := vwPool.Get().(*valueWriter) - defer vwPool.Put(vw) + defer putValueWriter(vw) vw.reset(dst) @@ -213,7 +213,7 @@ func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) { } vw := vwPool.Get().(*valueWriter) - defer vwPool.Put(vw) + defer putValueWriter(vw) vw.reset(dst) @@ -258,7 +258,7 @@ func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, [] } vw := vwPool.Get().(*valueWriter) - defer vwPool.Put(vw) + defer putValueWriter(vw) start := len(dst) diff --git a/bson/bsonrw/value_reader.go b/bson/bsonrw/value_reader.go index 9bf24fae0b..a242bb57cf 100644 --- a/bson/bsonrw/value_reader.go +++ b/bson/bsonrw/value_reader.go @@ -739,8 +739,7 @@ func (vr *valueReader) ReadValue() (ValueReader, error) { return nil, ErrEOA } - _, err = vr.readCString() - if err != nil { + if err := vr.skipCString(); err != nil { return nil, err } @@ -794,6 +793,15 @@ func (vr *valueReader) readByte() (byte, error) { return vr.d[vr.offset-1], nil } +func (vr *valueReader) skipCString() error { + idx := bytes.IndexByte(vr.d[vr.offset:], 0x00) + if idx < 0 { + return io.EOF + } + vr.offset += int64(idx) + 1 + return nil +} + func (vr *valueReader) readCString() (string, error) { idx := bytes.IndexByte(vr.d[vr.offset:], 0x00) if idx < 0 { diff --git a/bson/bsonrw/value_writer.go b/bson/bsonrw/value_writer.go index a6dd8d34f5..311518a80d 100644 --- a/bson/bsonrw/value_writer.go +++ b/bson/bsonrw/value_writer.go @@ -28,6 +28,13 @@ var vwPool = sync.Pool{ }, } +func putValueWriter(vw *valueWriter) { + if vw != nil { + vw.w = nil // don't leak the writer + vwPool.Put(vw) + } +} + // BSONValueWriterPool is a pool for BSON ValueWriters. // // Deprecated: BSONValueWriterPool will not be supported in Go Driver 2.0. @@ -149,32 +156,21 @@ type valueWriter struct { } func (vw *valueWriter) advanceFrame() { - if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack - length := len(vw.stack) - if length+1 >= cap(vw.stack) { - // double it - buf := make([]vwState, 2*cap(vw.stack)+1) - copy(buf, vw.stack) - vw.stack = buf - } - vw.stack = vw.stack[:length+1] - } vw.frame++ + if vw.frame >= int64(len(vw.stack)) { + vw.stack = append(vw.stack, vwState{}) + } } func (vw *valueWriter) push(m mode) { vw.advanceFrame() // Clean the stack - vw.stack[vw.frame].mode = m - vw.stack[vw.frame].key = "" - vw.stack[vw.frame].arrkey = 0 - vw.stack[vw.frame].start = 0 + vw.stack[vw.frame] = vwState{mode: m} - vw.stack[vw.frame].mode = m switch m { case mDocument, mArray, mCodeWithScope: - vw.reserveLength() + vw.reserveLength() // WARN: this is not needed } } @@ -213,6 +209,7 @@ func newValueWriter(w io.Writer) *valueWriter { return vw } +// TODO: only used in tests func newValueWriterFromSlice(buf []byte) *valueWriter { vw := new(valueWriter) stack := make([]vwState, 1, 5) @@ -249,17 +246,16 @@ func (vw *valueWriter) invalidTransitionError(destination mode, name string, mod } func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error { - switch vw.stack[vw.frame].mode { + frame := &vw.stack[vw.frame] + switch frame.mode { case mElement: - key := vw.stack[vw.frame].key + key := frame.key if !isValidCString(key) { return errors.New("BSON element key cannot contain null bytes") } - - vw.buf = bsoncore.AppendHeader(vw.buf, t, key) + vw.appendHeader(t, key) case mValue: - // TODO: Do this with a cache of the first 1000 or so array keys. - vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey)) + vw.appendIntHeader(t, frame.arrkey) default: modes := []mode{mElement, mValue} if addmodes != nil { @@ -601,9 +597,11 @@ func (vw *valueWriter) writeLength() error { if length > maxSize { return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))} } - length = length - int(vw.stack[vw.frame].start) - start := vw.stack[vw.frame].start + frame := &vw.stack[vw.frame] + length = length - int(frame.start) + start := frame.start + _ = vw.buf[start+3] // BCE vw.buf[start+0] = byte(length) vw.buf[start+1] = byte(length >> 8) vw.buf[start+2] = byte(length >> 16) @@ -612,5 +610,31 @@ func (vw *valueWriter) writeLength() error { } func isValidCString(cs string) bool { - return !strings.ContainsRune(cs, '\x00') + // Disallow the zero byte in a cstring because the zero byte is used as the + // terminating character. + // + // It's safe to check bytes instead of runes because all multibyte UTF-8 + // code points start with (binary) 11xxxxxx or 10xxxxxx, so 00000000 (i.e. + // 0) will never be part of a multibyte UTF-8 code point. This logic is the + // same as the "r < utf8.RuneSelf" case in strings.IndexRune but can be + // inlined. + // + // https://cs.opensource.google/go/go/+/refs/tags/go1.21.1:src/strings/strings.go;l=127 + return strings.IndexByte(cs, 0) == -1 +} + +// appendHeader is the same as bsoncore.AppendHeader but does not check if the +// key is a valid C string since the caller has already checked for that. +// +// The caller of this function must check if key is a valid C string. +func (vw *valueWriter) appendHeader(t bsontype.Type, key string) { + vw.buf = bsoncore.AppendType(vw.buf, t) + vw.buf = append(vw.buf, key...) + vw.buf = append(vw.buf, 0x00) +} + +func (vw *valueWriter) appendIntHeader(t bsontype.Type, key int) { + vw.buf = bsoncore.AppendType(vw.buf, t) + vw.buf = strconv.AppendInt(vw.buf, int64(key), 10) + vw.buf = append(vw.buf, 0x00) } diff --git a/bson/marshal.go b/bson/marshal.go index f2c48d049e..17ce6697e0 100644 --- a/bson/marshal.go +++ b/bson/marshal.go @@ -9,6 +9,7 @@ package bson import ( "bytes" "encoding/json" + "sync" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsonrw" @@ -141,6 +142,13 @@ func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{ return MarshalAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val) } +// Pool of buffers for marshalling BSON. +var bufPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + // MarshalAppendWithContext will encode val as a BSON document using Registry r and EncodeContext ec and append the // bytes to dst. If dst is not large enough to hold the bytes, it will be grown. If val is not a type that can be // transformed into a document, MarshalValueAppendWithContext should be used instead. @@ -162,8 +170,26 @@ func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{ // // See [Encoder] for more examples. func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) ([]byte, error) { - sw := new(bsonrw.SliceWriter) - *sw = dst + sw := bufPool.Get().(*bytes.Buffer) + defer func() { + // Proper usage of a sync.Pool requires each entry to have approximately + // the same memory cost. To obtain this property when the stored type + // contains a variably-sized buffer, we add a hard limit on the maximum + // buffer to place back in the pool. We limit the size to 16MiB because + // that's the maximum wire message size supported by any current MongoDB + // server. + // + // Comment based on + // https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/fmt/print.go;l=147 + // + // Recycle byte slices that are smaller than 16MiB and at least half + // occupied. + if sw.Cap() < 16*1024*1024 && sw.Cap()/2 < sw.Len() { + bufPool.Put(sw) + } + }() + + sw.Reset() vw := bvwPool.Get(sw) defer bvwPool.Put(vw) @@ -184,7 +210,7 @@ func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interf return nil, err } - return *sw, nil + return append(dst, sw.Bytes()...), nil } // MarshalValue returns the BSON encoding of val.