diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1e97451..b6981c5 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -36,12 +36,10 @@ jobs: run: go get . - name: Test run: go test -cover -vet all -coverprofile cover.out . - - name: Coverage Check - run: | - go tool cover -func ./cover.out - val=$(go tool cover -func cover.out | fgrep total | awk '{print $3}') - if [[ "100.0%" != $val ]] - then - echo 'Test coverage is less than 100.0%' - exit 1 - fi + - name: Coverage report + run: go tool cover -html ./cover.out -o cover.html + - name: Archive coverage report + uses: actions/upload-artifact@v4 + with: + name: coverage.html + path: ./cover.html diff --git a/bucket.go b/bucket.go index 890a08d..0fe83fb 100644 --- a/bucket.go +++ b/bucket.go @@ -1,9 +1,7 @@ package leaky import ( - "bytes" "encoding/binary" - "encoding/gob" "errors" "fmt" "io" @@ -13,10 +11,6 @@ import ( var ErrBucketFull = errors.New("leaky: bucket full or would overflow") -func init() { - gob.Register(&Bucket{}) -} - type Bucket struct { DrainBy int64 DrainInterval time.Duration @@ -53,7 +47,7 @@ func DecodeBucket(r io.Reader) (*Bucket, error) { // Check format version format := int32(0) if err := binary.Read(r, binary.BigEndian, &format); err != nil { - return nil, err + return nil, errors.Join(errors.New("leaky: unable to read format version"), err) } if format != 1 { return nil, fmt.Errorf("leaky: unsupported format version %d", format) @@ -61,57 +55,68 @@ func DecodeBucket(r io.Reader) (*Bucket, error) { // Read fields in write order if err := binary.Read(r, binary.BigEndian, &bucket.DrainBy); err != nil { - return nil, err + return nil, errors.Join(errors.New("leaky: unable to read `DrainBy`"), err) } if err := binary.Read(r, binary.BigEndian, &bucket.DrainInterval); err != nil { - return nil, err + return nil, errors.Join(errors.New("leaky: unable to read `DrainInterval`"), err) } if err := binary.Read(r, binary.BigEndian, &bucket.Capacity); err != nil { - return nil, err + return nil, errors.Join(errors.New("leaky: unable to read `Capacity`"), err) } if err := binary.Read(r, binary.BigEndian, &bucket.value); err != nil { - return nil, err + return nil, errors.Join(errors.New("leaky: unable to read `value`"), err) + } + timestampSize := int32(0) + if err := binary.Read(r, binary.BigEndian, ×tampSize); err != nil { + return nil, errors.Join(errors.New("leaky: unable to read size of `lastDrain`"), err) } - lastDrainMs := int64(0) - if err := binary.Read(r, binary.BigEndian, &lastDrainMs); err != nil { - return nil, err + timestampBytes := make([]byte, timestampSize) + if c, err := r.Read(timestampBytes); err != nil { + return nil, errors.Join(errors.New("leaky: unable to read `lastDrain`"), err) + } else if int32(c) != timestampSize { + return nil, errors.New("leaky: did not read entire timestamp") + } + if err := bucket.lastDrain.UnmarshalBinary(timestampBytes); err != nil { + return nil, errors.Join(errors.New("leaky: unable to unmarshal `lastDrain`"), err) } - bucket.lastDrain = time.UnixMilli(lastDrainMs) return bucket, nil } func (b *Bucket) Encode(w io.Writer) error { - buf := &bytes.Buffer{} - b.lock.Lock() defer b.lock.Unlock() // Format version - if err := binary.Write(buf, binary.BigEndian, int32(1)); err != nil { + if err := binary.Write(w, binary.BigEndian, int32(1)); err != nil { return errors.Join(errors.New("leaky: unable to write format version"), err) } // Fields, ordered - if err := binary.Write(buf, binary.BigEndian, b.DrainBy); err != nil { + if err := binary.Write(w, binary.BigEndian, b.DrainBy); err != nil { return errors.Join(errors.New("leaky: unable to write `DrainBy`"), err) } - if err := binary.Write(buf, binary.BigEndian, b.DrainInterval); err != nil { + if err := binary.Write(w, binary.BigEndian, b.DrainInterval); err != nil { return errors.Join(errors.New("leaky: unable to write `DrainInterval`"), err) } - if err := binary.Write(buf, binary.BigEndian, b.Capacity); err != nil { + if err := binary.Write(w, binary.BigEndian, b.Capacity); err != nil { return errors.Join(errors.New("leaky: unable to write `Capacity`"), err) } - if err := binary.Write(buf, binary.BigEndian, b.value); err != nil { + if err := binary.Write(w, binary.BigEndian, b.value); err != nil { return errors.Join(errors.New("leaky: unable to write `value`"), err) } - if err := binary.Write(buf, binary.BigEndian, b.lastDrain.UnixMilli()); err != nil { - return errors.Join(errors.New("leaky: unable to write `lastDrain`"), err) + if timestampBytes, err := b.lastDrain.MarshalBinary(); err != nil { + return errors.Join(errors.New("leaky: unable to marshal `lastDrain`"), err) + } else { + if err := binary.Write(w, binary.BigEndian, int32(len(timestampBytes))); err != nil { + return errors.Join(errors.New("leaky: unable to write length of `lastDrain`"), err) + } + if _, err := w.Write(timestampBytes); err != nil { + return errors.Join(errors.New("leaky: unable to write `lastDrain`"), err) + } } - // Write and return - _, err := w.Write(buf.Bytes()) - return err + return nil } func (b *Bucket) drain() { diff --git a/bucket_test.go b/bucket_test.go index 39b73f7..f23d0f4 100644 --- a/bucket_test.go +++ b/bucket_test.go @@ -1,13 +1,53 @@ package leaky import ( + "bytes" "errors" + "io" "testing" "time" "github.com/stretchr/testify/assert" ) +type faultyReaderWriter struct { + io.Reader + io.Writer + + FailOnReadOp int // 1 indexed + FailOnWriteOp int // 1 indexed + Buffer *bytes.Buffer + + readOp int + writeOp int +} + +func newFaultyReaderWriter(failReadOp int, failWriteOp int) *faultyReaderWriter { + return &faultyReaderWriter{ + FailOnReadOp: failReadOp, + FailOnWriteOp: failWriteOp, + Buffer: &bytes.Buffer{}, + readOp: 0, + writeOp: 0, + } +} + +func (rw *faultyReaderWriter) Read(b []byte) (int, error) { + rw.readOp++ + if rw.readOp == rw.FailOnReadOp { + return 0, errors.New("read error") + } + return rw.Buffer.Read(b) +} + +func (rw *faultyReaderWriter) Write(b []byte) (int, error) { + rw.writeOp++ + if rw.writeOp == rw.FailOnWriteOp { + return 0, errors.New("write error") + } + return rw.Buffer.Write(b) +} + var createCaseFunctions = []func(drainBy int64, drainEvery time.Duration, capacity int64) (*Bucket, error){ func(drainBy int64, drainEvery time.Duration, capacity int64) (*Bucket, error) { return &Bucket{ @@ -53,6 +93,112 @@ func TestNewBucket(t *testing.T) { assert.Equal(t, false, bucket.lastDrain.IsZero()) // ensure we set a timestamp } +func TestBucketEncodeThenDecode(t *testing.T) { + for i, createFn := range createCaseFunctions { + bucket, err := createFn(5, time.Minute, 300) + if err != nil { + t.Errorf("TestBucketEncodeThenDecode(case:%d): unexpected error %v", i, err) + continue + } + bucket.value = 42 // force a given value + bucket.lastDrain = time.Now().Add(-1 * bucket.DrainInterval) // prepare for 1 drain operation + + // Encode + buf := &bytes.Buffer{} + if err = bucket.Encode(buf); err != nil { + t.Errorf("TestBucketEncodeThenDecode(case:%d): unexpected encode error %v", i, err) + continue + } + + // Decode + var bucket2 *Bucket + if bucket2, err = DecodeBucket(buf); err != nil { + t.Errorf("TestBucketEncodeThenDecode(case:%d): unexpected decode error %v", i, err) + continue + } + assert.NotEqualf(t, bucket, bucket2, "TestBucketEncodeThenDecode(case:%d)", i) + assert.Equalf(t, bucket.DrainBy, bucket2.DrainBy, "TestBucketEncodeThenDecode(case:%d)", i) + assert.Equalf(t, bucket.DrainInterval, bucket2.DrainInterval, "TestBucketEncodeThenDecode(case:%d)", i) + assert.Equalf(t, bucket.Capacity, bucket2.Capacity, "TestBucketEncodeThenDecode(case:%d)", i) + assert.Equalf(t, bucket.value, bucket2.value, "TestBucketEncodeThenDecode(case:%d)", i) + assert.Equalf(t, 0, bucket2.lastDrain.Compare(bucket.lastDrain), "TestBucketEncodeThenDecode(case:%d)", i) + assert.Equalf(t, bucket.lastDrain.UnixNano(), bucket2.lastDrain.UnixNano(), "TestBucketEncodeThenDecode(case:%d)", i) + } +} + +func TestBucket_Encode(t *testing.T) { + for i, createFn := range createCaseFunctions { + bucket, err := createFn(5, time.Minute, 300) + if err != nil { + t.Errorf("TestBucket_Encode(case:%d): unexpected error %v", i, err) + continue + } + bucket.value = 42 // force a given value + bucket.lastDrain = time.Now().Add(-1 * bucket.DrainInterval) // prepare for 1 drain operation + + errorMessages := []string{ + "leaky: unable to write format version", + "leaky: unable to write `DrainBy`", + "leaky: unable to write `DrainInterval`", + "leaky: unable to write `Capacity`", + "leaky: unable to write `value`", + //"leaky: unable to marshal `lastDrain`", + "leaky: unable to write length of `lastDrain`", + "leaky: unable to write `lastDrain`", + } + for j, message := range errorMessages { + rw := newFaultyReaderWriter(j+1, j+1) + if err = bucket.Encode(rw); err != nil { + assert.ErrorContainsf(t, err, message, "TestBucket_Encode(case:%d,msg:%d)", i, j) + } else { + t.Errorf("TestBucket_Encode(case:%d,msg:%d): expected error %s", i, j, message) + } + } + } +} + +func TestBucket_Decode(t *testing.T) { + for i, createFn := range createCaseFunctions { + bucket, err := createFn(5, time.Minute, 300) + if err != nil { + t.Errorf("TestBucket_Decode(case:%d): unexpected error %v", i, err) + continue + } + bucket.value = 42 // force a given value + bucket.lastDrain = time.Now().Add(-1 * bucket.DrainInterval) // prepare for 1 drain operation + + buf := &bytes.Buffer{} + if err = bucket.Encode(buf); err != nil { + t.Errorf("TestBucket_Decode(case:%d): unexpected error %v", i, err) + continue + } + + errorMessages := []string{ + "leaky: unable to read format version", + //"leaky: unsupported format version %d", + "leaky: unable to read `DrainBy`", + "leaky: unable to read `DrainInterval`", + "leaky: unable to read `Capacity`", + "leaky: unable to read `value`", + "leaky: unable to read size of `lastDrain`", + "leaky: unable to read `lastDrain`", + //"leaky: did not read entire timestamp", + //"leaky: unable to unmarshal `lastDrain`", + } + for j, message := range errorMessages { + rw := newFaultyReaderWriter(j+1, j+1) + rw.Buffer = bytes.NewBuffer(buf.Bytes()) + bucket2, err := DecodeBucket(rw) + assert.Nilf(t, bucket2, "TestBucket_Decode(case:%d,msg:%d)", i, j) + if err != nil { + assert.ErrorContainsf(t, err, message, "TestBucket_Decode(case:%d,msg:%d)", i, j) + } else { + t.Errorf("TestBucket_Decode(case:%d,msg:%d): expected error %s", i, j, message) + } + } + } +} + func TestBucket_drain(t *testing.T) { for i, createFn := range createCaseFunctions { bucket, err := createFn(5, time.Minute, 300)