diff --git a/internal/internal_task_pollers.go b/internal/internal_task_pollers.go index 792bf6f88..1b1e6a67a 100644 --- a/internal/internal_task_pollers.go +++ b/internal/internal_task_pollers.go @@ -364,7 +364,7 @@ func (wtp *workflowTaskPoller) ProcessTask(task interface{}) error { } } -func (wtp *workflowTaskPoller) processWorkflowTask(task *workflowTask) error { +func (wtp *workflowTaskPoller) processWorkflowTask(task *workflowTask) (retErr error) { if task.task == nil { // We didn't have task, poll might have timeout. traceLog(func() { @@ -385,6 +385,20 @@ func (wtp *workflowTaskPoller) processWorkflowTask(task *workflowTask) error { } var taskErr error defer func() { + // If we panic during processing the workflow task, we need to unlock the workflow context with an error to discard it. + if p := recover(); p != nil { + topLine := fmt.Sprintf("workflow task for %s [panic]:", wtp.taskQueueName) + st := getStackTraceRaw(topLine, 7, 0) + wtp.logger.Error("Workflow task processing panic.", + tagWorkflowID, task.task.WorkflowExecution.GetWorkflowId(), + tagRunID, task.task.WorkflowExecution.GetRunId(), + tagWorkerType, task.task.GetWorkflowType().Name, + tagAttempt, task.task.Attempt, + tagPanicError, fmt.Sprintf("%v", p), + tagPanicStack, st) + taskErr = newPanicError(p, st) + retErr = taskErr + } wfctx.Unlock(taskErr) }() diff --git a/internal/internal_task_pollers_test.go b/internal/internal_task_pollers_test.go index d34a7e4f9..046e3aa0a 100644 --- a/internal/internal_task_pollers_test.go +++ b/internal/internal_task_pollers_test.go @@ -377,3 +377,58 @@ func TestWFTReset(t *testing.T) { cachedExecution = cache.getWorkflowContext(runID) require.True(t, originalCachedExecution == cachedExecution) } + +type panickingTaskHandler struct { + WorkflowTaskHandler +} + +func (wth *panickingTaskHandler) ProcessWorkflowTask( + task *workflowTask, + wfctx *workflowExecutionContextImpl, + hb workflowTaskHeartbeatFunc, +) (interface{}, error) { + panic("panickingTaskHandler") +} + +func TestWFTPanicInTaskHandler(t *testing.T) { + cache := NewWorkerCache() + params := workerExecutionParameters{cache: cache} + ensureRequiredParams(¶ms) + wfType := commonpb.WorkflowType{Name: t.Name() + "-workflow-type"} + reg := newRegistry() + reg.RegisterWorkflowWithOptions(func(ctx Context) error { + return nil + }, RegisterWorkflowOptions{ + Name: wfType.Name, + }) + var ( + taskQueue = taskqueuepb.TaskQueue{Name: t.Name() + "task-queue"} + startedAttrs = historypb.WorkflowExecutionStartedEventAttributes{ + TaskQueue: &taskQueue, + } + startedEvent = createTestEventWorkflowExecutionStarted(1, &startedAttrs) + history = historypb.History{Events: []*historypb.HistoryEvent{startedEvent}} + runID = t.Name() + "-run-id" + wfID = t.Name() + "-workflow-id" + wfe = commonpb.WorkflowExecution{RunId: runID, WorkflowId: wfID} + ctrl = gomock.NewController(t) + client = workflowservicemock.NewMockWorkflowServiceClient(ctrl) + innerTaskHandler = newWorkflowTaskHandler(params, nil, newRegistry()) + taskHandler = &panickingTaskHandler{WorkflowTaskHandler: innerTaskHandler} + contextManager = taskHandler + codec = binary.LittleEndian + pollResp0 = workflowservice.PollWorkflowTaskQueueResponse{ + Attempt: 1, + WorkflowExecution: &wfe, + WorkflowType: &wfType, + History: &history, + TaskToken: codec.AppendUint32(nil, 0), + } + task0 = workflowTask{task: &pollResp0} + ) + + poller := newWorkflowTaskPoller(taskHandler, contextManager, client, params) + require.Error(t, poller.processWorkflowTask(&task0)) + // Workflow should not be in cache + require.Nil(t, cache.getWorkflowContext(runID)) +}