Skip to content

Commit

Permalink
GODRIVER-2800 Convert Session Interface to a Struct (#1592)
Browse files Browse the repository at this point in the history
Co-authored-by: Matt Dale <[email protected]>
  • Loading branch information
prestonvasquez and matthewdale authored Apr 30, 2024
1 parent 21af53e commit 3365ea1
Show file tree
Hide file tree
Showing 16 changed files with 303 additions and 356 deletions.
80 changes: 44 additions & 36 deletions internal/docexamples/examples.go
Original file line number Diff line number Diff line change
Expand Up @@ -1758,30 +1758,32 @@ func UpdateEmployeeInfo(ctx context.Context, client *mongo.Client) error {
employees := client.Database("hr").Collection("employees")
events := client.Database("reporting").Collection("events")

return client.UseSession(ctx, func(sctx mongo.SessionContext) error {
err := sctx.StartTransaction(options.Transaction().
return client.UseSession(ctx, func(ctx context.Context) error {
sess := mongo.SessionFromContext(ctx)

err := sess.StartTransaction(options.Transaction().
SetReadConcern(readconcern.Snapshot()).
SetWriteConcern(writeconcern.Majority()),
)
if err != nil {
return err
}

_, err = employees.UpdateOne(sctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}})
_, err = employees.UpdateOne(ctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}})
if err != nil {
sctx.AbortTransaction(sctx)
sess.AbortTransaction(ctx)
log.Println("caught exception during transaction, aborting.")
return err
}
_, err = events.InsertOne(sctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}})
_, err = events.InsertOne(ctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}})
if err != nil {
sctx.AbortTransaction(sctx)
sess.AbortTransaction(ctx)
log.Println("caught exception during transaction, aborting.")
return err
}

for {
err = sctx.CommitTransaction(sctx)
err = sess.CommitTransaction(ctx)
switch e := err.(type) {
case nil:
return nil
Expand All @@ -1805,9 +1807,9 @@ func UpdateEmployeeInfo(ctx context.Context, client *mongo.Client) error {
// Start Transactions Retry Example 1

// RunTransactionWithRetry is an example function demonstrating transaction retry logic.
func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.SessionContext) error) error {
func RunTransactionWithRetry(ctx context.Context, txnFn func(context.Context) error) error {
for {
err := txnFn(sctx) // Performs transaction.
err := txnFn(ctx) // Performs transaction.
if err == nil {
return nil
}
Expand All @@ -1828,9 +1830,11 @@ func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.Session
// Start Transactions Retry Example 2

// CommitWithRetry is an example function demonstrating transaction commit with retry logic.
func CommitWithRetry(sctx mongo.SessionContext) error {
func CommitWithRetry(ctx context.Context) error {
sess := mongo.SessionFromContext(ctx)

for {
err := sctx.CommitTransaction(sctx)
err := sess.CommitTransaction(ctx)
switch e := err.(type) {
case nil:
log.Println("Transaction committed.")
Expand Down Expand Up @@ -1872,9 +1876,9 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error {
}
// Start Transactions Retry Example 3

runTransactionWithRetry := func(sctx mongo.SessionContext, txnFn func(mongo.SessionContext) error) error {
runTransactionWithRetry := func(ctx context.Context, txnFn func(context.Context) error) error {
for {
err := txnFn(sctx) // Performs transaction.
err := txnFn(ctx) // Performs transaction.
if err == nil {
return nil
}
Expand All @@ -1890,9 +1894,11 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error {
}
}

commitWithRetry := func(sctx mongo.SessionContext) error {
commitWithRetry := func(ctx context.Context) error {
sess := mongo.SessionFromContext(ctx)

for {
err := sctx.CommitTransaction(sctx)
err := sess.CommitTransaction(ctx)
switch e := err.(type) {
case nil:
log.Println("Transaction committed.")
Expand All @@ -1913,38 +1919,40 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error {
}

// Updates two collections in a transaction.
updateEmployeeInfo := func(sctx mongo.SessionContext) error {
updateEmployeeInfo := func(ctx context.Context) error {
employees := client.Database("hr").Collection("employees")
events := client.Database("reporting").Collection("events")

err := sctx.StartTransaction(options.Transaction().
sess := mongo.SessionFromContext(ctx)

err := sess.StartTransaction(options.Transaction().
SetReadConcern(readconcern.Snapshot()).
SetWriteConcern(writeconcern.Majority()),
)
if err != nil {
return err
}

_, err = employees.UpdateOne(sctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}})
_, err = employees.UpdateOne(ctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}})
if err != nil {
sctx.AbortTransaction(sctx)
sess.AbortTransaction(ctx)
log.Println("caught exception during transaction, aborting.")
return err
}
_, err = events.InsertOne(sctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}})
_, err = events.InsertOne(ctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}})
if err != nil {
sctx.AbortTransaction(sctx)
sess.AbortTransaction(ctx)
log.Println("caught exception during transaction, aborting.")
return err
}

return commitWithRetry(sctx)
return commitWithRetry(ctx)
}

return client.UseSessionWithOptions(
ctx, options.Session().SetDefaultReadPreference(readpref.Primary()),
func(sctx mongo.SessionContext) error {
return runTransactionWithRetry(sctx, updateEmployeeInfo)
func(ctx context.Context) error {
return runTransactionWithRetry(ctx, updateEmployeeInfo)
},
)
}
Expand Down Expand Up @@ -1976,13 +1984,13 @@ func WithTransactionExample(ctx context.Context) error {
barColl := client.Database("mydb1").Collection("bar", wcMajorityCollectionOpts)

// Step 1: Define the callback that specifies the sequence of operations to perform inside the transaction.
callback := func(sessCtx mongo.SessionContext) (interface{}, error) {
// Important: You must pass sessCtx as the Context parameter to the operations for them to be executed in the
callback := func(sesctx context.Context) (interface{}, error) {
// Important: You must pass sesctx as the Context parameter to the operations for them to be executed in the
// transaction.
if _, err := fooColl.InsertOne(sessCtx, bson.D{{"abc", 1}}); err != nil {
if _, err := fooColl.InsertOne(sesctx, bson.D{{"abc", 1}}); err != nil {
return nil, err
}
if _, err := barColl.InsertOne(sessCtx, bson.D{{"xyz", 999}}); err != nil {
if _, err := barColl.InsertOne(sesctx, bson.D{{"xyz", 999}}); err != nil {
return nil, err
}

Expand Down Expand Up @@ -2560,15 +2568,15 @@ func CausalConsistencyExamples(client *mongo.Client) error {
}
defer session1.EndSession(context.TODO())

err = client.UseSessionWithOptions(context.TODO(), opts, func(sctx mongo.SessionContext) error {
err = client.UseSessionWithOptions(context.TODO(), opts, func(ctx context.Context) error {
// Run an update with our causally-consistent session
_, err = coll.UpdateOne(sctx, bson.D{{"sku", 111}}, bson.D{{"$set", bson.D{{"end", currentDate}}}})
_, err = coll.UpdateOne(ctx, bson.D{{"sku", 111}}, bson.D{{"$set", bson.D{{"end", currentDate}}}})
if err != nil {
return err
}

// Run an insert with our causally-consistent session
_, err = coll.InsertOne(sctx, bson.D{{"sku", "nuts-111"}, {"name", "Pecans"}, {"start", currentDate}})
_, err = coll.InsertOne(ctx, bson.D{{"sku", "nuts-111"}, {"name", "Pecans"}, {"start", currentDate}})
if err != nil {
return err
}
Expand All @@ -2593,7 +2601,7 @@ func CausalConsistencyExamples(client *mongo.Client) error {
}
defer session2.EndSession(context.TODO())

err = client.UseSessionWithOptions(context.TODO(), opts, func(sctx mongo.SessionContext) error {
err = client.UseSessionWithOptions(context.TODO(), opts, func(ctx context.Context) error {
// Set cluster time of session2 to session1's cluster time
clusterTime := session1.ClusterTime()
session2.AdvanceClusterTime(clusterTime)
Expand All @@ -2602,13 +2610,13 @@ func CausalConsistencyExamples(client *mongo.Client) error {
operationTime := session1.OperationTime()
session2.AdvanceOperationTime(operationTime)
// Run a find on session2, which should find all the writes from session1
cursor, err := coll.Find(sctx, bson.D{{"end", nil}})
cursor, err := coll.Find(ctx, bson.D{{"end", nil}})

if err != nil {
return err
}

for cursor.Next(sctx) {
for cursor.Next(ctx) {
doc := cursor.Current
fmt.Printf("Document: %v\n", doc.String())
}
Expand Down Expand Up @@ -2984,7 +2992,7 @@ func snapshotQueryPetExample(mt *mtest.T) error {
defer sess.EndSession(ctx)

var adoptablePetsCount int32
err = mongo.WithSession(ctx, sess, func(ctx mongo.SessionContext) error {
err = mongo.WithSession(ctx, sess, func(ctx context.Context) error {
// Count the adoptable cats
const adoptableCatsOutput = "adoptableCatsCount"
cursor, err := db.Collection("cats").Aggregate(ctx, mongo.Pipeline{
Expand Down Expand Up @@ -3048,7 +3056,7 @@ func snapshotQueryRetailExample(mt *mtest.T) error {
defer sess.EndSession(ctx)

var totalDailySales int32
err = mongo.WithSession(ctx, sess, func(ctx mongo.SessionContext) error {
err = mongo.WithSession(ctx, sess, func(ctx context.Context) error {
// Count the total daily sales
const totalDailySalesOutput = "totalDailySales"
cursor, err := db.Collection("sales").Aggregate(ctx, mongo.Pipeline{
Expand Down
42 changes: 21 additions & 21 deletions internal/integration/causal_consistency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ func TestCausalConsistency_Supported(t *testing.T) {
// first read in a causally consistent session must not send afterClusterTime to the server

ccOpts := options.Session().SetCausalConsistency(true)
_ = mt.Client.UseSessionWithOptions(context.Background(), ccOpts, func(sc mongo.SessionContext) error {
_, _ = mt.Coll.Find(sc, bson.D{})
_ = mt.Client.UseSessionWithOptions(context.Background(), ccOpts, func(ctx context.Context) error {
_, _ = mt.Coll.Find(ctx, bson.D{})
return nil
})

Expand All @@ -57,8 +57,8 @@ func TestCausalConsistency_Supported(t *testing.T) {
assert.Nil(mt, err, "StartSession error: %v", err)
defer sess.EndSession(context.Background())

_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_, _ = mt.Coll.Find(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_, _ = mt.Coll.Find(ctx, bson.D{})
return nil
})

Expand All @@ -85,8 +85,8 @@ func TestCausalConsistency_Supported(t *testing.T) {
assert.Nil(mt, err, "StartSession error: %v", err)
defer sess.EndSession(context.Background())

_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})
currOptime := sess.OperationTime()
Expand Down Expand Up @@ -120,8 +120,8 @@ func TestCausalConsistency_Supported(t *testing.T) {
assert.NotNil(mt, currOptime, "expected session operation time, got nil")

mt.ClearEvents()
_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})
_, sentOptime := getReadConcernFields(mt, mt.GetStartedEvent().Command)
Expand All @@ -134,10 +134,10 @@ func TestCausalConsistency_Supported(t *testing.T) {
// a read operation in a non causally-consistent session should not include afterClusterTime

sessOpts := options.Session().SetCausalConsistency(false)
_ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(sc mongo.SessionContext) error {
_, _ = mt.Coll.Find(sc, bson.D{})
_ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(ctx context.Context) error {
_, _ = mt.Coll.Find(ctx, bson.D{})
mt.ClearEvents()
_, _ = mt.Coll.Find(sc, bson.D{})
_, _ = mt.Coll.Find(ctx, bson.D{})
return nil
})
evt := mt.GetStartedEvent()
Expand All @@ -152,14 +152,14 @@ func TestCausalConsistency_Supported(t *testing.T) {
assert.Nil(mt, err, "StartSession error: %v", err)
defer sess.EndSession(context.Background())

_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})
currOptime := sess.OperationTime()
mt.ClearEvents()
_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})

Expand All @@ -174,14 +174,14 @@ func TestCausalConsistency_Supported(t *testing.T) {
assert.Nil(mt, err, "StartSession error: %v", err)
defer sess.EndSession(context.Background())

_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})
currOptime := sess.OperationTime()
mt.ClearEvents()
_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})

Expand Down Expand Up @@ -215,8 +215,8 @@ func TestCausalConsistency_NotSupported(t *testing.T) {
// support cluster times

sessOpts := options.Session().SetCausalConsistency(true)
_ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(sc mongo.SessionContext) error {
_, _ = mt.Coll.Find(sc, bson.D{})
_ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(ctx context.Context) error {
_, _ = mt.Coll.Find(ctx, bson.D{})
return nil
})

Expand Down
3 changes: 1 addition & 2 deletions internal/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,7 @@ func TestClient(t *testing.T) {
sess, err := mt.Client.StartSession(tc.opts)
assert.Nil(mt, err, "StartSession error: %v", err)
defer sess.EndSession(context.Background())
xs := sess.(mongo.XSession)
consistent := xs.ClientSession().Consistent
consistent := sess.ClientSession().Consistent
assert.Equal(mt, tc.consistent, consistent, "expected consistent to be %v, got %v", tc.consistent, consistent)
})
}
Expand Down
Loading

0 comments on commit 3365ea1

Please sign in to comment.