diff --git a/fleetspeak/src/common/fscontext/fscontext.go b/fleetspeak/src/common/fscontext/fscontext.go index 736a0a43..45a8006e 100644 --- a/fleetspeak/src/common/fscontext/fscontext.go +++ b/fleetspeak/src/common/fscontext/fscontext.go @@ -21,6 +21,7 @@ import ( "context" "fmt" "sync" + "time" ) // ErrStopRequested is the cancelation cause to be used when callers @@ -60,3 +61,30 @@ func WithDoneChan(ctx context.Context, cause error, done <-chan struct{}) (newCt wg.Wait() } } + +// AfterFuncDelayed behaves like context.AfterFunc, but with a delay. +// The returned stop function must be called eventually to free resources. +func AfterFuncDelayed(ctx context.Context, f func(), delay time.Duration) func() bool { + stopCh := make(chan func() bool) + + cancelInitiation := context.AfterFunc(ctx, func() { + timer := time.AfterFunc(delay, f) + stopCh <- timer.Stop + close(stopCh) + }) + + return func() bool { + if cancelInitiation() { + // context.AfterFunc will not fire + close(stopCh) + return true + } + + stop, ok := <-stopCh + if ok { + return stop() + } + + return false + } +} diff --git a/fleetspeak/src/common/fscontext/fscontext_test.go b/fleetspeak/src/common/fscontext/fscontext_test.go index 45e6c6af..15e2be3f 100644 --- a/fleetspeak/src/common/fscontext/fscontext_test.go +++ b/fleetspeak/src/common/fscontext/fscontext_test.go @@ -30,7 +30,7 @@ var ( var shortDuration = 100 * time.Millisecond -func TestWithDoneChanNotCanceled(t *testing.T) { +func TestWithDoneChan_NotCanceled(t *testing.T) { // Given doneCh := make(chan struct{}) defer close(doneCh) @@ -49,7 +49,7 @@ func TestWithDoneChanNotCanceled(t *testing.T) { } } -func TestWithDoneChanCanceledThroughChannel(t *testing.T) { +func TestWithDoneChan_CanceledThroughChannel(t *testing.T) { // Given doneCh := make(chan struct{}) @@ -73,7 +73,7 @@ func TestWithDoneChanCanceledThroughChannel(t *testing.T) { } } -func TestWithDoneChanCanceledThroughOuterContext(t *testing.T) { +func TestWithDoneChan_CanceledThroughOuterContext(t *testing.T) { // Given doneCh := make(chan struct{}) defer close(doneCh) @@ -100,7 +100,7 @@ func TestWithDoneChanCanceledThroughOuterContext(t *testing.T) { } } -func TestWithDoneChanCanceledThroughOwnContext(t *testing.T) { +func TestWithDoneChan_CanceledThroughOwnContext(t *testing.T) { // Given doneCh := make(chan struct{}) defer close(doneCh) @@ -124,3 +124,123 @@ func TestWithDoneChanCanceledThroughOwnContext(t *testing.T) { } } } + +func TestAfterFuncDelayed_DelayReached(t *testing.T) { + // Given + const delay = 100 * time.Millisecond + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + flag := false + setFlag := func() { + flag = true + } + + stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay) + defer stop() + + // When + cancel() + time.Sleep(2 * delay) + + // Then + if !flag { + t.Errorf("flag not set after %v", delay) + } + if stop() { + t.Errorf("stop() returned true, but flag was set") + } +} + +func TestAfterFuncDelayed_StopBeforeCancel(t *testing.T) { + // Given + const delay = 100 * time.Millisecond + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + flag := false + setFlag := func() { + flag = true + } + + stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay) + defer stop() + + // When + stopped := stop() + cancel() + time.Sleep(2 * delay) + + // Then + if !stopped { + t.Errorf("stop() returned false, but context was not canceled") + } + if flag { + t.Errorf("flag was set, but stop() was called before") + } + if stop() { + t.Errorf("stop() returned true, but stop() was called before") + } +} + +func TestAfterFuncDelayed_StopAfterCancel(t *testing.T) { + // Given + const delay = 100 * time.Millisecond + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + flag := false + setFlag := func() { + flag = true + } + + stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay) + defer stop() + + // When + cancel() + stopped := stop() + time.Sleep(2 * delay) + + // Then + if !stopped { + t.Errorf("stop() returned false, but delay was not reached") + } + if flag { + t.Errorf("flag was set, but stop() was called before") + } + if stop() { + t.Errorf("stop() returned true, but stop() was called before") + } +} + +func TestAfterFuncDelayed_StopAfterDelay(t *testing.T) { + // Given + const delay = 100 * time.Millisecond + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + flag := false + setFlag := func() { + flag = true + } + + stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay) + defer stop() + + // When + cancel() + time.Sleep(2 * delay) + stopped := stop() + + // Then + if stopped { + t.Errorf("stop() returned true, but delay was reached") + } + if !flag { + t.Errorf("flag was not set, but delay was reached") + } + if stop() { + t.Errorf("stop() returned true, but stop() was called before") + } +}