diff --git a/pkg/kgo/client.go b/pkg/kgo/client.go index 6b3263d8..4f57ea37 100644 --- a/pkg/kgo/client.go +++ b/pkg/kgo/client.go @@ -25,7 +25,6 @@ import ( "time" "github.com/twmb/franz-go/pkg/kerr" - "github.com/twmb/franz-go/pkg/kgo/internal/pool" "github.com/twmb/franz-go/pkg/kmsg" "github.com/twmb/franz-go/pkg/sasl" ) @@ -454,15 +453,6 @@ func NewClient(opts ...Opt) (*Client, error) { } } - // Allow reusing decompression buffers if record pooling has been enabled - // via EnableRecordsPool option. - var decompressorPool *pool.BucketedPool[byte] - if cfg.recordsPool.p != nil { - decompressorPool = pool.NewBucketedPool[byte](1024, maxPoolDecodedBufferSize, 2, func(size int) []byte { - return make([]byte, 0, size) - }) - } - ctx, cancel := context.WithCancel(context.Background()) cl := &Client{ cfg: cfg, diff --git a/pkg/kgo/compression.go b/pkg/kgo/compression.go index ba123253..569050be 100644 --- a/pkg/kgo/compression.go +++ b/pkg/kgo/compression.go @@ -12,10 +12,9 @@ import ( "github.com/klauspost/compress/s2" "github.com/klauspost/compress/zstd" "github.com/pierrec/lz4/v4" -) -const maxPoolDecodedBufferSize = 8<<20 // 8MB -var decodedBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, maxPoolDecodedBufferSize)) }} + "github.com/twmb/franz-go/pkg/kgo/internal/pool" +) var byteBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, 8<<10)) }} @@ -233,9 +232,9 @@ func (c *compressor) compress(dst *bytes.Buffer, src []byte, produceRequestVersi } type decompressor struct { - ungzPool sync.Pool - unlz4Pool sync.Pool - unzstdPool sync.Pool + ungzPool sync.Pool + unlz4Pool sync.Pool + unzstdPool sync.Pool } var defaultDecompressor = newDecompressor() @@ -269,21 +268,26 @@ type zstdDecoder struct { inner *zstd.Decoder } -func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) { +func (d *decompressor) decompress(src []byte, codec byte, pool *pool.BucketedPool[byte]) ([]byte, error) { // Early return in case there is no compression compCodec := codecType(codec) if compCodec == codecNone { return src, nil } - - out := decodedBuffers.Get().(*bytes.Buffer) - out.Reset() - defer func() { - if out.Cap() > maxPoolDecodedBufferSize { - return // avoid keeping large buffers in the pool - } - decodedBuffers.Put(out) - }() + var out *bytes.Buffer + + if pool != nil { + // Assume the worst case scenario here, where decompressed buffer size is FetchMaxBytes. + outBuf := pool.Get(pool.MaxSize())[:0] + defer func() { + pool.Put(outBuf) + }() + out = bytes.NewBuffer(outBuf) + } else { + out = byteBuffers.Get().(*bytes.Buffer) + out.Reset() + defer byteBuffers.Put(out) + } switch compCodec { case codecGzip: @@ -295,7 +299,7 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) { if _, err := io.Copy(out, ungz); err != nil { return nil, err } - return d.copyDecodedBuffer(out.Bytes()), nil + return d.copyDecodedBuffer(out.Bytes(), pool), nil case codecSnappy: if len(src) > 16 && bytes.HasPrefix(src, xerialPfx) { return xerialDecode(src) @@ -304,7 +308,7 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) { if err != nil { return nil, err } - return d.copyDecodedBuffer(decoded), nil + return d.copyDecodedBuffer(decoded, pool), nil case codecLZ4: unlz4 := d.unlz4Pool.Get().(*lz4.Reader) defer d.unlz4Pool.Put(unlz4) @@ -312,7 +316,7 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) { if _, err := io.Copy(out, unlz4); err != nil { return nil, err } - return d.copyDecodedBuffer(out.Bytes()), nil + return d.copyDecodedBuffer(out.Bytes(), pool), nil case codecZstd: unzstd := d.unzstdPool.Get().(*zstdDecoder) defer d.unzstdPool.Put(unzstd) @@ -320,17 +324,17 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) { if err != nil { return nil, err } - return d.copyDecodedBuffer(decoded), nil + return d.copyDecodedBuffer(decoded, pool), nil default: return nil, errors.New("unknown compression codec") } } -func (d *decompressor) copyDecodedBuffer(decoded []byte) []byte { - if d.outBufferPool == nil { +func (d *decompressor) copyDecodedBuffer(decoded []byte, pool *pool.BucketedPool[byte]) []byte { + if pool == nil { return append([]byte(nil), decoded...) } - out := d.outBufferPool.Get(len(decoded)) + out := pool.Get(len(decoded)) return append(out[:0], decoded...) } diff --git a/pkg/kgo/compression_test.go b/pkg/kgo/compression_test.go index 1468b39d..5dd82b8b 100644 --- a/pkg/kgo/compression_test.go +++ b/pkg/kgo/compression_test.go @@ -71,7 +71,7 @@ func TestCompressDecompress(t *testing.T) { } t.Parallel() - d := newDecompressor(nil) + d := newDecompressor() inputs := [][]byte{ randStr(1 << 2), randStr(1 << 5), @@ -110,7 +110,7 @@ func TestCompressDecompress(t *testing.T) { w.Reset() got, used := c.compress(w, in, produceVersion) - got, err := d.decompress(got, byte(used)) + got, err := d.decompress(got, byte(used), nil) if err != nil { t.Errorf("unexpected decompress err: %v", err) return @@ -155,8 +155,8 @@ func BenchmarkDecompress(b *testing.B) { b.Run(fmt.Sprint(codec), func(b *testing.B) { for i := 0; i < b.N; i++ { - d := newDecompressor(nil) - d.decompress(w.Bytes(), byte(codec)) + d := newDecompressor() + d.decompress(w.Bytes(), byte(codec), nil) } }) byteBuffers.Put(w) diff --git a/pkg/kgo/config.go b/pkg/kgo/config.go index 512e68df..59e1a3f9 100644 --- a/pkg/kgo/config.go +++ b/pkg/kgo/config.go @@ -16,6 +16,8 @@ import ( "github.com/twmb/franz-go/pkg/kmsg" "github.com/twmb/franz-go/pkg/kversion" "github.com/twmb/franz-go/pkg/sasl" + + "github.com/twmb/franz-go/pkg/kgo/internal/pool" ) // Opt is an option to configure a client. @@ -151,7 +153,8 @@ type cfg struct { partitions map[string]map[int32]Offset // partitions to directly consume from regex bool - recordsPool recordsPool + recordsPool *recordsPool + decompressBufferPool *pool.BucketedPool[byte] //////////////////////////// // CONSUMER GROUP SECTION // @@ -391,6 +394,11 @@ func (cfg *cfg) validate() error { } cfg.hooks = processedHooks + if cfg.recordsPool != nil { + cfg.decompressBufferPool = pool.NewBucketedPool[byte](4096, int(cfg.maxBytes.load()), 2, func(sz int) []byte { + return make([]byte, sz) + }) + } return nil } diff --git a/pkg/kgo/internal/pool/bucketed_pool.go b/pkg/kgo/internal/pool/bucketed_pool.go index d6af0390..0710130b 100644 --- a/pkg/kgo/internal/pool/bucketed_pool.go +++ b/pkg/kgo/internal/pool/bucketed_pool.go @@ -90,3 +90,8 @@ func (p *BucketedPool[T]) Put(s []T) { return } } + +// MaxSize returns the maximum size of a slice in the pool. +func (p *BucketedPool[T]) MaxSize() int { + return p.sizes[len(p.sizes)-1] +} diff --git a/pkg/kgo/record_and_fetch.go b/pkg/kgo/record_and_fetch.go index 0961387c..dbfb9451 100644 --- a/pkg/kgo/record_and_fetch.go +++ b/pkg/kgo/record_and_fetch.go @@ -155,7 +155,7 @@ type Record struct { // recordsPool is the pool that this record was fetched from, if any. // // When reused, record is returned to this pool. - recordsPool recordsPool + recordsPool *recordsPool // rcBatchBuffer is used to keep track of the raw buffer that this record was // derived from when consuming, after decompression. @@ -178,13 +178,11 @@ type Record struct { // Once this method has been called, any reference to the passed record should be considered invalid by the caller, // as it may be reused as a result of future calls to the PollFetches/PollRecords method. func (r *Record) Reuse() { - if r.rcRawRecordsBuffer != nil { + if r.recordsPool != nil { r.rcRawRecordsBuffer.release() - } - if r.rcBatchBuffer != nil { r.rcBatchBuffer.release() + r.recordsPool.put(r) } - r.recordsPool.put(r) } func (r *Record) userSize() int64 { diff --git a/pkg/kgo/source.go b/pkg/kgo/source.go index c5815989..7ea01f76 100644 --- a/pkg/kgo/source.go +++ b/pkg/kgo/source.go @@ -20,23 +20,17 @@ import ( type recordsPool struct{ p *sync.Pool } -func newRecordsPool() recordsPool { - return recordsPool{ +func newRecordsPool() *recordsPool { + return &recordsPool{ p: &sync.Pool{New: func() any { return &Record{} }}, } } -func (p recordsPool) get() *Record { - if p.p == nil { - return &Record{} - } +func (p *recordsPool) get() *Record { return p.p.Get().(*Record) } -func (p recordsPool) put(r *Record) { - if p.p == nil { - return - } +func (p *recordsPool) put(r *Record) { *r = Record{} // zero out the record p.p.Put(r) } @@ -166,7 +160,10 @@ type ProcessFetchPartitionOptions struct { Partition int32 // recordsPool is for internal use only. - recordPool recordsPool + recordPool *recordsPool + + // decompressBufferPool is for internal use only. + decompressBufferPool *pool.BucketedPool[byte] } // cursor is where we are consuming from for an individual partition. @@ -1145,7 +1142,7 @@ func (s *source) handleReqResp(br *broker, req *fetchRequest, resp *kmsg.FetchRe continue } - fp := partOffset.processRespPartition(br, rp, s.cl.cfg.hooks, s.cl.cfg.recordsPool) + fp := partOffset.processRespPartition(br, rp, s.cl.cfg.hooks, s.cl.cfg.recordsPool, s.cl.cfg.decompressBufferPool) if fp.Err != nil { if moving := kmove.maybeAddFetchPartition(resp, rp, partOffset.from); moving { strip(topic, partition, fp.Err) @@ -1322,17 +1319,18 @@ func (s *source) handleReqResp(br *broker, req *fetchRequest, resp *kmsg.FetchRe // processRespPartition processes all records in all potentially compressed // batches (or message sets). -func (o *cursorOffsetNext) processRespPartition(br *broker, rp *kmsg.FetchResponseTopicPartition, hooks hooks, recordsPool recordsPool) (fp FetchPartition) { +func (o *cursorOffsetNext) processRespPartition(br *broker, rp *kmsg.FetchResponseTopicPartition, hooks hooks, recordsPool *recordsPool, decompressBufferPool *pool.BucketedPool[byte]) (fp FetchPartition) { if rp.ErrorCode == 0 { o.hwm = rp.HighWatermark } opts := ProcessFetchPartitionOptions{ - KeepControlRecords: br.cl.cfg.keepControl, - Offset: o.offset, - IsolationLevel: IsolationLevel{br.cl.cfg.isolationLevel}, - Topic: o.from.topic, - Partition: o.from.partition, - recordPool: recordsPool, + KeepControlRecords: br.cl.cfg.keepControl, + Offset: o.offset, + IsolationLevel: IsolationLevel{br.cl.cfg.isolationLevel}, + Topic: o.from.topic, + Partition: o.from.partition, + recordPool: recordsPool, + decompressBufferPool: decompressBufferPool, } observeMetrics := func(m FetchBatchMetrics) { hooks.each(func(h Hook) { @@ -1486,7 +1484,7 @@ func ProcessRespPartition(o ProcessFetchPartitionOptions, rp *kmsg.FetchResponse case *kmsg.RecordBatch: m.CompressedBytes = len(t.Records) // for record batches, we only track the record batch length m.CompressionType = uint8(t.Attributes) & 0b0000_0111 - m.NumRecords, m.UncompressedBytes = processRecordBatch(&o, &fp, t, aborter, defaultDecompressor, o.recordPool) + m.NumRecords, m.UncompressedBytes = processRecordBatch(&o, &fp, t, aborter, defaultDecompressor) } if m.UncompressedBytes == 0 { @@ -1552,7 +1550,6 @@ func readRawRecords(n int, in []byte) []kmsg.Record { rs = rs[:n] for i := 0; i < n; i++ { rs[i] = kmsg.Record{} - length, used := kbin.Varint(in) total := used + int(length) if used == 0 || length < 0 || len(in) < total { @@ -1572,7 +1569,6 @@ func processRecordBatch( batch *kmsg.RecordBatch, aborter aborter, decompressor *decompressor, - recordsPool recordsPool, ) (int, int) { if batch.Magic != 2 { fp.Err = fmt.Errorf("unknown batch magic %d", batch.Magic) @@ -1589,7 +1585,7 @@ func processRecordBatch( rawRecords := batch.Records if compression := byte(batch.Attributes & 0x0007); compression != 0 { var err error - if rawRecords, err = decompressor.decompress(rawRecords, compression); err != nil { + if rawRecords, err = decompressor.decompress(rawRecords, compression, o.decompressBufferPool); err != nil { return 0, 0 // truncated batch } } @@ -1617,11 +1613,11 @@ func processRecordBatch( }() var ( - rcBatchBuff *rcBuffer[byte] + rcBatchBuff *rcBuffer[byte] rcRawRecordsBuff *rcBuffer[kmsg.Record] ) - if decompressor.outBufferPool != nil { - rcBatchBuff = newRCBuffer(rawRecords, decompressor.outBufferPool) + if o.decompressBufferPool != nil { + rcBatchBuff = newRCBuffer(rawRecords, o.decompressBufferPool) rcRawRecordsBuff = newRCBuffer(krecords, rawRecordsPool) } @@ -1632,7 +1628,7 @@ func processRecordBatch( fp.Partition, batch, &krecords[i], - recordsPool, + o.recordPool, ) o.maybeKeepRecord(fp, record, rcBatchBuff, rcRawRecordsBuff, abortBatch) @@ -1660,7 +1656,7 @@ func processV1OuterMessage(o *ProcessFetchPartitionOptions, fp *FetchPartition, return 1, 0 } - rawInner, err := decompressor.decompress(message.Value, compression) + rawInner, err := decompressor.decompress(message.Value, compression, nil) if err != nil { return 0, 0 // truncated batch } @@ -1773,7 +1769,7 @@ func processV0OuterMessage( return 1, 0 // uncompressed bytes is 0; set to compressed bytes on return } - rawInner, err := decompressor.decompress(message.Value, compression) + rawInner, err := decompressor.decompress(message.Value, compression, nil) if err != nil { return 0, 0 // truncated batch } @@ -1841,7 +1837,7 @@ func processV0Message( // // If the record is being aborted or the record is a control record and the // client does not want to keep control records, this does not keep the record. -func (o *cursorOffsetNext) maybeKeepRecord(fp *FetchPartition, record *Record, rcBatchBuff *rcBuffer[byte], rcRawRecordsBuff *rcBuffer[kmsg.Record], abort bool) { +func (o *ProcessFetchPartitionOptions) maybeKeepRecord(fp *FetchPartition, record *Record, rcBatchBuff *rcBuffer[byte], rcRawRecordsBuff *rcBuffer[kmsg.Record], abort bool) { if record.Offset < o.Offset { // We asked for offset 5, but that was in the middle of a // batch; we got offsets 0 thru 4 that we need to skip. @@ -1853,12 +1849,10 @@ func (o *cursorOffsetNext) maybeKeepRecord(fp *FetchPartition, record *Record, r abort = !o.KeepControlRecords } if !abort { - if rcBatchBuff != nil { + if rcBatchBuff != nil && rcRawRecordsBuff != nil { rcBatchBuff.acquire() - record.rcBatchBuffer = rcBatchBuff - } - if rcRawRecordsBuff != nil { rcRawRecordsBuff.acquire() + record.rcBatchBuffer = rcBatchBuff record.rcRawRecordsBuffer = rcRawRecordsBuff } fp.Records = append(fp.Records, record) @@ -1883,7 +1877,7 @@ func recordToRecord( partition int32, batch *kmsg.RecordBatch, record *kmsg.Record, - recordsPool recordsPool, + recordsPool *recordsPool, ) *Record { h := make([]RecordHeader, 0, len(record.Headers)) for _, kv := range record.Headers { @@ -1892,7 +1886,12 @@ func recordToRecord( Value: kv.Value, }) } - r := recordsPool.get() + var r *Record + if recordsPool != nil { + r = recordsPool.get() + } else { + r = new(Record) + } r.Key = record.Key r.Value = record.Value