Skip to content

Commit

Permalink
Simplify the implementation of the Group.
Browse files Browse the repository at this point in the history
Remove the separate goroutine collecting errors, and deliver them directly to
the error filter and the output field. Moreover simplify the setup and teardown
so that there is not so much coordinated state. Although performance was not a
primary consideration, benchmarking suggests this is actually faster than the
previous implementation, and uses less memory.

Also expand and clarify the documentation of the Wait method.

Co-Authored-By: David Anderson <[email protected]>
  • Loading branch information
creachadair and danderson committed Mar 19, 2024
1 parent 3287b54 commit 0ed7876
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 59 deletions.
4 changes: 2 additions & 2 deletions collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ func NewCollector[T any](value func(T)) *Collector[T] { return &Collector[T]{han
// Wait waits until the collector has finished processing.
//
// Deprecated: This method is now a noop; it is safe but unnecessary to call
// it. The state serviced by c is settled once all the goroutines writing to
// the collector have returned. It may be removed in a future version.
// it. Once all the tasks created from c have returned, any state accessed by
// the accumulator is settled. Wait may be removed in a future version.
func (c *Collector[T]) Wait() {}

// Task returns a Task wrapping a call to f. If f reports an error, that error
Expand Down
114 changes: 57 additions & 57 deletions taskgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,44 @@
// and respond to task errors.
package taskgroup

import "sync"
import (
"sync"
"sync/atomic"
)

// A Task function is the basic unit of work in a Group. Errors reported by
// tasks are collected and reported by the group.
type Task func() error

// A Group manages a collection of cooperating goroutines. New tasks are added
// to the group with the Go method. Call the Wait method to wait for the tasks
// to complete. A zero value is ready for use, but must not be copied after its
// A Group manages a collection of cooperating goroutines. Add new tasks to
// the group with the Go method. Call the Wait method to wait for the tasks to
// complete. A zero value is ready for use, but must not be copied after its
// first use.
//
// The group collects any errors returned by the tasks in the group. The first
// non-nil error reported by any task (and not otherwise filtered) is returned
// from the Wait method.
type Group struct {
wg sync.WaitGroup // counter for active goroutines
err error // error returned from Wait
wg sync.WaitGroup // counter for active goroutines
onError ErrorFunc // called each time a task returns non-nil

setup sync.Once // set up and start the error collector
reset sync.Once // stop the error collector and set err
// active is nonzero when the group is "active", meaning there has been at
// least one call to Go since the group was created or the last Wait.
active atomic.Uint32

onError func(error) error // called each time a task returns non-nil
errc chan<- error // errors generated by goroutines
edone chan struct{} // signals error completion
μ sync.Mutex // guards err
err error // error returned from Wait
}

// activate resets the state of the group and marks it as active. This is
// triggered by adding a goroutine to an empty group.
func (g *Group) activate() {
g.μ.Lock()
defer g.μ.Unlock()
if g.active.Load() == 0 { // still inactive
g.err = nil
g.active.Store(1)
}
}

// New constructs a new empty group. If ef != nil, it is called for each error
Expand All @@ -40,72 +54,58 @@ func New(ef ErrorFunc) *Group { return &Group{onError: ef} }

// Go runs task in a new goroutine in g, and returns g to permit chaining.
func (g *Group) Go(task Task) *Group {
if g.active.Load() == 0 {
g.activate()
}
g.wg.Add(1)
g.init()
errc := g.errc
go func() {
defer g.wg.Done()
if err := task(); err != nil {
errc <- err
g.handleError(err)
}
}()
return g
}

func (g *Group) init() {
// The first time a task is added to an otherwise clear group, set up the
// error collector goroutine. We don't do this in the constructor so that
// an unused group can be abandoned without orphaning a goroutine.
g.setup.Do(func() {
if g.onError == nil {
g.onError = func(e error) error { return e }
}
g.err = nil
g.edone = make(chan struct{})
g.reset = sync.Once{}

errc := make(chan error)
g.errc = errc
go func() {
defer close(g.edone)
for err := range errc {
e := g.onError(err)
if e != nil && g.err == nil {
g.err = e // capture the first error always
}
}
}()
})
}

func (g *Group) cleanup() {
g.reset.Do(func() {
g.wg.Wait()
if g.errc == nil {
return
}
close(g.errc)
<-g.edone
g.errc = nil
g.setup = sync.Once{}
})
func (g *Group) handleError(err error) {
g.μ.Lock()
defer g.μ.Unlock()
e := g.onError.filter(err)
if e != nil && g.err == nil {
g.err = e // capture the first unfiltered error always
}
}

// Wait blocks until all the goroutines currently active in the group have
// returned, and all reported errors have been delivered to the callback.
// It returns the first non-nil error returned by any of the goroutines in the
// returned, and all reported errors have been delivered to the callback. It
// returns the first non-nil error reported by any of the goroutines in the
// group and not filtered by an ErrorFunc.
//
// It is safe to call Wait concurrently from multiple goroutines, but as with
// sync.WaitGroup no tasks can be added to g while any call to Wait is in
// progress. Once all Wait calls have returned, the group is ready for reuse.
func (g *Group) Wait() error { g.cleanup(); return g.err }
// As with sync.WaitGroup, new tasks can be added to g during a call to Wait
// only if there was already at least one task active when Wait was called.
// After Wait has returned, the group is ready for reuse.
//
// Wait may be called from at most one goroutine at a time.
func (g *Group) Wait() error {
g.wg.Wait()
g.μ.Lock()
defer g.μ.Unlock()
defer g.active.Store(0)
return g.err
}

// An ErrorFunc is called by a group each time a task reports an error. Its
// return value replaces the reported error, so the ErrorFunc can filter or
// suppress errors by modifying or discarding the input error.
type ErrorFunc func(error) error

func (ef ErrorFunc) filter(err error) error {
if ef == nil {
return err
}
return ef(err)
}

// Trigger creates an ErrorFunc that calls f each time a task reports an error.
// The resulting ErrorFunc returns task errors unmodified.
func Trigger(f func()) ErrorFunc { return func(e error) error { f(); return e } }
Expand Down

0 comments on commit 0ed7876

Please sign in to comment.