From 086724d3a914531cec35e567ff6ff607843ac300 Mon Sep 17 00:00:00 2001 From: Evan Anderson Date: Thu, 9 Jan 2025 23:58:22 -0800 Subject: [PATCH] Limit allowed redirect URLs from GetAuthorizationURL --- internal/controlplane/handlers_oauth.go | 39 ++++++++++++++++++++ internal/controlplane/handlers_oauth_test.go | 22 +++++++++++ 2 files changed, 61 insertions(+) diff --git a/internal/controlplane/handlers_oauth.go b/internal/controlplane/handlers_oauth.go index 9e2cd38b08..e0aab715ab 100644 --- a/internal/controlplane/handlers_oauth.go +++ b/internal/controlplane/handlers_oauth.go @@ -49,6 +49,11 @@ func (s *Server) GetAuthorizationURL(ctx context.Context, entityCtx := engcontext.EntityFromContext(ctx) projectID := entityCtx.Project.ID + redirectUrl, err := url.Parse(req.GetRedirectUrl()) + if err != nil || !s.alllowedRedirectURL(redirectUrl) { + return nil, util.UserVisibleError(codes.InvalidArgument, "invalid redirect URL") + } + var providerName string if req.GetContext().GetProvider() == "" { providerName = defaultProvider @@ -186,6 +191,40 @@ func (s *Server) GetAuthorizationURL(ctx context.Context, }, nil } +func (s *Server) alllowedRedirectURL(redirectUrl *url.URL) bool { + if redirectUrl == nil || redirectUrl.String() == "" { + return true // Empty URL is allowed + } + if redirectUrl.Host == "localhost" { + return true + } + hostUrl, err := redirectUrl.Parse("/") + if err != nil { + return false + } + hostUrlString := hostUrl.String() + + providerCfg := s.cfg.Provider + + if providerCfg.GitHub != nil && strings.HasPrefix(providerCfg.GitHub.RedirectURI, hostUrlString) { + return true + } + if providerCfg.GitHubApp != nil && strings.HasPrefix(providerCfg.GitHubApp.RedirectURI, hostUrlString) { + return true + } + if providerCfg.GitLab != nil && strings.HasPrefix(providerCfg.GitLab.RedirectURI, hostUrlString) { + return true + } + + if slices.ContainsFunc(s.cfg.HTTPServer.CORS.AllowOrigins, func(u string) bool { + return u == hostUrlString || u+"/" == hostUrlString + }) { + return true + } + + return false +} + // HandleOAuthCallback handles the OAuth 2.0 authorization code callback from the enrolled // provider. This function gathers the state from the database and compares it to the state // passed in. If they match, the provider code is exchanged for a provider token. diff --git a/internal/controlplane/handlers_oauth_test.go b/internal/controlplane/handlers_oauth_test.go index eb0237c4a8..4b829a1d9e 100644 --- a/internal/controlplane/handlers_oauth_test.go +++ b/internal/controlplane/handlers_oauth_test.go @@ -196,6 +196,7 @@ func TestGetAuthorizationURL(t *testing.T) { githubAppProviderClass := "github-app" nonGithubProviderName := "non-github" projectIdStr := projectID.String() + attackerUrl := "https://www.attacker.com/collect/here" testCases := []struct { name string @@ -231,6 +232,7 @@ func TestGetAuthorizationURL(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) + return } if res.Url == "" { @@ -265,6 +267,7 @@ func TestGetAuthorizationURL(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) + return } if res.Url == "" { @@ -293,6 +296,24 @@ func TestGetAuthorizationURL(t *testing.T) { expectedStatusCode: codes.InvalidArgument, }, + { + name: "Bad redirect URL", + req: &pb.GetAuthorizationURLRequest{ + Context: &pb.Context{ + Provider: &githubProviderClass, + Project: &projectIdStr, + }, + Cli: true, + RedirectUrl: &attackerUrl, + }, + buildStubs: func(_ *mockdb.MockStore) {}, + checkResponse: func(t *testing.T, _ *pb.GetAuthorizationURLResponse, err error) { + t.Helper() + + assert.Error(t, err, "Expected error in GetAuthorizationURL") + }, + expectedStatusCode: codes.InvalidArgument, + }, { name: "No GitHub id", req: &pb.GetAuthorizationURLRequest{ @@ -325,6 +346,7 @@ func TestGetAuthorizationURL(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) + return } if res.Url == "" {