Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add the rbac msg and send them to the replicate channel #39185

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ func (node *Proxy) AlterDatabase(ctx context.Context, request *milvuspb.AlterDat
Condition: NewTaskCondition(ctx),
AlterDatabaseRequest: request,
rootCoord: node.rootCoord,
replicateMsgStream: node.replicateMsgStream,
}

log := log.Ctx(ctx).With(
Expand Down Expand Up @@ -4853,6 +4854,10 @@ func (node *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCre
err = errors.Wrap(err, "encrypt password failed")
return merr.Status(err), nil
}
if req.Base == nil {
req.Base = &commonpb.MsgBase{}
}
req.Base.MsgType = commonpb.MsgType_CreateCredential

credInfo := &internalpb.CredentialInfo{
Username: req.Username,
Expand All @@ -4865,6 +4870,9 @@ func (node *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCre
zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, err
}

Expand Down Expand Up @@ -4922,6 +4930,10 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre
err = errors.Wrap(err, "encrypt password failed")
return merr.Status(err), nil
}
if req.Base == nil {
req.Base = &commonpb.MsgBase{}
}
req.Base.MsgType = commonpb.MsgType_UpdateCredential
updateCredReq := &internalpb.CredentialInfo{
Username: req.Username,
Sha256Password: crypto.SHA256(rawNewPassword, req.Username),
Expand All @@ -4933,6 +4945,9 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre
zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, err
}

Expand All @@ -4953,12 +4968,19 @@ func (node *Proxy) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCre
err := merr.WrapErrPrivilegeNotPermitted("root user cannot be deleted")
return merr.Status(err), nil
}
if req.Base == nil {
req.Base = &commonpb.MsgBase{}
}
req.Base.MsgType = commonpb.MsgType_DeleteCredential
result, err := node.rootCoord.DeleteCredential(ctx, req)
if err != nil { // for error like conntext timeout etc.
log.Error("delete credential fail",
zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, err
}

Expand All @@ -4973,6 +4995,10 @@ func (node *Proxy) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUser
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.ListCredUsersResponse{Status: merr.Status(err)}, nil
}
if req.Base == nil {
req.Base = &commonpb.MsgBase{}
}
req.Base.MsgType = commonpb.MsgType_ListCredUsernames
rootCoordReq := &milvuspb.ListCredUsersRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_ListCredUsernames),
Expand Down Expand Up @@ -5008,12 +5034,19 @@ func (node *Proxy) CreateRole(ctx context.Context, req *milvuspb.CreateRoleReque
if err := ValidateRoleName(roleName); err != nil {
return merr.Status(err), nil
}
if req.Base == nil {
req.Base = &commonpb.MsgBase{}
}
req.Base.MsgType = commonpb.MsgType_CreateRole

result, err := node.rootCoord.CreateRole(ctx, req)
if err != nil {
log.Warn("fail to create role", zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}

Expand All @@ -5031,6 +5064,10 @@ func (node *Proxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest)
if err := ValidateRoleName(req.RoleName); err != nil {
return merr.Status(err), nil
}
if req.Base == nil {
req.Base = &commonpb.MsgBase{}
}
req.Base.MsgType = commonpb.MsgType_DropRole
if IsDefaultRole(req.RoleName) {
err := merr.WrapErrPrivilegeNotPermitted("the role[%s] is a default role, which can't be dropped", req.GetRoleName())
return merr.Status(err), nil
Expand All @@ -5042,6 +5079,9 @@ func (node *Proxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest)
zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}

Expand All @@ -5061,12 +5101,19 @@ func (node *Proxy) OperateUserRole(ctx context.Context, req *milvuspb.OperateUse
if err := ValidateRoleName(req.RoleName); err != nil {
return merr.Status(err), nil
}
if req.Base == nil {
req.Base = &commonpb.MsgBase{}
}
req.Base.MsgType = commonpb.MsgType_OperateUserRole

result, err := node.rootCoord.OperateUserRole(ctx, req)
if err != nil {
log.Warn("fail to operate user role", zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}

Expand All @@ -5088,6 +5135,10 @@ func (node *Proxy) SelectRole(ctx context.Context, req *milvuspb.SelectRoleReque
}, nil
}
}
if req.Base == nil {
req.Base = &commonpb.MsgBase{}
}
req.Base.MsgType = commonpb.MsgType_SelectRole

result, err := node.rootCoord.SelectRole(ctx, req)
if err != nil {
Expand Down Expand Up @@ -5118,6 +5169,10 @@ func (node *Proxy) SelectUser(ctx context.Context, req *milvuspb.SelectUserReque
}, nil
}
}
if req.Base == nil {
req.Base = &commonpb.MsgBase{}
}
req.Base.MsgType = commonpb.MsgType_SelectUser

result, err := node.rootCoord.SelectUser(ctx, req)
if err != nil {
Expand Down Expand Up @@ -5175,6 +5230,10 @@ func (node *Proxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePr
if err := node.validPrivilegeParams(req); err != nil {
return merr.Status(err), nil
}
if req.Base == nil {
req.Base = &commonpb.MsgBase{}
}
req.Base.MsgType = commonpb.MsgType_OperatePrivilege
curUser, err := GetCurUserFromContext(ctx)
if err != nil {
log.Warn("fail to get current user", zap.Error(err))
Expand Down Expand Up @@ -5202,6 +5261,9 @@ func (node *Proxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePr
}
}
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}

Expand Down Expand Up @@ -5248,6 +5310,10 @@ func (node *Proxy) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantReq
Status: merr.Status(err),
}, nil
}
if req.Base == nil {
req.Base = &commonpb.MsgBase{}
}
req.Base.MsgType = commonpb.MsgType_SelectGrant

result, err := node.rootCoord.SelectGrant(ctx, req)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3895,7 +3895,7 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) {
resp, _ := proxy.CreateRole(ctx, &milvuspb.CreateRoleRequest{Entity: entity})
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)

entity.Name = "unit_test"
entity.Name = "unit_test1000"
resp, _ = proxy.CreateRole(ctx, &milvuspb.CreateRoleRequest{Entity: entity})
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)

Expand Down
3 changes: 3 additions & 0 deletions internal/proxy/task_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ type alterDatabaseTask struct {
ctx context.Context
rootCoord types.RootCoordClient
result *commonpb.Status

replicateMsgStream msgstream.MsgStream
}

func (t *alterDatabaseTask) TraceCtx() context.Context {
Expand Down Expand Up @@ -291,6 +293,7 @@ func (t *alterDatabaseTask) Execute(ctx context.Context) error {
return err
}

SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterDatabaseRequest)
t.result = ret
return nil
}
Expand Down
40 changes: 40 additions & 0 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,11 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream.
BaseMsg: getBaseMsg(ctx, ts),
DropDatabaseRequest: r,
}
case *milvuspb.AlterDatabaseRequest:
tsMsg = &msgstream.AlterDatabaseMsg{
BaseMsg: getBaseMsg(ctx, ts),
AlterDatabaseRequest: r,
}
case *milvuspb.FlushRequest:
tsMsg = &msgstream.FlushMsg{
BaseMsg: getBaseMsg(ctx, ts),
Expand Down Expand Up @@ -1618,6 +1623,41 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream.
BaseMsg: getBaseMsg(ctx, ts),
AlterIndexRequest: r,
}
case *milvuspb.CreateCredentialRequest:
tsMsg = &msgstream.CreateUserMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateCredentialRequest: r,
}
case *milvuspb.UpdateCredentialRequest:
tsMsg = &msgstream.UpdateUserMsg{
BaseMsg: getBaseMsg(ctx, ts),
UpdateCredentialRequest: r,
}
case *milvuspb.DeleteCredentialRequest:
tsMsg = &msgstream.DeleteUserMsg{
BaseMsg: getBaseMsg(ctx, ts),
DeleteCredentialRequest: r,
}
case *milvuspb.CreateRoleRequest:
tsMsg = &msgstream.CreateRoleMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateRoleRequest: r,
}
case *milvuspb.DropRoleRequest:
tsMsg = &msgstream.DropRoleMsg{
BaseMsg: getBaseMsg(ctx, ts),
DropRoleRequest: r,
}
case *milvuspb.OperateUserRoleRequest:
tsMsg = &msgstream.OperateUserRoleMsg{
BaseMsg: getBaseMsg(ctx, ts),
OperateUserRoleRequest: r,
}
case *milvuspb.OperatePrivilegeRequest:
tsMsg = &msgstream.OperatePrivilegeMsg{
BaseMsg: getBaseMsg(ctx, ts),
OperatePrivilegeRequest: r,
}
default:
log.Warn("unknown request", zap.Any("request", request))
return
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/meta/segment_dist_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (f *ReplicaSegDistFilter) Match(s *Segment) bool {
return f.GetCollectionID() == s.GetCollectionID() && f.Contains(s.Node)
}

func (f ReplicaSegDistFilter) AddFilter(filter *segDistCriterion) {
func (f *ReplicaSegDistFilter) AddFilter(filter *segDistCriterion) {
filter.nodes = f.GetNodes()
filter.collectionID = f.GetCollectionID()
}
Expand Down
54 changes: 54 additions & 0 deletions pkg/mq/msgstream/msg_for_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,57 @@
func (d *DropDatabaseMsg) Size() int {
return proto.Size(d.DropDatabaseRequest)
}

type AlterDatabaseMsg struct {
BaseMsg
*milvuspb.AlterDatabaseRequest
}

var _ TsMsg = &AlterDatabaseMsg{}

func (a *AlterDatabaseMsg) ID() UniqueID {
return a.Base.MsgID
}

func (a *AlterDatabaseMsg) SetID(id UniqueID) {
a.Base.MsgID = id
}

func (a *AlterDatabaseMsg) Type() MsgType {
return a.Base.MsgType
}

func (a *AlterDatabaseMsg) SourceID() int64 {
return a.Base.SourceID
}

func (a *AlterDatabaseMsg) Marshal(input TsMsg) (MarshalType, error) {
alterDataBaseMsg := input.(*AlterDatabaseMsg)
alterDatabaseRequest := alterDataBaseMsg.AlterDatabaseRequest
mb, err := proto.Marshal(alterDatabaseRequest)
if err != nil {
return nil, err

Check warning on line 163 in pkg/mq/msgstream/msg_for_database.go

View check run for this annotation

Codecov / codecov/patch

pkg/mq/msgstream/msg_for_database.go#L163

Added line #L163 was not covered by tests
}
return mb, nil
}

func (a *AlterDatabaseMsg) Unmarshal(input MarshalType) (TsMsg, error) {
alterDatabaseRequest := &milvuspb.AlterDatabaseRequest{}
in, err := convertToByteArray(input)
if err != nil {
return nil, err
}
err = proto.Unmarshal(in, alterDatabaseRequest)
if err != nil {
return nil, err

Check warning on line 176 in pkg/mq/msgstream/msg_for_database.go

View check run for this annotation

Codecov / codecov/patch

pkg/mq/msgstream/msg_for_database.go#L176

Added line #L176 was not covered by tests
}
alterDatabaseMsg := &AlterDatabaseMsg{AlterDatabaseRequest: alterDatabaseRequest}
alterDatabaseMsg.BeginTimestamp = alterDatabaseMsg.GetBase().GetTimestamp()
alterDatabaseMsg.EndTimestamp = alterDatabaseMsg.GetBase().GetTimestamp()

return alterDatabaseMsg, nil
}

func (a *AlterDatabaseMsg) Size() int {
return proto.Size(a.AlterDatabaseRequest)

Check warning on line 186 in pkg/mq/msgstream/msg_for_database.go

View check run for this annotation

Codecov / codecov/patch

pkg/mq/msgstream/msg_for_database.go#L185-L186

Added lines #L185 - L186 were not covered by tests
}
43 changes: 43 additions & 0 deletions pkg/mq/msgstream/msg_for_database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,46 @@ func TestDropDatabase(t *testing.T) {

assert.True(t, msg.Size() > 0)
}

func TestAlterDatabase(t *testing.T) {
var msg TsMsg = &AlterDatabaseMsg{
AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_AlterDatabase,
MsgID: 100,
Timestamp: 1000,
SourceID: 10000,
TargetID: 100000,
ReplicateInfo: nil,
},
DbName: "unit_db",
Properties: []*commonpb.KeyValuePair{
{
Key: "key",
Value: "value",
},
},
},
}
assert.EqualValues(t, 100, msg.ID())
msg.SetID(200)
assert.EqualValues(t, 200, msg.ID())
assert.Equal(t, commonpb.MsgType_AlterDatabase, msg.Type())
assert.EqualValues(t, 10000, msg.SourceID())

msgBytes, err := msg.Marshal(msg)
assert.NoError(t, err)

var newMsg TsMsg = &AlterDatabaseMsg{}
_, err = newMsg.Unmarshal("1")
assert.Error(t, err)

newMsg, err = newMsg.Unmarshal(msgBytes)
assert.NoError(t, err)
assert.EqualValues(t, 200, newMsg.ID())
assert.EqualValues(t, 1000, newMsg.BeginTs())
assert.EqualValues(t, 1000, newMsg.EndTs())
assert.EqualValues(t, "unit_db", newMsg.(*AlterDatabaseMsg).DbName)
assert.EqualValues(t, "key", newMsg.(*AlterDatabaseMsg).Properties[0].Key)
assert.EqualValues(t, "value", newMsg.(*AlterDatabaseMsg).Properties[0].Value)
}
Loading
Loading