From b6bf259dd20ee02d4f365722f83194d174869e3f Mon Sep 17 00:00:00 2001 From: Jake Van Vorhis <83739412+jakedoublev@users.noreply.github.com> Date: Sun, 7 Jul 2024 10:33:35 -0600 Subject: [PATCH] feat(core): extend authz policy (#1105) Resolves #1104 --- service/internal/auth/authn.go | 10 +-- service/internal/auth/authn_test.go | 9 +++ service/internal/auth/casbin.go | 117 +++++++++++++++++++++------ service/internal/auth/casbin_test.go | 76 ++++++++++++++++- service/internal/server/server.go | 2 + service/pkg/server/options.go | 16 +++- service/pkg/server/start.go | 15 ++++ 7 files changed, 207 insertions(+), 38 deletions(-) diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index 3af75cde7..68f61104e 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -190,6 +190,10 @@ func normalizeURL(o string, u *url.URL) string { return ou.String() } +func (a *Authentication) ExtendAuthzDefaultPolicy(policies [][]string) error { + return a.enforcer.ExtendDefaultPolicy(policies) +} + // verifyTokenHandler is a http handler that verifies the token func (a Authentication) MuxHandler(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -217,7 +221,6 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler { u: normalizeURL(origin, r.URL), m: r.Method, }, r.Header["Dpop"]) - if err != nil { a.logger.WarnContext(r.Context(), "failed to validate token", slog.String("error", err.Error())) http.Error(w, "unauthenticated", http.StatusUnauthorized) @@ -323,9 +326,7 @@ func (a Authentication) UnaryServerInterceptor(ctx context.Context, req any, inf // checkToken is a helper function to verify the token. func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpopInfo receiverInfo, dpopHeader []string) (jwt.Token, context.Context, error) { - var ( - tokenRaw string - ) + var tokenRaw string // If we don't get a DPoP/Bearer token type, we can't proceed switch { @@ -344,7 +345,6 @@ func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpo jwt.WithIssuer(a.oidcConfiguration.Issuer), jwt.WithAudience(a.oidcConfiguration.Audience), ) - if err != nil { return nil, nil, err } diff --git a/service/internal/auth/authn_test.go b/service/internal/auth/authn_test.go index 00fc43da5..e7679cb3a 100644 --- a/service/internal/auth/authn_test.go +++ b/service/internal/auth/authn_test.go @@ -538,6 +538,15 @@ func (s *AuthSuite) TestDPoPEndToEnd_HTTP() { s.Equal(dpopJWK.N(), dpopJWKFromRequest.N()) } +func (s *AuthSuite) Test_AddAuthzPolicies() { + err := s.auth.ExtendAuthzDefaultPolicy([][]string{ + {"p", "role:admin", "/path", "*", "allow"}, + {"p", "role:standard", "/path2", "read", "deny"}, + }) + s.Require().NoError(err) + s.False(s.auth.enforcer.isDefaultPolicy) +} + func makeDPoPToken(t *testing.T, tc dpopTestCase) string { jtiBytes := make([]byte, sdkauth.JTILength) _, err := rand.Read(jtiBytes) diff --git a/service/internal/auth/casbin.go b/service/internal/auth/casbin.go index 5d93e5007..1091f3057 100644 --- a/service/internal/auth/casbin.go +++ b/service/internal/auth/casbin.go @@ -1,6 +1,7 @@ package auth import ( + "errors" "fmt" "log/slog" "strings" @@ -13,8 +14,10 @@ import ( ) var ( - rolePrefix = "role:" - defaultRole = "unknown" + ErrPolicyMalformed = errors.New("malformed authz policy") + rolePrefix = "role:" + defaultRole = "unknown" + defaultPolicyPartsLen = 5 ) var defaultRoleClaim = "realm_access.roles" @@ -120,6 +123,11 @@ type Enforcer struct { *casbin.Enforcer Config CasbinConfig Policy string + + isDefaultRoleClaim bool + isDefaultRoleMap bool + isDefaultPolicy bool + isDefaultModel bool } type casbinSubject struct { @@ -139,37 +147,101 @@ func NewCasbinEnforcer(c CasbinConfig) (*Enforcer, error) { // if err != nil { // return nil, err // } - mStr := defaultModel - if c.Model != "" { - mStr = c.Model + + // Set Casbin config defaults if not provided + isDefaultModel := false + if c.Model == "" { + c.Model = defaultModel + isDefaultModel = true + } + isDefaultPolicy := false + if c.Csv == "" { + c.Csv = defaultPolicy + isDefaultPolicy = true } - pStr := defaultPolicy - if c.Csv != "" { - pStr = c.Csv + policyString := c.Csv + + isDefaultRoleClaim := false + if c.RoleClaim == "" { + isDefaultRoleClaim = true + c.RoleClaim = defaultRoleClaim + } + + isDefaultRoleMap := false + if len(c.RoleMap) == 0 { + isDefaultRoleMap = true + c.RoleMap = defaultRoleMap } - slog.Debug("creating casbin enforcer", slog.Any("config", c)) + slog.Debug("creating casbin enforcer", + slog.Any("config", c), + slog.Bool("isDefaultModel", isDefaultModel), + slog.Bool("isDefaultPolicy", isDefaultPolicy), + slog.Bool("isDefaultRoleMap", isDefaultRoleMap), + slog.Bool("isDefaultRoleClaim", isDefaultRoleClaim), + ) - m, err := casbinModel.NewModelFromString(mStr) + m, err := casbinModel.NewModelFromString(c.Model) if err != nil { return nil, fmt.Errorf("failed to create casbin model: %w", err) } - a := stringadapter.NewAdapter(pStr) + a := stringadapter.NewAdapter(policyString) e, err := casbin.NewEnforcer(m, a) if err != nil { return nil, fmt.Errorf("failed to create casbin enforcer: %w", err) } return &Enforcer{ - Enforcer: e, - Config: c, - Policy: pStr, + Enforcer: e, + Config: c, + Policy: policyString, + isDefaultPolicy: isDefaultPolicy, + isDefaultModel: isDefaultModel, + isDefaultRoleClaim: isDefaultRoleClaim, + isDefaultRoleMap: isDefaultRoleMap, }, nil } +// Extend the default policy +func (e *Enforcer) ExtendDefaultPolicy(policies [][]string) error { + if !e.isDefaultPolicy { + // don't error out, just log a warning + slog.Warn("default authz policy could not be not extended because policies are not the default", slog.Any("unextended_policies", policies)) + return nil + } + + policy := strings.TrimSpace(defaultPolicy) + policy += "\n\n## Extended Policies" + for p := range policies { + pol := policies[p] + polCsv := strings.Join(policies[p], ", ") + if len(pol) < defaultPolicyPartsLen { + return fmt.Errorf("policy missing one of 'p, subject, resource, action, effect', pol: [%s] %w", polCsv, ErrPolicyMalformed) + } + if pol[0] != "p" { + return fmt.Errorf("policy must be prefixed with 'p', pol: [%s] %w", polCsv, ErrPolicyMalformed) + } + if !strings.HasPrefix(pol[1], rolePrefix) { + return fmt.Errorf("policy must contain default role prefix, pol: [%s] %w", polCsv, ErrPolicyMalformed) + } + policy += "\n" + polCsv + } + policy += "\n" + + // Load up new adapter then load the new policy + a := stringadapter.NewAdapter(policy) + e.SetAdapter(a) + if err := e.LoadPolicy(); err != nil { + return fmt.Errorf("failed to load extended default policy: %w", err) + } + e.isDefaultPolicy = false + + return nil +} + // casbinEnforce is a helper function to enforce the policy with casbin // TODO implement a common type so this can be used for both http and grpc -func (e Enforcer) Enforce(token jwt.Token, resource, action string) (bool, error) { +func (e *Enforcer) Enforce(token jwt.Token, resource, action string) (bool, error) { var err error permDeniedError := fmt.Errorf("permission denied") @@ -202,7 +274,7 @@ func (e Enforcer) Enforce(token jwt.Token, resource, action string) (bool, error return true, nil } -func (e Enforcer) buildSubjectFromToken(t jwt.Token) casbinSubject { +func (e *Enforcer) buildSubjectFromToken(t jwt.Token) casbinSubject { slog.Debug("building subject from token", slog.Any("token", t)) roles := e.extractRolesFromToken(t) @@ -212,19 +284,12 @@ func (e Enforcer) buildSubjectFromToken(t jwt.Token) casbinSubject { } } -func (e Enforcer) extractRolesFromToken(t jwt.Token) []string { +func (e *Enforcer) extractRolesFromToken(t jwt.Token) []string { slog.Debug("extracting roles from token", slog.Any("token", t)) roles := []string{} - roleClaim := defaultRoleClaim - if e.Config.RoleClaim != "" { - roleClaim = e.Config.RoleClaim - } - - roleMap := defaultRoleMap - if len(e.Config.RoleMap) > 0 { - roleMap = e.Config.RoleMap - } + roleClaim := e.Config.RoleClaim + roleMap := e.Config.RoleMap selectors := strings.Split(roleClaim, ".") claim, exists := t.Get(selectors[0]) diff --git a/service/internal/auth/casbin_test.go b/service/internal/auth/casbin_test.go index 98d308e76..8339d5263 100644 --- a/service/internal/auth/casbin_test.go +++ b/service/internal/auth/casbin_test.go @@ -55,13 +55,13 @@ func (s *AuthnCasbinSuite) buildTokenRoles(orgAdmin bool, admin bool, standard b return roles } -func (s *AuthnCasbinSuite) newTokWithDefaultClaim(orgAdmin bool, admin bool, standard bool) (string, jwt.Token) { +func (s *AuthnCasbinSuite) newTokWithDefaultClaim(orgAdmin bool, admin bool, standard bool) jwt.Token { tok := jwt.New() tokenRoles := s.buildTokenRoles(orgAdmin, admin, standard, nil) if err := tok.Set("realm_access", map[string]interface{}{"roles": tokenRoles}); err != nil { s.T().Fatal(err) } - return "", tok + return tok } func (s *AuthnCasbinSuite) newTokenWithCustomClaim(orgAdmin bool, admin bool, standard bool) (string, jwt.Token) { @@ -318,7 +318,7 @@ func (s *AuthnCasbinSuite) Test_Enforcement() { slog.Info("running test w/ default claim", slog.String("name", name)) enforcer, err := NewCasbinEnforcer(CasbinConfig{}) s.Require().NoError(err) - _, tok := s.newTokWithDefaultClaim(test.roles[0], test.roles[1], test.roles[2]) + tok := s.newTokWithDefaultClaim(test.roles[0], test.roles[1], test.roles[2]) allowed, err := enforcer.Enforce(tok, test.resource, test.action) if !test.allowed { s.Require().Error(err) @@ -391,3 +391,73 @@ func (s *AuthnCasbinSuite) Test_Enforcement() { s.Equal(test.allowed, allowed) } } + +func (s *AuthnCasbinSuite) Test_ExtendDefaultPolicies() { + enforcer, err := NewCasbinEnforcer(CasbinConfig{}) + s.Require().NoError(err) + tok := s.newTokWithDefaultClaim(true, false, false) + + // Org-admin role + err = enforcer.ExtendDefaultPolicy([][]string{{"p", "role:org-admin", "new.service.*", "*", "allow"}}) + s.Require().NoError(err) + + // original org-admin policy still evaluates correctly + allowed, err := enforcer.Enforce(tok, "policy.attributes.DoSomething", "write") + s.Require().NoError(err) + s.True(allowed) + + // allowed role for new policy is allowed + allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "read") + s.Require().NoError(err) + s.True(allowed) + allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "write") + s.Require().NoError(err) + s.True(allowed) + + // other roles denied new policy: admin + tok = s.newTokWithDefaultClaim(false, true, false) + allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "read") + s.Require().Error(err) + s.False(allowed) + allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "write") + s.Require().Error(err) + s.False(allowed) + + // other roles denied new policy: standard + tok = s.newTokWithDefaultClaim(false, false, true) + allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "read") + s.Require().Error(err) + s.False(allowed) + allowed, err = enforcer.Enforce(tok, "new.service.DoSomething", "write") + s.Require().Error(err) + s.False(allowed) +} + +func (s *AuthnCasbinSuite) Test_ExtendDefaultPolicies_MalformedErrors() { + enforcer, err := NewCasbinEnforcer(CasbinConfig{}) + s.Require().NoError(err) + tok := s.newTokWithDefaultClaim(true, false, false) + allowed, err := enforcer.Enforce(tok, "policy.attributes.DoSomething", "read") + s.Require().NoError(err) + s.True(allowed) + + // missing 'p' + err = enforcer.ExtendDefaultPolicy([][]string{{"role:org-admin", "new.service.DoSomething", "*"}}) + s.Require().Error(err) + s.Require().ErrorIs(err, ErrPolicyMalformed) + + // missing effect + err = enforcer.ExtendDefaultPolicy([][]string{{"p", "role:org-admin", "new.service.DoSomething", "*"}}) + s.Require().Error(err) + s.Require().ErrorIs(err, ErrPolicyMalformed) + + // empty + err = enforcer.ExtendDefaultPolicy([][]string{{}}) + s.Require().Error(err) + s.Require().ErrorIs(err, ErrPolicyMalformed) + + // missing role prefix + err = enforcer.ExtendDefaultPolicy([][]string{{"p", "org-admin", "new.service.DoSomething", "*"}}) + s.Require().Error(err) + s.Require().ErrorIs(err, ErrPolicyMalformed) +} diff --git a/service/internal/server/server.go b/service/internal/server/server.go index cd46b1989..69a31fccc 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -87,6 +87,7 @@ type CORSConfig struct { } type OpenTDFServer struct { + AuthN *auth.Authentication Mux *runtime.ServeMux HTTPServer *http.Server GRPCServer *grpc.Server @@ -148,6 +149,7 @@ func NewOpenTDFServer(config Config, logr *logger.Logger) (*OpenTDFServer, error } o := OpenTDFServer{ + AuthN: authN, Mux: mux, HTTPServer: httpServer, GRPCServer: grpcServer, diff --git a/service/pkg/server/options.go b/service/pkg/server/options.go index 3840e7f1a..49850c4b4 100644 --- a/service/pkg/server/options.go +++ b/service/pkg/server/options.go @@ -3,10 +3,11 @@ package server type StartOptions func(StartConfig) StartConfig type StartConfig struct { - ConfigKey string - ConfigFile string - WaitForShutdownSignal bool - PublicRoutes []string + ConfigKey string + ConfigFile string + WaitForShutdownSignal bool + PublicRoutes []string + authzDefaultPolicyExtension [][]string } // Deprecated: Use WithConfigKey @@ -44,3 +45,10 @@ func WithPublicRoutes(routes []string) StartOptions { return c } } + +func WithAuthZDefaultPolicyExtension(policies [][]string) StartOptions { + return func(c StartConfig) StartConfig { + c.authzDefaultPolicyExtension = policies + return c + } +} diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index a85d53723..92d46ac11 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -2,6 +2,7 @@ package server import ( "context" + "errors" "fmt" "log/slog" "os" @@ -81,6 +82,20 @@ func Start(f ...StartOptions) error { } defer otdf.Stop() + // Append the authz policies + if len(startConfig.authzDefaultPolicyExtension) > 0 { + if otdf.AuthN == nil { + err := errors.New("authn not enabled") + logger.Error("issue adding authz policies", "error", err) + return fmt.Errorf("issue adding authz policies: %w", err) + } + err := otdf.AuthN.ExtendAuthzDefaultPolicy(startConfig.authzDefaultPolicyExtension) + if err != nil { + logger.Error("issue adding authz policies", slog.String("error", err.Error())) + return fmt.Errorf("issue adding authz policies: %w", err) + } + } + logger.Info("registering services") if err := registerServices(); err != nil { logger.Error("issue registering services", slog.String("error", err.Error()))