From ab8a125839c13457e06bf0101aec07920cdf7103 Mon Sep 17 00:00:00 2001 From: FZambia Date: Wed, 30 Oct 2024 10:06:28 +0200 Subject: [PATCH] separate method to keep backwards compatibility --- decode_stream.go | 27 ++++++++++++++++----------- decode_stream_test.go | 27 +++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/decode_stream.go b/decode_stream.go index 21b76a7..d696087 100644 --- a/decode_stream.go +++ b/decode_stream.go @@ -15,7 +15,11 @@ var ( streamProtobufCommandDecoderPool sync.Pool ) -func GetStreamCommandDecoder(protoType Type, reader io.Reader, messageSizeLimit int) StreamCommandDecoder { +func GetStreamCommandDecoder(protoType Type, reader io.Reader) StreamCommandDecoder { + return GetStreamCommandDecoderLimited(protoType, reader, 0) +} + +func GetStreamCommandDecoderLimited(protoType Type, reader io.Reader, messageSizeLimit int64) StreamCommandDecoder { if protoType == TypeJSON { e := streamJsonCommandDecoderPool.Get() if e == nil { @@ -45,7 +49,7 @@ func PutStreamCommandDecoder(protoType Type, e StreamCommandDecoder) { type StreamCommandDecoder interface { Decode() (*Command, int, error) - Reset(reader io.Reader, messageSizeLimit int) + Reset(reader io.Reader, messageSizeLimit int64) } // ErrMessageTooLarge for when the message exceeds the limit. @@ -54,14 +58,14 @@ var ErrMessageTooLarge = errors.New("message size exceeds the limit") type JSONStreamCommandDecoder struct { reader *bufio.Reader limitedReader *io.LimitedReader - messageSizeLimit int + messageSizeLimit int64 } -func NewJSONStreamCommandDecoder(reader io.Reader, messageSizeLimit int) *JSONStreamCommandDecoder { +func NewJSONStreamCommandDecoder(reader io.Reader, messageSizeLimit int64) *JSONStreamCommandDecoder { var limitedReader *io.LimitedReader var bufioReader *bufio.Reader if messageSizeLimit > 0 { - limitedReader = &io.LimitedReader{R: reader, N: int64(messageSizeLimit) + 1} + limitedReader = &io.LimitedReader{R: reader, N: messageSizeLimit + 1} bufioReader = bufio.NewReader(limitedReader) } else { bufioReader = bufio.NewReader(reader) @@ -79,7 +83,7 @@ func (d *JSONStreamCommandDecoder) Decode() (*Command, int, error) { } cmdBytes, err := d.reader.ReadBytes('\n') if err != nil { - if d.messageSizeLimit > 0 && len(cmdBytes) > d.messageSizeLimit { + if d.messageSizeLimit > 0 && int64(len(cmdBytes)) > d.messageSizeLimit { return nil, 0, ErrMessageTooLarge } if err == io.EOF && len(cmdBytes) > 0 { @@ -101,24 +105,25 @@ func (d *JSONStreamCommandDecoder) Decode() (*Command, int, error) { return &c, len(cmdBytes), nil } -func (d *JSONStreamCommandDecoder) Reset(reader io.Reader, messageSizeLimit int) { +func (d *JSONStreamCommandDecoder) Reset(reader io.Reader, messageSizeLimit int64) { d.messageSizeLimit = messageSizeLimit if messageSizeLimit > 0 { - limitedReader := &io.LimitedReader{R: reader, N: int64(messageSizeLimit) + 1} + limitedReader := &io.LimitedReader{R: reader, N: messageSizeLimit + 1} bufioReader := bufio.NewReader(limitedReader) d.limitedReader = limitedReader d.reader.Reset(bufioReader) } else { + d.limitedReader = nil d.reader.Reset(reader) } } type ProtobufStreamCommandDecoder struct { reader *bufio.Reader - messageSizeLimit int + messageSizeLimit int64 } -func NewProtobufStreamCommandDecoder(reader io.Reader, messageSizeLimit int) *ProtobufStreamCommandDecoder { +func NewProtobufStreamCommandDecoder(reader io.Reader, messageSizeLimit int64) *ProtobufStreamCommandDecoder { return &ProtobufStreamCommandDecoder{reader: bufio.NewReader(reader), messageSizeLimit: messageSizeLimit} } @@ -150,7 +155,7 @@ func (d *ProtobufStreamCommandDecoder) Decode() (*Command, int, error) { return &c, int(msgLength) + 8, nil } -func (d *ProtobufStreamCommandDecoder) Reset(reader io.Reader, messageSizeLimit int) { +func (d *ProtobufStreamCommandDecoder) Reset(reader io.Reader, messageSizeLimit int64) { d.messageSizeLimit = messageSizeLimit d.reader.Reset(reader) } diff --git a/decode_stream_test.go b/decode_stream_test.go index 075651e..f6e57a5 100644 --- a/decode_stream_test.go +++ b/decode_stream_test.go @@ -60,14 +60,14 @@ func TestStreamingDecode_JSON(t *testing.T) { func TestStreamingDecode_JSON_MessageLimit(t *testing.T) { frame := getTestFrame(t, TypeJSON, 10000) - dec := GetStreamCommandDecoder(TypeJSON, bytes.NewReader(frame), 100) + dec := GetStreamCommandDecoderLimited(TypeJSON, bytes.NewReader(frame), 100) _, _, err := dec.Decode() require.ErrorIs(t, err, ErrMessageTooLarge) } func TestStreamingDecode_Protobuf_MessageLimit(t *testing.T) { frame := getTestFrame(t, TypeProtobuf, 10000) - dec := GetStreamCommandDecoder(TypeProtobuf, bytes.NewReader(frame), 100) + dec := GetStreamCommandDecoderLimited(TypeProtobuf, bytes.NewReader(frame), 100) _, _, err := dec.Decode() require.ErrorIs(t, err, ErrMessageTooLarge) } @@ -95,7 +95,7 @@ func BenchmarkStreamingDecode_JSON(b *testing.B) { } func testDecodingFrame(tb testing.TB, frame []byte, protoType Type) { - dec := GetStreamCommandDecoder(protoType, bytes.NewReader(frame), 200000) + dec := GetStreamCommandDecoder(protoType, bytes.NewReader(frame)) _, size, err := dec.Decode() require.NoError(tb, err) if protoType == TypeProtobuf { @@ -129,7 +129,7 @@ func TestJSONStreamCommandDecoder(t *testing.T) { testCases := []struct { name string - messageSizeLimit int + messageSizeLimit int64 }{ { name: "no limit", @@ -169,3 +169,22 @@ func TestJSONStreamCommandDecoder(t *testing.T) { }) } } + +func TestJSONStreamCommandDecoder_ReuseDifferentLimit(t *testing.T) { + // Sample data emulating a network stream of JSON commands with newlines + data := `{"publish":{"channel":"1","data":{}}} +{"publish":{"channel":"1","data":{}}}` + decoder := GetStreamCommandDecoderLimited(TypeJSON, bytes.NewBufferString(data), 10) + _, _, err := decoder.Decode() + require.ErrorIs(t, err, ErrMessageTooLarge) + PutStreamCommandDecoder(TypeJSON, decoder) + decoder = GetStreamCommandDecoderLimited(TypeJSON, bytes.NewBufferString(data), 0) + cmd, _, err := decoder.Decode() + require.NoError(t, err) + require.NotNil(t, cmd) + require.NotNil(t, cmd.Publish) + cmd, _, err = decoder.Decode() + require.ErrorIs(t, err, io.EOF) + require.NotNil(t, cmd) + require.NotNil(t, cmd.Publish) +}