Skip to content

Commit

Permalink
GODRIVER-2935 Resolve merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez committed Sep 7, 2023
2 parents c0085e1 + 68bf155 commit 9569e0c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
28 changes: 17 additions & 11 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ func (op Operation) shouldEncrypt() bool {
}

// selectServer handles performing server selection for an operation.
func (op Operation) selectServer(ctx context.Context) (Server, error) {
func (op Operation) selectServer(ctx context.Context, requestID int32) (Server, error) {
if err := op.Validate(); err != nil {
return nil, err
}
Expand All @@ -341,14 +341,14 @@ func (op Operation) selectServer(ctx context.Context) (Server, error) {
}

ctx = logger.WithOperationName(ctx, op.Name)
ctx = logger.WithOperationID(ctx, wiremessage.CurrentRequestID())
ctx = logger.WithOperationID(ctx, requestID)

return op.Deployment.SelectServer(ctx, selector)
}

// getServerAndConnection should be used to retrieve a Server and Connection to execute an operation.
func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connection, error) {
server, err := op.selectServer(ctx)
func (op Operation) getServerAndConnection(ctx context.Context, requestID int32) (Server, Connection, error) {
server, err := op.selectServer(ctx, requestID)
if err != nil {
if op.Client != nil &&
!(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() {
Expand Down Expand Up @@ -531,11 +531,11 @@ func (op Operation) Execute(ctx context.Context) error {
}
}()
for {
wiremessage.NextRequestID()
requestID := wiremessage.NextRequestID()

// If the server or connection are nil, try to select a new server and get a new connection.
if srvr == nil || conn == nil {
srvr, conn, err = op.getServerAndConnection(ctx)
srvr, conn, err = op.getServerAndConnection(ctx, requestID)
if err != nil {
// If the returned error is retryable and there are retries remaining (negative
// retries means retry indefinitely), then retry the operation. Set the server
Expand Down Expand Up @@ -630,7 +630,7 @@ func (op Operation) Execute(ctx context.Context) error {
}

var startedInfo startedInformation
*wm, startedInfo, err = op.createWireMessage(ctx, (*wm)[:0], desc, maxTimeMS, conn)
*wm, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID)

if err != nil {
return err
Expand Down Expand Up @@ -1183,8 +1183,13 @@ func (op Operation) createLegacyHandshakeWireMessage(
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
}

func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer,
func (op Operation) createMsgWireMessage(
ctx context.Context,
maxTimeMS uint64,
dst []byte,
desc description.SelectedServer,
conn Connection,
requestID int32,
) ([]byte, startedInformation, error) {
var info startedInformation
var flags wiremessage.MsgFlag
Expand All @@ -1200,7 +1205,7 @@ func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64,
flags |= wiremessage.ExhaustAllowed
}

info.requestID = wiremessage.CurrentRequestID()
info.requestID = requestID
wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg)
dst = wiremessage.AppendMsgFlags(dst, flags)
// Body
Expand Down Expand Up @@ -1276,16 +1281,17 @@ func isLegacyHandshake(op Operation, desc description.SelectedServer) bool {

func (op Operation) createWireMessage(
ctx context.Context,
maxTimeMS uint64,
dst []byte,
desc description.SelectedServer,
maxTimeMS uint64,
conn Connection,
requestID int32,
) ([]byte, startedInformation, error) {
if isLegacyHandshake(op, desc) {
return op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc)
}

return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn)
return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, requestID)
}

// addCommandFields adds the fields for a command to the wire message in dst. This assumes that the start of the document
Expand Down
9 changes: 5 additions & 4 deletions x/mongo/driver/operation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestOperation(t *testing.T) {
t.Run("selectServer", func(t *testing.T) {
t.Run("returns validation error", func(t *testing.T) {
op := &Operation{}
_, err := op.selectServer(context.Background())
_, err := op.selectServer(context.Background(), 1)
if err == nil {
t.Error("Expected a validation error from selectServer, but got <nil>")
}
Expand All @@ -76,7 +76,7 @@ func TestOperation(t *testing.T) {
Database: "testing",
Selector: want,
}
_, err := op.selectServer(context.Background())
_, err := op.selectServer(context.Background(), 1)
noerr(t, err)
got := d.params.selector
if !cmp.Equal(got, want) {
Expand All @@ -90,7 +90,7 @@ func TestOperation(t *testing.T) {
Deployment: d,
Database: "testing",
}
_, err := op.selectServer(context.Background())
_, err := op.selectServer(context.Background(), 1)
noerr(t, err)
if d.params.selector == nil {
t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed <nil>.")
Expand Down Expand Up @@ -652,7 +652,8 @@ func TestOperation(t *testing.T) {
}

func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte {
idx, wm := wiremessage.AppendHeaderStart(nil, 0, wiremessage.CurrentRequestID()+1, wiremessage.OpMsg)
const psuedoRequestID = 1
idx, wm := wiremessage.AppendHeaderStart(nil, 0, psuedoRequestID, wiremessage.OpMsg)
var flags wiremessage.MsgFlag
if moreToCome {
flags = wiremessage.MoreToCome
Expand Down
3 changes: 0 additions & 3 deletions x/mongo/driver/wiremessage/wiremessage.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ type WireMessage []byte

var globalRequestID int32

// CurrentRequestID returns the current request ID.
func CurrentRequestID() int32 { return atomic.LoadInt32(&globalRequestID) }

// NextRequestID returns the next request ID.
func NextRequestID() int32 { return atomic.AddInt32(&globalRequestID, 1) }

Expand Down

0 comments on commit 9569e0c

Please sign in to comment.