diff --git a/neo4j/session.go b/neo4j/session.go index a06a8536..0e1fd7ce 100644 --- a/neo4j/session.go +++ b/neo4j/session.go @@ -259,7 +259,7 @@ func (s *session) runRetriable( DatabaseName: s.databaseName, } for state.Continue() { - if workResult := s.tryRun(&state, mode, &config, work); workResult != nil { + if workResult, successfullyCompleted := s.tryRun(&state, mode, &config, work); successfullyCompleted { return workResult, nil } } @@ -292,12 +292,12 @@ func (s *session) WriteTransaction( return s.runRetriable(db.WriteMode, work, configurers...) } -func (s *session) tryRun(state *retry.State, mode db.AccessMode, config *TransactionConfig, work TransactionWork) interface{} { +func (s *session) tryRun(state *retry.State, mode db.AccessMode, config *TransactionConfig, work TransactionWork) (interface{}, bool) { // Establish new connection conn, err := s.getConnection(mode) if err != nil { state.OnFailure(conn, err, false) - return nil + return nil, false } defer s.pool.Return(conn) txHandle, err := conn.TxBegin(db.TxConfig{ @@ -308,7 +308,7 @@ func (s *session) tryRun(state *retry.State, mode db.AccessMode, config *Transac }) if err != nil { state.OnFailure(conn, err, false) - return nil + return nil, false } tx := retryableTransaction{conn: conn, fetchSize: s.fetchSize, txHandle: txHandle} @@ -321,17 +321,17 @@ func (s *session) tryRun(state *retry.State, mode db.AccessMode, config *Transac // but instead rely on pool invoking reset on the connection, that // will do an implicit rollback. state.OnFailure(conn, err, false) - return nil + return nil, false } err = conn.TxCommit(txHandle) if err != nil { state.OnFailure(conn, err, true) - return nil + return nil, false } s.retrieveBookmarks(conn) - return x + return x, true } func (s *session) getServers(ctx context.Context, mode db.AccessMode) ([]string, error) { diff --git a/neo4j/session_test.go b/neo4j/session_test.go index f846f5b7..59ddbf20 100644 --- a/neo4j/session_test.go +++ b/neo4j/session_test.go @@ -152,6 +152,22 @@ func TestSession(st *testing.T) { }() _, _ = newSession.WriteTransaction(transactionFunction) }) + + rt.Run("tx function panic returns conn to pool and bubbles up", func(t *testing.T) { + _, pool, newSession := createSession() + pool.BorrowConn = &ConnFake{Alive: true} + transactionFunction := func(Transaction) (interface{}, error) { + return nil, nil + } + + result, err := newSession.WriteTransaction(transactionFunction) + if result != nil { + t.Errorf("expected nil result") + } + if err != nil { + t.Errorf("expected nil error") + } + }) }) st.Run("Bookmarking", func(bt *testing.T) {