From daa72982db1b6983dbe096c35bd427c8535fab41 Mon Sep 17 00:00:00 2001 From: Lokesh Kumar Date: Mon, 4 Dec 2023 21:30:14 +0100 Subject: [PATCH 1/2] handle errors using go-1.13 apis --- bson/bsoncodec/default_value_encoders_test.go | 2 +- bson/bsoncodec/map_codec.go | 2 +- bson/bsoncodec/registry_test.go | 4 +- bson/bsoncodec/slice_codec.go | 5 +- bson/bsoncodec/struct_codec.go | 6 +- bson/bsonrw/extjson_reader_test.go | 5 +- bson/bsonrw/json_scanner.go | 14 ++-- bson/bsonrw/value_reader_test.go | 2 +- bson/raw_test.go | 8 +-- bson/raw_value_test.go | 5 +- bson/unmarshal_test.go | 3 +- mongo/bulk_write.go | 2 +- mongo/collection.go | 7 +- mongo/crud_examples_test.go | 9 +-- mongo/database.go | 10 +-- mongo/errors.go | 22 ++++-- mongo/gridfs/bucket.go | 4 +- mongo/gridfs/download_stream.go | 4 +- mongo/index_view.go | 3 +- mongo/integration/collection_test.go | 3 +- mongo/integration/crud_helpers_test.go | 4 +- mongo/integration/index_view_test.go | 3 +- mongo/integration/json_helpers_test.go | 9 +-- mongo/integration/mtest/setup.go | 22 +++--- .../integration/mtest/wiremessage_helpers.go | 2 +- mongo/integration/sessions_test.go | 8 +-- mongo/integration/unified/admin_helpers.go | 6 +- mongo/integration/unified/bucket_options.go | 6 +- .../integration/unified/bulkwrite_helpers.go | 24 +++---- .../client_encryption_operation_execution.go | 17 ++--- .../unified/client_operation_execution.go | 6 +- mongo/integration/unified/collection_data.go | 12 ++-- .../unified/collection_operation_execution.go | 71 ++++++++++--------- mongo/integration/unified/crud_helpers.go | 6 +- .../unified/cursor_operation_execution.go | 4 +- .../unified/database_operation_execution.go | 8 +-- .../unified/db_collection_options.go | 6 +- mongo/integration/unified/entity.go | 18 ++--- mongo/integration/unified/error.go | 23 +++--- .../integration/unified/event_verification.go | 4 +- .../gridfs_bucket_operation_execution.go | 4 +- .../unified_runner_thread_helpers_test.go | 2 +- mongo/options/clientoptions_test.go | 5 +- mongo/read_write_concern_spec_test.go | 3 +- mongo/search_index_view.go | 4 +- mongo/session.go | 3 +- mongo/with_transactions_test.go | 3 +- x/mongo/driver/auth/sasl.go | 4 +- x/mongo/driver/batch_cursor.go | 2 +- x/mongo/driver/crypt.go | 3 +- x/mongo/driver/integration/main_test.go | 8 ++- x/mongo/driver/integration/scram_test.go | 2 +- x/mongo/driver/operation.go | 7 +- x/mongo/driver/operation/count.go | 4 +- x/mongo/driver/session/client_session_test.go | 15 ++-- x/mongo/driver/topology/connection.go | 5 +- x/mongo/driver/topology/rtt_monitor.go | 4 +- x/mongo/driver/topology/sdam_spec_test.go | 4 +- x/mongo/driver/topology/server.go | 12 ++-- 59 files changed, 253 insertions(+), 220 deletions(-) diff --git a/bson/bsoncodec/default_value_encoders_test.go b/bson/bsoncodec/default_value_encoders_test.go index 0cb35a1ae2..12410a0b19 100644 --- a/bson/bsoncodec/default_value_encoders_test.go +++ b/bson/bsoncodec/default_value_encoders_test.go @@ -1776,7 +1776,7 @@ func TestDefaultValueEncoders(t *testing.T) { enc, err := reg.LookupEncoder(reflect.TypeOf(tc.value)) noerr(t, err) err = enc.EncodeValue(EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.value)) - if err != tc.err { + if !errors.Is(err, tc.err) { t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) } if diff := cmp.Diff([]byte(b), tc.b); diff != "" { diff --git a/bson/bsoncodec/map_codec.go b/bson/bsoncodec/map_codec.go index 868e39ccc0..fe304b2e03 100644 --- a/bson/bsoncodec/map_codec.go +++ b/bson/bsoncodec/map_codec.go @@ -129,7 +129,7 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, v } currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key)) - if lookupErr != nil && lookupErr != errInvalidValue { + if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } diff --git a/bson/bsoncodec/registry_test.go b/bson/bsoncodec/registry_test.go index 2a7d50a719..39863ac874 100644 --- a/bson/bsoncodec/registry_test.go +++ b/bson/bsoncodec/registry_test.go @@ -424,7 +424,7 @@ func TestRegistryBuilder(t *testing.T) { want = nil wanterr := ErrNoTypeMapEntry{Type: bsontype.ObjectID} got, err = reg.LookupTypeMapEntry(bsontype.ObjectID) - if err != wanterr { + if !errors.Is(err, wanterr) { t.Errorf("did not get expected error: got %#v, want %#v", err, wanterr) } if got != want { @@ -884,7 +884,7 @@ func TestRegistry(t *testing.T) { want = nil wanterr := ErrNoTypeMapEntry{Type: bsontype.ObjectID} got, err = reg.LookupTypeMapEntry(bsontype.ObjectID) - if err != wanterr { + if !errors.Is(err, wanterr) { t.Errorf("unexpected error: got %#v, want %#v", err, wanterr) } if got != want { diff --git a/bson/bsoncodec/slice_codec.go b/bson/bsoncodec/slice_codec.go index a43daf005f..0ffd92d92b 100644 --- a/bson/bsoncodec/slice_codec.go +++ b/bson/bsoncodec/slice_codec.go @@ -7,6 +7,7 @@ package bsoncodec import ( + "errors" "fmt" "reflect" @@ -93,7 +94,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val re for idx := 0; idx < val.Len(); idx++ { currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.Index(idx)) - if lookupErr != nil && lookupErr != errInvalidValue { + if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) { return lookupErr } @@ -102,7 +103,7 @@ func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val re return err } - if lookupErr == errInvalidValue { + if errors.Is(lookupErr, errInvalidValue) { err = vw.WriteNull() if err != nil { return err diff --git a/bson/bsoncodec/struct_codec.go b/bson/bsoncodec/struct_codec.go index d7d129d314..54cde81700 100644 --- a/bson/bsoncodec/struct_codec.go +++ b/bson/bsoncodec/struct_codec.go @@ -164,11 +164,11 @@ func (sc *StructCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(ec, desc.encoder, rv) - if err != nil && err != errInvalidValue { + if err != nil && !errors.Is(err, errInvalidValue) { return err } - if err == errInvalidValue { + if errors.Is(err, errInvalidValue) { if desc.omitEmpty { continue } @@ -308,7 +308,7 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val for { name, vr, err := dr.ReadElement() - if err == bsonrw.ErrEOD { + if errors.Is(err, bsonrw.ErrEOD) { break } if err != nil { diff --git a/bson/bsonrw/extjson_reader_test.go b/bson/bsonrw/extjson_reader_test.go index 8a9f0cc24d..4f790033a1 100644 --- a/bson/bsonrw/extjson_reader_test.go +++ b/bson/bsonrw/extjson_reader_test.go @@ -7,6 +7,7 @@ package bsonrw import ( + "errors" "fmt" "io" "strings" @@ -131,7 +132,7 @@ func readAllDocuments(vr ValueReader) ([][]byte, error) { for { result, err := c.CopyDocumentToBytes(vr) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { break } return nil, err @@ -147,7 +148,7 @@ func readAllDocuments(vr ValueReader) ([][]byte, error) { for { evr, err := ar.ReadValue() if err != nil { - if err == ErrEOA { + if errors.Is(err, ErrEOA) { break } return nil, err diff --git a/bson/bsonrw/json_scanner.go b/bson/bsonrw/json_scanner.go index 65a812ac18..43f3e4f383 100644 --- a/bson/bsonrw/json_scanner.go +++ b/bson/bsonrw/json_scanner.go @@ -325,17 +325,17 @@ func (js *jsonScanner) scanLiteral(first byte) (*jsonToken, error) { c5, err := js.readNextByte() - if bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || err == io.EOF) { + if bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || errors.Is(err, io.EOF)) { js.pos = int(math.Max(0, float64(js.pos-1))) return &jsonToken{t: jttBool, v: true, p: p}, nil - } else if bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || err == io.EOF) { + } else if bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || errors.Is(err, io.EOF)) { js.pos = int(math.Max(0, float64(js.pos-1))) return &jsonToken{t: jttNull, v: nil, p: p}, nil } else if bytes.Equal([]byte("fals"), lit) { if c5 == 'e' { c5, err = js.readNextByte() - if isValueTerminator(c5) || err == io.EOF { + if isValueTerminator(c5) || errors.Is(err, io.EOF) { js.pos = int(math.Max(0, float64(js.pos-1))) return &jsonToken{t: jttBool, v: false, p: p}, nil } @@ -413,7 +413,7 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { case '}', ']', ',': s = nssDone default: - if isWhiteSpace(c) || err == io.EOF { + if isWhiteSpace(c) || errors.Is(err, io.EOF) { s = nssDone } else { s = nssInvalid @@ -430,7 +430,7 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { case '}', ']', ',': s = nssDone default: - if isWhiteSpace(c) || err == io.EOF { + if isWhiteSpace(c) || errors.Is(err, io.EOF) { s = nssDone } else if isDigit(c) { s = nssSawIntegerDigits @@ -455,7 +455,7 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { case '}', ']', ',': s = nssDone default: - if isWhiteSpace(c) || err == io.EOF { + if isWhiteSpace(c) || errors.Is(err, io.EOF) { s = nssDone } else if isDigit(c) { s = nssSawFractionDigits @@ -490,7 +490,7 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { case '}', ']', ',': s = nssDone default: - if isWhiteSpace(c) || err == io.EOF { + if isWhiteSpace(c) || errors.Is(err, io.EOF) { s = nssDone } else if isDigit(c) { s = nssSawExponentDigits diff --git a/bson/bsonrw/value_reader_test.go b/bson/bsonrw/value_reader_test.go index 11b257277e..0617acf930 100644 --- a/bson/bsonrw/value_reader_test.go +++ b/bson/bsonrw/value_reader_test.go @@ -1527,7 +1527,7 @@ func errequal(t *testing.T, err1, err2 error) bool { return false } - if err1 == err2 { // They are the same error, they are equal + if errors.Is(err1, err2) { // They are the same error, they are equal return true } diff --git a/bson/raw_test.go b/bson/raw_test.go index 644a2eea16..d078012290 100644 --- a/bson/raw_test.go +++ b/bson/raw_test.go @@ -118,7 +118,7 @@ func TestRaw(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err := tc.r.Validate() - if err != tc.err { + if !errors.Is(err, tc.err) { t.Errorf("Returned error does not match. got %v; want %v", err, tc.err) } }) @@ -128,7 +128,7 @@ func TestRaw(t *testing.T) { t.Run("empty-key", func(t *testing.T) { rdr := Raw{'\x05', '\x00', '\x00', '\x00', '\x00'} _, err := rdr.LookupErr() - if err != bsoncore.ErrEmptyKey { + if !errors.Is(err, bsoncore.ErrEmptyKey) { t.Errorf("Empty key lookup did not return expected result. got %v; want %v", err, bsoncore.ErrEmptyKey) } }) @@ -211,7 +211,7 @@ func TestRaw(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { got, err := tc.r.LookupErr(tc.key...) - if err != tc.err { + if !errors.Is(err, tc.err) { t.Errorf("Returned error does not match. got %v; want %v", err, tc.err) } if !cmp.Equal(got, tc.want) { @@ -224,7 +224,7 @@ func TestRaw(t *testing.T) { t.Run("Out of bounds", func(t *testing.T) { rdr := Raw{0xe, 0x0, 0x0, 0x0, 0xa, 0x78, 0x0, 0xa, 0x79, 0x0, 0xa, 0x7a, 0x0, 0x0} _, err := rdr.IndexErr(3) - if err != bsoncore.ErrOutOfBounds { + if !errors.Is(err, bsoncore.ErrOutOfBounds) { t.Errorf("Out of bounds should be returned when accessing element beyond end of document. got %v; want %v", err, bsoncore.ErrOutOfBounds) } }) diff --git a/bson/raw_value_test.go b/bson/raw_value_test.go index 87f08c4a55..6e6b78aebb 100644 --- a/bson/raw_value_test.go +++ b/bson/raw_value_test.go @@ -7,6 +7,7 @@ package bson import ( + "errors" "fmt" "reflect" "testing" @@ -57,7 +58,7 @@ func TestRawValue(t *testing.T) { want := ErrNilRegistry var val RawValue got := val.UnmarshalWithRegistry(nil, &D{}) - if got != want { + if !errors.Is(got, want) { t.Errorf("Expected errors to match. got %v; want %v", got, want) } }) @@ -108,7 +109,7 @@ func TestRawValue(t *testing.T) { want := ErrNilContext var val RawValue got := val.UnmarshalWithContext(nil, &D{}) - if got != want { + if !errors.Is(got, want) { t.Errorf("Expected errors to match. got %v; want %v", got, want) } }) diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index 2283b96771..70f2c321d0 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -7,6 +7,7 @@ package bson import ( + "errors" "math/rand" "reflect" "sync" @@ -100,7 +101,7 @@ func TestUnmarshalExtJSONWithRegistry(t *testing.T) { t.Run("UnmarshalExtJSONInvalidInput", func(t *testing.T) { data := []byte("invalid") err := UnmarshalExtJSONWithRegistry(DefaultRegistry, data, true, &M{}) - if err != bsonrw.ErrInvalidJSON { + if !errors.Is(err, bsonrw.ErrInvalidJSON) { t.Fatalf("wanted ErrInvalidJSON, got %v", err) } }) diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index a7efd551e7..87f896aec5 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -72,7 +72,7 @@ func (bw *bulkWrite) execute(ctx context.Context) error { bwErr.WriteErrors = append(bwErr.WriteErrors, batchErr.WriteErrors...) - commandErrorOccurred := err != nil && err != driver.ErrUnacknowledgedWrite + commandErrorOccurred := err != nil && !errors.Is(err, driver.ErrUnacknowledgedWrite) writeErrorOccurred := len(batchErr.WriteErrors) > 0 || batchErr.WriteConcernError != nil if !continueOnError && (commandErrorOccurred || writeErrorOccurred) { if err != nil { diff --git a/mongo/collection.go b/mongo/collection.go index ac173307ff..74b55bb2cf 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -929,7 +929,8 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { err = op.Execute(a.ctx) if err != nil { - if wce, ok := err.(driver.WriteCommandError); ok && wce.WriteConcernError != nil { + var wce driver.WriteCommandError + if errors.As(err, &wce) && wce.WriteConcernError != nil { return nil, *convertDriverWriteConcernError(wce.WriteConcernError) } return nil, replaceErrors(err) @@ -1868,8 +1869,8 @@ func (coll *Collection) drop(ctx context.Context) error { err = op.Execute(ctx) // ignore namespace not found errors - driverErr, ok := err.(driver.Error) - if !ok || (ok && !driverErr.NamespaceNotFound()) { + var driverErr driver.Error + if !errors.As(err, &driverErr) || !driverErr.NamespaceNotFound() { return replaceErrors(err) } return nil diff --git a/mongo/crud_examples_test.go b/mongo/crud_examples_test.go index d657ed6965..70e1e8a1c9 100644 --- a/mongo/crud_examples_test.go +++ b/mongo/crud_examples_test.go @@ -8,6 +8,7 @@ package mongo_test import ( "context" + "errors" "fmt" "log" "sync" @@ -387,7 +388,7 @@ func ExampleCollection_FindOne() { if err != nil { // ErrNoDocuments means that the filter did not match any documents in // the collection. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { return } log.Fatal(err) @@ -413,7 +414,7 @@ func ExampleCollection_FindOneAndDelete() { if err != nil { // ErrNoDocuments means that the filter did not match any documents in // the collection. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { return } log.Fatal(err) @@ -442,7 +443,7 @@ func ExampleCollection_FindOneAndReplace() { if err != nil { // ErrNoDocuments means that the filter did not match any documents in // the collection. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { return } log.Fatal(err) @@ -471,7 +472,7 @@ func ExampleCollection_FindOneAndUpdate() { if err != nil { // ErrNoDocuments means that the filter did not match any documents in // the collection. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { return } log.Fatal(err) diff --git a/mongo/database.go b/mongo/database.go index 6760f0d014..69f1e36bb4 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -312,8 +312,8 @@ func (db *Database) Drop(ctx context.Context) error { err = op.Execute(ctx) - driverErr, ok := err.(driver.Error) - if err != nil && (!ok || !driverErr.NamespaceNotFound()) { + var driverErr driver.Error + if err != nil && (!errors.As(err, &driverErr) || !driverErr.NamespaceNotFound()) { return replaceErrors(err) } return nil @@ -566,7 +566,7 @@ func (db *Database) getEncryptedFieldsFromServer(ctx context.Context, collection } collSpec := collSpecs[0] rawValue, err := collSpec.Options.LookupErr("encryptedFields") - if err == bsoncore.ErrElementNotFound { + if errors.Is(err, bsoncore.ErrElementNotFound) { return nil, nil } else if err != nil { return nil, err @@ -602,7 +602,7 @@ func (db *Database) getEncryptedFieldsFromMap(collectionName string) interface{} func (db *Database) createCollectionWithEncryptedFields(ctx context.Context, name string, ef interface{}, opts ...*options.CreateCollectionOptions) error { efBSON, err := marshal(ef, db.bsonOpts, db.registry) if err != nil { - return fmt.Errorf("error transforming document: %v", err) + return fmt.Errorf("error transforming document: %w", err) } // Check the wire version to ensure server is 7.0.0 or newer. @@ -662,7 +662,7 @@ func (db *Database) createCollectionWithEncryptedFields(ctx context.Context, nam // Create an index on the __safeContent__ field in the collection @collectionName. if _, err := db.Collection(name).Indexes().CreateOne(ctx, IndexModel{Keys: bson.D{{"__safeContent__", 1}}}); err != nil { - return fmt.Errorf("error creating safeContent index: %v", err) + return fmt.Errorf("error creating safeContent index: %w", err) } return nil diff --git a/mongo/errors.go b/mongo/errors.go index 72c3bcc243..777746d5a0 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -52,10 +52,12 @@ func replaceErrors(err error) error { return nil } - if err == topology.ErrTopologyClosed { + if errors.Is(err, topology.ErrTopologyClosed) { return ErrClientDisconnected } - if de, ok := err.(driver.Error); ok { + + var de driver.Error + if errors.As(err, &de) { return CommandError{ Code: de.Code, Message: de.Message, @@ -65,7 +67,9 @@ func replaceErrors(err error) error { Raw: bson.Raw(de.Raw), } } - if qe, ok := err.(driver.QueryFailureError); ok { + + var qe driver.QueryFailureError + if errors.As(err, &qe) { // qe.Message is "command failure" ce := CommandError{ Name: qe.Message, @@ -84,7 +88,9 @@ func replaceErrors(err error) error { return ce } - if me, ok := err.(mongocrypt.Error); ok { + + var me mongocrypt.Error + if errors.As(err, &me) { return MongocryptError{Code: me.Code, Message: me.Message} } @@ -92,7 +98,8 @@ func replaceErrors(err error) error { return ErrNilValue } - if marshalErr, ok := err.(codecutil.MarshalError); ok { + var marshalErr codecutil.MarshalError + if errors.As(err, &marshalErr) { return MarshalError{ Value: marshalErr.Value, Err: marshalErr.Err, @@ -171,7 +178,8 @@ func unwrap(err error) error { // errorHasLabel returns true if err contains the specified label func errorHasLabel(err error, label string) bool { for ; err != nil; err = unwrap(err) { - if le, ok := err.(LabeledError); ok && le.HasErrorLabel(label) { + var le LabeledError + if errors.As(err, &le) && le.HasErrorLabel(label) { return true } } @@ -630,7 +638,7 @@ const ( // WriteConcernError will be returned over WriteErrors if both are present. func processWriteError(err error) (returnResult, error) { switch { - case err == driver.ErrUnacknowledgedWrite: + case errors.Is(err, driver.ErrUnacknowledgedWrite): return rrAll, ErrUnacknowledgedWrite case err != nil: switch tt := err.(type) { diff --git a/mongo/gridfs/bucket.go b/mongo/gridfs/bucket.go index 61e2cb9e74..b231d1dd77 100644 --- a/mongo/gridfs/bucket.go +++ b/mongo/gridfs/bucket.go @@ -429,7 +429,7 @@ func (b *Bucket) openDownloadStream(filter interface{}, opts ...*options.FindOpt // in the File type. After parsing it, use RawValue.Unmarshal to ensure File.ID is set to the appropriate value. var foundFile File if err = cursor.Decode(&foundFile); err != nil { - return nil, fmt.Errorf("error decoding files collection document: %v", err) + return nil, fmt.Errorf("error decoding files collection document: %w", err) } if foundFile.Length == 0 { @@ -594,7 +594,7 @@ func (b *Bucket) createIndexes(ctx context.Context) error { docRes := cloned.FindOne(ctx, bson.D{}, options.FindOne().SetProjection(bson.D{{"_id", 1}})) _, err = docRes.Raw() - if err != mongo.ErrNoDocuments { + if !errors.Is(err, mongo.ErrNoDocuments) { // nil, or error that occurred during the FindOne operation return err } diff --git a/mongo/gridfs/download_stream.go b/mongo/gridfs/download_stream.go index 20c8df8a6f..7c75813f54 100644 --- a/mongo/gridfs/download_stream.go +++ b/mongo/gridfs/download_stream.go @@ -160,7 +160,7 @@ func (ds *DownloadStream) Read(p []byte) (int, error) { // Buffer is empty and can load in data from new chunk. err = ds.fillBuffer(ctx) if err != nil { - if err == errNoMoreChunks { + if errors.Is(err, errNoMoreChunks) { if bytesCopied == 0 { ds.done = true return 0, io.EOF @@ -203,7 +203,7 @@ func (ds *DownloadStream) Skip(skip int64) (int64, error) { // Buffer is empty and can load in data from new chunk. err = ds.fillBuffer(ctx) if err != nil { - if err == errNoMoreChunks { + if errors.Is(err, errNoMoreChunks) { return skipped, nil } return skipped, err diff --git a/mongo/index_view.go b/mongo/index_view.go index 8d3555d0b0..84fe026a17 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -53,7 +53,8 @@ type IndexModel struct { } func isNamespaceNotFoundError(err error) bool { - if de, ok := err.(driver.Error); ok { + var de driver.Error + if errors.As(err, &de) { return de.Code == 26 } return false diff --git a/mongo/integration/collection_test.go b/mongo/integration/collection_test.go index da03258738..15c0e9324c 100644 --- a/mongo/integration/collection_test.go +++ b/mongo/integration/collection_test.go @@ -8,6 +8,7 @@ package integration import ( "context" + "errors" "strings" "testing" "time" @@ -1697,7 +1698,7 @@ func TestCollection(t *testing.T) { mongo.NewInsertOneModel().SetDocument(bson.D{{"x", 1}}), } _, err := mt.Coll.BulkWrite(context.Background(), models) - if err != mongo.ErrUnacknowledgedWrite { + if !errors.Is(err, mongo.ErrUnacknowledgedWrite) { // Use a direct comparison rather than assert.Equal because assert.Equal will compare the error strings, // so the assertion would succeed even if the error had not been wrapped. mt.Fatalf("expected BulkWrite error %v, got %v", mongo.ErrUnacknowledgedWrite, err) diff --git a/mongo/integration/crud_helpers_test.go b/mongo/integration/crud_helpers_test.go index 2b0c743c87..3344ac297b 100644 --- a/mongo/integration/crud_helpers_test.go +++ b/mongo/integration/crud_helpers_test.go @@ -126,7 +126,7 @@ func runCommandOnAllServers(commandFn func(client *mongo.Client) error) error { if mtest.ClusterTopologyKind() != mtest.Sharded { client, err := mongo.Connect(context.Background(), opts) if err != nil { - return fmt.Errorf("error creating replica set client: %v", err) + return fmt.Errorf("error creating replica set client: %w", err) } defer func() { _ = client.Disconnect(context.Background()) }() @@ -136,7 +136,7 @@ func runCommandOnAllServers(commandFn func(client *mongo.Client) error) error { for _, host := range opts.Hosts { shardClient, err := mongo.Connect(context.Background(), opts.SetHosts([]string{host})) if err != nil { - return fmt.Errorf("error creating client for mongos %v: %v", host, err) + return fmt.Errorf("error creating client for mongos %v: %w", host, err) } err = commandFn(shardClient) diff --git a/mongo/integration/index_view_test.go b/mongo/integration/index_view_test.go index e0cc6e2f87..bff69150d1 100644 --- a/mongo/integration/index_view_test.go +++ b/mongo/integration/index_view_test.go @@ -8,6 +8,7 @@ package integration import ( "context" + "errors" "testing" "time" @@ -270,7 +271,7 @@ func TestIndexView(t *testing.T) { MinServerVersion("3.6") mt.RunOpts("unacknowledged write", unackMtOpts, func(mt *mtest.T) { _, err := mt.Coll.Indexes().CreateOne(context.Background(), mongo.IndexModel{Keys: bson.D{{"x", 1}}}) - if err != mongo.ErrUnacknowledgedWrite { + if !errors.Is(err, mongo.ErrUnacknowledgedWrite) { // Use a direct comparison rather than assert.Equal because assert.Equal will compare the error strings, // so the assertion would succeed even if the error had not been wrapped. mt.Fatalf("expected CreateOne error %v, got %v", mongo.ErrUnacknowledgedWrite, err) diff --git a/mongo/integration/json_helpers_test.go b/mongo/integration/json_helpers_test.go index 749de6a5b1..8ddc2b6867 100644 --- a/mongo/integration/json_helpers_test.go +++ b/mongo/integration/json_helpers_test.go @@ -8,6 +8,7 @@ package integration import ( "crypto/tls" + "errors" "fmt" "io/ioutil" "math" @@ -514,12 +515,12 @@ func extractErrorDetails(err error) (errorDetails, bool) { func verifyError(expected *operationError, actual error) error { // The spec test format doesn't treat ErrNoDocuments or ErrUnacknowledgedWrite as errors, so set actual to nil // to indicate that no error occurred. - if actual == mongo.ErrNoDocuments || actual == mongo.ErrUnacknowledgedWrite { + if errors.Is(actual, mongo.ErrNoDocuments) || errors.Is(actual, mongo.ErrUnacknowledgedWrite) { actual = nil } if expected == nil && actual != nil { - return fmt.Errorf("did not expect error but got %v", actual) + return fmt.Errorf("did not expect error but got %w", actual) } if expected != nil && actual == nil { return fmt.Errorf("expected error but got nil") @@ -554,12 +555,12 @@ func verifyError(expected *operationError, actual error) error { } for _, label := range expected.ErrorLabelsContain { if !stringSliceContains(details.labels, label) { - return fmt.Errorf("expected error %v to contain label %q", actual, label) + return fmt.Errorf("expected error %w to contain label %q", actual, label) } } for _, label := range expected.ErrorLabelsOmit { if stringSliceContains(details.labels, label) { - return fmt.Errorf("expected error %v to not contain label %q", actual, label) + return fmt.Errorf("expected error %w to not contain label %q", actual, label) } } return nil diff --git a/mongo/integration/mtest/setup.go b/mongo/integration/mtest/setup.go index 49aacfd194..fac18d471a 100644 --- a/mongo/integration/mtest/setup.go +++ b/mongo/integration/mtest/setup.go @@ -119,12 +119,12 @@ func Setup(setupOpts ...*SetupOptions) error { pingCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() if err := testContext.client.Ping(pingCtx, readpref.Primary()); err != nil { - return fmt.Errorf("ping error: %v; make sure the deployment is running on URI %v", err, + return fmt.Errorf("ping error: %w; make sure the deployment is running on URI %v", err, testContext.connString.Original) } if testContext.serverVersion, err = getServerVersion(); err != nil { - return fmt.Errorf("error getting server version: %v", err) + return fmt.Errorf("error getting server version: %w", err) } switch testContext.topo.Kind() { @@ -145,7 +145,7 @@ func Setup(setupOpts ...*SetupOptions) error { // Run a find against config.shards and get each document in the collection. cursor, err := testContext.client.Database("config").Collection("shards").Find(context.Background(), bson.D{}) if err != nil { - return fmt.Errorf("error running find against config.shards: %v", err) + return fmt.Errorf("error running find against config.shards: %w", err) } defer cursor.Close(context.Background()) @@ -153,7 +153,7 @@ func Setup(setupOpts ...*SetupOptions) error { Host string `bson:"host"` } if err := cursor.All(context.Background(), &shards); err != nil { - return fmt.Errorf("error getting results find against config.shards: %v", err) + return fmt.Errorf("error getting results find against config.shards: %w", err) } // Each document's host field will contain a single hostname if the shard is a standalone. If it's a replica @@ -181,7 +181,7 @@ func Setup(setupOpts ...*SetupOptions) error { } testContext.singleMongosLoadBalancerURI, err = addNecessaryParamsToURI(singleMongosURI) if err != nil { - return fmt.Errorf("error getting single mongos load balancer uri: %v", err) + return fmt.Errorf("error getting single mongos load balancer uri: %w", err) } multiMongosURI := os.Getenv("MULTI_MONGOS_LB_URI") @@ -190,7 +190,7 @@ func Setup(setupOpts ...*SetupOptions) error { } testContext.multiMongosLoadBalancerURI, err = addNecessaryParamsToURI(multiMongosURI) if err != nil { - return fmt.Errorf("error getting multi mongos load balancer uri: %v", err) + return fmt.Errorf("error getting multi mongos load balancer uri: %w", err) } } @@ -198,7 +198,7 @@ func Setup(setupOpts ...*SetupOptions) error { testContext.sslEnabled = os.Getenv("SSL") == "ssl" biRes, err := testContext.client.Database("admin").RunCommand(context.Background(), bson.D{{"buildInfo", 1}}).Raw() if err != nil { - return fmt.Errorf("buildInfo error: %v", err) + return fmt.Errorf("buildInfo error: %w", err) } modulesRaw, err := biRes.LookupErr("modules") if err == nil { @@ -217,7 +217,7 @@ func Setup(setupOpts ...*SetupOptions) error { db := testContext.client.Database("admin") testContext.serverParameters, err = db.RunCommand(context.Background(), bson.D{{"getParameter", "*"}}).Raw() if err != nil { - return fmt.Errorf("error getting serverParameters: %v", err) + return fmt.Errorf("error getting serverParameters: %w", err) } } return nil @@ -229,14 +229,14 @@ func Teardown() error { // Dropping the test database causes an error against Atlas Data Lake. if !testContext.dataLake { if err := testContext.client.Database(TestDb).Drop(context.Background()); err != nil { - return fmt.Errorf("error dropping test database: %v", err) + return fmt.Errorf("error dropping test database: %w", err) } } if err := testContext.client.Disconnect(context.Background()); err != nil { - return fmt.Errorf("error disconnecting test client: %v", err) + return fmt.Errorf("error disconnecting test client: %w", err) } if err := testContext.topo.Disconnect(context.Background()); err != nil { - return fmt.Errorf("error disconnecting test topology: %v", err) + return fmt.Errorf("error disconnecting test topology: %w", err) } return nil } diff --git a/mongo/integration/mtest/wiremessage_helpers.go b/mongo/integration/mtest/wiremessage_helpers.go index 192271a7f4..c6d8a677f6 100644 --- a/mongo/integration/mtest/wiremessage_helpers.go +++ b/mongo/integration/mtest/wiremessage_helpers.go @@ -65,7 +65,7 @@ func parseOpCompressed(wm []byte) (wiremessage.OpCode, []byte, error) { } decompressed, err := driver.DecompressPayload(compressedMsg, opts) if err != nil { - return originalOpcode, nil, fmt.Errorf("error decompressing payload: %v", err) + return originalOpcode, nil, fmt.Errorf("error decompressing payload: %w", err) } return originalOpcode, decompressed, nil diff --git a/mongo/integration/sessions_test.go b/mongo/integration/sessions_test.go index dfd3e2d260..da7fe8a8d1 100644 --- a/mongo/integration/sessions_test.go +++ b/mongo/integration/sessions_test.go @@ -415,7 +415,7 @@ func TestSessionsProse(t *testing.T) { }, "findOneAndDelete": func(ctx context.Context) error { result := mt.Coll.FindOneAndDelete(ctx, bson.D{}) - if err := result.Err(); err != nil && err != mongo.ErrNoDocuments { + if err := result.Err(); err != nil && !errors.Is(err, mongo.ErrNoDocuments) { return err } return nil @@ -424,14 +424,14 @@ func TestSessionsProse(t *testing.T) { result := mt.Coll.FindOneAndUpdate(ctx, bson.D{}, bson.D{{"$set", bson.D{{"a", 1}}}}) - if err := result.Err(); err != nil && err != mongo.ErrNoDocuments { + if err := result.Err(); err != nil && !errors.Is(err, mongo.ErrNoDocuments) { return err } return nil }, "findOneAndReplace": func(ctx context.Context) error { result := mt.Coll.FindOneAndReplace(ctx, bson.D{}, bson.D{{"a", 1}}) - if err := result.Err(); err != nil && err != mongo.ErrNoDocuments { + if err := result.Err(); err != nil && !errors.Is(err, mongo.ErrNoDocuments) { return err } return nil @@ -469,7 +469,7 @@ func TestSessionsProse(t *testing.T) { cmd := cmd errs.Go(func() error { if err := op(ctx); err != nil { - return fmt.Errorf("error running %s operation: %v", cmd, err) + return fmt.Errorf("error running %s operation: %w", cmd, err) } return nil }) diff --git a/mongo/integration/unified/admin_helpers.go b/mongo/integration/unified/admin_helpers.go index 29955aed82..e6a0bda9cc 100644 --- a/mongo/integration/unified/admin_helpers.go +++ b/mongo/integration/unified/admin_helpers.go @@ -70,7 +70,7 @@ func performDistinctWorkaround(ctx context.Context) error { _, err := newColl.Distinct(ctx, "x", bson.D{}) if err != nil { ns := fmt.Sprintf("%s.%s", coll.Database().Name(), coll.Name()) - return fmt.Errorf("error running distinct for collection %q: %v", ns, err) + return fmt.Errorf("error running distinct for collection %q: %w", ns, err) } } @@ -88,7 +88,7 @@ func runCommandOnHost(ctx context.Context, host string, commandFn func(context.C client, err := mongo.Connect(ctx, clientOpts) if err != nil { - return fmt.Errorf("error creating client to host %q: %v", host, err) + return fmt.Errorf("error creating client to host %q: %w", host, err) } defer client.Disconnect(ctx) @@ -98,7 +98,7 @@ func runCommandOnHost(ctx context.Context, host string, commandFn func(context.C func runAgainstAllMongoses(ctx context.Context, commandFn func(context.Context, *mongo.Client) error) error { for _, host := range mtest.ClusterConnString().Hosts { if err := runCommandOnHost(ctx, host, commandFn); err != nil { - return fmt.Errorf("error executing callback against host %q: %v", host, err) + return fmt.Errorf("error executing callback against host %q: %w", host, err) } } return nil diff --git a/mongo/integration/unified/bucket_options.go b/mongo/integration/unified/bucket_options.go index a738a42aa6..fe5f3d625f 100644 --- a/mongo/integration/unified/bucket_options.go +++ b/mongo/integration/unified/bucket_options.go @@ -31,7 +31,7 @@ func (bo *gridFSBucketOptions) UnmarshalBSON(data []byte) error { Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { - return fmt.Errorf("error unmarshalling to temporary gridFSBucketOptions object: %v", err) + return fmt.Errorf("error unmarshalling to temporary gridFSBucketOptions object: %w", err) } if len(temp.Extra) > 0 { return fmt.Errorf("unrecognized fields for gridFSBucketOptions: %v", mapKeys(temp.Extra)) @@ -50,14 +50,14 @@ func (bo *gridFSBucketOptions) UnmarshalBSON(data []byte) error { if temp.RP != nil { rp, err := temp.RP.ToReadPrefOption() if err != nil { - return fmt.Errorf("error parsing read preference document: %v", err) + return fmt.Errorf("error parsing read preference document: %w", err) } bo.SetReadPreference(rp) } if temp.WC != nil { wc, err := temp.WC.toWriteConcernOption() if err != nil { - return fmt.Errorf("error parsing write concern document: %v", err) + return fmt.Errorf("error parsing write concern document: %w", err) } bo.SetWriteConcern(wc) } diff --git a/mongo/integration/unified/bulkwrite_helpers.go b/mongo/integration/unified/bulkwrite_helpers.go index 5e5ade0f16..1d43fca40a 100644 --- a/mongo/integration/unified/bulkwrite_helpers.go +++ b/mongo/integration/unified/bulkwrite_helpers.go @@ -27,7 +27,7 @@ func createBulkWriteModels(rawModels bson.Raw) ([]mongo.WriteModel, error) { for idx, val := range vals { model, err := createBulkWriteModel(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating model at index %d: %v", idx, err) + return nil, fmt.Errorf("error creating model at index %d: %w", idx, err) } models = append(models, model) } @@ -79,7 +79,7 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) { case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } uom.SetCollation(collation) case "filter": @@ -87,13 +87,13 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) { case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } uom.SetHint(hint) case "update": update, err = createUpdateValue(val) if err != nil { - return nil, fmt.Errorf("error creating update: %v", err) + return nil, fmt.Errorf("error creating update: %w", err) } case "upsert": uom.SetUpsert(val.Boolean()) @@ -128,7 +128,7 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) { case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } umm.SetCollation(collation) case "filter": @@ -136,13 +136,13 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) { case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } umm.SetHint(hint) case "update": update, err = createUpdateValue(val) if err != nil { - return nil, fmt.Errorf("error creating update: %v", err) + return nil, fmt.Errorf("error creating update: %w", err) } case "upsert": umm.SetUpsert(val.Boolean()) @@ -173,7 +173,7 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) { case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } dom.SetHint(hint) default: @@ -198,7 +198,7 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) { case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } dmm.SetCollation(collation) case "filter": @@ -206,7 +206,7 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) { case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } dmm.SetHint(hint) default: @@ -231,7 +231,7 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) { case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } rom.SetCollation(collation) case "filter": @@ -239,7 +239,7 @@ func createBulkWriteModel(rawModel bson.Raw) (mongo.WriteModel, error) { case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } rom.SetHint(hint) case "replacement": diff --git a/mongo/integration/unified/client_encryption_operation_execution.go b/mongo/integration/unified/client_encryption_operation_execution.go index f5345f5620..d30087593d 100644 --- a/mongo/integration/unified/client_encryption_operation_execution.go +++ b/mongo/integration/unified/client_encryption_operation_execution.go @@ -8,6 +8,7 @@ package unified import ( "context" + "errors" "fmt" "go.mongodb.org/mongo-driver/bson" @@ -31,19 +32,19 @@ func parseDataKeyOptions(opts bson.Raw) (*options.DataKeyOptions, error) { case "masterKey": masterKey := make(map[string]interface{}) if err := val.Unmarshal(&masterKey); err != nil { - return nil, fmt.Errorf("error unmarshaling 'masterKey': %v", err) + return nil, fmt.Errorf("error unmarshaling 'masterKey': %w", err) } dko.SetMasterKey(masterKey) case "keyAltNames": keyAltNames := []string{} if err := val.Unmarshal(&keyAltNames); err != nil { - return nil, fmt.Errorf("error unmarshaling 'keyAltNames': %v", err) + return nil, fmt.Errorf("error unmarshaling 'keyAltNames': %w", err) } dko.SetKeyAltNames(keyAltNames) case "keyMaterial": bin := primitive.Binary{} if err := val.Unmarshal(&bin); err != nil { - return nil, fmt.Errorf("error unmarshaling 'keyMaterial': %v", err) + return nil, fmt.Errorf("error unmarshaling 'keyMaterial': %w", err) } dko.SetKeyMaterial(bin.Data) default: @@ -86,7 +87,7 @@ func executeAddKeyAltName(ctx context.Context, operation *operation) (*operation res, err := cee.AddKeyAltName(ctx, id, keyAltName).Raw() // Ignore ErrNoDocuments errors from Raw. In the event that the cursor returned in a find operation has no // associated documents, Raw will return ErrNoDocuments. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { err = nil } return newDocumentResult(res, err), nil @@ -202,7 +203,7 @@ func executeGetKeyByAltName(ctx context.Context, operation *operation) (*operati res, err := cee.GetKeyByAltName(ctx, keyAltName).Raw() // Ignore ErrNoDocuments errors from Raw. In the event that the cursor returned in a find operation has no // associated documents, Raw will return ErrNoDocuments. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { err = nil } return newDocumentResult(res, err), nil @@ -238,7 +239,7 @@ func executeGetKey(ctx context.Context, operation *operation) (*operationResult, res, err := cee.GetKey(ctx, id).Raw() // Ignore ErrNoDocuments errors from Raw. In the event that the cursor returned in a find operation has no // associated documents, Raw will return ErrNoDocuments. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { err = nil } return newDocumentResult(res, err), nil @@ -294,7 +295,7 @@ func executeRemoveKeyAltName(ctx context.Context, operation *operation) (*operat res, err := cee.RemoveKeyAltName(ctx, id, keyAltName).Raw() // Ignore ErrNoDocuments errors from Raw. In the event that the cursor returned in a find operation has no // associated documents, Raw will return ErrNoDocuments. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { err = nil } return newDocumentResult(res, err), nil @@ -333,7 +334,7 @@ func rewrapManyDataKeyResultsOpResult(result *mongo.RewrapManyDataKeyResult) (*o if res.UpsertedIDs != nil { rawUpsertedIDs, marshalErr = bson.Marshal(res.UpsertedIDs) if marshalErr != nil { - return nil, fmt.Errorf("error marshalling UpsertedIDs map to BSON: %v", marshalErr) + return nil, fmt.Errorf("error marshalling UpsertedIDs map to BSON: %w", marshalErr) } } bulkWriteResult := bsoncore.NewDocumentBuilder() diff --git a/mongo/integration/unified/client_operation_execution.go b/mongo/integration/unified/client_operation_execution.go index 506ca4f351..5a69e77b1e 100644 --- a/mongo/integration/unified/client_operation_execution.go +++ b/mongo/integration/unified/client_operation_execution.go @@ -51,13 +51,13 @@ func executeCreateChangeStream(ctx context.Context, operation *operation) (*oper case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(*collation) case "comment": commentString, err := createCommentString(val) if err != nil { - return nil, fmt.Errorf("error creating comment: %v", err) + return nil, fmt.Errorf("error creating comment: %w", err) } opts.SetComment(commentString) case "fullDocument": @@ -112,7 +112,7 @@ func executeCreateChangeStream(ctx context.Context, operation *operation) (*oper // empty result in this case. if operation.ResultEntityID != nil { if err := entities(ctx).addCursorEntity(*operation.ResultEntityID, stream); err != nil { - return nil, fmt.Errorf("error storing result as cursor entity: %v", err) + return nil, fmt.Errorf("error storing result as cursor entity: %w", err) } } return newEmptyResult(), nil diff --git a/mongo/integration/unified/collection_data.go b/mongo/integration/unified/collection_data.go index 979ecccf05..6d4c59c282 100644 --- a/mongo/integration/unified/collection_data.go +++ b/mongo/integration/unified/collection_data.go @@ -37,7 +37,7 @@ func (c *collectionData) createCollection(ctx context.Context) error { db := mtest.GlobalClient().Database(c.DatabaseName, options.Database().SetWriteConcern(mtest.MajorityWc)) coll := db.Collection(c.CollectionName) if err := coll.Drop(ctx); err != nil { - return fmt.Errorf("error dropping collection: %v", err) + return fmt.Errorf("error dropping collection: %w", err) } // Explicitly create collection if Options are specified. @@ -51,7 +51,7 @@ func (c *collectionData) createCollection(ctx context.Context) error { } if err := db.CreateCollection(ctx, c.CollectionName, createOpts); err != nil { - return fmt.Errorf("error creating collection: %v", err) + return fmt.Errorf("error creating collection: %w", err) } } @@ -66,14 +66,14 @@ func (c *collectionData) createCollection(ctx context.Context) error { }}, } if err := db.RunCommand(ctx, create).Err(); err != nil { - return fmt.Errorf("error creating collection: %v", err) + return fmt.Errorf("error creating collection: %w", err) } return nil } docs := bsonutil.RawToInterfaces(c.Documents...) if _, err := coll.InsertMany(ctx, docs); err != nil { - return fmt.Errorf("error inserting data: %v", err) + return fmt.Errorf("error inserting data: %w", err) } return nil } @@ -88,13 +88,13 @@ func (c *collectionData) verifyContents(ctx context.Context) error { cursor, err := coll.Find(ctx, bson.D{}, options.Find().SetSort(bson.M{"_id": 1})) if err != nil { - return fmt.Errorf("Find error: %v", err) + return fmt.Errorf("Find error: %w", err) } defer cursor.Close(ctx) var docs []bson.Raw if err := cursor.All(ctx, &docs); err != nil { - return fmt.Errorf("cursor iteration error: %v", err) + return fmt.Errorf("cursor iteration error: %w", err) } // Verify the slice lengths are equal. This also covers the case of asserting that the collection is empty if diff --git a/mongo/integration/unified/collection_operation_execution.go b/mongo/integration/unified/collection_operation_execution.go index 0d4930b335..978ce13f00 100644 --- a/mongo/integration/unified/collection_operation_execution.go +++ b/mongo/integration/unified/collection_operation_execution.go @@ -8,6 +8,7 @@ package unified import ( "context" + "errors" "fmt" "time" @@ -56,7 +57,7 @@ func executeAggregate(ctx context.Context, operation *operation) (*operationResu case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(collation) case "comment": @@ -64,13 +65,13 @@ func executeAggregate(ctx context.Context, operation *operation) (*operationResu // TODO with `opts.SetComment(val)` commentString, err := createCommentString(val) if err != nil { - return nil, fmt.Errorf("error creating comment: %v", err) + return nil, fmt.Errorf("error creating comment: %w", err) } opts.SetComment(commentString) case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) case "maxTimeMS": @@ -127,7 +128,7 @@ func executeBulkWrite(ctx context.Context, operation *operation) (*operationResu case "requests": models, err = createBulkWriteModels(val.Array()) if err != nil { - return nil, fmt.Errorf("error creating write models: %v", err) + return nil, fmt.Errorf("error creating write models: %w", err) } case "let": opts.SetLet(val.Document()) @@ -147,7 +148,7 @@ func executeBulkWrite(ctx context.Context, operation *operation) (*operationResu if res.UpsertedIDs != nil { rawUpsertedIDs, marshalErr = bson.Marshal(res.UpsertedIDs) if marshalErr != nil { - return nil, fmt.Errorf("error marshalling UpsertedIDs map to BSON: %v", marshalErr) + return nil, fmt.Errorf("error marshalling UpsertedIDs map to BSON: %w", marshalErr) } } @@ -184,7 +185,7 @@ func executeCountDocuments(ctx context.Context, operation *operation) (*operatio case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(collation) case "comment": @@ -192,7 +193,7 @@ func executeCountDocuments(ctx context.Context, operation *operation) (*operatio // TODO with `opts.SetComment(val)` commentString, err := createCommentString(val) if err != nil { - return nil, fmt.Errorf("error creating comment: %v", err) + return nil, fmt.Errorf("error creating comment: %w", err) } opts.SetComment(commentString) case "filter": @@ -200,7 +201,7 @@ func executeCountDocuments(ctx context.Context, operation *operation) (*operatio case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) case "limit": @@ -253,7 +254,7 @@ func executeCreateIndex(ctx context.Context, operation *operation) (*operationRe case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } indexOpts.SetCollation(collation) case "defaultLanguage": @@ -414,7 +415,7 @@ func executeDeleteOne(ctx context.Context, operation *operation) (*operationResu case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(collation) case "comment": @@ -424,7 +425,7 @@ func executeDeleteOne(ctx context.Context, operation *operation) (*operationResu case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) case "let": @@ -470,7 +471,7 @@ func executeDeleteMany(ctx context.Context, operation *operation) (*operationRes case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(collation) case "filter": @@ -478,7 +479,7 @@ func executeDeleteMany(ctx context.Context, operation *operation) (*operationRes case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) case "let": @@ -523,7 +524,7 @@ func executeDistinct(ctx context.Context, operation *operation) (*operationResul case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(collation) case "comment": @@ -551,7 +552,7 @@ func executeDistinct(ctx context.Context, operation *operation) (*operationResul } _, rawRes, err := bson.MarshalValue(res) if err != nil { - return nil, fmt.Errorf("error converting Distinct result to raw BSON: %v", err) + return nil, fmt.Errorf("error converting Distinct result to raw BSON: %w", err) } return newValueResult(bsontype.Array, rawRes, nil), nil } @@ -688,7 +689,7 @@ func executeCreateFindCursor(ctx context.Context, operation *operation) (*operat return nil, fmt.Errorf("no entity name provided to store executeCreateFindCursor result") } if err := entities(ctx).addCursorEntity(*operation.ResultEntityID, result.cursor); err != nil { - return nil, fmt.Errorf("error storing result as cursor entity: %v", err) + return nil, fmt.Errorf("error storing result as cursor entity: %w", err) } return newEmptyResult(), nil } @@ -727,7 +728,7 @@ func executeFindOne(ctx context.Context, operation *operation) (*operationResult case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(collation) case "filter": @@ -735,7 +736,7 @@ func executeFindOne(ctx context.Context, operation *operation) (*operationResult case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) case "maxTimeMS": @@ -756,7 +757,7 @@ func executeFindOne(ctx context.Context, operation *operation) (*operationResult // Ignore ErrNoDocuments errors from Raw. In the event that the cursor // returned in a find operation has no associated documents, Raw will // return ErrNoDocuments. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { err = nil } @@ -784,7 +785,7 @@ func executeFindOneAndDelete(ctx context.Context, operation *operation) (*operat case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(collation) case "comment": @@ -794,7 +795,7 @@ func executeFindOneAndDelete(ctx context.Context, operation *operation) (*operat case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) case "maxTimeMS": @@ -817,7 +818,7 @@ func executeFindOneAndDelete(ctx context.Context, operation *operation) (*operat // Ignore ErrNoDocuments errors from Raw. In the event that the cursor // returned in a find operation has no associated documents, Raw will // return ErrNoDocuments. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { err = nil } @@ -848,7 +849,7 @@ func executeFindOneAndReplace(ctx context.Context, operation *operation) (*opera case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(collation) case "comment": @@ -858,7 +859,7 @@ func executeFindOneAndReplace(ctx context.Context, operation *operation) (*opera case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) case "let": @@ -897,7 +898,7 @@ func executeFindOneAndReplace(ctx context.Context, operation *operation) (*opera // Ignore ErrNoDocuments errors from Raw. In the event that the cursor // returned in a find operation has no associated documents, Raw will // return ErrNoDocuments. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { err = nil } @@ -932,7 +933,7 @@ func executeFindOneAndUpdate(ctx context.Context, operation *operation) (*operat case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(collation) case "comment": @@ -942,7 +943,7 @@ func executeFindOneAndUpdate(ctx context.Context, operation *operation) (*operat case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) case "let": @@ -984,7 +985,7 @@ func executeFindOneAndUpdate(ctx context.Context, operation *operation) (*operat // Ignore ErrNoDocuments errors from Raw. In the event that the cursor // returned in a find operation has no associated documents, Raw will // return ErrNoDocuments. - if err == mongo.ErrNoDocuments { + if errors.Is(err, mongo.ErrNoDocuments) { err = nil } @@ -1078,7 +1079,7 @@ func executeInsertOne(ctx context.Context, operation *operation) (*operationResu if res != nil { t, data, err := bson.MarshalValue(res.InsertedID) if err != nil { - return nil, fmt.Errorf("error converting InsertedID field to BSON: %v", err) + return nil, fmt.Errorf("error converting InsertedID field to BSON: %w", err) } raw = bsoncore.NewDocumentBuilder(). AppendValue("insertedId", bsoncore.Value{Type: t, Data: data}). @@ -1230,7 +1231,7 @@ func executeReplaceOne(ctx context.Context, operation *operation) (*operationRes case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(collation) case "comment": @@ -1240,7 +1241,7 @@ func executeReplaceOne(ctx context.Context, operation *operation) (*operationRes case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) case "replacement": @@ -1347,7 +1348,7 @@ func buildUpdateResultDocument(res *mongo.UpdateResult) (bsoncore.Document, erro if res.UpsertedID != nil { t, data, err := bson.MarshalValue(res.UpsertedID) if err != nil { - return nil, fmt.Errorf("error converting UpsertedID to BSON: %v", err) + return nil, fmt.Errorf("error converting UpsertedID to BSON: %w", err) } builder.AppendValue("upsertedId", bsoncore.Value{Type: t, Data: data}) } @@ -1391,7 +1392,7 @@ func createFindCursor(ctx context.Context, operation *operation) (*cursorResult, case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } opts.SetCollation(collation) case "comment": @@ -1399,7 +1400,7 @@ func createFindCursor(ctx context.Context, operation *operation) (*cursorResult, // TODO with `opts.SetComment(val)` commentString, err := createCommentString(val) if err != nil { - return nil, fmt.Errorf("error creating comment: %v", err) + return nil, fmt.Errorf("error creating comment: %w", err) } opts.SetComment(commentString) case "filter": @@ -1407,7 +1408,7 @@ func createFindCursor(ctx context.Context, operation *operation) (*cursorResult, case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } opts.SetHint(hint) case "let": diff --git a/mongo/integration/unified/crud_helpers.go b/mongo/integration/unified/crud_helpers.go index 245d0e02a9..0a01685988 100644 --- a/mongo/integration/unified/crud_helpers.go +++ b/mongo/integration/unified/crud_helpers.go @@ -48,7 +48,7 @@ func createUpdateArguments(args bson.Raw) (*updateArguments, error) { case "collation": collation, err := createCollation(val.Document()) if err != nil { - return nil, fmt.Errorf("error creating collation: %v", err) + return nil, fmt.Errorf("error creating collation: %w", err) } ua.opts.SetCollation(collation) case "comment": @@ -58,7 +58,7 @@ func createUpdateArguments(args bson.Raw) (*updateArguments, error) { case "hint": hint, err := createHint(val) if err != nil { - return nil, fmt.Errorf("error creating hint: %v", err) + return nil, fmt.Errorf("error creating hint: %w", err) } ua.opts.SetHint(hint) case "let": @@ -66,7 +66,7 @@ func createUpdateArguments(args bson.Raw) (*updateArguments, error) { case "update": ua.update, err = createUpdateValue(val) if err != nil { - return nil, fmt.Errorf("error processing update value: %v", err) + return nil, fmt.Errorf("error processing update value: %w", err) } case "upsert": ua.opts.SetUpsert(val.Boolean()) diff --git a/mongo/integration/unified/cursor_operation_execution.go b/mongo/integration/unified/cursor_operation_execution.go index 390e844ad0..1f9f5fca60 100644 --- a/mongo/integration/unified/cursor_operation_execution.go +++ b/mongo/integration/unified/cursor_operation_execution.go @@ -25,7 +25,7 @@ func executeIterateOnce(ctx context.Context, operation *operation) (*operationRe // as fatal. var res bson.Raw if err := cursor.Decode(&res); err != nil { - return nil, fmt.Errorf("error decoding cursor result: %v", err) + return nil, fmt.Errorf("error decoding cursor result: %w", err) } return newDocumentResult(res, nil), nil @@ -44,7 +44,7 @@ func executeIterateUntilDocumentOrError(ctx context.Context, operation *operatio // We don't expect the server to return malformed documents, so any errors from Decode are treated as fatal. var res bson.Raw if err := cursor.Decode(&res); err != nil { - return nil, fmt.Errorf("error decoding cursor result: %v", err) + return nil, fmt.Errorf("error decoding cursor result: %w", err) } return newDocumentResult(res, nil), nil diff --git a/mongo/integration/unified/database_operation_execution.go b/mongo/integration/unified/database_operation_execution.go index 3c3361a45b..675ab480b7 100644 --- a/mongo/integration/unified/database_operation_execution.go +++ b/mongo/integration/unified/database_operation_execution.go @@ -219,7 +219,7 @@ func executeListCollectionNames(ctx context.Context, operation *operation) (*ope } _, data, err := bson.MarshalValue(names) if err != nil { - return nil, fmt.Errorf("error converting collection names slice to BSON: %v", err) + return nil, fmt.Errorf("error converting collection names slice to BSON: %w", err) } return newValueResult(bsontype.Array, data, nil), nil } @@ -250,12 +250,12 @@ func executeRunCommand(ctx context.Context, operation *operation) (*operationRes case "readPreference": var temp ReadPreference if err := bson.Unmarshal(val.Document(), &temp); err != nil { - return nil, fmt.Errorf("error unmarshalling readPreference option: %v", err) + return nil, fmt.Errorf("error unmarshalling readPreference option: %w", err) } rp, err := temp.ToReadPrefOption() if err != nil { - return nil, fmt.Errorf("error creating readpref.ReadPref object: %v", err) + return nil, fmt.Errorf("error creating readpref.ReadPref object: %w", err) } opts.SetReadPreference(rp) case "writeConcern": @@ -406,7 +406,7 @@ func executeCreateRunCursorCommand(ctx context.Context, operation *operation) (* if cursorID := operation.ResultEntityID; cursorID != nil { err := entities(ctx).addCursorEntity(*cursorID, cursor) if err != nil { - return nil, fmt.Errorf("failed to store result as cursor entity: %v", err) + return nil, fmt.Errorf("failed to store result as cursor entity: %w", err) } } diff --git a/mongo/integration/unified/db_collection_options.go b/mongo/integration/unified/db_collection_options.go index 934f2530ef..2d0dbeb804 100644 --- a/mongo/integration/unified/db_collection_options.go +++ b/mongo/integration/unified/db_collection_options.go @@ -30,7 +30,7 @@ func (d *dbOrCollectionOptions) UnmarshalBSON(data []byte) error { Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { - return fmt.Errorf("error unmarshalling to temporary dbOrCollectionOptions object: %v", err) + return fmt.Errorf("error unmarshalling to temporary dbOrCollectionOptions object: %w", err) } if len(temp.Extra) > 0 { return fmt.Errorf("unrecognized fields for dbOrCollectionOptions: %v", mapKeys(temp.Extra)) @@ -46,7 +46,7 @@ func (d *dbOrCollectionOptions) UnmarshalBSON(data []byte) error { if temp.RP != nil { rp, err := temp.RP.ToReadPrefOption() if err != nil { - return fmt.Errorf("error parsing read preference document: %v", err) + return fmt.Errorf("error parsing read preference document: %w", err) } d.DBOptions.SetReadPreference(rp) @@ -55,7 +55,7 @@ func (d *dbOrCollectionOptions) UnmarshalBSON(data []byte) error { if temp.WC != nil { wc, err := temp.WC.toWriteConcernOption() if err != nil { - return fmt.Errorf("error parsing write concern document: %v", err) + return fmt.Errorf("error parsing write concern document: %w", err) } d.DBOptions.SetWriteConcern(wc) diff --git a/mongo/integration/unified/entity.go b/mongo/integration/unified/entity.go index 06de9b6e78..19c6952ef6 100644 --- a/mongo/integration/unified/entity.go +++ b/mongo/integration/unified/entity.go @@ -292,7 +292,7 @@ func (em *EntityMap) addEntity(ctx context.Context, entityType string, entityOpt } if err != nil { - return fmt.Errorf("error constructing entity of type %q: %v", entityType, err) + return fmt.Errorf("error constructing entity of type %q: %w", entityType, err) } em.allEntities[entityOptions.ID] = struct{}{} return nil @@ -424,7 +424,7 @@ func (em *EntityMap) close(ctx context.Context) []error { var errs []error for id, cursor := range em.cursorEntities { if err := cursor.Close(ctx); err != nil { - errs = append(errs, fmt.Errorf("error closing cursor with ID %q: %v", id, err)) + errs = append(errs, fmt.Errorf("error closing cursor with ID %q: %w", id, err)) } } @@ -435,13 +435,13 @@ func (em *EntityMap) close(ctx context.Context) []error { } if err := client.disconnect(ctx); err != nil { - errs = append(errs, fmt.Errorf("error closing client with ID %q: %v", id, err)) + errs = append(errs, fmt.Errorf("error closing client with ID %q: %w", id, err)) } } for id, clientEncryption := range em.clientEncryptionEntities { if err := clientEncryption.Close(ctx); err != nil { - errs = append(errs, fmt.Errorf("error closing clientEncryption with ID: %q: %v", id, err)) + errs = append(errs, fmt.Errorf("error closing clientEncryption with ID: %q: %w", id, err)) } } @@ -463,7 +463,7 @@ func (em *EntityMap) addClientEntity(ctx context.Context, entityOptions *entityO client, err := newClientEntity(ctx, em, entityOptions) if err != nil { - return fmt.Errorf("error creating client entity: %v", err) + return fmt.Errorf("error creating client entity: %w", err) } em.clientEntities[entityOptions.ID] = client @@ -490,7 +490,7 @@ func (em *EntityMap) addDatabaseEntity(entityOptions *entityOptions) error { // A string is returned as-is. func getKmsCredential(kmsDocument bson.Raw, credentialName string, envVar string, defaultValue string) (string, error) { credentialVal, err := kmsDocument.LookupErr(credentialName) - if err == bsoncore.ErrElementNotFound { + if errors.Is(err, bsoncore.ErrElementNotFound) { return "", nil } if err != nil { @@ -638,7 +638,7 @@ func (em *EntityMap) addClientEncryptionEntity(entityOptions *entityOptions) err "tlsCAFile": tlsCAFile, }) if err != nil { - return fmt.Errorf("error constructing tls config: %v", err) + return fmt.Errorf("error constructing tls config: %w", err) } tlsconf["kmip"] = cfg } @@ -710,7 +710,7 @@ func (em *EntityMap) addSessionEntity(entityOptions *entityOptions) error { sess, err := client.StartSession(sessionOpts) if err != nil { - return fmt.Errorf("error starting session: %v", err) + return fmt.Errorf("error starting session: %w", err) } em.sessions[entityOptions.ID] = sess @@ -730,7 +730,7 @@ func (em *EntityMap) addGridFSBucketEntity(entityOptions *entityOptions) error { bucket, err := gridfs.NewBucket(db, bucketOpts) if err != nil { - return fmt.Errorf("error creating GridFS bucket: %v", err) + return fmt.Errorf("error creating GridFS bucket: %w", err) } em.gridfsBuckets[entityOptions.ID] = bucket diff --git a/mongo/integration/unified/error.go b/mongo/integration/unified/error.go index 2137fa3432..0edc79428a 100644 --- a/mongo/integration/unified/error.go +++ b/mongo/integration/unified/error.go @@ -8,6 +8,7 @@ package unified import ( "context" + "errors" "fmt" "strings" @@ -35,13 +36,13 @@ type expectedError struct { func verifyOperationError(ctx context.Context, expected *expectedError, result *operationResult) error { // The unified spec test format doesn't treat ErrUnacknowledgedWrite as an error, so set result.Err to nil // to indicate that no error occurred. - if result.Err == mongo.ErrUnacknowledgedWrite { + if errors.Is(result.Err, mongo.ErrUnacknowledgedWrite) { result.Err = nil } if expected == nil { if result.Err != nil { - return fmt.Errorf("expected no error, but got %v", result.Err) + return fmt.Errorf("expected no error, but got %w", result.Err) } return nil } @@ -57,7 +58,7 @@ func verifyOperationError(ctx context.Context, expected *expectedError, result * expectedErrMsg := strings.ToLower(*expected.ErrorSubstring) actualErrMsg := strings.ToLower(result.Err.Error()) if !strings.Contains(actualErrMsg, expectedErrMsg) { - return fmt.Errorf("expected error %v to contain substring %s", result.Err, *expected.ErrorSubstring) + return fmt.Errorf("expected error %w to contain substring %s", result.Err, *expected.ErrorSubstring) } } @@ -68,14 +69,14 @@ func verifyOperationError(ctx context.Context, expected *expectedError, result * // The unified test format spec considers network errors to be client-side errors. isClientError := !serverError || mongo.IsNetworkError(result.Err) if *expected.IsClientError != isClientError { - return fmt.Errorf("expected error %v to be a client error: %v, is client error: %v", result.Err, + return fmt.Errorf("expected error %w to be a client error: %v, is client error: %v", result.Err, *expected.IsClientError, isClientError) } } if expected.IsTimeoutError != nil { isTimeoutError := mongo.IsTimeout(result.Err) if *expected.IsTimeoutError != isTimeoutError { - return fmt.Errorf("expected error %v to be a timeout error: %v, is timeout error: %v", result.Err, + return fmt.Errorf("expected error %w to be a timeout error: %v, is timeout error: %v", result.Err, *expected.IsTimeoutError, isTimeoutError) } } @@ -95,7 +96,7 @@ func verifyOperationError(ctx context.Context, expected *expectedError, result * } } if !found { - return fmt.Errorf("expected error %v to have code %d", result.Err, *expected.Code) + return fmt.Errorf("expected error %w to have code %d", result.Err, *expected.Code) } } if expected.CodeName != nil { @@ -107,23 +108,23 @@ func verifyOperationError(ctx context.Context, expected *expectedError, result * } } if !found { - return fmt.Errorf("expected error %v to have code name %q", result.Err, *expected.CodeName) + return fmt.Errorf("expected error %w to have code name %q", result.Err, *expected.CodeName) } } for _, label := range expected.IncludedLabels { if !stringSliceContains(details.labels, label) { - return fmt.Errorf("expected error %v to contain label %q", result.Err, label) + return fmt.Errorf("expected error %w to contain label %q", result.Err, label) } } for _, label := range expected.OmittedLabels { if stringSliceContains(details.labels, label) { - return fmt.Errorf("expected error %v to not contain label %q", result.Err, label) + return fmt.Errorf("expected error %w to not contain label %q", result.Err, label) } } if expected.ExpectedResult != nil { if err := verifyOperationResult(ctx, *expected.ExpectedResult, result); err != nil { - return fmt.Errorf("result comparison error: %v", err) + return fmt.Errorf("result comparison error: %w", err) } } @@ -136,7 +137,7 @@ func verifyOperationError(ctx context.Context, expected *expectedError, result * gotValue := documentToRawValue(details.raw) expectedValue := documentToRawValue(*expected.ErrorResponse) if err := verifyValuesMatch(ctx, expectedValue, gotValue, true); err != nil { - return fmt.Errorf("error response comparison error: %v", err) + return fmt.Errorf("error response comparison error: %w", err) } } return nil diff --git a/mongo/integration/unified/event_verification.go b/mongo/integration/unified/event_verification.go index 1d54e3fb2a..6516000416 100644 --- a/mongo/integration/unified/event_verification.go +++ b/mongo/integration/unified/event_verification.go @@ -113,7 +113,7 @@ func (e *expectedEvents) UnmarshalBSON(data []byte) error { Extra map[string]interface{} `bson:",inline"` } if err := bson.Unmarshal(data, &temp); err != nil { - return fmt.Errorf("error unmarshalling to temporary expectedEvents object: %v", err) + return fmt.Errorf("error unmarshalling to temporary expectedEvents object: %w", err) } if len(temp.Extra) > 0 { return fmt.Errorf("unrecognized fields for expectedEvents: %v", temp.Extra) @@ -137,7 +137,7 @@ func (e *expectedEvents) UnmarshalBSON(data []byte) error { } if err := temp.Events.Unmarshal(target); err != nil { - return fmt.Errorf("error unmarshalling events array: %v", err) + return fmt.Errorf("error unmarshalling events array: %w", err) } if temp.IgnoreExtraEvents != nil { diff --git a/mongo/integration/unified/gridfs_bucket_operation_execution.go b/mongo/integration/unified/gridfs_bucket_operation_execution.go index 3be6fded0c..8a9adc540d 100644 --- a/mongo/integration/unified/gridfs_bucket_operation_execution.go +++ b/mongo/integration/unified/gridfs_bucket_operation_execution.go @@ -235,7 +235,7 @@ func executeBucketUpload(ctx context.Context, operation *operation) (*operationR case "source": fileBytes, err = hex.DecodeString(val.Document().Lookup("$$hexBytes").StringValue()) if err != nil { - return nil, fmt.Errorf("error converting source string to bytes: %v", err) + return nil, fmt.Errorf("error converting source string to bytes: %w", err) } case "contentType": return nil, newSkipTestError("the deprecated contentType file option is not supported") @@ -263,7 +263,7 @@ func executeBucketUpload(ctx context.Context, operation *operation) (*operationR Value: fileID[:], } if err := entities(ctx).addBSONEntity(*operation.ResultEntityID, fileIDValue); err != nil { - return nil, fmt.Errorf("error storing result as BSON entity: %v", err) + return nil, fmt.Errorf("error storing result as BSON entity: %w", err) } } diff --git a/mongo/integration/unified_runner_thread_helpers_test.go b/mongo/integration/unified_runner_thread_helpers_test.go index 8ab2530b3d..2feb5ecfc3 100644 --- a/mongo/integration/unified_runner_thread_helpers_test.go +++ b/mongo/integration/unified_runner_thread_helpers_test.go @@ -52,7 +52,7 @@ func (b *backgroundRoutine) start() { } if err := runOperation(b.mt, b.testCase, op, nil, nil); err != nil { - b.err = fmt.Errorf("error running operation %s: %v", op.Name, err) + b.err = fmt.Errorf("error running operation %s: %w", op.Name, err) } } }() diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index fb304ed225..cd89508933 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -889,9 +889,8 @@ func compareErrors(err1, err2 error) bool { return false } - ospe1, ok1 := err1.(*os.PathError) - ospe2, ok2 := err2.(*os.PathError) - if ok1 && ok2 { + var ospe1, ospe2 *os.PathError + if errors.As(err1, &ospe1) && errors.As(err2, &ospe2) { return ospe1.Op == ospe2.Op && ospe1.Path == ospe2.Path } diff --git a/mongo/read_write_concern_spec_test.go b/mongo/read_write_concern_spec_test.go index 51d46b4c2a..5e196895be 100644 --- a/mongo/read_write_concern_spec_test.go +++ b/mongo/read_write_concern_spec_test.go @@ -8,6 +8,7 @@ package mongo import ( "bytes" + "errors" "io/ioutil" "path" "reflect" @@ -175,7 +176,7 @@ func runDocumentTest(t *testing.T, test documentTest) { } expected := *test.WriteConcernDocument - if err == writeconcern.ErrEmptyWriteConcern { + if errors.Is(err, writeconcern.ErrEmptyWriteConcern) { elems, _ := expected.Elements() if len(elems) == 0 { assert.NotNil(t, test.IsServerDefault, "expected write concern %s, got empty", expected) diff --git a/mongo/search_index_view.go b/mongo/search_index_view.go index 6a7871531e..56e1ffc3f3 100644 --- a/mongo/search_index_view.go +++ b/mongo/search_index_view.go @@ -8,6 +8,7 @@ package mongo import ( "context" + "errors" "fmt" "strconv" @@ -214,7 +215,8 @@ func (siv SearchIndexView) DropOne( Timeout(siv.coll.client.timeout) err = op.Execute(ctx) - if de, ok := err.(driver.Error); ok && de.NamespaceNotFound() { + var de driver.Error + if errors.As(err, &de) && de.NamespaceNotFound() { return nil } return err diff --git a/mongo/session.go b/mongo/session.go index 8f1e029b95..5b5c4ceeb6 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -245,7 +245,8 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(ctx SessionCo default: } - if cerr, ok := err.(CommandError); ok { + var cerr CommandError + if errors.As(err, &cerr) { if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() { continue } diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index 9a387264f9..4917fe0fed 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -50,7 +50,8 @@ func TestConvenientTransactions(t *testing.T) { {"killAllSessions", bson.A{}}, }).Err() if err != nil { - if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted { + var ce CommandError + if !errors.As(err, &ce) || ce.Code != errorInterrupted { t.Fatalf("killAllSessions error: %v", err) } } diff --git a/x/mongo/driver/auth/sasl.go b/x/mongo/driver/auth/sasl.go index a7ae3368f0..2a84b53a64 100644 --- a/x/mongo/driver/auth/sasl.go +++ b/x/mongo/driver/auth/sasl.go @@ -102,7 +102,7 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstRespon var saslResp saslResponse err := bson.Unmarshal(firstResponse, &saslResp) if err != nil { - fullErr := fmt.Errorf("unmarshal error: %v", err) + fullErr := fmt.Errorf("unmarshal error: %w", err) return newError(fullErr, sc.mechanism) } @@ -146,7 +146,7 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstRespon err = bson.Unmarshal(rdr, &saslResp) if err != nil { - fullErr := fmt.Errorf("unmarshal error: %v", err) + fullErr := fmt.Errorf("unmarshal error: %w", err) return newError(fullErr, sc.mechanism) } } diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index 827e536137..df676aa103 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -79,7 +79,7 @@ type CursorResponse struct { func NewCursorResponse(info ResponseInfo) (CursorResponse, error) { response := info.ServerResponse cur, err := response.LookupErr("cursor") - if err == bsoncore.ErrElementNotFound { + if errors.Is(err, bsoncore.ErrElementNotFound) { return CursorResponse{}, ErrNoCursor } if err != nil { diff --git a/x/mongo/driver/crypt.go b/x/mongo/driver/crypt.go index 4c254c03cf..576c007d67 100644 --- a/x/mongo/driver/crypt.go +++ b/x/mongo/driver/crypt.go @@ -9,6 +9,7 @@ package driver import ( "context" "crypto/tls" + "errors" "fmt" "io" "strings" @@ -399,7 +400,7 @@ func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error { res := make([]byte, bytesNeeded) bytesRead, err := conn.Read(res) - if err != nil && err != io.EOF { + if err != nil && !errors.Is(err, io.EOF) { return err } diff --git a/x/mongo/driver/integration/main_test.go b/x/mongo/driver/integration/main_test.go index ef6331853d..563462d65f 100644 --- a/x/mongo/driver/integration/main_test.go +++ b/x/mongo/driver/integration/main_test.go @@ -8,6 +8,7 @@ package integration import ( "context" + "errors" "flag" "fmt" "os" @@ -69,8 +70,8 @@ func autherr(t *testing.T, err error) { t.Helper() switch e := err.(type) { case topology.ConnectionError: - _, ok := e.Wrapped.(*auth.Error) - if !ok { + var authErr *auth.Error + if !errors.As(e.Wrapped, &authErr) { t.Fatal("Expected auth error and didn't get one") } case *auth.Error: @@ -134,7 +135,8 @@ func dropCollection(t *testing.T, dbname, colname string) { err := operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "drop", colname))). Database(dbname).ServerSelector(description.WriteSelector()).Deployment(integtest.Topology(t)). Execute(context.Background()) - if de, ok := err.(driver.Error); err != nil && !(ok && de.NamespaceNotFound()) { + var de driver.Error + if err != nil && !(errors.As(err, &de) && de.NamespaceNotFound()) { require.NoError(t, err) } } diff --git a/x/mongo/driver/integration/scram_test.go b/x/mongo/driver/integration/scram_test.go index 18d99a5c40..bfd735abed 100644 --- a/x/mongo/driver/integration/scram_test.go +++ b/x/mongo/driver/integration/scram_test.go @@ -171,7 +171,7 @@ func createScramUsers(t *testing.T, s driver.Server, cases []scramTestCase) erro ) _, err := runCommand(s, db, newUserCmd) if err != nil { - return fmt.Errorf("Couldn't create user '%s' on db '%s': %v", c.username, integtest.DBName(t), err) + return fmt.Errorf("Couldn't create user '%s' on db '%s': %w", c.username, integtest.DBName(t), err) } } return nil diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 33ed562426..9064a7bd96 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -459,7 +459,7 @@ func (op Operation) getServerAndConnection( if err := pinnedConn.PinToTransaction(); err != nil { // Close the original connection to avoid a leak. _ = conn.Close() - return nil, nil, fmt.Errorf("error incrementing connection reference count when starting a transaction: %v", err) + return nil, nil, fmt.Errorf("error incrementing connection reference count when starting a transaction: %w", err) } op.Client.PinnedConnection = pinnedConn } @@ -627,7 +627,8 @@ func (op Operation) Execute(ctx context.Context) error { // If the returned error is retryable and there are retries remaining (negative // retries means retry indefinitely), then retry the operation. Set the server // and connection to nil to request a new server and connection. - if rerr, ok := err.(RetryablePoolError); ok && rerr.Retryable() && retries != 0 { + var rerr RetryablePoolError + if errors.As(err, &rerr) && rerr.Retryable() && retries != 0 { resetForRetry(err) continue } @@ -1749,7 +1750,7 @@ func (op Operation) createReadPref(desc description.SelectedServer, isOpQuery bo doc = bsoncore.AppendBooleanElement(doc, "enabled", *hedgeEnabled) doc, err = bsoncore.AppendDocumentEnd(doc, hedgeIdx) if err != nil { - return nil, fmt.Errorf("error creating hedge document: %v", err) + return nil, fmt.Errorf("error creating hedge document: %w", err) } } diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index 8de1e9f8d9..3de9b6b9ca 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -132,8 +132,8 @@ func (c *Count) Execute(ctx context.Context) error { // Swallow error if NamespaceNotFound(26) is returned from aggregate on non-existent namespace if err != nil { - dErr, ok := err.(driver.Error) - if ok && dErr.Code == 26 { + var dErr driver.Error + if errors.As(err, &dErr) && dErr.Code == 26 { err = nil } } diff --git a/x/mongo/driver/session/client_session_test.go b/x/mongo/driver/session/client_session_test.go index 245f3255b9..1ba1198802 100644 --- a/x/mongo/driver/session/client_session_test.go +++ b/x/mongo/driver/session/client_session_test.go @@ -8,6 +8,7 @@ package session import ( "bytes" + "errors" "testing" "go.mongodb.org/mongo-driver/bson/primitive" @@ -125,12 +126,12 @@ func TestClientSession(t *testing.T) { require.Nil(t, err, "Unexpected error") err = sess.CommitTransaction() - if err != ErrNoTransactStarted { + if !errors.Is(err, ErrNoTransactStarted) { t.Errorf("expected error, got %v", err) } err = sess.AbortTransaction() - if err != ErrNoTransactStarted { + if !errors.Is(err, ErrNoTransactStarted) { t.Errorf("expected error, got %v", err) } @@ -145,7 +146,7 @@ func TestClientSession(t *testing.T) { } err = sess.StartTransaction(nil) - if err != ErrTransactInProgress { + if !errors.Is(err, ErrTransactInProgress) { t.Errorf("expected error, got %v", err) } @@ -156,7 +157,7 @@ func TestClientSession(t *testing.T) { } err = sess.StartTransaction(nil) - if err != ErrTransactInProgress { + if !errors.Is(err, ErrTransactInProgress) { t.Errorf("expected error, got %v", err) } @@ -167,7 +168,7 @@ func TestClientSession(t *testing.T) { } err = sess.AbortTransaction() - if err != ErrAbortAfterCommit { + if !errors.Is(err, ErrAbortAfterCommit) { t.Errorf("expected error, got %v", err) } @@ -184,12 +185,12 @@ func TestClientSession(t *testing.T) { } err = sess.AbortTransaction() - if err != ErrAbortTwice { + if !errors.Is(err, ErrAbortTwice) { t.Errorf("expected error, got %v", err) } err = sess.CommitTransaction() - if err != ErrCommitAfterAbort { + if !errors.Is(err, ErrCommitAfterAbort) { t.Errorf("expected error, got %v", err) } }) diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index ac78c12045..80fbdf08c4 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -323,7 +323,8 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead if !contextDeadlineUsed { return originalError } - if netErr, ok := originalError.(net.Error); ok && netErr.Timeout() { + var netErr net.Error + if errors.As(originalError, &netErr) && netErr.Timeout() { return context.DeadlineExceeded } @@ -411,7 +412,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { // We closeConnection the connection because we don't know if there are other bytes left to read. c.close() message := errMsg - if err == io.EOF { + if errors.Is(err, io.EOF) { message = "socket was unexpectedly closed" } return nil, ConnectionError{ diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index 0934beed89..3dd031f2ea 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -267,7 +267,7 @@ func percentile(perc float64, samples []time.Duration, minSamples int) time.Dura p, err := stats.Percentile(floatSamples, perc) if err != nil { - panic(fmt.Errorf("x/mongo/driver/topology: error calculating %f percentile RTT: %v for samples:\n%v", perc, err, floatSamples)) + panic(fmt.Errorf("x/mongo/driver/topology: error calculating %f percentile RTT: %w for samples:\n%v", perc, err, floatSamples)) } return time.Duration(p) } @@ -318,7 +318,7 @@ func (r *rttMonitor) Stats() string { var err error stdDev, err = stats.StandardDeviation(floatSamples) if err != nil { - panic(fmt.Errorf("x/mongo/driver/topology: error calculating standard deviation RTT: %v for samples:\n%v", err, floatSamples)) + panic(fmt.Errorf("x/mongo/driver/topology: error calculating standard deviation RTT: %w for samples:\n%v", err, floatSamples)) } } diff --git a/x/mongo/driver/topology/sdam_spec_test.go b/x/mongo/driver/topology/sdam_spec_test.go index e09ffb87f0..ee43a2bd7d 100644 --- a/x/mongo/driver/topology/sdam_spec_test.go +++ b/x/mongo/driver/topology/sdam_spec_test.go @@ -211,11 +211,11 @@ var lock sync.Mutex func (r *response) UnmarshalBSON(buf []byte) error { doc := bson.Raw(buf) if err := doc.Index(0).Value().Unmarshal(&r.Host); err != nil { - return fmt.Errorf("error unmarshalling Host: %v", err) + return fmt.Errorf("error unmarshalling Host: %w", err) } if err := doc.Index(1).Value().Unmarshal(&r.Hello); err != nil { - return fmt.Errorf("error unmarshalling Hello: %v", err) + return fmt.Errorf("error unmarshalling Hello: %w", err) } return nil diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 41f93a7df2..4ccd562093 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -416,8 +416,8 @@ func (s *Server) RequestImmediateCheck() { // (error, true) if the error is a WriteConcernError and the falls under the requirements for SDAM error // handling and (nil, false) otherwise. func getWriteConcernErrorForProcessing(err error) (*driver.WriteConcernError, bool) { - writeCmdErr, ok := err.(driver.WriteCommandError) - if !ok { + var writeCmdErr driver.WriteCommandError + if !errors.As(err, &writeCmdErr) { return nil, false } @@ -523,7 +523,8 @@ func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessE } // Ignore transient timeout errors. - if netErr, ok := wrappedConnErr.(net.Error); ok && netErr.Timeout() { + var netErr net.Error + if errors.As(wrappedConnErr, &netErr) && netErr.Timeout() { return driver.NoChange } if errors.Is(wrappedConnErr, context.Canceled) || errors.Is(wrappedConnErr, context.DeadlineExceeded) { @@ -602,7 +603,7 @@ func (s *Server) update() { // Perform the next check. desc, err := s.check() - if err == errCheckCancelled { + if errors.Is(err, errCheckCancelled) { if atomic.LoadInt64(&s.state) != serverConnected { continue } @@ -629,7 +630,8 @@ func (s *Server) update() { // We want to immediately retry on timeout error. Continue to next loop. return true } - if err, ok := err.(net.Error); ok && err.Timeout() { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { timeoutCnt++ // We want to immediately retry on timeout error. Continue to next loop. return true From 6612270b1fd616ff664d811eae07683080b99091 Mon Sep 17 00:00:00 2001 From: Lokesh Kumar Date: Thu, 7 Dec 2023 23:02:43 +0100 Subject: [PATCH 2/2] revert errors.As to use type assertion --- bson/bsoncodec/registry_test.go | 6 ++---- examples/documentation_examples/examples.go | 7 ++----- internal/aws/awserr/types.go | 4 +--- mongo/collection.go | 7 +++---- mongo/database.go | 4 ++-- mongo/errors.go | 18 +++++------------- mongo/index_view.go | 3 +-- mongo/search_index_view.go | 4 +--- mongo/session.go | 3 +-- mongo/with_transactions_test.go | 3 +-- x/mongo/driver/batch_cursor.go | 3 +-- x/mongo/driver/connstring/connstring.go | 3 +-- x/mongo/driver/integration/main_test.go | 7 ++----- x/mongo/driver/operation.go | 6 ++---- x/mongo/driver/operation/count.go | 4 ++-- x/mongo/driver/topology/connection.go | 3 +-- x/mongo/driver/topology/server.go | 6 ++---- 17 files changed, 30 insertions(+), 61 deletions(-) diff --git a/bson/bsoncodec/registry_test.go b/bson/bsoncodec/registry_test.go index 39863ac874..03500dca44 100644 --- a/bson/bsoncodec/registry_test.go +++ b/bson/bsoncodec/registry_test.go @@ -352,8 +352,7 @@ func TestRegistryBuilder(t *testing.T) { }) t.Run("Decoder", func(t *testing.T) { wanterr := tc.wanterr - var ene ErrNoEncoder - if errors.As(tc.wanterr, &ene) { + if ene, ok := tc.wanterr.(ErrNoEncoder); ok { wanterr = ErrNoDecoder(ene) } @@ -777,8 +776,7 @@ func TestRegistry(t *testing.T) { t.Parallel() wanterr := tc.wanterr - var ene ErrNoEncoder - if errors.As(tc.wanterr, &ene) { + if ene, ok := tc.wanterr.(ErrNoEncoder); ok { wanterr = ErrNoDecoder(ene) } diff --git a/examples/documentation_examples/examples.go b/examples/documentation_examples/examples.go index c6bfd0faed..ca92646865 100644 --- a/examples/documentation_examples/examples.go +++ b/examples/documentation_examples/examples.go @@ -8,7 +8,6 @@ package documentation_examples import ( "context" - "errors" "fmt" "io/ioutil" logger "log" @@ -1817,8 +1816,7 @@ func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.Session log.Println("Transaction aborted. Caught exception during transaction.") // If transient error, retry the whole transaction - var cmdErr mongo.CommandError - if errors.As(err, &cmdErr) && cmdErr.HasErrorLabel("TransientTransactionError") { + if cmdErr, ok := err.(mongo.CommandError); ok && cmdErr.HasErrorLabel("TransientTransactionError") { log.Println("TransientTransactionError, retrying transaction...") continue } @@ -1885,8 +1883,7 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { log.Println("Transaction aborted. Caught exception during transaction.") // If transient error, retry the whole transaction - var cmdErr mongo.CommandError - if errors.As(err, &cmdErr) && cmdErr.HasErrorLabel("TransientTransactionError") { + if cmdErr, ok := err.(mongo.CommandError); ok && cmdErr.HasErrorLabel("TransientTransactionError") { log.Println("TransientTransactionError, retrying transaction...") continue } diff --git a/internal/aws/awserr/types.go b/internal/aws/awserr/types.go index b70168f7d3..18cb4cda28 100644 --- a/internal/aws/awserr/types.go +++ b/internal/aws/awserr/types.go @@ -11,7 +11,6 @@ package awserr import ( - "errors" "fmt" ) @@ -107,8 +106,7 @@ func (b baseError) OrigErr() error { case 1: return b.errs[0] default: - var err Error - if errors.As(b.errs[0], &err) { + if err, ok := b.errs[0].(Error); ok { return NewBatchError(err.Code(), err.Message(), b.errs[1:]) } return NewBatchError("BatchedErrors", diff --git a/mongo/collection.go b/mongo/collection.go index 74b55bb2cf..ac173307ff 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -929,8 +929,7 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { err = op.Execute(a.ctx) if err != nil { - var wce driver.WriteCommandError - if errors.As(err, &wce) && wce.WriteConcernError != nil { + if wce, ok := err.(driver.WriteCommandError); ok && wce.WriteConcernError != nil { return nil, *convertDriverWriteConcernError(wce.WriteConcernError) } return nil, replaceErrors(err) @@ -1869,8 +1868,8 @@ func (coll *Collection) drop(ctx context.Context) error { err = op.Execute(ctx) // ignore namespace not found errors - var driverErr driver.Error - if !errors.As(err, &driverErr) || !driverErr.NamespaceNotFound() { + driverErr, ok := err.(driver.Error) + if !ok || (ok && !driverErr.NamespaceNotFound()) { return replaceErrors(err) } return nil diff --git a/mongo/database.go b/mongo/database.go index 69f1e36bb4..c5cda9e5bd 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -312,8 +312,8 @@ func (db *Database) Drop(ctx context.Context) error { err = op.Execute(ctx) - var driverErr driver.Error - if err != nil && (!errors.As(err, &driverErr) || !driverErr.NamespaceNotFound()) { + driverErr, ok := err.(driver.Error) + if err != nil && (!ok || !driverErr.NamespaceNotFound()) { return replaceErrors(err) } return nil diff --git a/mongo/errors.go b/mongo/errors.go index 777746d5a0..d92c9ca9bd 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -55,9 +55,7 @@ func replaceErrors(err error) error { if errors.Is(err, topology.ErrTopologyClosed) { return ErrClientDisconnected } - - var de driver.Error - if errors.As(err, &de) { + if de, ok := err.(driver.Error); ok { return CommandError{ Code: de.Code, Message: de.Message, @@ -67,9 +65,7 @@ func replaceErrors(err error) error { Raw: bson.Raw(de.Raw), } } - - var qe driver.QueryFailureError - if errors.As(err, &qe) { + if qe, ok := err.(driver.QueryFailureError); ok { // qe.Message is "command failure" ce := CommandError{ Name: qe.Message, @@ -88,9 +84,7 @@ func replaceErrors(err error) error { return ce } - - var me mongocrypt.Error - if errors.As(err, &me) { + if me, ok := err.(mongocrypt.Error); ok { return MongocryptError{Code: me.Code, Message: me.Message} } @@ -98,8 +92,7 @@ func replaceErrors(err error) error { return ErrNilValue } - var marshalErr codecutil.MarshalError - if errors.As(err, &marshalErr) { + if marshalErr, ok := err.(codecutil.MarshalError); ok { return MarshalError{ Value: marshalErr.Value, Err: marshalErr.Err, @@ -178,8 +171,7 @@ func unwrap(err error) error { // errorHasLabel returns true if err contains the specified label func errorHasLabel(err error, label string) bool { for ; err != nil; err = unwrap(err) { - var le LabeledError - if errors.As(err, &le) && le.HasErrorLabel(label) { + if le, ok := err.(LabeledError); ok && le.HasErrorLabel(label) { return true } } diff --git a/mongo/index_view.go b/mongo/index_view.go index 84fe026a17..8d3555d0b0 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -53,8 +53,7 @@ type IndexModel struct { } func isNamespaceNotFoundError(err error) bool { - var de driver.Error - if errors.As(err, &de) { + if de, ok := err.(driver.Error); ok { return de.Code == 26 } return false diff --git a/mongo/search_index_view.go b/mongo/search_index_view.go index 56e1ffc3f3..6a7871531e 100644 --- a/mongo/search_index_view.go +++ b/mongo/search_index_view.go @@ -8,7 +8,6 @@ package mongo import ( "context" - "errors" "fmt" "strconv" @@ -215,8 +214,7 @@ func (siv SearchIndexView) DropOne( Timeout(siv.coll.client.timeout) err = op.Execute(ctx) - var de driver.Error - if errors.As(err, &de) && de.NamespaceNotFound() { + if de, ok := err.(driver.Error); ok && de.NamespaceNotFound() { return nil } return err diff --git a/mongo/session.go b/mongo/session.go index 5b5c4ceeb6..8f1e029b95 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -245,8 +245,7 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(ctx SessionCo default: } - var cerr CommandError - if errors.As(err, &cerr) { + if cerr, ok := err.(CommandError); ok { if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() { continue } diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index 4917fe0fed..9a387264f9 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -50,8 +50,7 @@ func TestConvenientTransactions(t *testing.T) { {"killAllSessions", bson.A{}}, }).Err() if err != nil { - var ce CommandError - if !errors.As(err, &ce) || ce.Code != errorInterrupted { + if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted { t.Fatalf("killAllSessions error: %v", err) } } diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index df676aa103..23b4a6539d 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -451,8 +451,7 @@ func (bc *BatchCursor) getMore(ctx context.Context) { // If we're in load balanced mode and the pinned connection encounters a network error, we should not use it for // future commands. Per the spec, the connection will not be unpinned until the cursor is actually closed, but // we set the cursor ID to 0 to ensure the Close() call will not execute a killCursors command. - var driverErr Error - if errors.As(bc.err, &driverErr) && driverErr.NetworkError() && bc.connection != nil { + if driverErr, ok := bc.err.(Error); ok && driverErr.NetworkError() && bc.connection != nil { bc.id = 0 } diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index b3dc97cf38..cd43136471 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -624,8 +624,7 @@ func (p *parser) addHost(host string) error { // this is unfortunate that SplitHostPort actually requires // a port to exist. if err != nil { - var addrError *net.AddrError - if !errors.As(err, &addrError) || addrError.Err != "missing port in address" { + if addrError, ok := err.(*net.AddrError); !ok || addrError.Err != "missing port in address" { return err } } diff --git a/x/mongo/driver/integration/main_test.go b/x/mongo/driver/integration/main_test.go index 563462d65f..f82b0175e2 100644 --- a/x/mongo/driver/integration/main_test.go +++ b/x/mongo/driver/integration/main_test.go @@ -8,7 +8,6 @@ package integration import ( "context" - "errors" "flag" "fmt" "os" @@ -70,8 +69,7 @@ func autherr(t *testing.T, err error) { t.Helper() switch e := err.(type) { case topology.ConnectionError: - var authErr *auth.Error - if !errors.As(e.Wrapped, &authErr) { + if _, ok := e.Wrapped.(*auth.Error); !ok { t.Fatal("Expected auth error and didn't get one") } case *auth.Error: @@ -135,8 +133,7 @@ func dropCollection(t *testing.T, dbname, colname string) { err := operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "drop", colname))). Database(dbname).ServerSelector(description.WriteSelector()).Deployment(integtest.Topology(t)). Execute(context.Background()) - var de driver.Error - if err != nil && !(errors.As(err, &de) && de.NamespaceNotFound()) { + if de, ok := err.(driver.Error); err != nil && !(ok && de.NamespaceNotFound()) { require.NoError(t, err) } } diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 9064a7bd96..b39a63abe4 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -142,8 +142,7 @@ func convertInt64PtrToInt32Ptr(i64 *int64) *int32 { // write errors are included since the actual command did succeed, only writes // failed. func (info finishedInformation) success() bool { - var writeCmdErr WriteCommandError - if errors.As(info.cmdErr, &writeCmdErr) { + if _, ok := info.cmdErr.(WriteCommandError); ok { return true } @@ -627,8 +626,7 @@ func (op Operation) Execute(ctx context.Context) error { // If the returned error is retryable and there are retries remaining (negative // retries means retry indefinitely), then retry the operation. Set the server // and connection to nil to request a new server and connection. - var rerr RetryablePoolError - if errors.As(err, &rerr) && rerr.Retryable() && retries != 0 { + if rerr, ok := err.(RetryablePoolError); ok && rerr.Retryable() && retries != 0 { resetForRetry(err) continue } diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index 3de9b6b9ca..8de1e9f8d9 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -132,8 +132,8 @@ func (c *Count) Execute(ctx context.Context) error { // Swallow error if NamespaceNotFound(26) is returned from aggregate on non-existent namespace if err != nil { - var dErr driver.Error - if errors.As(err, &dErr) && dErr.Code == 26 { + dErr, ok := err.(driver.Error) + if ok && dErr.Code == 26 { err = nil } } diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 80fbdf08c4..88bfc03cdd 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -323,8 +323,7 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead if !contextDeadlineUsed { return originalError } - var netErr net.Error - if errors.As(originalError, &netErr) && netErr.Timeout() { + if netErr, ok := originalError.(net.Error); ok && netErr.Timeout() { return context.DeadlineExceeded } diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 4ccd562093..751d05de93 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -523,8 +523,7 @@ func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessE } // Ignore transient timeout errors. - var netErr net.Error - if errors.As(wrappedConnErr, &netErr) && netErr.Timeout() { + if netErr, ok := wrappedConnErr.(net.Error); ok && netErr.Timeout() { return driver.NoChange } if errors.Is(wrappedConnErr, context.Canceled) || errors.Is(wrappedConnErr, context.DeadlineExceeded) { @@ -630,8 +629,7 @@ func (s *Server) update() { // We want to immediately retry on timeout error. Continue to next loop. return true } - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { + if err, ok := err.(net.Error); ok && err.Timeout() { timeoutCnt++ // We want to immediately retry on timeout error. Continue to next loop. return true