diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index bbd3af9743..2b72d3efd8 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -801,10 +801,19 @@ func (s *Server) check() (description.Server, error) { if s.conn == nil || s.conn.closed() || s.checkWasCancelled() { // Create a new connection and add it's handshake RTT as a sample. err = s.setupHeartbeatConnection() + duration = time.Since(start) if err == nil { // Use the description from the connection handshake as the value for this check. s.rttMonitor.addSample(s.conn.helloRTT) descPtr = &s.conn.desc + if s.conn != nil { + s.publishServerHeartbeatSucceededEvent(s.conn.ID(), duration, s.conn.desc, false) + } + } else { + err = unwrapConnectionError(err) + if s.conn != nil { + s.publishServerHeartbeatFailedEvent(s.conn.ID(), duration, err, false) + } } } else { // An existing connection is being used. Use the server description properties to execute the right heartbeat. diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index e42caff804..3078f2eeb3 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -56,35 +56,32 @@ type errorQueue struct { mutex sync.Mutex } -func (eq *errorQueue) head() (int, error) { +func (eq *errorQueue) head() error { eq.mutex.Lock() defer eq.mutex.Unlock() - if l := len(eq.errors); l > 0 { - return l, eq.errors[0] + if len(eq.errors) > 0 { + return eq.errors[0] } - return 0, nil + return nil } -func (eq *errorQueue) dequeue() { +func (eq *errorQueue) dequeue() bool { eq.mutex.Lock() defer eq.mutex.Unlock() if len(eq.errors) > 0 { eq.errors = eq.errors[1:] + return true } + return false } type timeoutConn struct { net.Conn errors *errorQueue - ch chan int } func (c *timeoutConn) Read(b []byte) (int, error) { - var n int - l, err := c.errors.head() - defer func(l int) { - c.ch <- l - }(l) + n, err := 0, c.errors.head() if err == nil { n, err = c.Conn.Read(b) } @@ -92,11 +89,7 @@ func (c *timeoutConn) Read(b []byte) (int, error) { } func (c *timeoutConn) Write(b []byte) (int, error) { - var n int - l, err := c.errors.head() - defer func(l int) { - c.ch <- l - }(l) + n, err := 0, c.errors.head() if err == nil { n, err = c.Conn.Write(b) } @@ -106,7 +99,6 @@ func (c *timeoutConn) Write(b []byte) (int, error) { type timeoutDialer struct { Dialer errors *errorQueue - ch chan int } func (d *timeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { @@ -129,7 +121,7 @@ func (d *timeoutDialer) DialContext(ctx context.Context, network, address string } c = tls.Client(c, config) } - return &timeoutConn{c, d.errors, d.ch}, e + return &timeoutConn{c, d.errors}, e } // TestServerHeartbeatTimeout tests timeout retry for GODRIVER-2577. @@ -145,19 +137,16 @@ func TestServerHeartbeatTimeout(t *testing.T) { testCases := []struct { desc string ioErrors []error - len int expectPoolCleared bool }{ { desc: "one single timeout should not clear the pool", ioErrors: []error{nil, networkTimeoutError, nil, networkTimeoutError, nil}, - len: 0, expectPoolCleared: false, }, { desc: "continuous timeouts should clear the pool", - ioErrors: []error{nil, networkTimeoutError, networkTimeoutError}, - len: 1, + ioErrors: []error{nil, networkTimeoutError, networkTimeoutError, nil}, expectPoolCleared: true, }, } @@ -166,9 +155,8 @@ func TestServerHeartbeatTimeout(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { t.Parallel() - const heartbeatInterval = 200 * time.Millisecond - - c := make(chan int) + var wg sync.WaitGroup + wg.Add(1) errors := &errorQueue{errors: tc.ioErrors} tpm := eventtest.NewTestPoolMonitor() @@ -182,40 +170,29 @@ func TestServerHeartbeatTimeout(t *testing.T) { return append(opts, WithDialer(func(d Dialer) Dialer { var dialer net.Dialer - return &timeoutDialer{&dialer, errors, c} + return &timeoutDialer{&dialer, errors} })) }), WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return &event.ServerMonitor{ ServerHeartbeatSucceeded: func(e *event.ServerHeartbeatSucceededEvent) { - errors.dequeue() + if !errors.dequeue() { + wg.Done() + } }, ServerHeartbeatFailed: func(e *event.ServerHeartbeatFailedEvent) { - errors.dequeue() + if !errors.dequeue() { + wg.Done() + } }, } }), WithHeartbeatInterval(func(time.Duration) time.Duration { - return heartbeatInterval + return 200 * time.Millisecond }), ) require.NoError(t, server.Connect(nil)) - - timeout := time.After(50 * heartbeatInterval) - var l int - loop: - for { - select { - case l = <-c: - if l == 0 || tpm.IsPoolCleared() { - break loop - } - case <-timeout: - assert.Fail(t, "timeout") - break loop - } - } - assert.Equal(t, tc.len, l, "pool has been cleared unexpectedly") + wg.Wait() assert.Equal(t, tc.expectPoolCleared, tpm.IsPoolCleared(), "expected pool cleared to be %v but was %v", tc.expectPoolCleared, tpm.IsPoolCleared()) }) }