Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement process exit expect failure #13

Merged
merged 9 commits into from
Jun 28, 2024
10 changes: 5 additions & 5 deletions expect.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (tt *TermTest) ExpectCustom(consumer consumer, opts ...SetExpectOpt) (rerr
return fmt.Errorf("could not create expect options: %w", err)
}

cons, err := tt.outputProducer.addConsumer(tt, consumer, expectOpts.ToConsumerOpts()...)
cons, err := tt.outputProducer.addConsumer(consumer, expectOpts.ToConsumerOpts()...)
if err != nil {
return fmt.Errorf("could not add consumer: %w", err)
}
Expand Down Expand Up @@ -180,11 +180,11 @@ func (tt *TermTest) expectExitCode(exitCode int, match bool, opts ...SetExpectOp
select {
case <-time.After(timeoutV):
return fmt.Errorf("after %s: %w", timeoutV, TimeoutError)
case state := <-tt.Exited(false): // do not wait for unread output since it's not read by this select{}
if state.Err != nil && (state.ProcessState == nil || state.ProcessState.ExitCode() == 0) {
return fmt.Errorf("cmd wait failed: %w", state.Err)
case err := <-waitChan(tt.cmd.Wait):
if err != nil && (tt.cmd.ProcessState == nil || tt.cmd.ProcessState.ExitCode() == 0) {
return fmt.Errorf("cmd wait failed: %w", err)
}
if err := tt.assertExitCode(state.ProcessState.ExitCode(), exitCode, match); err != nil {
if err := tt.assertExitCode(tt.cmd.ProcessState.ExitCode(), exitCode, match); err != nil {
return err
}
}
Expand Down
21 changes: 21 additions & 0 deletions expect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,27 @@ func Test_Expect_Timeout(t *testing.T) {
tt.ExpectExitCode(0, OptExpectTimeout(time.Hour))
}

// Test_ExpectMet_ProcessExit tests a potential race condition; where the process exiting event might hit before all output
// has been fully consumed.
func Test_ExpectMet_ProcessExit(t *testing.T) {
tt := newTermTest(t, exec.Command("bash", "-c", "echo HELLO && exit 1"), false)
tt.Expect("HELLO")
tt.ExpectExitCode(1)
}

// Test_ExpectFail_ProcessExit tests that we don't wait for the timeout if an expect is still waiting for output after
// the process has exited.
func Test_ExpectFail_ProcessExit(t *testing.T) {
tt := newTermTest(t, exec.Command("bash", "-c", "echo HELLO && exit 1"), true)
start := time.Now()
err := tt.Expect("GOODBYE", OptExpectTimeout(5*time.Second), OptExpectSilenceErrorHandler())
require.ErrorIs(t, err, ptyEOF)
if time.Now().Sub(start) >= 5*time.Second {
t.Errorf("Should not have waited for timeout as process has exited")
}
tt.ExpectExitCode(1)
}

func Test_ExpectMatchTwiceSameBuffer(t *testing.T) {
tt := newTermTest(t, exec.Command("bash"), false)

Expand Down
27 changes: 18 additions & 9 deletions outputconsumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ type outputConsumer struct {
opts *OutputConsumerOpts
isalive bool
mutex *sync.Mutex
tt *TermTest
}

type OutputConsumerOpts struct {
Expand All @@ -37,7 +36,7 @@ func OptsConsTimeout(timeout time.Duration) func(o *OutputConsumerOpts) {
}
}

func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outputConsumer {
func newOutputConsumer(consume consumer, opts ...SetConsOpt) *outputConsumer {
oc := &outputConsumer{
consume: consume,
opts: &OutputConsumerOpts{
Expand All @@ -47,7 +46,6 @@ func newOutputConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) *outp
waiter: make(chan error, 1),
isalive: true,
mutex: &sync.Mutex{},
tt: tt,
}

for _, optSetter := range opts {
Expand Down Expand Up @@ -83,6 +81,23 @@ func (e *outputConsumer) Report(buffer []byte) (int, error) {
return pos, err
}

type errConsumerStopped struct {
reason error
}

func (e errConsumerStopped) Error() string {
return fmt.Sprintf("consumer stopped, reason: %s", e.reason)
}

func (e errConsumerStopped) Unwrap() error {
return e.reason
}

func (e *outputConsumer) Stop(reason error) {
e.opts.Logger.Printf("stopping consumer, reason: %s\n", reason)
e.waiter <- errConsumerStopped{reason}
}

func (e *outputConsumer) wait() error {
e.opts.Logger.Println("started waiting")
defer e.opts.Logger.Println("stopped waiting")
Expand All @@ -103,11 +118,5 @@ func (e *outputConsumer) wait() error {
e.mutex.Lock()
e.opts.Logger.Println("Encountered timeout")
return fmt.Errorf("after %s: %w", e.opts.Timeout, TimeoutError)
case state := <-e.tt.Exited(true): // allow for output to be read first by first case in this select{}
e.mutex.Lock()
if state.Err != nil {
e.opts.Logger.Println("Encountered error waiting for process to exit: %s\n", state.Err.Error())
}
return fmt.Errorf("process exited (status: %d)", state.ProcessState.ExitCode())
}
}
20 changes: 17 additions & 3 deletions outputproducer.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ func (o *outputProducer) listen(r io.Reader, w io.Writer, appendBuffer func([]by
o.opts.Logger.Println("listen: loop")
if err := o.processNextRead(br, w, appendBuffer, size); err != nil {
if errors.Is(err, ptyEOF) {
o.opts.Logger.Println("listen: reached EOF")
return nil
} else {
return fmt.Errorf("could not poll reader: %w", err)
Expand All @@ -78,6 +77,7 @@ func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer
pathError := &fs.PathError{}
if errors.Is(errRead, fs.ErrClosed) || errors.Is(errRead, io.EOF) || (runtime.GOOS == "linux" && errors.As(errRead, &pathError)) {
isEOF = true
o.opts.Logger.Println("reached EOF")
}
}

Expand All @@ -96,6 +96,7 @@ func (o *outputProducer) processNextRead(r io.Reader, w io.Writer, appendBuffer

if errRead != nil {
if isEOF {
o.closeConsumers(ptyEOF)
return errors.Join(errRead, ptyEOF)
}
return fmt.Errorf("could not read pty output: %w", errRead)
Expand Down Expand Up @@ -194,6 +195,19 @@ func (o *outputProducer) processDirtyOutput(output []byte, cursorPos int, cleanU
return append(append(alreadyCleanedOutput, processedOutput...), unprocessedOutput...), processedCursorPos, newCleanUptoPos, nil
}

func (o *outputProducer) closeConsumers(reason error) {
o.opts.Logger.Println("closing consumers")
defer o.opts.Logger.Println("closed consumers")

o.mutex.Lock()
defer o.mutex.Unlock()

for n := 0; n < len(o.consumers); n++ {
o.consumers[n].Stop(reason)
o.consumers = append(o.consumers[:n], o.consumers[n+1:]...)
}
}

func (o *outputProducer) flushConsumers() error {
o.opts.Logger.Println("flushing consumers")
defer o.opts.Logger.Println("flushed consumers")
Expand Down Expand Up @@ -238,12 +252,12 @@ func (o *outputProducer) flushConsumers() error {
return nil
}

func (o *outputProducer) addConsumer(tt *TermTest, consume consumer, opts ...SetConsOpt) (*outputConsumer, error) {
func (o *outputProducer) addConsumer(consume consumer, opts ...SetConsOpt) (*outputConsumer, error) {
o.opts.Logger.Printf("adding consumer")
defer o.opts.Logger.Printf("added consumer")

opts = append(opts, OptConsInherit(o.opts))
listener := newOutputConsumer(tt, consume, opts...)
listener := newOutputConsumer(consume, opts...)
o.consumers = append(o.consumers, listener)

if err := o.flushConsumers(); err != nil {
Expand Down
27 changes: 0 additions & 27 deletions termtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ type TermTest struct {
outputProducer *outputProducer
listenError chan error
opts *Opts
exited *cmdExit
}

type ErrorHandler func(*TermTest, error) error
Expand Down Expand Up @@ -238,10 +237,6 @@ func (tt *TermTest) start() (rerr error) {
}()
wg.Wait()

go func() {
tt.exited = <-waitForCmdExit(tt.cmd)
}()

return nil
}

Expand Down Expand Up @@ -324,28 +319,6 @@ func (tt *TermTest) SendCtrlC() {
tt.Send(string([]byte{0x03})) // 0x03 is ASCII character for ^C
}

// Exited returns a channel that sends the given termtest's command cmdExit info when available.
// This can be used within a select{} statement.
// If waitExtra is given, waits a little bit before sending cmdExit info. This allows any fellow
// switch cases with output consumers to handle unprocessed stdout. If there are no such cases
// (e.g. ExpectExit(), where we want to catch an exit ASAP), waitExtra should be false.
func (tt *TermTest) Exited(waitExtra bool) chan *cmdExit {
return waitChan(func() *cmdExit {
ticker := time.NewTicker(processExitPollInterval)
for {
select {
case <-ticker.C:
if tt.exited != nil {
if waitExtra { // allow sibling output consumer cases to handle their output
time.Sleep(processExitExtraWait)
}
return tt.exited
}
}
}
})
}

func (tt *TermTest) errorHandler(rerr *error) {
err := *rerr
if err == nil {
Expand Down
Loading