diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index 7cbf77552..b0a5b5b35 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -157,6 +157,12 @@ type ( taskQueues map[string]struct{} } + updateResult struct { + success interface{} + err error + mu *sync.Mutex + } + // testWorkflowEnvironmentShared is the shared data between parent workflow and child workflow test environments testWorkflowEnvironmentShared struct { locker sync.Mutex @@ -229,6 +235,7 @@ type ( signalHandler func(name string, input *commonpb.Payloads, header *commonpb.Header) error queryHandler func(string, *commonpb.Payloads, *commonpb.Header) (*commonpb.Payloads, error) updateHandler func(name string, id string, input *commonpb.Payloads, header *commonpb.Header, resp UpdateCallbacks) + updateMap map[string]updateResult startedHandler func(r WorkflowExecution, e error) isWorkflowCompleted bool @@ -256,6 +263,13 @@ type ( *sessionEnvironmentImpl testWorkflowEnvironment *testWorkflowEnvironmentImpl } + + // UpdateCallbacksWrapper is a wrapper to UpdateCallbacks. It allows us to dedup duplicate update IDs in the test environment. + updateCallbacksWrapper struct { + uc UpdateCallbacks + env *testWorkflowEnvironmentImpl + updateID string + } ) func newTestWorkflowEnvironmentImpl(s *WorkflowTestSuite, parentRegistry *registry) *testWorkflowEnvironmentImpl { @@ -2910,10 +2924,32 @@ func (env *testWorkflowEnvironmentImpl) updateWorkflow(name string, id string, u if err != nil { panic(err) } - env.postCallback(func() { - // Do not send any headers on test invocations - env.updateHandler(name, id, data, nil, uc) - }, true) + + if env.updateMap == nil { + env.updateMap = make(map[string]updateResult) + } + + var ucWrapper = updateCallbacksWrapper{uc: uc, env: env, updateID: id} + + // check for duplicate update ID + if result, ok := env.updateMap[id]; ok { + result.mu.Lock() + // return cached result + env.postCallback(func() { + ucWrapper.uc.Accept() + ucWrapper.uc.Complete(result.success, result.err) + defer result.mu.Unlock() + }, false) + } else { + env.updateMap[id] = updateResult{nil, nil, &sync.Mutex{}} + env.updateMap[id].mu.Lock() + env.postCallback(func() { + // Do not send any headers on test invocations + env.updateHandler(name, id, data, nil, ucWrapper) + defer env.updateMap[id].mu.Unlock() + }, true) + } + } func (env *testWorkflowEnvironmentImpl) updateWorkflowByID(workflowID, name, id string, uc UpdateCallbacks, args ...interface{}) error { @@ -2925,9 +2961,30 @@ func (env *testWorkflowEnvironmentImpl) updateWorkflowByID(workflowID, name, id if err != nil { panic(err) } - workflowHandle.env.postCallback(func() { - workflowHandle.env.updateHandler(name, id, data, nil, uc) - }, true) + + if env.updateMap == nil { + env.updateMap = make(map[string]updateResult) + } + + var ucWrapper = updateCallbacksWrapper{uc: uc, env: env, updateID: id} + + // Check for duplicate update ID + if result, ok := env.updateMap[id]; ok { + result.mu.Lock() + workflowHandle.env.postCallback(func() { + ucWrapper.uc.Accept() + ucWrapper.uc.Complete(result.success, result.err) + defer result.mu.Unlock() + }, false) + } else { + env.updateMap[id] = updateResult{nil, nil, &sync.Mutex{}} + env.updateMap[id].mu.Lock() + workflowHandle.env.postCallback(func() { + workflowHandle.env.updateHandler(name, id, data, nil, ucWrapper) + defer env.updateMap[id].mu.Unlock() + }, true) + } + return nil } @@ -3068,6 +3125,29 @@ func mockFnGetVersion(string, Version, Version) Version { // make sure interface is implemented var _ WorkflowEnvironment = (*testWorkflowEnvironmentImpl)(nil) +func (uc updateCallbacksWrapper) Accept() { + uc.uc.Accept() +} + +func (uc updateCallbacksWrapper) Reject(err error) { + uc.uc.Reject(err) +} + +func (uc updateCallbacksWrapper) Complete(success interface{}, err error) { + // cache update result so we can dedup duplicate update IDs + if uc.env == nil { + panic("env is needed in updateCallback to cache update results for deduping purposes") + } + if result, ok := uc.env.updateMap[uc.updateID]; ok { + result.success = success + result.err = err + uc.env.updateMap[uc.updateID] = result + } else { + panic("updateMap[updateID] should already be created from updateWorkflow()") + } + uc.uc.Complete(success, err) +} + func (h *testNexusOperationHandle) newStartTask() *workflowservice.PollNexusTaskQueueResponse { return &workflowservice.PollNexusTaskQueueResponse{ TaskToken: []byte{}, diff --git a/internal/workflow_testsuite_test.go b/internal/workflow_testsuite_test.go index 3fc46146b..1cbb08cc3 100644 --- a/internal/workflow_testsuite_test.go +++ b/internal/workflow_testsuite_test.go @@ -491,6 +491,61 @@ func TestWorkflowUpdateOrderAcceptReject(t *testing.T) { require.Equal(t, "unknown update bad update. KnownUpdates=[update]", updateRejectionErr.Error()) } +func TestWorkflowDuplicateIDDedup(t *testing.T) { + var suite WorkflowTestSuite + // Test dev server dedups UpdateWorkflow with same ID + env := suite.NewTestWorkflowEnvironment() + env.RegisterDelayedCallback(func() { + env.UpdateWorkflow("update", "id", &updateCallback{ + reject: func(err error) { + require.Fail(t, fmt.Sprintf("update should not be rejected, err: %v", err)) + }, + accept: func() { + }, + complete: func(result interface{}, err error) { + intResult, ok := result.(int) + if !ok { + require.Fail(t, "result should be int") + } else { + require.Equal(t, 0, intResult) + } + }, + }, 0) + }, 0) + + env.RegisterDelayedCallback(func() { + env.UpdateWorkflow("update", "id", &updateCallback{ + reject: func(err error) { + require.Fail(t, fmt.Sprintf("update should not be rejected, err: %v", err)) + }, + accept: func() { + }, + complete: func(result interface{}, err error) { + intResult, ok := result.(int) + if !ok { + require.Fail(t, "result should be int") + } else { + // if dedup, this be okay, even if we pass in 1 as arg, since it's deduping, + // the result should match the first update's result, 0 + require.Equal(t, 0, intResult) + } + }, + }, 1) + + }, 1*time.Millisecond) + + env.ExecuteWorkflow(func(ctx Context) error { + err := SetUpdateHandler(ctx, "update", func(ctx Context, i int) (int, error) { + return i, nil + }, UpdateHandlerOptions{}) + if err != nil { + return err + } + return Sleep(ctx, time.Hour) + }) + require.NoError(t, env.GetWorkflowError()) +} + func TestAllHandlersFinished(t *testing.T) { var suite WorkflowTestSuite env := suite.NewTestWorkflowEnvironment()