Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-3058 Centralize x-package Connection interface as a struct #1475

Merged
merged 16 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions internal/integration/mtest/opmsg_deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/mnet"
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
Expand Down Expand Up @@ -50,15 +51,16 @@ type connection struct {
responses []bson.D // responses to send when ReadWireMessage is called
}

var _ driver.Connection = &connection{}
var _ mnet.ReadWriteCloser = &connection{}
var _ mnet.Describer = &connection{}

// WriteWireMessage is a no-op.
func (c *connection) WriteWireMessage(context.Context, []byte) error {
// Write is a no-op.
func (c *connection) Write(context.Context, []byte) error {
return nil
}

// ReadWireMessage returns the next response in the connection's list of responses.
func (c *connection) ReadWireMessage(_ context.Context) ([]byte, error) {
// Read returns the next response in the connection's list of responses.
func (c *connection) Read(_ context.Context) ([]byte, error) {
var dst []byte
if len(c.responses) == 0 {
return dst, errors.New("no responses remaining")
Expand Down Expand Up @@ -137,8 +139,8 @@ func (md *mockDeployment) Kind() description.TopologyKind {
}

// Connection implements the driver.Server interface.
func (md *mockDeployment) Connection(context.Context) (driver.Connection, error) {
return md.conn, nil
func (md *mockDeployment) Connection(context.Context) (*mnet.Connection, error) {
return mnet.NewConnection(md.conn), nil
}

// RTTMonitor implements the driver.Server interface.
Expand Down
5 changes: 3 additions & 2 deletions mongo/change_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/mnet"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
Expand Down Expand Up @@ -282,7 +283,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in
return cs, cs.Err()
}

func (cs *ChangeStream) createOperationDeployment(server driver.Server, connection driver.Connection) driver.Deployment {
func (cs *ChangeStream) createOperationDeployment(server driver.Server, connection *mnet.Connection) driver.Deployment {
return &changeStreamDeployment{
topologyKind: cs.client.deployment.Kind(),
server: server,
Expand All @@ -292,7 +293,7 @@ func (cs *ChangeStream) createOperationDeployment(server driver.Server, connecti

func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) error {
var server driver.Server
var conn driver.Connection
var conn *mnet.Connection

if server, cs.err = cs.client.deployment.SelectServer(ctx, cs.selector); cs.err != nil {
return cs.Err()
Expand Down
9 changes: 5 additions & 4 deletions mongo/change_stream_deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ import (

"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/mnet"
)

type changeStreamDeployment struct {
topologyKind description.TopologyKind
server driver.Server
conn driver.Connection
conn *mnet.Connection
}

var _ driver.Deployment = (*changeStreamDeployment)(nil)
Expand All @@ -31,19 +32,19 @@ func (c *changeStreamDeployment) Kind() description.TopologyKind {
return c.topologyKind
}

func (c *changeStreamDeployment) Connection(context.Context) (driver.Connection, error) {
func (c *changeStreamDeployment) Connection(context.Context) (*mnet.Connection, error) {
return c.conn, nil
}

func (c *changeStreamDeployment) RTTMonitor() driver.RTTMonitor {
return c.server.RTTMonitor()
}

func (c *changeStreamDeployment) ProcessError(err error, conn driver.Connection) driver.ProcessErrorResult {
func (c *changeStreamDeployment) ProcessError(err error, describer mnet.Describer) driver.ProcessErrorResult {
ep, ok := c.server.(driver.ErrorProcessor)
if !ok {
return driver.NoChange
}

return ep.ProcessError(err, conn)
return ep.ProcessError(err, describer)
}
16 changes: 9 additions & 7 deletions x/mongo/driver/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/mnet"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
Expand Down Expand Up @@ -76,7 +77,11 @@ var _ driver.Handshaker = (*authHandshaker)(nil)

// GetHandshakeInformation performs the initial MongoDB handshake to retrieve the required information for the provided
// connection.
func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn driver.Connection) (driver.HandshakeInformation, error) {
func (ah *authHandshaker) GetHandshakeInformation(
ctx context.Context,
addr address.Address,
conn *mnet.Connection,
) (driver.HandshakeInformation, error) {
if ah.wrapped != nil {
return ah.wrapped.GetHandshakeInformation(ctx, addr, conn)
}
Expand Down Expand Up @@ -115,7 +120,7 @@ func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr addr
}

// FinishHandshake performs authentication for conn if necessary.
func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error {
func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn *mnet.Connection) error {
performAuth := ah.options.PerformAuthentication
if performAuth == nil {
performAuth = func(serv description.Server) bool {
Expand All @@ -124,10 +129,8 @@ func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Conne
}
}

desc := conn.Description()
if performAuth(desc) && ah.options.Authenticator != nil {
if performAuth(conn.Description()) && ah.options.Authenticator != nil {
cfg := &Config{
Description: desc,
Connection: conn,
ClusterClock: ah.options.ClusterClock,
HandshakeInfo: ah.handshakeInfo,
Expand Down Expand Up @@ -172,8 +175,7 @@ func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshake

// Config holds the information necessary to perform an authentication attempt.
type Config struct {
Description description.Server
Connection driver.Connection
Connection *mnet.Connection
ClusterClock *session.ClusterClock
HandshakeInfo driver.HandshakeInformation
ServerAPI *driver.ServerAPIOptions
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/auth/gssapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type GSSAPIAuthenticator struct {

// Auth authenticates the connection.
func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error {
target := cfg.Description.Addr.String()
target := cfg.Connection.Description().Addr.String()
hostname, _, err := net.SplitHostPort(target)
if err != nil {
return newAuthError(fmt.Sprintf("invalid endpoint (%s) specified: %s", target, err), nil)
Expand Down
10 changes: 9 additions & 1 deletion x/mongo/driver/auth/gssapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (

"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
"go.mongodb.org/mongo-driver/x/mongo/driver/mnet"
)

func TestGSSAPIAuthenticator(t *testing.T) {
Expand All @@ -36,7 +38,13 @@ func TestGSSAPIAuthenticator(t *testing.T) {
},
Addr: address.Address("foo:27017"),
}
err := authenticator.Auth(context.Background(), &Config{Description: desc})
chanconn := &drivertest.ChannelConn{
Desc: desc,
}

mnetconn := mnet.NewConnection(chanconn)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
if err == nil {
t.Fatalf("expected err, got nil")
}
Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/auth/mongodbcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error {
doc := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendInt32Element(nil, "getnonce", 1))
cmd := operation.NewCommand(doc).
Database(db).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
Deployment(driver.SingleConnectionDeployment{C: cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err := cmd.Execute(ctx)
Expand All @@ -86,7 +86,7 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error {
)
cmd = operation.NewCommand(doc).
Database(db).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
Deployment(driver.SingleConnectionDeployment{C: cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err = cmd.Execute(ctx)
Expand Down
9 changes: 7 additions & 2 deletions x/mongo/driver/auth/mongodbcr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
. "go.mongodb.org/mongo-driver/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
"go.mongodb.org/mongo-driver/x/mongo/driver/mnet"
)

func TestMongoDBCRAuthenticator_Fails(t *testing.T) {
Expand Down Expand Up @@ -46,7 +47,9 @@ func TestMongoDBCRAuthenticator_Fails(t *testing.T) {
Desc: desc,
}

err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
if err == nil {
t.Fatalf("expected an error but got none")
}
Expand Down Expand Up @@ -85,7 +88,9 @@ func TestMongoDBCRAuthenticator_Succeeds(t *testing.T) {
Desc: desc,
}

err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
if err != nil {
t.Fatalf("expected no error but got \"%s\"", err)
}
Expand Down
17 changes: 13 additions & 4 deletions x/mongo/driver/auth/plain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
. "go.mongodb.org/mongo-driver/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
"go.mongodb.org/mongo-driver/x/mongo/driver/mnet"
)

func TestPlainAuthenticator_Fails(t *testing.T) {
Expand Down Expand Up @@ -48,7 +49,9 @@ func TestPlainAuthenticator_Fails(t *testing.T) {
Desc: desc,
}

err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
if err == nil {
t.Fatalf("expected an error but got none")
}
Expand Down Expand Up @@ -91,7 +94,9 @@ func TestPlainAuthenticator_Extra_server_message(t *testing.T) {
Desc: desc,
}

err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
if err == nil {
t.Fatalf("expected an error but got none")
}
Expand Down Expand Up @@ -129,7 +134,9 @@ func TestPlainAuthenticator_Succeeds(t *testing.T) {
Desc: desc,
}

err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
if err != nil {
t.Fatalf("expected no error but got \"%s\"", err)
}
Expand Down Expand Up @@ -174,7 +181,9 @@ func TestPlainAuthenticator_SucceedsBoolean(t *testing.T) {
Desc: desc,
}

err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
mnetconn := mnet.NewConnection(c)

err := authenticator.Auth(context.Background(), &Config{Connection: mnetconn})
require.NoError(t, err, "Auth error")
require.Len(t, c.Written, 1, "expected 1 messages to be sent")

Expand Down
11 changes: 7 additions & 4 deletions x/mongo/driver/auth/scram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
"go.mongodb.org/mongo-driver/x/mongo/driver/mnet"
)

const (
Expand Down Expand Up @@ -68,18 +69,20 @@ func TestSCRAM(t *testing.T) {
Max: 21,
},
}
conn := &drivertest.ChannelConn{
chanconn := &drivertest.ChannelConn{
Written: make(chan []byte, len(tc.payloads)),
ReadResp: responses,
Desc: desc,
}

err = authenticator.Auth(context.Background(), &Config{Description: desc, Connection: conn})
conn := mnet.NewConnection(chanconn)

err = authenticator.Auth(context.Background(), &Config{Connection: conn})
assert.Nil(t, err, "Auth error: %v\n", err)

// Verify that the first command sent is saslStart.
assert.True(t, len(conn.Written) > 1, "wire messages were written to the connection")
startCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
assert.True(t, len(chanconn.Written) > 1, "wire messages were written to the connection")
startCmd, err := drivertest.GetCommandFromMsgWireMessage(<-chanconn.Written)
assert.Nil(t, err, "error parsing wire message: %v", err)
cmdName := startCmd.Index(0).Key()
assert.Equal(t, cmdName, "saslStart", "cmd name mismatch; expected 'saslStart', got %v", cmdName)
Expand Down
13 changes: 9 additions & 4 deletions x/mongo/driver/auth/speculative_scram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
"go.mongodb.org/mongo-driver/x/mongo/driver/mnet"
)

var (
Expand Down Expand Up @@ -80,13 +81,15 @@ func TestSpeculativeSCRAM(t *testing.T) {
ReadResp: responses,
}

mnetconn := mnet.NewConnection(conn)

// Do both parts of the handshake.
info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), mnetconn)
assert.Nil(t, err, "GetHandshakeInformation error: %v", err)
assert.NotNil(t, info.SpeculativeAuthenticate, "desc.SpeculativeAuthenticate not set")
conn.Desc = info.Description // Set conn.Desc so the new description will be used for the authentication.

err = handshaker.FinishHandshake(context.Background(), conn)
err = handshaker.FinishHandshake(context.Background(), mnetconn)
assert.Nil(t, err, "FinishHandshake error: %v", err)
assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))

Expand Down Expand Up @@ -165,13 +168,15 @@ func TestSpeculativeSCRAM(t *testing.T) {
ReadResp: responses,
}

info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
mnetconn := mnet.NewConnection(conn)

info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), mnetconn)
assert.Nil(t, err, "GetHandshakeInformation error: %v", err)
assert.Nil(t, info.SpeculativeAuthenticate, "expected desc.SpeculativeAuthenticate to be unset, got %s",
bson.Raw(info.SpeculativeAuthenticate))
conn.Desc = info.Description

err = handshaker.FinishHandshake(context.Background(), conn)
err = handshaker.FinishHandshake(context.Background(), mnetconn)
assert.Nil(t, err, "FinishHandshake error: %v", err)
assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))

Expand Down
Loading
Loading