Skip to content

Commit

Permalink
GODRIVER-3058 Add mnet constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez committed Nov 28, 2023
1 parent f444c07 commit 3625419
Show file tree
Hide file tree
Showing 17 changed files with 64 additions and 187 deletions.
1 change: 0 additions & 1 deletion mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) {
client.serverAPI = topology.ServerAPIFromServerOptions(cfg.ServerOpts)

if client.deployment == nil {
fmt.Println("Get a new depy")
client.deployment, err = topology.New(cfg)
if err != nil {
return nil, replaceErrors(err)
Expand Down
4 changes: 1 addition & 3 deletions mongo/integration/mtest/opmsg_deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ func (md *mockDeployment) Kind() description.TopologyKind {

// Connection implements the driver.Server interface.
func (md *mockDeployment) Connection(context.Context) (*mnet.Connection, error) {
return &mnet.Connection{
WireMessageReadWriteCloser: md.conn,
Describer: md.conn}, nil
return mnet.NewConnection(md.conn), nil
}

// RTTMonitor implements the driver.Server interface.
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/auth/gssapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestGSSAPIAuthenticator(t *testing.T) {
chanconn := &drivertest.ChannelConn{
Desc: desc,
}
err := authenticator.Auth(context.Background(), &Config{Connection: &mnet.Connection{Describer: chanconn}})
err := authenticator.Auth(context.Background(), &Config{Connection: mnet.NewConnection(chanconn)})
if err == nil {
t.Fatalf("expected err, got nil")
}
Expand Down
10 changes: 2 additions & 8 deletions x/mongo/driver/auth/mongodbcr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ func TestMongoDBCRAuthenticator_Fails(t *testing.T) {
Desc: desc,
}

mnetconn := &mnet.Connection{
WireMessageReadWriteCloser: c,
Describer: c,
}
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
if err == nil {
Expand Down Expand Up @@ -91,10 +88,7 @@ func TestMongoDBCRAuthenticator_Succeeds(t *testing.T) {
Desc: desc,
}

mnetconn := &mnet.Connection{
WireMessageReadWriteCloser: c,
Describer: c,
}
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
if err != nil {
Expand Down
20 changes: 4 additions & 16 deletions x/mongo/driver/auth/plain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ func TestPlainAuthenticator_Fails(t *testing.T) {
Desc: desc,
}

mnetconn := &mnet.Connection{
WireMessageReadWriteCloser: c,
Describer: c,
}
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
if err == nil {
Expand Down Expand Up @@ -97,10 +94,7 @@ func TestPlainAuthenticator_Extra_server_message(t *testing.T) {
Desc: desc,
}

mnetconn := &mnet.Connection{
WireMessageReadWriteCloser: c,
Describer: c,
}
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
if err == nil {
Expand Down Expand Up @@ -140,10 +134,7 @@ func TestPlainAuthenticator_Succeeds(t *testing.T) {
Desc: desc,
}

mnetconn := &mnet.Connection{
WireMessageReadWriteCloser: c,
Describer: c,
}
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
if err != nil {
Expand Down Expand Up @@ -190,10 +181,7 @@ func TestPlainAuthenticator_SucceedsBoolean(t *testing.T) {
Desc: desc,
}

mnetconn := &mnet.Connection{
WireMessageReadWriteCloser: c,
Describer: c,
}
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
require.NoError(t, err, "Auth error")
Expand Down
5 changes: 1 addition & 4 deletions x/mongo/driver/auth/scram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ func TestSCRAM(t *testing.T) {
Desc: desc,
}

conn := &mnet.Connection{
WireMessageReadWriteCloser: chanconn,
Describer: chanconn,
}
conn := mnet.NewConnection(chanconn)

err = authenticator.Auth(context.Background(), &Config{Connection: conn})
assert.Nil(t, err, "Auth error: %v\n", err)
Expand Down
10 changes: 2 additions & 8 deletions x/mongo/driver/auth/speculative_scram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ func TestSpeculativeSCRAM(t *testing.T) {
ReadResp: responses,
}

mnetconn := &mnet.Connection{
WireMessageReadWriteCloser: conn,
Describer: conn,
}
mnetconn := mnet.NewConnection(conn)

// Do both parts of the handshake.
info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), mnetconn)
Expand Down Expand Up @@ -171,10 +168,7 @@ func TestSpeculativeSCRAM(t *testing.T) {
ReadResp: responses,
}

mnetconn := &mnet.Connection{
WireMessageReadWriteCloser: conn,
Describer: conn,
}
mnetconn := mnet.NewConnection(conn)

info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), mnetconn)
assert.Nil(t, err, "GetHandshakeInformation error: %v", err)
Expand Down
10 changes: 2 additions & 8 deletions x/mongo/driver/auth/speculative_x509_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ func TestSpeculativeX509(t *testing.T) {
ReadResp: responses,
}

mnetconn := &mnet.Connection{
WireMessageReadWriteCloser: conn,
Describer: conn,
}
mnetconn := mnet.NewConnection(conn)

info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), mnetconn)
assert.Nil(t, err, "GetDescription error: %v", err)
Expand Down Expand Up @@ -97,10 +94,7 @@ func TestSpeculativeX509(t *testing.T) {
ReadResp: responses,
}

mnetconn := &mnet.Connection{
WireMessageReadWriteCloser: conn,
Describer: conn,
}
mnetconn := mnet.NewConnection(conn)

info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), mnetconn)
assert.Nil(t, err, "GetDescription error: %v", err)
Expand Down
6 changes: 1 addition & 5 deletions x/mongo/driver/batch_cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,7 @@ func (lbcd *loadBalancedCursorDeployment) Kind() description.TopologyKind {
}

func (lbcd *loadBalancedCursorDeployment) Connection(context.Context) (*mnet.Connection, error) {
return &mnet.Connection{
WireMessageReadWriteCloser: lbcd.conn,
Describer: lbcd.conn,
Pinner: lbcd.conn,
}, nil
return mnet.NewConnection(lbcd), nil
}

// RTTMonitor implements the driver.Server interface.
Expand Down
62 changes: 0 additions & 62 deletions x/mongo/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,6 @@ type Server interface {
RTTMonitor() RTTMonitor
}

// Connection represents a connection to a MongoDB server.
//type Connection interface {
// WriteWireMessage(context.Context, []byte) error
// ReadWireMessage(ctx context.Context) ([]byte, error)
// Description() description.Server
//
// // Close closes any underlying connection and returns or frees any resources held by the
// // connection. Close is idempotent and can be called multiple times, although subsequent calls
// // to Close may return an error. A connection cannot be used after it is closed.
// Close() error
//
// ID() string
// ServerConnectionID() *int64
// DriverConnectionID() int64
// Address() address.Address
// Stale() bool
//}

// RTTMonitor represents a round-trip-time monitor.
type RTTMonitor interface {
// EWMA returns the exponentially weighted moving average observed round-trip time.
Expand All @@ -91,27 +73,6 @@ type RTTMonitor interface {

var _ RTTMonitor = &csot.ZeroRTTMonitor{}

// PinnedConnection represents a Connection that can be pinned by one or more cursors or transactions. Implementations
// of this interface should maintain the following invariants:
//
// 1. Each Pin* call should increment the number of references for the connection.
// 2. Each Unpin* call should decrement the number of references for the connection.
// 3. Calls to Close() should be ignored until all resources have unpinned the connection.
//type PinnedConnection interface {
// //Connection
// mnet.WireMessageReadWriteCloser
// mnet.Describer
// PinToCursor() error
// PinToTransaction() error
// UnpinFromCursor() error
// UnpinFromTransaction() error
//}

// The session.LoadBalancedTransactionConnection type is a copy of PinnedConnection that was introduced to avoid
// import cycles. This compile-time assertion ensures that these types remain in sync if the PinnedConnection interface
// is changed in the future.
// var _ PinnedConnection = (session.LoadBalancedTransactionConnection)(nil)

// LocalAddresser is a type that is able to supply its local address
type LocalAddresser interface {
LocalAddress() address.Address
Expand All @@ -123,29 +84,6 @@ type Expirable interface {
Alive() bool
}

// StreamerConnection represents a Connection that supports streaming wire protocol messages using the moreToCome and
// exhaustAllowed flags.
//
// The SetStreaming and CurrentlyStreaming functions correspond to the moreToCome flag on server responses. If a
// response has moreToCome set, SetStreaming(true) will be called and CurrentlyStreaming() should return true.
//
// CanStream corresponds to the exhaustAllowed flag. The operations layer will set exhaustAllowed on outgoing wire
// messages to inform the server that the driver supports streaming.
type StreamerConnection interface {
mnet.WireMessageReadWriteCloser
mnet.Describer
SetStreaming(bool)
CurrentlyStreaming() bool
SupportsStreaming() bool
}

// Compressor is an interface used to compress wire messages. If a Connection supports compression
// it should implement this interface as well. The CompressWireMessage method will be called during
// the execution of an operation if the wire message is allowed to be compressed.
type Compressor interface {
CompressWireMessage(src, dst []byte) ([]byte, error)
}

// ProcessErrorResult represents the result of a ErrorProcessor.ProcessError() call. Exact values for this type can be
// checked directly (e.g. res == ServerMarkedUnknown), but it is recommended that applications use the ServerChanged()
// function instead.
Expand Down
27 changes: 27 additions & 0 deletions x/mongo/driver/mnet/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,30 @@ type Connection struct {
Compressor
Pinner
}

// NewConnection creates a new Connection with the provided component.
func NewConnection(component interface{}) *Connection {
conn := &Connection{}

if describer, ok := component.(Describer); ok {
conn.Describer = describer
}

if streamer, ok := component.(Streamer); ok {
conn.Streamer = streamer
}

if compressor, ok := component.(Compressor); ok {
conn.Compressor = compressor
}

if pinner, ok := component.(Pinner); ok {
conn.Pinner = pinner
}

if rwc, ok := component.(WireMessageReadWriteCloser); ok {
conn.WireMessageReadWriteCloser = rwc
}

return conn
}
8 changes: 1 addition & 7 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,13 +422,7 @@ func (op Operation) getServerAndConnection(
// If the provided client session has a pinned connection, it should be used for the operation because this
// indicates that we're in a transaction and the target server is behind a load balancer.
if op.Client != nil && op.Client.PinnedConnection != nil {
pinnedConn := &mnet.Connection{
WireMessageReadWriteCloser: op.Client.PinnedConnection,
Describer: op.Client.PinnedConnection,
Pinner: op.Client.PinnedConnection,
}

return server, pinnedConn, nil
return server, mnet.NewConnection(op.Client.PinnedConnection), nil
}

// Otherwise, default to checking out a connection from the server's pool.
Expand Down
32 changes: 6 additions & 26 deletions x/mongo/driver/operation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,7 @@ func TestOperation(t *testing.T) {
conn := &mockConnection{
rStreaming: false,
}
err := Operation{}.ExecuteExhaust(context.TODO(), &mnet.Connection{
WireMessageReadWriteCloser: conn,
Describer: conn,
Streamer: conn,
})
err := Operation{}.ExecuteExhaust(context.TODO(), mnet.NewConnection(conn))
assert.NotNil(t, err, "expected error, got nil")
})
})
Expand All @@ -582,11 +578,7 @@ func TestOperation(t *testing.T) {
rReadWM: nonStreamingResponse,
rCanStream: false,
}
mnetconn := &mnet.Connection{
WireMessageReadWriteCloser: conn,
Describer: conn,
Streamer: conn,
}
mnetconn := mnet.NewConnection(conn)
op := Operation{
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
return bsoncore.AppendInt32Element(dst, handshake.LegacyHello, 1), nil
Expand All @@ -611,12 +603,6 @@ func TestOperation(t *testing.T) {
assertExhaustAllowedSet(t, conn.pWriteWM, true)
assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true")

//mnetstreamer := &mnet.StreamerConnection{
// WireMessageReadWriteCloser: conn,
// Describer: conn,
// Streamer: conn,
//}

// Reset the server response and go through ExecuteExhaust to mimic streaming the next response. After
// execution, the connection should still be in a streaming state.
conn.rReadWM = streamingResponse
Expand All @@ -631,11 +617,8 @@ func TestOperation(t *testing.T) {
defer cancel()

op := Operation{
Database: "foobar",
Deployment: SingleConnectionDeployment{C: &mnet.Connection{
WireMessageReadWriteCloser: conn,
Describer: conn,
}},
Database: "foobar",
Deployment: SingleConnectionDeployment{C: mnet.NewConnection(conn)},
CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "ping", 1)
return dst, nil
Expand All @@ -655,11 +638,8 @@ func TestOperation(t *testing.T) {
cancel()

op := Operation{
Database: "foobar",
Deployment: SingleConnectionDeployment{C: &mnet.Connection{
WireMessageReadWriteCloser: conn,
Describer: conn,
}},
Database: "foobar",
Deployment: SingleConnectionDeployment{C: mnet.NewConnection(conn)},
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "ping", 1)
return dst, nil
Expand Down
8 changes: 3 additions & 5 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,8 @@ func (c *connection) connect(ctx context.Context) (err error) {
handshakeStartTime := time.Now()

iconn := initConnection{c}
handshakeConn := &mnet.Connection{
WireMessageReadWriteCloser: iconn,
Describer: iconn,
}

handshakeConn := mnet.NewConnection(iconn)

handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn)
if err == nil {
Expand Down Expand Up @@ -550,7 +548,7 @@ type initConnection struct{ *connection }

var _ mnet.WireMessageReadWriteCloser = initConnection{}
var _ mnet.Describer = initConnection{}
var _ driver.StreamerConnection = initConnection{}
var _ mnet.Streamer = initConnection{}

func (c initConnection) Description() description.Server {
if c.connection == nil {
Expand Down
Loading

0 comments on commit 3625419

Please sign in to comment.