Skip to content

Commit

Permalink
fix: rbac custom group privilege level check (#39164)
Browse files Browse the repository at this point in the history
related: #39086

Signed-off-by: shaoting-huang <[email protected]>
  • Loading branch information
shaoting-huang authored Jan 13, 2025
1 parent 5f94954 commit 5c5948c
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 185 deletions.
4 changes: 0 additions & 4 deletions internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -5439,10 +5439,6 @@ func (node *Proxy) validateOperatePrivilegeV2Params(req *milvuspb.OperatePrivile
if err := ValidatePrivilege(req.Grantor.Privilege.Name); err != nil {
return err
}
// validate built-in privilege group params
if err := ValidateBuiltInPrivilegeGroup(req.Grantor.Privilege.Name, req.DbName, req.CollectionName); err != nil {
return err
}
if req.Type != milvuspb.OperatePrivilegeType_Grant && req.Type != milvuspb.OperatePrivilegeType_Revoke {
return merr.WrapErrParameterInvalidMsg("the type in the request not grant or revoke")
}
Expand Down
25 changes: 0 additions & 25 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -1120,31 +1120,6 @@ func ValidatePrivilege(entity string) error {
return validateName(entity, "Privilege")
}

func ValidateBuiltInPrivilegeGroup(entity string, dbName string, collectionName string) error {
if !util.IsBuiltinPrivilegeGroup(entity) {
return nil
}
switch {
case strings.HasPrefix(entity, milvuspb.PrivilegeLevel_Cluster.String()):
if !util.IsAnyWord(dbName) || !util.IsAnyWord(collectionName) {
return merr.WrapErrParameterInvalidMsg("dbName and collectionName should be * for the cluster level privilege: %s", entity)
}
return nil
case strings.HasPrefix(entity, milvuspb.PrivilegeLevel_Database.String()):
if collectionName != "" && collectionName != util.AnyWord {
return merr.WrapErrParameterInvalidMsg("collectionName should be * for the database level privilege: %s", entity)
}
return nil
case strings.HasPrefix(entity, milvuspb.PrivilegeLevel_Collection.String()):
if util.IsAnyWord(dbName) && !util.IsAnyWord(collectionName) && collectionName != "" {
return merr.WrapErrParameterInvalidMsg("please specify database name for the collection level privilege: %s", entity)
}
return nil
default:
return nil
}
}

func GetCurUserFromContext(ctx context.Context) (string, error) {
return contextutil.GetCurUserFromContext(ctx)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/rootcoord/meta_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -1525,7 +1525,7 @@ func (mt *MetaTable) RestoreRBAC(ctx context.Context, tenant string, meta *milvu
return mt.catalog.RestoreRBAC(ctx, tenant, meta)
}

// check if the privielge group name is defined by users
// check if the privilege group name is defined by users
func (mt *MetaTable) IsCustomPrivilegeGroup(ctx context.Context, groupName string) (bool, error) {
privGroups, err := mt.catalog.ListPrivilegeGroups(ctx)
if err != nil {
Expand Down Expand Up @@ -1641,7 +1641,7 @@ func (mt *MetaTable) OperatePrivilegeGroup(ctx context.Context, groupName string
if group.GroupName == p.Name {
privileges = append(privileges, group.Privileges...)
} else {
return merr.WrapErrParameterInvalidMsg("there is no privilege name or privielge group name [%s] defined in system to operate", p.Name)
return merr.WrapErrParameterInvalidMsg("there is no privilege name or privilege group name [%s] defined in system to operate", p.Name)
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions internal/rootcoord/rbac_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func executeOperatePrivilegeTaskSteps(ctx context.Context, core *Core, in *milvu
}
grants := []*milvuspb.GrantEntity{in.Entity}

allGroups, err := core.getPrivilegeGroups(ctx)
allGroups, err := core.getDefaultAndCustomPrivilegeGroups(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -275,12 +275,12 @@ func executeOperatePrivilegeGroupTaskSteps(ctx context.Context, core *Core, in *
return p.Name
})

// check if privileges are the same object type
objectTypes := lo.SliceToMap(newPrivs, func(p *milvuspb.PrivilegeEntity) (string, struct{}) {
return util.GetObjectType(p.Name), struct{}{}
// check if privileges are the same privilege level
privilegeLevels := lo.SliceToMap(newPrivs, func(p *milvuspb.PrivilegeEntity) (string, struct{}) {
return util.GetPrivilegeLevel(p.Name), struct{}{}
})
if len(objectTypes) > 1 {
return nil, errors.New("privileges are not the same object type")
if len(privilegeLevels) > 1 {
return nil, errors.New("privileges are not the same privilege level")
}
case milvuspb.OperatePrivilegeGroupType_RemovePrivilegesFromGroup:
newPrivs, _ := lo.Difference(v, in.Privileges)
Expand Down
46 changes: 43 additions & 3 deletions internal/rootcoord/root_coord.go
Original file line number Diff line number Diff line change
Expand Up @@ -2595,6 +2595,10 @@ func (c *Core) OperatePrivilege(ctx context.Context, in *milvuspb.OperatePrivile
ctxLog.Error("", zap.Error(err))
return merr.StatusWithErrorCode(err, commonpb.ErrorCode_OperatePrivilegeFailure), nil
}
if err := c.validatePrivilegeGroupParams(ctx, privName, in.Entity.DbName, in.Entity.ObjectName); err != nil {
ctxLog.Error("", zap.Error(err))
return merr.StatusWithErrorCode(err, commonpb.ErrorCode_OperatePrivilegeFailure), nil
}
// set up object type for metastore, to be compatible with v1 version
in.Entity.Object.Name = util.GetObjectType(privName)
default:
Expand Down Expand Up @@ -2656,6 +2660,42 @@ func (c *Core) operatePrivilegeCommonCheck(ctx context.Context, in *milvuspb.Ope
return nil
}

func (c *Core) validatePrivilegeGroupParams(ctx context.Context, entity string, dbName string, collectionName string) error {
allGroups, err := c.getDefaultAndCustomPrivilegeGroups(ctx)
if err != nil {
return err
}
groups := lo.SliceToMap(allGroups, func(group *milvuspb.PrivilegeGroupInfo) (string, []*milvuspb.PrivilegeEntity) {
return group.GroupName, group.Privileges
})
privs, exists := groups[entity]
if !exists || len(privs) == 0 {
// it is a privilege, no need to check with other params
return nil
}
// since all privileges are same level in a group, just check the first privilege
level := util.GetPrivilegeLevel(privs[0].GetName())
switch level {
case milvuspb.PrivilegeLevel_Cluster.String():
if !util.IsAnyWord(dbName) || !util.IsAnyWord(collectionName) {
return merr.WrapErrParameterInvalidMsg("dbName and collectionName should be * for the cluster level privilege: %s", entity)
}
return nil
case milvuspb.PrivilegeLevel_Database.String():
if collectionName != "" && collectionName != util.AnyWord {
return merr.WrapErrParameterInvalidMsg("collectionName should be * for the database level privilege: %s", entity)
}
return nil
case milvuspb.PrivilegeLevel_Collection.String():
if util.IsAnyWord(dbName) && !util.IsAnyWord(collectionName) && collectionName != "" {
return merr.WrapErrParameterInvalidMsg("please specify database name for the collection level privilege: %s", entity)
}
return nil
default:
return errors.New("not found the privilege level")
}
}

func (c *Core) getMetastorePrivilegeName(ctx context.Context, privName string) (string, error) {
// if it is built-in privilege, return the privilege name directly
if util.IsPrivilegeNameDefined(privName) {
Expand Down Expand Up @@ -2757,7 +2797,7 @@ func (c *Core) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest)
}, nil
}
// expand privilege groups and turn to policies
allGroups, err := c.getPrivilegeGroups(ctx)
allGroups, err := c.getDefaultAndCustomPrivilegeGroups(ctx)
if err != nil {
errMsg := "fail to get privilege groups"
ctxLog.Warn(errMsg, zap.Error(err))
Expand Down Expand Up @@ -3131,8 +3171,8 @@ func (c *Core) expandPrivilegeGroups(ctx context.Context, grants []*milvuspb.Gra
}), nil
}

// getPrivilegeGroups returns default privilege groups and user-defined privilege groups.
func (c *Core) getPrivilegeGroups(ctx context.Context) ([]*milvuspb.PrivilegeGroupInfo, error) {
// getDefaultAndCustomPrivilegeGroups returns default privilege groups and user-defined privilege groups.
func (c *Core) getDefaultAndCustomPrivilegeGroups(ctx context.Context) ([]*milvuspb.PrivilegeGroupInfo, error) {
allGroups, err := c.meta.ListPrivilegeGroups(ctx)
allGroups = append(allGroups, Params.RbacConfig.GetDefaultPrivilegeGroups()...)
if err != nil {
Expand Down
119 changes: 119 additions & 0 deletions pkg/util/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/samber/lo"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
Expand Down Expand Up @@ -292,6 +293,124 @@ var (
}
)

// rbac v2 uses privilege level to group privileges rather than object type
var (
CollectionReadOnlyPrivileges = ConvertPrivileges([]string{
commonpb.ObjectPrivilege_PrivilegeQuery.String(),
commonpb.ObjectPrivilege_PrivilegeSearch.String(),
commonpb.ObjectPrivilege_PrivilegeIndexDetail.String(),
commonpb.ObjectPrivilege_PrivilegeGetFlushState.String(),
commonpb.ObjectPrivilege_PrivilegeGetLoadState.String(),
commonpb.ObjectPrivilege_PrivilegeGetLoadingProgress.String(),
commonpb.ObjectPrivilege_PrivilegeHasPartition.String(),
commonpb.ObjectPrivilege_PrivilegeShowPartitions.String(),
commonpb.ObjectPrivilege_PrivilegeDescribeCollection.String(),
commonpb.ObjectPrivilege_PrivilegeDescribeAlias.String(),
commonpb.ObjectPrivilege_PrivilegeGetStatistics.String(),
commonpb.ObjectPrivilege_PrivilegeListAliases.String(),
})

CollectionReadWritePrivileges = append(CollectionReadOnlyPrivileges,
ConvertPrivileges([]string{
commonpb.ObjectPrivilege_PrivilegeLoad.String(),
commonpb.ObjectPrivilege_PrivilegeRelease.String(),
commonpb.ObjectPrivilege_PrivilegeInsert.String(),
commonpb.ObjectPrivilege_PrivilegeDelete.String(),
commonpb.ObjectPrivilege_PrivilegeUpsert.String(),
commonpb.ObjectPrivilege_PrivilegeImport.String(),
commonpb.ObjectPrivilege_PrivilegeFlush.String(),
commonpb.ObjectPrivilege_PrivilegeCompaction.String(),
commonpb.ObjectPrivilege_PrivilegeLoadBalance.String(),
commonpb.ObjectPrivilege_PrivilegeCreateIndex.String(),
commonpb.ObjectPrivilege_PrivilegeDropIndex.String(),
commonpb.ObjectPrivilege_PrivilegeCreatePartition.String(),
commonpb.ObjectPrivilege_PrivilegeDropPartition.String(),
})...,
)

CollectionAdminPrivileges = append(CollectionReadWritePrivileges,
ConvertPrivileges([]string{
commonpb.ObjectPrivilege_PrivilegeCreateAlias.String(),
commonpb.ObjectPrivilege_PrivilegeDropAlias.String(),
})...,
)

DatabaseReadOnlyPrivileges = ConvertPrivileges([]string{
commonpb.ObjectPrivilege_PrivilegeShowCollections.String(),
commonpb.ObjectPrivilege_PrivilegeDescribeDatabase.String(),
})

DatabaseReadWritePrivileges = append(DatabaseReadOnlyPrivileges,
ConvertPrivileges([]string{
commonpb.ObjectPrivilege_PrivilegeAlterDatabase.String(),
})...,
)

DatabaseAdminPrivileges = append(DatabaseReadWritePrivileges,
ConvertPrivileges([]string{
commonpb.ObjectPrivilege_PrivilegeCreateCollection.String(),
commonpb.ObjectPrivilege_PrivilegeDropCollection.String(),
})...,
)

ClusterReadOnlyPrivileges = ConvertPrivileges([]string{
commonpb.ObjectPrivilege_PrivilegeListDatabases.String(),
commonpb.ObjectPrivilege_PrivilegeSelectOwnership.String(),
commonpb.ObjectPrivilege_PrivilegeSelectUser.String(),
commonpb.ObjectPrivilege_PrivilegeDescribeResourceGroup.String(),
commonpb.ObjectPrivilege_PrivilegeListResourceGroups.String(),
commonpb.ObjectPrivilege_PrivilegeListPrivilegeGroups.String(),
})

ClusterReadWritePrivileges = append(ClusterReadOnlyPrivileges,
ConvertPrivileges([]string{
commonpb.ObjectPrivilege_PrivilegeFlushAll.String(),
commonpb.ObjectPrivilege_PrivilegeTransferNode.String(),
commonpb.ObjectPrivilege_PrivilegeTransferReplica.String(),
commonpb.ObjectPrivilege_PrivilegeUpdateResourceGroups.String(),
})...,
)

ClusterAdminPrivileges = append(ClusterReadWritePrivileges,
ConvertPrivileges([]string{
commonpb.ObjectPrivilege_PrivilegeBackupRBAC.String(),
commonpb.ObjectPrivilege_PrivilegeRestoreRBAC.String(),
commonpb.ObjectPrivilege_PrivilegeCreateDatabase.String(),
commonpb.ObjectPrivilege_PrivilegeDropDatabase.String(),
commonpb.ObjectPrivilege_PrivilegeCreateOwnership.String(),
commonpb.ObjectPrivilege_PrivilegeDropOwnership.String(),
commonpb.ObjectPrivilege_PrivilegeManageOwnership.String(),
commonpb.ObjectPrivilege_PrivilegeCreateResourceGroup.String(),
commonpb.ObjectPrivilege_PrivilegeDropResourceGroup.String(),
commonpb.ObjectPrivilege_PrivilegeUpdateUser.String(),
commonpb.ObjectPrivilege_PrivilegeRenameCollection.String(),
commonpb.ObjectPrivilege_PrivilegeCreatePrivilegeGroup.String(),
commonpb.ObjectPrivilege_PrivilegeDropPrivilegeGroup.String(),
commonpb.ObjectPrivilege_PrivilegeOperatePrivilegeGroup.String(),
})...,
)
)

// ConvertPrivileges converts each privilege from metastore format to API format.
func ConvertPrivileges(privileges []string) []string {
return lo.Map(privileges, func(name string, _ int) string {
return MetaStore2API(name)
})
}

func GetPrivilegeLevel(privilege string) string {
if lo.Contains(ClusterAdminPrivileges, privilege) {
return milvuspb.PrivilegeLevel_Cluster.String()
}
if lo.Contains(DatabaseAdminPrivileges, privilege) {
return milvuspb.PrivilegeLevel_Database.String()
}
if lo.Contains(CollectionAdminPrivileges, privilege) {
return milvuspb.PrivilegeLevel_Collection.String()
}
return ""
}

// StringSet convert array to map for conveniently check if the array contains an element
func StringSet(strings []string) map[string]struct{} {
stringsMap := make(map[string]struct{})
Expand Down
24 changes: 11 additions & 13 deletions pkg/util/paramtable/rbac_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,24 @@ import (
"github.com/stretchr/testify/assert"
)

func TestRbacConfig_Init(t *testing.T) {
func TestRbacConfig_DefaultPrivileges(t *testing.T) {
params := ComponentParam{}
params.Init(NewBaseTable(SkipRemote(true)))
cfg := &params.RbacConfig
assert.Equal(t, len(cfg.GetDefaultPrivilegeGroupNames()), 9)
assert.True(t, cfg.IsCollectionPrivilegeGroup("CollectionReadOnly"))
assert.False(t, cfg.IsCollectionPrivilegeGroup("DatabaseReadOnly"))
assert.Equal(t, cfg.Enabled.GetAsBool(), false)
assert.Equal(t, cfg.ClusterReadOnlyPrivileges.GetAsStrings(), builtinPrivilegeGroups["ClusterReadOnly"])
assert.Equal(t, cfg.ClusterReadWritePrivileges.GetAsStrings(), builtinPrivilegeGroups["ClusterReadWrite"])
assert.Equal(t, cfg.ClusterAdminPrivileges.GetAsStrings(), builtinPrivilegeGroups["ClusterAdmin"])
assert.Equal(t, cfg.DBReadOnlyPrivileges.GetAsStrings(), builtinPrivilegeGroups["DatabaseReadOnly"])
assert.Equal(t, cfg.DBReadWritePrivileges.GetAsStrings(), builtinPrivilegeGroups["DatabaseReadWrite"])
assert.Equal(t, cfg.DBAdminPrivileges.GetAsStrings(), builtinPrivilegeGroups["DatabaseAdmin"])
assert.Equal(t, cfg.CollectionReadOnlyPrivileges.GetAsStrings(), builtinPrivilegeGroups["CollectionReadOnly"])
assert.Equal(t, cfg.CollectionReadWritePrivileges.GetAsStrings(), builtinPrivilegeGroups["CollectionReadWrite"])
assert.Equal(t, cfg.CollectionAdminPrivileges.GetAsStrings(), builtinPrivilegeGroups["CollectionAdmin"])
assert.Equal(t, cfg.ClusterReadOnlyPrivileges.GetAsStrings(), cfg.GetDefaultPrivilegeGroupPrivileges("ClusterReadOnly"))
assert.Equal(t, cfg.ClusterReadWritePrivileges.GetAsStrings(), cfg.GetDefaultPrivilegeGroupPrivileges("ClusterReadWrite"))
assert.Equal(t, cfg.ClusterAdminPrivileges.GetAsStrings(), cfg.GetDefaultPrivilegeGroupPrivileges("ClusterAdmin"))
assert.Equal(t, cfg.DBReadOnlyPrivileges.GetAsStrings(), cfg.GetDefaultPrivilegeGroupPrivileges("DatabaseReadOnly"))
assert.Equal(t, cfg.DBReadWritePrivileges.GetAsStrings(), cfg.GetDefaultPrivilegeGroupPrivileges("DatabaseReadWrite"))
assert.Equal(t, cfg.DBAdminPrivileges.GetAsStrings(), cfg.GetDefaultPrivilegeGroupPrivileges("DatabaseAdmin"))
assert.Equal(t, cfg.CollectionReadOnlyPrivileges.GetAsStrings(), cfg.GetDefaultPrivilegeGroupPrivileges("CollectionReadOnly"))
assert.Equal(t, cfg.CollectionReadWritePrivileges.GetAsStrings(), cfg.GetDefaultPrivilegeGroupPrivileges("CollectionReadWrite"))
assert.Equal(t, cfg.CollectionAdminPrivileges.GetAsStrings(), cfg.GetDefaultPrivilegeGroupPrivileges("CollectionAdmin"))
}

func TestRbacConfig_Override(t *testing.T) {
func TestRbacConfig_OverridePrivileges(t *testing.T) {
params := ComponentParam{}
params.Init(NewBaseTable(SkipRemote(true)))

Expand Down
Loading

0 comments on commit 5c5948c

Please sign in to comment.