From 3c00307d4f2ed6303eef0d6387a44753f59170cd Mon Sep 17 00:00:00 2001 From: Patrick Meredith Date: Thu, 20 Jun 2024 13:52:43 -0400 Subject: [PATCH] GODRIVER-2911: Change to using errors --- cmd/testoidcauth/main.go | 120 ++++++++++++++++++------------------ x/mongo/driver/auth/oidc.go | 22 +++---- 2 files changed, 71 insertions(+), 71 deletions(-) diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go index b12acfc371..5459dda410 100644 --- a/cmd/testoidcauth/main.go +++ b/cmd/testoidcauth/main.go @@ -35,15 +35,11 @@ func tokenFile(user string) string { return path.Join(oidcTokenDir, user) } -func connectWithMachineCB(uri string, cb driver.OIDCCallback) *mongo.Client { +func connectWithMachineCB(uri string, cb driver.OIDCCallback) (*mongo.Client, error) { opts := options.Client().ApplyURI(uri) opts.Auth.OIDCMachineCallback = cb - client, err := mongo.Connect(context.Background(), opts) - if err != nil { - fmt.Printf("Error connecting client: %v", err) - } - return client + return mongo.Connect(context.Background(), opts) } func connectWithMachineCBAndProperties(uri string, cb driver.OIDCCallback, props map[string]string) (*mongo.Client, error) { @@ -56,16 +52,16 @@ func connectWithMachineCBAndProperties(uri string, cb driver.OIDCCallback, props func main() { hasError := false - aux := func(test_name string, f func() bool) { + aux := func(test_name string, f func() error) { fmt.Printf("%s...", test_name) - testResult := f() - if testResult { + err := f() + if err != nil { + fmt.Println("Test Error: ", err) fmt.Println("...Failed") + hasError = true } else { fmt.Println("...Ok") } - fmt.Println("hasError: ", hasError, "testResult: ", testResult) - hasError = hasError || testResult } aux("machine_1_1_callbackIsCalled", machine_1_1_callbackIsCalled) aux("machine_1_2_callbackIsCalledOnlyOneForMultipleConnections", machine_1_2_callbackIsCalledOnlyOneForMultipleConnections) @@ -77,12 +73,12 @@ func main() { } } -func machine_1_1_callbackIsCalled() bool { +func machine_1_1_callbackIsCalled() error { callbackCount := 0 - callbackFailed := false + var callbackFailed error = nil countMutex := sync.Mutex{} - client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -90,8 +86,7 @@ func machine_1_1_callbackIsCalled() bool { tokenFile := tokenFile("test_user1") accessToken, err := os.ReadFile(tokenFile) if err != nil { - fmt.Printf("machine_1_1: failed reading token file: %v\n", err) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_1_1: failed reading token file: %v\n", err) } return &driver.OIDCCredential{ AccessToken: string(accessToken), @@ -100,28 +95,30 @@ func machine_1_1_callbackIsCalled() bool { }, nil }) + if err != nil { + return fmt.Errorf("machine_1_1: failed connecting client: %v", err) + } + coll := client.Database("test").Collection("test") - _, err := coll.Find(context.Background(), bson.D{}) + _, err = coll.Find(context.Background(), bson.D{}) if err != nil { - fmt.Printf("machine_1_1: failed executing Find: %v", err) - return true + return fmt.Errorf("machine_1_1: failed executing Find: %v", err) } countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - fmt.Printf("machine_1_1: expected callback count to be 1, got %d\n", callbackCount) - return true + return fmt.Errorf("machine_1_1: expected callback count to be 1, got %d\n", callbackCount) } return callbackFailed } -func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool { +func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() error { callbackCount := 0 - callbackFailed := false + var callbackFailed error = nil countMutex := sync.Mutex{} - client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -129,8 +126,7 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool { tokenFile := tokenFile("test_user1") accessToken, err := os.ReadFile(tokenFile) if err != nil { - fmt.Printf("machine_1_2: failed reading token file: %v\n", err) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_1_2: failed reading token file: %v\n", err) } return &driver.OIDCCredential{ AccessToken: string(accessToken), @@ -139,9 +135,13 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool { }, nil }) + if err != nil { + return fmt.Errorf("machine_1_2: failed connecting client: %v", err) + } + var wg sync.WaitGroup - findFailed := false + var findFailed error = nil for i := 0; i < 10; i++ { wg.Add(1) go func() { @@ -149,8 +149,7 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool { coll := client.Database("test").Collection("test") _, err := coll.Find(context.Background(), bson.D{}) if err != nil { - fmt.Printf("machine_1_2: failed executing Find: %v\n", err) - findFailed = true + findFailed = fmt.Errorf("machine_1_2: failed executing Find: %v\n", err) } }() } @@ -159,33 +158,31 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool { countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - fmt.Printf("machine_1_2: expected callback count to be 1, got %d\n", callbackCount) - return true + return fmt.Errorf("machine_1_2: expected callback count to be 1, got %d\n", callbackCount) + } + if callbackFailed != nil { + return callbackFailed } - return callbackFailed || findFailed + return findFailed } -func machine_2_1_validCallbackInputs() bool { +func machine_2_1_validCallbackInputs() error { callbackCount := 0 - callbackFailed := false + var callbackFailed error = nil countMutex := sync.Mutex{} - client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { if args.RefreshToken != nil { - fmt.Printf("machine_2_1: expected RefreshToken to be nil, got %v\n", args.RefreshToken) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_2_1: expected RefreshToken to be nil, got %v\n", args.RefreshToken) } if args.Timeout.Before(time.Now()) { - fmt.Printf("machine_2_1: expected timeout to be in the future, got %v\n", args.Timeout) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_2_1: expected timeout to be in the future, got %v\n", args.Timeout) } if args.Version < 1 { - fmt.Printf("machine_2_1: expected Version to be at least 1, got %d\n", args.Version) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_2_1: expected Version to be at least 1, got %d\n", args.Version) } if args.IDPInfo != nil { - fmt.Printf("machine_2_1: expected IdpID to be nil for Machine flow, got %v\n", args.IDPInfo) - callbackFailed = true + callbackFailed = fmt.Errorf("machine_2_1: expected IdpID to be nil for Machine flow, got %v\n", args.IDPInfo) } countMutex.Lock() defer countMutex.Unlock() @@ -203,27 +200,29 @@ func machine_2_1_validCallbackInputs() bool { }, nil }) + if err != nil { + return fmt.Errorf("machine_2_1: failed connecting client: %v", err) + } + coll := client.Database("test").Collection("test") - _, err := coll.Find(context.Background(), bson.D{}) + _, err = coll.Find(context.Background(), bson.D{}) if err != nil { - fmt.Printf("machine_2_1: failed executing Find: %v", err) - return true + return fmt.Errorf("machine_2_1: failed executing Find: %v", err) } countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - fmt.Printf("machine_2_1: expected callback count to be 1, got %d\n", callbackCount) - return true + return fmt.Errorf("machine_2_1: expected callback count to be 1, got %d\n", callbackCount) } return callbackFailed } -func machine_2_3_oidcCallbackReturnMissingData() bool { +func machine_2_3_oidcCallbackReturnMissingData() error { callbackCount := 0 countMutex := sync.Mutex{} - client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { countMutex.Lock() defer countMutex.Unlock() callbackCount++ @@ -235,23 +234,25 @@ func machine_2_3_oidcCallbackReturnMissingData() bool { }, nil }) + if err != nil { + return fmt.Errorf("machine_2_3: failed connecting client: %v\n", err) + } + coll := client.Database("test").Collection("test") - _, err := coll.Find(context.Background(), bson.D{}) + _, err = coll.Find(context.Background(), bson.D{}) if err == nil { - fmt.Println("machine_2_3: should have failed to executed Find, but succeeded") - return true + return fmt.Errorf("machine_2_3: should have failed to executed Find, but succeeded") } countMutex.Lock() defer countMutex.Unlock() if callbackCount != 1 { - fmt.Printf("machine_2_3: expected callback count to be 1, got %d\n", callbackCount) - return true + return fmt.Errorf("machine_2_3: expected callback count to be 1, got %d\n", callbackCount) } - return true + return nil } -func machine_2_4_invalidClientConfigurationWithCallback() bool { +func machine_2_4_invalidClientConfigurationWithCallback() error { _, err := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { t := time.Now().Add(time.Hour) return &driver.OIDCCredential{ @@ -263,8 +264,7 @@ func machine_2_4_invalidClientConfigurationWithCallback() bool { map[string]string{"ENVIRONMENT": "test"}, ) if err == nil { - fmt.Println("machine_2_4: succeeded building client when it should fail") - return true + return fmt.Errorf("machine_2_4: succeeded building client when it should fail") } - return false + return nil } diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 16aa613101..5138ae4402 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -31,6 +31,7 @@ const resourceProp = "TOKEN_RESOURCE" const azureEnvironmentValue = "azure" const gcpEnvironmentValue = "gcp" +const testEnvironmentValue = "test" const apiVersion = 1 const invalidateSleepTimeout = 100 * time.Millisecond @@ -82,19 +83,22 @@ type OIDCAuthenticator struct { } func newOIDCAuthenticator(cred *Cred) (Authenticator, error) { + if cred.Password != "" { + return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC) + } if cred.Props != nil { if env, ok := cred.Props[environmentProp]; ok { switch strings.ToLower(env) { - case "azure": + case azureEnvironmentValue: fallthrough - case "gcp": + case gcpEnvironmentValue: if _, ok := cred.Props[resourceProp]; !ok { - return nil, fmt.Errorf("%s must be specified for %s %s", resourceProp, env, environmentProp) + return nil, fmt.Errorf("%q must be specified for %q %q", resourceProp, env, environmentProp) } fallthrough - case "test": + case testEnvironmentValue: if cred.OIDCMachineCallback != nil || cred.OIDCHumanCallback != nil { - return nil, fmt.Errorf("OIDC callbacks are not allowed for %s %s", env, environmentProp) + return nil, fmt.Errorf("OIDC callbacks are not allowed for %q %q", env, environmentProp) } } } @@ -146,15 +150,11 @@ func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { return nil, nil } - switch env { + //switch env { // TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider // TODO GODRIVER-2806: Automatic token acquisition for GCP Identity Provider // This is here just to pass the linter, it will be fixed in one of the above tickets. - case azureEnvironmentValue, gcpEnvironmentValue: - return func(ctx context.Context, args *OIDCArgs) (*OIDCCredential, error) { - return nil, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) - }, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) - } + //} return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env) }