Skip to content

Commit

Permalink
GODRIVER-2911: Change to using errors
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeredit committed Jun 20, 2024
1 parent f33dca7 commit 3c00307
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 71 deletions.
120 changes: 60 additions & 60 deletions cmd/testoidcauth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,11 @@ func tokenFile(user string) string {
return path.Join(oidcTokenDir, user)
}

func connectWithMachineCB(uri string, cb driver.OIDCCallback) *mongo.Client {
func connectWithMachineCB(uri string, cb driver.OIDCCallback) (*mongo.Client, error) {
opts := options.Client().ApplyURI(uri)

opts.Auth.OIDCMachineCallback = cb
client, err := mongo.Connect(context.Background(), opts)
if err != nil {
fmt.Printf("Error connecting client: %v", err)
}
return client
return mongo.Connect(context.Background(), opts)
}

func connectWithMachineCBAndProperties(uri string, cb driver.OIDCCallback, props map[string]string) (*mongo.Client, error) {
Expand All @@ -56,16 +52,16 @@ func connectWithMachineCBAndProperties(uri string, cb driver.OIDCCallback, props

func main() {
hasError := false
aux := func(test_name string, f func() bool) {
aux := func(test_name string, f func() error) {
fmt.Printf("%s...", test_name)
testResult := f()
if testResult {
err := f()
if err != nil {
fmt.Println("Test Error: ", err)
fmt.Println("...Failed")
hasError = true
} else {
fmt.Println("...Ok")
}
fmt.Println("hasError: ", hasError, "testResult: ", testResult)
hasError = hasError || testResult
}
aux("machine_1_1_callbackIsCalled", machine_1_1_callbackIsCalled)
aux("machine_1_2_callbackIsCalledOnlyOneForMultipleConnections", machine_1_2_callbackIsCalledOnlyOneForMultipleConnections)
Expand All @@ -77,21 +73,20 @@ func main() {
}
}

func machine_1_1_callbackIsCalled() bool {
func machine_1_1_callbackIsCalled() error {
callbackCount := 0
callbackFailed := false
var callbackFailed error = nil
countMutex := sync.Mutex{}

client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
countMutex.Lock()
defer countMutex.Unlock()
callbackCount++
t := time.Now().Add(time.Hour)
tokenFile := tokenFile("test_user1")
accessToken, err := os.ReadFile(tokenFile)
if err != nil {
fmt.Printf("machine_1_1: failed reading token file: %v\n", err)
callbackFailed = true
callbackFailed = fmt.Errorf("machine_1_1: failed reading token file: %v\n", err)
}
return &driver.OIDCCredential{
AccessToken: string(accessToken),
Expand All @@ -100,37 +95,38 @@ func machine_1_1_callbackIsCalled() bool {
}, nil
})

if err != nil {
return fmt.Errorf("machine_1_1: failed connecting client: %v", err)
}

coll := client.Database("test").Collection("test")

_, err := coll.Find(context.Background(), bson.D{})
_, err = coll.Find(context.Background(), bson.D{})
if err != nil {
fmt.Printf("machine_1_1: failed executing Find: %v", err)
return true
return fmt.Errorf("machine_1_1: failed executing Find: %v", err)
}
countMutex.Lock()
defer countMutex.Unlock()
if callbackCount != 1 {
fmt.Printf("machine_1_1: expected callback count to be 1, got %d\n", callbackCount)
return true
return fmt.Errorf("machine_1_1: expected callback count to be 1, got %d\n", callbackCount)
}
return callbackFailed
}

func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool {
func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() error {
callbackCount := 0
callbackFailed := false
var callbackFailed error = nil
countMutex := sync.Mutex{}

client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
countMutex.Lock()
defer countMutex.Unlock()
callbackCount++
t := time.Now().Add(time.Hour)
tokenFile := tokenFile("test_user1")
accessToken, err := os.ReadFile(tokenFile)
if err != nil {
fmt.Printf("machine_1_2: failed reading token file: %v\n", err)
callbackFailed = true
callbackFailed = fmt.Errorf("machine_1_2: failed reading token file: %v\n", err)
}
return &driver.OIDCCredential{
AccessToken: string(accessToken),
Expand All @@ -139,18 +135,21 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool {
}, nil
})

if err != nil {
return fmt.Errorf("machine_1_2: failed connecting client: %v", err)
}

var wg sync.WaitGroup

findFailed := false
var findFailed error = nil
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
coll := client.Database("test").Collection("test")
_, err := coll.Find(context.Background(), bson.D{})
if err != nil {
fmt.Printf("machine_1_2: failed executing Find: %v\n", err)
findFailed = true
findFailed = fmt.Errorf("machine_1_2: failed executing Find: %v\n", err)
}
}()
}
Expand All @@ -159,33 +158,31 @@ func machine_1_2_callbackIsCalledOnlyOneForMultipleConnections() bool {
countMutex.Lock()
defer countMutex.Unlock()
if callbackCount != 1 {
fmt.Printf("machine_1_2: expected callback count to be 1, got %d\n", callbackCount)
return true
return fmt.Errorf("machine_1_2: expected callback count to be 1, got %d\n", callbackCount)
}
if callbackFailed != nil {
return callbackFailed
}
return callbackFailed || findFailed
return findFailed
}

func machine_2_1_validCallbackInputs() bool {
func machine_2_1_validCallbackInputs() error {
callbackCount := 0
callbackFailed := false
var callbackFailed error = nil
countMutex := sync.Mutex{}

client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
if args.RefreshToken != nil {
fmt.Printf("machine_2_1: expected RefreshToken to be nil, got %v\n", args.RefreshToken)
callbackFailed = true
callbackFailed = fmt.Errorf("machine_2_1: expected RefreshToken to be nil, got %v\n", args.RefreshToken)
}
if args.Timeout.Before(time.Now()) {
fmt.Printf("machine_2_1: expected timeout to be in the future, got %v\n", args.Timeout)
callbackFailed = true
callbackFailed = fmt.Errorf("machine_2_1: expected timeout to be in the future, got %v\n", args.Timeout)
}
if args.Version < 1 {
fmt.Printf("machine_2_1: expected Version to be at least 1, got %d\n", args.Version)
callbackFailed = true
callbackFailed = fmt.Errorf("machine_2_1: expected Version to be at least 1, got %d\n", args.Version)
}
if args.IDPInfo != nil {
fmt.Printf("machine_2_1: expected IdpID to be nil for Machine flow, got %v\n", args.IDPInfo)
callbackFailed = true
callbackFailed = fmt.Errorf("machine_2_1: expected IdpID to be nil for Machine flow, got %v\n", args.IDPInfo)
}
countMutex.Lock()
defer countMutex.Unlock()
Expand All @@ -203,27 +200,29 @@ func machine_2_1_validCallbackInputs() bool {
}, nil
})

if err != nil {
return fmt.Errorf("machine_2_1: failed connecting client: %v", err)
}

coll := client.Database("test").Collection("test")

_, err := coll.Find(context.Background(), bson.D{})
_, err = coll.Find(context.Background(), bson.D{})
if err != nil {
fmt.Printf("machine_2_1: failed executing Find: %v", err)
return true
return fmt.Errorf("machine_2_1: failed executing Find: %v", err)
}
countMutex.Lock()
defer countMutex.Unlock()
if callbackCount != 1 {
fmt.Printf("machine_2_1: expected callback count to be 1, got %d\n", callbackCount)
return true
return fmt.Errorf("machine_2_1: expected callback count to be 1, got %d\n", callbackCount)
}
return callbackFailed
}

func machine_2_3_oidcCallbackReturnMissingData() bool {
func machine_2_3_oidcCallbackReturnMissingData() error {
callbackCount := 0
countMutex := sync.Mutex{}

client := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
countMutex.Lock()
defer countMutex.Unlock()
callbackCount++
Expand All @@ -235,23 +234,25 @@ func machine_2_3_oidcCallbackReturnMissingData() bool {
}, nil
})

if err != nil {
return fmt.Errorf("machine_2_3: failed connecting client: %v\n", err)
}

coll := client.Database("test").Collection("test")

_, err := coll.Find(context.Background(), bson.D{})
_, err = coll.Find(context.Background(), bson.D{})
if err == nil {
fmt.Println("machine_2_3: should have failed to executed Find, but succeeded")
return true
return fmt.Errorf("machine_2_3: should have failed to executed Find, but succeeded")
}
countMutex.Lock()
defer countMutex.Unlock()
if callbackCount != 1 {
fmt.Printf("machine_2_3: expected callback count to be 1, got %d\n", callbackCount)
return true
return fmt.Errorf("machine_2_3: expected callback count to be 1, got %d\n", callbackCount)
}
return true
return nil
}

func machine_2_4_invalidClientConfigurationWithCallback() bool {
func machine_2_4_invalidClientConfigurationWithCallback() error {
_, err := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
t := time.Now().Add(time.Hour)
return &driver.OIDCCredential{
Expand All @@ -263,8 +264,7 @@ func machine_2_4_invalidClientConfigurationWithCallback() bool {
map[string]string{"ENVIRONMENT": "test"},
)
if err == nil {
fmt.Println("machine_2_4: succeeded building client when it should fail")
return true
return fmt.Errorf("machine_2_4: succeeded building client when it should fail")
}
return false
return nil
}
22 changes: 11 additions & 11 deletions x/mongo/driver/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const resourceProp = "TOKEN_RESOURCE"

const azureEnvironmentValue = "azure"
const gcpEnvironmentValue = "gcp"
const testEnvironmentValue = "test"

const apiVersion = 1
const invalidateSleepTimeout = 100 * time.Millisecond
Expand Down Expand Up @@ -82,19 +83,22 @@ type OIDCAuthenticator struct {
}

func newOIDCAuthenticator(cred *Cred) (Authenticator, error) {
if cred.Password != "" {
return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC)
}
if cred.Props != nil {
if env, ok := cred.Props[environmentProp]; ok {
switch strings.ToLower(env) {
case "azure":
case azureEnvironmentValue:
fallthrough
case "gcp":
case gcpEnvironmentValue:
if _, ok := cred.Props[resourceProp]; !ok {
return nil, fmt.Errorf("%s must be specified for %s %s", resourceProp, env, environmentProp)
return nil, fmt.Errorf("%q must be specified for %q %q", resourceProp, env, environmentProp)
}
fallthrough
case "test":
case testEnvironmentValue:
if cred.OIDCMachineCallback != nil || cred.OIDCHumanCallback != nil {
return nil, fmt.Errorf("OIDC callbacks are not allowed for %s %s", env, environmentProp)
return nil, fmt.Errorf("OIDC callbacks are not allowed for %q %q", env, environmentProp)
}
}
}
Expand Down Expand Up @@ -146,15 +150,11 @@ func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) {
return nil, nil
}

switch env {
//switch env {
// TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider
// TODO GODRIVER-2806: Automatic token acquisition for GCP Identity Provider
// This is here just to pass the linter, it will be fixed in one of the above tickets.
case azureEnvironmentValue, gcpEnvironmentValue:
return func(ctx context.Context, args *OIDCArgs) (*OIDCCredential, error) {
return nil, fmt.Errorf("automatic token acquisition for %q not implemented yet", env)
}, fmt.Errorf("automatic token acquisition for %q not implemented yet", env)
}
//}

return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env)
}
Expand Down

0 comments on commit 3c00307

Please sign in to comment.