From 852be152deece8a7ca4760acf2de66bc8ce3aa3b Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 27 Oct 2023 01:08:12 +0800 Subject: [PATCH] Change task sourceID to stringer interface (#27965) Signed-off-by: Congqi Xia --- internal/querycoordv2/balance/utils.go | 8 +-- .../querycoordv2/checkers/balance_checker.go | 5 +- .../querycoordv2/checkers/channel_checker.go | 5 +- internal/querycoordv2/checkers/checker.go | 15 +---- internal/querycoordv2/checkers/controller.go | 66 +++++++++++-------- .../querycoordv2/checkers/index_checker.go | 5 +- .../querycoordv2/checkers/segment_checker.go | 5 +- internal/querycoordv2/handlers.go | 2 +- internal/querycoordv2/task/executor.go | 12 ++-- internal/querycoordv2/task/merger_test.go | 2 +- internal/querycoordv2/task/scheduler.go | 10 +-- internal/querycoordv2/task/task.go | 26 ++++---- internal/querycoordv2/task/task_test.go | 50 +++++++------- internal/querycoordv2/task/utils.go | 11 ++++ 14 files changed, 124 insertions(+), 98 deletions(-) diff --git a/internal/querycoordv2/balance/utils.go b/internal/querycoordv2/balance/utils.go index 16d5ede8d7cfb..9ba66902ddb7c 100644 --- a/internal/querycoordv2/balance/utils.go +++ b/internal/querycoordv2/balance/utils.go @@ -34,7 +34,7 @@ const ( DistInfoPrefix = "Balance-Dists:" ) -func CreateSegmentTasksFromPlans(ctx context.Context, checkerID int64, timeout time.Duration, plans []SegmentAssignPlan) []task.Task { +func CreateSegmentTasksFromPlans(ctx context.Context, source task.Source, timeout time.Duration, plans []SegmentAssignPlan) []task.Task { ret := make([]task.Task, 0) for _, p := range plans { actions := make([]task.Action, 0) @@ -49,7 +49,7 @@ func CreateSegmentTasksFromPlans(ctx context.Context, checkerID int64, timeout t t, err := task.NewSegmentTask( ctx, timeout, - checkerID, + source, p.Segment.GetCollectionID(), p.ReplicaID, actions..., @@ -86,7 +86,7 @@ func CreateSegmentTasksFromPlans(ctx context.Context, checkerID int64, timeout t return ret } -func CreateChannelTasksFromPlans(ctx context.Context, checkerID int64, timeout time.Duration, plans []ChannelAssignPlan) []task.Task { +func CreateChannelTasksFromPlans(ctx context.Context, source task.Source, timeout time.Duration, plans []ChannelAssignPlan) []task.Task { ret := make([]task.Task, 0, len(plans)) for _, p := range plans { actions := make([]task.Action, 0) @@ -98,7 +98,7 @@ func CreateChannelTasksFromPlans(ctx context.Context, checkerID int64, timeout t action := task.NewChannelAction(p.From, task.ActionTypeReduce, p.Channel.GetChannelName()) actions = append(actions, action) } - t, err := task.NewChannelTask(ctx, timeout, checkerID, p.Channel.GetCollectionID(), p.ReplicaID, actions...) + t, err := task.NewChannelTask(ctx, timeout, source, p.Channel.GetCollectionID(), p.ReplicaID, actions...) if err != nil { log.Warn("create channel task failed", zap.Int64("collection", p.Channel.GetCollectionID()), diff --git a/internal/querycoordv2/checkers/balance_checker.go b/internal/querycoordv2/checkers/balance_checker.go index 8444392fd67a3..a76e119002ab0 100644 --- a/internal/querycoordv2/checkers/balance_checker.go +++ b/internal/querycoordv2/checkers/balance_checker.go @@ -36,7 +36,6 @@ import ( // BalanceChecker checks the cluster distribution and generates balance tasks. type BalanceChecker struct { - baseChecker balance.Balance meta *meta.Meta nodeManager *session.NodeManager @@ -54,6 +53,10 @@ func NewBalanceChecker(meta *meta.Meta, balancer balance.Balance, nodeMgr *sessi } } +func (b *BalanceChecker) ID() task.Source { + return balanceChecker +} + func (b *BalanceChecker) Description() string { return "BalanceChecker checks the cluster distribution and generates balance tasks" } diff --git a/internal/querycoordv2/checkers/channel_checker.go b/internal/querycoordv2/checkers/channel_checker.go index 046cbc45fc243..3c6fbd0c71cac 100644 --- a/internal/querycoordv2/checkers/channel_checker.go +++ b/internal/querycoordv2/checkers/channel_checker.go @@ -34,7 +34,6 @@ import ( // TODO(sunby): have too much similar codes with SegmentChecker type ChannelChecker struct { - baseChecker meta *meta.Meta dist *meta.DistributionManager targetMgr *meta.TargetManager @@ -55,6 +54,10 @@ func NewChannelChecker( } } +func (c *ChannelChecker) ID() task.Source { + return channelChecker +} + func (c *ChannelChecker) Description() string { return "DmChannelChecker checks the lack of DmChannels, or some DmChannels are redundant" } diff --git a/internal/querycoordv2/checkers/checker.go b/internal/querycoordv2/checkers/checker.go index 8386e70f71a4d..5c5ea5ad4a473 100644 --- a/internal/querycoordv2/checkers/checker.go +++ b/internal/querycoordv2/checkers/checker.go @@ -23,20 +23,7 @@ import ( ) type Checker interface { - ID() int64 - SetID(id int64) + ID() task.Source Description() string Check(ctx context.Context) []task.Task } - -type baseChecker struct { - id int64 -} - -func (checker *baseChecker) ID() int64 { - return checker.id -} - -func (checker *baseChecker) SetID(id int64) { - checker.id = id -} diff --git a/internal/querycoordv2/checkers/controller.go b/internal/querycoordv2/checkers/controller.go index b238b16255dcf..dc65482d76001 100644 --- a/internal/querycoordv2/checkers/controller.go +++ b/internal/querycoordv2/checkers/controller.go @@ -32,20 +32,39 @@ import ( ) const ( - segmentChecker = "segment_checker" - channelChecker = "channel_checker" - balanceChecker = "balance_checker" - indexChecker = "index_checker" + segmentCheckerName = "segment_checker" + channelCheckerName = "channel_checker" + balanceCheckerName = "balance_checker" + indexCheckerName = "index_checker" +) + +type checkerType int32 + +const ( + channelChecker checkerType = iota + 1 + segmentChecker + balanceChecker + indexChecker ) var ( checkRoundTaskNumLimit = 256 - checkerOrder = []string{channelChecker, segmentChecker, balanceChecker, indexChecker} + checkerOrder = []string{channelCheckerName, segmentCheckerName, balanceCheckerName, indexCheckerName} + checkerNames = map[checkerType]string{ + segmentChecker: segmentCheckerName, + channelChecker: channelCheckerName, + balanceChecker: balanceCheckerName, + indexChecker: indexCheckerName, + } ) +func (s checkerType) String() string { + return checkerNames[s] +} + type CheckerController struct { cancel context.CancelFunc - manualCheckChs map[string]chan struct{} + manualCheckChs map[checkerType]chan struct{} meta *meta.Meta dist *meta.DistributionManager targetMgr *meta.TargetManager @@ -54,7 +73,7 @@ type CheckerController struct { balancer balance.Balance scheduler task.Scheduler - checkers map[string]Checker + checkers map[checkerType]Checker stopOnce sync.Once } @@ -70,19 +89,14 @@ func NewCheckerController( ) *CheckerController { // CheckerController runs checkers with the order, // the former checker has higher priority - checkers := map[string]Checker{ + checkers := map[checkerType]Checker{ channelChecker: NewChannelChecker(meta, dist, targetMgr, balancer), segmentChecker: NewSegmentChecker(meta, dist, targetMgr, balancer, nodeMgr), balanceChecker: NewBalanceChecker(meta, balancer, nodeMgr, scheduler), indexChecker: NewIndexChecker(meta, dist, broker), } - id := 0 - for _, checkerName := range checkerOrder { - checkers[checkerName].SetID(int64(id + 1)) - } - - manualCheckChs := map[string]chan struct{}{ + manualCheckChs := map[checkerType]chan struct{}{ channelChecker: make(chan struct{}, 1), segmentChecker: make(chan struct{}, 1), balanceChecker: make(chan struct{}, 1), @@ -103,13 +117,13 @@ func (controller *CheckerController) Start() { ctx, cancel := context.WithCancel(context.Background()) controller.cancel = cancel - for checkerType := range controller.checkers { - go controller.startChecker(ctx, checkerType) + for checker := range controller.checkers { + go controller.startChecker(ctx, checker) } } -func getCheckerInterval(checkerType string) time.Duration { - switch checkerType { +func getCheckerInterval(checker checkerType) time.Duration { + switch checker { case segmentChecker: return Params.QueryCoordCfg.SegmentCheckInterval.GetAsDuration(time.Millisecond) case channelChecker: @@ -123,8 +137,8 @@ func getCheckerInterval(checkerType string) time.Duration { } } -func (controller *CheckerController) startChecker(ctx context.Context, checkerType string) { - interval := getCheckerInterval(checkerType) +func (controller *CheckerController) startChecker(ctx context.Context, checker checkerType) { + interval := getCheckerInterval(checker) ticker := time.NewTicker(interval) defer ticker.Stop() @@ -132,15 +146,15 @@ func (controller *CheckerController) startChecker(ctx context.Context, checkerTy select { case <-ctx.Done(): log.Info("Checker stopped", - zap.String("type", checkerType)) + zap.String("type", checker.String())) return case <-ticker.C: - controller.check(ctx, checkerType) + controller.check(ctx, checker) - case <-controller.manualCheckChs[checkerType]: + case <-controller.manualCheckChs[checker]: ticker.Stop() - controller.check(ctx, checkerType) + controller.check(ctx, checker) ticker.Reset(interval) } } @@ -164,8 +178,8 @@ func (controller *CheckerController) Check() { } // check is the real implementation of Check -func (controller *CheckerController) check(ctx context.Context, checkerType string) { - checker := controller.checkers[checkerType] +func (controller *CheckerController) check(ctx context.Context, checkType checkerType) { + checker := controller.checkers[checkType] tasks := checker.Check(ctx) for _, task := range tasks { diff --git a/internal/querycoordv2/checkers/index_checker.go b/internal/querycoordv2/checkers/index_checker.go index 943582332bc15..764d2f3e5b9a5 100644 --- a/internal/querycoordv2/checkers/index_checker.go +++ b/internal/querycoordv2/checkers/index_checker.go @@ -35,7 +35,6 @@ var _ Checker = (*IndexChecker)(nil) // IndexChecker perform segment index check. type IndexChecker struct { - baseChecker meta *meta.Meta dist *meta.DistributionManager broker meta.Broker @@ -53,6 +52,10 @@ func NewIndexChecker( } } +func (c *IndexChecker) ID() task.Source { + return indexChecker +} + func (c *IndexChecker) Description() string { return "SegmentChecker checks index state change of segments and generates load index task" } diff --git a/internal/querycoordv2/checkers/segment_checker.go b/internal/querycoordv2/checkers/segment_checker.go index c971886e14335..7392c46414e14 100644 --- a/internal/querycoordv2/checkers/segment_checker.go +++ b/internal/querycoordv2/checkers/segment_checker.go @@ -36,7 +36,6 @@ import ( ) type SegmentChecker struct { - baseChecker meta *meta.Meta dist *meta.DistributionManager targetMgr *meta.TargetManager @@ -60,6 +59,10 @@ func NewSegmentChecker( } } +func (c *SegmentChecker) ID() task.Source { + return segmentChecker +} + func (c *SegmentChecker) Description() string { return "SegmentChecker checks the lack of segments, or some segments are redundant" } diff --git a/internal/querycoordv2/handlers.go b/internal/querycoordv2/handlers.go index 56bbec48b0931..1f431c6274c21 100644 --- a/internal/querycoordv2/handlers.go +++ b/internal/querycoordv2/handlers.go @@ -142,7 +142,7 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe ) task, err := task.NewSegmentTask(ctx, Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), - req.GetBase().GetMsgID(), + task.WrapIDSource(req.GetBase().GetMsgID()), req.GetCollectionID(), replica.GetID(), task.NewSegmentActionWithScope(plan.To, task.ActionTypeGrow, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical), diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index 0d1e5ccdd397c..b5e2b508a3756 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -104,7 +104,7 @@ func (ex *Executor) Execute(task Task, step int) bool { zap.Int64("collectionID", task.CollectionID()), zap.Int64("replicaID", task.ReplicaID()), zap.Int("step", step), - zap.Int64("source", task.SourceID()), + zap.String("source", task.Source().String()), ) go func() { @@ -169,7 +169,7 @@ func (ex *Executor) processMergeTask(mergeTask *LoadSegmentsTask) { zap.String("shard", task.Shard()), zap.Int64s("segmentIDs", segments), zap.Int64("nodeID", action.Node()), - zap.Int64("source", task.SourceID()), + zap.String("source", task.Source().String()), ) // Get shard leader for the given replica and segment @@ -231,7 +231,7 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { zap.Int64("replicaID", task.ReplicaID()), zap.Int64("segmentID", task.segmentID), zap.Int64("node", action.Node()), - zap.Int64("source", task.SourceID()), + zap.String("source", task.Source().String()), ) var err error @@ -312,7 +312,7 @@ func (ex *Executor) releaseSegment(task *SegmentTask, step int) { zap.Int64("replicaID", task.ReplicaID()), zap.Int64("segmentID", task.segmentID), zap.Int64("node", action.Node()), - zap.Int64("source", task.SourceID()), + zap.String("source", task.Source().String()), ) ctx := task.Context() @@ -384,7 +384,7 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error { zap.Int64("replicaID", task.ReplicaID()), zap.String("channel", task.Channel()), zap.Int64("node", action.Node()), - zap.Int64("source", task.SourceID()), + zap.String("source", task.Source().String()), ) var err error @@ -467,7 +467,7 @@ func (ex *Executor) unsubDmChannel(task *ChannelTask, step int) error { zap.Int64("replicaID", task.ReplicaID()), zap.String("channel", task.Channel()), zap.Int64("node", action.Node()), - zap.Int64("source", task.SourceID()), + zap.String("source", task.Source().String()), ) var err error diff --git a/internal/querycoordv2/task/merger_test.go b/internal/querycoordv2/task/merger_test.go index 366270234db55..9f376649026ab 100644 --- a/internal/querycoordv2/task/merger_test.go +++ b/internal/querycoordv2/task/merger_test.go @@ -122,7 +122,7 @@ func (suite *MergerSuite) TestMerge() { ctx := context.Background() for segmentID := int64(1); segmentID <= 3; segmentID++ { - task, err := NewSegmentTask(ctx, timeout, 0, suite.collectionID, suite.replicaID, + task, err := NewSegmentTask(ctx, timeout, WrapIDSource(0), suite.collectionID, suite.replicaID, NewSegmentAction(suite.nodeID, ActionTypeGrow, "", segmentID)) suite.NoError(err) suite.merger.Add(NewLoadSegmentsTask(task, 0, suite.requests[segmentID])) diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index 37bb22627e963..ad93a2c7ed16d 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -421,7 +421,7 @@ func (scheduler *taskScheduler) promote(task Task) error { zap.Int64("taskID", task.ID()), zap.Int64("collectionID", task.CollectionID()), zap.Int64("replicaID", task.ReplicaID()), - zap.Int64("source", task.SourceID()), + zap.String("source", task.Source().String()), ) if err := scheduler.check(task); err != nil { @@ -643,7 +643,7 @@ func (scheduler *taskScheduler) process(task Task) bool { zap.Int64("collectionID", task.CollectionID()), zap.Int64("replicaID", task.ReplicaID()), zap.String("type", GetTaskType(task).String()), - zap.Int64("source", task.SourceID()), + zap.String("source", task.Source().String()), ) actions, step := task.Actions(), task.Step() @@ -733,7 +733,7 @@ func (scheduler *taskScheduler) checkStale(task Task) error { zap.Int64("taskID", task.ID()), zap.Int64("collectionID", task.CollectionID()), zap.Int64("replicaID", task.ReplicaID()), - zap.Int64("source", task.SourceID()), + zap.String("source", task.Source().String()), ) switch task := task.(type) { @@ -770,7 +770,7 @@ func (scheduler *taskScheduler) checkSegmentTaskStale(task *SegmentTask) error { zap.Int64("taskID", task.ID()), zap.Int64("collectionID", task.CollectionID()), zap.Int64("replicaID", task.ReplicaID()), - zap.Int64("source", task.SourceID()), + zap.String("source", task.Source().String()), ) for _, action := range task.Actions() { @@ -814,7 +814,7 @@ func (scheduler *taskScheduler) checkChannelTaskStale(task *ChannelTask) error { zap.Int64("taskID", task.ID()), zap.Int64("collectionID", task.CollectionID()), zap.Int64("replicaID", task.ReplicaID()), - zap.Int64("source", task.SourceID()), + zap.String("source", task.Source().String()), ) for _, action := range task.Actions() { diff --git a/internal/querycoordv2/task/task.go b/internal/querycoordv2/task/task.go index 9b7b507f5ae5f..f2d4446337ef4 100644 --- a/internal/querycoordv2/task/task.go +++ b/internal/querycoordv2/task/task.go @@ -64,9 +64,11 @@ func (p Priority) String() string { // All task priorities from low to high var TaskPriorities = []Priority{TaskPriorityLow, TaskPriorityNormal, TaskPriorityHigh} +type Source fmt.Stringer + type Task interface { Context() context.Context - SourceID() UniqueID + Source() Source ID() UniqueID CollectionID() UniqueID ReplicaID() UniqueID @@ -98,13 +100,13 @@ type baseTask struct { doneCh chan struct{} canceled *atomic.Bool - sourceID UniqueID // RequestID id UniqueID // Set by scheduler collectionID UniqueID replicaID UniqueID shard string loadType querypb.LoadType + source Source status *atomic.Int32 priority Priority err error @@ -116,12 +118,12 @@ type baseTask struct { span trace.Span } -func newBaseTask(ctx context.Context, sourceID, collectionID, replicaID UniqueID, shard string) *baseTask { +func newBaseTask(ctx context.Context, source Source, collectionID, replicaID UniqueID, shard string) *baseTask { ctx, cancel := context.WithCancel(ctx) ctx, span := otel.Tracer("QueryCoord").Start(ctx, "QueryCoord-BaseTask") return &baseTask{ - sourceID: sourceID, + source: source, collectionID: collectionID, replicaID: replicaID, shard: shard, @@ -140,8 +142,8 @@ func (task *baseTask) Context() context.Context { return task.ctx } -func (task *baseTask) SourceID() UniqueID { - return task.sourceID +func (task *baseTask) Source() Source { + return task.source } func (task *baseTask) ID() UniqueID { @@ -260,10 +262,10 @@ func (task *baseTask) String() string { } } return fmt.Sprintf( - "[id=%d] [type=%s] [source=%d] [reason=%s] [collectionID=%d] [replicaID=%d] [priority=%s] [actionsCount=%d] [actions=%s]", + "[id=%d] [type=%s] [source=%s] [reason=%s] [collectionID=%d] [replicaID=%d] [priority=%s] [actionsCount=%d] [actions=%s]", task.id, GetTaskType(task).String(), - task.sourceID, + task.source.String(), task.reason, task.collectionID, task.replicaID, @@ -284,7 +286,7 @@ type SegmentTask struct { // empty actions is not allowed func NewSegmentTask(ctx context.Context, timeout time.Duration, - sourceID, + source Source, collectionID, replicaID UniqueID, actions ...Action, @@ -308,7 +310,7 @@ func NewSegmentTask(ctx context.Context, } } - base := newBaseTask(ctx, sourceID, collectionID, replicaID, shard) + base := newBaseTask(ctx, source, collectionID, replicaID, shard) base.actions = actions return &SegmentTask{ baseTask: base, @@ -341,7 +343,7 @@ type ChannelTask struct { // empty actions is not allowed func NewChannelTask(ctx context.Context, timeout time.Duration, - sourceID, + source Source, collectionID, replicaID UniqueID, actions ...Action, @@ -363,7 +365,7 @@ func NewChannelTask(ctx context.Context, } } - base := newBaseTask(ctx, sourceID, collectionID, replicaID, channel) + base := newBaseTask(ctx, source, collectionID, replicaID, channel) base.actions = actions return &ChannelTask{ baseTask: base, diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index a682ccffe97e0..601e586e24adf 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -242,7 +242,7 @@ func (suite *TaskSuite) TestSubscribeChannelTask() { task, err := NewChannelTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewChannelAction(targetNode, ActionTypeGrow, channel), @@ -291,7 +291,7 @@ func (suite *TaskSuite) TestSubmitDuplicateSubscribeChannelTask() { task, err := NewChannelTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewChannelAction(targetNode, ActionTypeGrow, channel), @@ -336,7 +336,7 @@ func (suite *TaskSuite) TestUnsubscribeChannelTask() { task, err := NewChannelTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, -1, NewChannelAction(targetNode, ActionTypeReduce, channel), @@ -426,7 +426,7 @@ func (suite *TaskSuite) TestLoadSegmentTask() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeGrow, channel.GetChannelName(), segment), @@ -522,7 +522,7 @@ func (suite *TaskSuite) TestLoadSegmentTaskNotIndex() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeGrow, channel.GetChannelName(), segment), @@ -612,7 +612,7 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeGrow, channel.GetChannelName(), segment), @@ -677,7 +677,7 @@ func (suite *TaskSuite) TestReleaseSegmentTask() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeReduce, channel.GetChannelName(), segment), @@ -721,7 +721,7 @@ func (suite *TaskSuite) TestReleaseGrowingSegmentTask() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentActionWithScope(targetNode, ActionTypeReduce, "", segment, querypb.DataScope_Streaming), @@ -827,7 +827,7 @@ func (suite *TaskSuite) TestMoveSegmentTask() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeGrow, channel.GetChannelName(), segment), @@ -911,7 +911,7 @@ func (suite *TaskSuite) TestMoveSegmentTaskStale() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeGrow, channel.GetChannelName(), segment), @@ -986,7 +986,7 @@ func (suite *TaskSuite) TestTaskCanceled() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeGrow, channel.GetChannelName(), segment), @@ -1074,7 +1074,7 @@ func (suite *TaskSuite) TestSegmentTaskStale() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeGrow, channel.GetChannelName(), segment), @@ -1148,7 +1148,7 @@ func (suite *TaskSuite) TestChannelTaskReplace() { task, err := NewChannelTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewChannelAction(targetNode, ActionTypeGrow, channel), @@ -1165,7 +1165,7 @@ func (suite *TaskSuite) TestChannelTaskReplace() { task, err := NewChannelTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewChannelAction(targetNode, ActionTypeGrow, channel), @@ -1184,7 +1184,7 @@ func (suite *TaskSuite) TestChannelTaskReplace() { task, err := NewChannelTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewChannelAction(targetNode, ActionTypeGrow, channel), @@ -1199,34 +1199,34 @@ func (suite *TaskSuite) TestChannelTaskReplace() { } func (suite *TaskSuite) TestCreateTaskBehavior() { - chanelTask, err := NewChannelTask(context.TODO(), 5*time.Second, 0, 0, 0) + chanelTask, err := NewChannelTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0) suite.ErrorIs(err, merr.ErrParameterInvalid) suite.Nil(chanelTask) action := NewSegmentAction(0, 0, "", 0) - chanelTask, err = NewChannelTask(context.TODO(), 5*time.Second, 0, 0, 0, action) + chanelTask, err = NewChannelTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0, action) suite.ErrorIs(err, merr.ErrParameterInvalid) suite.Nil(chanelTask) action1 := NewChannelAction(0, 0, "fake-channel1") action2 := NewChannelAction(0, 0, "fake-channel2") - chanelTask, err = NewChannelTask(context.TODO(), 5*time.Second, 0, 0, 0, action1, action2) + chanelTask, err = NewChannelTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0, action1, action2) suite.ErrorIs(err, merr.ErrParameterInvalid) suite.Nil(chanelTask) - segmentTask, err := NewSegmentTask(context.TODO(), 5*time.Second, 0, 0, 0) + segmentTask, err := NewSegmentTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0) suite.ErrorIs(err, merr.ErrParameterInvalid) suite.Nil(segmentTask) channelAction := NewChannelAction(0, 0, "fake-channel1") - segmentTask, err = NewSegmentTask(context.TODO(), 5*time.Second, 0, 0, 0, channelAction) + segmentTask, err = NewSegmentTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0, channelAction) suite.ErrorIs(err, merr.ErrParameterInvalid) suite.Nil(segmentTask) segmentAction1 := NewSegmentAction(0, 0, "", 0) segmentAction2 := NewSegmentAction(0, 0, "", 1) - segmentTask, err = NewSegmentTask(context.TODO(), 5*time.Second, 0, 0, 0, segmentAction1, segmentAction2) + segmentTask, err = NewSegmentTask(context.TODO(), 5*time.Second, WrapIDSource(0), 0, 0, segmentAction1, segmentAction2) suite.ErrorIs(err, merr.ErrParameterInvalid) suite.Nil(segmentTask) } @@ -1240,7 +1240,7 @@ func (suite *TaskSuite) TestSegmentTaskReplace() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeGrow, "", segment), @@ -1257,7 +1257,7 @@ func (suite *TaskSuite) TestSegmentTaskReplace() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeGrow, "", segment), @@ -1276,7 +1276,7 @@ func (suite *TaskSuite) TestSegmentTaskReplace() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeGrow, "", segment), @@ -1317,7 +1317,7 @@ func (suite *TaskSuite) TestNoExecutor() { task, err := NewSegmentTask( ctx, timeout, - 0, + WrapIDSource(0), suite.collection, suite.replica, NewSegmentAction(targetNode, ActionTypeGrow, channel.GetChannelName(), segment), diff --git a/internal/querycoordv2/task/utils.go b/internal/querycoordv2/task/utils.go index 3d95c3903e853..f9ee116745ce8 100644 --- a/internal/querycoordv2/task/utils.go +++ b/internal/querycoordv2/task/utils.go @@ -36,6 +36,17 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) +// idSource helper type for using id as task source +type idSource int64 + +func (s idSource) String() string { + return fmt.Sprintf("ID-%d", s) +} + +func WrapIDSource(id int64) Source { + return idSource(id) +} + func Wait(ctx context.Context, timeout time.Duration, tasks ...Task) error { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel()