From 590a3c8bfc28bef326d55186e46c0155f0ce742e Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Thu, 20 Jun 2024 14:13:34 -0400 Subject: [PATCH] GODRIVER-2911: Add more tests that do not require fail points --- cmd/testoidcauth/main.go | 80 +++++++++++++++++++++++++++++++++++++ mongo/client.go | 5 +++ x/mongo/driver/auth/oidc.go | 8 ++++ 3 files changed, 93 insertions(+) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index 5459dda410..71f67bf41b 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -19,6 +19,7 @@ import ( "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/auth" ) var uriAdmin = os.Getenv("MONGODB_URI") @@ -68,6 +69,7 @@ func main() { aux("machine_2_1_validCallbackInputs", machine_2_1_validCallbackInputs) aux("machine_2_3_oidcCallbackReturnMissingData", machine_2_3_oidcCallbackReturnMissingData) aux("machine_2_4_invalidClientConfigurationWithCallback", machine_2_4_invalidClientConfigurationWithCallback) + aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth) if hasError { log.Fatal("One or more tests failed") } @@ -268,3 +270,81 @@ func machine_2_4_invalidClientConfigurationWithCallback() error { } return nil } + +func machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth() error { + callbackCount := 0 + var callbackFailed error = nil + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_3_1: failed reading token file: %v\n", err) + } + return &driver.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + if err != nil { + return fmt.Errorf("machine_3_1: failed connecting client: %v", err) + } + + // Poison the cache with a random token + client.GetAuthenticator().(*auth.OIDCAuthenticator).SetAccessToken("some random happy sunshine string") + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_3_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_1: expected callback count to be 1, got %d\n", callbackCount) + } + return callbackFailed +} + +func machine_3_2_authFailuresWithoutCachedTokensReturnsAnError() error { + callbackCount := 0 + var callbackFailed error = nil + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + return &driver.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + if err != nil { + return fmt.Errorf("machine_3_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_3_2: failed succeeded Find when it should fail") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_2: expected callback count to be 1, got %d\n", callbackCount) + } + return callbackFailed +} diff --git a/mongo/client.go b/mongo/client.go index 082554adbb..fec2c45287 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -83,6 +83,11 @@ type Client struct { authenticator driver.Authenticator } +// GetAuthenticator returns the authenticator for the client, used for testing purposes. +func (c *Client) GetAuthenticator() driver.Authenticator { + return c.authenticator +} + // Connect creates a new Client and then initializes it using the Connect method. This is equivalent to calling // NewClient followed by Client.Connect. // diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 5138ae4402..6414d238ef 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -82,6 +82,14 @@ type OIDCAuthenticator struct { tokenGenID uint64 } +// SetAccessToken allows for manually setting the access token for the OIDCAuthenticator, this is +// only for testing purposes. +func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) { + oa.mu.Lock() + defer oa.mu.Unlock() + oa.accessToken = accessToken +} + func newOIDCAuthenticator(cred *Cred) (Authenticator, error) { if cred.Password != "" { return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC)