From 978f98fbf604d8ce769efc01f7a0da158aedc82d Mon Sep 17 00:00:00 2001 From: Adam Bouhenguel Date: Sat, 23 Jul 2016 14:56:01 -0700 Subject: [PATCH] Clean up run command Split up argument parsing into output-related and execution-related options. Also break up argument processing from actual command running so we can reuse these primitives for the upcoming `qa auto` functionality. Also switch to -squash argument instead of suffixes for runner types. Use io.Closer interface in a number of places rather than defer foo.Close(). Also start properly cleaning up signal handlers and other cruft that the run command generates. Avoid double-closing server.Server instances, so we can reuse them. This is the first step towards automatically running tests on file save. --- src/qa/cmd/run/execution-flags.go | 140 ++++++++++++ src/qa/cmd/run/output-flags.go | 265 ++++++++++++++++++++++ src/qa/cmd/run/run-flags.go | 65 ++++++ src/qa/cmd/run/run.go | 362 ++++-------------------------- src/qa/emitter/emitter.go | 79 ++----- src/qa/main/main.go | 23 +- src/qa/runner/ruby/ruby.go | 64 +++--- src/qa/runner/runner.go | 82 +++++++ src/qa/runner/server/server.go | 3 + src/qa/suite/suite.go | 34 +-- src/qa/tapjio/flamegraph.go | 4 +- src/qa/tapjio/stacktrace.go | 11 + src/qa/tapjio/tapj.go | 11 +- src/qa/tapjio/trace.go | 6 +- 14 files changed, 693 insertions(+), 456 deletions(-) create mode 100644 src/qa/cmd/run/execution-flags.go create mode 100644 src/qa/cmd/run/output-flags.go create mode 100644 src/qa/cmd/run/run-flags.go diff --git a/src/qa/cmd/run/execution-flags.go b/src/qa/cmd/run/execution-flags.go new file mode 100644 index 0000000..7ab6654 --- /dev/null +++ b/src/qa/cmd/run/execution-flags.go @@ -0,0 +1,140 @@ +package run + +import ( + "errors" + "flag" + "fmt" + "math/rand" + "runtime" + "strings" + + "qa/cmd" + "qa/emitter" + "qa/runner" + "qa/runner/server" +) + +type executionFlags struct { + jobs *int + squashPolicy *runner.SquashPolicy + listenNetwork *string + listenAddress *string + errorsCaptureLocals *string + captureStandardFds *bool + evalBeforeFork *string + evalAfterFork *string + sampleStack *bool + warmup *bool + seed *int +} + +type squashPolicyValue struct { + value *runner.SquashPolicy +} + +func (v *squashPolicyValue) String() string { + switch *v.value { + case runner.SquashAll: + return "all" + case runner.SquashNothing: + return "nothing" + case runner.SquashByFile: + return "file" + } + + return "" +} + +func (v *squashPolicyValue) Set(s string) error { + switch s { + case "all": + *v.value = runner.SquashAll + case "none": + *v.value = runner.SquashNothing + case "file": + *v.value = runner.SquashByFile + default: + return errors.New("Invalid squash policy: " + s) + } + + return nil +} + +func defineExecutionFlags(flags *flag.FlagSet) *executionFlags { + squashPolicyValue := &squashPolicyValue{new(runner.SquashPolicy)} + *squashPolicyValue.value = runner.SquashByFile + flags.Var(squashPolicyValue, "squash", "One of: all, none, file") + + return &executionFlags{ + seed: flags.Int("seed", int(rand.Int31()), "Set seed to use"), + jobs: flags.Int("jobs", runtime.NumCPU(), "Set number of jobs"), + squashPolicy: squashPolicyValue.value, + listenNetwork: flags.String("listen-network", "unix", "Specify unix or tcp socket for worker coordination"), + listenAddress: flags.String("listen-address", "/tmp/qa", "Listen address for worker coordination"), + errorsCaptureLocals: flags.String("errors-capture-locals", "false", "Use runtime debug API to capture locals from stack when raising errors"), + captureStandardFds: flags.Bool("capture-standard-fds", true, "Capture stdout and stderr"), + evalBeforeFork: flags.String("eval-before-fork", "", "Execute the given code before forking any workers or loading any files"), + evalAfterFork: flags.String("eval-after-fork", "", "Execute the given code after a work forks, but before work begins"), + sampleStack: flags.Bool("sample-stack", false, "Enable stack sampling"), + warmup: flags.Bool("warmup", false, "Use a variety of experimental heuristics to warm up worker caches"), + } +} + +func (f *executionFlags) Listen() (*server.Server, error) { + return server.Listen(*f.listenNetwork, *f.listenAddress) +} + +func (f *executionFlags) WorkerEnvs() []map[string]string { + workerEnvs := []map[string]string{} + for i := 0; i < *f.jobs; i++ { + workerEnvs = append(workerEnvs, + map[string]string{"QA_WORKER": fmt.Sprintf("%d", i)}) + } + + return workerEnvs +} + +func (f *executionFlags) RunnerConfigs(env *cmd.Env, runnerSpecs []string) []runner.Config { + var configs []runner.Config + for _, runnerSpec := range runnerSpecs { + runnerSpecSplit := strings.Split(runnerSpec, ":") + runnerName := runnerSpecSplit[0] + var lister runner.FileLister + if len(runnerSpecSplit) == 1 { + lister = runner.NewFileGlob(env.Dir, []string{emitter.DefaultGlob(runnerName)}) + } else { + lister = runner.NewFileGlob(env.Dir, runnerSpecSplit[1:]) + } + + configs = append(configs, runner.Config{ + Name: runnerName, + FileLister: lister, + Seed: *f.seed, + Dir: env.Dir, + EnvVars: env.Vars, + SquashPolicy: *f.squashPolicy, + // Enable entries below to add specific method calls (and optionally their arguments) to the trace. + TraceProbes: []string{ + // "Kernel#require(path)", + // "Kernel#load", + // "ActiveRecord::ConnectionAdapters::Mysql2Adapter#execute(sql,name)", + // "ActiveRecord::ConnectionAdapters::PostgresSQLAdapter#execute_and_clear(sql,name,binds)", + // "ActiveSupport::Dependencies::Loadable#require(path)", + // "ActiveRecord::ConnectionAdapters::QueryCache#clear_query_cache", + // "ActiveRecord::ConnectionAdapters::SchemaCache#initialize", + // "ActiveRecord::ConnectionAdapters::SchemaCache#clear!", + // "ActiveRecord::ConnectionAdapters::SchemaCache#clear_table_cache!", + }, + PassthroughConfig: map[string](interface{}){ + "warmup": *f.warmup, + "errorsCaptureLocals": *f.errorsCaptureLocals, + "captureStandardFds": *f.captureStandardFds, + "evalBeforeFork": *f.evalBeforeFork, + "evalAfterFork": *f.evalAfterFork, + "sampleStack": *f.sampleStack, + }, + }) + } + + return configs +} diff --git a/src/qa/cmd/run/output-flags.go b/src/qa/cmd/run/output-flags.go new file mode 100644 index 0000000..13c6b88 --- /dev/null +++ b/src/qa/cmd/run/output-flags.go @@ -0,0 +1,265 @@ +package run + +import ( + "errors" + "flag" + "fmt" + "io/ioutil" + "log" + "math/rand" + "os" + "path" + "time" + + "qa/cmd" + "qa/reporting" + "qa/tapjio" +) + +func maybeJoin(p string, dir string) string { + if p != "" && p[0] != '.' && p[0] != '/' { + return path.Join(dir, p) + } + + return p +} + +const archiveNonceLexicon = "abcdefghijklmnopqrstuvwxyz" +const archiveNonceLength = 8 + +func randomString(r *rand.Rand, lexicon string, length int) string { + bytes := make([]byte, length) + lexiconLen := len(lexicon) + for i := 0; i < length; i++ { + bytes[i] = lexicon[r.Intn(lexiconLen)] + } + + return string(bytes) +} + +func newArchiveTapjEmitter(archiveBaseDir string) (tapjio.Visitor, error) { + now := time.Now() + tapjArchiveDir := path.Join(archiveBaseDir, now.Format("2006-01-02")) + os.MkdirAll(tapjArchiveDir, 0755) + r := rand.New(rand.NewSource(now.UnixNano())) + nonce := randomString(r, archiveNonceLexicon, archiveNonceLength) + + tapjArchiveFilePath := path.Join(tapjArchiveDir, fmt.Sprintf("%d-%s.tapj", now.Unix(), nonce)) + tapjArchiveFile, err := os.Create(tapjArchiveFilePath) + if err != nil { + return nil, err + } + + return tapjio.NewTapjEmitCloser(tapjArchiveFile), nil +} + +type outputFlags struct { + archiveBaseDir *string + auditDir *string + quiet *bool + saveTapj *string + saveTrace *string + saveStacktraces *string + saveFlamegraph *string + saveIcegraph *string + savePalette *string + format *string + showUpdatingSummary *bool + elidePass *bool + elideOmit *bool +} + +func defineOutputFlags(flags *flag.FlagSet) *outputFlags { + return &outputFlags{ + archiveBaseDir: flags.String("archive-base-dir", "", "Base directory to store data for later analysis"), + auditDir: flags.String("audit-dir", "", "Directory to save any generated audits, e.g. TAP-J, JSON, SVG, etc."), + quiet: flags.Bool("quiet", false, "Whether or not to print anything at all"), + saveTapj: flags.String("save-tapj", "", "Path to save TAP-J"), + saveTrace: flags.String("save-trace", "", "Path to save trace JSON"), + saveStacktraces: flags.String("save-stacktraces", "", "Path to save stacktraces.txt, implies -sample-stack"), + saveFlamegraph: flags.String("save-flamegraph", "", "Path to save flamegraph SVG, implies -sample-stack"), + saveIcegraph: flags.String("save-icegraph", "", "Path to save icegraph SVG, implies -sample-stack"), + savePalette: flags.String("save-palette", "palette.map", "Path to save (flame|ice)graph palette"), + format: flags.String("format", "pretty", "Set output format"), + showUpdatingSummary: flags.Bool("pretty-overwrite", true, "Pretty reporter shows live updating summary"), + elidePass: flags.Bool("pretty-quiet-pass", true, "Pretty reporter elides passing tests without (std)output"), + elideOmit: flags.Bool("pretty-quiet-omit", true, "Pretty reporter elides omitted tests without (std)output"), + } +} + +func (f *outputFlags) newVisitor(env *cmd.Env, jobs int, svgTitleSuffix string) (tapjio.Visitor, error) { + saveTapj := *f.saveTapj + saveTrace := *f.saveTrace + saveStacktraces := *f.saveStacktraces + saveFlamegraph := *f.saveFlamegraph + saveIcegraph := *f.saveIcegraph + savePalette := *f.savePalette + + auditDir := *f.auditDir + + if auditDir != "" { + auditDir = maybeJoin(auditDir, env.Dir) + os.MkdirAll(auditDir, 0755) + + saveTapj = maybeJoin(saveTapj, auditDir) + saveTrace = maybeJoin(saveTrace, auditDir) + saveStacktraces = maybeJoin(saveStacktraces, auditDir) + saveFlamegraph = maybeJoin(saveFlamegraph, auditDir) + saveIcegraph = maybeJoin(saveIcegraph, auditDir) + savePalette = maybeJoin(savePalette, auditDir) + } + + var visitors []tapjio.Visitor + var err error + + if !*f.quiet { + switch *f.format { + case "tapj": + visitors = append(visitors, tapjio.NewTapjEmitter(env.Stdout)) + case "pretty": + pretty := reporting.NewPretty(env.Stdout, jobs) + pretty.ShowUpdatingSummary = *f.showUpdatingSummary + pretty.ElideQuietPass = *f.elidePass + pretty.ElideQuietOmit = *f.elideOmit + visitors = append(visitors, pretty) + default: + return nil, errors.New(fmt.Sprintf("Unknown format: %v", *f.format)) + } + } + + if saveTapj != "" { + tapjFile, err := os.Create(saveTapj) + if err != nil { + return nil, err + } + visitors = append(visitors, tapjio.NewTapjEmitCloser(tapjFile)) + } + + archiveBaseDir := *f.archiveBaseDir + if archiveBaseDir != "" { + archiveBaseDir = maybeJoin(archiveBaseDir, env.Dir) + visitor, err := newArchiveTapjEmitter(archiveBaseDir) + if err != nil { + return nil, err + } + visitors = append(visitors, visitor) + } + + if saveTrace != "" { + traceFile, err := os.Create(saveTrace) + if err != nil { + return nil, err + } + visitors = append(visitors, tapjio.NewTraceWriter(traceFile)) + } + + var stacktracesFile *os.File + if saveStacktraces != "" { + stacktracesFile, err = os.Create(saveStacktraces) + if err != nil { + log.Fatal(err) + } + } + + removeStacktracesFileAfterUse := false + if saveFlamegraph != "" || saveIcegraph != "" { + if stacktracesFile == nil { + stacktracesFile, err = ioutil.TempFile("", "stacktrace") + if err != nil { + log.Fatal(err) + } + removeStacktracesFileAfterUse = true + } + } + + if stacktracesFile != nil { + visitors = append(visitors, tapjio.NewStacktraceEmitCloser(stacktracesFile)) + } + + if saveFlamegraph != "" { + visitors = append(visitors, &tapjio.DecodingCallbacks{ + OnFinal: func(final tapjio.FinalEvent) error { + options := []string{ + "--title", "Flame Graph" + svgTitleSuffix, + "--minwidth=2", + } + if savePalette != "" { + options = append(options, "--cp", "--palfile="+savePalette) + } + + // There may be nothing to do if we didn't see any stacktrace data! + stacktraceFileInfo, err := stacktracesFile.Stat() + if err != nil { + return err + } + if stacktraceFileInfo.Size() == 0 { + return nil + } + + flamegraphFile, err := os.Create(saveFlamegraph) + if err != nil { + log.Fatal(err) + } + + stacktracesFile.Seek(0, 0) + err = tapjio.GenerateFlameGraph( + stacktracesFile, + flamegraphFile, + options...) + if err != nil { + return err + } + return nil + }, + }) + } + + if saveIcegraph != "" { + visitors = append(visitors, &tapjio.DecodingCallbacks{ + OnFinal: func(final tapjio.FinalEvent) error { + options := []string{ + "--title", "Icicle Graph" + svgTitleSuffix, + "--minwidth=2", + "--reverse", + "--inverted", + } + if savePalette != "" { + options = append(options, "--cp", "--palfile="+savePalette) + } + + // There may be nothing to do if we didn't see any stacktrace data! + stacktraceFileInfo, err := stacktracesFile.Stat() + if err != nil { + return err + } + if stacktraceFileInfo.Size() == 0 { + return nil + } + + icegraphFile, err := os.Create(saveIcegraph) + if err != nil { + log.Fatal(err) + } + + stacktracesFile.Seek(0, 0) + err = tapjio.GenerateFlameGraph( + stacktracesFile, + icegraphFile, + options...) + if err != nil { + return err + } + return nil + }, + }) + } + + if removeStacktracesFileAfterUse { + visitors = append(visitors, + &tapjio.DecodingCallbacks{ + OnEnd: func(err error) error { return os.Remove(stacktracesFile.Name()) }, + }) + } + + return tapjio.MultiVisitor(visitors), nil +} diff --git a/src/qa/cmd/run/run-flags.go b/src/qa/cmd/run/run-flags.go new file mode 100644 index 0000000..bf7f7c3 --- /dev/null +++ b/src/qa/cmd/run/run-flags.go @@ -0,0 +1,65 @@ +package run + +import ( + "encoding/json" + "flag" + "path/filepath" + "qa/cmd" + "strconv" +) + +type runFlags struct { + outputFlags *outputFlags + executionFlags *executionFlags + + chdir *string +} + +func DefineFlags(flags *flag.FlagSet) *runFlags { + return &runFlags{ + outputFlags: defineOutputFlags(flags), + executionFlags: defineExecutionFlags(flags), + chdir: flags.String("chdir", "", "Change to the given directory"), + } +} + +func (f *runFlags) NewEnv(env *cmd.Env, runnerSpecs []string) (*Env, error) { + executionFlags := *f.executionFlags + outputFlags := *f.outputFlags + e := *env + + if *f.chdir != "" { + if filepath.IsAbs(*f.chdir) { + e.Dir = *f.chdir + } else { + e.Dir = filepath.Join(e.Dir, *f.chdir) + } + } + + if *outputFlags.saveStacktraces != "" || + *outputFlags.saveFlamegraph != "" || + *outputFlags.saveIcegraph != "" { + *executionFlags.sampleStack = true + } + + svgTitleArgs, _ := json.Marshal(runnerSpecs) + svgTitleSuffix := " — jobs = " + strconv.Itoa(*executionFlags.jobs) + ", runnerSpecs = " + string(svgTitleArgs) + + visitor, err := outputFlags.newVisitor(&e, *executionFlags.jobs, string(svgTitleSuffix)) + if err != nil { + return nil, err + } + + srv, err := executionFlags.Listen() + if err != nil { + return nil, err + } + + return &Env{ + Seed: *executionFlags.seed, + WorkerEnvs: executionFlags.WorkerEnvs(), + RunnerConfigs: executionFlags.RunnerConfigs(&e, runnerSpecs), + Visitor: visitor, + Server: srv, + }, nil +} diff --git a/src/qa/cmd/run/run.go b/src/qa/cmd/run/run.go index 6419274..0c51d8f 100644 --- a/src/qa/cmd/run/run.go +++ b/src/qa/cmd/run/run.go @@ -1,373 +1,95 @@ package run -// cd && qa - import ( - // "bytes" - "encoding/json" "errors" "flag" "fmt" - "io/ioutil" - "log" - "math/rand" "os" - "os/exec" "os/signal" - "path" - "path/filepath" - "runtime" - "strconv" - "strings" "syscall" - "time" - - "github.com/mattn/go-zglob" "qa/cmd" "qa/emitter" - "qa/reporting" "qa/runner" "qa/runner/server" "qa/suite" "qa/tapjio" ) -func maybeJoin(p string, dir string) string { - if p != "" && p[0] != '.' && p[0] != '/' { - return path.Join(dir, p) - } - - return p +type Env struct { + Seed int + WorkerEnvs []map[string]string + RunnerConfigs []runner.Config + Visitor tapjio.Visitor + Server *server.Server } -const archiveNonceLexicon = "abcdefghijklmnopqrstuvwxyz" -const archiveNonceLength = 8 -func randomString(r *rand.Rand, lexicon string, length int) string { - bytes := make([]byte, length) - lexiconLen := len(lexicon) - for i := 0; i < length; i++ { - bytes[i] = lexicon[r.Intn(lexiconLen)] - } - - return string(bytes); -} - -func Main(env *cmd.Env, args []string) error { - flags := flag.NewFlagSet("run", flag.ContinueOnError) - archiveBaseDir := flags.String("archive-base-dir", "", "Base directory to store data for later analysis") - auditDir := flags.String("audit-dir", "", "Directory to save any generated audits, e.g. TAP-J, JSON, SVG, etc.") - quiet := flags.Bool("quiet", false, "Whether or not to print anything at all") - saveTapj := flags.String("save-tapj", "", "Path to save TAP-J") - saveTrace := flags.String("save-trace", "", "Path to save trace JSON") - saveStacktraces := flags.String("save-stacktraces", "", "Path to save stacktraces.txt, implies -sample-stack") - saveFlamegraph := flags.String("save-flamegraph", "", "Path to save flamegraph SVG, implies -sample-stack") - saveIcegraph := flags.String("save-icegraph", "", "Path to save icegraph SVG, implies -sample-stack") - savePalette := flags.String("save-palette", "palette.map", "Path to save (flame|ice)graph palette") - format := flags.String("format", "pretty", "Set output format") - jobs := flags.Int("jobs", runtime.NumCPU(), "Set number of jobs") - - showUpdatingSummary := flags.Bool("pretty-overwrite", true, "Pretty reporter shows live updating summary") - elidePass := flags.Bool("pretty-quiet-pass", true, "Pretty reporter elides passing tests without (std)output") - elideOmit := flags.Bool("pretty-quiet-omit", true, "Pretty reporter elides omitted tests without (std)output") - - errorsCaptureLocals := flags.String("errors-capture-locals", "false", "Use runtime debug API to capture locals from stack when raising errors") - captureStandardFds := flags.Bool("capture-standard-fds", true, "Capture stdout and stderr") - evalBeforeFork := flags.String("eval-before-fork", "", "Execute the given code before forking any workers or loading any files") - evalAfterFork := flags.String("eval-after-fork", "", "Execute the given code after a work forks, but before work begins") - sampleStack := flags.Bool("sample-stack", false, "Enable stack sampling") - - warmup := flags.Bool("warmup", false, "Use a variety of experimental heuristics to warm up worker caches") - - err := flags.Parse(args) - if err != nil { - return err - } - - if *saveStacktraces != "" || *saveFlamegraph != "" || *saveIcegraph != "" { - *sampleStack = true - } - - if *auditDir != "" { - os.MkdirAll(*auditDir, 0755) - - *saveTapj = maybeJoin(*saveTapj, *auditDir) - *saveTrace = maybeJoin(*saveTrace, *auditDir) - *saveStacktraces = maybeJoin(*saveStacktraces, *auditDir) - *saveFlamegraph = maybeJoin(*saveFlamegraph, *auditDir) - *saveIcegraph = maybeJoin(*saveIcegraph, *auditDir) - *savePalette = maybeJoin(*savePalette, *auditDir) - } - - var visitors []tapjio.Visitor - - if !*quiet { - switch *format { - case "tapj": - visitors = append(visitors, tapjio.NewTapjEmitter(env.Stdout)) - case "pretty": - pretty := reporting.NewPretty(env.Stdout, *jobs) - pretty.ShowUpdatingSummary = *showUpdatingSummary - pretty.ElideQuietPass = *elidePass - pretty.ElideQuietOmit = *elideOmit - visitors = append(visitors, pretty) - default: - return errors.New(fmt.Sprintf("Unknown format: %v", *format)) - } - } - - if *saveTapj != "" { - tapjFile, err := os.Create(*saveTapj) - if err != nil { - return err - } - defer tapjFile.Close() - visitors = append(visitors, tapjio.NewTapjEmitter(tapjFile)) - } - - if *archiveBaseDir != "" { - now := time.Now() - tapjArchiveDir := path.Join(*archiveBaseDir, now.Format("2006-01-02")) - os.MkdirAll(tapjArchiveDir, 0755) - r := rand.New(rand.NewSource(now.UnixNano())) - nonce := randomString(r, archiveNonceLexicon, archiveNonceLength) +func Run(env *Env) (tapjio.FinalEvent, error) { + var final tapjio.FinalEvent - tapjArchiveFilePath := path.Join(tapjArchiveDir, fmt.Sprintf("%d-%s.tapj", now.Unix(), nonce)) - tapjArchiveFile, err := os.Create(tapjArchiveFilePath) + var testRunners []runner.TestRunner + for _, runnerConfig := range env.RunnerConfigs { + em, err := emitter.Resolve(env.Server, env.WorkerEnvs, runnerConfig) + defer em.Close() if err != nil { - return err + fmt.Fprintf(os.Stderr, "Error! %v\n", err) + return final, err } - defer tapjArchiveFile.Close() - visitors = append(visitors, tapjio.NewTapjEmitter(tapjArchiveFile)) - } - if *saveTrace != "" { - traceFile, err := os.Create(*saveTrace) + traceEvents, runners, err := em.EnumerateTests() if err != nil { - return err + return final, err } - defer traceFile.Close() - visitors = append(visitors, tapjio.NewTraceWriter(traceFile)) - } - var stacktracesFile *os.File - if *saveStacktraces != "" { - stacktracesFile, err = os.Create(*saveStacktraces) - if err != nil { - log.Fatal(err) - } - defer stacktracesFile.Close() - } + testRunners = append(testRunners, runners...) - if *saveFlamegraph != "" || *saveIcegraph != "" { - if stacktracesFile == nil { - stacktracesFile, err = ioutil.TempFile("", "stacktrace") + visitor := env.Visitor + for _, traceEvent := range traceEvents { + err := visitor.TraceEvent(traceEvent) if err != nil { - log.Fatal(err) + return final, err } - defer stacktracesFile.Close() - defer os.Remove(stacktracesFile.Name()) } } - if stacktracesFile != nil { - visitors = append(visitors, tapjio.NewStacktraceEmitter(stacktracesFile)) - } - - if *saveFlamegraph != "" { - visitors = append(visitors, &tapjio.DecodingCallbacks{ - OnFinal: func(final tapjio.FinalEvent) error { - titleSuffix, _ := json.Marshal(flags.Args()) - options := []string{ - "--title", "Flame Graph — jobs = " + strconv.Itoa(*jobs) + ", args = " + string(titleSuffix), - "--minwidth=2", - } - if *savePalette != "" { - options = append(options, "--cp", "--palfile="+*savePalette) - } - - // There may be nothing to do if we didn't see any stacktrace data! - stacktraceFileInfo, err := stacktracesFile.Stat() - if err != nil { - return err - } - if stacktraceFileInfo.Size() == 0 { - return nil - } - - flamegraphFile, err := os.Create(*saveFlamegraph) - if err != nil { - log.Fatal(err) - } - defer flamegraphFile.Close() - - stacktracesFile.Seek(0, 0) - err = tapjio.GenerateFlameGraph( - stacktracesFile, - flamegraphFile, - options...) - if err != nil { - return err - } - return nil - }, - }) - } - - if *saveIcegraph != "" { - visitors = append(visitors, &tapjio.DecodingCallbacks{ - OnFinal: func(final tapjio.FinalEvent) error { - titleSuffix, _ := json.Marshal(flags.Args()) - options := []string{ - "--title", "Icicle Graph — jobs = " + strconv.Itoa(*jobs) + ", args = " + string(titleSuffix), - "--minwidth=2", - "--reverse", - "--inverted", - } - if *savePalette != "" { - options = append(options, "--cp", "--palfile="+*savePalette) - } - - // There may be nothing to do if we didn't see any stacktrace data! - stacktraceFileInfo, err := stacktracesFile.Stat() - if err != nil { - return err - } - if stacktraceFileInfo.Size() == 0 { - return nil - } + return suite.Run(env.Visitor, env.WorkerEnvs, env.Seed, testRunners) +} - icegraphFile, err := os.Create(*saveIcegraph) - if err != nil { - log.Fatal(err) - } - defer icegraphFile.Close() +func Main(env *cmd.Env, args []string) error { + flags := flag.NewFlagSet("run", flag.ContinueOnError) - stacktracesFile.Seek(0, 0) - err = tapjio.GenerateFlameGraph( - stacktracesFile, - icegraphFile, - options...) - if err != nil { - return err - } - return nil - }, - }) + f := DefineFlags(flags) + err := flags.Parse(args) + if err != nil { + return err } - visitor := tapjio.MultiVisitor(visitors) - - // srv, err := server.Listen("tcp", "127.0.0.1:0") - srv, err := server.Listen("unix", "/tmp/qa") + runEnv, err := f.NewEnv(env, flags.Args()) if err != nil { return err } + srv := runEnv.Server + defer srv.Close() + go srv.Run() + // Handle common process-killing signals so we can gracefully shut down: sigc := make(chan os.Signal, 1) signal.Notify(sigc, os.Interrupt, os.Kill, syscall.SIGTERM) go func(c chan os.Signal) { // Wait for signal - sig := <-c - fmt.Fprintln(env.Stderr, "Got signal:", sig) - srv.Close() - os.Exit(1) - }(sigc) - - defer srv.Close() - go srv.Run() - - workerEnvs := []map[string]string{} - for i := 0; i < *jobs; i++ { - workerEnvs = append(workerEnvs, - map[string]string{"QA_WORKER": fmt.Sprintf("%d", i)}) - } - - seed := int(rand.Int31()) - - // TODO(adamb) Parallelize this, after sanitizing name/globs specs. - var allRunners []runner.TestRunner - for _, runnerSpec := range flags.Args() { - runnerSpecSplit := strings.SplitN(runnerSpec, ":", 2) - var runnerName string - var globStr string - if len(runnerSpecSplit) != 2 { - runnerName = runnerSpecSplit[0] - globStr = emitter.DefaultGlob(runnerName) - } else { - runnerName = runnerSpecSplit[0] - globStr = runnerSpecSplit[1] + sig, ok := <-c + if ok { + fmt.Fprintln(env.Stderr, "Got signal:", sig) + srv.Close() } - - var files []string - - for _, glob := range strings.Split(globStr, ":") { - relative := !filepath.IsAbs(glob) - if relative && env.Dir != "" { - glob = filepath.Join(env.Dir, glob) - } - - globFiles, err := zglob.Glob(glob) - if !*quiet { - fmt.Fprintf(env.Stderr, "Resolved %v to %v\n", glob, globFiles) - if err != nil { - return err - } - } - - if relative && env.Dir != "" { - trimPrefix := fmt.Sprintf("%s%c", env.Dir, os.PathSeparator) - for _, file := range globFiles { - files = append(files, strings.TrimPrefix(file, trimPrefix)) - } - } else { - files = append(files, globFiles...) - } - } - - passthrough := map[string](interface{}){ - "warmup": *warmup, - "errorsCaptureLocals": *errorsCaptureLocals, - "captureStandardFds": *captureStandardFds, - "evalBeforeFork": *evalBeforeFork, - "evalAfterFork": *evalAfterFork, - "sampleStack": *sampleStack, - } - - em, err := emitter.Resolve(runnerName, srv, passthrough, workerEnvs, env.Dir, env.Vars, seed, files) - if err != nil { - return err - } - - traceEvents, runners, err := em.EnumerateTests() - if err != nil { - return err - } - - allRunners = append(allRunners, runners...) - - for _, traceEvent := range traceEvents { - err := visitor.TraceEvent(traceEvent) - if err != nil { - return err - } - } - } - - suite := suite.NewTestSuiteRunner(seed, srv, allRunners) + }(sigc) + defer signal.Stop(sigc) + defer close(sigc) var final tapjio.FinalEvent - - final, err = suite.Run(workerEnvs, visitor) + final, err = Run(runEnv) if err != nil { - fmt.Fprintln(env.Stderr, "Error in NewTestSuiteRunner", err) - if exitError, ok := err.(*exec.ExitError); ok { - if len(exitError.Stderr) > 0 { - fmt.Fprintln(env.Stderr, string(exitError.Stderr)) - } - } - return err } diff --git a/src/qa/emitter/emitter.go b/src/qa/emitter/emitter.go index 20b4095..375f8d6 100644 --- a/src/qa/emitter/emitter.go +++ b/src/qa/emitter/emitter.go @@ -9,54 +9,28 @@ import ( ) type Emitter interface { - TraceProbes() []string EnumerateTests() ([]tapjio.TraceEvent, []runner.TestRunner, error) + Close() error } type emitterStarter func( srv *server.Server, - passthroughConfig map[string](interface{}), workerEnvs []map[string]string, - dir string, - env map[string]string, - seed int, - files []string) (Emitter, error) + runnerConfig runner.Config) (Emitter, error) -// Enable entries below to add specific method calls (and optionally their arguments) to the trace. -var rubyTraceProbes = []string{ -// "Kernel#require(path)", -// "Kernel#load", -// "ActiveRecord::ConnectionAdapters::Mysql2Adapter#execute(sql,name)", -// "ActiveRecord::ConnectionAdapters::PostgresSQLAdapter#execute_and_clear(sql,name,binds)", -// "ActiveSupport::Dependencies::Loadable#require(path)", -// "ActiveRecord::ConnectionAdapters::QueryCache#clear_query_cache", -// "ActiveRecord::ConnectionAdapters::SchemaCache#initialize", -// "ActiveRecord::ConnectionAdapters::SchemaCache#clear!", -// "ActiveRecord::ConnectionAdapters::SchemaCache#clear_table_cache!", -} - -func rubyEmitterStarter(runnerAssetName string, policy ruby.SquashPolicy) emitterStarter { +func rubyEmitterStarter(runnerAssetName string) emitterStarter { return func( srv *server.Server, - passthroughConfig map[string](interface{}), workerEnvs []map[string]string, - dir string, - env map[string]string, - seed int, - files []string) (Emitter, error) { + runnerConfig runner.Config) (Emitter, error) { config := &ruby.ContextConfig{ - Dir: dir, - EnvVars: env, - Seed: seed, - Rubylib: []string{"spec", "lib", "test"}, - RunnerAssetName: runnerAssetName, - TraceProbes: rubyTraceProbes, - SquashPolicy: policy, - PassthroughConfig: passthroughConfig, + RunnerConfig: runnerConfig, + Rubylib: []string{"spec", "lib", "test"}, + RunnerAssetName: runnerAssetName, } - ctx, err := ruby.StartContext(config, srv, workerEnvs, files) + ctx, err := ruby.StartContext(config, srv, workerEnvs) if err != nil { return nil, err } @@ -66,27 +40,15 @@ func rubyEmitterStarter(runnerAssetName string, policy ruby.SquashPolicy) emitte } var starters = map[string]emitterStarter{ - "rspec": rubyEmitterStarter("ruby/rspec.rb", ruby.SquashByFile), - "rspec-squashall": rubyEmitterStarter("ruby/rspec.rb", ruby.SquashAll), - "rspec-pendantic": rubyEmitterStarter("ruby/rspec.rb", ruby.SquashNothing), - "minitest": rubyEmitterStarter("ruby/minitest.rb", ruby.SquashByFile), - "minitest-squashall": rubyEmitterStarter("ruby/minitest.rb", ruby.SquashAll), - "minitest-pendantic": rubyEmitterStarter("ruby/minitest.rb", ruby.SquashNothing), - "test-unit": rubyEmitterStarter("ruby/test-unit.rb", ruby.SquashByFile), - "test-unit-squashall": rubyEmitterStarter("ruby/test-unit.rb", ruby.SquashAll), - "test-unit-pendantic": rubyEmitterStarter("ruby/test-unit.rb", ruby.SquashNothing), + "rspec": rubyEmitterStarter("ruby/rspec.rb"), + "minitest": rubyEmitterStarter("ruby/minitest.rb"), + "test-unit": rubyEmitterStarter("ruby/test-unit.rb"), } -var defaultGlobs = map[string]string { - "rspec": "spec/**/*spec.rb", - "rspec-squashall": "spec/**/*spec.rb", - "rspec-pendantic": "spec/**/*spec.rb", - "minitest": "test/**/test*.rb", - "minitest-squashall": "test/**/test*.rb", - "minitest-pendantic": "test/**/test*.rb", +var defaultGlobs = map[string]string{ + "rspec": "spec/**/*spec.rb", + "minitest": "test/**/test*.rb", "test-unit": "test/**/test*.rb", - "test-unit-squashall": "test/**/test*.rb", - "test-unit-pendantic": "test/**/test*.rb", } func DefaultGlob(name string) string { @@ -94,18 +56,13 @@ func DefaultGlob(name string) string { } func Resolve( - name string, srv *server.Server, - passthroughConfig map[string](interface{}), workerEnvs []map[string]string, - dir string, - env map[string]string, - seed int, - files []string) (Emitter, error) { - starter, ok := starters[name] + config runner.Config) (Emitter, error) { + starter, ok := starters[config.Name] if !ok { - return nil, errors.New("Could not find starter: " + name) + return nil, errors.New("Could not find starter: " + config.Name) } - return starter(srv, passthroughConfig, workerEnvs, dir, env, seed, files) + return starter(srv, workerEnvs, config) } diff --git a/src/qa/main/main.go b/src/qa/main/main.go index c4fe7a6..1eb5309 100644 --- a/src/qa/main/main.go +++ b/src/qa/main/main.go @@ -4,39 +4,42 @@ import ( "errors" "fmt" "os" + "os/exec" "qa/cmd" + "qa/cmd/auto" "qa/cmd/discover" "qa/cmd/flaky" - "qa/cmd/grouping" - "qa/cmd/summary" "qa/cmd/flamegraph" + "qa/cmd/grouping" "qa/cmd/run" "qa/cmd/stackcollapse" + "qa/cmd/summary" ) func main() { var status int command := os.Args[1] + env := &cmd.Env{Stdin: os.Stdin, Stdout: os.Stdout, Stderr: os.Stderr} + var err error switch command { case "flaky": - env := &cmd.Env{Stdin: os.Stdin, Stdout: os.Stdout, Stderr: os.Stderr} err = flaky.Main(env, os.Args[2:]) case "discover": - env := &cmd.Env{Stdin: os.Stdin, Stdout: os.Stdout, Stderr: os.Stderr} err = discover.Main(env, os.Args[2:]) case "grouping": - env := &cmd.Env{Stdin: os.Stdin, Stdout: os.Stdout, Stderr: os.Stderr} err = grouping.Main(env, os.Args[2:]) case "summary": - env := &cmd.Env{Stdin: os.Stdin, Stdout: os.Stdout, Stderr: os.Stderr} err = summary.Main(env, os.Args[2:]) + case "auto": + err = auto.Main(env, os.Args[2:]) case "run": - env := &cmd.Env{Stdin: os.Stdin, Stdout: os.Stdout, Stderr: os.Stderr} err = run.Main(env, os.Args[2:]) case "flamegraph": + // TODO(adamb) Switch flamegraph to use env arg err = flamegraph.Main(os.Args[2:]) case "stackcollapse": + // TODO(adamb) Switch stackcollapse to use env arg err = stackcollapse.Main(os.Args[2:]) default: err = errors.New("Unknown command: " + command) @@ -45,6 +48,12 @@ func main() { if err != nil { fmt.Fprintln(os.Stderr, err) status = 1 + + if exitError, ok := err.(*exec.ExitError); ok { + if len(exitError.Stderr) > 0 { + fmt.Fprintln(env.Stderr, string(exitError.Stderr)) + } + } } else { status = 0 } diff --git a/src/qa/runner/ruby/ruby.go b/src/qa/runner/ruby/ruby.go index 16f6a89..e897384 100644 --- a/src/qa/runner/ruby/ruby.go +++ b/src/qa/runner/ruby/ruby.go @@ -12,23 +12,10 @@ import ( "sync" ) -type SquashPolicy int - -const ( - SquashNothing SquashPolicy = iota - SquashByFile - SquashAll -) - type ContextConfig struct { - Seed int - EnvVars map[string]string - Dir string - Rubylib []string - RunnerAssetName string - TraceProbes []string - SquashPolicy SquashPolicy - PassthroughConfig map[string](interface{}) + RunnerConfig runner.Config + Rubylib []string + RunnerAssetName string } type context struct { @@ -38,9 +25,16 @@ type context struct { config *ContextConfig } -func StartContext(cfg *ContextConfig, server *server.Server, workerEnvs []map[string]string, files []string) (*context, error) { +func StartContext(cfg *ContextConfig, server *server.Server, workerEnvs []map[string]string) (*context, error) { requestCh := make(chan interface{}, 1) + runnerCfg := cfg.RunnerConfig + + files, err := runnerCfg.Files() + if err != nil { + return nil, err + } + sharedData, err := assets.Asset("ruby/shared.rb") if err != nil { return nil, err @@ -67,18 +61,18 @@ func StartContext(cfg *ContextConfig, server *server.Server, workerEnvs []map[st requestCh <- map[string](interface{}){ "workerEnvs": workerEnvs, "files": files, - "passthrough": cfg.PassthroughConfig, + "passthrough": runnerCfg.PassthroughConfig, } - for _, traceProbe := range cfg.TraceProbes { + for _, traceProbe := range runnerCfg.TraceProbes { args = append(args, "--trace-probe", traceProbe) } cmd := exec.Command("ruby", args...) - if len(cfg.EnvVars) > 0 { + if len(runnerCfg.EnvVars) > 0 { baseEnv := os.Environ() - for envVarName, envVarValue := range cfg.EnvVars { + for envVarName, envVarValue := range runnerCfg.EnvVars { baseEnv = append(baseEnv, fmt.Sprintf("%s=%s", envVarName, envVarValue)) } cmd.Env = baseEnv @@ -88,7 +82,7 @@ func StartContext(cfg *ContextConfig, server *server.Server, workerEnvs []map[st // args = append([]string{"-ex=set follow-fork-mode child", "-ex=r", "--args", "ruby"}, args...) // cmd := exec.Command("gdb", args...) - cmd.Dir = cfg.Dir + cmd.Dir = runnerCfg.Dir cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr cmd.Stdout = os.Stderr @@ -107,23 +101,18 @@ func StartContext(cfg *ContextConfig, server *server.Server, workerEnvs []map[st // TODO(adamb) Should also cancel all existing waitgroups func (self *context) Close() (err error) { - s := *self.server + close(self.requestCh) + if self.process != nil { err = self.process.Kill() - } - - closeErr := s.Close() - if closeErr != nil { - err = closeErr + if err != nil { + _, err = self.process.Wait() + } } return } -func (self *context) TraceProbes() []string { - return self.config.TraceProbes -} - func (self *context) EnumerateTests() (traceEvents []tapjio.TraceEvent, testRunners []runner.TestRunner, err error) { var wg sync.WaitGroup wg.Add(1) @@ -136,9 +125,10 @@ func (self *context) EnumerateTests() (traceEvents []tapjio.TraceEvent, testRunn return }, OnTest: func(test tapjio.TestEvent) error { - if cfg.SquashPolicy == SquashNothing || - cfg.SquashPolicy == SquashByFile && (currentRunner == nil || currentRunner.file != test.File) || - cfg.SquashPolicy == SquashAll && currentRunner == nil { + squashPolicy := cfg.RunnerConfig.SquashPolicy + if squashPolicy == runner.SquashNothing || + squashPolicy == runner.SquashByFile && (currentRunner == nil || currentRunner.file != test.File) || + squashPolicy == runner.SquashAll && currentRunner == nil { if currentRunner != nil { testRunners = append(testRunners, *currentRunner) } @@ -164,7 +154,7 @@ func (self *context) EnumerateTests() (traceEvents []tapjio.TraceEvent, testRunn map[string]string{}, []string{ "--dry-run", - "--seed", fmt.Sprintf("%v", cfg.Seed), + "--seed", fmt.Sprintf("%v", cfg.RunnerConfig.Seed), "--tapj-sink", serverAddress, }) wg.Wait() @@ -264,7 +254,7 @@ func (self rubyRunner) Run(env map[string]string, callbacks tapjio.DecodingCallb self.ctx.request( env, append([]string{ - "--seed", fmt.Sprintf("%v", self.ctx.config.Seed), + "--seed", fmt.Sprintf("%v", self.ctx.config.RunnerConfig.Seed), "--tapj-sink", self.ctx.subscribeVisitor(&callbacks), }, self.filters...)) diff --git a/src/qa/runner/runner.go b/src/qa/runner/runner.go index ae96d83..9b2930f 100644 --- a/src/qa/runner/runner.go +++ b/src/qa/runner/runner.go @@ -1,8 +1,14 @@ package runner import ( + "fmt" + "os" + "path/filepath" "qa/tapjio" "sort" + "strings" + + "github.com/mattn/go-zglob" ) //go:generate go-bindata -o $GOGENPATH/qa/runner/assets/bindata.go -pkg assets -prefix ../runner-assets/ ../runner-assets/... @@ -35,3 +41,79 @@ func (s *testRunnerSorter) Swap(i, j int) { func (s *testRunnerSorter) Less(i, j int) bool { return s.by(&s.runners[i], &s.runners[j]) } + +type FileGlob struct { + dir string + patterns []string +} + +func NewFileGlob(dir string, patterns []string) *FileGlob { + return &FileGlob{dir: dir, patterns: patterns} +} + +type FileLister interface { + Patterns() []string + Dir() string + ListFiles() ([]string, error) +} + +func (f *FileGlob) Dir() string { + return f.dir +} + +func (f *FileGlob) Patterns() []string { + return f.patterns +} + +func (f *FileGlob) ListFiles() ([]string, error) { + var files []string + dir := f.dir + for _, pattern := range f.patterns { + // Make glob absolute, using dir + relative := !filepath.IsAbs(pattern) + if relative && dir != "" { + pattern = filepath.Join(dir, pattern) + } + + // Expand glob + globFiles, err := zglob.Glob(pattern) + if err != nil { + return files, err + } + + // Strip prefix from glob matches if needed. + if relative && dir != "" { + trimPrefix := fmt.Sprintf("%s%c", dir, os.PathSeparator) + for _, file := range globFiles { + files = append(files, strings.TrimPrefix(file, trimPrefix)) + } + } else { + files = append(files, globFiles...) + } + } + + return files, nil +} + +type SquashPolicy int + +const ( + SquashNothing SquashPolicy = iota + SquashByFile + SquashAll +) + +type Config struct { + Name string + FileLister FileLister + PassthroughConfig map[string](interface{}) + Dir string + EnvVars map[string]string + Seed int + SquashPolicy SquashPolicy + TraceProbes []string +} + +func (f *Config) Files() ([]string, error) { + return f.FileLister.ListFiles() +} diff --git a/src/qa/runner/server/server.go b/src/qa/runner/server/server.go index 1bbf5b1..3117d1c 100644 --- a/src/qa/runner/server/server.go +++ b/src/qa/runner/server/server.go @@ -84,6 +84,9 @@ func (s *Server) Run() error { } }() + defer close(s.registerCallbackChan) + defer close(s.registerChannelChan) + for acceptChan != nil { select { case entry := <-s.registerCallbackChan: diff --git a/src/qa/suite/suite.go b/src/qa/suite/suite.go index d19fb44..1b42f32 100644 --- a/src/qa/suite/suite.go +++ b/src/qa/suite/suite.go @@ -8,7 +8,6 @@ import ( "time" "qa/runner" - "qa/runner/server" "qa/tapjio" ) @@ -19,38 +18,21 @@ type eventUnion struct { error error } -type testSuiteRunner struct { - seed int - runners []runner.TestRunner - count int - srv *server.Server -} - -func NewTestSuiteRunner(seed int, - srv *server.Server, - runners []runner.TestRunner) *testSuiteRunner { +func Run( + visitor tapjio.Visitor, + workerEnvs []map[string]string, + seed int, + runners []runner.TestRunner) (final tapjio.FinalEvent, err error) { count := 0 for _, runner := range runners { count += runner.TestCount() } - return &testSuiteRunner{ - seed: seed, - runners: runners, - count: count, - srv: srv, - } -} - -func (self *testSuiteRunner) Run( - workerEnvs []map[string]string, - visitor tapjio.Visitor) (final tapjio.FinalEvent, err error) { - numWorkers := len(workerEnvs) startTime := time.Now().UTC() - suite := tapjio.NewSuiteEvent(startTime, self.count, self.seed) + suite := tapjio.NewSuiteEvent(startTime, count, seed) final = *tapjio.NewFinalEvent(suite) err = visitor.SuiteStarted(*suite) @@ -75,9 +57,9 @@ func (self *testSuiteRunner) Run( // near the end of the run by running testRunners with the most tests first, avoiding // scenarios where the last testRunner we run has many tests, causing the entire test // run to drag on needlessly while other workers are idle. - runner.By(func(r1, r2 *runner.TestRunner) bool { return (*r2).TestCount() < (*r1).TestCount() }).Sort(self.runners) + runner.By(func(r1, r2 *runner.TestRunner) bool { return (*r2).TestCount() < (*r1).TestCount() }).Sort(runners) - for _, testRunner := range self.runners { + for _, testRunner := range runners { testRunnerChan <- testRunner } close(testRunnerChan) diff --git a/src/qa/tapjio/flamegraph.go b/src/qa/tapjio/flamegraph.go index ac11a0a..fa37eaa 100644 --- a/src/qa/tapjio/flamegraph.go +++ b/src/qa/tapjio/flamegraph.go @@ -53,12 +53,14 @@ func emitSupportAsset(assetName string) (string, error) { } // GenerateFlameGraph runs the flamegraph script to generate a flame graph SVG. -func GenerateFlameGraph(stacktraceReader io.Reader, writer io.Writer, args ...string) error { +func GenerateFlameGraph(stacktraceReader io.Reader, writer io.WriteCloser, args ...string) error { flamegraphPl, err := emitSupportAsset("flamegraph.pl") if err != nil { return err } + defer writer.Close() + cmd := exec.Command("perl", append([]string{flamegraphPl}, args...)...) cmd.Stdin = stacktraceReader cmd.Stderr = os.Stderr diff --git a/src/qa/tapjio/stacktrace.go b/src/qa/tapjio/stacktrace.go index f7fa9e6..39bf40a 100644 --- a/src/qa/tapjio/stacktrace.go +++ b/src/qa/tapjio/stacktrace.go @@ -89,6 +89,7 @@ func (s *stacktraceWriter) EmitStacktrace(key string, weight int) error { type stacktraceEmitter struct { writer io.Writer + closer io.Closer } type encodedProfile struct { @@ -167,6 +168,13 @@ func decodeFlamegraphSample(writer io.Writer, b []byte) error { return nil } +func NewStacktraceEmitCloser(writer io.WriteCloser) *stacktraceEmitter { + return &stacktraceEmitter{ + writer: writer, + closer: writer, + } +} + func NewStacktraceEmitter(writer io.Writer) *stacktraceEmitter { return &stacktraceEmitter{ writer: writer, @@ -205,5 +213,8 @@ func (t *stacktraceEmitter) TestFinished(event TestEvent) error { } func (t *stacktraceEmitter) End(reason error) error { + if t.closer != nil { + return t.closer.Close() + } return nil } diff --git a/src/qa/tapjio/tapj.go b/src/qa/tapjio/tapj.go index 60c4b8d..1611ed0 100644 --- a/src/qa/tapjio/tapj.go +++ b/src/qa/tapjio/tapj.go @@ -6,12 +6,17 @@ import ( ) type tapj struct { + closer io.Closer encoder *json.Encoder currentCases []CaseEvent } +func NewTapjEmitCloser(writer io.WriteCloser) *tapj { + return &tapj{encoder: json.NewEncoder(writer), closer: writer} +} + func NewTapjEmitter(writer io.Writer) *tapj { - return &tapj{encoder: json.NewEncoder(writer)} + return &tapj{encoder: json.NewEncoder(writer), closer: nil} } func (t *tapj) TraceEvent(event TraceEvent) error { @@ -64,5 +69,9 @@ func (t *tapj) SuiteFinished(event FinalEvent) error { } func (t *tapj) End(reason error) error { + if t.closer != nil { + return t.closer.Close() + } + return nil } diff --git a/src/qa/tapjio/trace.go b/src/qa/tapjio/trace.go index 7c4bf3e..86b8110 100644 --- a/src/qa/tapjio/trace.go +++ b/src/qa/tapjio/trace.go @@ -9,12 +9,12 @@ import ( // Extracts trace events from a tapj stream. type TraceWriter struct { - writer io.Writer + writer io.WriteCloser encoder *json.Encoder delim string } -func NewTraceWriter(writer io.Writer) *TraceWriter { +func NewTraceWriter(writer io.WriteCloser) *TraceWriter { return &TraceWriter{ writer: writer, encoder: json.NewEncoder(writer), @@ -57,5 +57,5 @@ func (t *TraceWriter) TestFinished(event TestEvent) error { } func (t *TraceWriter) End(reason error) error { - return nil + return t.writer.Close() }