diff --git a/app/api/auth/refresh_access_token.go b/app/api/auth/refresh_access_token.go index c2168381c..cb6e42a4b 100644 --- a/app/api/auth/refresh_access_token.go +++ b/app/api/auth/refresh_access_token.go @@ -50,38 +50,40 @@ func (srv *Service) refreshAccessToken(w http.ResponseWriter, r *http.Request) s var expiresIn int32 apiError := service.NoError - sessionMostRecentToken := store. - AccessTokens(). - GetMostRecentValidTokenForSession(sessionID) - if sessionMostRecentToken.Token != oldAccessToken || sessionMostRecentToken.TooNewToRefresh { - // We return the most recent token if the input token is not the most recent one or if it is too new to refresh. - // Note: we know that the token is valid because we checked it in the middleware. - newToken = sessionMostRecentToken.Token - expiresIn = sessionMostRecentToken.SecondsUntilExpiry - } else { - if user.IsTempUser { - service.MustNotBeError(store.InTransaction(func(store *database.DataStore) error { - var err error - newToken, expiresIn, err = auth.RefreshTempUserSession(store, user.GroupID, sessionID) - return err - })) + service.MustNotBeError(sessionIDsInProgress.withLock(sessionID, r, func() error { + sessionMostRecentToken := store. + AccessTokens(). + GetMostRecentValidTokenForSession(sessionID) + if sessionMostRecentToken.Token != oldAccessToken || sessionMostRecentToken.TooNewToRefresh { + // We return the most recent token if the input token is not the most recent one or if it is too new to refresh. + // Note: we know that the token is valid because we checked it in the middleware. + newToken = sessionMostRecentToken.Token + expiresIn = sessionMostRecentToken.SecondsUntilExpiry } else { - // We should not allow concurrency in this part because the login module generates not only - // a new access token, but also a new refresh token and revokes the old one. We want to prevent - // usage of the old refresh token for that reason. - service.MustNotBeError(sessionIDsInProgress.withLock(sessionID, r, func() error { + if user.IsTempUser { + service.MustNotBeError(store.InTransaction(func(store *database.DataStore) error { + var err error + newToken, expiresIn, err = auth.RefreshTempUserSession(store, user.GroupID, sessionID) + return err + })) + } else { + // We should not allow concurrency in this part because the login module generates not only + // a new access token, but also a new refresh token and revokes the old one. We want to prevent + // usage of the old refresh token for that reason. + newToken, expiresIn, apiError = srv.refreshTokens(r.Context(), store, user, sessionID) - return nil - })) + } } - if apiError != service.NoError { - return apiError - } + return nil + })) - store.AccessTokens().DeleteExpiredTokensOfUser(user.GroupID) + if apiError != service.NoError { + return apiError } + store.AccessTokens().DeleteExpiredTokensOfUser(user.GroupID) + srv.respondWithNewAccessToken(r, w, service.CreationSuccess, newToken, time.Now().Add(time.Duration(expiresIn)*time.Second), cookieAttributes) return service.NoError diff --git a/app/api/auth/refresh_access_token_test.go b/app/api/auth/refresh_access_token_test.go index 0c18e02dc..8f9dabe6a 100644 --- a/app/api/auth/refresh_access_token_test.go +++ b/app/api/auth/refresh_access_token_test.go @@ -42,18 +42,17 @@ func TestService_refreshAccessToken_NotAllowRefreshTokenRaces(t *testing.T) { response, mock, logs, err := servicetest.GetResponseForRouteWithMockedDBAndUser( "POST", "/auth/token", "", &database.User{GroupID: 2}, func(mock sqlmock.Sqlmock) { - mock.ExpectQuery("^" + - regexp.QuoteMeta( - "SELECT "+ - "token, "+ - "TIMESTAMPDIFF(SECOND, NOW(), expires_at) AS seconds_until_expiry, "+ - "issued_at > (NOW() - INTERVAL 5 MINUTE) AS too_new_to_refresh "+ - "FROM `access_tokens` WHERE (session_id = ?) ORDER BY expires_at DESC LIMIT 1") + "$"). - WithArgs(sqlmock.AnyArg()). - WillReturnRows(mock.NewRows([]string{"token", "seconds_until_expiry", "too_new_to_refresh"}). - AddRow("accesstoken", 600, false)) - if !timeout { + mock.ExpectQuery("^" + + regexp.QuoteMeta( + "SELECT "+ + "token, "+ + "TIMESTAMPDIFF(SECOND, NOW(), expires_at) AS seconds_until_expiry, "+ + "issued_at > (NOW() - INTERVAL 5 MINUTE) AS too_new_to_refresh "+ + "FROM `access_tokens` WHERE (session_id = ?) ORDER BY expires_at DESC LIMIT 1") + "$"). + WithArgs(sqlmock.AnyArg()). + WillReturnRows(mock.NewRows([]string{"token", "seconds_until_expiry", "too_new_to_refresh"}). + AddRow("accesstoken", 600, false)) mock.ExpectQuery("^" + regexp.QuoteMeta("SELECT refresh_token FROM `sessions` WHERE (session_id = ?) LIMIT 1") + "$"). WithArgs(sqlmock.AnyArg()).