Skip to content

Commit

Permalink
GODRIVER-2929 Add the ability to join multiple errors into one.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdale committed Sep 2, 2023
1 parent 2f372fd commit 9396361
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 0 deletions.
79 changes: 79 additions & 0 deletions internal/errutil/join.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package errutil

import "errors"

// join is a Go 1.13-1.19 compatible version of [errors.Join]. It is only called
// by Join in join_go1.19.go. It is included here in a file without build
// constraints only for testing purposes.
//
// It is heavily based on Join from
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
func join(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
e := &joinError{
errs: make([]error, 0, n),
}
for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}
return e
}

// joinError is a Go 1.13-1.19 compatible joinable error type. Its error
// message is identical to [errors.Join], but it implements "Unwrap() error"
// instead of "Unwrap() []error".
//
// It is heavily based on the joinError from
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
type joinError struct {
errs []error
}

func (e *joinError) Error() string {
var b []byte
for i, err := range e.errs {
if i > 0 {
b = append(b, '\n')
}
b = append(b, err.Error()...)
}
return string(b)
}

// Unwrap returns another joinError with the same errors as the current
// joinError except the first error in the slice. Continuing to call Unwrap
// on each returned error will increment through every error in the slice. The
// resulting behavior when using [errors.Is] and [errors.As] is similar to an
// error created using [errors.Join] in Go 1.20+.
func (e *joinError) Unwrap() error {
if len(e.errs) == 1 {
return e.errs[0]
}
return &joinError{errs: e.errs[1:]}
}

// Is calls [errors.Is] with the first error in the slice.
func (e *joinError) Is(target error) bool {
if len(e.errs) == 0 {
return false
}
return errors.Is(e.errs[0], target)
}

// As calls [errors.As] with the first error in the slice.
func (e *joinError) As(target interface{}) bool {
if len(e.errs) == 0 {
return false
}
return errors.As(e.errs[0], target)
}
20 changes: 20 additions & 0 deletions internal/errutil/join_go1.19.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

//go:build !go1.20
// +build !go1.20

package errutil

// Join returns an error that wraps the given errors. Any nil error values are
// discarded. Join returns nil if every value in errs is nil. The error formats
// as the concatenation of the strings obtained by calling the Error method of
// each element of errs, with a newline between each string.
//
// A non-nil error returned by Join implements the "Unwrap() error" method.
func Join(errs ...error) error {
return join(errs...)
}
17 changes: 17 additions & 0 deletions internal/errutil/join_go1.20.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

//go:build go1.20
// +build go1.20

package errutil

import "errors"

// Join calls [errors.Join].
func Join(errs ...error) error {
return errors.Join(errs...)
}
243 changes: 243 additions & 0 deletions internal/errutil/join_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package errutil

import (
"context"
"errors"
"fmt"
"testing"

"go.mongodb.org/mongo-driver/internal/assert"
)

// TestJoin_Nil asserts that join returns a nil error for the same inputs that
// [errors.Join] returns a nil error.
func TestJoin_Nil(t *testing.T) {
t.Parallel()

assert.Equal(t, errors.Join(), join(), "errors.Join() != join()")
assert.Equal(t, errors.Join(nil), join(nil), "errors.Join(nil) != join(nil)")
assert.Equal(t, errors.Join(nil, nil), join(nil, nil), "errors.Join(nil, nil) != join(nil, nil)")
}

// TestJoin_Error asserts that join returns an error with the same error message
// as the error returned by [errors.Join].
func TestJoin_Error(t *testing.T) {
t.Parallel()

err1 := errors.New("err1")
err2 := errors.New("err2")

tests := []struct {
desc string
errs []error
}{{
desc: "single error",
errs: []error{err1},
}, {
desc: "two errors",
errs: []error{err1, err2},
}, {
desc: "two errors and a nil value",
errs: []error{err1, nil, err2},
}}

for _, test := range tests {
test := test // Capture range variable.

t.Run(test.desc, func(t *testing.T) {
t.Parallel()

want := errors.Join(test.errs...).Error()
got := join(test.errs...).Error()
assert.Equal(t,
want,
got,
"errors.Join().Error() != join().Error() for input %v",
test.errs)
})
}
}

// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
// to the error returned by [errors.Join] when passed to [errors.Is].
func TestJoin_ErrorsIs(t *testing.T) {
t.Parallel()

err1 := errors.New("err1")
err2 := errors.New("err2")

tests := []struct {
desc string
errs []error
target error
}{{
desc: "one error with a matching target",
errs: []error{err1},
target: err1,
}, {
desc: "one error with a non-matching target",
errs: []error{err1},
target: err2,
}, {
desc: "nil error",
errs: []error{nil},
target: err1,
}, {
desc: "no errors",
errs: []error{},
target: err1,
}, {
desc: "two different errors with a matching target",
errs: []error{err1, err2},
target: err2,
}, {
desc: "two identical errors with a matching target",
errs: []error{err1, err1},
target: err1,
}, {
desc: "wrapped error with a matching target",
errs: []error{fmt.Errorf("error: %w", err1)},
target: err1,
}, {
desc: "nested joined error with a matching target",
errs: []error{err1, join(err2, errors.New("nope"))},
target: err2,
}, {
desc: "nested joined error with no matching targets",
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
target: err2,
}, {
desc: "nested joined error with a wrapped matching target",
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
target: err1,
}, {
desc: "context.DeadlineExceeded",
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: context.DeadlineExceeded,
}, {
desc: "wrapped context.DeadlineExceeded",
errs: []error{err1, nil, fmt.Errorf("error: %w", context.DeadlineExceeded), err2},
target: context.DeadlineExceeded,
}}

for _, test := range tests {
test := test // Capture range variable.

t.Run(test.desc, func(t *testing.T) {
// Assert that top-level errors returned by errors.Join and join
// behave the same with errors.Is.
want := errors.Join(test.errs...)
got := join(test.errs...)
assert.Equal(t,
errors.Is(want, test.target),
errors.Is(got, test.target),
"errors.Join() and join() behave differently with errors.Is")

// Assert that wrapped errors returned by errors.Join and join
// behave the same with errors.Is.
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
got = fmt.Errorf("error: %w", join(test.errs...))
assert.Equal(t,
errors.Is(want, test.target),
errors.Is(got, test.target),
"errors.Join() and join(), when wrapped, behave differently with errors.Is")
})
}
}

type errType1 struct{}

func (errType1) Error() string { return "" }

type errType2 struct{}

func (errType2) Error() string { return "" }

// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
// to the error returned by [errors.Join] when passed to [errors.As].
func TestJoin_ErrorsAs(t *testing.T) {
t.Parallel()

err1 := errType1{}
err2 := errType2{}

tests := []struct {
desc string
errs []error
target interface{}
}{{
desc: "one error with a matching target",
errs: []error{err1},
target: &errType1{},
}, {
desc: "one error with a non-matching target",
errs: []error{err1},
target: &errType2{},
}, {
desc: "nil error",
errs: []error{nil},
target: &errType1{},
}, {
desc: "no errors",
errs: []error{},
target: &errType1{},
}, {
desc: "two different errors with a matching target",
errs: []error{err1, err2},
target: &errType2{},
}, {
desc: "two identical errors with a matching target",
errs: []error{err1, err1},
target: &errType1{},
}, {
desc: "wrapped error with a matching target",
errs: []error{fmt.Errorf("error: %w", err1)},
target: &errType1{},
}, {
desc: "nested joined error with a matching target",
errs: []error{err1, join(err2, errors.New("nope"))},
target: &errType2{},
}, {
desc: "nested joined error with no matching targets",
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
target: &errType2{},
}, {
desc: "nested joined error with a wrapped matching target",
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
target: &errType1{},
}, {
desc: "context.DeadlineExceeded",
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: &errType2{},
}}

for _, test := range tests {
test := test // Capture range variable.

t.Run(test.desc, func(t *testing.T) {
// Assert that top-level errors returned by errors.Join and join
// behave the same with errors.As.
want := errors.Join(test.errs...)
got := join(test.errs...)
assert.Equal(t,
errors.As(want, test.target),
errors.As(got, test.target),
"errors.Join() and join() behave differently with errors.As")

// Assert that wrapped errors returned by errors.Join and join
// behave the same with errors.As.
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
got = fmt.Errorf("error: %w", join(test.errs...))
assert.Equal(t,
errors.As(want, test.target),
errors.As(got, test.target),
"errors.Join() and join(), when wrapped, behave differently with errors.As")
})
}
}

0 comments on commit 9396361

Please sign in to comment.