Skip to content

Commit

Permalink
Add fscontext.AfterFuncDelayed function
Browse files Browse the repository at this point in the history
This is useful in scenarios where a context gets canceled, but the code bound to the lifetime of the context does not return within a reasonable time. `AfterFuncDelayed` can be used to try exit the code more aggressively.

PiperOrigin-RevId: 686477044
  • Loading branch information
torsm authored and copybara-github committed Oct 16, 2024
1 parent 19279ad commit 266062c
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 4 deletions.
28 changes: 28 additions & 0 deletions fleetspeak/src/common/fscontext/fscontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"context"
"fmt"
"sync"
"time"
)

// ErrStopRequested is the cancelation cause to be used when callers
Expand Down Expand Up @@ -60,3 +61,30 @@ func WithDoneChan(ctx context.Context, cause error, done <-chan struct{}) (newCt
wg.Wait()
}
}

// AfterFuncDelayed behaves like context.AfterFunc, but with a delay.
// The returned stop function must be called eventually to free resources.
func AfterFuncDelayed(ctx context.Context, f func(), delay time.Duration) func() bool {
stopCh := make(chan func() bool)

cancelInitiation := context.AfterFunc(ctx, func() {
timer := time.AfterFunc(delay, f)
stopCh <- timer.Stop
close(stopCh)
})

return func() bool {
if cancelInitiation() {
// context.AfterFunc will not fire
close(stopCh)
return true
}

stop, ok := <-stopCh
if ok {
return stop()
}

return false
}
}
129 changes: 125 additions & 4 deletions fleetspeak/src/common/fscontext/fscontext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package fscontext_test
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"

Expand All @@ -30,7 +31,7 @@ var (

var shortDuration = 100 * time.Millisecond

func TestWithDoneChanNotCanceled(t *testing.T) {
func TestWithDoneChan_NotCanceled(t *testing.T) {
// Given
doneCh := make(chan struct{})
defer close(doneCh)
Expand All @@ -49,7 +50,7 @@ func TestWithDoneChanNotCanceled(t *testing.T) {
}
}

func TestWithDoneChanCanceledThroughChannel(t *testing.T) {
func TestWithDoneChan_CanceledThroughChannel(t *testing.T) {
// Given
doneCh := make(chan struct{})

Expand All @@ -73,7 +74,7 @@ func TestWithDoneChanCanceledThroughChannel(t *testing.T) {
}
}

func TestWithDoneChanCanceledThroughOuterContext(t *testing.T) {
func TestWithDoneChan_CanceledThroughOuterContext(t *testing.T) {
// Given
doneCh := make(chan struct{})
defer close(doneCh)
Expand All @@ -100,7 +101,7 @@ func TestWithDoneChanCanceledThroughOuterContext(t *testing.T) {
}
}

func TestWithDoneChanCanceledThroughOwnContext(t *testing.T) {
func TestWithDoneChan_CanceledThroughOwnContext(t *testing.T) {
// Given
doneCh := make(chan struct{})
defer close(doneCh)
Expand All @@ -124,3 +125,123 @@ func TestWithDoneChanCanceledThroughOwnContext(t *testing.T) {
}
}
}

func TestAfterFuncDelayed_DelayReached(t *testing.T) {
// Given
const delay = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

flag := atomic.Bool{}
setFlag := func() {
flag.Store(true)
}

stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay)
defer stop()

// When
cancel()
time.Sleep(2 * delay)

// Then
if !flag.Load() {
t.Errorf("flag not set after %v", delay)
}
if stop() {
t.Errorf("stop() returned true, but flag was set")
}
}

func TestAfterFuncDelayed_StopBeforeCancel(t *testing.T) {
// Given
const delay = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

flag := atomic.Bool{}
setFlag := func() {
flag.Store(true)
}

stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay)
defer stop()

// When
stopped := stop()
cancel()
time.Sleep(2 * delay)

// Then
if !stopped {
t.Errorf("stop() returned false, but context was not canceled")
}
if flag.Load() {
t.Errorf("flag was set, but stop() was called before")
}
if stop() {
t.Errorf("stop() returned true, but stop() was called before")
}
}

func TestAfterFuncDelayed_StopAfterCancel(t *testing.T) {
// Given
const delay = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

flag := atomic.Bool{}
setFlag := func() {
flag.Store(true)
}

stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay)
defer stop()

// When
cancel()
stopped := stop()
time.Sleep(2 * delay)

// Then
if !stopped {
t.Errorf("stop() returned false, but delay was not reached")
}
if flag.Load() {
t.Errorf("flag was set, but stop() was called before")
}
if stop() {
t.Errorf("stop() returned true, but stop() was called before")
}
}

func TestAfterFuncDelayed_StopAfterDelay(t *testing.T) {
// Given
const delay = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

flag := atomic.Bool{}
setFlag := func() {
flag.Store(true)
}

stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay)
defer stop()

// When
cancel()
time.Sleep(2 * delay)
stopped := stop()

// Then
if stopped {
t.Errorf("stop() returned true, but delay was reached")
}
if !flag.Load() {
t.Errorf("flag was not set, but delay was reached")
}
if stop() {
t.Errorf("stop() returned true, but stop() was called before")
}
}

0 comments on commit 266062c

Please sign in to comment.