Skip to content

Commit

Permalink
enhance: refine pular related mq interfaces (#38007)
Browse files Browse the repository at this point in the history
issue: #35917 
Refines the pulsar-related mq APIs to allow the ctx to be passed down

Signed-off-by: tinswzy <[email protected]>
  • Loading branch information
tinswzy authored Dec 4, 2024
1 parent 73aa95f commit 5768dbb
Show file tree
Hide file tree
Showing 50 changed files with 380 additions and 367 deletions.
2 changes: 1 addition & 1 deletion internal/datacoord/compaction_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ func (t *compactionTrigger) handleSignal(signal *compactionSignal) {
return
}

segment := t.meta.GetHealthySegment(t.meta.ctx, signal.segmentID)
segment := t.meta.GetHealthySegment(context.TODO(), signal.segmentID)
if segment == nil {
log.Warn("segment in compaction signal not found in meta", zap.Int64("segmentID", signal.segmentID))
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack {
return make(chan *msgstream.MsgPack, 100)
}

func (mtm *mockTtMsgStream) AsProducer(channels []string) {}
func (mtm *mockTtMsgStream) AsProducer(ctx context.Context, channels []string) {}

func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
return nil
Expand All @@ -80,11 +80,11 @@ func (mtm *mockTtMsgStream) GetProduceChannels() []string {
return make([]string, 0)
}

func (mtm *mockTtMsgStream) Produce(*msgstream.MsgPack) error {
func (mtm *mockTtMsgStream) Produce(context.Context, *msgstream.MsgPack) error {
return nil
}

func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
func (mtm *mockTtMsgStream) Broadcast(context.Context, *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
return nil, nil
}

Expand Down
18 changes: 9 additions & 9 deletions internal/proxy/channels_mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import (
type channelsMgr interface {
getChannels(collectionID UniqueID) ([]pChan, error)
getVChannels(collectionID UniqueID) ([]vChan, error)
getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error)
getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error)
removeDMLStream(collectionID UniqueID)
removeAllDMLStream()
}
Expand Down Expand Up @@ -172,7 +172,7 @@ func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool
return ok && streamInfos.stream != nil
}

func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
func createStream(ctx context.Context, factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
var stream msgstream.MsgStream
var err error

Expand All @@ -181,7 +181,7 @@ func createStream(factory msgstream.Factory, pchans []pChan, repack repackFuncTy
return nil, err
}

stream.AsProducer(pchans)
stream.AsProducer(ctx, pchans)
if repack != nil {
stream.SetRepackFunc(repack)
}
Expand All @@ -202,7 +202,7 @@ func decPChanMetrics(pchans []pChan) {

// createMsgStream create message stream for specified collection. Idempotent.
// If stream already exists, directly return it and no error will be returned.
func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstream.MsgStream, error) {
func (mgr *singleTypeChannelsMgr) createMsgStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
mgr.mu.RLock()
infos, ok := mgr.infos[collectionID]
if ok && infos.stream != nil {
Expand All @@ -219,7 +219,7 @@ func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) (msgstr
return nil, err
}

stream, err := createStream(mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
stream, err := createStream(ctx, mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
if err != nil {
// What if stream created by other goroutines?
log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID))
Expand Down Expand Up @@ -253,12 +253,12 @@ func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstrea

// getOrCreateStream get message stream of specified collection.
// If stream doesn't exist, call createMsgStream to create for it.
func (mgr *singleTypeChannelsMgr) getOrCreateStream(collectionID UniqueID) (msgstream.MsgStream, error) {
func (mgr *singleTypeChannelsMgr) getOrCreateStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
if stream, err := mgr.lockGetStream(collectionID); err == nil {
return stream, nil
}

return mgr.createMsgStream(collectionID)
return mgr.createMsgStream(ctx, collectionID)
}

// removeStream remove the corresponding stream of the specified collection. Idempotent.
Expand Down Expand Up @@ -315,8 +315,8 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error)
return mgr.dmlChannelsMgr.getVChannels(collectionID)
}

func (mgr *channelsMgrImpl) getOrCreateDmlStream(collectionID UniqueID) (msgstream.MsgStream, error) {
return mgr.dmlChannelsMgr.getOrCreateStream(collectionID)
func (mgr *channelsMgrImpl) getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
return mgr.dmlChannelsMgr.getOrCreateStream(ctx, collectionID)
}

func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) {
Expand Down
24 changes: 12 additions & 12 deletions internal/proxy/channels_mgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func Test_createStream(t *testing.T) {
factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(factory, nil, nil)
_, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err)
})

Expand All @@ -223,7 +223,7 @@ func Test_createStream(t *testing.T) {
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(factory, nil, nil)
_, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err)
})

Expand All @@ -232,7 +232,7 @@ func Test_createStream(t *testing.T) {
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
_, err := createStream(factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
_, err := createStream(context.TODO(), factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return nil, nil
})
assert.NoError(t, err)
Expand All @@ -247,7 +247,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
100: {stream: newMockMsgStream()},
},
}
stream, err := m.createMsgStream(100)
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
Expand Down Expand Up @@ -275,7 +275,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
stream, err := m.createMsgStream(100)
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
}()
Expand All @@ -295,7 +295,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.createMsgStream(100)
_, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err)
})

Expand All @@ -311,7 +311,7 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
msgStreamFactory: factory,
repackFunc: nil,
}
_, err := m.createMsgStream(100)
_, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err)
})

Expand All @@ -328,10 +328,10 @@ func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
msgStreamFactory: factory,
repackFunc: nil,
}
stream, err := m.createMsgStream(100)
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
stream, err = m.getOrCreateStream(100)
stream, err = m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
Expand Down Expand Up @@ -365,7 +365,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
100: {stream: newMockMsgStream()},
},
}
stream, err := m.getOrCreateStream(100)
stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
Expand All @@ -377,7 +377,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.getOrCreateStream(100)
_, err := m.getOrCreateStream(context.TODO(), 100)
assert.Error(t, err)
})

Expand All @@ -394,7 +394,7 @@ func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
msgStreamFactory: factory,
repackFunc: nil,
}
stream, err := m.getOrCreateStream(100)
stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
Expand Down
2 changes: 1 addition & 1 deletion internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -6323,7 +6323,7 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate
Status: merr.Status(err),
}, nil
}
messageIDsMap, err := msgStream.Broadcast(msgPack)
messageIDsMap, err := msgStream.Broadcast(ctx, msgPack)
if err != nil {
log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
Expand Down
16 changes: 8 additions & 8 deletions internal/proxy/impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ func TestProxy_FlushAll_DbCollection(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})

Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000")
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
Expand Down Expand Up @@ -483,7 +483,7 @@ func TestProxy_FlushAll(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})

Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000")
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
Expand Down Expand Up @@ -955,7 +955,7 @@ func TestProxyCreateDatabase(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})

t.Run("create database fail", func(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t)
Expand Down Expand Up @@ -1015,7 +1015,7 @@ func TestProxyDropDatabase(t *testing.T) {
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer([]string{rpcRequestChannel})
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})

t.Run("drop database fail", func(t *testing.T) {
rc := mocks.NewMockRootCoordClient(t)
Expand Down Expand Up @@ -1496,13 +1496,13 @@ func TestProxy_ReplicateMessage(t *testing.T) {
factory := newMockMsgStreamFactory()
msgStreamObj := msgstream.NewMockMsgStream(t)
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return()
msgStreamObj.EXPECT().AsProducer(mock.Anything).Return()
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return()
msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return()
msgStreamObj.EXPECT().Close().Return()
mockMsgID1 := mqcommon.NewMockMessageID(t)
mockMsgID2 := mqcommon.NewMockMessageID(t)
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2"))
broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{
broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"unit_test_replicate_message": {mockMsgID1, mockMsgID2},
}, nil)

Expand Down Expand Up @@ -1581,7 +1581,7 @@ func TestProxy_ReplicateMessage(t *testing.T) {

{
broadcastMock.Unset()
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(nil, errors.New("mock error: broadcast"))
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: broadcast"))
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
assert.NoError(t, err)
assert.NotEqualValues(t, 0, resp.GetStatus().GetCode())
Expand All @@ -1590,7 +1590,7 @@ func TestProxy_ReplicateMessage(t *testing.T) {
}
{
broadcastMock.Unset()
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqcommon.MessageID{
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"unit_test_replicate_message": {},
}, nil)
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
Expand Down
31 changes: 17 additions & 14 deletions internal/proxy/mock_channels_manager.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion internal/proxy/mock_msgstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type mockMsgStream struct {
enableProduce func(bool)
}

func (m *mockMsgStream) AsProducer(producers []string) {
func (m *mockMsgStream) AsProducer(ctx context.Context, producers []string) {
if m.asProducer != nil {
m.asProducer(producers)
}
Expand Down
6 changes: 3 additions & 3 deletions internal/proxy/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack {
return ms.msgChan
}

func (ms *simpleMockMsgStream) AsProducer(channels []string) {
func (ms *simpleMockMsgStream) AsProducer(ctx context.Context, channels []string) {
}

func (ms *simpleMockMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error {
Expand Down Expand Up @@ -283,15 +283,15 @@ func (ms *simpleMockMsgStream) decreaseMsgCount(delta int) {
ms.increaseMsgCount(-delta)
}

func (ms *simpleMockMsgStream) Produce(pack *msgstream.MsgPack) error {
func (ms *simpleMockMsgStream) Produce(ctx context.Context, pack *msgstream.MsgPack) error {
defer ms.increaseMsgCount(1)

ms.msgChan <- pack

return nil
}

func (ms *simpleMockMsgStream) Broadcast(pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
func (ms *simpleMockMsgStream) Broadcast(ctx context.Context, pack *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) {
return map[string][]msgstream.MessageID{}, nil
}

Expand Down
Loading

0 comments on commit 5768dbb

Please sign in to comment.