Skip to content

Commit

Permalink
refactor after rebase
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel Ángel Ortuño <[email protected]>
  • Loading branch information
ortuman committed Oct 1, 2024
1 parent 22dba64 commit 95bcb25
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 79 deletions.
10 changes: 0 additions & 10 deletions pkg/kgo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"time"

"github.com/twmb/franz-go/pkg/kerr"
"github.com/twmb/franz-go/pkg/kgo/internal/pool"
"github.com/twmb/franz-go/pkg/kmsg"
"github.com/twmb/franz-go/pkg/sasl"
)
Expand Down Expand Up @@ -454,15 +453,6 @@ func NewClient(opts ...Opt) (*Client, error) {
}
}

// Allow reusing decompression buffers if record pooling has been enabled
// via EnableRecordsPool option.
var decompressorPool *pool.BucketedPool[byte]
if cfg.recordsPool.p != nil {
decompressorPool = pool.NewBucketedPool[byte](1024, maxPoolDecodedBufferSize, 2, func(size int) []byte {
return make([]byte, 0, size)
})
}

ctx, cancel := context.WithCancel(context.Background())
cl := &Client{
cfg: cfg,
Expand Down
50 changes: 27 additions & 23 deletions pkg/kgo/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ import (
"github.com/klauspost/compress/s2"
"github.com/klauspost/compress/zstd"
"github.com/pierrec/lz4/v4"
)

const maxPoolDecodedBufferSize = 8<<20 // 8MB
var decodedBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, maxPoolDecodedBufferSize)) }}
"github.com/twmb/franz-go/pkg/kgo/internal/pool"
)

var byteBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, 8<<10)) }}

Expand Down Expand Up @@ -233,9 +232,9 @@ func (c *compressor) compress(dst *bytes.Buffer, src []byte, produceRequestVersi
}

type decompressor struct {
ungzPool sync.Pool
unlz4Pool sync.Pool
unzstdPool sync.Pool
ungzPool sync.Pool
unlz4Pool sync.Pool
unzstdPool sync.Pool
}

var defaultDecompressor = newDecompressor()
Expand Down Expand Up @@ -269,21 +268,26 @@ type zstdDecoder struct {
inner *zstd.Decoder
}

func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) {
func (d *decompressor) decompress(src []byte, codec byte, pool *pool.BucketedPool[byte]) ([]byte, error) {
// Early return in case there is no compression
compCodec := codecType(codec)
if compCodec == codecNone {
return src, nil
}

out := decodedBuffers.Get().(*bytes.Buffer)
out.Reset()
defer func() {
if out.Cap() > maxPoolDecodedBufferSize {
return // avoid keeping large buffers in the pool
}
decodedBuffers.Put(out)
}()
var out *bytes.Buffer

if pool != nil {
// Assume the worst case scenario here, where decompressed buffer size is FetchMaxBytes.
outBuf := pool.Get(pool.MaxSize())[:0]
defer func() {
pool.Put(outBuf)
}()
out = bytes.NewBuffer(outBuf)
} else {
out = byteBuffers.Get().(*bytes.Buffer)
out.Reset()
defer byteBuffers.Put(out)
}

switch compCodec {
case codecGzip:
Expand All @@ -295,7 +299,7 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) {
if _, err := io.Copy(out, ungz); err != nil {
return nil, err
}
return d.copyDecodedBuffer(out.Bytes()), nil
return d.copyDecodedBuffer(out.Bytes(), pool), nil
case codecSnappy:
if len(src) > 16 && bytes.HasPrefix(src, xerialPfx) {
return xerialDecode(src)
Expand All @@ -304,33 +308,33 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) {
if err != nil {
return nil, err
}
return d.copyDecodedBuffer(decoded), nil
return d.copyDecodedBuffer(decoded, pool), nil
case codecLZ4:
unlz4 := d.unlz4Pool.Get().(*lz4.Reader)
defer d.unlz4Pool.Put(unlz4)
unlz4.Reset(bytes.NewReader(src))
if _, err := io.Copy(out, unlz4); err != nil {
return nil, err
}
return d.copyDecodedBuffer(out.Bytes()), nil
return d.copyDecodedBuffer(out.Bytes(), pool), nil
case codecZstd:
unzstd := d.unzstdPool.Get().(*zstdDecoder)
defer d.unzstdPool.Put(unzstd)
decoded, err := unzstd.inner.DecodeAll(src, out.Bytes())
if err != nil {
return nil, err
}
return d.copyDecodedBuffer(decoded), nil
return d.copyDecodedBuffer(decoded, pool), nil
default:
return nil, errors.New("unknown compression codec")
}
}

func (d *decompressor) copyDecodedBuffer(decoded []byte) []byte {
if d.outBufferPool == nil {
func (d *decompressor) copyDecodedBuffer(decoded []byte, pool *pool.BucketedPool[byte]) []byte {
if pool == nil {
return append([]byte(nil), decoded...)
}
out := d.outBufferPool.Get(len(decoded))
out := pool.Get(len(decoded))
return append(out[:0], decoded...)
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/kgo/compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestCompressDecompress(t *testing.T) {
}

t.Parallel()
d := newDecompressor(nil)
d := newDecompressor()
inputs := [][]byte{
randStr(1 << 2),
randStr(1 << 5),
Expand Down Expand Up @@ -110,7 +110,7 @@ func TestCompressDecompress(t *testing.T) {
w.Reset()

got, used := c.compress(w, in, produceVersion)
got, err := d.decompress(got, byte(used))
got, err := d.decompress(got, byte(used), nil)
if err != nil {
t.Errorf("unexpected decompress err: %v", err)
return
Expand Down Expand Up @@ -155,8 +155,8 @@ func BenchmarkDecompress(b *testing.B) {

b.Run(fmt.Sprint(codec), func(b *testing.B) {
for i := 0; i < b.N; i++ {
d := newDecompressor(nil)
d.decompress(w.Bytes(), byte(codec))
d := newDecompressor()
d.decompress(w.Bytes(), byte(codec), nil)
}
})
byteBuffers.Put(w)
Expand Down
10 changes: 9 additions & 1 deletion pkg/kgo/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/twmb/franz-go/pkg/kmsg"
"github.com/twmb/franz-go/pkg/kversion"
"github.com/twmb/franz-go/pkg/sasl"

"github.com/twmb/franz-go/pkg/kgo/internal/pool"
)

// Opt is an option to configure a client.
Expand Down Expand Up @@ -151,7 +153,8 @@ type cfg struct {
partitions map[string]map[int32]Offset // partitions to directly consume from
regex bool

recordsPool recordsPool
recordsPool *recordsPool
decompressBufferPool *pool.BucketedPool[byte]

////////////////////////////
// CONSUMER GROUP SECTION //
Expand Down Expand Up @@ -391,6 +394,11 @@ func (cfg *cfg) validate() error {
}
cfg.hooks = processedHooks

if cfg.recordsPool != nil {
cfg.decompressBufferPool = pool.NewBucketedPool[byte](4096, int(cfg.maxBytes.load()), 2, func(sz int) []byte {
return make([]byte, sz)
})
}
return nil
}

Expand Down
5 changes: 5 additions & 0 deletions pkg/kgo/internal/pool/bucketed_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,8 @@ func (p *BucketedPool[T]) Put(s []T) {
return
}
}

// MaxSize returns the maximum size of a slice in the pool.
func (p *BucketedPool[T]) MaxSize() int {
return p.sizes[len(p.sizes)-1]
}
8 changes: 3 additions & 5 deletions pkg/kgo/record_and_fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ type Record struct {
// recordsPool is the pool that this record was fetched from, if any.
//
// When reused, record is returned to this pool.
recordsPool recordsPool
recordsPool *recordsPool

// rcBatchBuffer is used to keep track of the raw buffer that this record was
// derived from when consuming, after decompression.
Expand All @@ -178,13 +178,11 @@ type Record struct {
// Once this method has been called, any reference to the passed record should be considered invalid by the caller,
// as it may be reused as a result of future calls to the PollFetches/PollRecords method.
func (r *Record) Reuse() {
if r.rcRawRecordsBuffer != nil {
if r.recordsPool != nil {
r.rcRawRecordsBuffer.release()
}
if r.rcBatchBuffer != nil {
r.rcBatchBuffer.release()
r.recordsPool.put(r)
}
r.recordsPool.put(r)
}

func (r *Record) userSize() int64 {
Expand Down
Loading

0 comments on commit 95bcb25

Please sign in to comment.