From b565a23108ee970b71368ed551412ccd5f21d9a8 Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 23 Aug 2024 17:56:43 -0400 Subject: [PATCH] code cleanup --- x/mongo/driver/topology/connection.go | 161 ++++++++++++++------------ x/mongo/driver/topology/pool.go | 22 +--- x/mongo/driver/topology/pool_test.go | 2 +- 3 files changed, 91 insertions(+), 94 deletions(-) diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index b99ea53536..d0dfe08789 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -79,9 +79,9 @@ type connection struct { driverConnectionID uint64 generation uint64 - // awaitingResponse indicates the size of server response that was not completely + // awaitRemainingBytes indicates the size of server response that was not completely // read before returning the connection to the pool. - awaitingResponse *int32 + awaitRemainingBytes *int32 // oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate // accessTokens in the OIDC authenticator cache. @@ -115,12 +115,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection { return c } -// DriverConnectionID returns the driver connection ID. -// TODO(GODRIVER-2824): change return type to int64. -func (c *connection) DriverConnectionID() uint64 { - return c.driverConnectionID -} - // setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection // configuration. func (c *connection) setGenerationNumber() { @@ -142,6 +136,39 @@ func (c *connection) hasGenerationNumber() bool { return c.desc.LoadBalanced() } +func configureTLS(ctx context.Context, + tlsConnSource tlsConnectionSource, + nc net.Conn, + addr address.Address, + config *tls.Config, + ocspOpts *ocsp.VerifyOptions, +) (net.Conn, error) { + // Ensure config.ServerName is always set for SNI. + if config.ServerName == "" { + hostname := addr.String() + colonPos := strings.LastIndex(hostname, ":") + if colonPos == -1 { + colonPos = len(hostname) + } + + hostname = hostname[:colonPos] + config.ServerName = hostname + } + + client := tlsConnSource.Client(nc, config) + if err := clientHandshake(ctx, client); err != nil { + return nil, err + } + + // Only do OCSP verification if TLS verification is requested. + if !config.InsecureSkipVerify { + if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil { + return nil, ocspErr + } + } + return client, nil +} + // connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization // handshakes. All errors returned by connect are considered "before the handshake completes" and // must be handled by calling the appropriate SDAM handshake error handler. @@ -317,6 +344,10 @@ func (c *connection) closeConnectContext() { } } +func (c *connection) cancellationListenerCallback() { + _ = c.close() +} + func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error { if originalError == nil { return nil @@ -339,10 +370,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead return originalError } -func (c *connection) cancellationListenerCallback() { - _ = c.close() -} - func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { var err error if atomic.LoadInt64(&c.state) != connConnected { @@ -423,7 +450,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { dst, errMsg, err := c.read(ctx) if err != nil { - if c.awaitingResponse == nil { + if c.awaitRemainingBytes == nil { // If the connection was not marked as awaiting response, use the // pre-CSOT behavior and close the connection because we don't know // if there are other bytes left to read. @@ -443,6 +470,29 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { return dst, nil } +func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) { + // read the length as an int32 + size := (int32(wmSizeBytes[0])) | + (int32(wmSizeBytes[1]) << 8) | + (int32(wmSizeBytes[2]) << 16) | + (int32(wmSizeBytes[3]) << 24) + + if size < 4 { + return 0, fmt.Errorf("malformed message length: %d", size) + } + // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded + // defaultMaxMessageSize instead. + maxMessageSize := c.desc.MaxMessageSize + if maxMessageSize == 0 { + maxMessageSize = defaultMaxMessageSize + } + if uint32(size) > maxMessageSize { + return 0, errResponseTooLarge + } + + return size, nil +} + func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) { go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback) defer func() { @@ -475,35 +525,23 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, n, err := io.ReadFull(c.nc, sizeBuf[:]) if err != nil { if l := int32(n); l == 0 && needToWait(err) { - c.awaitingResponse = &l + c.awaitRemainingBytes = &l } return nil, "incomplete read of message header", err } - - // read the length as an int32 - size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) - - if size < 4 { - err = fmt.Errorf("malformed message length: %d", size) + size, err := c.parseWmSizeBytes(sizeBuf) + if err != nil { return nil, err.Error(), err } - // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded - // defaultMaxMessageSize instead. - maxMessageSize := c.desc.MaxMessageSize - if maxMessageSize == 0 { - maxMessageSize = defaultMaxMessageSize - } - if uint32(size) > maxMessageSize { - return nil, errResponseTooLarge.Error(), errResponseTooLarge - } dst := make([]byte, size) copy(dst, sizeBuf[:]) n, err = io.ReadFull(c.nc, dst[4:]) if err != nil { - if l := size - 4 - int32(n); l > 0 && needToWait(err) { - c.awaitingResponse = &l + remainingBytes := size - 4 - int32(n) + if remainingBytes > 0 && needToWait(err) { + c.awaitRemainingBytes = &remainingBytes } return dst, "incomplete read of full message", err } @@ -551,10 +589,6 @@ func (c *connection) setCanStream(canStream bool) { c.canStream = canStream } -func (c initConnection) supportsStreaming() bool { - return c.canStream -} - func (c *connection) setStreaming(streaming bool) { c.currentlyStreaming = streaming } @@ -568,6 +602,12 @@ func (c *connection) setSocketTimeout(timeout time.Duration) { c.writeTimeout = timeout } +// DriverConnectionID returns the driver connection ID. +// TODO(GODRIVER-2824): change return type to int64. +func (c *connection) DriverConnectionID() uint64 { + return c.driverConnectionID +} + func (c *connection) ID() string { return c.id } @@ -576,6 +616,14 @@ func (c *connection) ServerConnectionID() *int64 { return c.serverConnectionID } +func (c *connection) OIDCTokenGenID() uint64 { + return c.oidcTokenGenID +} + +func (c *connection) SetOIDCTokenGenID(genID uint64) { + c.oidcTokenGenID = genID +} + // initConnection is an adapter used during connection initialization. It has the minimum // functionality necessary to implement the driver.Connection interface, which is required to pass a // *connection to a Handshaker. @@ -613,7 +661,7 @@ func (c initConnection) CurrentlyStreaming() bool { return c.getCurrentlyStreaming() } func (c initConnection) SupportsStreaming() bool { - return c.supportsStreaming() + return c.canStream } // Connection implements the driver.Connection interface to allow reading and writing wire @@ -847,39 +895,6 @@ func (c *Connection) DriverConnectionID() uint64 { return c.connection.DriverConnectionID() } -func configureTLS(ctx context.Context, - tlsConnSource tlsConnectionSource, - nc net.Conn, - addr address.Address, - config *tls.Config, - ocspOpts *ocsp.VerifyOptions, -) (net.Conn, error) { - // Ensure config.ServerName is always set for SNI. - if config.ServerName == "" { - hostname := addr.String() - colonPos := strings.LastIndex(hostname, ":") - if colonPos == -1 { - colonPos = len(hostname) - } - - hostname = hostname[:colonPos] - config.ServerName = hostname - } - - client := tlsConnSource.Client(nc, config) - if err := clientHandshake(ctx, client); err != nil { - return nil, err - } - - // Only do OCSP verification if TLS verification is requested. - if !config.InsecureSkipVerify { - if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil { - return nil, ocspErr - } - } - return client, nil -} - // OIDCTokenGenID returns the OIDC token generation ID. func (c *Connection) OIDCTokenGenID() uint64 { return c.oidcTokenGenID @@ -933,11 +948,3 @@ func (c *cancellListener) StopListening() bool { c.done <- struct{}{} return c.aborted } - -func (c *connection) OIDCTokenGenID() uint64 { - return c.oidcTokenGenID -} - -func (c *connection) SetOIDCTokenGenID(genID uint64) { - c.oidcTokenGenID = genID -} diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index b4001cb17a..5d232f1ebc 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -10,7 +10,6 @@ import ( "context" "fmt" "io" - "io/ioutil" "net" "sync" "sync/atomic" @@ -833,22 +832,13 @@ func bgRead(pool *pool, conn *connection, size int32) { err = fmt.Errorf("error reading the message size: %w", err) return } - size = (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) - if size < 4 { - err = fmt.Errorf("malformed message length: %d", size) - return - } - maxMessageSize := conn.desc.MaxMessageSize - if maxMessageSize == 0 { - maxMessageSize = defaultMaxMessageSize - } - if uint32(size) > maxMessageSize { - err = errResponseTooLarge + size, err = conn.parseWmSizeBytes(sizeBuf) + if err != nil { return } size -= 4 } - _, err = io.CopyN(ioutil.Discard, conn.nc, int64(size)) + _, err = io.CopyN(io.Discard, conn.nc, int64(size)) if err != nil { err = fmt.Errorf("error reading message of %d: %w", size, err) } @@ -901,9 +891,9 @@ func (p *pool) checkInNoEvent(conn *connection) error { // means that connections in "awaiting response" state are checked in but // not usable, which is not covered by the current pool events. We may need // to add pool event information in the future to communicate that. - if conn.awaitingResponse != nil { - size := *conn.awaitingResponse - conn.awaitingResponse = nil + if conn.awaitRemainingBytes != nil { + size := *conn.awaitRemainingBytes + conn.awaitRemainingBytes = nil go bgRead(p, conn, size) return nil } diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index 35185e954c..ebb342e17c 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -1197,7 +1197,7 @@ func TestPool(t *testing.T) { `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`, ) assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) - assert.Nil(t, conn.awaitingResponse, "conn.awaitingResponse should be nil") + assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitingResponse should be nil") wg.Wait() p.close(context.Background()) close(errCh)