diff --git a/taskgroup.go b/taskgroup.go index e71d676..74faf91 100644 --- a/taskgroup.go +++ b/taskgroup.go @@ -10,24 +10,21 @@ import "sync" // tasks are collected and reported by the group. type Task func() error -// A Group manages a collection of cooperating goroutines. New tasks are added -// to the group with the Go method. Call the Wait method to wait for the tasks -// to complete. A zero value is ready for use, but must not be copied after its +// A Group manages a collection of cooperating goroutines. Add new tasks to +// the group with the Go method. Call the Wait method to wait for the tasks to +// complete. A zero value is ready for use, but must not be copied after its // first use. // // The group collects any errors returned by the tasks in the group. The first // non-nil error reported by any task (and not otherwise filtered) is returned // from the Wait method. type Group struct { - wg sync.WaitGroup // counter for active goroutines - err error // error returned from Wait + wg sync.WaitGroup // counter for active goroutines + onError ErrorFunc // called each time a task returns non-nil - setup sync.Once // set up and start the error collector - reset sync.Once // stop the error collector and set err - - onError func(error) error // called each time a task returns non-nil - errc chan<- error // errors generated by goroutines - edone chan struct{} // signals error completion + μ sync.Mutex // guards the fields below + setup sync.Once // set up and start the error collector + err error // error returned from Wait } // New constructs a new empty group. If ef != nil, it is called for each error @@ -42,54 +39,25 @@ func New(ef ErrorFunc) *Group { return &Group{onError: ef} } func (g *Group) Go(task Task) *Group { g.wg.Add(1) g.init() - errc := g.errc go func() { defer g.wg.Done() if err := task(); err != nil { - errc <- err + g.handleError(err) } }() return g } -func (g *Group) init() { - // The first time a task is added to an otherwise clear group, set up the - // error collector goroutine. We don't do this in the constructor so that - // an unused group can be abandoned without orphaning a goroutine. - g.setup.Do(func() { - if g.onError == nil { - g.onError = func(e error) error { return e } - } - g.err = nil - g.edone = make(chan struct{}) - g.reset = sync.Once{} - - errc := make(chan error) - g.errc = errc - go func() { - defer close(g.edone) - for err := range errc { - e := g.onError(err) - if e != nil && g.err == nil { - g.err = e // capture the first error always - } - } - }() - }) +func (g *Group) handleError(err error) { + g.μ.Lock() + defer g.μ.Unlock() + e := g.onError.filter(err) + if e != nil && g.err == nil { + g.err = e // capture the first unfiltered error always + } } -func (g *Group) cleanup() { - g.reset.Do(func() { - g.wg.Wait() - if g.errc == nil { - return - } - close(g.errc) - <-g.edone - g.errc = nil - g.setup = sync.Once{} - }) -} +func (g *Group) init() { g.setup.Do(func() { g.err = nil }) } // Wait blocks until all the goroutines currently active in the group have // returned, and all reported errors have been delivered to the callback. @@ -100,13 +68,26 @@ func (g *Group) cleanup() { // sync.WaitGroup, new tasks can be added to the group only if there is at // least one task active that started before all active Wait calls. Once all // Wait calls have returned, the group is ready for reuse. -func (g *Group) Wait() error { g.cleanup(); return g.err } +func (g *Group) Wait() error { + g.wg.Wait() + g.μ.Lock() + defer g.μ.Unlock() + g.setup = sync.Once{} + return g.err +} // An ErrorFunc is called by a group each time a task reports an error. Its // return value replaces the reported error, so the ErrorFunc can filter or // suppress errors by modifying or discarding the input error. type ErrorFunc func(error) error +func (ef ErrorFunc) filter(err error) error { + if ef == nil { + return err + } + return ef(err) +} + // Trigger creates an ErrorFunc that calls f each time a task reports an error. // The resulting ErrorFunc returns task errors unmodified. func Trigger(f func()) ErrorFunc { return func(e error) error { f(); return e } }