From 4c98d14b11391ebb7503eea2ef0739c57d0cda32 Mon Sep 17 00:00:00 2001 From: Matt Dale <9760375+matthewdale@users.noreply.github.com> Date: Thu, 31 Aug 2023 14:23:35 -0700 Subject: [PATCH] GODRIVER-2929 Add the ability to join multiple errors into one. --- internal/errutil/join.go | 85 +++++++++++ internal/errutil/join_go1.19.go | 20 +++ internal/errutil/join_go1.20.go | 17 +++ internal/errutil/join_test.go | 243 ++++++++++++++++++++++++++++++++ 4 files changed, 365 insertions(+) create mode 100644 internal/errutil/join.go create mode 100644 internal/errutil/join_go1.19.go create mode 100644 internal/errutil/join_go1.20.go create mode 100644 internal/errutil/join_test.go diff --git a/internal/errutil/join.go b/internal/errutil/join.go new file mode 100644 index 0000000000..aa28b03327 --- /dev/null +++ b/internal/errutil/join.go @@ -0,0 +1,85 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package errutil + +import "errors" + +// join is a Go 1.13-1.19 compatible version of [errors.Join]. It is only called +// by Join in join_go1.19.go. It is included here in a file without build +// constraints only for testing purposes. +// +// It is heavily based on Join from +// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go +func join(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + e := &joinError{ + errs: make([]error, 0, n), + } + for _, err := range errs { + if err != nil { + e.errs = append(e.errs, err) + } + } + return e +} + +// joinError is a Go 1.13-1.19 compatible joinable error type. Its error +// message is identical to [errors.Join], but it implements "Unwrap() error" +// instead of "Unwrap() []error". +// +// It is heavily based on the joinError from +// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go +type joinError struct { + errs []error +} + +func (e *joinError) Error() string { + var b []byte + for i, err := range e.errs { + if i > 0 { + b = append(b, '\n') + } + b = append(b, err.Error()...) + } + return string(b) +} + +// Unwrap returns another joinError with the same errors as the current +// joinError except the first error in the slice. Continuing to call Unwrap +// on each returned error will increment through every error in the slice. The +// resulting behavior when using [errors.Is] and [errors.As] is similar to an +// error created using [errors.Join] in Go 1.20+. +func (e *joinError) Unwrap() error { + if len(e.errs) == 1 { + return e.errs[0] + } + return &joinError{errs: e.errs[1:]} +} + +// Is calls [errors.Is] with the first error in the slice. +func (e *joinError) Is(target error) bool { + if len(e.errs) == 0 { + return false + } + return errors.Is(e.errs[0], target) +} + +// As calls [errors.As] with the first error in the slice. +func (e *joinError) As(target interface{}) bool { + if len(e.errs) == 0 { + return false + } + return errors.As(e.errs[0], target) +} diff --git a/internal/errutil/join_go1.19.go b/internal/errutil/join_go1.19.go new file mode 100644 index 0000000000..569a0216b5 --- /dev/null +++ b/internal/errutil/join_go1.19.go @@ -0,0 +1,20 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +//go:build !go1.20 +// +build !go1.20 + +package errutil + +// Join returns an error that wraps the given errors. Any nil error values are +// discarded. Join returns nil if every value in errs is nil. The error formats +// as the concatenation of the strings obtained by calling the Error method of +// each element of errs, with a newline between each string. +// +// A non-nil error returned by Join implements the "Unwrap() error" method. +func Join(errs ...error) error { + return join(errs...) +} diff --git a/internal/errutil/join_go1.20.go b/internal/errutil/join_go1.20.go new file mode 100644 index 0000000000..69b9ad2231 --- /dev/null +++ b/internal/errutil/join_go1.20.go @@ -0,0 +1,17 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +//go:build go1.20 +// +build go1.20 + +package errutil + +import "errors" + +// Join calls [errors.Join]. +func Join(errs ...error) error { + return errors.Join(errs...) +} diff --git a/internal/errutil/join_test.go b/internal/errutil/join_test.go new file mode 100644 index 0000000000..333c2e9a2f --- /dev/null +++ b/internal/errutil/join_test.go @@ -0,0 +1,243 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package errutil + +import ( + "context" + "errors" + "fmt" + "testing" + + "go.mongodb.org/mongo-driver/internal/assert" +) + +// TestJoin_Nil asserts that join returns a nil error for the same inputs that +// [errors.Join] returns a nil error. +func TestJoin_Nil(t *testing.T) { + t.Parallel() + + assert.Equal(t, errors.Join(), join(), "errors.Join() != join()") + assert.Equal(t, errors.Join(nil), join(nil), "errors.Join(nil) != join(nil)") + assert.Equal(t, errors.Join(nil, nil), join(nil, nil), "errors.Join(nil, nil) != join(nil, nil)") +} + +// TestJoin_Error asserts that join returns an error with the same error message +// as the error returned by [errors.Join]. +func TestJoin_Error(t *testing.T) { + t.Parallel() + + err1 := errors.New("err1") + err2 := errors.New("err2") + + tests := []struct { + desc string + errs []error + }{{ + desc: "single error", + errs: []error{err1}, + }, { + desc: "two errors", + errs: []error{err1, err2}, + }, { + desc: "two errors and a nil value", + errs: []error{err1, nil, err2}, + }} + + for _, test := range tests { + test := test // Capture range variable. + + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + want := errors.Join(test.errs...).Error() + got := join(test.errs...).Error() + assert.Equal(t, + want, + got, + "errors.Join().Error() != join().Error() for input %v", + test.errs) + }) + } +} + +// TestJoin_ErrorsIs asserts that join returns an error that behaves identically +// to the error returned by [errors.Join] when passed to [errors.Is]. +func TestJoin_ErrorsIs(t *testing.T) { + t.Parallel() + + err1 := errors.New("err1") + err2 := errors.New("err2") + + tests := []struct { + desc string + errs []error + target error + }{{ + desc: "one error with a matching target", + errs: []error{err1}, + target: err1, + }, { + desc: "one error with a non-matching target", + errs: []error{err1}, + target: err2, + }, { + desc: "nil error", + errs: []error{nil}, + target: err1, + }, { + desc: "no errors", + errs: []error{}, + target: err1, + }, { + desc: "two different errors with a matching target", + errs: []error{err1, err2}, + target: err2, + }, { + desc: "two identical errors with a matching target", + errs: []error{err1, err1}, + target: err1, + }, { + desc: "wrapped error with a matching target", + errs: []error{fmt.Errorf("error: %w", err1)}, + target: err1, + }, { + desc: "nested joined error with a matching target", + errs: []error{err1, join(err2, errors.New("nope"))}, + target: err2, + }, { + desc: "nested joined error with no matching targets", + errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))}, + target: err2, + }, { + desc: "nested joined error with a wrapped matching target", + errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2}, + target: err1, + }, { + desc: "context.DeadlineExceeded", + errs: []error{err1, nil, context.DeadlineExceeded, err2}, + target: context.DeadlineExceeded, + }, { + desc: "wrapped context.DeadlineExceeded", + errs: []error{err1, nil, fmt.Errorf("error: %w", context.DeadlineExceeded), err2}, + target: context.DeadlineExceeded, + }} + + for _, test := range tests { + test := test // Capture range variable. + + t.Run(test.desc, func(t *testing.T) { + // Assert that top-level errors returned by errors.Join and join + // behave the same with errors.Is. + want := errors.Join(test.errs...) + got := join(test.errs...) + assert.Equal(t, + errors.Is(want, test.target), + errors.Is(got, test.target), + "errors.Join() and join() behave differently with errors.Is") + + // Assert that wrapped errors returned by errors.Join and join + // behave the same with errors.Is. + want = fmt.Errorf("error: %w", errors.Join(test.errs...)) + got = fmt.Errorf("error: %w", join(test.errs...)) + assert.Equal(t, + errors.Is(want, test.target), + errors.Is(got, test.target), + "errors.Join() and join(), when wrapped, behave differently with errors.Is") + }) + } +} + +type errType1 struct{} + +func (errType1) Error() string { return "" } + +type errType2 struct{} + +func (errType2) Error() string { return "" } + +// TestJoin_ErrorsIs asserts that join returns an error that behaves identically +// to the error returned by [errors.Join] when passed to [errors.As]. +func TestJoin_ErrorsAs(t *testing.T) { + t.Parallel() + + err1 := errType1{} + err2 := errType2{} + + tests := []struct { + desc string + errs []error + target interface{} + }{{ + desc: "one error with a matching target", + errs: []error{err1}, + target: &errType1{}, + }, { + desc: "one error with a non-matching target", + errs: []error{err1}, + target: &errType2{}, + }, { + desc: "nil error", + errs: []error{nil}, + target: &errType1{}, + }, { + desc: "no errors", + errs: []error{}, + target: &errType1{}, + }, { + desc: "two different errors with a matching target", + errs: []error{err1, err2}, + target: &errType2{}, + }, { + desc: "two identical errors with a matching target", + errs: []error{err1, err1}, + target: &errType1{}, + }, { + desc: "wrapped error with a matching target", + errs: []error{fmt.Errorf("error: %w", err1)}, + target: &errType1{}, + }, { + desc: "nested joined error with a matching target", + errs: []error{err1, join(err2, errors.New("nope"))}, + target: &errType2{}, + }, { + desc: "nested joined error with no matching targets", + errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))}, + target: &errType2{}, + }, { + desc: "nested joined error with a wrapped matching target", + errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2}, + target: &errType1{}, + }, { + desc: "context.DeadlineExceeded", + errs: []error{err1, nil, context.DeadlineExceeded, err2}, + target: &errType2{}, + }} + + for _, test := range tests { + test := test // Capture range variable. + + t.Run(test.desc, func(t *testing.T) { + // Assert that top-level errors returned by errors.Join and join + // behave the same with errors.As. + want := errors.Join(test.errs...) + got := join(test.errs...) + assert.Equal(t, + errors.As(want, test.target), + errors.As(got, test.target), + "errors.Join() and join() behave differently with errors.As") + + // Assert that wrapped errors returned by errors.Join and join + // behave the same with errors.As. + want = fmt.Errorf("error: %w", errors.Join(test.errs...)) + got = fmt.Errorf("error: %w", join(test.errs...)) + assert.Equal(t, + errors.As(want, test.target), + errors.As(got, test.target), + "errors.Join() and join(), when wrapped, behave differently with errors.As") + }) + } +}