Skip to content

Commit

Permalink
Optional ExtendPayload interface support
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Lo-A-Foe <[email protected]>
  • Loading branch information
loafoe committed Oct 3, 2024
1 parent 16f873d commit 9e1ae0e
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 14 deletions.
4 changes: 4 additions & 0 deletions connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,7 @@ type RefreshConnector interface {
type TokenIdentityConnector interface {
TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (Identity, error)
}

type PayloadExtender interface {
ExtendPayload(scopes []string, payload []byte, connectorData []byte) ([]byte, error)
}
16 changes: 8 additions & 8 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,14 +728,14 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
implicitOrHybrid = true
var err error

accessToken, _, err = s.newAccessToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID)
accessToken, _, err = s.newAccessToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID, authReq.ConnectorData)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

idToken, idTokenExpiry, err = s.newIDToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID)
idToken, idTokenExpiry, err = s.newIDToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID, authReq.ConnectorData)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to create ID token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -943,14 +943,14 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
}

func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
accessToken, _, err := s.newAccessToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
accessToken, _, err := s.newAccessToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID, authCode.ConnectorData)
if err != nil {
s.logger.ErrorContext(ctx, "failed to create new access token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err
}

idToken, expiry, err := s.newIDToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID)
idToken, expiry, err := s.newIDToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID, authCode.ConnectorData)
if err != nil {
s.logger.ErrorContext(ctx, "failed to create ID token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -1217,14 +1217,14 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
Groups: identity.Groups,
}

accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, scopes, nonce, connID)
accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, scopes, nonce, connID, identity.ConnectorData)
if err != nil {
s.logger.ErrorContext(r.Context(), "password grant failed to create new access token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, scopes, nonce, accessToken, "", connID)
idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, scopes, nonce, accessToken, "", connID, identity.ConnectorData)
if err != nil {
s.logger.ErrorContext(r.Context(), "password grant failed to create new ID token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
Expand Down Expand Up @@ -1421,9 +1421,9 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli
var expiry time.Time
switch requestedTokenType {
case tokenTypeID:
resp.AccessToken, expiry, err = s.newIDToken(r.Context(), client.ID, claims, scopes, "", "", "", connID)
resp.AccessToken, expiry, err = s.newIDToken(r.Context(), client.ID, claims, scopes, "", "", "", connID, identity.ConnectorData)
case tokenTypeAccess:
resp.AccessToken, expiry, err = s.newAccessToken(r.Context(), client.ID, claims, scopes, "", connID)
resp.AccessToken, expiry, err = s.newAccessToken(r.Context(), client.ID, claims, scopes, "", connID, identity.ConnectorData)
default:
s.tokenErrHelper(w, errRequestNotSupported, "Invalid requested_token_type.", http.StatusBadRequest)
return
Expand Down
2 changes: 1 addition & 1 deletion server/introspectionhandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func TestHandleIntrospect(t *testing.T) {
Email: "[email protected]",
EmailVerified: true,
Groups: []string{"a", "b"},
}, []string{"openid", "email", "profile", "groups"}, "foo", "", "", "test")
}, []string{"openid", "email", "profile", "groups"}, "foo", "", "", "test", nil)
require.NoError(t, err)

activeRefreshToken, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"})
Expand Down
24 changes: 21 additions & 3 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ type federatedIDClaims struct {
UserID string `json:"user_id,omitempty"`
}

func (s *Server) newAccessToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, expiry time.Time, err error) {
return s.newIDToken(ctx, clientID, claims, scopes, nonce, storage.NewID(), "", connID)
func (s *Server) newAccessToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, connID string, connectorData []byte) (accessToken string, expiry time.Time, err error) {
return s.newIDToken(ctx, clientID, claims, scopes, nonce, storage.NewID(), "", connID, connectorData)
}

func getClientID(aud audience, azp string) (string, error) {
Expand Down Expand Up @@ -351,13 +351,20 @@ func genSubject(userID string, connID string) (string, error) {
return internal.Marshal(sub)
}

func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) {

func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string, connectorData []byte) (idToken string, expiry time.Time, err error) {
keys, err := s.storage.GetKeys()
if err != nil {
s.logger.ErrorContext(ctx, "failed to get keys", "err", err)
return "", expiry, err
}

conn, err := s.getConnector(connID)
if err != nil {
s.logger.ErrorContext(ctx, "failed to get connector", "connector", connID, "err", err)
return "", expiry, err
}

signingKey := keys.SigningKey
if signingKey == nil {
return "", expiry, fmt.Errorf("no key to sign payload with")
Expand Down Expand Up @@ -446,6 +453,17 @@ func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage
return "", expiry, fmt.Errorf("could not serialize claims: %v", err)
}

switch c := conn.Connector.(type) {
case connector.PayloadExtender:
extendedPayload, err := c.ExtendPayload(scopes, payload, connectorData)
if err != nil {
s.logger.WarnContext(ctx, "failed to enhance payload", "err", err)
break
}
payload = extendedPayload
default:
}

if idToken, err = signPayload(signingKey, signingAlg, payload); err != nil {
return "", expiry, fmt.Errorf("failed to sign payload: %v", err)
}
Expand Down
5 changes: 3 additions & 2 deletions server/refreshhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,14 +364,15 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Groups: ident.Groups,
}

accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID)
accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID, rCtx.connectorData)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err)
s.refreshTokenErrHelper(w, newInternalServerError())
return
}

idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID)

idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID, rCtx.connectorData)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to create ID token", "err", err)
s.refreshTokenErrHelper(w, newInternalServerError())
Expand Down

0 comments on commit 9e1ae0e

Please sign in to comment.