diff --git a/internal/integration/mtest/opmsg_deployment.go b/internal/integration/mtest/opmsg_deployment.go index 23b258354a..dc15831fe5 100644 --- a/internal/integration/mtest/opmsg_deployment.go +++ b/internal/integration/mtest/opmsg_deployment.go @@ -16,6 +16,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) @@ -50,15 +51,16 @@ type connection struct { responses []bson.D // responses to send when ReadWireMessage is called } -var _ driver.Connection = &connection{} +var _ mnet.ReadWriteCloser = &connection{} +var _ mnet.Describer = &connection{} -// WriteWireMessage is a no-op. -func (c *connection) WriteWireMessage(context.Context, []byte) error { +// Write is a no-op. +func (c *connection) Write(context.Context, []byte) error { return nil } -// ReadWireMessage returns the next response in the connection's list of responses. -func (c *connection) ReadWireMessage(_ context.Context) ([]byte, error) { +// Read returns the next response in the connection's list of responses. +func (c *connection) Read(_ context.Context) ([]byte, error) { var dst []byte if len(c.responses) == 0 { return dst, errors.New("no responses remaining") @@ -137,8 +139,8 @@ func (md *mockDeployment) Kind() description.TopologyKind { } // Connection implements the driver.Server interface. -func (md *mockDeployment) Connection(context.Context) (driver.Connection, error) { - return md.conn, nil +func (md *mockDeployment) Connection(context.Context) (*mnet.Connection, error) { + return mnet.NewConnection(md.conn), nil } // RTTMonitor implements the driver.Server interface. diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 1d2bd703ac..cfdbb1f2b8 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -24,6 +24,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -282,7 +283,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in return cs, cs.Err() } -func (cs *ChangeStream) createOperationDeployment(server driver.Server, connection driver.Connection) driver.Deployment { +func (cs *ChangeStream) createOperationDeployment(server driver.Server, connection *mnet.Connection) driver.Deployment { return &changeStreamDeployment{ topologyKind: cs.client.deployment.Kind(), server: server, @@ -292,7 +293,7 @@ func (cs *ChangeStream) createOperationDeployment(server driver.Server, connecti func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) error { var server driver.Server - var conn driver.Connection + var conn *mnet.Connection if server, cs.err = cs.client.deployment.SelectServer(ctx, cs.selector); cs.err != nil { return cs.Err() diff --git a/mongo/change_stream_deployment.go b/mongo/change_stream_deployment.go index 4dca59f91c..a84b43f05c 100644 --- a/mongo/change_stream_deployment.go +++ b/mongo/change_stream_deployment.go @@ -11,12 +11,13 @@ import ( "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) type changeStreamDeployment struct { topologyKind description.TopologyKind server driver.Server - conn driver.Connection + conn *mnet.Connection } var _ driver.Deployment = (*changeStreamDeployment)(nil) @@ -31,7 +32,7 @@ func (c *changeStreamDeployment) Kind() description.TopologyKind { return c.topologyKind } -func (c *changeStreamDeployment) Connection(context.Context) (driver.Connection, error) { +func (c *changeStreamDeployment) Connection(context.Context) (*mnet.Connection, error) { return c.conn, nil } @@ -39,11 +40,11 @@ func (c *changeStreamDeployment) RTTMonitor() driver.RTTMonitor { return c.server.RTTMonitor() } -func (c *changeStreamDeployment) ProcessError(err error, conn driver.Connection) driver.ProcessErrorResult { +func (c *changeStreamDeployment) ProcessError(err error, describer mnet.Describer) driver.ProcessErrorResult { ep, ok := c.server.(driver.ErrorProcessor) if !ok { return driver.NoChange } - return ep.ProcessError(err, conn) + return ep.ProcessError(err, describer) } diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index 6eeaf0ee01..ac6540233b 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -15,6 +15,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -76,7 +77,11 @@ var _ driver.Handshaker = (*authHandshaker)(nil) // GetHandshakeInformation performs the initial MongoDB handshake to retrieve the required information for the provided // connection. -func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn driver.Connection) (driver.HandshakeInformation, error) { +func (ah *authHandshaker) GetHandshakeInformation( + ctx context.Context, + addr address.Address, + conn *mnet.Connection, +) (driver.HandshakeInformation, error) { if ah.wrapped != nil { return ah.wrapped.GetHandshakeInformation(ctx, addr, conn) } @@ -115,7 +120,7 @@ func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr addr } // FinishHandshake performs authentication for conn if necessary. -func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error { +func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn *mnet.Connection) error { performAuth := ah.options.PerformAuthentication if performAuth == nil { performAuth = func(serv description.Server) bool { @@ -124,10 +129,8 @@ func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Conne } } - desc := conn.Description() - if performAuth(desc) && ah.options.Authenticator != nil { + if performAuth(conn.Description()) && ah.options.Authenticator != nil { cfg := &Config{ - Description: desc, Connection: conn, ClusterClock: ah.options.ClusterClock, HandshakeInfo: ah.handshakeInfo, @@ -172,8 +175,7 @@ func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshake // Config holds the information necessary to perform an authentication attempt. type Config struct { - Description description.Server - Connection driver.Connection + Connection *mnet.Connection ClusterClock *session.ClusterClock HandshakeInfo driver.HandshakeInformation ServerAPI *driver.ServerAPIOptions diff --git a/x/mongo/driver/auth/gssapi.go b/x/mongo/driver/auth/gssapi.go index 4b860ba63f..9181280887 100644 --- a/x/mongo/driver/auth/gssapi.go +++ b/x/mongo/driver/auth/gssapi.go @@ -44,7 +44,7 @@ type GSSAPIAuthenticator struct { // Auth authenticates the connection. func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error { - target := cfg.Description.Addr.String() + target := cfg.Connection.Description().Addr.String() hostname, _, err := net.SplitHostPort(target) if err != nil { return newAuthError(fmt.Sprintf("invalid endpoint (%s) specified: %s", target, err), nil) diff --git a/x/mongo/driver/auth/gssapi_test.go b/x/mongo/driver/auth/gssapi_test.go index 857d171ce5..59f4e20c12 100644 --- a/x/mongo/driver/auth/gssapi_test.go +++ b/x/mongo/driver/auth/gssapi_test.go @@ -15,6 +15,8 @@ import ( "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) func TestGSSAPIAuthenticator(t *testing.T) { @@ -36,7 +38,13 @@ func TestGSSAPIAuthenticator(t *testing.T) { }, Addr: address.Address("foo:27017"), } - err := authenticator.Auth(context.Background(), &Config{Description: desc}) + chanconn := &drivertest.ChannelConn{ + Desc: desc, + } + + mnetconn := mnet.NewConnection(chanconn) + + err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) if err == nil { t.Fatalf("expected err, got nil") } diff --git a/x/mongo/driver/auth/mongodbcr.go b/x/mongo/driver/auth/mongodbcr.go index 6e2c2f4dcb..dcbdea82cf 100644 --- a/x/mongo/driver/auth/mongodbcr.go +++ b/x/mongo/driver/auth/mongodbcr.go @@ -60,7 +60,7 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error { doc := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendInt32Element(nil, "getnonce", 1)) cmd := operation.NewCommand(doc). Database(db). - Deployment(driver.SingleConnectionDeployment{cfg.Connection}). + Deployment(driver.SingleConnectionDeployment{C: cfg.Connection}). ClusterClock(cfg.ClusterClock). ServerAPI(cfg.ServerAPI) err := cmd.Execute(ctx) @@ -86,7 +86,7 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error { ) cmd = operation.NewCommand(doc). Database(db). - Deployment(driver.SingleConnectionDeployment{cfg.Connection}). + Deployment(driver.SingleConnectionDeployment{C: cfg.Connection}). ClusterClock(cfg.ClusterClock). ServerAPI(cfg.ServerAPI) err = cmd.Execute(ctx) diff --git a/x/mongo/driver/auth/mongodbcr_test.go b/x/mongo/driver/auth/mongodbcr_test.go index e2f43b2f21..8fcc59820b 100644 --- a/x/mongo/driver/auth/mongodbcr_test.go +++ b/x/mongo/driver/auth/mongodbcr_test.go @@ -16,6 +16,7 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" . "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) func TestMongoDBCRAuthenticator_Fails(t *testing.T) { @@ -46,7 +47,9 @@ func TestMongoDBCRAuthenticator_Fails(t *testing.T) { Desc: desc, } - err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c}) + mnetconn := mnet.NewConnection(c) + + err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) if err == nil { t.Fatalf("expected an error but got none") } @@ -85,7 +88,9 @@ func TestMongoDBCRAuthenticator_Succeeds(t *testing.T) { Desc: desc, } - err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c}) + mnetconn := mnet.NewConnection(c) + + err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) if err != nil { t.Fatalf("expected no error but got \"%s\"", err) } diff --git a/x/mongo/driver/auth/plain_test.go b/x/mongo/driver/auth/plain_test.go index baaf175d85..c0dc8d760f 100644 --- a/x/mongo/driver/auth/plain_test.go +++ b/x/mongo/driver/auth/plain_test.go @@ -18,6 +18,7 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" . "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) func TestPlainAuthenticator_Fails(t *testing.T) { @@ -48,7 +49,9 @@ func TestPlainAuthenticator_Fails(t *testing.T) { Desc: desc, } - err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c}) + mnetconn := mnet.NewConnection(c) + + err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) if err == nil { t.Fatalf("expected an error but got none") } @@ -91,7 +94,9 @@ func TestPlainAuthenticator_Extra_server_message(t *testing.T) { Desc: desc, } - err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c}) + mnetconn := mnet.NewConnection(c) + + err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) if err == nil { t.Fatalf("expected an error but got none") } @@ -129,7 +134,9 @@ func TestPlainAuthenticator_Succeeds(t *testing.T) { Desc: desc, } - err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c}) + mnetconn := mnet.NewConnection(c) + + err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) if err != nil { t.Fatalf("expected no error but got \"%s\"", err) } @@ -174,7 +181,9 @@ func TestPlainAuthenticator_SucceedsBoolean(t *testing.T) { Desc: desc, } - err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c}) + mnetconn := mnet.NewConnection(c) + + err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn}) require.NoError(t, err, "Auth error") require.Len(t, c.Written, 1, "expected 1 messages to be sent") diff --git a/x/mongo/driver/auth/scram_test.go b/x/mongo/driver/auth/scram_test.go index ef30a07364..46e6ed9111 100644 --- a/x/mongo/driver/auth/scram_test.go +++ b/x/mongo/driver/auth/scram_test.go @@ -14,6 +14,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) const ( @@ -68,18 +69,20 @@ func TestSCRAM(t *testing.T) { Max: 21, }, } - conn := &drivertest.ChannelConn{ + chanconn := &drivertest.ChannelConn{ Written: make(chan []byte, len(tc.payloads)), ReadResp: responses, Desc: desc, } - err = authenticator.Auth(context.Background(), &Config{Description: desc, Connection: conn}) + conn := mnet.NewConnection(chanconn) + + err = authenticator.Auth(context.Background(), &Config{Connection: conn}) assert.Nil(t, err, "Auth error: %v\n", err) // Verify that the first command sent is saslStart. - assert.True(t, len(conn.Written) > 1, "wire messages were written to the connection") - startCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + assert.True(t, len(chanconn.Written) > 1, "wire messages were written to the connection") + startCmd, err := drivertest.GetCommandFromMsgWireMessage(<-chanconn.Written) assert.Nil(t, err, "error parsing wire message: %v", err) cmdName := startCmd.Index(0).Key() assert.Equal(t, cmdName, "saslStart", "cmd name mismatch; expected 'saslStart', got %v", cmdName) diff --git a/x/mongo/driver/auth/speculative_scram_test.go b/x/mongo/driver/auth/speculative_scram_test.go index a159891adc..6ea2b8afcd 100644 --- a/x/mongo/driver/auth/speculative_scram_test.go +++ b/x/mongo/driver/auth/speculative_scram_test.go @@ -17,6 +17,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) var ( @@ -80,13 +81,15 @@ func TestSpeculativeSCRAM(t *testing.T) { ReadResp: responses, } + mnetconn := mnet.NewConnection(conn) + // Do both parts of the handshake. - info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn) + info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), mnetconn) assert.Nil(t, err, "GetHandshakeInformation error: %v", err) assert.NotNil(t, info.SpeculativeAuthenticate, "desc.SpeculativeAuthenticate not set") conn.Desc = info.Description // Set conn.Desc so the new description will be used for the authentication. - err = handshaker.FinishHandshake(context.Background(), conn) + err = handshaker.FinishHandshake(context.Background(), mnetconn) assert.Nil(t, err, "FinishHandshake error: %v", err) assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp)) @@ -165,13 +168,15 @@ func TestSpeculativeSCRAM(t *testing.T) { ReadResp: responses, } - info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn) + mnetconn := mnet.NewConnection(conn) + + info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), mnetconn) assert.Nil(t, err, "GetHandshakeInformation error: %v", err) assert.Nil(t, info.SpeculativeAuthenticate, "expected desc.SpeculativeAuthenticate to be unset, got %s", bson.Raw(info.SpeculativeAuthenticate)) conn.Desc = info.Description - err = handshaker.FinishHandshake(context.Background(), conn) + err = handshaker.FinishHandshake(context.Background(), mnetconn) assert.Nil(t, err, "FinishHandshake error: %v", err) assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp)) diff --git a/x/mongo/driver/auth/speculative_x509_test.go b/x/mongo/driver/auth/speculative_x509_test.go index cf46de6ffd..6ec2b8ea64 100644 --- a/x/mongo/driver/auth/speculative_x509_test.go +++ b/x/mongo/driver/auth/speculative_x509_test.go @@ -17,6 +17,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) var ( @@ -47,12 +48,14 @@ func TestSpeculativeX509(t *testing.T) { ReadResp: responses, } - info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn) + mnetconn := mnet.NewConnection(conn) + + info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), mnetconn) assert.Nil(t, err, "GetDescription error: %v", err) assert.NotNil(t, info.SpeculativeAuthenticate, "desc.SpeculativeAuthenticate not set") conn.Desc = info.Description - err = handshaker.FinishHandshake(context.Background(), conn) + err = handshaker.FinishHandshake(context.Background(), mnetconn) assert.Nil(t, err, "FinishHandshake error: %v", err) assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp)) @@ -91,13 +94,15 @@ func TestSpeculativeX509(t *testing.T) { ReadResp: responses, } - info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn) + mnetconn := mnet.NewConnection(conn) + + info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), mnetconn) assert.Nil(t, err, "GetDescription error: %v", err) assert.Nil(t, info.SpeculativeAuthenticate, "expected desc.SpeculativeAuthenticate to be unset, got %s", bson.Raw(info.SpeculativeAuthenticate)) conn.Desc = info.Description - err = handshaker.FinishHandshake(context.Background(), conn) + err = handshaker.FinishHandshake(context.Background(), mnetconn) assert.Nil(t, err, "FinishHandshake error: %v", err) assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp)) diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index bcfc8af95d..99db007c31 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -21,6 +21,7 @@ import ( "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -42,7 +43,7 @@ type BatchCursor struct { server Server serverDescription description.Server errorProcessor ErrorProcessor // This will only be set when pinning to a connection. - connection PinnedConnection + connection *mnet.Connection batchSize int32 maxTimeMS int64 currentBatch *bsoncore.Iterator @@ -62,7 +63,7 @@ type BatchCursor struct { type CursorResponse struct { Server Server ErrorProcessor ErrorProcessor // This will only be set when pinning to a connection. - Connection PinnedConnection + Connection *mnet.Connection Desc description.Server FirstBatch *bsoncore.Iterator Database string @@ -138,14 +139,15 @@ func NewCursorResponse(info ResponseInfo) (CursorResponse, error) { } curresp.ErrorProcessor = ep - refConn, ok := info.Connection.(PinnedConnection) - if !ok { + refConn := info.Connection.Pinner + if refConn == nil { + //debug.PrintStack() return CursorResponse{}, fmt.Errorf("expected Connection used to establish a cursor to implement PinnedConnection, but got %T", info.Connection) } if err := refConn.PinToCursor(); err != nil { return CursorResponse{}, fmt.Errorf("error incrementing connection reference count when creating a cursor: %w", err) } - curresp.Connection = refConn + curresp.Connection = info.Connection } return curresp, nil @@ -277,7 +279,7 @@ func (bc *BatchCursor) Close(ctx context.Context) error { } func (bc *BatchCursor) unpinConnection() error { - if bc.connection == nil { + if bc.connection == nil || bc.connection.Pinner == nil { return nil } @@ -499,7 +501,8 @@ func (bc *BatchCursor) getOperationDeployment() Deployment { // handled for these commands in this mode. type loadBalancedCursorDeployment struct { errorProcessor ErrorProcessor - conn PinnedConnection + //conn PinnedConnection + conn *mnet.Connection } var _ Deployment = (*loadBalancedCursorDeployment)(nil) @@ -514,7 +517,7 @@ func (lbcd *loadBalancedCursorDeployment) Kind() description.TopologyKind { return description.LoadBalanced } -func (lbcd *loadBalancedCursorDeployment) Connection(_ context.Context) (Connection, error) { +func (lbcd *loadBalancedCursorDeployment) Connection(context.Context) (*mnet.Connection, error) { return lbcd.conn, nil } @@ -523,6 +526,6 @@ func (lbcd *loadBalancedCursorDeployment) RTTMonitor() RTTMonitor { return &csot.ZeroRTTMonitor{} } -func (lbcd *loadBalancedCursorDeployment) ProcessError(err error, conn Connection) ProcessErrorResult { - return lbcd.errorProcessor.ProcessError(err, conn) +func (lbcd *loadBalancedCursorDeployment) ProcessError(err error, desc mnet.Describer) ProcessErrorResult { + return lbcd.errorProcessor.ProcessError(err, desc) } diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index d0c0ee5c22..57afebba0e 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -14,7 +14,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" - "go.mongodb.org/mongo-driver/x/mongo/driver/session" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) // Deployment is implemented by types that can select a server from a deployment. @@ -50,30 +50,12 @@ type Subscriber interface { // Server represents a MongoDB server. Implementations should pool connections and handle the // retrieving and returning of connections. type Server interface { - Connection(context.Context) (Connection, error) + Connection(context.Context) (*mnet.Connection, error) // RTTMonitor returns the round-trip time monitor associated with this server. RTTMonitor() RTTMonitor } -// Connection represents a connection to a MongoDB server. -type Connection interface { - WriteWireMessage(context.Context, []byte) error - ReadWireMessage(ctx context.Context) ([]byte, error) - Description() description.Server - - // Close closes any underlying connection and returns or frees any resources held by the - // connection. Close is idempotent and can be called multiple times, although subsequent calls - // to Close may return an error. A connection cannot be used after it is closed. - Close() error - - ID() string - ServerConnectionID() *int64 - DriverConnectionID() int64 - Address() address.Address - Stale() bool -} - // RTTMonitor represents a round-trip-time monitor. type RTTMonitor interface { // EWMA returns the exponentially weighted moving average observed round-trip time. @@ -88,25 +70,6 @@ type RTTMonitor interface { var _ RTTMonitor = &csot.ZeroRTTMonitor{} -// PinnedConnection represents a Connection that can be pinned by one or more cursors or transactions. Implementations -// of this interface should maintain the following invariants: -// -// 1. Each Pin* call should increment the number of references for the connection. -// 2. Each Unpin* call should decrement the number of references for the connection. -// 3. Calls to Close() should be ignored until all resources have unpinned the connection. -type PinnedConnection interface { - Connection - PinToCursor() error - PinToTransaction() error - UnpinFromCursor() error - UnpinFromTransaction() error -} - -// The session.LoadBalancedTransactionConnection type is a copy of PinnedConnection that was introduced to avoid -// import cycles. This compile-time assertion ensures that these types remain in sync if the PinnedConnection interface -// is changed in the future. -var _ PinnedConnection = (session.LoadBalancedTransactionConnection)(nil) - // LocalAddresser is a type that is able to supply its local address type LocalAddresser interface { LocalAddress() address.Address @@ -118,28 +81,6 @@ type Expirable interface { Alive() bool } -// StreamerConnection represents a Connection that supports streaming wire protocol messages using the moreToCome and -// exhaustAllowed flags. -// -// The SetStreaming and CurrentlyStreaming functions correspond to the moreToCome flag on server responses. If a -// response has moreToCome set, SetStreaming(true) will be called and CurrentlyStreaming() should return true. -// -// CanStream corresponds to the exhaustAllowed flag. The operations layer will set exhaustAllowed on outgoing wire -// messages to inform the server that the driver supports streaming. -type StreamerConnection interface { - Connection - SetStreaming(bool) - CurrentlyStreaming() bool - SupportsStreaming() bool -} - -// Compressor is an interface used to compress wire messages. If a Connection supports compression -// it should implement this interface as well. The CompressWireMessage method will be called during -// the execution of an operation if the wire message is allowed to be compressed. -type Compressor interface { - CompressWireMessage(src, dst []byte) ([]byte, error) -} - // ProcessErrorResult represents the result of a ErrorProcessor.ProcessError() call. Exact values for this type can be // checked directly (e.g. res == ServerMarkedUnknown), but it is recommended that applications use the ServerChanged() // function instead. @@ -159,7 +100,7 @@ const ( // If this type is implemented by a Server, then Operation.Execute will call it's ProcessError // method after it decodes a wire message. type ErrorProcessor interface { - ProcessError(err error, conn Connection) ProcessErrorResult + ProcessError(err error, desc mnet.Describer) ProcessErrorResult } // HandshakeInformation contains information extracted from a MongoDB connection handshake. This is a helper type that @@ -178,8 +119,8 @@ type HandshakeInformation struct { // handshake over a provided driver.Connection. This is used during connection // initialization. Implementations must be goroutine safe. type Handshaker interface { - GetHandshakeInformation(context.Context, address.Address, Connection) (HandshakeInformation, error) - FinishHandshake(context.Context, Connection) error + GetHandshakeInformation(context.Context, address.Address, *mnet.Connection) (HandshakeInformation, error) + FinishHandshake(context.Context, *mnet.Connection) error } // SingleServerDeployment is an implementation of Deployment that always returns a single server. @@ -199,7 +140,7 @@ func (SingleServerDeployment) Kind() description.TopologyKind { return descripti // SingleConnectionDeployment is an implementation of Deployment that always returns the same Connection. This // implementation should only be used for connection handshakes and server heartbeats as it does not implement // ErrorProcessor, which is necessary for application operations. -type SingleConnectionDeployment struct{ C Connection } +type SingleConnectionDeployment struct{ C *mnet.Connection } var _ Deployment = SingleConnectionDeployment{} var _ Server = SingleConnectionDeployment{} @@ -215,7 +156,7 @@ func (scd SingleConnectionDeployment) SelectServer(context.Context, description. func (SingleConnectionDeployment) Kind() description.TopologyKind { return description.Single } // Connection implements the Server interface. It always returns the embedded connection. -func (scd SingleConnectionDeployment) Connection(context.Context) (Connection, error) { +func (scd SingleConnectionDeployment) Connection(context.Context) (*mnet.Connection, error) { return scd.C, nil } diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index 85471bc94e..c1eb6f19c6 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -27,7 +27,7 @@ type ChannelConn struct { } // WriteWireMessage implements the driver.Connection interface. -func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error { +func (c *ChannelConn) Write(ctx context.Context, wm []byte) error { // Copy wm in case it came from a buffer pool. b := make([]byte, len(wm)) copy(b, wm) @@ -42,7 +42,7 @@ func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error { } // ReadWireMessage implements the driver.Connection interface. -func (c *ChannelConn) ReadWireMessage(ctx context.Context) ([]byte, error) { +func (c *ChannelConn) Read(ctx context.Context) ([]byte, error) { var wm []byte var err error select { diff --git a/x/mongo/driver/mnet/connection.go b/x/mongo/driver/mnet/connection.go new file mode 100644 index 0000000000..f8542c4e0a --- /dev/null +++ b/x/mongo/driver/mnet/connection.go @@ -0,0 +1,118 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mnet + +import ( + "context" + "io" + + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/description" +) + +// ReadWriteCloser represents a Connection where server operations +// can read from, written to, and closed. +type ReadWriteCloser interface { + Read(ctx context.Context) ([]byte, error) + Write(ctx context.Context, wm []byte) error + io.Closer +} + +// Describer represents a Connection that can be described. +type Describer interface { + Description() description.Server + ID() string + ServerConnectionID() *int64 + DriverConnectionID() int64 + Address() address.Address + Stale() bool +} + +// Streamer represents a Connection that supports streaming wire protocol +// messages using the moreToCome and exhaustAllowed flags. +// +// The SetStreaming and CurrentlyStreaming functions correspond to the +// moreToCome flag on server responses. If a response has moreToCome set, +// SetStreaming(true) will be called and CurrentlyStreaming() should return +// true. +// +// CanStream corresponds to the exhaustAllowed flag. The operations layer will +// set exhaustAllowed on outgoing wire messages to inform the server that the +// driver supports streaming. +type Streamer interface { + SetStreaming(bool) + CurrentlyStreaming() bool + SupportsStreaming() bool +} + +// Compressor is an interface used to compress wire messages. If a Connection +// supports compression it should implement this interface as well. The +// CompressWireMessage method will be called during the execution of an +// operation if the wire message is allowed to be compressed. +type Compressor interface { + CompressWireMessage(src, dst []byte) ([]byte, error) +} + +// Pinner represents a Connection that can be pinned by one or more cursors or +// transactions. Implementations of this interface should maintain the following +// invariants: +// +// 1. Each Pin* call should increment the number of references for the +// connection. +// 2. Each Unpin* call should decrement the number of references for the +// connection. +// 3. Calls to Close() should be ignored until all resources have unpinned the +// connection. +type Pinner interface { + PinToCursor() error + PinToTransaction() error + UnpinFromCursor() error + UnpinFromTransaction() error +} + +// Connection represents a connection to a MongoDB server. +type Connection struct { + ReadWriteCloser + Describer + Streamer + Compressor + Pinner +} + +// NewConnection creates a new Connection with the provided component. This +// constructor returns a component that is already a Connection to avoid +// mis-asserting the composite interfaces. +func NewConnection(component interface { + ReadWriteCloser + Describer +}) *Connection { + if _, ok := component.(*Connection); ok { + return component.(*Connection) + } + + conn := &Connection{ + ReadWriteCloser: component, + } + + if describer, ok := component.(Describer); ok { + conn.Describer = describer + } + + if streamer, ok := component.(Streamer); ok { + conn.Streamer = streamer + } + + if compressor, ok := component.(Compressor); ok { + conn.Compressor = compressor + } + + if pinner, ok := component.(Pinner); ok { + conn.Pinner = pinner + } + + return conn +} diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 41588069d6..4f97dd7a86 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -31,6 +31,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) @@ -136,7 +137,7 @@ func (info finishedInformation) success() bool { type ResponseInfo struct { ServerResponse bsoncore.Document Server Server - Connection Connection + Connection *mnet.Connection ConnectionDescription description.Server CurrentIndex int } @@ -404,7 +405,7 @@ func (op Operation) getServerAndConnection( ctx context.Context, requestID int32, deprioritized []description.Server, -) (Server, Connection, error) { +) (Server, *mnet.Connection, error) { server, err := op.selectServer(ctx, requestID, deprioritized) if err != nil { if op.Client != nil && @@ -421,7 +422,8 @@ func (op Operation) getServerAndConnection( // If the provided client session has a pinned connection, it should be used for the operation because this // indicates that we're in a transaction and the target server is behind a load balancer. if op.Client != nil && op.Client.PinnedConnection != nil { - return server, op.Client.PinnedConnection, nil + conn := mnet.NewConnection(op.Client.PinnedConnection) + return server, conn, nil } // Otherwise, default to checking out a connection from the server's pool. @@ -432,18 +434,17 @@ func (op Operation) getServerAndConnection( // If we're in load balanced mode and this is the first operation in a transaction, pin the session to a connection. if conn.Description().LoadBalanced() && op.Client != nil && op.Client.TransactionStarting() { - pinnedConn, ok := conn.(PinnedConnection) - if !ok { + if conn.Pinner == nil { // Close the original connection to avoid a leak. _ = conn.Close() return nil, nil, fmt.Errorf("expected Connection used to start a transaction to be a PinnedConnection, but got %T", conn) } - if err := pinnedConn.PinToTransaction(); err != nil { + if err := conn.PinToTransaction(); err != nil { // Close the original connection to avoid a leak. _ = conn.Close() return nil, nil, fmt.Errorf("error incrementing connection reference count when starting a transaction: %w", err) } - op.Client.PinnedConnection = pinnedConn + op.Client.PinnedConnection = conn } return server, conn, nil @@ -528,7 +529,7 @@ func (op Operation) Execute(ctx context.Context) error { } var srvr Server - var conn Connection + var conn *mnet.Connection var res bsoncore.Document var operationErr WriteCommandError var prevErr error @@ -729,7 +730,7 @@ func (op Operation) Execute(ctx context.Context) error { moreToCome := wiremessage.IsMsgMoreToCome(*wm) // compress wiremessage if allowed - if compressor, ok := conn.(Compressor); ok && op.canCompress(startedInfo.cmdName) { + if compressor := conn.Compressor; compressor != nil && op.canCompress(startedInfo.cmdName) { b := memoryPool.Get().(*[]byte) *b, err = compressor.CompressWireMessage(*wm, (*b)[:0]) memoryPool.Put(wm) @@ -1032,23 +1033,23 @@ func (op Operation) retryable(desc description.Server) bool { // roundTrip writes a wiremessage to the connection and then reads a wiremessage. The wm parameter // is reused when reading the wiremessage. -func (op Operation) roundTrip(ctx context.Context, conn Connection, wm []byte) ([]byte, error) { - err := conn.WriteWireMessage(ctx, wm) +func (op Operation) roundTrip(ctx context.Context, conn *mnet.Connection, wm []byte) ([]byte, error) { + err := conn.Write(ctx, wm) if err != nil { return nil, op.networkError(err) } return op.readWireMessage(ctx, conn) } -func (op Operation) readWireMessage(ctx context.Context, conn Connection) (result []byte, err error) { - wm, err := conn.ReadWireMessage(ctx) +func (op Operation) readWireMessage(ctx context.Context, conn *mnet.Connection) (result []byte, err error) { + wm, err := conn.Read(ctx) if err != nil { return nil, op.networkError(err) } // If we're using a streamable connection, we set its streaming state based on the moreToCome flag in the server // response. - if streamer, ok := conn.(StreamerConnection); ok { + if streamer := conn.Streamer; streamer != nil { streamer.SetStreaming(wiremessage.IsMsgMoreToCome(wm)) } @@ -1112,8 +1113,8 @@ func (op Operation) networkError(err error) error { // moreToComeRoundTrip writes a wiremessage to the provided connection. This is used when an OP_MSG is // being sent with the moreToCome bit set. -func (op *Operation) moreToComeRoundTrip(ctx context.Context, conn Connection, wm []byte) (result []byte, err error) { - err = conn.WriteWireMessage(ctx, wm) +func (op *Operation) moreToComeRoundTrip(ctx context.Context, conn *mnet.Connection, wm []byte) (result []byte, err error) { + err = conn.Write(ctx, wm) if err != nil { if op.Client != nil { op.Client.MarkDirty() @@ -1251,7 +1252,7 @@ func (op Operation) createMsgWireMessage( maxTimeMS uint64, dst []byte, desc description.SelectedServer, - conn Connection, + conn *mnet.Connection, requestID int32, ) ([]byte, startedInformation, error) { var info startedInformation @@ -1264,7 +1265,7 @@ func (op Operation) createMsgWireMessage( } // Set the ExhaustAllowed flag if the connection supports streaming. This will tell the server that it can // respond with the MoreToCome flag and then stream responses over this connection. - if streamer, ok := conn.(StreamerConnection); ok && streamer.SupportsStreaming() { + if streamer := conn.Streamer; streamer != nil && streamer.SupportsStreaming() { flags |= wiremessage.ExhaustAllowed } @@ -1347,7 +1348,7 @@ func (op Operation) createWireMessage( maxTimeMS uint64, dst []byte, desc description.SelectedServer, - conn Connection, + conn *mnet.Connection, requestID int32, ) ([]byte, startedInformation, error) { if isLegacyHandshake(op, desc) { diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index 656ef409e3..a62b9522c0 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -111,7 +111,6 @@ func (f *Find) Execute(ctx context.Context) error { Logger: f.logger, Name: driverutil.FindOp, }.Execute(ctx) - } func (f *Find) command(dst []byte, desc description.SelectedServer) ([]byte, error) { diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 16f2ebf6c0..5a9d9bb36b 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -23,6 +23,7 @@ import ( "go.mongodb.org/mongo-driver/version" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -567,7 +568,7 @@ func (h *Hello) Execute(ctx context.Context) error { } // StreamResponse gets the next streaming Hello response from the server. -func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnection) error { +func (h *Hello) StreamResponse(ctx context.Context, conn *mnet.Connection) error { return h.createOperation().ExecuteExhaust(ctx, conn) } @@ -601,8 +602,8 @@ func (h *Hello) createOperation() driver.Operation { // GetHandshakeInformation performs the MongoDB handshake for the provided connection and returns the relevant // information about the server. This function implements the driver.Handshaker interface. -func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, c driver.Connection) (driver.HandshakeInformation, error) { - deployment := driver.SingleConnectionDeployment{C: c} +func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, conn *mnet.Connection) (driver.HandshakeInformation, error) { + deployment := driver.SingleConnectionDeployment{C: conn} op := driver.Operation{ Clock: h.clock, @@ -625,7 +626,7 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, } info := driver.HandshakeInformation{ - Description: h.Result(c.Address()), + Description: h.Result(conn.Address()), } if speculativeAuthenticate, ok := h.res.Lookup("speculativeAuthenticate").DocumentOK(); ok { info.SpeculativeAuthenticate = speculativeAuthenticate @@ -646,6 +647,6 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, // FinishHandshake implements the Handshaker interface. This is a no-op function because a non-authenticated connection // does not do anything besides the initial Hello for a handshake. -func (h *Hello) FinishHandshake(context.Context, driver.Connection) error { +func (h *Hello) FinishHandshake(context.Context, *mnet.Connection) error { return nil } diff --git a/x/mongo/driver/operation_exhaust.go b/x/mongo/driver/operation_exhaust.go index e0879de316..e3a220b3ac 100644 --- a/x/mongo/driver/operation_exhaust.go +++ b/x/mongo/driver/operation_exhaust.go @@ -9,11 +9,13 @@ package driver import ( "context" "errors" + + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) // ExecuteExhaust reads a response from the provided StreamerConnection. This will error if the connection's // CurrentlyStreaming function returns false. -func (op Operation) ExecuteExhaust(ctx context.Context, conn StreamerConnection) error { +func (op Operation) ExecuteExhaust(ctx context.Context, conn *mnet.Connection) error { if !conn.CurrentlyStreaming() { return errors.New("exhaust read must be done with a connection that is currently streaming") } diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 9fbfaae133..befe51f2db 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -28,6 +28,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/tag" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) @@ -569,9 +570,10 @@ func TestOperation(t *testing.T) { }) t.Run("ExecuteExhaust", func(t *testing.T) { t.Run("errors if connection is not streaming", func(t *testing.T) { - conn := &mockConnection{ + conn := mnet.NewConnection(&mockConnection{ rStreaming: false, - } + }) + err := Operation{}.ExecuteExhaust(context.TODO(), conn) assert.NotNil(t, err, "expected error, got nil") }) @@ -596,12 +598,15 @@ func TestOperation(t *testing.T) { rReadWM: nonStreamingResponse, rCanStream: false, } + + mnetconn := mnet.NewConnection(conn) + op := Operation{ CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { return bsoncore.AppendInt32Element(dst, handshake.LegacyHello, 1), nil }, Database: "admin", - Deployment: SingleConnectionDeployment{conn}, + Deployment: SingleConnectionDeployment{C: mnetconn}, } err := op.Execute(context.TODO()) assert.Nil(t, err, "Execute error: %v", err) @@ -623,12 +628,13 @@ func TestOperation(t *testing.T) { // Reset the server response and go through ExecuteExhaust to mimic streaming the next response. After // execution, the connection should still be in a streaming state. conn.rReadWM = streamingResponse - err = op.ExecuteExhaust(context.TODO(), conn) + err = op.ExecuteExhaust(context.TODO(), mnetconn) assert.Nil(t, err, "ExecuteExhaust error: %v", err) assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true") }) t.Run("context deadline exceeded not marked as TransientTransactionError", func(t *testing.T) { - conn := new(mockConnection) + conn := mnet.NewConnection(&mockConnection{}) + // Create a context that's already timed out. ctx, cancel := context.WithDeadline(context.Background(), time.Unix(893934480, 0)) defer cancel() @@ -636,7 +642,7 @@ func TestOperation(t *testing.T) { op := Operation{ Database: "foobar", Deployment: SingleConnectionDeployment{C: conn}, - CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { + CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) { dst = bsoncore.AppendInt32Element(dst, "ping", 1) return dst, nil }, @@ -649,7 +655,8 @@ func TestOperation(t *testing.T) { assert.Equal(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded error, got %v", err) }) t.Run("canceled context not marked as TransientTransactionError", func(t *testing.T) { - conn := new(mockConnection) + conn := mnet.NewConnection(&mockConnection{}) + // Create a context and cancel it immediately. ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -757,12 +764,12 @@ func (m *mockConnection) Stale() bool { return false } func (m *mockConnection) DriverConnectionID() int64 { return 0 } -func (m *mockConnection) WriteWireMessage(_ context.Context, wm []byte) error { +func (m *mockConnection) Write(_ context.Context, wm []byte) error { m.pWriteWM = wm return m.rWriteErr } -func (m *mockConnection) ReadWireMessage(_ context.Context) ([]byte, error) { +func (m *mockConnection) Read(_ context.Context) ([]byte, error) { return m.rReadWM, m.rReadErr } @@ -782,7 +789,7 @@ type mockRetryServer struct { // Connection records the number of calls and returns retryable errors until the provided context // times out or is cancelled, then returns the context error. -func (ms *mockRetryServer) Connection(ctx context.Context) (Connection, error) { +func (ms *mockRetryServer) Connection(ctx context.Context) (*mnet.Connection, error) { ms.numCallsToConnection++ if ctx.Err() != nil { diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index 35c2c3e89d..d99ad101ae 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -7,19 +7,18 @@ package session import ( - "context" "errors" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/internal/uuid" - "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) // ErrSessionEnded is returned when a client session is used after a call to endSession(). @@ -76,26 +75,15 @@ func (s TransactionState) String() string { } } +var _ mnet.Pinner = (LoadBalancedTransactionConnection)(nil) + // LoadBalancedTransactionConnection represents a connection that's pinned by a ClientSession because it's being used // to execute a transaction when running against a load balancer. This interface is a copy of driver.PinnedConnection // and exists to be able to pin transactions to a connection without causing an import cycle. type LoadBalancedTransactionConnection interface { - // Functions copied over from driver.Connection. - WriteWireMessage(context.Context, []byte) error - ReadWireMessage(ctx context.Context) ([]byte, error) - Description() description.Server - Close() error - ID() string - ServerConnectionID() *int64 - DriverConnectionID() int64 - Address() address.Address - Stale() bool - - // Functions copied over from driver.PinnedConnection that are not part of Connection or Expirable. - PinToCursor() error - PinToTransaction() error - UnpinFromCursor() error - UnpinFromTransaction() error + mnet.ReadWriteCloser + mnet.Describer + mnet.Pinner } // Client is a session for clients to run commands. diff --git a/x/mongo/driver/topology/CMAP_spec_test.go b/x/mongo/driver/topology/CMAP_spec_test.go index 5f0e1c2eef..764d25db9d 100644 --- a/x/mongo/driver/topology/CMAP_spec_test.go +++ b/x/mongo/driver/topology/CMAP_spec_test.go @@ -23,6 +23,7 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/internal/spectest" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" ) @@ -517,7 +518,7 @@ func runOperationInThread(t *testing.T, operation map[string]interface{}, testIn t.Fatalf("was unable to find %v in objects when expected", cName) } - c, ok := cEmptyInterface.(*Connection) + c, ok := cEmptyInterface.(*mnet.Connection) if !ok { t.Fatalf("object in objects was expected to be a connection, but was instead a %T", cEmptyInterface) } diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 6bfc827a6d..00f44bee6b 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -22,6 +22,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/ocsp" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) @@ -222,7 +223,11 @@ func (c *connection) connect(ctx context.Context) (err error) { var handshakeInfo driver.HandshakeInformation handshakeStartTime := time.Now() - handshakeConn := initConnection{c} + + iconn := initConnection{c} + + handshakeConn := mnet.NewConnection(iconn) + handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn) if err == nil { // We only need to retain the Description field as the connection's description. The authentication-related @@ -547,8 +552,9 @@ func (c *connection) ServerConnectionID() *int64 { // *connection to a Handshaker. type initConnection struct{ *connection } -var _ driver.Connection = initConnection{} -var _ driver.StreamerConnection = initConnection{} +var _ mnet.ReadWriteCloser = initConnection{} +var _ mnet.Describer = initConnection{} +var _ mnet.Streamer = initConnection{} func (c initConnection) Description() description.Server { if c.connection == nil { @@ -566,10 +572,10 @@ func (c initConnection) LocalAddress() address.Address { } return address.Address(c.nc.LocalAddr().String()) } -func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error { +func (c initConnection) Write(ctx context.Context, wm []byte) error { return c.writeWireMessage(ctx, wm) } -func (c initConnection) ReadWireMessage(ctx context.Context) ([]byte, error) { +func (c initConnection) Read(ctx context.Context) ([]byte, error) { return c.readWireMessage(ctx) } func (c initConnection) SetStreaming(streaming bool) { @@ -597,12 +603,14 @@ type Connection struct { mu sync.RWMutex } -var _ driver.Connection = (*Connection)(nil) +var _ mnet.ReadWriteCloser = (*Connection)(nil) +var _ mnet.Describer = (*Connection)(nil) +var _ mnet.Compressor = (*Connection)(nil) +var _ mnet.Pinner = (*Connection)(nil) var _ driver.Expirable = (*Connection)(nil) -var _ driver.PinnedConnection = (*Connection)(nil) // WriteWireMessage handles writing a wire message to the underlying connection. -func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error { +func (c *Connection) Write(ctx context.Context, wm []byte) error { c.mu.RLock() defer c.mu.RUnlock() if c.connection == nil { @@ -613,7 +621,7 @@ func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error { // ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter // will be overwritten with the new wire message. -func (c *Connection) ReadWireMessage(ctx context.Context) ([]byte, error) { +func (c *Connection) Read(ctx context.Context) ([]byte, error) { c.mu.RLock() defer c.mu.RUnlock() if c.connection == nil { diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index dc774b469b..0294a35be5 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -22,16 +22,17 @@ import ( "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) type testHandshaker struct { - getHandshakeInformation func(context.Context, address.Address, driver.Connection) (driver.HandshakeInformation, error) - finishHandshake func(context.Context, driver.Connection) error + getHandshakeInformation func(context.Context, address.Address, *mnet.Connection) (driver.HandshakeInformation, error) + finishHandshake func(context.Context, *mnet.Connection) error } // GetHandshakeInformation implements the Handshaker interface. -func (th *testHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn driver.Connection) (driver.HandshakeInformation, error) { +func (th *testHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn *mnet.Connection) (driver.HandshakeInformation, error) { if th.getHandshakeInformation != nil { return th.getHandshakeInformation(ctx, addr, conn) } @@ -39,7 +40,7 @@ func (th *testHandshaker) GetHandshakeInformation(ctx context.Context, addr addr } // FinishHandshake implements the Handshaker interface. -func (th *testHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error { +func (th *testHandshaker) FinishHandshake(ctx context.Context, conn *mnet.Connection) error { if th.finishHandshake != nil { return th.finishHandshake(ctx, conn) } @@ -78,7 +79,7 @@ func TestConnection(t *testing.T) { conn := newConnection(address.Address(""), WithHandshaker(func(Handshaker) Handshaker { return &testHandshaker{ - finishHandshake: func(context.Context, driver.Connection) error { + finishHandshake: func(context.Context, *mnet.Connection) error { return err }, } @@ -302,11 +303,11 @@ func TestConnection(t *testing.T) { var getInfoCtx, finishCtx context.Context handshaker := &testHandshaker{ - getHandshakeInformation: func(ctx context.Context, _ address.Address, _ driver.Connection) (driver.HandshakeInformation, error) { + getHandshakeInformation: func(ctx context.Context, _ address.Address, _ *mnet.Connection) (driver.HandshakeInformation, error) { getInfoCtx = ctx return driver.HandshakeInformation{}, nil }, - finishHandshake: func(ctx context.Context, _ driver.Connection) error { + finishHandshake: func(ctx context.Context, _ *mnet.Connection) error { finishCtx = ctx return nil }, @@ -667,7 +668,7 @@ func TestConnection(t *testing.T) { conn := newConnection(address.Address(""), WithHandshaker(func(Handshaker) Handshaker { return &testHandshaker{ - finishHandshake: func(context.Context, driver.Connection) error { + finishHandshake: func(context.Context, *mnet.Connection) error { return errors.New("handshake err") }, } @@ -712,11 +713,11 @@ func TestConnection(t *testing.T) { var want, got interface{} want = ErrConnectionClosed - got = conn.WriteWireMessage(context.Background(), nil) + got = conn.Write(context.Background(), nil) if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { t.Errorf("errors do not match. got %v; want %v", got, want) } - _, got = conn.ReadWireMessage(context.Background()) + _, got = conn.Read(context.Background()) if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { t.Errorf("errors do not match. got %v; want %v", got, want) } diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index 8b0a4b4950..54b37de048 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -14,6 +14,7 @@ import ( "time" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" ) @@ -34,7 +35,7 @@ type rttConfig struct { minRTTWindow time.Duration createConnectionFn func() *connection - createOperationFn func(driver.Connection) *operation.Hello + createOperationFn func(*mnet.Connection) *operation.Hello } type rttMonitor struct { @@ -173,7 +174,9 @@ func (r *rttMonitor) runHellos(conn *connection) { ctx, cancel := context.WithTimeout(r.ctx, timeout) start := time.Now() - err := r.cfg.createOperationFn(initConnection{conn}).Execute(ctx) + iconn := mnet.NewConnection(initConnection{conn}) + + err := r.cfg.createOperationFn(iconn).Execute(ctx) cancel() if err != nil { return diff --git a/x/mongo/driver/topology/rtt_monitor_test.go b/x/mongo/driver/topology/rtt_monitor_test.go index 5fa1cb9bf1..7abfe024fc 100644 --- a/x/mongo/driver/topology/rtt_monitor_test.go +++ b/x/mongo/driver/topology/rtt_monitor_test.go @@ -23,6 +23,7 @@ import ( "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" ) @@ -94,7 +95,7 @@ func TestRTTMonitor(t *testing.T) { createConnectionFn: func() *connection { return newConnection("", WithDialer(func(Dialer) Dialer { return dialer })) }, - createOperationFn: func(conn driver.Connection) *operation.Hello { + createOperationFn: func(conn *mnet.Connection) *operation.Hello { return operation.NewHello().Deployment(driver.SingleConnectionDeployment{C: conn}) }, }) @@ -132,7 +133,7 @@ func TestRTTMonitor(t *testing.T) { return dialer })) }, - createOperationFn: func(conn driver.Connection) *operation.Hello { + createOperationFn: func(conn *mnet.Connection) *operation.Hello { return operation.NewHello().Deployment(driver.SingleConnectionDeployment{C: conn}) }, }) @@ -153,7 +154,7 @@ func TestRTTMonitor(t *testing.T) { createConnectionFn: func() *connection { return newConnection("", WithDialer(func(Dialer) Dialer { return dialer })) }, - createOperationFn: func(conn driver.Connection) *operation.Hello { + createOperationFn: func(conn *mnet.Connection) *operation.Hello { return operation.NewHello().Deployment(driver.SingleConnectionDeployment{C: conn}) }, }) @@ -252,7 +253,7 @@ func TestRTTMonitor(t *testing.T) { createConnectionFn: func() *connection { return newConnection(address.Address(l.Addr().String())) }, - createOperationFn: func(conn driver.Connection) *operation.Hello { + createOperationFn: func(conn *mnet.Connection) *operation.Hello { return operation.NewHello().Deployment(driver.SingleConnectionDeployment{C: conn}) }, }) diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index c62efc0a2a..373f3a861c 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -24,6 +24,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" ) @@ -295,7 +296,7 @@ func (s *Server) Disconnect(ctx context.Context) error { } // Connection gets a connection to the server. -func (s *Server) Connection(ctx context.Context) (driver.Connection, error) { +func (s *Server) Connection(ctx context.Context) (*mnet.Connection, error) { if atomic.LoadInt64(&s.state) != serverConnected { return nil, ErrServerClosed } @@ -310,7 +311,7 @@ func (s *Server) Connection(ctx context.Context) (driver.Connection, error) { return nil, err } - return &Connection{ + serverConn := &Connection{ connection: conn, cleanupServerFn: func() { // Decrement the operation count whenever the caller is done with the connection. Note @@ -321,7 +322,9 @@ func (s *Server) Connection(ctx context.Context) (driver.Connection, error) { // make the server much less selectable. atomic.AddInt64(&s.operationCount, -1) }, - }, nil + } + + return mnet.NewConnection(serverConn), nil } // ProcessHandshakeError implements SDAM error handling for errors that occur before a connection @@ -429,7 +432,7 @@ func getWriteConcernErrorForProcessing(err error) (*driver.WriteConcernError, bo } // ProcessError handles SDAM error handling and implements driver.ErrorProcessor. -func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessErrorResult { +func (s *Server) ProcessError(err error, describer mnet.Describer) driver.ProcessErrorResult { // Ignore nil errors. if err == nil { return driver.NoChange @@ -440,7 +443,7 @@ func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessE // the pool generation to increment. Processing errors for stale connections could result in // handling the same error root cause multiple times (e.g. a temporary network interrupt causing // all connections to the same server to return errors). - if conn.Stale() { + if describer.Stale() { return driver.NoChange } @@ -453,7 +456,7 @@ func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessE // Get the wire version and service ID from the connection description because they will never // change for the lifetime of a connection and can possibly be different between connections to // the same server. - connDesc := conn.Description() + connDesc := describer.Description() wireVersion := connDesc.WireVersion serviceID := connDesc.ServiceID @@ -796,7 +799,7 @@ func (s *Server) checkWasCancelled() bool { return s.heartbeatCtx.Err() != nil } -func (s *Server) createBaseOperation(conn driver.Connection) *operation.Hello { +func (s *Server) createBaseOperation(conn *mnet.Connection) *operation.Hello { return operation. NewHello(). ClusterClock(s.cfg.clock). @@ -854,7 +857,9 @@ func (s *Server) check() (description.Server, error) { // An existing connection is being used. Use the server description properties to execute the right heartbeat. // Wrap conn in a type that implements driver.StreamerConnection. - heartbeatConn := initConnection{s.conn} + iconn := initConnection{s.conn} + heartbeatConn := mnet.NewConnection(iconn) + baseOperation := s.createBaseOperation(heartbeatConn) previousDescription := s.Description() streamable := isStreamingEnabled(s) && isStreamable(s) diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index e23c604156..1bf7aa8f1a 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -35,6 +35,7 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) @@ -384,7 +385,7 @@ func TestServer(t *testing.T) { return append(connOpts, WithHandshaker(func(Handshaker) Handshaker { return &testHandshaker{ - finishHandshake: func(context.Context, driver.Connection) error { + finishHandshake: func(context.Context, *mnet.Connection) error { var err error if tt.connectionError && returnConnectionError { err = authErr.Wrapped @@ -510,7 +511,7 @@ func TestServer(t *testing.T) { } handshaker := &testHandshaker{ - getHandshakeInformation: func(_ context.Context, addr address.Address, _ driver.Connection) (driver.HandshakeInformation, error) { + getHandshakeInformation: func(_ context.Context, addr address.Address, _ *mnet.Connection) (driver.HandshakeInformation, error) { if tc.getInfoErr != nil && returnConnectionError { return driver.HandshakeInformation{}, tc.getInfoErr } @@ -521,7 +522,7 @@ func TestServer(t *testing.T) { } return driver.HandshakeInformation{Description: desc}, nil }, - finishHandshake: func(context.Context, driver.Connection) error { + finishHandshake: func(context.Context, *mnet.Connection) error { if tc.finishHandshakeErr != nil && returnConnectionError { return tc.finishHandshakeErr } @@ -858,8 +859,8 @@ func TestServer_ProcessError(t *testing.T) { startDescription description.Server // Initial server description at the start of the test. - inputErr error // ProcessError error input. - inputConn driver.Connection // ProcessError conn input. + inputErr error // ProcessError error input. + inputConn *mnet.Connection // ProcessError conn input. want driver.ProcessErrorResult // Expected ProcessError return value. wantGeneration uint64 // Expected resulting connection pool generation. @@ -884,12 +885,8 @@ func TestServer_ProcessError(t *testing.T) { startDescription: description.Server{ Kind: description.RSPrimary, }, - inputErr: errors.New("foo"), - inputConn: newProcessErrorTestConn( - &description.VersionRange{ - Max: 17, - }, - true), + inputErr: errors.New("foo"), + inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, true), want: driver.NoChange, wantGeneration: 0, wantDescription: description.Server{ @@ -988,7 +985,7 @@ func TestServer_ProcessError(t *testing.T) { Counter: 1, }, }, - inputConn: &processErrorTestConn{ + inputConn: mnet.NewConnection(&processErrorTestConn{ description: description.Server{ WireVersion: &description.VersionRange{Max: 17}, TopologyVersion: &description.TopologyVersion{ @@ -997,7 +994,7 @@ func TestServer_ProcessError(t *testing.T) { }, }, stale: false, - }, + }), want: driver.NoChange, wantGeneration: 0, wantDescription: newServerDescription(description.RSPrimary, processID, 0, nil), @@ -1262,20 +1259,23 @@ func includesClientMetadata(t *testing.T, wm []byte) bool { // for Server.ProcessError. This type should not be used for other tests // because it does not implement all of the functions of the interface. type processErrorTestConn struct { + mnet.ReadWriteCloser + mnet.Describer // Embed a driver.Connection to quickly implement the interface without // implementing all methods. - driver.Connection description description.Server stale bool } -func newProcessErrorTestConn(wireVersion *description.VersionRange, stale bool) *processErrorTestConn { - return &processErrorTestConn{ +func newProcessErrorTestConn(wireVersion *description.VersionRange, stale bool) *mnet.Connection { + peconn := &processErrorTestConn{ description: description.Server{ WireVersion: wireVersion, }, stale: stale, } + + return mnet.NewConnection(peconn) } func (p *processErrorTestConn) Stale() bool {