Skip to content

Commit

Permalink
GODRIVER-2603 [master] - Revised error handling using Go 1.13 error A…
Browse files Browse the repository at this point in the history
…PIs (#1474)

Co-authored-by: Lokesh Kumar <[email protected]>
  • Loading branch information
blink1073 and kumarlokesh authored Nov 20, 2023
1 parent f93a990 commit d33301f
Show file tree
Hide file tree
Showing 22 changed files with 63 additions and 75 deletions.
9 changes: 5 additions & 4 deletions benchmark/harness_case.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package benchmark

import (
"context"
"errors"
"fmt"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -95,12 +96,12 @@ benchRepeat:
res.Duration = c.elapsed
c.cumulativeRuntime += res.Duration

switch res.Error {
case context.DeadlineExceeded:
switch {
case errors.Is(res.Error, context.DeadlineExceeded):
break benchRepeat
case context.Canceled:
case errors.Is(res.Error, context.Canceled):
break benchRepeat
case nil:
case res.Error == nil:
out.Trials++
c.elapsed = 0
out.Raw = append(out.Raw, res)
Expand Down
8 changes: 4 additions & 4 deletions bson/bsoncodec/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func newDefaultStructCodec() *StructCodec {
if err != nil {
// This function is called from the codec registration path, so errors can't be propagated. If there's an error
// constructing the StructCodec, we panic to avoid losing it.
panic(fmt.Errorf("error creating default StructCodec: %v", err))
panic(fmt.Errorf("error creating default StructCodec: %w", err))
}
return codec
}
Expand Down Expand Up @@ -178,7 +178,7 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueRe

for {
key, elemVr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
if errors.Is(err, bsonrw.ErrEOD) {
break
} else if err != nil {
return err
Expand Down Expand Up @@ -1379,7 +1379,7 @@ func (dvd DefaultValueDecoders) MapDecodeValue(dc DecodeContext, vr bsonrw.Value
keyType := val.Type().Key()
for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
if errors.Is(err, bsonrw.ErrEOD) {
break
}
if err != nil {
Expand Down Expand Up @@ -1675,7 +1675,7 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueR
idx := 0
for {
vr, err := ar.ReadValue()
if err == bsonrw.ErrEOA {
if errors.Is(err, bsonrw.ErrEOA) {
break
}
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions bson/bsoncodec/default_value_decoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2370,8 +2370,8 @@ func TestDefaultValueDecoders(t *testing.T) {
return
}
if rc.val == cansettest { // We're doing an IsValid and CanSet test
wanterr, ok := rc.err.(ValueDecoderError)
if !ok {
var wanterr ValueDecoderError
if !errors.As(rc.err, &wanterr) {
t.Fatalf("Error must be a DecodeValueError, but got a %T", rc.err)
}

Expand Down Expand Up @@ -3685,8 +3685,8 @@ func TestDefaultValueDecoders(t *testing.T) {
val := reflect.New(reflect.TypeOf(outer{})).Elem()
err := defaultTestStructCodec.DecodeValue(dc, vr, val)

decodeErr, ok := err.(*DecodeError)
assert.True(t, ok, "expected DecodeError, got %v of type %T", err, err)
var decodeErr *DecodeError
assert.True(t, errors.As(err, &decodeErr), "expected DecodeError, got %v of type %T", err, err)
expectedKeys := []string{"foo", "bar"}
assert.Equal(t, expectedKeys, decodeErr.Keys(), "expected keys slice %v, got %v", expectedKeys,
decodeErr.Keys())
Expand Down
6 changes: 3 additions & 3 deletions bson/bsoncodec/default_value_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.Docum
}

currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.MapIndex(key))
if lookupErr != nil && lookupErr != errInvalidValue {
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}

Expand Down Expand Up @@ -418,7 +418,7 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.Val

for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && lookupErr != errInvalidValue {
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}

Expand Down Expand Up @@ -487,7 +487,7 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.Val

for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && lookupErr != errInvalidValue {
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}

Expand Down
2 changes: 1 addition & 1 deletion bson/bsoncodec/map_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value,
if mc.EncodeKeysWithStringer {
parsed, err := strconv.ParseFloat(key, 64)
if err != nil {
return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyType.Kind(), err)
return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %w", keyType.Kind(), err)
}
keyVal = reflect.ValueOf(parsed)
break
Expand Down
4 changes: 2 additions & 2 deletions bson/bsoncodec/struct_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val
}

func newDecodeError(key string, original error) error {
de, ok := original.(*DecodeError)
if !ok {
var de *DecodeError
if !errors.As(original, &de) {
return &DecodeError{
keys: []string{key},
wrapped: original,
Expand Down
3 changes: 2 additions & 1 deletion bson/decoder_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bson_test

import (
"bytes"
"errors"
"fmt"
"io"

Expand Down Expand Up @@ -200,7 +201,7 @@ func ExampleDecoder_multipleExtendedJSONDocuments() {
for {
var res Coordinate
err = decoder.Decode(&res)
if err == io.EOF {
if errors.Is(err, io.EOF) {
break
}
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion bson/encoder_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bson_test

import (
"bytes"
"errors"
"fmt"
"io"

Expand Down Expand Up @@ -162,7 +163,7 @@ func ExampleEncoder_multipleBSONDocuments() {
// Extended JSON by converting them to bson.Raw.
for {
doc, err := bson.ReadDocument(buf)
if err == io.EOF {
if errors.Is(err, io.EOF) {
return
}
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions bson/primitive/objectid.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func processUniqueBytes() [5]byte {
var b [5]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err))
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %w", err))
}

return b
Expand All @@ -193,7 +193,7 @@ func readRandomUint32() uint32 {
var b [4]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err))
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %w", err))
}

return (uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
Expand Down
9 changes: 5 additions & 4 deletions cmd/testatlas/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package main

import (
"context"
"errors"
"flag"
"fmt"
"time"
Expand Down Expand Up @@ -52,7 +53,7 @@ func main() {
func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
client, err := mongo.Connect(ctx, clientOpts)
if err != nil {
return fmt.Errorf("Connect error: %v", err)
return fmt.Errorf("Connect error: %w", err)
}

defer func() {
Expand All @@ -63,12 +64,12 @@ func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
cmd := bson.D{{handshake.LegacyHello, 1}}
err = db.RunCommand(ctx, cmd).Err()
if err != nil {
return fmt.Errorf("legacy hello error: %v", err)
return fmt.Errorf("legacy hello error: %w", err)
}

coll := db.Collection("test")
if err = coll.FindOne(ctx, bson.D{{"x", 1}}).Err(); err != nil && err != mongo.ErrNoDocuments {
return fmt.Errorf("FindOne error: %v", err)
if err = coll.FindOne(ctx, bson.D{{"x", 1}}).Err(); err != nil && !errors.Is(err, mongo.ErrNoDocuments) {
return fmt.Errorf("FindOne error: %w", err)
}
return nil
}
3 changes: 2 additions & 1 deletion cmd/testaws/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package main

import (
"context"
"errors"
"fmt"
"os"

Expand All @@ -33,7 +34,7 @@ func main() {

db := client.Database("aws")
coll := db.Collection("test")
if err = coll.FindOne(ctx, bson.D{{"x", 1}}).Err(); err != nil && err != mongo.ErrNoDocuments {
if err = coll.FindOne(ctx, bson.D{{"x", 1}}).Err(); err != nil && !errors.Is(err, mongo.ErrNoDocuments) {
panic(fmt.Sprintf("FindOne error: %v", err))
}
}
4 changes: 3 additions & 1 deletion internal/aws/awserr/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package awserr

import (
"errors"
"fmt"
)

Expand Down Expand Up @@ -106,7 +107,8 @@ func (b baseError) OrigErr() error {
case 1:
return b.errs[0]
default:
if err, ok := b.errs[0].(Error); ok {
var err Error
if errors.As(b.errs[0], &err) {
return NewBatchError(err.Code(), err.Message(), b.errs[1:])
}
return NewBatchError("BatchedErrors",
Expand Down
3 changes: 2 additions & 1 deletion internal/csfle/csfle.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package csfle

import (
"errors"
"fmt"

"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
Expand All @@ -23,7 +24,7 @@ func GetEncryptedStateCollectionName(efBSON bsoncore.Document, dataCollectionNam
fieldName := stateCollection + "Collection"
val, err := efBSON.LookupErr(fieldName)
if err != nil {
if err != bsoncore.ErrElementNotFound {
if !errors.Is(err, bsoncore.ErrElementNotFound) {
return "", err
}
// Return default name.
Expand Down
12 changes: 6 additions & 6 deletions mongo/cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool {
ctx = context.Background()
}
doc, err := c.batch.Next()
switch err {
case nil:
switch {
case err == nil:
// Consume the next document in the current batch.
c.batchLength--
c.Current = bson.Raw(doc)
return true
case io.EOF: // Need to do a getMore
case errors.Is(err, io.EOF): // Need to do a getMore
default:
c.err = err
return false
Expand Down Expand Up @@ -204,12 +204,12 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool {
c.batch = c.bc.Batch()
c.batchLength = c.batch.DocumentCount()
doc, err = c.batch.Next()
switch err {
case nil:
switch {
case err == nil:
c.batchLength--
c.Current = bson.Raw(doc)
return true
case io.EOF: // Empty batch so we continue
case errors.Is(err, io.EOF): // Empty batch so we continue
default:
c.err = err
return false
Expand Down
21 changes: 5 additions & 16 deletions mongo/integration/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package integration
import (
"context"
"errors"
"fmt"
"io"
"net"
"testing"
Expand Down Expand Up @@ -45,18 +46,6 @@ func (n netErr) Temporary() bool {

var _ net.Error = (*netErr)(nil)

type wrappedError struct {
err error
}

func (we wrappedError) Error() string {
return we.err.Error()
}

func (we wrappedError) Unwrap() error {
return we.err
}

func TestErrors(t *testing.T) {
mt := mtest.New(t, noClientOpts)

Expand Down Expand Up @@ -478,7 +467,7 @@ func TestErrors(t *testing.T) {
},
false,
},
{"wrapped error", wrappedError{mongo.CommandError{11000, "", nil, "blah", nil, nil}}, true},
{"wrapped error", fmt.Errorf("%w", mongo.CommandError{11000, "", nil, "blah", nil, nil}), true},
{"other error type", errors.New("foo"), false},
}
for _, tc := range testCases {
Expand All @@ -499,7 +488,7 @@ func TestErrors(t *testing.T) {
}{
{"ServerError true", mongo.CommandError{100, "", []string{networkLabel}, "blah", nil, nil}, true},
{"ServerError false", mongo.CommandError{100, "", []string{otherLabel}, "blah", nil, nil}, false},
{"wrapped error", wrappedError{mongo.CommandError{100, "", []string{networkLabel}, "blah", nil, nil}}, true},
{"wrapped error", fmt.Errorf("%w", mongo.CommandError{100, "", []string{networkLabel}, "blah", nil, nil}), true},
{"other error type", errors.New("foo"), false},
}
for _, tc := range testCases {
Expand Down Expand Up @@ -533,8 +522,8 @@ func TestErrors(t *testing.T) {
{"net error true", mongo.CommandError{
100, "", []string{"other"}, "blah", netErr{true}, nil}, true},
{"net error false", netErr{false}, false},
{"wrapped error", wrappedError{mongo.CommandError{
100, "", []string{"other"}, "blah", context.DeadlineExceeded, nil}}, true},
{"wrapped error", fmt.Errorf("%w", mongo.CommandError{
100, "", []string{"other"}, "blah", context.DeadlineExceeded, nil}), true},
{"other error", errors.New("foo"), false},
}
for _, tc := range testCases {
Expand Down
4 changes: 2 additions & 2 deletions mongo/integration/mtest/global_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func ServerVersion() string {
func SetFailPoint(fp FailPoint, client *mongo.Client) error {
admin := client.Database("admin")
if err := admin.RunCommand(context.Background(), fp).Err(); err != nil {
return fmt.Errorf("error creating fail point: %v", err)
return fmt.Errorf("error creating fail point: %w", err)
}
return nil
}
Expand All @@ -89,7 +89,7 @@ func SetFailPoint(fp FailPoint, client *mongo.Client) error {
func SetRawFailPoint(fp bson.Raw, client *mongo.Client) error {
admin := client.Database("admin")
if err := admin.RunCommand(context.Background(), fp).Err(); err != nil {
return fmt.Errorf("error creating fail point: %v", err)
return fmt.Errorf("error creating fail point: %w", err)
}
return nil
}
Loading

0 comments on commit d33301f

Please sign in to comment.