Skip to content

Commit

Permalink
Implement TokenIdentity
Browse files Browse the repository at this point in the history
  • Loading branch information
loafoe committed Jul 8, 2024
1 parent 4751cae commit fe5f4a5
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions connector/hsdp/hsdp.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ type ConnectorData struct {
User iam.Profile
}

type caller uint

const (
createCaller caller = iota
refreshCaller
exchangeCaller
)

// Open returns a connector which can be used to log in users through an upstream
// OpenID Connect provider.
func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) {
Expand Down Expand Up @@ -306,15 +314,15 @@ func (c *HSDPConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
RefreshToken: tr.RefreshToken,
Expiry: time.Unix(tr.ExpiresIn, 0),
}
return c.createIdentity(r.Context(), identity, token, r)
return c.createIdentity(r.Context(), identity, token, r, createCaller)
}

token, err := c.oauth2Config.Exchange(r.Context(), q.Get("code"))
if err != nil {
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}

return c.createIdentity(r.Context(), identity, token, r)
return c.createIdentity(r.Context(), identity, token, r, createCaller)
}

// Refresh is used to refresh a session with the refresh token provided by the IdP
Expand All @@ -334,15 +342,24 @@ func (c *HSDPConnector) Refresh(ctx context.Context, s connector.Scopes, identit
return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err)
}

return c.createIdentity(ctx, identity, token, nil)
return c.createIdentity(ctx, identity, token, nil, refreshCaller)
}

func (c *HSDPConnector) TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (connector.Identity, error) {
var identity connector.Identity
token := &oauth2.Token{
AccessToken: subjectToken,
TokenType: "Bearer",
}
return c.createIdentity(ctx, identity, token, nil, exchangeCaller)
}

func (c *HSDPConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, r *http.Request) (connector.Identity, error) {
func (c *HSDPConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, r *http.Request, caller caller) (connector.Identity, error) {
var claims map[string]interface{}

cd := ConnectorData{}

if c.isSAML() && r != nil {
if caller == createCaller && c.isSAML() && r != nil {
// Save assertion
q := r.URL.Query()
assertion := q.Get("assertion")
Expand All @@ -352,10 +369,10 @@ func (c *HSDPConnector) createIdentity(ctx context.Context, identity connector.I
// We immediately want to run getUserInfo if configured before we validate the claims
userInfo, err := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
if err != nil {
return identity, fmt.Errorf("oidc: error loading userinfo: %v", err)
return identity, fmt.Errorf("hsdp: error loading userinfo: %v", err)
}
if err := userInfo.Claims(&claims); err != nil {
return identity, fmt.Errorf("oidc: failed to decode userinfo claims: %v", err)
return identity, fmt.Errorf("hsdp: failed to decode userinfo claims: %v", err)
}
// Introspect so we can get group assignments
introspectResponse, err := c.introspect(ctx, oauth2.StaticTokenSource(token))
Expand Down Expand Up @@ -398,7 +415,7 @@ func (c *HSDPConnector) createIdentity(ctx context.Context, identity connector.I
}
}
if !found {
return identity, fmt.Errorf("oidc: unexpected hd claim %v", hostedDomain)
return identity, fmt.Errorf("hsdp: unexpected hd claim %v", hostedDomain)
}
}

Expand Down

0 comments on commit fe5f4a5

Please sign in to comment.