Skip to content

Commit

Permalink
GODRIVER-2911: Tests all passing
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeredit committed Jun 23, 2024
1 parent 9dd40c9 commit b343ebb
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 9 deletions.
170 changes: 164 additions & 6 deletions cmd/testoidcauth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ func main() {
aux("machine_2_4_invalidClientConfigurationWithCallback", machine24invalidClientConfigurationWithCallback)
aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine31failureWithCachedTokensFetchANewTokenAndRetryAuth)
aux("machine_3_2_authFailuresWithoutCachedTokensReturnsAnError", machine32authFailuresWithoutCachedTokensReturnsAnError)
//aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine33UnexpectedErrorCodeDoesNotClearTheCache)
aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine33UnexpectedErrorCodeDoesNotClearTheCache)
aux("machine_4_1_reauthenticationSucceeds", machine41ReauthenticationSucceeds)
aux("machine_4_2_readCommandsFailIfReauthenticationFails", machine42ReadCommandsFailIfReauthenticationFails)
aux("machine_4_3_writeCommandsFailIfReauthenticationFails", machine43WriteCommandsFailIfReauthenticationFails)
if hasError {
log.Fatal("One or more tests failed")
}
Expand Down Expand Up @@ -482,8 +484,6 @@ func machine33UnexpectedErrorCodeDoesNotClearTheCache() error {
if err != nil {
return fmt.Errorf("machine_3_3: failed executing Find: %v", err)
}
countMutex.Lock()
defer countMutex.Unlock()
if callbackCount != 1 {
return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d", callbackCount)
}
Expand Down Expand Up @@ -543,18 +543,176 @@ func machine41ReauthenticationSucceeds() error {
return fmt.Errorf("machine_4_1: failed setting failpoint: %v", res.Err())
}

_, err = coll.Find(context.Background(), bson.D{})
if err != nil {
return fmt.Errorf("machine_4_1: failed executing Find: %v", err)
}
countMutex.Lock()
defer countMutex.Unlock()
if callbackCount != 2 {
return fmt.Errorf("machine_4_1: expected callback count to be 2, got %d", callbackCount)
}
return callbackFailed
}

func machine42ReadCommandsFailIfReauthenticationFails() error {
callbackCount := 0
var callbackFailed error
firstCall := true
countMutex := sync.Mutex{}

adminClient, err := connectAdminClinet()
defer adminClient.Disconnect(context.Background())

if err != nil {
return fmt.Errorf("machine_4_2: failed connecting admin client: %v", err)
}

client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
countMutex.Lock()
defer countMutex.Unlock()
callbackCount++
t := time.Now().Add(time.Hour)
if firstCall {
firstCall = false
tokenFile := tokenFile("test_user1")
accessToken, err := os.ReadFile(tokenFile)
if err != nil {
callbackFailed = fmt.Errorf("machine_4_2: failed reading token file: %v", err)
}
return &driver.OIDCCredential{
AccessToken: string(accessToken),
ExpiresAt: &t,
RefreshToken: nil,
}, nil
} else {
return &driver.OIDCCredential{
AccessToken: "this is a bad, bad token",
ExpiresAt: &t,
RefreshToken: nil,
}, nil
}
})

defer client.Disconnect(context.Background())

if err != nil {
return fmt.Errorf("machine_4_1: failed setting failpoint: %v", err)
return fmt.Errorf("machine_4_2: failed connecting client: %v", err)
}

coll := client.Database("test").Collection("test")
_, err = coll.Find(context.Background(), bson.D{})
if err != nil {
return fmt.Errorf("machine_4_1: failed executing Find: %v", err)
return fmt.Errorf("machine_4_2: failed executing Find: %v", err)
}

res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{
{Key: "configureFailPoint", Value: "failCommand"},
{Key: "mode", Value: bson.D{
{Key: "times", Value: 1},
}},
{Key: "data", Value: bson.D{
{Key: "failCommands", Value: bson.A{
"find",
}},
{Key: "errorCode", Value: 391},
}},
})

if res.Err() != nil {
return fmt.Errorf("machine_4_2: failed setting failpoint: %v", res.Err())
}

_, err = coll.Find(context.Background(), bson.D{})
if err == nil {
return fmt.Errorf("machine_4_2: Find succeeded when it should fail")
}

countMutex.Lock()
defer countMutex.Unlock()
if callbackCount != 2 {
return fmt.Errorf("machine_4_1: expected callback count to be 2, got %d", callbackCount)
return fmt.Errorf("machine_4_2: expected callback count to be 2, got %d", callbackCount)
}
return callbackFailed
}

func machine43WriteCommandsFailIfReauthenticationFails() error {
callbackCount := 0
var callbackFailed error
firstCall := true
countMutex := sync.Mutex{}

adminClient, err := connectAdminClinet()
defer adminClient.Disconnect(context.Background())

if err != nil {
return fmt.Errorf("machine_4_3: failed connecting admin client: %v", err)
}

client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
countMutex.Lock()
defer countMutex.Unlock()
callbackCount++
t := time.Now().Add(time.Hour)
if firstCall {
firstCall = false
tokenFile := tokenFile("test_user1")
accessToken, err := os.ReadFile(tokenFile)
if err != nil {
callbackFailed = fmt.Errorf("machine_4_3: failed reading token file: %v", err)
}
return &driver.OIDCCredential{
AccessToken: string(accessToken),
ExpiresAt: &t,
RefreshToken: nil,
}, nil
} else {
return &driver.OIDCCredential{
AccessToken: "this is a bad, bad token",
ExpiresAt: &t,
RefreshToken: nil,
}, nil
}
})

defer client.Disconnect(context.Background())

if err != nil {
return fmt.Errorf("machine_4_3: failed connecting client: %v", err)
}

coll := client.Database("test").Collection("test")
_, err = coll.InsertOne(context.Background(), bson.D{})
if err != nil {
return fmt.Errorf("machine_4_3: failed executing Insert: %v", err)
}

res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{
{Key: "configureFailPoint", Value: "failCommand"},
{Key: "mode", Value: bson.D{
{Key: "times", Value: 1},
}},
{Key: "data", Value: bson.D{
{Key: "failCommands", Value: bson.A{
"insert",
}},
{Key: "errorCode", Value: 391},
}},
})

if res.Err() != nil {
return fmt.Errorf("machine_4_3: failed setting failpoint: %v", res.Err())
}

_, err = coll.InsertOne(context.Background(), bson.D{})
if err == nil {
return fmt.Errorf("machine_4_3: Insert succeeded when it should fail")
}

countMutex.Lock()
defer countMutex.Unlock()
if callbackCount != 2 {
return fmt.Errorf("machine_4_3: expected callback count to be 2, got %d", callbackCount)
}
return callbackFailed
}
3 changes: 0 additions & 3 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -916,12 +916,9 @@ func (op Operation) Execute(ctx context.Context) error {
operationErr.Labels = tt.Labels
operationErr.Raw = tt.Raw
case Error:
fmt.Println("!!!!")
// 391 is the reauthentication required error code, so we will attempt a reauth and
// retry the operation, if it is successful.
fmt.Println("code", tt.Code)
if tt.Code == 391 {
fmt.Println("!!!!")
if op.Authenticator != nil {
if err := op.Authenticator.Reauth(ctx); err != nil {
return fmt.Errorf("error reauthenticating: %w", err)
Expand Down

0 comments on commit b343ebb

Please sign in to comment.