diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index f427c1c1e7..354efa2805 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -29,6 +29,7 @@ import ( // Connection state constants. const ( connDisconnected int64 = iota + connInterrupted connConnected connInitialized ) @@ -484,7 +485,9 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, func (c *connection) close() error { // Overwrite the connection state as the first step so only the first close call will execute. - if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) { + connected := atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) + interrupted := atomic.CompareAndSwapInt64(&c.state, connInterrupted, connDisconnected) + if !connected && !interrupted { return nil } @@ -496,13 +499,14 @@ func (c *connection) close() error { return err } -func (c *connection) closeWithErr(err error) error { +func (c *connection) interrupt(err error) { c.err = err - return c.close() + atomic.CompareAndSwapInt64(&c.state, connConnected, connInterrupted) } func (c *connection) closed() bool { - return atomic.LoadInt64(&c.state) == connDisconnected + state := atomic.LoadInt64(&c.state) + return state == connDisconnected || state == connInterrupted } func (c *connection) idleTimeoutExpired() bool { diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index ff886c326b..8a0b7155bb 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -814,6 +814,9 @@ func (p *pool) checkInWithCallback(conn *connection, cb func() (reason, bool)) e if conn.pool != p { return ErrWrongPool } + if atomic.LoadInt64(&conn.state) == connInterrupted { + return nil + } conn.inUse = false @@ -868,10 +871,6 @@ func (p *pool) clearAll(err error, serviceID *primitive.ObjectID) { } for _, conn := range p.conns { if conn.inUse && p.stale(conn) { - _ = conn.closeWithErr(poolClearedError{ - err: fmt.Errorf("interrupted"), - address: p.address, - }) _ = p.checkInWithCallback(conn, func() (reason, bool) { if mustLogPoolMessage(p) { keysAndValues := logger.KeyValues{ @@ -889,6 +888,11 @@ func (p *pool) clearAll(err error, serviceID *primitive.ObjectID) { }) } + conn.interrupt(poolClearedError{ + err: fmt.Errorf("interrupted"), + address: p.address, + }) + r, ok := connectionPerished(conn) if ok { r = reason{