From e5f2c8e3e04a5586df200251d76b42491c92ae43 Mon Sep 17 00:00:00 2001 From: Morten Torkildsen Date: Sat, 29 Jan 2022 16:35:10 -0800 Subject: [PATCH] Avoid goroutine being blocked forever in WaitTask --- pkg/apply/taskrunner/task.go | 11 ++++++++--- pkg/apply/taskrunner/task_test.go | 4 +++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pkg/apply/taskrunner/task.go b/pkg/apply/taskrunner/task.go index a48a707c..a1a9bd29 100644 --- a/pkg/apply/taskrunner/task.go +++ b/pkg/apply/taskrunner/task.go @@ -4,6 +4,7 @@ package taskrunner import ( + "context" "fmt" "reflect" "time" @@ -101,12 +102,16 @@ func (w *WaitTask) Start(taskContext *TaskContext) { // the WaitTask struct. Once the timer expires, it will send // a message on the EventChannel provided in the taskContext. func (w *WaitTask) setTimer(taskContext *TaskContext) { - timer := time.NewTimer(w.Timeout) + ctx, cancel := context.WithTimeout(context.Background(), w.Timeout) go func() { // TODO(mortent): See if there is a better way to do this. This // solution will cause the goroutine to hang forever if the // Timeout is cancelled. - <-timer.C + <-ctx.Done() + // If the context was cancelled, we don't want to send any events. + if ctx.Err() == context.Canceled { + return + } select { // We only send the TimeoutError to the eventChannel if no one has gotten // to the token first. @@ -130,7 +135,7 @@ func (w *WaitTask) setTimer(taskContext *TaskContext) { } }() w.cancelFunc = func() { - timer.Stop() + cancel() } } diff --git a/pkg/apply/taskrunner/task_test.go b/pkg/apply/taskrunner/task_test.go index 3504ea8d..ed8bf01e 100644 --- a/pkg/apply/taskrunner/task_test.go +++ b/pkg/apply/taskrunner/task_test.go @@ -60,8 +60,10 @@ func TestWaitTask_TimeoutCancelled(t *testing.T) { timer := time.NewTimer(3 * time.Second) select { + case res := <-taskContext.EventChannel(): + t.Errorf("didn't expect error on eventChannel, but got %v", res) case res := <-taskContext.TaskChannel(): - t.Errorf("didn't expect timeout error, but got %v", res.Err) + t.Errorf("didn't expect event on taskChannel, but got %v", res) case <-timer.C: return }