Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Aug 23, 2024
1 parent ddbd3e9 commit b565a23
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 94 deletions.
161 changes: 84 additions & 77 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ type connection struct {
driverConnectionID uint64
generation uint64

// awaitingResponse indicates the size of server response that was not completely
// awaitRemainingBytes indicates the size of server response that was not completely
// read before returning the connection to the pool.
awaitingResponse *int32
awaitRemainingBytes *int32

// oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate
// accessTokens in the OIDC authenticator cache.
Expand Down Expand Up @@ -115,12 +115,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
return c
}

// DriverConnectionID returns the driver connection ID.
// TODO(GODRIVER-2824): change return type to int64.
func (c *connection) DriverConnectionID() uint64 {
return c.driverConnectionID
}

// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
// configuration.
func (c *connection) setGenerationNumber() {
Expand All @@ -142,6 +136,39 @@ func (c *connection) hasGenerationNumber() bool {
return c.desc.LoadBalanced()
}

func configureTLS(ctx context.Context,
tlsConnSource tlsConnectionSource,
nc net.Conn,
addr address.Address,
config *tls.Config,
ocspOpts *ocsp.VerifyOptions,
) (net.Conn, error) {
// Ensure config.ServerName is always set for SNI.
if config.ServerName == "" {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}

hostname = hostname[:colonPos]
config.ServerName = hostname
}

client := tlsConnSource.Client(nc, config)
if err := clientHandshake(ctx, client); err != nil {
return nil, err
}

// Only do OCSP verification if TLS verification is requested.
if !config.InsecureSkipVerify {
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
return nil, ocspErr
}
}
return client, nil
}

// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
// handshakes. All errors returned by connect are considered "before the handshake completes" and
// must be handled by calling the appropriate SDAM handshake error handler.
Expand Down Expand Up @@ -317,6 +344,10 @@ func (c *connection) closeConnectContext() {
}
}

func (c *connection) cancellationListenerCallback() {
_ = c.close()
}

func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
if originalError == nil {
return nil
Expand All @@ -339,10 +370,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
return originalError
}

func (c *connection) cancellationListenerCallback() {
_ = c.close()
}

func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
var err error
if atomic.LoadInt64(&c.state) != connConnected {
Expand Down Expand Up @@ -423,7 +450,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {

dst, errMsg, err := c.read(ctx)
if err != nil {
if c.awaitingResponse == nil {
if c.awaitRemainingBytes == nil {
// If the connection was not marked as awaiting response, use the
// pre-CSOT behavior and close the connection because we don't know
// if there are other bytes left to read.
Expand All @@ -443,6 +470,29 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
return dst, nil
}

func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) {
// read the length as an int32
size := (int32(wmSizeBytes[0])) |
(int32(wmSizeBytes[1]) << 8) |
(int32(wmSizeBytes[2]) << 16) |
(int32(wmSizeBytes[3]) << 24)

if size < 4 {
return 0, fmt.Errorf("malformed message length: %d", size)
}
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
// defaultMaxMessageSize instead.
maxMessageSize := c.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
return 0, errResponseTooLarge
}

return size, nil
}

func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) {
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
defer func() {
Expand Down Expand Up @@ -475,35 +525,23 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
n, err := io.ReadFull(c.nc, sizeBuf[:])
if err != nil {
if l := int32(n); l == 0 && needToWait(err) {
c.awaitingResponse = &l
c.awaitRemainingBytes = &l
}
return nil, "incomplete read of message header", err
}

// read the length as an int32
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)

if size < 4 {
err = fmt.Errorf("malformed message length: %d", size)
size, err := c.parseWmSizeBytes(sizeBuf)
if err != nil {
return nil, err.Error(), err
}
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
// defaultMaxMessageSize instead.
maxMessageSize := c.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
return nil, errResponseTooLarge.Error(), errResponseTooLarge
}

dst := make([]byte, size)
copy(dst, sizeBuf[:])

n, err = io.ReadFull(c.nc, dst[4:])
if err != nil {
if l := size - 4 - int32(n); l > 0 && needToWait(err) {
c.awaitingResponse = &l
remainingBytes := size - 4 - int32(n)
if remainingBytes > 0 && needToWait(err) {
c.awaitRemainingBytes = &remainingBytes
}
return dst, "incomplete read of full message", err
}
Expand Down Expand Up @@ -551,10 +589,6 @@ func (c *connection) setCanStream(canStream bool) {
c.canStream = canStream
}

func (c initConnection) supportsStreaming() bool {
return c.canStream
}

func (c *connection) setStreaming(streaming bool) {
c.currentlyStreaming = streaming
}
Expand All @@ -568,6 +602,12 @@ func (c *connection) setSocketTimeout(timeout time.Duration) {
c.writeTimeout = timeout
}

// DriverConnectionID returns the driver connection ID.
// TODO(GODRIVER-2824): change return type to int64.
func (c *connection) DriverConnectionID() uint64 {
return c.driverConnectionID
}

func (c *connection) ID() string {
return c.id
}
Expand All @@ -576,6 +616,14 @@ func (c *connection) ServerConnectionID() *int64 {
return c.serverConnectionID
}

func (c *connection) OIDCTokenGenID() uint64 {
return c.oidcTokenGenID
}

func (c *connection) SetOIDCTokenGenID(genID uint64) {
c.oidcTokenGenID = genID
}

// initConnection is an adapter used during connection initialization. It has the minimum
// functionality necessary to implement the driver.Connection interface, which is required to pass a
// *connection to a Handshaker.
Expand Down Expand Up @@ -613,7 +661,7 @@ func (c initConnection) CurrentlyStreaming() bool {
return c.getCurrentlyStreaming()
}
func (c initConnection) SupportsStreaming() bool {
return c.supportsStreaming()
return c.canStream
}

// Connection implements the driver.Connection interface to allow reading and writing wire
Expand Down Expand Up @@ -847,39 +895,6 @@ func (c *Connection) DriverConnectionID() uint64 {
return c.connection.DriverConnectionID()
}

func configureTLS(ctx context.Context,
tlsConnSource tlsConnectionSource,
nc net.Conn,
addr address.Address,
config *tls.Config,
ocspOpts *ocsp.VerifyOptions,
) (net.Conn, error) {
// Ensure config.ServerName is always set for SNI.
if config.ServerName == "" {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}

hostname = hostname[:colonPos]
config.ServerName = hostname
}

client := tlsConnSource.Client(nc, config)
if err := clientHandshake(ctx, client); err != nil {
return nil, err
}

// Only do OCSP verification if TLS verification is requested.
if !config.InsecureSkipVerify {
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
return nil, ocspErr
}
}
return client, nil
}

// OIDCTokenGenID returns the OIDC token generation ID.
func (c *Connection) OIDCTokenGenID() uint64 {
return c.oidcTokenGenID
Expand Down Expand Up @@ -933,11 +948,3 @@ func (c *cancellListener) StopListening() bool {
c.done <- struct{}{}
return c.aborted
}

func (c *connection) OIDCTokenGenID() uint64 {
return c.oidcTokenGenID
}

func (c *connection) SetOIDCTokenGenID(genID uint64) {
c.oidcTokenGenID = genID
}
22 changes: 6 additions & 16 deletions x/mongo/driver/topology/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -833,22 +832,13 @@ func bgRead(pool *pool, conn *connection, size int32) {
err = fmt.Errorf("error reading the message size: %w", err)
return
}
size = (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
if size < 4 {
err = fmt.Errorf("malformed message length: %d", size)
return
}
maxMessageSize := conn.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
err = errResponseTooLarge
size, err = conn.parseWmSizeBytes(sizeBuf)
if err != nil {
return
}
size -= 4
}
_, err = io.CopyN(ioutil.Discard, conn.nc, int64(size))
_, err = io.CopyN(io.Discard, conn.nc, int64(size))
if err != nil {
err = fmt.Errorf("error reading message of %d: %w", size, err)
}
Expand Down Expand Up @@ -901,9 +891,9 @@ func (p *pool) checkInNoEvent(conn *connection) error {
// means that connections in "awaiting response" state are checked in but
// not usable, which is not covered by the current pool events. We may need
// to add pool event information in the future to communicate that.
if conn.awaitingResponse != nil {
size := *conn.awaitingResponse
conn.awaitingResponse = nil
if conn.awaitRemainingBytes != nil {
size := *conn.awaitRemainingBytes
conn.awaitRemainingBytes = nil
go bgRead(p, conn, size)
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,7 @@ func TestPool(t *testing.T) {
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`,
)
assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err)
assert.Nil(t, conn.awaitingResponse, "conn.awaitingResponse should be nil")
assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitingResponse should be nil")
wg.Wait()
p.close(context.Background())
close(errCh)
Expand Down

0 comments on commit b565a23

Please sign in to comment.