diff --git a/connector/hsdp/hsdp.go b/connector/hsdp/hsdp.go index 216df48653..77fc869dc5 100644 --- a/connector/hsdp/hsdp.go +++ b/connector/hsdp/hsdp.go @@ -81,6 +81,14 @@ type ConnectorData struct { User iam.Profile } +type caller uint + +const ( + createCaller caller = iota + refreshCaller + exchangeCaller +) + // Open returns a connector which can be used to log in users through an upstream // OpenID Connect provider. func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) { @@ -306,7 +314,7 @@ func (c *HSDPConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide RefreshToken: tr.RefreshToken, Expiry: time.Unix(tr.ExpiresIn, 0), } - return c.createIdentity(r.Context(), identity, token, r) + return c.createIdentity(r.Context(), identity, token, r, createCaller) } token, err := c.oauth2Config.Exchange(r.Context(), q.Get("code")) @@ -314,7 +322,7 @@ func (c *HSDPConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide return identity, fmt.Errorf("oidc: failed to get token: %v", err) } - return c.createIdentity(r.Context(), identity, token, r) + return c.createIdentity(r.Context(), identity, token, r, createCaller) } // Refresh is used to refresh a session with the refresh token provided by the IdP @@ -334,15 +342,24 @@ func (c *HSDPConnector) Refresh(ctx context.Context, s connector.Scopes, identit return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err) } - return c.createIdentity(ctx, identity, token, nil) + return c.createIdentity(ctx, identity, token, nil, refreshCaller) +} + +func (c *HSDPConnector) TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (connector.Identity, error) { + var identity connector.Identity + token := &oauth2.Token{ + AccessToken: subjectToken, + TokenType: "Bearer", + } + return c.createIdentity(ctx, identity, token, nil, exchangeCaller) } -func (c *HSDPConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, r *http.Request) (connector.Identity, error) { +func (c *HSDPConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, r *http.Request, caller caller) (connector.Identity, error) { var claims map[string]interface{} cd := ConnectorData{} - if c.isSAML() && r != nil { + if caller == createCaller && c.isSAML() && r != nil { // Save assertion q := r.URL.Query() assertion := q.Get("assertion") @@ -352,10 +369,10 @@ func (c *HSDPConnector) createIdentity(ctx context.Context, identity connector.I // We immediately want to run getUserInfo if configured before we validate the claims userInfo, err := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(token)) if err != nil { - return identity, fmt.Errorf("oidc: error loading userinfo: %v", err) + return identity, fmt.Errorf("hsdp: error loading userinfo: %v", err) } if err := userInfo.Claims(&claims); err != nil { - return identity, fmt.Errorf("oidc: failed to decode userinfo claims: %v", err) + return identity, fmt.Errorf("hsdp: failed to decode userinfo claims: %v", err) } // Introspect so we can get group assignments introspectResponse, err := c.introspect(ctx, oauth2.StaticTokenSource(token)) @@ -398,7 +415,7 @@ func (c *HSDPConnector) createIdentity(ctx context.Context, identity connector.I } } if !found { - return identity, fmt.Errorf("oidc: unexpected hd claim %v", hostedDomain) + return identity, fmt.Errorf("hsdp: unexpected hd claim %v", hostedDomain) } }