diff --git a/x/mongo/driver/compression.go b/x/mongo/driver/compression.go index 7f355f61a4..03a057df5e 100644 --- a/x/mongo/driver/compression.go +++ b/x/mongo/driver/compression.go @@ -26,48 +26,70 @@ type CompressionOpts struct { UncompressedSize int32 } -var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder +func zstdNewWriter(lvl zstd.EncoderLevel) *zstd.Encoder { + enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(lvl)) + if err != nil { + panic(err) + } + return enc +} + +var zstdEncoders = [zstd.SpeedBestCompression + 1]*zstd.Encoder{ + 0: nil, // zstd.speedNotSet + zstd.SpeedFastest: zstdNewWriter(zstd.SpeedFastest), + zstd.SpeedDefault: zstdNewWriter(zstd.SpeedDefault), + zstd.SpeedBetterCompression: zstdNewWriter(zstd.SpeedBetterCompression), + zstd.SpeedBestCompression: zstdNewWriter(zstd.SpeedBestCompression), +} func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { - if v, ok := zstdEncoders.Load(level); ok { - return v.(*zstd.Encoder), nil - } - encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) - if err != nil { - return nil, err + if zstd.SpeedFastest <= level && level <= zstd.SpeedBestCompression { + return zstdEncoders[level], nil } - zstdEncoders.Store(level, encoder) - return encoder, nil + // The level is invalid so call zstd.NewWriter for the error. + return zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) } -var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder +// zlibEncodersOffset is the offset into the zlibEncoders array for a given +// compression level. +const zlibEncodersOffset = -zlib.HuffmanOnly // HuffmanOnly == -2 + +var zlibEncoders [zlib.BestCompression + zlibEncodersOffset + 1]sync.Pool func getZlibEncoder(level int) (*zlibEncoder, error) { - if v, ok := zlibEncoders.Load(level); ok { - return v.(*zlibEncoder), nil - } - writer, err := zlib.NewWriterLevel(nil, level) - if err != nil { - return nil, err + if zlib.HuffmanOnly <= level && level <= zlib.BestCompression { + if enc, _ := zlibEncoders[level+zlibEncodersOffset].Get().(*zlibEncoder); enc != nil { + return enc, nil + } + writer, err := zlib.NewWriterLevel(nil, level) + if err != nil { + return nil, err + } + enc := &zlibEncoder{writer: writer, level: level} + return enc, nil } - encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} - zlibEncoders.Store(level, encoder) + // The level is invalid so call zlib.NewWriterLever for the error. + _, err := zlib.NewWriterLevel(nil, level) + return nil, err +} - return encoder, nil +func putZlibEncoder(enc *zlibEncoder) { + if enc != nil { + zlibEncoders[enc.level+zlibEncodersOffset].Put(enc) + } } type zlibEncoder struct { - mu sync.Mutex writer *zlib.Writer - buf *bytes.Buffer + buf bytes.Buffer + level int } func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { - e.mu.Lock() - defer e.mu.Unlock() + defer putZlibEncoder(e) e.buf.Reset() - e.writer.Reset(e.buf) + e.writer.Reset(&e.buf) _, err := e.writer.Write(src) if err != nil { @@ -105,8 +127,15 @@ func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { } } +var zstdReaderPool = sync.Pool{ + New: func() interface{} { + r, _ := zstd.NewReader(nil) + return r + }, +} + // DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed -func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { +func DecompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { switch opts.Compressor { case wiremessage.CompressorNoOp: return in, nil @@ -117,34 +146,28 @@ func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, er } else if int32(l) != opts.UncompressedSize { return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) } - uncompressed = make([]byte, opts.UncompressedSize) - return snappy.Decode(uncompressed, in) + out := make([]byte, opts.UncompressedSize) + return snappy.Decode(out, in) case wiremessage.CompressorZLib: r, err := zlib.NewReader(bytes.NewReader(in)) if err != nil { return nil, err } - defer func() { - err = r.Close() - }() - uncompressed = make([]byte, opts.UncompressedSize) - _, err = io.ReadFull(r, uncompressed) - if err != nil { + out := make([]byte, opts.UncompressedSize) + if _, err := io.ReadFull(r, out); err != nil { return nil, err } - return uncompressed, nil - case wiremessage.CompressorZstd: - r, err := zstd.NewReader(bytes.NewBuffer(in)) - if err != nil { - return nil, err - } - defer r.Close() - uncompressed = make([]byte, opts.UncompressedSize) - _, err = io.ReadFull(r, uncompressed) - if err != nil { + if err := r.Close(); err != nil { return nil, err } - return uncompressed, nil + return out, nil + case wiremessage.CompressorZstd: + // Using a pool here is about ~20% faster + // than using a single global zstd.Reader + r := zstdReaderPool.Get().(*zstd.Decoder) + out, err := r.DecodeAll(in, nil) + zstdReaderPool.Put(r) + return out, err default: return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) } diff --git a/x/mongo/driver/compression_test.go b/x/mongo/driver/compression_test.go index 5557257334..acca1317c6 100644 --- a/x/mongo/driver/compression_test.go +++ b/x/mongo/driver/compression_test.go @@ -46,6 +46,46 @@ func TestCompression(t *testing.T) { } } +func TestCompressionLevels(t *testing.T) { + errEq := func(e1, e2 error) bool { + if e1 == nil || e2 == nil { + return (e1 == nil) == (e2 == nil) + } + return e1.Error() == e2.Error() + } + + in := []byte("abc") + wr := new(bytes.Buffer) + + t.Run("ZLib", func(t *testing.T) { + opts := CompressionOpts{ + Compressor: wiremessage.CompressorZLib, + } + for lvl := zlib.HuffmanOnly - 2; lvl < zlib.BestCompression+2; lvl++ { + opts.ZlibLevel = lvl + _, err1 := CompressPayload(in, opts) + _, err2 := zlib.NewWriterLevel(wr, lvl) + if !errEq(err1, err2) { + t.Fatalf("%d: error: %v, want: %v", lvl, err1, err2) + } + } + }) + + t.Run("Zstd", func(t *testing.T) { + opts := CompressionOpts{ + Compressor: wiremessage.CompressorZstd, + } + for lvl := zstd.SpeedFastest - 2; lvl < zstd.SpeedBestCompression+2; lvl++ { + opts.ZstdLevel = int(lvl) + _, err1 := CompressPayload(in, opts) + _, err2 := zstd.NewWriter(wr, zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(opts.ZstdLevel))) + if !errEq(err1, err2) { + t.Fatalf("%d: error: %v, want: %v", lvl, err1, err2) + } + } + }) +} + func TestDecompressFailures(t *testing.T) { t.Parallel() diff --git a/x/mongo/driver/testdata/compression.go b/x/mongo/driver/testdata/compression.go new file mode 100644 index 0000000000..7f355f61a4 --- /dev/null +++ b/x/mongo/driver/testdata/compression.go @@ -0,0 +1,151 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driver + +import ( + "bytes" + "compress/zlib" + "fmt" + "io" + "sync" + + "github.com/golang/snappy" + "github.com/klauspost/compress/zstd" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" +) + +// CompressionOpts holds settings for how to compress a payload +type CompressionOpts struct { + Compressor wiremessage.CompressorID + ZlibLevel int + ZstdLevel int + UncompressedSize int32 +} + +var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder + +func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { + if v, ok := zstdEncoders.Load(level); ok { + return v.(*zstd.Encoder), nil + } + encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) + if err != nil { + return nil, err + } + zstdEncoders.Store(level, encoder) + return encoder, nil +} + +var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder + +func getZlibEncoder(level int) (*zlibEncoder, error) { + if v, ok := zlibEncoders.Load(level); ok { + return v.(*zlibEncoder), nil + } + writer, err := zlib.NewWriterLevel(nil, level) + if err != nil { + return nil, err + } + encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} + zlibEncoders.Store(level, encoder) + + return encoder, nil +} + +type zlibEncoder struct { + mu sync.Mutex + writer *zlib.Writer + buf *bytes.Buffer +} + +func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { + e.mu.Lock() + defer e.mu.Unlock() + + e.buf.Reset() + e.writer.Reset(e.buf) + + _, err := e.writer.Write(src) + if err != nil { + return nil, err + } + err = e.writer.Close() + if err != nil { + return nil, err + } + dst = append(dst[:0], e.buf.Bytes()...) + return dst, nil +} + +// CompressPayload takes a byte slice and compresses it according to the options passed +func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { + switch opts.Compressor { + case wiremessage.CompressorNoOp: + return in, nil + case wiremessage.CompressorSnappy: + return snappy.Encode(nil, in), nil + case wiremessage.CompressorZLib: + encoder, err := getZlibEncoder(opts.ZlibLevel) + if err != nil { + return nil, err + } + return encoder.Encode(nil, in) + case wiremessage.CompressorZstd: + encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel)) + if err != nil { + return nil, err + } + return encoder.EncodeAll(in, nil), nil + default: + return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) + } +} + +// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed +func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { + switch opts.Compressor { + case wiremessage.CompressorNoOp: + return in, nil + case wiremessage.CompressorSnappy: + l, err := snappy.DecodedLen(in) + if err != nil { + return nil, fmt.Errorf("decoding compressed length %w", err) + } else if int32(l) != opts.UncompressedSize { + return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) + } + uncompressed = make([]byte, opts.UncompressedSize) + return snappy.Decode(uncompressed, in) + case wiremessage.CompressorZLib: + r, err := zlib.NewReader(bytes.NewReader(in)) + if err != nil { + return nil, err + } + defer func() { + err = r.Close() + }() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { + return nil, err + } + return uncompressed, nil + case wiremessage.CompressorZstd: + r, err := zstd.NewReader(bytes.NewBuffer(in)) + if err != nil { + return nil, err + } + defer r.Close() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { + return nil, err + } + return uncompressed, nil + default: + return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) + } +}