diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 708852708395e..f0c4a73f9f0ac 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -5334,7 +5334,7 @@ func (node *Proxy) validPrivilegeParams(req *milvuspb.OperatePrivilegeRequest) e func (node *Proxy) validateOperatePrivilegeV2Params(req *milvuspb.OperatePrivilegeV2Request) error { if req.Role == nil { - return fmt.Errorf("the role in the request is nil") + return merr.WrapErrParameterInvalidMsg("the role in the request is nil") } if err := ValidateRoleName(req.Role.Name); err != nil { return err @@ -5342,8 +5342,12 @@ func (node *Proxy) validateOperatePrivilegeV2Params(req *milvuspb.OperatePrivile if err := ValidatePrivilege(req.Grantor.Privilege.Name); err != nil { return err } + // validate built-in privilege group params + if err := ValidateBuiltInPrivilegeGroup(req.Grantor.Privilege.Name, req.DbName, req.CollectionName); err != nil { + return err + } if req.Type != milvuspb.OperatePrivilegeType_Grant && req.Type != milvuspb.OperatePrivilegeType_Revoke { - return fmt.Errorf("the type in the request not grant or revoke") + return merr.WrapErrParameterInvalidMsg("the type in the request not grant or revoke") } if req.DbName != "" && !util.IsAnyWord(req.DbName) { if err := ValidateDatabaseName(req.DbName); err != nil { diff --git a/internal/proxy/util.go b/internal/proxy/util.go index a41f8598c1607..ecbc37c62b9b9 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1113,6 +1113,31 @@ func ValidatePrivilege(entity string) error { return validateName(entity, "Privilege") } +func ValidateBuiltInPrivilegeGroup(entity string, dbName string, collectionName string) error { + if !util.IsBuiltinPrivilegeGroup(entity) { + return nil + } + switch { + case strings.HasPrefix(entity, milvuspb.PrivilegeLevel_Cluster.String()): + if !util.IsAnyWord(dbName) || !util.IsAnyWord(collectionName) { + return merr.WrapErrParameterInvalidMsg("dbName and collectionName should be * for the cluster level privilege: %s", entity) + } + return nil + case strings.HasPrefix(entity, milvuspb.PrivilegeLevel_Database.String()): + if collectionName != "" && collectionName != util.AnyWord { + return merr.WrapErrParameterInvalidMsg("collectionName should be * for the database level privilege: %s", entity) + } + return nil + case strings.HasPrefix(entity, milvuspb.PrivilegeLevel_Collection.String()): + if util.IsAnyWord(dbName) && !util.IsAnyWord(collectionName) && collectionName != "" { + return merr.WrapErrParameterInvalidMsg("please specify database name for the collection level privilege: %s", entity) + } + return nil + default: + return nil + } +} + func GetCurUserFromContext(ctx context.Context) (string, error) { return contextutil.GetCurUserFromContext(ctx) } diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 4d3c9f93c3592..13bff93517145 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -22,7 +22,6 @@ import ( "math/rand" "os" "strconv" - "strings" "sync" "time" @@ -2588,43 +2587,21 @@ func (c *Core) isValidPrivilege(ctx context.Context, privilegeName string, objec return fmt.Errorf("not found the privilege name[%s] in object[%s]", privilegeName, object) } -func (c *Core) isValidPrivilegeV2(ctx context.Context, privilegeName, dbName, collectionName string) error { +func (c *Core) isValidPrivilegeV2(ctx context.Context, privilegeName string) error { if util.IsAnyWord(privilegeName) { return nil } - var privilegeLevel string - for group, privileges := range util.BuiltinPrivilegeGroups { - if privilegeName == group || lo.Contains(privileges, privilegeName) { - privilegeLevel = group - break - } - } - if privilegeLevel == "" { - customPrivGroup, err := c.meta.IsCustomPrivilegeGroup(ctx, privilegeName) - if err != nil { - return err - } - if customPrivGroup { - return nil - } - return fmt.Errorf("not found the privilege name[%s] in the custom privilege groups", privilegeName) + customPrivGroup, err := c.meta.IsCustomPrivilegeGroup(ctx, privilegeName) + if err != nil { + return err } - switch { - case strings.HasPrefix(privilegeLevel, milvuspb.PrivilegeLevel_Cluster.String()): - if !util.IsAnyWord(dbName) || !util.IsAnyWord(collectionName) { - return fmt.Errorf("dbName and collectionName should be * for the cluster level privilege: %s", privilegeName) - } - return nil - case strings.HasPrefix(privilegeLevel, milvuspb.PrivilegeLevel_Database.String()): - if collectionName != "" && collectionName != util.AnyWord { - return fmt.Errorf("collectionName should be empty or * for the database level privilege: %s", privilegeName) - } - return nil - case strings.HasPrefix(privilegeLevel, milvuspb.PrivilegeLevel_Collection.String()): + if customPrivGroup { return nil - default: + } + if util.IsPrivilegeNameDefined(privilegeName) { return nil } + return fmt.Errorf("not found the privilege name[%s]", privilegeName) } // OperatePrivilege operate the privilege, including grant and revoke @@ -2648,7 +2625,7 @@ func (c *Core) OperatePrivilege(ctx context.Context, in *milvuspb.OperatePrivile privName := in.Entity.Grantor.Privilege.Name switch in.Version { case "v2": - if err := c.isValidPrivilegeV2(ctx, privName, in.Entity.DbName, in.Entity.ObjectName); err != nil { + if err := c.isValidPrivilegeV2(ctx, privName); err != nil { ctxLog.Error("", zap.Error(err)) return merr.StatusWithErrorCode(err, commonpb.ErrorCode_OperatePrivilegeFailure), nil } diff --git a/tests/integration/rbac/privilege_group_test.go b/tests/integration/rbac/privilege_group_test.go index 8be1474a4f772..7e44c5301e119 100644 --- a/tests/integration/rbac/privilege_group_test.go +++ b/tests/integration/rbac/privilege_group_test.go @@ -179,7 +179,7 @@ func (s *PrivilegeGroupTestSuite) TestGrantV2BuiltinPrivilegeGroup() { resp, _ = s.operatePrivilegeV2(ctx, roleName, "CollectionAdmin", "db1", "col1", milvuspb.OperatePrivilegeType_Grant) s.True(merr.Ok(resp)) resp, _ = s.operatePrivilegeV2(ctx, roleName, "CollectionAdmin", util.AnyWord, "col1", milvuspb.OperatePrivilegeType_Grant) - s.True(merr.Ok(resp)) + s.False(merr.Ok(resp)) } func (s *PrivilegeGroupTestSuite) TestGrantV2CustomPrivilegeGroup() {