Skip to content

Commit

Permalink
client: Support both sync and async interface for tsoStream, to avoid…
Browse files Browse the repository at this point in the history
… potential performane regression when concurrent RPC is not enabled

Signed-off-by: MyonKeminta <[email protected]>
  • Loading branch information
MyonKeminta committed Sep 18, 2024
1 parent 71f6f96 commit 6af7a95
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 75 deletions.
80 changes: 53 additions & 27 deletions client/tso_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,10 @@ tsoBatchLoop:
case td.tsDeadlineCh <- dl:
}
// processRequests guarantees that the collected requests could be finished properly.
err = td.processRequests(stream, dc, batchController, done)
tbcConsumed, err := td.processRequests(stream, dc, batchController, done, false)
// If error happens during tso stream handling, reset stream and run the next trial.
if err == nil {
// A nil error returned by `processRequests` indicates that the request batch is started successfully.
if tbcConsumed {
// The function `processRequests` *consumed* the batchController.
// In this case, the `batchController` will be put back to the pool when the request is finished
// asynchronously (either successful or not). This infers that the current `batchController` object will
// be asynchronously accessed after the `processRequests` call. As a result, we need to use another
Expand All @@ -340,7 +340,8 @@ tsoBatchLoop:
// Otherwise, the `batchController` won't be processed in other goroutines concurrently, and it can be
// reused in the next loop safely.
batchController = nil
} else {
}
if err != nil {
exit := !td.handleProcessRequestError(ctx, bo, streamURL, cancel, err)
stream = nil
if exit {
Expand Down Expand Up @@ -462,10 +463,16 @@ func chooseStream(connectionCtxs *sync.Map) (connectionCtx *tsoConnectionContext
// processRequests sends the RPC request for the batch. It's guaranteed that after calling this function, requests
// in the batch must be eventually finished (done or canceled), either synchronously or asynchronously.
// `close(done)` will be called at the same time when finishing the requests.
// If this function returns a non-nil error, the requests will always be canceled synchronously.
// If this function returns a non-nil error, infers that the requests must have been canceled synchronously, no matter
// how `asyncMode` is set.
//
// When the batch of requests is scheduled to be processed asynchronously, the `tbc` may need to be used by other
// goroutines later, and put it back to `td.batchBufferPool` after finishing using it. In this case, the `tbc` object
// cannot be used after this function returns. It can be considered as this function *consumed* the `tbc` object.
// To notify the caller about this, the function returns a bool to indicate whether the `tbc` is consumed.
func (td *tsoDispatcher) processRequests(
stream *tsoStream, dcLocation string, tbc *tsoBatchController, done chan struct{},
) error {
stream *tsoStream, dcLocation string, tbc *tsoBatchController, done chan struct{}, asyncMode bool,
) (bool, error) {
// `done` must be guaranteed to be eventually called.
var (
requests = tbc.getCollectedRequests()
Expand Down Expand Up @@ -501,35 +508,50 @@ func (td *tsoDispatcher) processRequests(
close(done)

defer td.batchBufferPool.Put(tbc)

td.handleTSOResultsForBatch(tbc, stream, result, reqKeyspaceGroupID, err)

Check warning on line 512 in client/tso_dispatcher.go

View check run for this annotation

Codecov / codecov/patch

client/tso_dispatcher.go#L512

Added line #L512 was not covered by tests
}

if asyncMode {
err := stream.ProcessRequestsAsync(
clusterID, keyspaceID, reqKeyspaceGroupID,
dcLocation, count, tbc.extraBatchingStartTime, cb)

Check warning on line 518 in client/tso_dispatcher.go

View check run for this annotation

Codecov / codecov/patch

client/tso_dispatcher.go#L516-L518

Added lines #L516 - L518 were not covered by tests
if err != nil {
close(done)

Check warning on line 520 in client/tso_dispatcher.go

View check run for this annotation

Codecov / codecov/patch

client/tso_dispatcher.go#L520

Added line #L520 was not covered by tests

td.cancelCollectedRequests(tbc, stream.streamID, err)
return
return false, err
} else {
return true, nil

Check warning on line 525 in client/tso_dispatcher.go

View check run for this annotation

Codecov / codecov/patch

client/tso_dispatcher.go#L523-L525

Added lines #L523 - L525 were not covered by tests
}
} else {
result, err := stream.ProcessRequests(
clusterID, keyspaceID, reqKeyspaceGroupID, dcLocation, count, tbc.extraBatchingStartTime)
close(done)

curTSOInfo := &tsoInfo{
tsoServer: stream.getServerURL(),
reqKeyspaceGroupID: reqKeyspaceGroupID,
respKeyspaceGroupID: result.respKeyspaceGroupID,
respReceivedAt: time.Now(),
physical: result.physical,
logical: result.logical,
}
// `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request.
firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits)
td.compareAndSwapTS(curTSOInfo, firstLogical)
td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits, stream.streamID)
td.handleTSOResultsForBatch(tbc, stream, result, reqKeyspaceGroupID, err)
return false, err
}
}

err := stream.processRequests(
clusterID, keyspaceID, reqKeyspaceGroupID,
dcLocation, count, tbc.extraBatchingStartTime, cb)
func (td *tsoDispatcher) handleTSOResultsForBatch(tbc *tsoBatchController, stream *tsoStream, result tsoRequestResult, reqKeyspaceGroupID uint32, err error) {
if err != nil {
close(done)

td.cancelCollectedRequests(tbc, stream.streamID, err)
return err
return
}
return nil

curTSOInfo := &tsoInfo{
tsoServer: stream.getServerURL(),
reqKeyspaceGroupID: reqKeyspaceGroupID,
respKeyspaceGroupID: result.respKeyspaceGroupID,
respReceivedAt: time.Now(),
physical: result.physical,
logical: result.logical,
}
// `logical` is the largest ts's logical part here, we need to do the subtracting before we finish each TSO request.
firstLogical := tsoutil.AddLogical(result.logical, -int64(result.count)+1, result.suffixBits)
td.compareAndSwapTS(curTSOInfo, firstLogical)
td.doneCollectedRequests(tbc, result.physical, firstLogical, result.suffixBits, stream.streamID)
}

func (td *tsoDispatcher) cancelCollectedRequests(tbc *tsoBatchController, streamID string, err error) {
Expand Down Expand Up @@ -580,3 +602,7 @@ func (td *tsoDispatcher) compareAndSwapTS(
}
td.lastTSOInfo = curTSOInfo
}

func (td *tsoDispatcher) useAsyncStream() bool {
return false

Check warning on line 607 in client/tso_dispatcher.go

View check run for this annotation

Codecov / codecov/patch

client/tso_dispatcher.go#L606-L607

Added lines #L606 - L607 were not covered by tests
}
111 changes: 71 additions & 40 deletions client/tso_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ type tsoStream struct {

pendingRequests chan batchedRequests

cancel context.CancelFunc
wg sync.WaitGroup
wg sync.WaitGroup

// For syncing between sender and receiver to guarantee all requests are finished when closing.
state atomic.Int32
Expand All @@ -230,18 +229,12 @@ const invalidStreamID = "<invalid>"

func newTSOStream(ctx context.Context, serverURL string, stream grpcTSOStreamAdapter) *tsoStream {
streamID := fmt.Sprintf("%s-%d", serverURL, streamIDAlloc.Add(1))
// To make error handling in `tsoDispatcher` work, the internal `cancel` and external `cancel` is better to be
// distinguished.
ctx, cancel := context.WithCancel(ctx)
s := &tsoStream{
serverURL: serverURL,
stream: stream,
streamID: streamID,

serverURL: serverURL,
stream: stream,
streamID: streamID,
pendingRequests: make(chan batchedRequests, 64),

cancel: cancel,

ongoingRequestCountGauge: ongoingRequestCountGauge.WithLabelValues(streamID),
}
s.wg.Add(1)
Expand All @@ -253,14 +246,70 @@ func (s *tsoStream) getServerURL() string {
return s.serverURL
}

// processRequests starts an RPC to get a batch of timestamps without waiting for the result. When the result is ready,
// it will be passed th `notifier.finish`.
// ProcessRequests starts an RPC to get a batch of timestamps and synchronously waits for the result.
//
// This function is NOT thread-safe. Don't call this function concurrently in multiple goroutines. Neither should this
// function be called concurrently with ProcessRequestsAsync.
//
// WARNING: The caller is responsible to guarantee that when using this function, there must NOT be any unfinished
// async requests started by ProcessRequestsAsync. Otherwise, it might cause multiple goroutines calling `RecvMsg`
// of gRPC, which is unsafe.
//
// Calling this may implicitly switch the stream to sync mode.
func (s *tsoStream) ProcessRequests(
clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time,
) (tsoRequestResult, error) {
start := time.Now()
if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil {
if err == io.EOF {
err = errs.ErrClientTSOStreamClosed
} else {
err = errors.WithStack(err)

Check warning on line 267 in client/tso_stream.go

View check run for this annotation

Codecov / codecov/patch

client/tso_stream.go#L264-L267

Added lines #L264 - L267 were not covered by tests
}
return tsoRequestResult{}, err

Check warning on line 269 in client/tso_stream.go

View check run for this annotation

Codecov / codecov/patch

client/tso_stream.go#L269

Added line #L269 was not covered by tests
}
tsoBatchSendLatency.Observe(time.Since(batchStartTime).Seconds())
res, err := s.stream.Recv()
duration := time.Since(start).Seconds()
if err != nil {
requestFailedDurationTSO.Observe(duration)
if err == io.EOF {
err = errs.ErrClientTSOStreamClosed

Check warning on line 277 in client/tso_stream.go

View check run for this annotation

Codecov / codecov/patch

client/tso_stream.go#L277

Added line #L277 was not covered by tests
} else {
err = errors.WithStack(err)
}
return tsoRequestResult{}, err
}
requestDurationTSO.Observe(duration)
tsoBatchSize.Observe(float64(count))

if res.count != uint32(count) {
err = errors.WithStack(errTSOLength)
return tsoRequestResult{}, err

Check warning on line 288 in client/tso_stream.go

View check run for this annotation

Codecov / codecov/patch

client/tso_stream.go#L287-L288

Added lines #L287 - L288 were not covered by tests
}

respKeyspaceGroupID := res.respKeyspaceGroupID
physical, logical, suffixBits := res.physical, res.logical, res.suffixBits
return tsoRequestResult{
physical: physical,
logical: logical,
count: uint32(count),
suffixBits: suffixBits,
respKeyspaceGroupID: respKeyspaceGroupID,
}, nil
}

// ProcessRequestsAsync starts an RPC to get a batch of timestamps without waiting for the result. When the result is ready,
// it will be passed th `notifier.finish`.
//
// This function is NOT thread-safe. Don't call this function concurrently in multiple goroutines.
// This function is NOT thread-safe. Don't call this function concurrently in multiple goroutines. Neither should this
// function be called concurrently with ProcessRequests.
//
// It's guaranteed that the `callback` will be called, but when the request is failed to be scheduled, the callback
// will be ignored.
func (s *tsoStream) processRequests(
//
// Calling this may implicitly switch the stream to async mode.
func (s *tsoStream) ProcessRequestsAsync(
clusterID uint64, keyspaceID, keyspaceGroupID uint32, dcLocation string, count int64, batchStartTime time.Time, callback onFinishedCallback,
) error {
start := time.Now()
Expand Down Expand Up @@ -300,11 +349,7 @@ func (s *tsoStream) processRequests(

if err := s.stream.Send(clusterID, keyspaceID, keyspaceGroupID, dcLocation, count); err != nil {
// As the request is already put into `pendingRequests`, the request should finally be canceled by the recvLoop.
// So skip returning error here to avoid
// if err == io.EOF {
// return errors.WithStack(errs.ErrClientTSOStreamClosed)
// }
// return errors.WithStack(err)
// So do not return error here to avoid double-cancelling.
log.Warn("failed to send RPC request through tsoStream", zap.String("stream", s.streamID), zap.Error(err))
return nil
}
Expand Down Expand Up @@ -335,7 +380,6 @@ func (s *tsoStream) recvLoop(ctx context.Context) {
}

s.stoppedWithErr.Store(&finishWithErr)
s.cancel()
for !s.state.CompareAndSwap(streamStateIdle, streamStateClosing) {
switch state := s.state.Load(); state {
case streamStateIdle, streamStateSending:
Expand Down Expand Up @@ -365,42 +409,29 @@ func (s *tsoStream) recvLoop(ctx context.Context) {

recvLoop:
for {
select {
case <-ctx.Done():
finishWithErr = context.Canceled
break recvLoop
default:
}

res, err := s.stream.Recv()

// Try to load the corresponding `batchedRequests`. If `Recv` is successful, there must be a request pending
// in the queue.
select {
case currentReq = <-s.pendingRequests:
hasReq = true
default:
hasReq = false
case <-ctx.Done():
finishWithErr = ctx.Err()
return
}

res, err := s.stream.Recv()

durationSeconds := time.Since(currentReq.startTime).Seconds()

if err != nil {
// If a request is pending and error occurs, observe the duration it has cost.
// Note that it's also possible that the stream is broken due to network without being requested. In this
// case, `Recv` may return an error while no request is pending.
if hasReq {
requestFailedDurationTSO.Observe(durationSeconds)
}
requestFailedDurationTSO.Observe(durationSeconds)
if err == io.EOF {
finishWithErr = errors.WithStack(errs.ErrClientTSOStreamClosed)
} else {
finishWithErr = errors.WithStack(err)
}
break recvLoop
} else if !hasReq {
finishWithErr = errors.New("tsoStream timing order broken")
break recvLoop
}

latencySeconds := durationSeconds
Expand Down
37 changes: 29 additions & 8 deletions client/tso_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,20 @@ type testTSOStreamSuite struct {

inner *mockTSOStreamImpl
stream *tsoStream

ctx context.Context
cancel context.CancelFunc
}

func (s *testTSOStreamSuite) SetupTest() {
s.re = require.New(s.T())
s.inner = newMockTSOStreamImpl(context.Background(), false)
s.stream = newTSOStream(context.Background(), mockStreamURL, s.inner)
s.ctx, s.cancel = context.WithCancel(context.Background())
s.inner = newMockTSOStreamImpl(s.ctx, false)
s.stream = newTSOStream(s.ctx, mockStreamURL, s.inner)
}

func (s *testTSOStreamSuite) TearDownTest() {
s.cancel()
s.inner.stop()
s.stream.WaitForClosed()
s.inner = nil
Expand Down Expand Up @@ -268,7 +273,7 @@ func (s *testTSOStreamSuite) getResult(ch <-chan callbackInvocation) callbackInv

func (s *testTSOStreamSuite) processRequestWithResultCh(count int64) (<-chan callbackInvocation, error) {
ch := make(chan callbackInvocation, 1)
err := s.stream.processRequests(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) {
err := s.stream.ProcessRequestsAsync(1, 2, 3, globalDCLocation, count, time.Now(), func(result tsoRequestResult, reqKeyspaceGroupID uint32, err error) {
if err == nil {
s.re.Equal(uint32(3), reqKeyspaceGroupID)
s.re.Equal(uint32(0), result.suffixBits)
Expand Down Expand Up @@ -320,7 +325,7 @@ func (s *testTSOStreamSuite) TestTSOStreamBasic() {

// After an error from the (simulated) RPC stream, the tsoStream should be in a broken status and can't accept
// new request anymore.
err := s.stream.processRequests(1, 2, 3, globalDCLocation, 1, time.Now(), func(_result tsoRequestResult, _reqKeyspaceGroupID uint32, _err error) {
err := s.stream.ProcessRequestsAsync(1, 2, 3, globalDCLocation, 1, time.Now(), func(_result tsoRequestResult, _reqKeyspaceGroupID uint32, _err error) {
panic("unreachable")
})
s.re.Error(err)
Expand All @@ -341,6 +346,20 @@ func (s *testTSOStreamSuite) testTSOStreamBrokenImpl(err error, pendingRequests
s.stream.WaitForClosed()
closedCh <- struct{}{}
}()

if pendingRequests == 0 {
// As the recvLoop retrieves the pending requests first before trying to receive from the stream, if there's
// no pending requests, it doesn't immediately detect the stream is broken, until when there's a new incoming
// request.
select {
case <-closedCh:
s.re.FailNow("stream receiver loop exists unexpectedly")
case <-time.After(time.Millisecond * 50):
}
ch := s.mustProcessRequestWithResultCh(1)
resultCh = append(resultCh, ch)
}

select {
case <-closedCh:
case <-time.After(time.Second):
Expand Down Expand Up @@ -443,7 +462,7 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() {
}

// After handling all these requests, the stream is ended by an EOF error. The next request won't succeed.
// So, either the `processRequests` function returns an error or the callback is called with an error.
// So, either the `ProcessRequestsAsync` function returns an error or the callback is called with an error.
ch, err := s.processRequestWithResultCh(1)
if err != nil {
s.re.ErrorIs(err, errs.ErrClientTSOStreamClosed)
Expand All @@ -457,9 +476,11 @@ func (s *testTSOStreamSuite) TestTSOStreamConcurrentRunning() {
func BenchmarkTSOStreamSendRecv(b *testing.B) {
log.SetLevel(zapcore.FatalLevel)

streamInner := newMockTSOStreamImpl(context.Background(), true)
stream := newTSOStream(context.Background(), mockStreamURL, streamInner)
ctx, cancel := context.WithCancel(context.Background())
streamInner := newMockTSOStreamImpl(ctx, true)
stream := newTSOStream(ctx, mockStreamURL, streamInner)
defer func() {
cancel()
streamInner.stop()
stream.WaitForClosed()
}()
Expand All @@ -469,7 +490,7 @@ func BenchmarkTSOStreamSendRecv(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
err := stream.processRequests(1, 1, 1, globalDCLocation, 1, now, func(result tsoRequestResult, _ uint32, err error) {
err := stream.ProcessRequestsAsync(1, 1, 1, globalDCLocation, 1, now, func(result tsoRequestResult, _ uint32, err error) {
if err != nil {
panic(err)
}
Expand Down

0 comments on commit 6af7a95

Please sign in to comment.