Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-2929 Add the ability to join multiple errors into one. #1370

Merged
merged 1 commit into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions internal/errutil/join.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package errutil

import "errors"

// join is a Go 1.13-1.19 compatible version of [errors.Join]. It is only called
// by Join in join_go1.19.go. It is included here in a file without build
// constraints only for testing purposes.
//
// It is heavily based on Join from
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
func join(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
e := &joinError{
errs: make([]error, 0, n),
}
for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}
return e
}

// joinError is a Go 1.13-1.19 compatible joinable error type. Its error
// message is identical to [errors.Join], but it implements "Unwrap() error"
// instead of "Unwrap() []error".
//
// It is heavily based on the joinError from
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
type joinError struct {
errs []error
}

func (e *joinError) Error() string {
var b []byte
for i, err := range e.errs {
if i > 0 {
b = append(b, '\n')
}
b = append(b, err.Error()...)
}
return string(b)
}

// Unwrap returns another joinError with the same errors as the current
// joinError except the first error in the slice. Continuing to call Unwrap
// on each returned error will increment through every error in the slice. The
// resulting behavior when using [errors.Is] and [errors.As] is similar to an
// error created using [errors.Join] in Go 1.20+.
func (e *joinError) Unwrap() error {
if len(e.errs) == 1 {
return e.errs[0]
}
return &joinError{errs: e.errs[1:]}
}

// Is calls [errors.Is] with the first error in the slice.
func (e *joinError) Is(target error) bool {
if len(e.errs) == 0 {
return false
}
return errors.Is(e.errs[0], target)
}

// As calls [errors.As] with the first error in the slice.
func (e *joinError) As(target interface{}) bool {
if len(e.errs) == 0 {
return false
}
return errors.As(e.errs[0], target)
}
20 changes: 20 additions & 0 deletions internal/errutil/join_go1.19.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

//go:build !go1.20
// +build !go1.20

package errutil

// Join returns an error that wraps the given errors. Any nil error values are
// discarded. Join returns nil if every value in errs is nil. The error formats
// as the concatenation of the strings obtained by calling the Error method of
// each element of errs, with a newline between each string.
//
// A non-nil error returned by Join implements the "Unwrap() error" method.
func Join(errs ...error) error {
return join(errs...)
}
17 changes: 17 additions & 0 deletions internal/errutil/join_go1.20.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

//go:build go1.20
// +build go1.20

package errutil

import "errors"

// Join calls [errors.Join].
func Join(errs ...error) error {
return errors.Join(errs...)
}
243 changes: 243 additions & 0 deletions internal/errutil/join_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package errutil

import (
"context"
"errors"
"fmt"
"testing"

"go.mongodb.org/mongo-driver/internal/assert"
)

// TestJoin_Nil asserts that join returns a nil error for the same inputs that
// [errors.Join] returns a nil error.
func TestJoin_Nil(t *testing.T) {
t.Parallel()

assert.Equal(t, errors.Join(), join(), "errors.Join() != join()")
assert.Equal(t, errors.Join(nil), join(nil), "errors.Join(nil) != join(nil)")
assert.Equal(t, errors.Join(nil, nil), join(nil, nil), "errors.Join(nil, nil) != join(nil, nil)")
}

// TestJoin_Error asserts that join returns an error with the same error message
// as the error returned by [errors.Join].
func TestJoin_Error(t *testing.T) {
t.Parallel()

err1 := errors.New("err1")
err2 := errors.New("err2")

tests := []struct {
desc string
errs []error
}{{
desc: "single error",
errs: []error{err1},
}, {
desc: "two errors",
errs: []error{err1, err2},
}, {
desc: "two errors and a nil value",
errs: []error{err1, nil, err2},
}}

for _, test := range tests {
test := test // Capture range variable.

t.Run(test.desc, func(t *testing.T) {
t.Parallel()

want := errors.Join(test.errs...).Error()
got := join(test.errs...).Error()
assert.Equal(t,
want,
got,
"errors.Join().Error() != join().Error() for input %v",
test.errs)
})
}
}

// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
// to the error returned by [errors.Join] when passed to [errors.Is].
func TestJoin_ErrorsIs(t *testing.T) {
t.Parallel()

err1 := errors.New("err1")
err2 := errors.New("err2")

tests := []struct {
desc string
errs []error
target error
}{{
desc: "one error with a matching target",
errs: []error{err1},
target: err1,
}, {
desc: "one error with a non-matching target",
errs: []error{err1},
target: err2,
}, {
desc: "nil error",
errs: []error{nil},
target: err1,
}, {
desc: "no errors",
errs: []error{},
target: err1,
}, {
desc: "two different errors with a matching target",
errs: []error{err1, err2},
target: err2,
}, {
desc: "two identical errors with a matching target",
errs: []error{err1, err1},
target: err1,
}, {
desc: "wrapped error with a matching target",
errs: []error{fmt.Errorf("error: %w", err1)},
target: err1,
}, {
desc: "nested joined error with a matching target",
errs: []error{err1, join(err2, errors.New("nope"))},
target: err2,
}, {
desc: "nested joined error with no matching targets",
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
target: err2,
}, {
desc: "nested joined error with a wrapped matching target",
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
target: err1,
}, {
desc: "context.DeadlineExceeded",
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: context.DeadlineExceeded,
}, {
desc: "wrapped context.DeadlineExceeded",
errs: []error{err1, nil, fmt.Errorf("error: %w", context.DeadlineExceeded), err2},
target: context.DeadlineExceeded,
}}

for _, test := range tests {
test := test // Capture range variable.

t.Run(test.desc, func(t *testing.T) {
// Assert that top-level errors returned by errors.Join and join
// behave the same with errors.Is.
want := errors.Join(test.errs...)
got := join(test.errs...)
assert.Equal(t,
errors.Is(want, test.target),
errors.Is(got, test.target),
"errors.Join() and join() behave differently with errors.Is")

// Assert that wrapped errors returned by errors.Join and join
// behave the same with errors.Is.
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
got = fmt.Errorf("error: %w", join(test.errs...))
assert.Equal(t,
errors.Is(want, test.target),
errors.Is(got, test.target),
"errors.Join() and join(), when wrapped, behave differently with errors.Is")
})
}
}

type errType1 struct{}

func (errType1) Error() string { return "" }

type errType2 struct{}

func (errType2) Error() string { return "" }

// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
// to the error returned by [errors.Join] when passed to [errors.As].
func TestJoin_ErrorsAs(t *testing.T) {
t.Parallel()

err1 := errType1{}
err2 := errType2{}

tests := []struct {
desc string
errs []error
target interface{}
}{{
desc: "one error with a matching target",
errs: []error{err1},
target: &errType1{},
}, {
desc: "one error with a non-matching target",
errs: []error{err1},
target: &errType2{},
}, {
desc: "nil error",
errs: []error{nil},
target: &errType1{},
}, {
desc: "no errors",
errs: []error{},
target: &errType1{},
}, {
desc: "two different errors with a matching target",
errs: []error{err1, err2},
target: &errType2{},
}, {
desc: "two identical errors with a matching target",
errs: []error{err1, err1},
target: &errType1{},
}, {
desc: "wrapped error with a matching target",
errs: []error{fmt.Errorf("error: %w", err1)},
target: &errType1{},
}, {
desc: "nested joined error with a matching target",
errs: []error{err1, join(err2, errors.New("nope"))},
target: &errType2{},
}, {
desc: "nested joined error with no matching targets",
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
target: &errType2{},
}, {
desc: "nested joined error with a wrapped matching target",
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
target: &errType1{},
}, {
desc: "context.DeadlineExceeded",
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: &errType2{},
}}

for _, test := range tests {
test := test // Capture range variable.

t.Run(test.desc, func(t *testing.T) {
// Assert that top-level errors returned by errors.Join and join
// behave the same with errors.As.
want := errors.Join(test.errs...)
got := join(test.errs...)
assert.Equal(t,
errors.As(want, test.target),
errors.As(got, test.target),
"errors.Join() and join() behave differently with errors.As")

// Assert that wrapped errors returned by errors.Join and join
// behave the same with errors.As.
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
got = fmt.Errorf("error: %w", join(test.errs...))
assert.Equal(t,
errors.As(want, test.target),
errors.As(got, test.target),
"errors.Join() and join(), when wrapped, behave differently with errors.As")
})
}
}
Loading