From 6af7a95bd81eea1f6e6518d04857569ac32a425c Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Thu, 19 Sep 2024 03:00:32 +0800 Subject: [PATCH] client: Support both sync and async interface for tsoStream, to avoid potential performane regression when concurrent RPC is not enabled Signed-off-by: MyonKeminta --- client/tso_dispatcher.go | 80 +++++++++++++++++---------- client/tso_stream.go | 111 ++++++++++++++++++++++++-------------- client/tso_stream_test.go | 37 ++++++++++--- 3 files changed, 153 insertions(+), 75 deletions(-) diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index a1e0b03a1fa..bd014f174eb 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -328,10 +328,10 @@ tsoBatchLoop: case td.tsDeadlineCh <- dl: } // processRequests guarantees that the collected requests could be finished properly. - err = td.processRequests(stream, dc, batchController, done) + tbcConsumed, err := td.processRequests(stream, dc, batchController, done, false) // If error happens during tso stream handling, reset stream and run the next trial. - if err == nil { - // A nil error returned by `processRequests` indicates that the request batch is started successfully. + if tbcConsumed { + // The function `processRequests` *consumed* the batchController. // In this case, the `batchController` will be put back to the pool when the request is finished // asynchronously (either successful or not). This infers that the current `batchController` object will // be asynchronously accessed after the `processRequests` call. As a result, we need to use another @@ -340,7 +340,8 @@ tsoBatchLoop: // Otherwise, the `batchController` won't be processed in other goroutines concurrently, and it can be // reused in the next loop safely. batchController = nil - } else { + } + if err != nil { exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err) stream = nil if exit { @@ -462,10 +463,16 @@ func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext // processRequests sends the RPC request for the batch. It's guaranteed that after calling this function, requests // in the batch must be eventually finished (done or canceled), either synchronously or asynchronously. // `close(done)` will be called at the same time when finishing the requests. -// If this function returns a non-nil error, the requests will always be canceled synchronously. +// If this function returns a non-nil error, infers that the requests must have been canceled synchronously, no matter +// how `asyncMode` is set. +// +// When the batch of requests is scheduled to be processed asynchronously, the `tbc` may need to be used by other +// goroutines later, and put it back to `td.batchBufferPool` after finishing using it. In this case, the `tbc` object +// cannot be used after this function returns. It can be considered as this function *consumed* the `tbc` object. +// To notify the caller about this, the function returns a bool to indicate whether the `tbc` is consumed. func (td *tsoDispatcher) processRequests( - stream *tsoStream, dcLocation string, tbc *tsoBatchController, done chan struct{}, -) error { + stream *tsoStream, dcLocation string, tbc *tsoBatchController, done chan struct{}, asyncMode bool, +) (bool, error) { // `done` must be guaranteed to be eventually called. var ( requests = tbc.getCollectedRequests() @@ -501,35 +508,50 @@ func (td *tsoDispatcher) processRequests( close(done) defer td.batchBufferPool.Put(tbc) + + td.handleTSOResultsForBatch(tbc, stream, result, reqKeyspaceGroupID, err) + } + + if asyncMode { + err := stream.ProcessRequestsAsync( + clusterID, keyspaceID, reqKeyspaceGroupID, + dcLocation, count, tbc.extraBatchingStartTime, cb) if err != nil { + close(done) + td.cancelCollectedRequests(tbc, stream.streamID, err) - return + return false, err + } else { + return true, nil } + } else { + result, err := stream.ProcessRequests( + clusterID, keyspaceID, reqKeyspaceGroupID, dcLocation, count, tbc.extraBatchingStartTime) + close(done) - curTSOInfo := &tsoInfo{ - tsoServer: stream.getServerURL(), - reqKeyspaceGroupID: reqKeyspaceGroupID, - respKeyspaceGroupID: result.respKeyspaceGroupID, - respReceivedAt: time.Now(), - physical: result.physical, - logical: result.logical, - } - // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. - firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits) - td.compareAndSwapTS(curTSOInfo, firstLogical) - td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits, stream.streamID) + td.handleTSOResultsForBatch(tbc, stream, result, reqKeyspaceGroupID, err) + return false, err } +} - err := stream.processRequests( - clusterID, keyspaceID, reqKeyspaceGroupID, - dcLocation, count, tbc.extraBatchingStartTime, cb) +func (td *tsoDispatcher) handleTSOResultsForBatch(tbc *tsoBatchController, stream *tsoStream, result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { if err != nil { - close(done) - td.cancelCollectedRequests(tbc, stream.streamID, err) - return err + return } - return nil + + curTSOInfo := &tsoInfo{ + tsoServer: stream.getServerURL(), + reqKeyspaceGroupID: reqKeyspaceGroupID, + respKeyspaceGroupID: result.respKeyspaceGroupID, + respReceivedAt: time.Now(), + physical: result.physical, + logical: result.logical, + } + // `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request. + firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits) + td.compareAndSwapTS(curTSOInfo, firstLogical) + td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits, stream.streamID) } func (td *tsoDispatcher) cancelCollectedRequests(tbc *tsoBatchController, streamID string, err error) { @@ -580,3 +602,7 @@ func (td *tsoDispatcher) compareAndSwapTS( } td.lastTSOInfo = curTSOInfo } + +func (td *tsoDispatcher) useAsyncStream() bool { + return false +} diff --git a/client/tso_stream.go b/client/tso_stream.go index 479beff2c6a..3b4d9171cac 100644 --- a/client/tso_stream.go +++ b/client/tso_stream.go @@ -207,8 +207,7 @@ type tsoStream struct { pendingRequests chan batchedRequests - cancel context.CancelFunc - wg sync.WaitGroup + wg sync.WaitGroup // For syncing between sender and receiver to guarantee all requests are finished when closing. state atomic.Int32 @@ -230,18 +229,12 @@ const invalidStreamID = "" func newTSOStream(ctx context.Context, serverURL string, stream grpcTSOStreamAdapter) *tsoStream { streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1)) - // To make error handling in `tsoDispatcher` work, the internal `cancel` and external `cancel` is better to be - // distinguished. - ctx, cancel := context.WithCancel(ctx) s := &tsoStream{ - serverURL: serverURL, - stream: stream, - streamID: streamID, - + serverURL: serverURL, + stream: stream, + streamID: streamID, pendingRequests: make(chan batchedRequests, 64), - cancel: cancel, - ongoingRequestCountGauge: ongoingRequestCountGauge.WithLabelValues(streamID), } s.wg.Add(1) @@ -253,14 +246,70 @@ func (s *tsoStream) getServerURL() string { return s.serverURL } -// processRequests starts an RPC to get a batch of timestamps without waiting for the result. When the result is ready, -// it will be passed th `notifier.finish`. +// ProcessRequests starts an RPC to get a batch of timestamps and synchronously waits for the result. +// +// This function is NOT thread-safe. Don't call this function concurrently in multiple goroutines. Neither should this +// function be called concurrently with ProcessRequestsAsync. +// +// WARNING: The caller is responsible to guarantee that when using this function, there must NOT be any unfinished +// async requests started by ProcessRequestsAsync. Otherwise, it might cause multiple goroutines calling `RecvMsg` +// of gRPC, which is unsafe. +// +// Calling this may implicitly switch the stream to sync mode. +func (s *tsoStream) ProcessRequests( + clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time, +) (tsoRequestResult, error) { + start := time.Now() + if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { + if err == io.EOF { + err = errs.ErrClientTSOStreamClosed + } else { + err = errors.WithStack(err) + } + return tsoRequestResult{}, err + } + tsoBatchSendLatency.Observe(time.Since(batchStartTime).Seconds()) + res, err := s.stream.Recv() + duration := time.Since(start).Seconds() + if err != nil { + requestFailedDurationTSO.Observe(duration) + if err == io.EOF { + err = errs.ErrClientTSOStreamClosed + } else { + err = errors.WithStack(err) + } + return tsoRequestResult{}, err + } + requestDurationTSO.Observe(duration) + tsoBatchSize.Observe(float64(count)) + + if res.count != uint32(count) { + err = errors.WithStack(errTSOLength) + return tsoRequestResult{}, err + } + + respKeyspaceGroupID := res.respKeyspaceGroupID + physical, logical, suffixBits := res.physical, res.logical, res.suffixBits + return tsoRequestResult{ + physical: physical, + logical: logical, + count: uint32(count), + suffixBits: suffixBits, + respKeyspaceGroupID: respKeyspaceGroupID, + }, nil +} + +// ProcessRequestsAsync starts an RPC to get a batch of timestamps without waiting for the result. When the result is ready, +// it will be passed th `notifier.finish`. // -// This function is NOT thread-safe. Don't call this function concurrently in multiple goroutines. +// This function is NOT thread-safe. Don't call this function concurrently in multiple goroutines. Neither should this +// function be called concurrently with ProcessRequests. // // It's guaranteed that the `callback` will be called, but when the request is failed to be scheduled, the callback // will be ignored. -func (s *tsoStream) processRequests( +// +// Calling this may implicitly switch the stream to async mode. +func (s *tsoStream) ProcessRequestsAsync( clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time, callback onFinishedCallback, ) error { start := time.Now() @@ -300,11 +349,7 @@ func (s *tsoStream) processRequests( if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil { // As the request is already put into `pendingRequests`, the request should finally be canceled by the recvLoop. - // So skip returning error here to avoid - // if err == io.EOF { - // return errors.WithStack(errs.ErrClientTSOStreamClosed) - // } - // return errors.WithStack(err) + // So do not return error here to avoid double-cancelling. log.Warn("failed to send RPC request through tsoStream", zap.String("stream", s.streamID), zap.Error(err)) return nil } @@ -335,7 +380,6 @@ func (s *tsoStream) recvLoop(ctx context.Context) { } s.stoppedWithErr.Store(&finishWithErr) - s.cancel() for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) { switch state := s.state.Load(); state { case streamStateIdle, streamStateSending: @@ -365,42 +409,29 @@ func (s *tsoStream) recvLoop(ctx context.Context) { recvLoop: for { - select { - case <-ctx.Done(): - finishWithErr = context.Canceled - break recvLoop - default: - } - - res, err := s.stream.Recv() - // Try to load the corresponding `batchedRequests`. If `Recv` is successful, there must be a request pending // in the queue. select { case currentReq = <-s.pendingRequests: hasReq = true - default: - hasReq = false + case <-ctx.Done(): + finishWithErr = ctx.Err() + return } + res, err := s.stream.Recv() + durationSeconds := time.Since(currentReq.startTime).Seconds() if err != nil { // If a request is pending and error occurs, observe the duration it has cost. - // Note that it's also possible that the stream is broken due to network without being requested. In this - // case, `Recv` may return an error while no request is pending. - if hasReq { - requestFailedDurationTSO.Observe(durationSeconds) - } + requestFailedDurationTSO.Observe(durationSeconds) if err == io.EOF { finishWithErr = errors.WithStack(errs.ErrClientTSOStreamClosed) } else { finishWithErr = errors.WithStack(err) } break recvLoop - } else if !hasReq { - finishWithErr = errors.New("tsoStream timing order broken") - break recvLoop } latencySeconds := durationSeconds diff --git a/client/tso_stream_test.go b/client/tso_stream_test.go index b09c54baf3a..623d645bd1e 100644 --- a/client/tso_stream_test.go +++ b/client/tso_stream_test.go @@ -229,15 +229,20 @@ type testTSOStreamSuite struct { inner *mockTSOStreamImpl stream *tsoStream + + ctx context.Context + cancel context.CancelFunc } func (s *testTSOStreamSuite) SetupTest() { s.re = require.New(s.T()) - s.inner = newMockTSOStreamImpl(context.Background(), false) - s.stream = newTSOStream(context.Background(), mockStreamURL, s.inner) + s.ctx, s.cancel = context.WithCancel(context.Background()) + s.inner = newMockTSOStreamImpl(s.ctx, false) + s.stream = newTSOStream(s.ctx, mockStreamURL, s.inner) } func (s *testTSOStreamSuite) TearDownTest() { + s.cancel() s.inner.stop() s.stream.WaitForClosed() s.inner = nil @@ -268,7 +273,7 @@ func (s *testTSOStreamSuite) getResult(ch <-chan callbackInvocation) callbackInv func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) (<-chan callbackInvocation, error) { ch := make(chan callbackInvocation, 1) - err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { + err := s.stream.ProcessRequestsAsync(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) { if err == nil { s.re.Equal(uint32(3), reqKeyspaceGroupID) s.re.Equal(uint32(0), result.suffixBits) @@ -320,7 +325,7 @@ func (s *testTSOStreamSuite) TestTSOStreamBasic() { // After an error from the (simulated) RPC stream, the tsoStream should be in a broken status and can't accept // new request anymore. - err := s.stream.processRequests(1, 2, 3, globalDCLocation, 1, time.Now(), func(_result tsoRequestResult, _reqKeyspaceGroupID uint32, _err error) { + err := s.stream.ProcessRequestsAsync(1, 2, 3, globalDCLocation, 1, time.Now(), func(_result tsoRequestResult, _reqKeyspaceGroupID uint32, _err error) { panic("unreachable") }) s.re.Error(err) @@ -341,6 +346,20 @@ func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests s.stream.WaitForClosed() closedCh <- struct{}{} }() + + if pendingRequests == 0 { + // As the recvLoop retrieves the pending requests first before trying to receive from the stream, if there's + // no pending requests, it doesn't immediately detect the stream is broken, until when there's a new incoming + // request. + select { + case <-closedCh: + s.re.FailNow("stream receiver loop exists unexpectedly") + case <-time.After(time.Millisecond * 50): + } + ch := s.mustProcessRequestWithResultCh(1) + resultCh = append(resultCh, ch) + } + select { case <-closedCh: case <-time.After(time.Second): @@ -443,7 +462,7 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { } // After handling all these requests, the stream is ended by an EOF error. The next request won't succeed. - // So, either the `processRequests` function returns an error or the callback is called with an error. + // So, either the `ProcessRequestsAsync` function returns an error or the callback is called with an error. ch, err := s.processRequestWithResultCh(1) if err != nil { s.re.ErrorIs(err, errs.ErrClientTSOStreamClosed) @@ -457,9 +476,11 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() { func BenchmarkTSOStreamSendRecv(b *testing.B) { log.SetLevel(zapcore.FatalLevel) - streamInner := newMockTSOStreamImpl(context.Background(), true) - stream := newTSOStream(context.Background(), mockStreamURL, streamInner) + ctx, cancel := context.WithCancel(context.Background()) + streamInner := newMockTSOStreamImpl(ctx, true) + stream := newTSOStream(ctx, mockStreamURL, streamInner) defer func() { + cancel() streamInner.stop() stream.WaitForClosed() }() @@ -469,7 +490,7 @@ func BenchmarkTSOStreamSendRecv(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - err := stream.processRequests(1, 1, 1, globalDCLocation, 1, now, func(result tsoRequestResult, _ uint32, err error) { + err := stream.ProcessRequestsAsync(1, 1, 1, globalDCLocation, 1, now, func(result tsoRequestResult, _ uint32, err error) { if err != nil { panic(err) }