From f2e6ad12d3b019e6ab60f393be2b566c80af8842 Mon Sep 17 00:00:00 2001 From: Matt Lord Date: Tue, 22 Oct 2024 14:39:07 -0400 Subject: [PATCH] Add unit test Signed-off-by: Matt Lord --- go/mysql/binlog_event_compression.go | 11 ++-- go/mysql/binlog_event_compression_test.go | 63 +++++++++++++++++++++++ 2 files changed, 68 insertions(+), 6 deletions(-) create mode 100644 go/mysql/binlog_event_compression_test.go diff --git a/go/mysql/binlog_event_compression.go b/go/mysql/binlog_event_compression.go index b6b76822e8a..b1fffbc5284 100644 --- a/go/mysql/binlog_event_compression.go +++ b/go/mysql/binlog_event_compression.go @@ -325,11 +325,8 @@ func (tp *TransactionPayload) decompress() error { func (tp *TransactionPayload) Close() { switch reader := tp.reader.(type) { case *zstd.Decoder: - if err := reader.Reset(nil); err == nil || err == io.EOF { - statefulDecoderPool.Put(reader) - } + statefulDecoderPool.Put(reader) default: - reader = nil } tp.iterator = nil } @@ -361,8 +358,10 @@ type decoderPool struct { // Get gets a pooled OR new *zstd.Decoder. func (dp *decoderPool) Get(reader io.Reader) (*zstd.Decoder, error) { - decoder := dp.pool.Get().(*zstd.Decoder) - if decoder == nil { + var decoder *zstd.Decoder + if pooled := dp.pool.Get(); pooled != nil { + decoder = pooled.(*zstd.Decoder) + } else { d, err := zstd.NewReader(nil, zstd.WithDecoderMaxMemory(zstdInMemoryDecompressorMaxSize)) if err != nil { // Should only happen e.g. due to ENOMEM return nil, vterrors.New(vtrpcpb.Code_INTERNAL, "failed to create stateful stream decoder") diff --git a/go/mysql/binlog_event_compression_test.go b/go/mysql/binlog_event_compression_test.go new file mode 100644 index 00000000000..36ff68609b7 --- /dev/null +++ b/go/mysql/binlog_event_compression_test.go @@ -0,0 +1,63 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mysql + +import ( + "bytes" + "io" + "testing" + + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/require" +) + +func TestDecoderPool(t *testing.T) { + type args struct { + r io.Reader + } + tests := []struct { + name string + reader io.Reader + wantErr bool + }{ + { + name: "happy path", + reader: bytes.NewReader([]byte{0x68, 0x61, 0x70, 0x70, 0x79}), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoder, err := statefulDecoderPool.Get(tt.reader) + require.NoError(t, err) + require.NotNil(t, decoder) + require.IsType(t, &zstd.Decoder{}, decoder) + statefulDecoderPool.Put(decoder) + decoder2, err := statefulDecoderPool.Get(tt.reader) + require.NoError(t, err) + require.NotNil(t, decoder2) + require.IsType(t, &zstd.Decoder{}, decoder) + statefulDecoderPool.Put(decoder) + require.True(t, (decoder2 == decoder)) + statefulDecoderPool.Put(decoder2) + decoder3, err := statefulDecoderPool.Get(tt.reader) + require.NoError(t, err) + require.IsType(t, &zstd.Decoder{}, decoder) + statefulDecoderPool.Put(decoder) + require.True(t, (decoder3 == decoder2)) + }) + } +}