diff --git a/authentication/authentication.go b/authentication/authentication.go index 2c996cd2..574de67f 100644 --- a/authentication/authentication.go +++ b/authentication/authentication.go @@ -42,6 +42,7 @@ type Provider interface { Middleware() Middleware GRPCMiddleware() grpc.StreamServerInterceptor Handler() (string, http.Handler) + LoginPath(tenant string) string } type tenantHandlers map[string]http.Handler @@ -52,6 +53,7 @@ type ProviderManager struct { patternHandlers map[string]tenantHandlers middlewares map[string]Middleware gRPCInterceptors map[string]grpc.StreamServerInterceptor + authenticator map[string]Provider logger log.Logger registrationRetryCount *prometheus.CounterVec } @@ -63,6 +65,7 @@ func NewProviderManager(l log.Logger, registrationRetryCount *prometheus.Counter patternHandlers: make(map[string]tenantHandlers), middlewares: make(map[string]Middleware), gRPCInterceptors: make(map[string]grpc.StreamServerInterceptor), + authenticator: make(map[string]Provider), logger: l, } } @@ -91,6 +94,7 @@ func (ah *ProviderManager) InitializeProvider(config map[string]interface{}, ah.mtx.Lock() ah.middlewares[tenant] = authenticator.Middleware() ah.gRPCInterceptors[tenant] = authenticator.GRPCMiddleware() + ah.authenticator[tenant] = authenticator pattern, handler := authenticator.Handler() if pattern != "" && handler != nil { if ah.patternHandlers[pattern] == nil { @@ -148,6 +152,18 @@ func (ah *ProviderManager) PatternHandler(pattern string) http.HandlerFunc { }) } +func (ah *ProviderManager) GetLoginPath(tenant string) (string, bool) { + ah.mtx.RLock() + provider, ok := ah.authenticator[tenant] + ah.mtx.RUnlock() + + if !ok { + return "", false + } + + return provider.LoginPath(tenant), true +} + func getProviderFactory(authType string) (ProviderFactory, error) { providersMtx.RLock() defer providersMtx.RUnlock() diff --git a/authentication/authentication_test.go b/authentication/authentication_test.go index 7e990316..2f55287d 100644 --- a/authentication/authentication_test.go +++ b/authentication/authentication_test.go @@ -46,6 +46,10 @@ func (a dummyAuthenticator) Handler() (string, http.Handler) { return "", nil } +func (a dummyAuthenticator) LoginPath(tenant string) string { + return "" +} + func newdummyAuthenticator(c map[string]interface{}, tenant string, registrationRetryCount *prometheus.CounterVec, logger log.Logger) (Provider, error) { var config dummyAuthenticatorConfig diff --git a/authentication/mtls.go b/authentication/mtls.go index e6608490..e00db19e 100644 --- a/authentication/mtls.go +++ b/authentication/mtls.go @@ -156,3 +156,7 @@ func (a MTLSAuthenticator) GRPCMiddleware() grpc.StreamServerInterceptor { func (a MTLSAuthenticator) Handler() (string, http.Handler) { return "", nil } + +func (a MTLSAuthenticator) LoginPath(tenant string) string { + return "" +} diff --git a/authentication/oidc.go b/authentication/oidc.go index ce9f375d..cf2798cf 100644 --- a/authentication/oidc.go +++ b/authentication/oidc.go @@ -434,3 +434,7 @@ func (a oidcAuthenticator) checkAuth(ctx context.Context, token string) (context return ctx, "", http.StatusOK, codes.OK } + +func (a oidcAuthenticator) LoginPath(tenant string) string { + return strings.ReplaceAll("/oidc/{tenant}/login", "{tenant}", tenant) +} diff --git a/authentication/openshift.go b/authentication/openshift.go index 6d1f2a01..f5eaa766 100644 --- a/authentication/openshift.go +++ b/authentication/openshift.go @@ -507,3 +507,7 @@ func (a OpenShiftAuthenticator) GRPCMiddleware() grpc.StreamServerInterceptor { func (a OpenShiftAuthenticator) Handler() (string, http.Handler) { return "/openshift/{tenant}", a.handler } + +func (a OpenShiftAuthenticator) LoginPath(tenant string) string { + return strings.ReplaceAll("/openshift/{tenant}/login", "{tenant}", tenant) +} diff --git a/main.go b/main.go index 3943cbb2..a1cbb0d5 100644 --- a/main.go +++ b/main.go @@ -625,6 +625,24 @@ func main() { }) } + // Redirect "/api/traces/v1/{tenant}/login" to e.g. "/oidc/{tenant}/login". + r.Get("/api/traces/v1/{tenant}/login", func(w http.ResponseWriter, r *http.Request) { + tenant, ok := authentication.GetTenant(r.Context()) + if !ok { + w.WriteHeader(http.StatusNotFound) + return + } + + // Skip providers without login URLs. + loginPath, ok := pm.GetLoginPath(tenant) + if !ok || loginPath == "" { + w.WriteHeader(http.StatusNotFound) + return + } + + http.Redirect(w, r, loginPath, http.StatusMovedPermanently) + }) + r.Mount("/api/traces/v1/{tenant}", stripTenantPrefix("/api/traces/v1", tracesv1.NewV2Handler(