Skip to content

Commit

Permalink
Fix race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Nov 17, 2023
1 parent 78e4d2c commit 62895ac
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
12 changes: 8 additions & 4 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
// Connection state constants.
const (
connDisconnected int64 = iota
connInterrupted
connConnected
connInitialized
)
Expand Down Expand Up @@ -484,7 +485,9 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,

func (c *connection) close() error {
// Overwrite the connection state as the first step so only the first close call will execute.
if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) {
connected := atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected)
interrupted := atomic.CompareAndSwapInt64(&c.state, connInterrupted, connDisconnected)
if !connected && !interrupted {
return nil
}

Expand All @@ -496,13 +499,14 @@ func (c *connection) close() error {
return err
}

func (c *connection) closeWithErr(err error) error {
func (c *connection) interrupt(err error) {
c.err = err
return c.close()
atomic.CompareAndSwapInt64(&c.state, connConnected, connInterrupted)
}

func (c *connection) closed() bool {
return atomic.LoadInt64(&c.state) == connDisconnected
state := atomic.LoadInt64(&c.state)
return state == connDisconnected || state == connInterrupted
}

func (c *connection) idleTimeoutExpired() bool {
Expand Down
12 changes: 8 additions & 4 deletions x/mongo/driver/topology/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,9 @@ func (p *pool) checkInWithCallback(conn *connection, cb func() (reason, bool)) e
if conn.pool != p {
return ErrWrongPool
}
if atomic.LoadInt64(&conn.state) == connInterrupted {
return nil
}

conn.inUse = false

Expand Down Expand Up @@ -868,10 +871,6 @@ func (p *pool) clearAll(err error, serviceID *primitive.ObjectID) {
}
for _, conn := range p.conns {
if conn.inUse && p.stale(conn) {
_ = conn.closeWithErr(poolClearedError{
err: fmt.Errorf("interrupted"),
address: p.address,
})
_ = p.checkInWithCallback(conn, func() (reason, bool) {
if mustLogPoolMessage(p) {
keysAndValues := logger.KeyValues{
Expand All @@ -889,6 +888,11 @@ func (p *pool) clearAll(err error, serviceID *primitive.ObjectID) {
})
}

conn.interrupt(poolClearedError{
err: fmt.Errorf("interrupted"),
address: p.address,
})

r, ok := connectionPerished(conn)
if ok {
r = reason{
Expand Down

0 comments on commit 62895ac

Please sign in to comment.