From c878cfbf66947f3c669ebbd21b617175c26c36d6 Mon Sep 17 00:00:00 2001 From: Ilia Sergunin Date: Sat, 3 Feb 2024 14:15:10 +0400 Subject: [PATCH] 1. Fix redis goroutine leak --- .github/workflows/main.yaml | 1 + gorm/example_test.go | 4 ++ gorm/transaction.go | 22 +++++------ gorm/transaction_test.go | 3 +- redis/settings.go | 35 ++++++++++++++--- redis/transaction.go | 76 ++++++++++++++++++++++++------------- redis/transaction_test.go | 54 +++++++++++++++++++++----- redis/watcher.go | 9 ++++- sql/transaction_test.go | 3 +- sqlx/transaction_test.go | 3 +- 10 files changed, 153 insertions(+), 57 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 6c327e0..9fed4be 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -5,6 +5,7 @@ on: pull_request: branches: - main + - v1 name: Test env: GO_TARGET_VERSION: 1.21 diff --git a/gorm/example_test.go b/gorm/example_test.go index be97dac..73c44ec 100644 --- a/gorm/example_test.go +++ b/gorm/example_test.go @@ -20,7 +20,11 @@ import ( // Example demonstrates the implementation of the Repository pattern by trm.Manager. func Example() { db, err := gorm.Open(sqlite.Open("file:test.db?mode=memory")) + + checkErr(err) + sqlDB, err := db.DB() checkErr(err) + defer sqlDB.Close() // Migrate the schema checkErr(db.AutoMigrate(&userRow{})) diff --git a/gorm/transaction.go b/gorm/transaction.go index fbe2874..170d2a5 100644 --- a/gorm/transaction.go +++ b/gorm/transaction.go @@ -30,7 +30,7 @@ func NewTransaction( opts *sql.TxOptions, db *gorm.DB, ) (context.Context, *Transaction, error) { - tr := &Transaction{ + t := &Transaction{ tx: nil, txMutex: sync.Mutex{}, active: drivers.NewIsClosed(), @@ -46,28 +46,28 @@ func NewTransaction( db = db.WithContext(ctx) // Used closure to avoid implementing nested transactions. err = db.Transaction(func(tx *gorm.DB) error { - tr.tx = tx + t.tx = tx wg.Done() - <-tr.activeClosure.Closed() + <-t.activeClosure.Closed() - return tr.activeClosure.Err() + return t.activeClosure.Err() }, opts) - tr.txMutex.Lock() - defer tr.txMutex.Unlock() - tx := tr.tx + t.txMutex.Lock() + defer t.txMutex.Unlock() + tx := t.tx if tx != nil { // Return error from transaction rollback // Error from commit returns from db.Transaction closure if errors.Is(err, drivers.ErrRollbackTr) && tx.Error != nil { - err = tr.tx.Error + err = t.tx.Error } - tr.active.CloseWithCause(err) + t.active.CloseWithCause(err) } else { wg.Done() } @@ -79,9 +79,9 @@ func NewTransaction( return ctx, nil, err } - go tr.awaitDone(ctx) + go t.awaitDone(ctx) - return ctx, tr, nil + return ctx, t, nil } func (t *Transaction) awaitDone(ctx context.Context) { diff --git a/gorm/transaction_test.go b/gorm/transaction_test.go index b66c226..bcb8c57 100644 --- a/gorm/transaction_test.go +++ b/gorm/transaction_test.go @@ -272,7 +272,7 @@ func TestTransaction_awaitDone_byRollback(t *testing.T) { require.NoError(t, err) f := NewDefaultFactory(dbgorm) - ctx := context.Background() + ctx, _ := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(1) @@ -284,6 +284,7 @@ func TestTransaction_awaitDone_byRollback(t *testing.T) { require.NoError(t, tr.Rollback(ctx)) require.False(t, tr.IsActive()) + require.ErrorIs(t, tr.Rollback(ctx), sql.ErrTxDone) }() wg.Wait() diff --git a/redis/settings.go b/redis/settings.go index f834c60..5a88f49 100644 --- a/redis/settings.go +++ b/redis/settings.go @@ -1,6 +1,8 @@ package redis import ( + "sync" + "github.com/go-redis/redis/v8" "github.com/avito-tech/go-transaction-manager/trm" @@ -20,7 +22,9 @@ type Settings struct { isMulti *bool watchKeys []string txDecorator []TxDecorator - ret *[]redis.Cmder + + ret *[]redis.Cmder + muRet sync.RWMutex } // NewSettings creates Settings. @@ -31,6 +35,7 @@ func NewSettings(trms trm.Settings, oo ...Opt) (Settings, error) { watchKeys: nil, txDecorator: nil, ret: nil, + muRet: sync.RWMutex{}, } for _, o := range oo { @@ -68,8 +73,8 @@ func (s Settings) EnrichBy(in trm.Settings) trm.Settings { s = s.SetTxDecorators(external.TxDecorators()...) } - if s.Return() == nil { - s = s.SetReturn(external.Return()) + if s.ReturnPtr() == nil { + s = s.SetReturn(external.ReturnPtr()) } } @@ -135,11 +140,31 @@ func (s Settings) setTxDecorator(in ...TxDecorator) Settings { return s } -// Return returns []redis.Cmder from Transaction. -func (s Settings) Return() *[]redis.Cmder { +func (s Settings) ReturnPtr() *[]redis.Cmder { + s.muRet.RLock() + defer s.muRet.RUnlock() + return s.ret } +// Return returns []redis.Cmder from Transaction. +func (s Settings) Return() []redis.Cmder { + res := s.ReturnPtr() + if res != nil { + return *s.ReturnPtr() + } + + return nil +} + +// AppendReturn append []redis.Cmder from Transaction. +func (s *Settings) AppendReturn(cmds ...redis.Cmder) { + s.muRet.Lock() + defer s.muRet.Unlock() + + *s.ret = append(*s.ret, cmds...) +} + // SetReturn sets link to save []redis.Cmder from Transaction. func (s Settings) SetReturn(in *[]redis.Cmder) Settings { return s.setReturn(in) diff --git a/redis/transaction.go b/redis/transaction.go index 4f18541..8b5bc48 100644 --- a/redis/transaction.go +++ b/redis/transaction.go @@ -18,10 +18,9 @@ type TxDecorator func(tx Cmdable, db redis.Cmdable) Cmdable // Transaction is trm.Transaction for sqlx.Tx. type Transaction struct { - tx Cmdable - // err is used to close transaction and get error from it - err chan error - active *drivers.IsClose + tx txInterface + active *drivers.IsClose + activeClosure *drivers.IsClose } // NewTransaction creates trm.Transaction for sqlx.Tx. @@ -31,9 +30,9 @@ func NewTransaction( s Settings, ) (context.Context, *Transaction, error) { t := &Transaction{ - err: make(chan error), - tx: nil, - active: drivers.NewIsClosed(), + tx: nil, + active: drivers.NewIsClosed(), + activeClosure: drivers.NewIsClosed(), } var err error @@ -52,8 +51,8 @@ func NewTransaction( cmds, err = fn(ctx, func(pipe redis.Pipeliner) error { t.tx = &tx{ - tx: rtx, - Cmdable: pipe, + tx: rtx, + Pipeliner: pipe, } for _, d := range s.TxDecorators() { @@ -62,18 +61,20 @@ func NewTransaction( wg.Done() - return <-t.err + <-t.activeClosure.Closed() + + return t.activeClosure.Err() }) - if len(cmds) > 0 && s.Return() != nil { - *s.Return() = append(*s.Return(), cmds...) + if len(cmds) > 0 && s.ReturnPtr() != nil { + s.AppendReturn(cmds...) } return err }, s.WatchKeys()...) if t.tx != nil { - t.err <- err + t.active.CloseWithCause(err) } else { wg.Done() } @@ -97,7 +98,8 @@ func (t *Transaction) awaitDone(ctx context.Context) { select { case <-ctx.Done(): - t.active.Close() + // Rollback will be called by context.Err() + t.activeClosure.Close() case <-t.active.Closed(): } } @@ -109,32 +111,52 @@ func (t *Transaction) Transaction() interface{} { } // Commit closes the trm.Transaction. -func (t *Transaction) Commit(_ context.Context) error { - defer t.active.Close() +func (t *Transaction) Commit(ctx context.Context) error { + select { + case <-t.active.Closed(): + cmds, err := t.tx.Exec(ctx) + + // TODO process cmds + _ = cmds - // TODO deadlock - t.err <- nil + return err + default: + t.activeClosure.Close() - return <-t.err + <-t.active.Closed() + + return t.active.Err() + } } // Rollback the trm.Transaction. func (t *Transaction) Rollback(_ context.Context) error { - defer t.active.Close() + select { + case <-t.active.Closed(): + return t.tx.Discard() + default: + t.activeClosure.CloseWithCause(drivers.ErrRollbackTr) - // TODO deadlock - t.err <- errRollbackTx + <-t.active.Closed() - err := <-t.err + err := t.active.Err() + if errors.Is(err, drivers.ErrRollbackTr) { + return nil + } - if errors.Is(err, errRollbackTx) { - return nil - } + // unreachable code, because of go-redis doesn't process error from Close + // https://github.com/redis/go-redis/blob/v8.11.5/tx.go#L69 + // https://github.com/redis/go-redis/blob/v8.11.5/pipeline.go#L130 - return err + return err + } } // IsActive returns true if the transaction started but not committed or rolled back. func (t *Transaction) IsActive() bool { return t.active.IsActive() } + +func (t *Transaction) Closed() <-chan struct{} { + return t.active.Closed() +} diff --git a/redis/transaction_test.go b/redis/transaction_test.go index 370b9ad..13482bc 100644 --- a/redis/transaction_test.go +++ b/redis/transaction_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/go-redis/redis/v8" "github.com/go-redis/redismock/v8" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -43,12 +44,13 @@ func TestTransaction(t *testing.T) { testExp := time.Duration(0) tests := map[string]struct { - prepare func(t *testing.T, m redismock.ClientMock) - args args - ret error - wantErr assert.ErrorAssertionFunc + prepare func(t *testing.T, m redismock.ClientMock) + args args + ret error + wantErr assert.ErrorAssertionFunc + wantCmds int }{ - "success": { + "commit": { prepare: func(t *testing.T, m redismock.ClientMock) { m.ExpectWatch(testKey) m.ExpectTxPipeline() @@ -60,8 +62,9 @@ func TestTransaction(t *testing.T) { args: args{ ctx: ctx, }, - ret: nil, - wantErr: assert.NoError, + ret: nil, + wantErr: assert.NoError, + wantCmds: 1, }, "begin_error": { prepare: func(t *testing.T, m redismock.ClientMock) {}, @@ -85,10 +88,12 @@ func TestTransaction(t *testing.T) { args: args{ ctx: ctx, }, + ret: nil, wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { return assert.ErrorContains(t, err, "redis: nil") && assert.ErrorIs(t, err, trm.ErrCommit) }, + wantCmds: 1, }, "rollback": { prepare: func(t *testing.T, m redismock.ClientMock) { @@ -115,7 +120,7 @@ func TestTransaction(t *testing.T) { s := MustSettings(settings.Must( settings.WithPropagation(trm.PropagationNested), - ), WithWatchKeys(testKey)) + ), WithWatchKeys(testKey), WithRet(&[]redis.Cmder{})) m := manager.Must( NewDefaultFactory(db), manager.WithLog(log), @@ -154,6 +159,8 @@ func TestTransaction(t *testing.T) { if !tt.wantErr(t, err) { return } + + assert.Len(t, s.Return(), tt.wantCmds) assert.NoError(t, rmock.ExpectationsWereMet()) }) } @@ -176,12 +183,41 @@ func TestTransaction_awaitDone_byContext(t *testing.T) { _, tr, err := f(ctx, settings.Must()) cancel() - <-time.After(time.Second) <-ctx.Done() + require.True(t, tr.IsActive()) + <-tr.Closed() + require.False(t, tr.IsActive()) + + require.Equal(t, context.Canceled, ctx.Err()) + err = tr.Commit(ctx) + require.ErrorIs(t, err, redis.ErrClosed) + }() + wg.Wait() + assert.NoError(t, rmock.ExpectationsWereMet()) +} + +// TestTransaction_awaitDone_byRollback checks goroutine leak when we close transaction manually. +func TestTransaction_awaitDone_byRollback(t *testing.T) { + t.Parallel() + + db, rmock := redismock.NewClientMock() + + f := NewDefaultFactory(db) + ctx, _ := context.WithCancel(context.Background()) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + _, tr, err := f(ctx, settings.Must()) require.NoError(t, err) + + require.NoError(t, tr.Rollback(ctx)) require.False(t, tr.IsActive()) + require.NoError(t, tr.Rollback(ctx)) }() wg.Wait() diff --git a/redis/watcher.go b/redis/watcher.go index 25ed8c2..53d8628 100644 --- a/redis/watcher.go +++ b/redis/watcher.go @@ -8,8 +8,8 @@ import ( // Cmdable is an experimental interface to Watch and Unwatch keys in Transaction. type Cmdable interface { - redis.Cmdable Watch + redis.Pipeliner } // Watch is experimental functional for watching updated keys. @@ -20,10 +20,15 @@ type Watch interface { } type tx struct { - redis.Cmdable + redis.Pipeliner tx *redis.Tx } +type txInterface interface { + redis.Pipeliner + Watch +} + func (t *tx) Watch(ctx context.Context, keys ...string) *redis.StatusCmd { return t.tx.Watch(ctx, keys...) } diff --git a/sql/transaction_test.go b/sql/transaction_test.go index bb26812..6810523 100644 --- a/sql/transaction_test.go +++ b/sql/transaction_test.go @@ -270,7 +270,7 @@ func TestTransaction_awaitDone_byRollback(t *testing.T) { }) f := NewDefaultFactory(db) - ctx := context.Background() + ctx, _ := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(1) @@ -282,6 +282,7 @@ func TestTransaction_awaitDone_byRollback(t *testing.T) { require.NoError(t, tr.Rollback(ctx)) require.False(t, tr.IsActive()) + require.ErrorIs(t, tr.Rollback(ctx), sql.ErrTxDone) }() wg.Wait() diff --git a/sqlx/transaction_test.go b/sqlx/transaction_test.go index 3109229..111722f 100644 --- a/sqlx/transaction_test.go +++ b/sqlx/transaction_test.go @@ -273,7 +273,7 @@ func TestTransaction_awaitDone_byRollback(t *testing.T) { }) f := NewDefaultFactory(sqlx.NewDb(db, "sqlmock")) - ctx := context.Background() + ctx, _ := context.WithCancel(context.Background()) wg := sync.WaitGroup{} wg.Add(1) @@ -285,6 +285,7 @@ func TestTransaction_awaitDone_byRollback(t *testing.T) { require.NoError(t, tr.Rollback(ctx)) require.False(t, tr.IsActive()) + require.ErrorIs(t, tr.Rollback(ctx), sql.ErrTxDone) }() wg.Wait()