diff --git a/eris.go b/eris.go index cfbcf2a..5b621c4 100644 --- a/eris.go +++ b/eris.go @@ -433,8 +433,18 @@ func (e *rootError) Format(s fmt.State, verb rune) { printError(e, s, verb) } -// Is returns true if both errors have the same message and code. Ignores additional KV pairs. +// Is returns true if both errors have the same message and code. +// In case of a joined error, returns true if at least one of the joined errors is equal to target. +// Ignores additional KV pairs. func (e *rootError) Is(target error) bool { + if joinErr, ok := e.ext.(joinError); ok { + for _, err := range joinErr.Unwrap() { + if Is(err, target) { + return true + } + } + return false + } if err, ok := target.(*rootError); ok { return e.msg == err.msg && e.code == err.code && reflect.DeepEqual(e.kvs, err.kvs) } @@ -446,6 +456,14 @@ func (e *rootError) Is(target error) bool { // As returns true if the error message in the target error is equivalent to the error message in the root error. func (e *rootError) As(target any) bool { + if joinErr, ok := e.ext.(joinError); ok { + for _, err := range joinErr.Unwrap() { + if As(err, target) { + return true + } + } + return false + } t := reflect.Indirect(reflect.ValueOf(target)).Interface() if err, ok := t.(*rootError); ok { if e.msg == err.msg { diff --git a/eris_test.go b/eris_test.go index d309802..05ba41c 100644 --- a/eris_test.go +++ b/eris_test.go @@ -427,6 +427,7 @@ func TestErrorUnwrap(t *testing.T) { } func TestErrorIs(t *testing.T) { + rootErr := eris.New("root error") externalErr := errors.New("external error") customErr := withLayer{ msg: "additional context", @@ -518,6 +519,53 @@ func TestErrorIs(t *testing.T) { compare: nil, output: true, }, + "join error (external)": { + cause: eris.Join(externalErr, rootErr), + compare: externalErr, + output: true, + }, + "join error (root)": { + cause: eris.Join(externalErr, rootErr), + compare: rootErr, + output: true, + }, + "join error (nil)": { + cause: eris.Join(nil, nil), + compare: nil, + output: true, + }, + "join error (wrap)": { + cause: eris.Join(externalErr, eris.Wrap(rootErr, "eris wrap error")), + compare: eris.New("eris wrap error").WithCode(eris.CodeInternal), + output: true, + }, + "join error not found (code don't match)": { + cause: eris.Join(externalErr, eris.Wrap(rootErr, "eris wrap error")), + compare: eris.New("eris wrap error"), + output: false, + }, + "join error not found (message don't match)": { + cause: eris.Join(externalErr, eris.Wrap(rootErr, "eris wrap error")), + compare: eris.New("eris root error message wrong"), + output: false, + }, + "join error not found (external don't match)": { + cause: eris.Join(externalErr, rootErr), + compare: errors.New("external error not match"), + output: false, + }, + "wrapped join error (match join errors)": { + cause: eris.Join(externalErr, rootErr), + input: []string{"additional context"}, + compare: rootErr, + output: true, + }, + "wrapped join error (match wrap)": { + cause: eris.Join(externalErr, rootErr), + input: []string{"additional context"}, + compare: eris.New("additional context").WithCode(eris.CodeUnknown), + output: true, + }, } for desc, tc := range tests { @@ -535,6 +583,7 @@ func TestErrorIs(t *testing.T) { func TestErrorAs(t *testing.T) { externalError := errors.New("external error") rootErr := eris.New("root error").WithCode(eris.CodeUnknown) + anotherRootErr := eris.New("another root error").WithCode(eris.CodeUnknown) wrappedErr := eris.WithCode(eris.Wrap(rootErr, "additional context"), eris.CodeUnknown) customErr := withLayer{ msg: "additional context", @@ -544,7 +593,6 @@ func TestErrorAs(t *testing.T) { }, }, } - tests := map[string]struct { cause error // root error target any // errors for comparison @@ -641,6 +689,29 @@ func TestErrorAs(t *testing.T) { match: true, output: customErr, }, + "join error (external)": { + cause: eris.Join(externalError, rootErr), + target: &externalError, + match: true, + output: externalError, + }, + "join error (root)": { + cause: eris.Join(externalError, rootErr), + target: &rootErr, + match: true, + output: rootErr, + }, + "join error (custom)": { + cause: eris.Join(externalError, withMessage{"test"}), + target: &withMessage{""}, + match: true, + output: withMessage{"test"}, + }, + "join error not found (message don't match)": { + cause: eris.Join(externalError, rootErr), + target: &anotherRootErr, + match: false, + }, } for desc, tc := range tests {