Skip to content

Commit

Permalink
GODRIVER-2117 - Check clientSession is not nil inside executeTestRunn…
Browse files Browse the repository at this point in the history
…erOperation (#1457)
  • Loading branch information
kumarlokesh authored Nov 8, 2023
1 parent d52c9e1 commit fd1ca35
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions mongo/integration/unified_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,46 +462,64 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation,

var fp mtest.FailPoint
if err := bson.Unmarshal(fpDoc.Document(), &fp); err != nil {
return fmt.Errorf("Unmarshal error: %v", err)
return fmt.Errorf("Unmarshal error: %w", err)
}

if clientSession == nil {
return errors.New("expected valid session, got nil")
}
targetHost := clientSession.PinnedServer.Addr.String()
opts := options.Client().ApplyURI(mtest.ClusterURI()).SetHosts([]string{targetHost})
integtest.AddTestServerAPIVersion(opts)
client, err := mongo.Connect(context.Background(), opts)
if err != nil {
return fmt.Errorf("Connect error for targeted client: %v", err)
return fmt.Errorf("Connect error for targeted client: %w", err)
}
defer func() { _ = client.Disconnect(context.Background()) }()

if err = client.Database("admin").RunCommand(context.Background(), fp).Err(); err != nil {
return fmt.Errorf("error setting targeted fail point: %v", err)
return fmt.Errorf("error setting targeted fail point: %w", err)
}
mt.TrackFailPoint(fp.ConfigureFailPoint)
case "configureFailPoint":
fp, err := op.Arguments.LookupErr("failPoint")
assert.Nil(mt, err, "failPoint not found in arguments")
if err != nil {
return fmt.Errorf("unable to find 'failPoint' in arguments: %w", err)
}
mt.SetFailPointFromDocument(fp.Document())
case "assertSessionTransactionState":
stateVal, err := op.Arguments.LookupErr("state")
assert.Nil(mt, err, "state not found in arguments")
if err != nil {
return fmt.Errorf("unable to find 'state' in arguments: %w", err)
}
expectedState, ok := stateVal.StringValueOK()
assert.True(mt, ok, "state argument is not a string")
if !ok {
return errors.New("expected 'state' argument to be string")
}

assert.NotNil(mt, clientSession, "expected valid session, got nil")
if clientSession == nil {
return errors.New("expected valid session, got nil")
}
actualState := clientSession.TransactionState.String()

// actualState should match expectedState, but "in progress" is the same as
// "in_progress".
stateMatch := actualState == expectedState ||
actualState == "in progress" && expectedState == "in_progress"
assert.True(mt, stateMatch, "expected transaction state %v, got %v",
expectedState, actualState)
if !stateMatch {
return fmt.Errorf("expected transaction state %v, got %v", expectedState, actualState)
}
case "assertSessionPinned":
if clientSession == nil {
return errors.New("expected valid session, got nil")
}
if clientSession.PinnedServer == nil {
return errors.New("expected pinned server, got nil")
}
case "assertSessionUnpinned":
if clientSession == nil {
return errors.New("expected valid session, got nil")
}
// We don't use a combined helper for assertSessionPinned and assertSessionUnpinned because the unpinned
// case provides the pinned server address in the error msg for debugging.
if clientSession.PinnedServer != nil {
Expand Down Expand Up @@ -544,7 +562,7 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation,
case "waitForThread":
waitForThread(mt, testCase, op)
default:
mt.Fatalf("unrecognized testRunner operation %v", op.Name)
return fmt.Errorf("unrecognized testRunner operation %v", op.Name)
}

return nil
Expand All @@ -571,7 +589,7 @@ func indexExists(dbName, collName, indexName string) (bool, error) {
iv := mtest.GlobalClient().Database(dbName).Collection(collName).Indexes()
cursor, err := iv.List(context.Background())
if err != nil {
return false, fmt.Errorf("IndexView.List error: %v", err)
return false, fmt.Errorf("IndexView.List error: %w", err)
}
defer cursor.Close(context.Background())

Expand Down Expand Up @@ -606,7 +624,7 @@ func collectionExists(dbName, collName string) (bool, error) {
// Use global client because listCollections cannot be executed inside a transaction.
collections, err := mtest.GlobalClient().Database(dbName).ListCollectionNames(context.Background(), filter)
if err != nil {
return false, fmt.Errorf("ListCollectionNames error: %v", err)
return false, fmt.Errorf("ListCollectionNames error: %w", err)
}

return len(collections) > 0, nil
Expand Down Expand Up @@ -636,9 +654,8 @@ func executeSessionOperation(mt *mtest.T, op *operation, sess mongo.Session) err
case "withTransaction":
return executeWithTransaction(mt, sess, op.Arguments)
default:
mt.Fatalf("unrecognized session operation: %v", op.Name)
return fmt.Errorf("unrecognized session operation: %v", op.Name)
}
return nil
}

func executeCollectionOperation(mt *mtest.T, op *operation, sess mongo.Session) error {
Expand Down

0 comments on commit fd1ca35

Please sign in to comment.