From 933c2e813e2b2b59444e6055137fd57d5aec8f0e Mon Sep 17 00:00:00 2001 From: Ilya Date: Wed, 24 Aug 2022 23:47:56 +0100 Subject: [PATCH] Proof of concept of using TestStepInputOutput that hides using channels behind interface. Previously each TestStep operated on targets using channels inside TestStepChannels. Which introduced an ability to close output channel as well as it didn't allow to override behaviour when TestStep obtains a new Target. Signed-off-by: Ilya --- pkg/runner/base_test_suite_test.go | 2 +- pkg/runner/job_runner_test.go | 14 +- pkg/runner/step_runner.go | 76 +- pkg/runner/step_runner_test.go | 23 +- pkg/runner/test_runner_test.go | 29 +- pkg/runner/test_step_input.go | 46 + pkg/test/step.go | 10 +- plugins/teststeps/cmd/cmd.go | 4 +- plugins/teststeps/cpucmd/cpucmd.go | 5 +- plugins/teststeps/echo/echo.go | 34 +- plugins/teststeps/example/example.go | 4 +- plugins/teststeps/exec/exec.go | 4 +- plugins/teststeps/gathercmd/gathercmd.go | 45 +- plugins/teststeps/randecho/randecho.go | 4 +- .../teststeps/s3fileupload/s3fileupload.go | 4 +- plugins/teststeps/sleep/sleep.go | 4 +- plugins/teststeps/sshcmd/sshcmd.go | 4 +- .../terminalexpect/terminalexpect.go | 4 +- plugins/teststeps/teststeps.go | 83 +- plugins/teststeps/teststeps_test.go | 848 ++++++++---------- plugins/teststeps/waitport/waitport.go | 4 +- plugins/teststeps/waitport/waitport_test.go | 25 +- .../mocks/test_step_input_output_mock.go | 55 ++ .../teststeps/badtargets/badtargets.go | 87 +- tests/plugins/teststeps/channels/channels.go | 56 -- tests/plugins/teststeps/hanging/hanging.go | 2 +- tests/plugins/teststeps/noreturn/noreturn.go | 15 +- .../plugins/teststeps/panicstep/panicstep.go | 2 +- tests/plugins/teststeps/teststep/teststep.go | 4 +- 29 files changed, 713 insertions(+), 784 deletions(-) create mode 100644 pkg/runner/test_step_input.go create mode 100644 tests/common/mocks/test_step_input_output_mock.go delete mode 100644 tests/plugins/teststeps/channels/channels.go diff --git a/pkg/runner/base_test_suite_test.go b/pkg/runner/base_test_suite_test.go index 6f3d4edf..5f884595 100644 --- a/pkg/runner/base_test_suite_test.go +++ b/pkg/runner/base_test_suite_test.go @@ -73,7 +73,7 @@ func (s *BaseTestSuite) TearDownTest() { } func (s *BaseTestSuite) RegisterStateFullStep( - runFunction func(ctx xcontext.Context, ch test.TestStepChannels, ev testevent.Emitter, + runFunction func(ctx xcontext.Context, io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage) (json.RawMessage, error), validateFunction func(ctx xcontext.Context, params test.TestStepParameters) error) error { diff --git a/pkg/runner/job_runner_test.go b/pkg/runner/job_runner_test.go index fa34658e..e04df640 100644 --- a/pkg/runner/job_runner_test.go +++ b/pkg/runner/job_runner_test.go @@ -58,9 +58,9 @@ func (s *JobRunnerSuite) TestSimpleJobStartFinish() { var resultTargets []*target.Target require.NoError(s.T(), s.RegisterStateFullStep( - func(ctx xcontext.Context, ch test.TestStepChannels, ev testevent.Emitter, + func(ctx xcontext.Context, io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage) (json.RawMessage, error) { - return teststeps.ForEachTarget(stateFullStepName, ctx, ch, func(ctx xcontext.Context, target *target.Target) error { + return teststeps.ForEachTarget(stateFullStepName, ctx, io, func(ctx xcontext.Context, target *target.Target) error { assert.NotNil(s.T(), target) mu.Lock() defer mu.Unlock() @@ -125,9 +125,9 @@ func (s *JobRunnerSuite) TestJobWithTestRetry() { var callsCount int require.NoError(s.T(), s.RegisterStateFullStep( - func(ctx xcontext.Context, ch test.TestStepChannels, ev testevent.Emitter, + func(ctx xcontext.Context, io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage) (json.RawMessage, error) { - return teststeps.ForEachTarget(stateFullStepName, ctx, ch, func(ctx xcontext.Context, target *target.Target) error { + return teststeps.ForEachTarget(stateFullStepName, ctx, io, func(ctx xcontext.Context, target *target.Target) error { assert.NotNil(s.T(), target) mu.Lock() defer mu.Unlock() @@ -456,7 +456,7 @@ func (s *JobRunnerSuite) TestResumeStateBadJobId() { const stateFullStepName = "statefull" type stateFullStep struct { - runFunction func(ctx xcontext.Context, ch test.TestStepChannels, ev testevent.Emitter, + runFunction func(ctx xcontext.Context, io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage) (json.RawMessage, error) validateFunction func(ctx xcontext.Context, params test.TestStepParameters) error } @@ -467,7 +467,7 @@ func (sfs *stateFullStep) Name() string { func (sfs *stateFullStep) Run( ctx xcontext.Context, - ch test.TestStepChannels, + io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, @@ -476,7 +476,7 @@ func (sfs *stateFullStep) Run( if sfs.runFunction == nil { return nil, fmt.Errorf("stateFullStep run is not initialised") } - return sfs.runFunction(ctx, ch, ev, stepsVars, params, resumeState) + return sfs.runFunction(ctx, io, ev, stepsVars, params, resumeState) } func (sfs *stateFullStep) ValidateParameters(ctx xcontext.Context, params test.TestStepParameters) error { diff --git a/pkg/runner/step_runner.go b/pkg/runner/step_runner.go index 16de886d..42171f91 100644 --- a/pkg/runner/step_runner.go +++ b/pkg/runner/step_runner.go @@ -30,7 +30,7 @@ type StepResult struct { type StepRunner struct { mu sync.Mutex - input chan *target.Target + targetsCh chan targetInput inputWg sync.WaitGroup activeTargets map[string]*stepTargetInfo @@ -64,22 +64,13 @@ func (str *resultNotifier) postResult(err error) { } type stepTargetInfo struct { - targetInEmitted bool - result *resultNotifier -} - -func (sti *stepTargetInfo) acquireTargetInEmission() bool { - if sti.targetInEmitted { - return false - } - sti.targetInEmitted = true - return true + result *resultNotifier } // NewStepRunner creates a new StepRunner object func NewStepRunner() *StepRunner { return &StepRunner{ - input: make(chan *target.Target), + targetsCh: make(chan targetInput), activeTargets: make(map[string]*stepTargetInfo), notifyStopped: newResultNotifier(), stopped: make(chan struct{}), @@ -105,8 +96,7 @@ func (sr *StepRunner) Run( var resumedTargetsResults []ChanNotifier for _, resumeTarget := range resumeStateTargets { targetInfo := &stepTargetInfo{ - targetInEmitted: true, - result: newResultNotifier(), + result: newResultNotifier(), } sr.activeTargets[resumeTarget.ID] = targetInfo resumedTargetsResults = append(resumedTargetsResults, targetInfo.result) @@ -131,9 +121,23 @@ func (sr *StepRunner) Run( } stepOut := make(chan test.TestStepResult) + stepIO := newTestStepInputOutput(sr.targetsCh, func(_ctx xcontext.Context, tgt target.Target, err error) error { + var resultErr error + select { + case stepOut <- test.TestStepResult{Target: &tgt, Err: err}: + return nil + case <-_ctx.Done(): + resultErr = _ctx.Err() + case <-ctx.Done(): + resultErr = ctx.Err() + } + ctx.Debugf("canceled while reporting target '%s' result: %v", tgt.ID, err) + return resultErr + }) + go func() { defer finish() - sr.runningLoop(ctx, sr.input, stepOut, bundle, stepsVariables, ev, resumeState) + sr.runningLoop(ctx, stepIO, stepOut, bundle, stepsVariables, ev, resumeState) ctx.Debugf("Running loop finished") }() @@ -169,6 +173,12 @@ func (sr *StepRunner) addTarget( return nil, fmt.Errorf("step runner was stopped") } + onTargetConsumed := func() { + if err := emitEvent(ctx, ev, target.EventTargetIn, tgt, nil); err != nil { + sr.setErrLocked(ctx, fmt.Errorf("failed to report target injection: %w", err)) + } + } + targetInfo, err := func() (*stepTargetInfo, error) { targetInfo, err := func() (*stepTargetInfo, error) { sr.mu.Lock() @@ -190,17 +200,7 @@ func (sr *StepRunner) addTarget( defer sr.inputWg.Done() select { - case sr.input <- tgt: - // we should always emit TargetIn before TargetOut or TargetError - // we have a race condition that outputLoop may receive result for this target first - // in that case we will emit TargetIn in outputLoop and should not emit it here - sr.mu.Lock() - if targetInfo.acquireTargetInEmission() { - if err := emitEvent(ctx, ev, target.EventTargetIn, tgt, nil); err != nil { - sr.setErrLocked(ctx, fmt.Errorf("failed to report target injection: %w", err)) - } - } - sr.mu.Unlock() + case sr.targetsCh <- targetInput{tgt: *tgt, onConsumed: onTargetConsumed}: return targetInfo, nil case <-stopped: return nil, fmt.Errorf("step runner was stopped") @@ -273,7 +273,7 @@ func (sr *StepRunner) Stop() { } sr.inputWg.Wait() - close(sr.input) + close(sr.targetsCh) } func (sr *StepRunner) outputLoop( @@ -314,37 +314,28 @@ func (sr *StepRunner) outputLoop( } ctx.Infof("Obtained '%v' for target '%s'", res, res.Target.ID) - shouldEmitTargetIn, targetResult, err := func() (bool, *resultNotifier, error) { + targetResult, err := func() (*resultNotifier, error) { sr.mu.Lock() defer sr.mu.Unlock() info, found := sr.activeTargets[res.Target.ID] if !found { - return false, nil, &cerrors.ErrTestStepReturnedUnexpectedResult{ + return nil, &cerrors.ErrTestStepReturnedUnexpectedResult{ StepName: testStepLabel, Target: res.Target.ID, } } if info == nil { - return false, nil, &cerrors.ErrTestStepReturnedDuplicateResult{StepName: testStepLabel, Target: res.Target.ID} + return nil, &cerrors.ErrTestStepReturnedDuplicateResult{StepName: testStepLabel, Target: res.Target.ID} } sr.activeTargets[res.Target.ID] = nil - - shouldEmitTargetIn := info.acquireTargetInEmission() - return shouldEmitTargetIn, info.result, nil + return info.result, nil }() if err != nil { sr.setErr(ctx, err) return } - if shouldEmitTargetIn { - if err := emitEvent(ctx, ev, target.EventTargetIn, res.Target, nil); err != nil { - sr.setErr(ctx, fmt.Errorf("failed to report target injection: %w", err)) - return - } - } - if res.Err == nil { err = emitEvent(ctx, ev, target.EventTargetOut, res.Target, nil) } else { @@ -365,7 +356,7 @@ func (sr *StepRunner) outputLoop( func (sr *StepRunner) runningLoop( ctx xcontext.Context, - stepIn <-chan *target.Target, + stepIO *testStepInputOutput, stepOut chan test.TestStepResult, bundle test.TestStepBundle, stepsVariables test.StepsVariables, @@ -397,8 +388,7 @@ func (sr *StepRunner) runningLoop( } }() - inChannels := test.TestStepChannels{In: stepIn, Out: stepOut} - return bundle.TestStep.Run(ctx, inChannels, ev, stepsVariables, bundle.Parameters, resumeState) + return bundle.TestStep.Run(ctx, stepIO, ev, stepsVariables, bundle.Parameters, resumeState) }() ctx.Debugf("TestStep finished '%v', rs: '%s'", err, string(resultResumeState)) diff --git a/pkg/runner/step_runner_test.go b/pkg/runner/step_runner_test.go index 788e0846..5feed700 100644 --- a/pkg/runner/step_runner_test.go +++ b/pkg/runner/step_runner_test.go @@ -57,10 +57,10 @@ func (s *StepRunnerSuite) TestRunningStep() { var obtainedResumeState json.RawMessage err := s.RegisterStateFullStep( - func(ctx xcontext.Context, ch test.TestStepChannels, ev testevent.Emitter, + func(ctx xcontext.Context, io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage) (json.RawMessage, error) { obtainedResumeState = resumeState - _, err := teststeps.ForEachTarget(stateFullStepName, ctx, ch, func(ctx xcontext.Context, target *target.Target) error { + _, err := teststeps.ForEachTarget(stateFullStepName, ctx, io, func(ctx xcontext.Context, target *target.Target) error { require.NotNil(s.T(), target) mu.Lock() @@ -129,9 +129,9 @@ func (s *StepRunnerSuite) TestAddSameTargetSequentiallyTimes() { const inputTargetID = "input_target_id" err := s.RegisterStateFullStep( - func(ctx xcontext.Context, ch test.TestStepChannels, ev testevent.Emitter, + func(ctx xcontext.Context, io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage) (json.RawMessage, error) { - _, err := teststeps.ForEachTarget(stateFullStepName, ctx, ch, func(ctx xcontext.Context, target *target.Target) error { + _, err := teststeps.ForEachTarget(stateFullStepName, ctx, io, func(ctx xcontext.Context, target *target.Target) error { require.NotNil(s.T(), target) require.Equal(s.T(), inputTargetID, target.ID) return nil @@ -184,11 +184,14 @@ func (s *StepRunnerSuite) TestAddTargetReturnsErrorIfFailsToInput() { } }() err := s.RegisterStateFullStep( - func(ctx xcontext.Context, ch test.TestStepChannels, ev testevent.Emitter, + func(ctx xcontext.Context, io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage) (json.RawMessage, error) { <-hangCh - for range ch.In { - require.Fail(s.T(), "unexpected input") + for { + tgt, err := io.Get(ctx) + require.NoError(s.T(), err) + require.Nil(s.T(), tgt, "unexpected input") + break } return nil, nil }, @@ -244,7 +247,7 @@ func (s *StepRunnerSuite) TestStepPanics() { defer cancel() err := s.RegisterStateFullStep( - func(ctx xcontext.Context, ch test.TestStepChannels, ev testevent.Emitter, + func(ctx xcontext.Context, ch test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage) (json.RawMessage, error) { panic("panic") }, @@ -296,9 +299,9 @@ func (s *StepRunnerSuite) TestCornerCases() { defer cancel() err := s.RegisterStateFullStep( - func(ctx xcontext.Context, ch test.TestStepChannels, ev testevent.Emitter, + func(ctx xcontext.Context, in test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage) (json.RawMessage, error) { - _, err := teststeps.ForEachTarget(stateFullStepName, ctx, ch, func(ctx xcontext.Context, target *target.Target) error { + _, err := teststeps.ForEachTarget(stateFullStepName, ctx, in, func(ctx xcontext.Context, target *target.Target) error { return fmt.Errorf("should not be called") }) return nil, err diff --git a/pkg/runner/test_runner_test.go b/pkg/runner/test_runner_test.go index 77721047..4d84bc09 100644 --- a/pkg/runner/test_runner_test.go +++ b/pkg/runner/test_runner_test.go @@ -31,7 +31,6 @@ import ( "github.com/linuxboot/contest/tests/common" "github.com/linuxboot/contest/tests/common/goroutine_leak_check" "github.com/linuxboot/contest/tests/plugins/teststeps/badtargets" - "github.com/linuxboot/contest/tests/plugins/teststeps/channels" "github.com/linuxboot/contest/tests/plugins/teststeps/hanging" "github.com/linuxboot/contest/tests/plugins/teststeps/noreturn" "github.com/linuxboot/contest/tests/plugins/teststeps/panicstep" @@ -86,7 +85,6 @@ func (s *TestRunnerSuite) SetupTest() { events []event.Name }{ {badtargets.Name, badtargets.New, badtargets.Events}, - {channels.Name, channels.New, channels.Events}, {hanging.Name, hanging.New, hanging.Events}, {noreturn.Name, noreturn.New, noreturn.Events}, {panicstep.Name, panicstep.New, panicstep.Events}, @@ -332,29 +330,6 @@ func (s *TestRunnerSuite) TestStepPanics() { require.Contains(s.T(), s.MemoryStorage.GetStepEvents(ctx, testName, "Step1"), "step Step1 paniced") } -// A misbehaving step that closes its output channel. -func (s *TestRunnerSuite) TestStepClosesChannels() { - ctx, cancel := logrusctx.NewContext(logger.LevelDebug) - defer cancel() - - tr := newTestRunner() - _, _, err := s.runWithTimeout(ctx, tr, nil, 1, 2*time.Second, - []*target.Target{tgt("T1")}, - []test.TestStepBundle{ - s.NewStep(ctx, "Step1", channels.Name, nil), - }, - ) - require.Error(s.T(), err) - require.IsType(s.T(), &cerrors.ErrTestStepClosedChannels{}, err) - require.Equal(s.T(), ` -{[1 1 SimpleTest 0 Step1][Target{ID: "T1"} TargetIn]} -{[1 1 SimpleTest 0 Step1][Target{ID: "T1"} TargetOut]} -`, s.MemoryStorage.GetTargetEvents(ctx, testName, "T1")) - require.Equal(s.T(), ` -{[1 1 SimpleTest 0 Step1][(*Target)(nil) TestError &"\"test step Step1 closed output channels (api violation)\""]} -`, s.MemoryStorage.GetStepEvents(ctx, testName, "Step1")) -} - // A misbehaving step that yields a result for a target that does not exist. func (s *TestRunnerSuite) TestStepYieldsResultForNonexistentTarget() { ctx, cancel := logrusctx.NewContext(logger.LevelDebug) @@ -480,13 +455,13 @@ func (s *TestRunnerSuite) TestVariables() { ) require.NoError(s.T(), s.RegisterStateFullStep( func(ctx xcontext.Context, - ch test.TestStepChannels, + io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage, ) (json.RawMessage, error) { - _, err := teststeps.ForEachTargetWithResume(ctx, ch, resumeState, 1, + _, err := teststeps.ForEachTargetWithResume(ctx, io, resumeState, 1, func(ctx xcontext.Context, target *teststeps.TargetWithData) error { require.NoError(s.T(), stepsVars.Add(target.Target.ID, "target_id", target.Target.ID)) diff --git a/pkg/runner/test_step_input.go b/pkg/runner/test_step_input.go new file mode 100644 index 00000000..b69f7925 --- /dev/null +++ b/pkg/runner/test_step_input.go @@ -0,0 +1,46 @@ +package runner + +import ( + "github.com/linuxboot/contest/pkg/target" + "github.com/linuxboot/contest/pkg/xcontext" +) + +type onTargetResult func(ctx xcontext.Context, tgt target.Target, err error) error + +// TestStepChannels represents the input and output channels used by a TestStep +// to communicate with the TestRunner +type testStepInputOutput struct { + targetsCh chan targetInput + onTargetResult onTargetResult +} + +func newTestStepInputOutput(targetsCh chan targetInput, onTargetResult onTargetResult) *testStepInputOutput { + return &testStepInputOutput{ + targetsCh: targetsCh, + onTargetResult: onTargetResult, + } +} + +type targetInput struct { + tgt target.Target + onConsumed func() +} + +func (tsi *testStepInputOutput) Get(ctx xcontext.Context) (*target.Target, error) { + select { + case in, ok := <-tsi.targetsCh: + if !ok { + return nil, nil + } + if in.onConsumed != nil { + in.onConsumed() + } + return &in.tgt, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (tsi *testStepInputOutput) Report(ctx xcontext.Context, tgt target.Target, err error) error { + return tsi.onTargetResult(ctx, tgt, err) +} diff --git a/pkg/test/step.go b/pkg/test/step.go index e20c1a3c..2f9702a5 100644 --- a/pkg/test/step.go +++ b/pkg/test/step.go @@ -102,11 +102,9 @@ type TestStepResult struct { Err error } -// TestStepChannels represents the input and output channels used by a TestStep -// to communicate with the TestRunner -type TestStepChannels struct { - In <-chan *target.Target - Out chan<- TestStepResult +type TestStepInputOutput interface { + Get(ctx xcontext.Context) (*target.Target, error) + Report(ctx xcontext.Context, tgt target.Target, err error) error } // StepsVariablesReader represents a read access for step variables @@ -136,7 +134,7 @@ type TestStep interface { // Name returns the name of the step Name() string // Run runs the test step. The test step is expected to be synchronous. - Run(ctx xcontext.Context, ch TestStepChannels, ev testevent.Emitter, + Run(ctx xcontext.Context, inputOutput TestStepInputOutput, ev testevent.Emitter, stepsVars StepsVariables, params TestStepParameters, resumeState json.RawMessage) (json.RawMessage, error) // ValidateParameters checks that the parameters are correct before passing diff --git a/plugins/teststeps/cmd/cmd.go b/plugins/teststeps/cmd/cmd.go index 43dc4041..3eba933a 100644 --- a/plugins/teststeps/cmd/cmd.go +++ b/plugins/teststeps/cmd/cmd.go @@ -95,7 +95,7 @@ func emitEvent(ctx xcontext.Context, name event.Name, payload interface{}, tgt * // Run executes the cmd step. func (ts *Cmd) Run( ctx xcontext.Context, - ch test.TestStepChannels, + io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, @@ -161,7 +161,7 @@ func (ts *Cmd) Run( cmd.Path, cmd.Args, stdout.Bytes(), stderr.Bytes(), runErr) return runErr } - return teststeps.ForEachTarget(Name, ctx, ch, f) + return teststeps.ForEachTarget(Name, ctx, io, f) } func (ts *Cmd) validateAndPopulate(params test.TestStepParameters) error { diff --git a/plugins/teststeps/cpucmd/cpucmd.go b/plugins/teststeps/cpucmd/cpucmd.go index 096ef95b..b377f889 100644 --- a/plugins/teststeps/cpucmd/cpucmd.go +++ b/plugins/teststeps/cpucmd/cpucmd.go @@ -24,7 +24,6 @@ import ( "errors" "fmt" "io" - "regexp" "strconv" "time" @@ -72,7 +71,7 @@ func (ts CPUCmd) Name() string { // Run executes the cmd step. func (ts *CPUCmd) Run( ctx xcontext.Context, - ch test.TestStepChannels, + stepIO test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, @@ -260,7 +259,7 @@ func (ts *CPUCmd) Run( } } } - return teststeps.ForEachTarget(Name, ctx, ch, f) + return teststeps.ForEachTarget(Name, ctx, stepIO, f) } func (ts *CPUCmd) validateAndPopulate(params test.TestStepParameters) error { diff --git a/plugins/teststeps/echo/echo.go b/plugins/teststeps/echo/echo.go index cd5975c4..c3390841 100644 --- a/plugins/teststeps/echo/echo.go +++ b/plugins/teststeps/echo/echo.go @@ -53,29 +53,29 @@ func (e Step) Name() string { // Run executes the step func (e Step) Run( ctx xcontext.Context, - ch test.TestStepChannels, + io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage, ) (json.RawMessage, error) { for { - select { - case target, ok := <-ch.In: - if !ok { - return nil, nil - } - output, err := params.GetOne("text").Expand(target, stepsVars) - if err != nil { - return nil, err - } - // guaranteed to work here - jobID, _ := types.JobIDFromContext(ctx) - runID, _ := types.RunIDFromContext(ctx) - ctx.Infof("This is job %d, run %d on target %s with text '%s'", jobID, runID, target.ID, output) - ch.Out <- test.TestStepResult{Target: target} - case <-ctx.Done(): - return nil, nil + tgt, _ := io.Get(ctx) + if tgt == nil { + break + } + + output, err := params.GetOne("text").Expand(tgt, stepsVars) + if err != nil { + return nil, err + } + // guaranteed to work here + jobID, _ := types.JobIDFromContext(ctx) + runID, _ := types.RunIDFromContext(ctx) + ctx.Infof("This is job %d, run %d on target %s with text '%s'", jobID, runID, tgt.ID, output) + if err := io.Report(ctx, *tgt, nil); err != nil { + return nil, err } } + return nil, nil } diff --git a/plugins/teststeps/example/example.go b/plugins/teststeps/example/example.go index 8d262473..59138756 100644 --- a/plugins/teststeps/example/example.go +++ b/plugins/teststeps/example/example.go @@ -63,7 +63,7 @@ func (ts *Step) shouldFail(t *target.Target) bool { // Run executes the example step. func (ts *Step) Run( ctx xcontext.Context, - ch test.TestStepChannels, + stepIO test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, @@ -89,7 +89,7 @@ func (ts *Step) Run( } return nil } - return teststeps.ForEachTarget(Name, ctx, ch, f) + return teststeps.ForEachTarget(Name, ctx, stepIO, f) } // ValidateParameters validates the parameters associated to the TestStep diff --git a/plugins/teststeps/exec/exec.go b/plugins/teststeps/exec/exec.go index 73f7f16d..94a8425a 100644 --- a/plugins/teststeps/exec/exec.go +++ b/plugins/teststeps/exec/exec.go @@ -54,7 +54,7 @@ func (ts TestStep) Name() string { // Run executes the step. func (ts *TestStep) Run( ctx xcontext.Context, - ch test.TestStepChannels, + stepIO test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, @@ -65,7 +65,7 @@ func (ts *TestStep) Run( } tr := NewTargetRunner(ts, ev, stepsVars) - return teststeps.ForEachTarget(Name, ctx, ch, tr.Run) + return teststeps.ForEachTarget(Name, ctx, stepIO, tr.Run) } func (ts *TestStep) populateParams(stepParams test.TestStepParameters) error { diff --git a/plugins/teststeps/gathercmd/gathercmd.go b/plugins/teststeps/gathercmd/gathercmd.go index 98f27062..ecbf5f19 100644 --- a/plugins/teststeps/gathercmd/gathercmd.go +++ b/plugins/teststeps/gathercmd/gathercmd.go @@ -96,39 +96,42 @@ func truncate(in string, maxsize uint) string { return in[:size] } -func (ts *GatherCmd) acquireTargets(ctx xcontext.Context, ch test.TestStepChannels) ([]*target.Target, error) { - var targets []*target.Target +func (ts *GatherCmd) acquireTargets(ctx xcontext.Context, stepIO test.TestStepInputOutput) ([]target.Target, error) { + ctx, cancel := xcontext.WithCancel(ctx, xcontext.ErrPaused) + defer cancel() - for { + go func() { select { - case target, ok := <-ch.In: - if !ok { - ctx.Debugf("acquired %d targets", len(targets)) - return targets, nil - } - targets = append(targets, target) - case <-ctx.Until(xcontext.ErrPaused): - ctx.Debugf("paused during target acquisition, acquired %d", len(targets)) - return nil, xcontext.ErrPaused - + cancel() case <-ctx.Done(): - ctx.Debugf("canceled during target acquisition, acquired %d", len(targets)) - return nil, ctx.Err() } + }() + + var targets []target.Target + for { + tgt, err := stepIO.Get(ctx) + if err != nil { + return nil, err + } + if tgt == nil { + ctx.Debugf("acquired %d targets", len(targets)) + return targets, nil + } + targets = append(targets, *tgt) } } -func (ts *GatherCmd) returnTargets(ctx xcontext.Context, ch test.TestStepChannels, targets []*target.Target) { +func (ts *GatherCmd) returnTargets(ctx xcontext.Context, stepIO test.TestStepInputOutput, targets []target.Target) { for _, target := range targets { - ch.Out <- test.TestStepResult{Target: target} + stepIO.Report(ctx, target, nil) } } // Run executes the step func (ts *GatherCmd) Run( ctx xcontext.Context, - ch test.TestStepChannels, + stepIO test.TestStepInputOutput, emitter testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, @@ -141,11 +144,11 @@ func (ts *GatherCmd) Run( } // acquire all targets and hold them hostage until the cmd is done - targets, err := ts.acquireTargets(ctx, ch) + targets, err := ts.acquireTargets(ctx, stepIO) if err != nil { return nil, err } - defer ts.returnTargets(ctx, ch, targets) + defer ts.returnTargets(ctx, stepIO, targets) if len(targets) == 0 { return nil, nil @@ -154,7 +157,7 @@ func (ts *GatherCmd) Run( // arbitrarily choose first target to associate events with, anyone would work // but it is unnecessary to have the same event on all targets since this is a // "gather" type plugin - eventTarget := targets[0] + eventTarget := &targets[0] // used to manually cancel the exec if step becomes paused ctx, cancel := xcontext.WithCancel(ctx) diff --git a/plugins/teststeps/randecho/randecho.go b/plugins/teststeps/randecho/randecho.go index a3b39820..911c76b6 100644 --- a/plugins/teststeps/randecho/randecho.go +++ b/plugins/teststeps/randecho/randecho.go @@ -56,13 +56,13 @@ func (e Step) Name() string { // Run executes the step func (e Step) Run( ctx xcontext.Context, - ch test.TestStepChannels, + stepIO test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, resumeState json.RawMessage, ) (json.RawMessage, error) { - return teststeps.ForEachTarget(Name, ctx, ch, + return teststeps.ForEachTarget(Name, ctx, stepIO, func(ctx xcontext.Context, target *target.Target) error { r := rand.Intn(2) if r == 0 { diff --git a/plugins/teststeps/s3fileupload/s3fileupload.go b/plugins/teststeps/s3fileupload/s3fileupload.go index 12d8e985..d6359f6c 100644 --- a/plugins/teststeps/s3fileupload/s3fileupload.go +++ b/plugins/teststeps/s3fileupload/s3fileupload.go @@ -85,7 +85,7 @@ func emitEvent(ctx xcontext.Context, name event.Name, payload interface{}, tgt * // Run executes the awsFileUpload. func (ts *FileUpload) Run( ctx xcontext.Context, - ch test.TestStepChannels, + stepIO test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, @@ -134,7 +134,7 @@ func (ts *FileUpload) Run( } return nil } - return teststeps.ForEachTarget(Name, ctx, ch, f) + return teststeps.ForEachTarget(Name, ctx, stepIO, f) } // Retrieve all the parameters defines through the jobDesc diff --git a/plugins/teststeps/sleep/sleep.go b/plugins/teststeps/sleep/sleep.go index 9040d061..6d759639 100644 --- a/plugins/teststeps/sleep/sleep.go +++ b/plugins/teststeps/sleep/sleep.go @@ -70,7 +70,7 @@ type sleepStepData struct { // Run executes the step func (ss *sleepStep) Run( ctx xcontext.Context, - ch test.TestStepChannels, + stepIO test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, @@ -119,5 +119,5 @@ func (ss *sleepStep) Run( return nil } - return teststeps.ForEachTargetWithResume(ctx, ch, resumeState, 1, fn) + return teststeps.ForEachTargetWithResume(ctx, stepIO, resumeState, 1, fn) } diff --git a/plugins/teststeps/sshcmd/sshcmd.go b/plugins/teststeps/sshcmd/sshcmd.go index 545ae45f..6e82c3fd 100644 --- a/plugins/teststeps/sshcmd/sshcmd.go +++ b/plugins/teststeps/sshcmd/sshcmd.go @@ -71,7 +71,7 @@ func (ts SSHCmd) Name() string { // Run executes the cmd step. func (ts *SSHCmd) Run( ctx xcontext.Context, - ch test.TestStepChannels, + stepIO test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, @@ -278,7 +278,7 @@ func (ts *SSHCmd) Run( } } } - return teststeps.ForEachTarget(Name, ctx, ch, f) + return teststeps.ForEachTarget(Name, ctx, stepIO, f) } func (ts *SSHCmd) validateAndPopulate(params test.TestStepParameters) error { diff --git a/plugins/teststeps/terminalexpect/terminalexpect.go b/plugins/teststeps/terminalexpect/terminalexpect.go index 3f44a857..2532a065 100644 --- a/plugins/teststeps/terminalexpect/terminalexpect.go +++ b/plugins/teststeps/terminalexpect/terminalexpect.go @@ -56,7 +56,7 @@ func match(match string, log xcontext.Logger) termhook.LineHandler { // Run executes the terminal step. func (ts *TerminalExpect) Run( ctx xcontext.Context, - ch test.TestStepChannels, + stepIO test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, @@ -90,7 +90,7 @@ func (ts *TerminalExpect) Run( } } log.Debugf("%s: waiting for string '%s' with timeout %s", Name, ts.Match, ts.Timeout) - return teststeps.ForEachTarget(Name, ctx, ch, f) + return teststeps.ForEachTarget(Name, ctx, stepIO, f) } func (ts *TerminalExpect) validateAndPopulate(params test.TestStepParameters) error { diff --git a/plugins/teststeps/teststeps.go b/plugins/teststeps/teststeps.go index f5bef9a5..0ebc70a2 100644 --- a/plugins/teststeps/teststeps.go +++ b/plugins/teststeps/teststeps.go @@ -29,43 +29,27 @@ type PerTargetFunc func(ctx xcontext.Context, target *target.Target) error // This function wraps the logic that handles target routing through the in/out // The implementation of the per-target function is responsible for // reacting to cancel/pause signals and return quickly. -func ForEachTarget(pluginName string, ctx xcontext.Context, ch test.TestStepChannels, f PerTargetFunc) (json.RawMessage, error) { - reportTarget := func(t *target.Target, err error) { +func ForEachTarget(pluginName string, ctx xcontext.Context, inputOutput test.TestStepInputOutput, f PerTargetFunc) (json.RawMessage, error) { + var wg sync.WaitGroup + for { + tgt, err := inputOutput.Get(ctx) if err != nil { - ctx.Errorf("%s: ForEachTarget: failed to apply test step function on target %s: %v", pluginName, t, err) - } else { - ctx.Debugf("%s: ForEachTarget: target %s completed successfully", pluginName, t) + ctx.Debugf("%s: ForEachTarget: incoming targets error: '%v'", err) + break } - select { - case ch.Out <- test.TestStepResult{Target: t, Err: err}: - case <-ctx.Done(): - ctx.Debugf("%s: ForEachTarget: received cancellation signal while reporting result", pluginName) + if tgt == nil { + ctx.Debugf("%s: ForEachTarget: all targets have been received", pluginName) + break } - } - var wg sync.WaitGroup - func() { - for { - select { - case tgt, ok := <-ch.In: - if !ok { - ctx.Debugf("%s: ForEachTarget: all targets have been received", pluginName) - return - } - ctx.Debugf("%s: ForEachTarget: received target %s", pluginName, tgt) - wg.Add(1) - go func() { - defer wg.Done() - - err := f(ctx, tgt) - reportTarget(tgt, err) - }() - case <-ctx.Done(): - ctx.Debugf("%s: ForEachTarget: incoming loop canceled", pluginName) - return - } - } - }() + wg.Add(1) + go func(tgt target.Target) { + defer wg.Done() + + tgtErr := f(ctx, &tgt) + inputOutput.Report(ctx, tgt, tgtErr) + }(*tgt) + } wg.Wait() return nil, nil } @@ -121,7 +105,7 @@ type parallelTargetsState struct { // with the same data on job resumption. The helper will not call functions again that succeeded or failed // before the pause signal was received. // The supplied PerTargetWithResumeFunc must react to pause and cancellation signals as normal. -func ForEachTargetWithResume(ctx xcontext.Context, ch test.TestStepChannels, resumeState json.RawMessage, currentStepStateVersion int, f PerTargetWithResumeFunc) (json.RawMessage, error) { +func ForEachTargetWithResume(ctx xcontext.Context, inputOutput test.TestStepInputOutput, resumeState json.RawMessage, currentStepStateVersion int, f PerTargetWithResumeFunc) (json.RawMessage, error) { var ss parallelTargetsState // Parse resume state, if any. @@ -157,11 +141,7 @@ func ForEachTargetWithResume(ctx xcontext.Context, ch test.TestStepChannels, res } else { ctx.Debugf("ForEachTargetWithResume: target %s completed successfully", tgt2.Target.ID) } - select { - case ch.Out <- test.TestStepResult{Target: tgt2.Target, Err: err}: - case <-ctx.Done(): - ctx.Debugf("ForEachTargetWithResume: received cancellation signal while reporting result") - } + inputOutput.Report(ctx, *tgt2.Target, err) } } @@ -175,22 +155,19 @@ func ForEachTargetWithResume(ctx xcontext.Context, ch test.TestStepChannels, res ss.Targets = nil var err error -mainloop: for { - select { - // no need to check for pause here, pausing closes the channel - case tgt, ok := <-ch.In: - if !ok { - break mainloop - } - ctx.Debugf("ForEachTargetWithResume: received target %s", tgt) - wg.Add(1) - go handleTarget(&TargetWithData{Target: tgt}) - case <-ctx.Done(): - ctx.Debugf("ForEachTargetWithResume: canceled, terminating") - err = xcontext.ErrCanceled - break mainloop + var tgt *target.Target + tgt, err = inputOutput.Get(ctx) + if err != nil { + ctx.Debugf("%s: ForEachTargetWithResume: incoming targets error: '%v'", err) + break + } + if tgt == nil { + break } + + wg.Add(1) + go handleTarget(&TargetWithData{Target: tgt}) } // close pauseStates to signal all handlers are done diff --git a/plugins/teststeps/teststeps_test.go b/plugins/teststeps/teststeps_test.go index e00ba3ce..ee3da116 100644 --- a/plugins/teststeps/teststeps_test.go +++ b/plugins/teststeps/teststeps_test.go @@ -6,502 +6,448 @@ package teststeps import ( - "context" - "encoding/json" "fmt" - "sync" - "sync/atomic" - "testing" - "time" - "github.com/linuxboot/contest/pkg/target" - "github.com/linuxboot/contest/pkg/test" "github.com/linuxboot/contest/pkg/xcontext" "github.com/linuxboot/contest/pkg/xcontext/bundles/logrusctx" "github.com/linuxboot/contest/pkg/xcontext/logger" + "github.com/linuxboot/contest/tests/common/mocks" + "testing" - "github.com/linuxboot/contest/tests/common/goroutine_leak_check" - - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type data struct { ctx xcontext.Context cancel, pause func() - inCh chan *target.Target - outCh chan test.TestStepResult - stepChans test.TestStepChannels } func newData() data { ctx, pause := xcontext.WithNotify(nil, xcontext.ErrPaused) ctx, cancel := xcontext.WithCancel(ctx) - inCh := make(chan *target.Target) - outCh := make(chan test.TestStepResult) return data{ ctx: ctx, cancel: cancel, pause: pause, - inCh: inCh, - outCh: outCh, - stepChans: test.TestStepChannels{ - In: inCh, - Out: outCh, - }, } } func TestForEachTargetOneTarget(t *testing.T) { - ctx, _ := logrusctx.NewContext(logger.LevelDebug) - log := ctx.Logger() - d := newData() - fn := func(ctx xcontext.Context, tgt *target.Target) error { - log.Debugf("Handling target %s", tgt) - return nil - } - go func() { - d.inCh <- &target.Target{ID: "target001"} - close(d.inCh) - }() - ctx, cancel := xcontext.WithCancel(ctx) - defer cancel() - go func() { - for { - select { - case <-ctx.Done(): - return - case res := <-d.outCh: - if res.Err == nil { - log.Debugf("Step for target %s completed as expected", res.Target) - } else { - t.Errorf("Expected no error but got one: %v", res.Err) - } - } - } - }() - _, err := ForEachTarget("test_one_target ", d.ctx, d.stepChans, fn) - require.NoError(t, err) -} - -func TestForEachTargetOneTargetAllFail(t *testing.T) { - ctx, _ := logrusctx.NewContext(logger.LevelDebug) - log := ctx.Logger() - d := newData() - fn := func(ctx xcontext.Context, t *target.Target) error { - log.Debugf("Handling target %s", t) - return fmt.Errorf("error with target %s", t) - } - go func() { - d.inCh <- &target.Target{ID: "target001"} - close(d.inCh) - }() - ctx, cancel := xcontext.WithCancel(ctx) + ctx, cancel := logrusctx.NewContext(logger.LevelDebug) defer cancel() - go func() { - for { - select { - case <-ctx.Done(): - return - case res := <-d.outCh: - if res.Err == nil { - t.Errorf("Step for target %s expected to fail but completed successfully instead", res.Target) - } else { - log.Debugf("Step for target failed as expected: %v", res.Err) - } - } - } - }() - _, err := ForEachTarget("test_one_target ", d.ctx, d.stepChans, fn) - require.NoError(t, err) -} -func TestForEachTargetTenTargets(t *testing.T) { - d := newData() - fn := func(ctx xcontext.Context, tgt *target.Target) error { + stepIO := mocks.NewTestStepInputOutputMock([]target.Target{ + {ID: "target001"}, + }) + _, err := ForEachTarget("test_one_target ", ctx, stepIO, func(ctx xcontext.Context, tgt *target.Target) error { ctx.Debugf("Handling target %s", tgt) return nil - } - go func() { - for i := 0; i < 10; i++ { - d.inCh <- &target.Target{ID: fmt.Sprintf("target%00d", i)} - } - close(d.inCh) - }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go func() { - for { - select { - case <-ctx.Done(): - return - case res := <-d.outCh: - if res.Err == nil { - d.ctx.Debugf("Step for target %s completed as expected", res.Target) - } else { - t.Errorf("Expected no error but got one: %v", res.Err) - } - } - } - }() - _, err := ForEachTarget("test_one_target ", d.ctx, d.stepChans, fn) + }) require.NoError(t, err) -} -func TestForEachTargetTenTargetsAllFail(t *testing.T) { - d := newData() - fn := func(ctx xcontext.Context, tgt *target.Target) error { - d.ctx.Debugf("Handling target %s", tgt) - return fmt.Errorf("error with target %s", tgt) - } - go func() { - for i := 0; i < 10; i++ { - d.inCh <- &target.Target{ID: fmt.Sprintf("target%00d", i)} - } - close(d.inCh) - }() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go func() { - for { - select { - case <-ctx.Done(): - return - case res := <-d.outCh: - if res.Err == nil { - t.Errorf("Step for target %s expected to fail but completed successfully instead", res.Target) - } else { - d.ctx.Debugf("Step for target failed as expected: %v", res.Err) - } - } - } - }() - _, err := ForEachTarget("test_one_target ", d.ctx, d.stepChans, fn) - require.NoError(t, err) + require.Equal(t, map[string]error{ + "target001": nil, + }, stepIO.GetReportedTargets()) } -func TestForEachTargetTenTargetsOneFails(t *testing.T) { - d := newData() - // chosen by fair dice roll. - // guaranteed to be random. - failingTarget := "target004" - fn := func(ctx xcontext.Context, tgt *target.Target) error { - d.ctx.Debugf("Handling target %s", tgt) - if tgt.ID == failingTarget { - return fmt.Errorf("error with target %s", tgt) - } - return nil - } - go func() { - for i := 0; i < 10; i++ { - d.inCh <- &target.Target{ID: fmt.Sprintf("target%03d", i)} - } - close(d.inCh) - }() - ctx, cancel := context.WithCancel(context.Background()) +func TestForEachTargetOneTargetAllFail(t *testing.T) { + ctx, cancel := logrusctx.NewContext(logger.LevelDebug) defer cancel() - go func() { - for { - select { - case <-ctx.Done(): - return - case res := <-d.outCh: - if res.Err == nil { - if res.Target.ID == failingTarget { - t.Errorf("Step for target %s expected to fail but completed successfully instead", res.Target) - } else { - d.ctx.Debugf("Step for target %s completed as expected", res.Target) - } - } else { - if res.Target.ID == failingTarget { - d.ctx.Debugf("Step for target %s failed as expected: %v", res.Target, res.Err) - } else { - t.Errorf("Expected no error for %s but got one: %v", res.Target, res.Err) - } - } - } - } - }() - _, err := ForEachTarget("test_one_target ", d.ctx, d.stepChans, fn) - require.NoError(t, err) -} - -// TestForEachTargetTenTargetsParallelism checks if we didn't break the parallelism of -// ForEachTarget. It passes 10 targets and a function that takes 1 second for each -// target, so the whole process should not take more than ~1s if properly parallelized. -// I am using a deadline of 3s to give it some margin, knowing that if it is sequential -// it will take ~10s. -func TestForEachTargetTenTargetsParallelism(t *testing.T) { - sleepTime := 300 * time.Millisecond - d := newData() - fn := func(ctx xcontext.Context, tgt *target.Target) error { - d.ctx.Debugf("Handling target %s", tgt) - select { - case <-ctx.Done(): - d.ctx.Debugf("target %s cancelled", tgt) - case <-ctx.Until(xcontext.ErrPaused): - d.ctx.Debugf("target %s paused", tgt) - case <-time.After(sleepTime): - d.ctx.Debugf("target %s processed", tgt) - } - return nil - } - - numTargets := 10 - go func() { - for i := 0; i < numTargets; i++ { - d.inCh <- &target.Target{ID: fmt.Sprintf("target%03d", i)} - } - close(d.inCh) - }() - - deadlineExceeded := false - var targetError error - targetsRemain := numTargets - var wg sync.WaitGroup - - wg.Add(1) - go func() { - // try to cancel ForEachTarget in case it's still running - defer d.cancel() - defer wg.Done() - - maxWaitTime := sleepTime * 3 - deadline := time.Now().Add(maxWaitTime) - d.ctx.Debugf("Setting deadline to now+%s", maxWaitTime) - - for { - select { - case res := <-d.outCh: - targetsRemain-- - if res.Err == nil { - d.ctx.Debugf("Step for target %s completed successfully as expected", res.Target) - } else { - d.ctx.Debugf("Step for target %s expected to completed successfully but failed instead", res.Target, res.Err) - targetError = res.Err - } - if targetsRemain == 0 { - d.ctx.Debugf("All targets processed") - return - } - case <-time.After(time.Until(deadline)): - deadlineExceeded = true - d.ctx.Debugf("Deadline exceeded") - return - } - } - }() - - _, err := ForEachTarget("test_parallel", d.ctx, d.stepChans, fn) - - wg.Wait() //wait for receiver - - if deadlineExceeded { - t.Fatal("wait deadline exceeded, it's possible that parallelization is not working anymore") - } - require.NoError(t, targetError) - require.NoError(t, err) - assert.Equal(t, 0, targetsRemain) -} - -func TestForEachTargetCancelSignalPropagation(t *testing.T) { - sleepTime := 300 * time.Millisecond - numTargets := 10 - var canceledTargets int32 - d := newData() - - fn := func(ctx xcontext.Context, tgt *target.Target) error { - d.ctx.Debugf("Handling target %s", tgt) - select { - case <-ctx.Done(): - d.ctx.Debugf("target %s caneled", tgt) - atomic.AddInt32(&canceledTargets, 1) - case <-ctx.Until(xcontext.ErrPaused): - d.ctx.Debugf("target %s paused", tgt) - case <-time.After(sleepTime): - d.ctx.Debugf("target %s processed", tgt) - } - return nil - } - - go func() { - for i := 0; i < numTargets; i++ { - d.inCh <- &target.Target{ID: fmt.Sprintf("target%03d", i)} - } - close(d.inCh) - }() - - go func() { - time.Sleep(sleepTime / 3) - d.cancel() - }() - _, err := ForEachTarget("test_cancelation", d.ctx, d.stepChans, fn) - require.NoError(t, err) - - assert.Equal(t, int32(numTargets), canceledTargets) -} - -func TestForEachTargetCancelBeforeInputChannelClosed(t *testing.T) { - sleepTime := 300 * time.Millisecond - numTargets := 10 - var canceledTargets int32 - d := newData() - - fn := func(ctx xcontext.Context, tgt *target.Target) error { - d.ctx.Debugf("Handling target %s", tgt) - select { - case <-ctx.Done(): - d.ctx.Debugf("target %s cancelled", tgt) - atomic.AddInt32(&canceledTargets, 1) - case <-ctx.Until(xcontext.ErrPaused): - d.ctx.Debugf("target %s paused", tgt) - case <-time.After(sleepTime): - d.ctx.Debugf("target %s processed", tgt) - } - return nil - } - - var wg sync.WaitGroup - wg.Add(1) - go func() { - for i := 0; i < numTargets; i++ { - d.inCh <- &target.Target{ID: fmt.Sprintf("target%03d", i)} - } - wg.Wait() //Don't close the input channel until ForEachTarget returned - }() - - go func() { - time.Sleep(sleepTime / 3) - d.cancel() - }() - - _, err := ForEachTarget("test_cancelation", d.ctx, d.stepChans, fn) + stepIO := mocks.NewTestStepInputOutputMock([]target.Target{ + {ID: "target001"}, + }) + _, err := ForEachTarget("test_one_target ", ctx, stepIO, func(ctx xcontext.Context, t *target.Target) error { + ctx.Debugf("Handling target %s", t) + return fmt.Errorf("error with target %s", t) + }) require.NoError(t, err) - wg.Done() - assert.Equal(t, int32(numTargets), canceledTargets) + require.Equal(t, map[string]error{ + "target001": fmt.Errorf("error with target Target{ID: \"target001\"}"), + }, stepIO.GetReportedTargets()) } -func TestForEachTargetWithResumeAllReturn(t *testing.T) { - numTargets := 10 - d := newData() +func TestForEachTargetTenTargets(t *testing.T) { + ctx, cancel := logrusctx.NewContext(logger.LevelDebug) + defer cancel() - fn := func(ctx xcontext.Context, target *TargetWithData) error { - return nil // success + var inputTargets []target.Target + for i := 0; i < 10; i++ { + inputTargets = append(inputTargets, target.Target{ID: fmt.Sprintf("target%00d", i)}) } + stepIO := mocks.NewTestStepInputOutputMock(inputTargets) - var wg sync.WaitGroup - wg.Add(1) - // submit all, then close - go func() { - for i := 0; i < numTargets; i++ { - d.inCh <- &target.Target{ID: fmt.Sprintf("target%03d", i)} - } - close(d.inCh) - wg.Done() - }() - - wg.Add(1) - // read all results - go func() { - for i := 0; i < numTargets; i++ { - <-d.outCh - } - wg.Done() - }() - - res, err := ForEachTargetWithResume(d.ctx, d.stepChans, nil, 1, fn) + _, err := ForEachTarget("test_one_target ", ctx, stepIO, func(ctx xcontext.Context, tgt *target.Target) error { + ctx.Debugf("Handling target %s", tgt) + return nil + }) require.NoError(t, err) - assert.Nil(t, res) - // make sure all helpers are done - wg.Wait() - assert.NoError(t, goroutine_leak_check.CheckLeakedGoRoutines()) -} -type simpleStepData struct { - Foo string -} - -func TestForEachTargetWithResumeAllPause(t *testing.T) { - numTargets := 10 - targets := make([]target.Target, 10) - for i := 0; i < numTargets; i++ { - targets[i] = target.Target{ID: fmt.Sprintf("target%03d", i)} - } - d := newData() + targetsResults := stepIO.GetReportedTargets() + require.Len(t, targetsResults, len(inputTargets)) - fn := func(ctx xcontext.Context, target *TargetWithData) error { - stepData := simpleStepData{target.Target.ID} - json, err := json.Marshal(&stepData) + for _, tgt := range inputTargets { + err, ok := targetsResults[tgt.ID] + require.True(t, ok) require.NoError(t, err) - // block and pause - <-ctx.Until(xcontext.ErrPaused) - target.Data = json - return xcontext.ErrPaused } - var testingWg sync.WaitGroup - - // constantly read out channel, must not receive anything - outDone := make(chan struct{}) - testingWg.Add(1) - go func() { - select { - case res := <-d.outCh: - assert.Fail(t, "unexpected target in out channel", res) - case <-outDone: - testingWg.Done() - } - }() - - var inputWg sync.WaitGroup - inputWg.Add(1) - // submit all, then close - go func() { - for i := 0; i < numTargets; i++ { - d.inCh <- &targets[i] - } - close(d.inCh) - inputWg.Done() - }() - - // run helper so it accepts jobs - testingWg.Add(1) - go func() { - res, err := ForEachTargetWithResume(d.ctx, d.stepChans, nil, 1, fn) - assert.Equal(t, xcontext.ErrPaused, err) - // inspect result - state := parallelTargetsState{} - assert.NoError(t, json.Unmarshal(res, &state)) - assert.Equal(t, 1, state.Version) - assert.Equal(t, numTargets, len(state.Targets)) - targetSeen := make(map[string]*TargetWithData) - // check all targets were returned once - for _, twd := range state.Targets { - _, ok := targetSeen[twd.Target.ID] - if ok { - assert.Fail(t, "duplicate target data in serialized resume data", twd) - } - targetSeen[twd.Target.ID] = twd - } - for i := 0; i < numTargets; i++ { - twd, ok := targetSeen[targets[i].ID] - assert.True(t, ok) - // check serialized data - stepData := simpleStepData{} - assert.NoError(t, json.Unmarshal(twd.Data, &stepData)) - assert.Equal(t, targets[i].ID, stepData.Foo) - } - // done monitoring out channels now - outDone <- struct{}{} - testingWg.Done() - }() - - // pause when all were submitted - inputWg.Wait() - d.pause() - - // wait for pausing and all testing of pause result to be done - testingWg.Wait() - assert.NoError(t, goroutine_leak_check.CheckLeakedGoRoutines()) } + +//func TestForEachTargetTenTargetsAllFail(t *testing.T) { +// d := newData() +// fn := func(ctx xcontext.Context, tgt *target.Target) error { +// d.ctx.Debugf("Handling target %s", tgt) +// return fmt.Errorf("error with target %s", tgt) +// } +// go func() { +// for i := 0; i < 10; i++ { +// d.inCh <- &target.Target{ID: fmt.Sprintf("target%00d", i)} +// } +// close(d.inCh) +// }() +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() +// go func() { +// for { +// select { +// case <-ctx.Done(): +// return +// case res := <-d.outCh: +// if res.Err == nil { +// t.Errorf("Step for target %s expected to fail but completed successfully instead", res.Target) +// } else { +// d.ctx.Debugf("Step for target failed as expected: %v", res.Err) +// } +// } +// } +// }() +// _, err := ForEachTarget("test_one_target ", d.ctx, d.stepChans, fn) +// require.NoError(t, err) +//} +// +//func TestForEachTargetTenTargetsOneFails(t *testing.T) { +// d := newData() +// // chosen by fair dice roll. +// // guaranteed to be random. +// failingTarget := "target004" +// fn := func(ctx xcontext.Context, tgt *target.Target) error { +// d.ctx.Debugf("Handling target %s", tgt) +// if tgt.ID == failingTarget { +// return fmt.Errorf("error with target %s", tgt) +// } +// return nil +// } +// go func() { +// for i := 0; i < 10; i++ { +// d.inCh <- &target.Target{ID: fmt.Sprintf("target%03d", i)} +// } +// close(d.inCh) +// }() +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() +// go func() { +// for { +// select { +// case <-ctx.Done(): +// return +// case res := <-d.outCh: +// if res.Err == nil { +// if res.Target.ID == failingTarget { +// t.Errorf("Step for target %s expected to fail but completed successfully instead", res.Target) +// } else { +// d.ctx.Debugf("Step for target %s completed as expected", res.Target) +// } +// } else { +// if res.Target.ID == failingTarget { +// d.ctx.Debugf("Step for target %s failed as expected: %v", res.Target, res.Err) +// } else { +// t.Errorf("Expected no error for %s but got one: %v", res.Target, res.Err) +// } +// } +// } +// } +// }() +// _, err := ForEachTarget("test_one_target ", d.ctx, d.stepChans, fn) +// require.NoError(t, err) +//} +// +//// TestForEachTargetTenTargetsParallelism checks if we didn't break the parallelism of +//// ForEachTarget. It passes 10 targets and a function that takes 1 second for each +//// target, so the whole process should not take more than ~1s if properly parallelized. +//// I am using a deadline of 3s to give it some margin, knowing that if it is sequential +//// it will take ~10s. +//func TestForEachTargetTenTargetsParallelism(t *testing.T) { +// sleepTime := 300 * time.Millisecond +// d := newData() +// fn := func(ctx xcontext.Context, tgt *target.Target) error { +// d.ctx.Debugf("Handling target %s", tgt) +// select { +// case <-ctx.Done(): +// d.ctx.Debugf("target %s cancelled", tgt) +// case <-ctx.Until(xcontext.ErrPaused): +// d.ctx.Debugf("target %s paused", tgt) +// case <-time.After(sleepTime): +// d.ctx.Debugf("target %s processed", tgt) +// } +// return nil +// } +// +// numTargets := 10 +// go func() { +// for i := 0; i < numTargets; i++ { +// d.inCh <- &target.Target{ID: fmt.Sprintf("target%03d", i)} +// } +// close(d.inCh) +// }() +// +// deadlineExceeded := false +// var targetError error +// targetsRemain := numTargets +// var wg sync.WaitGroup +// +// wg.Add(1) +// go func() { +// // try to cancel ForEachTarget in case it's still running +// defer d.cancel() +// defer wg.Done() +// +// maxWaitTime := sleepTime * 3 +// deadline := time.Now().Add(maxWaitTime) +// d.ctx.Debugf("Setting deadline to now+%s", maxWaitTime) +// +// for { +// select { +// case res := <-d.outCh: +// targetsRemain-- +// if res.Err == nil { +// d.ctx.Debugf("Step for target %s completed successfully as expected", res.Target) +// } else { +// d.ctx.Debugf("Step for target %s expected to completed successfully but failed instead", res.Target, res.Err) +// targetError = res.Err +// } +// if targetsRemain == 0 { +// d.ctx.Debugf("All targets processed") +// return +// } +// case <-time.After(time.Until(deadline)): +// deadlineExceeded = true +// d.ctx.Debugf("Deadline exceeded") +// return +// } +// } +// }() +// +// _, err := ForEachTarget("test_parallel", d.ctx, d.stepChans, fn) +// +// wg.Wait() //wait for receiver +// +// if deadlineExceeded { +// t.Fatal("wait deadline exceeded, it's possible that parallelization is not working anymore") +// } +// require.NoError(t, targetError) +// require.NoError(t, err) +// assert.Equal(t, 0, targetsRemain) +//} +// +//func TestForEachTargetCancelSignalPropagation(t *testing.T) { +// sleepTime := 300 * time.Millisecond +// numTargets := 10 +// var canceledTargets int32 +// d := newData() +// +// fn := func(ctx xcontext.Context, tgt *target.Target) error { +// d.ctx.Debugf("Handling target %s", tgt) +// select { +// case <-ctx.Done(): +// d.ctx.Debugf("target %s caneled", tgt) +// atomic.AddInt32(&canceledTargets, 1) +// case <-ctx.Until(xcontext.ErrPaused): +// d.ctx.Debugf("target %s paused", tgt) +// case <-time.After(sleepTime): +// d.ctx.Debugf("target %s processed", tgt) +// } +// return nil +// } +// +// go func() { +// for i := 0; i < numTargets; i++ { +// d.inCh <- &target.Target{ID: fmt.Sprintf("target%03d", i)} +// } +// close(d.inCh) +// }() +// +// go func() { +// time.Sleep(sleepTime / 3) +// d.cancel() +// }() +// +// _, err := ForEachTarget("test_cancelation", d.ctx, d.stepChans, fn) +// require.NoError(t, err) +// +// assert.Equal(t, int32(numTargets), canceledTargets) +//} +// +//func TestForEachTargetCancelBeforeInputChannelClosed(t *testing.T) { +// sleepTime := 300 * time.Millisecond +// numTargets := 10 +// var canceledTargets int32 +// d := newData() +// +// fn := func(ctx xcontext.Context, tgt *target.Target) error { +// d.ctx.Debugf("Handling target %s", tgt) +// select { +// case <-ctx.Done(): +// d.ctx.Debugf("target %s cancelled", tgt) +// atomic.AddInt32(&canceledTargets, 1) +// case <-ctx.Until(xcontext.ErrPaused): +// d.ctx.Debugf("target %s paused", tgt) +// case <-time.After(sleepTime): +// d.ctx.Debugf("target %s processed", tgt) +// } +// return nil +// } +// +// var wg sync.WaitGroup +// wg.Add(1) +// go func() { +// for i := 0; i < numTargets; i++ { +// d.inCh <- &target.Target{ID: fmt.Sprintf("target%03d", i)} +// } +// wg.Wait() //Don't close the input channel until ForEachTarget returned +// }() +// +// go func() { +// time.Sleep(sleepTime / 3) +// d.cancel() +// }() +// +// _, err := ForEachTarget("test_cancelation", d.ctx, d.stepChans, fn) +// require.NoError(t, err) +// +// wg.Done() +// assert.Equal(t, int32(numTargets), canceledTargets) +//} +// +//func TestForEachTargetWithResumeAllReturn(t *testing.T) { +// numTargets := 10 +// d := newData() +// +// fn := func(ctx xcontext.Context, target *TargetWithData) error { +// return nil // success +// } +// +// var wg sync.WaitGroup +// wg.Add(1) +// // submit all, then close +// go func() { +// for i := 0; i < numTargets; i++ { +// d.inCh <- &target.Target{ID: fmt.Sprintf("target%03d", i)} +// } +// close(d.inCh) +// wg.Done() +// }() +// +// wg.Add(1) +// // read all results +// go func() { +// for i := 0; i < numTargets; i++ { +// <-d.outCh +// } +// wg.Done() +// }() +// +// res, err := ForEachTargetWithResume(d.ctx, d.stepChans, nil, 1, fn) +// require.NoError(t, err) +// assert.Nil(t, res) +// // make sure all helpers are done +// wg.Wait() +// assert.NoError(t, goroutine_leak_check.CheckLeakedGoRoutines()) +//} +// +//type simpleStepData struct { +// Foo string +//} +// +//func TestForEachTargetWithResumeAllPause(t *testing.T) { +// numTargets := 10 +// targets := make([]target.Target, 10) +// for i := 0; i < numTargets; i++ { +// targets[i] = target.Target{ID: fmt.Sprintf("target%03d", i)} +// } +// d := newData() +// +// fn := func(ctx xcontext.Context, target *TargetWithData) error { +// stepData := simpleStepData{target.Target.ID} +// json, err := json.Marshal(&stepData) +// require.NoError(t, err) +// // block and pause +// <-ctx.Until(xcontext.ErrPaused) +// target.Data = json +// return xcontext.ErrPaused +// } +// var testingWg sync.WaitGroup +// +// // constantly read out channel, must not receive anything +// outDone := make(chan struct{}) +// testingWg.Add(1) +// go func() { +// select { +// case res := <-d.outCh: +// assert.Fail(t, "unexpected target in out channel", res) +// case <-outDone: +// testingWg.Done() +// } +// }() +// +// var inputWg sync.WaitGroup +// inputWg.Add(1) +// // submit all, then close +// go func() { +// for i := 0; i < numTargets; i++ { +// d.inCh <- &targets[i] +// } +// close(d.inCh) +// inputWg.Done() +// }() +// +// // run helper so it accepts jobs +// testingWg.Add(1) +// go func() { +// res, err := ForEachTargetWithResume(d.ctx, d.stepChans, nil, 1, fn) +// assert.Equal(t, xcontext.ErrPaused, err) +// // inspect result +// state := parallelTargetsState{} +// assert.NoError(t, json.Unmarshal(res, &state)) +// assert.Equal(t, 1, state.Version) +// assert.Equal(t, numTargets, len(state.Targets)) +// targetSeen := make(map[string]*TargetWithData) +// // check all targets were returned once +// for _, twd := range state.Targets { +// _, ok := targetSeen[twd.Target.ID] +// if ok { +// assert.Fail(t, "duplicate target data in serialized resume data", twd) +// } +// targetSeen[twd.Target.ID] = twd +// } +// for i := 0; i < numTargets; i++ { +// twd, ok := targetSeen[targets[i].ID] +// assert.True(t, ok) +// // check serialized data +// stepData := simpleStepData{} +// assert.NoError(t, json.Unmarshal(twd.Data, &stepData)) +// assert.Equal(t, targets[i].ID, stepData.Foo) +// } +// // done monitoring out channels now +// outDone <- struct{}{} +// testingWg.Done() +// }() +// +// // pause when all were submitted +// inputWg.Wait() +// d.pause() +// +// // wait for pausing and all testing of pause result to be done +// testingWg.Wait() +// assert.NoError(t, goroutine_leak_check.CheckLeakedGoRoutines()) +//} diff --git a/plugins/teststeps/waitport/waitport.go b/plugins/teststeps/waitport/waitport.go index 514452f3..73277327 100644 --- a/plugins/teststeps/waitport/waitport.go +++ b/plugins/teststeps/waitport/waitport.go @@ -46,7 +46,7 @@ func (ts *WaitPort) Name() string { // Run executes the cmd step. func (ts *WaitPort) Run( ctx xcontext.Context, - ch test.TestStepChannels, + stepIO test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, inputParams test.TestStepParameters, @@ -141,7 +141,7 @@ func (ts *WaitPort) Run( ctx.Infof("wait port plugin finished, err: '%v'", resultErr) return resultErr } - return teststeps.ForEachTargetWithResume(ctx, ch, resumeState, 0, f) + return teststeps.ForEachTargetWithResume(ctx, stepIO, resumeState, 0, f) } // ValidateParameters validates the parameters associated to the TestStep diff --git a/plugins/teststeps/waitport/waitport_test.go b/plugins/teststeps/waitport/waitport_test.go index e04875fb..2e76478a 100644 --- a/plugins/teststeps/waitport/waitport_test.go +++ b/plugins/teststeps/waitport/waitport_test.go @@ -2,6 +2,8 @@ package waitport import ( "fmt" + "github.com/linuxboot/contest/tests/common/mocks" + "github.com/stretchr/testify/require" "net" "sync" "testing" @@ -50,23 +52,18 @@ func TestWaitForTCPPort(t *testing.T) { } }() - inCh := make(chan *target.Target, 1) - testStepChannels := test.TestStepChannels{ - In: inCh, - Out: make(chan test.TestStepResult, 1), - } + stepIO := mocks.NewTestStepInputOutputMock([]target.Target{ + { + ID: "some_id", + FQDN: "localhost", + }, + }) ev := storage.NewTestEventEmitterFetcher(storageEngineVault, testevent.Header{ JobID: 12345, TestName: "waitport_tests", TestStepLabel: "waitport", }) - inCh <- &target.Target{ - ID: "some_id", - FQDN: "localhost", - } - close(inCh) - params := test.TestStepParameters{ "protocol": []test.Param{*test.NewParam("tcp")}, "port": []test.Param{*test.NewParam(fmt.Sprintf("%d", listener.Addr().(*net.TCPAddr).Port))}, @@ -75,8 +72,12 @@ func TestWaitForTCPPort(t *testing.T) { } plugin := &WaitPort{} - if _, err = plugin.Run(ctx, testStepChannels, ev, nil, params, nil); err != nil { + if _, err = plugin.Run(ctx, stepIO, ev, nil, params, nil); err != nil { t.Errorf("Plugin run failed: '%v'", err) } wg.Wait() + + require.Equal(t, map[string]error{ + "some_id": nil, + }, stepIO.GetReportedTargets()) } diff --git a/tests/common/mocks/test_step_input_output_mock.go b/tests/common/mocks/test_step_input_output_mock.go new file mode 100644 index 00000000..b0a30a9d --- /dev/null +++ b/tests/common/mocks/test_step_input_output_mock.go @@ -0,0 +1,55 @@ +package mocks + +import ( + "github.com/linuxboot/contest/pkg/target" + "github.com/linuxboot/contest/pkg/test" + "github.com/linuxboot/contest/pkg/xcontext" + "sync" +) + +type TestStepInputOutputMock struct { + mu sync.Mutex + inputTargets []target.Target + targetsIdx int + + reportedTargets map[string]error +} + +func NewTestStepInputOutputMock(inputTargets []target.Target) *TestStepInputOutputMock { + return &TestStepInputOutputMock{ + inputTargets: inputTargets, + reportedTargets: make(map[string]error), + } +} + +func (ioMock *TestStepInputOutputMock) Get(ctx xcontext.Context) (*target.Target, error) { + ioMock.mu.Lock() + defer ioMock.mu.Unlock() + + if ioMock.targetsIdx >= len(ioMock.inputTargets) { + return nil, nil + } + ioMock.targetsIdx++ + return &ioMock.inputTargets[ioMock.targetsIdx-1], nil +} + +func (ioMock *TestStepInputOutputMock) Report(ctx xcontext.Context, tgt target.Target, err error) error { + ioMock.mu.Lock() + defer ioMock.mu.Unlock() + + ioMock.reportedTargets[tgt.ID] = err + return nil +} + +func (ioMock *TestStepInputOutputMock) GetReportedTargets() map[string]error { + ioMock.mu.Lock() + defer ioMock.mu.Unlock() + + result := make(map[string]error) + for tgtID, err := range ioMock.reportedTargets { + result[tgtID] = err + } + return result +} + +var _ test.TestStepInputOutput = (*TestStepInputOutputMock)(nil) diff --git a/tests/plugins/teststeps/badtargets/badtargets.go b/tests/plugins/teststeps/badtargets/badtargets.go index 2ea7e4ae..7aa452d4 100644 --- a/tests/plugins/teststeps/badtargets/badtargets.go +++ b/tests/plugins/teststeps/badtargets/badtargets.go @@ -8,7 +8,6 @@ package badtargets import ( "encoding/json" "fmt" - "github.com/linuxboot/contest/pkg/event" "github.com/linuxboot/contest/pkg/event/testevent" "github.com/linuxboot/contest/pkg/target" @@ -33,65 +32,49 @@ func (ts *badTargets) Name() string { // Run executes a step that messes up the flow of targets. func (ts *badTargets) Run( ctx xcontext.Context, - ch test.TestStepChannels, + io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, inputParams test.TestStepParameters, resumeState json.RawMessage, ) (json.RawMessage, error) { for { - select { - case tgt, ok := <-ch.In: - if !ok { - return nil, nil + tgt, err := io.Get(ctx) + if err != nil { + return nil, err + } + if tgt == nil { + return nil, nil + } + + switch tgt.ID { + case "TDrop": + // ... crickets ... + case "TGood": + if err := io.Report(ctx, *tgt, nil); err != nil { + return nil, err + } + case "TDup": + if err := io.Report(ctx, *tgt, nil); err != nil { + return nil, err + } + if err := io.Report(ctx, *tgt, nil); err != nil { + return nil, err + } + case "TExtra": + if err := io.Report(ctx, *tgt, nil); err != nil { + return nil, err + } + if err := io.Report(ctx, target.Target{ID: "TExtra2"}, nil); err != nil { + return nil, err } - switch tgt.ID { - case "TDrop": - // ... crickets ... - case "TGood": - // We should not depend on pointer matching, so emit a copy. - tgt2 := *tgt - select { - case ch.Out <- test.TestStepResult{Target: &tgt2}: - case <-ctx.Done(): - return nil, xcontext.ErrCanceled - } - case "TDup": - select { - case ch.Out <- test.TestStepResult{Target: tgt}: - case <-ctx.Done(): - return nil, xcontext.ErrCanceled - } - select { - case ch.Out <- test.TestStepResult{Target: tgt}: - case <-ctx.Done(): - return nil, xcontext.ErrCanceled - } - case "TExtra": - tgt2 := &target.Target{ID: "TExtra2"} - select { - case ch.Out <- test.TestStepResult{Target: tgt}: - case <-ctx.Done(): - return nil, xcontext.ErrCanceled - } - select { - case ch.Out <- test.TestStepResult{Target: tgt2}: - case <-ctx.Done(): - return nil, xcontext.ErrCanceled - } - case "T1": - // Mangle the returned target name. - tgt2 := &target.Target{ID: tgt.ID + "XXX"} - select { - case ch.Out <- test.TestStepResult{Target: tgt2}: - case <-ctx.Done(): - return nil, xcontext.ErrCanceled - } - default: - return nil, fmt.Errorf("Unexpected target name: %q", tgt.ID) + case "T1": + // Mangle the returned target name. + if err := io.Report(ctx, target.Target{ID: tgt.ID + "XXX"}, nil); err != nil { + return nil, err } - case <-ctx.Done(): - return nil, xcontext.ErrCanceled + default: + return nil, fmt.Errorf("unexpected target name: %q", tgt.ID) } } } diff --git a/tests/plugins/teststeps/channels/channels.go b/tests/plugins/teststeps/channels/channels.go deleted file mode 100644 index 971cd387..00000000 --- a/tests/plugins/teststeps/channels/channels.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -package channels - -import ( - "encoding/json" - - "github.com/linuxboot/contest/pkg/event" - "github.com/linuxboot/contest/pkg/event/testevent" - "github.com/linuxboot/contest/pkg/test" - "github.com/linuxboot/contest/pkg/xcontext" -) - -// Name is the name used to look this plugin up. -var Name = "Channels" - -// Events defines the events that a TestStep is allowed to emit -var Events = []event.Name{} - -type channels struct { -} - -// Name returns the name of the Step -func (ts *channels) Name() string { - return Name -} - -// Run executes a step that runs fine but closes its output channels on exit. -func (ts *channels) Run( - ctx xcontext.Context, - ch test.TestStepChannels, - ev testevent.Emitter, - stepsVars test.StepsVariables, - inputParams test.TestStepParameters, - resumeState json.RawMessage, -) (json.RawMessage, error) { - for target := range ch.In { - ch.Out <- test.TestStepResult{Target: target} - } - // This is bad, do not do this. - close(ch.Out) - return nil, nil -} - -// ValidateParameters validates the parameters associated to the TestStep -func (ts *channels) ValidateParameters(_ xcontext.Context, params test.TestStepParameters) error { - return nil -} - -// New creates a new Channels step -func New() test.TestStep { - return &channels{} -} diff --git a/tests/plugins/teststeps/hanging/hanging.go b/tests/plugins/teststeps/hanging/hanging.go index 434af72a..cfa9c520 100644 --- a/tests/plugins/teststeps/hanging/hanging.go +++ b/tests/plugins/teststeps/hanging/hanging.go @@ -31,7 +31,7 @@ func (ts *hanging) Name() string { // Run executes a step that does not process any targets and never returns. func (ts *hanging) Run( ctx xcontext.Context, - ch test.TestStepChannels, + io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, inputParams test.TestStepParameters, diff --git a/tests/plugins/teststeps/noreturn/noreturn.go b/tests/plugins/teststeps/noreturn/noreturn.go index e145bbb9..708345a1 100644 --- a/tests/plugins/teststeps/noreturn/noreturn.go +++ b/tests/plugins/teststeps/noreturn/noreturn.go @@ -31,14 +31,23 @@ func (ts *noreturnStep) Name() string { // Run executes a step that never returns. func (ts *noreturnStep) Run( ctx xcontext.Context, - ch test.TestStepChannels, + io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, inputParams test.TestStepParameters, resumeState json.RawMessage, ) (json.RawMessage, error) { - for target := range ch.In { - ch.Out <- test.TestStepResult{Target: target} + for { + tgt, err := io.Get(ctx) + if err != nil { + return nil, err + } + if tgt == nil { + break + } + if err := io.Report(ctx, *tgt, nil); err != nil { + return nil, err + } } channel := make(chan struct{}) <-channel diff --git a/tests/plugins/teststeps/panicstep/panicstep.go b/tests/plugins/teststeps/panicstep/panicstep.go index 5b565902..5aacc8b4 100644 --- a/tests/plugins/teststeps/panicstep/panicstep.go +++ b/tests/plugins/teststeps/panicstep/panicstep.go @@ -31,7 +31,7 @@ func (ts *panicStep) Name() string { // Run executes the example step. func (ts *panicStep) Run( ctx xcontext.Context, - ch test.TestStepChannels, + io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, inputParams test.TestStepParameters, diff --git a/tests/plugins/teststeps/teststep/teststep.go b/tests/plugins/teststeps/teststep/teststep.go index 831f75ad..c08a3fa3 100644 --- a/tests/plugins/teststeps/teststep/teststep.go +++ b/tests/plugins/teststeps/teststep/teststep.go @@ -67,7 +67,7 @@ func (ts *Step) shouldFail(t *target.Target, params test.TestStepParameters) boo // Run executes the example step. func (ts *Step) Run( ctx xcontext.Context, - ch test.TestStepChannels, + io test.TestStepInputOutput, ev testevent.Emitter, stepsVars test.StepsVariables, params test.TestStepParameters, @@ -103,7 +103,7 @@ func (ts *Step) Run( if err := ev.Emit(ctx, testevent.Data{EventName: StepRunningEvent}); err != nil { return nil, fmt.Errorf("failed to emit failed event: %v", err) } - _, res := teststeps.ForEachTarget(Name, ctx, ch, f) + _, res := teststeps.ForEachTarget(Name, ctx, io, f) if err := ev.Emit(ctx, testevent.Data{EventName: StepFinishedEvent}); err != nil { return nil, fmt.Errorf("failed to emit failed event: %v", err) }