From c61efde64860ef8cd2a53e977db6eb18d4bc93e7 Mon Sep 17 00:00:00 2001 From: Lokesh Kumar Date: Thu, 30 Nov 2023 02:12:16 +0100 Subject: [PATCH] GODRIVER-2603 (Contd.) Revised error handling using Go 1.13 error APIs (#1476) --- bson/bsoncodec/default_value_decoders.go | 2 +- bson/bsoncodec/default_value_encoders.go | 6 +++--- bson/bsoncodec/map_codec.go | 5 +++-- bson/bsoncodec/registry_test.go | 7 +++++-- bson/bsonrw/copier.go | 5 +++-- bson/bsonrw/extjson_parser.go | 2 +- bson/bsonrw/extjson_parser_test.go | 3 ++- bson/bsonrw/extjson_reader.go | 5 +++-- bson/bsonrw/json_scanner.go | 12 ++++++------ bson/bsonrw/value_reader_test.go | 7 ++++--- bson/decoder_test.go | 2 +- bson/primitive_codecs_test.go | 4 ++-- bson/raw_test.go | 7 ++++--- examples/documentation_examples/examples.go | 7 +++++-- internal/logger/logger.go | 2 +- mongo/bulk_write.go | 13 +++++++------ mongo/change_stream.go | 4 ++-- mongo/client.go | 2 +- mongo/client_encryption.go | 2 +- mongo/collection.go | 10 +++++----- mongo/integration/mtest/proxy_dialer.go | 6 +++--- mongo/integration/mtest/received_message.go | 4 ++-- mongo/integration/mtest/sent_message.go | 6 +++--- mongo/integration/mtest/setup.go | 12 ++++++------ x/bsonx/bsoncore/array_test.go | 2 +- x/bsonx/bsoncore/document_sequence_test.go | 7 ++++--- x/bsonx/bsoncore/document_test.go | 7 ++++--- x/mongo/driver/batch_cursor.go | 3 ++- x/mongo/driver/ocsp/cache_test.go | 4 ++-- x/mongo/driver/ocsp/config.go | 2 +- x/mongo/driver/ocsp/ocsp.go | 6 +++--- x/mongo/driver/operation.go | 5 +++-- x/mongo/driver/topology/errors.go | 7 ++++--- 33 files changed, 98 insertions(+), 80 deletions(-) diff --git a/bson/bsoncodec/default_value_decoders.go b/bson/bsoncodec/default_value_decoders.go index 4ba1b61019..7e08aab35e 100644 --- a/bson/bsoncodec/default_value_decoders.go +++ b/bson/bsoncodec/default_value_decoders.go @@ -1787,7 +1787,7 @@ func (DefaultValueDecoders) decodeElemsFromDocumentReader(dc DecodeContext, dr b elems := make([]reflect.Value, 0) for { key, vr, err := dr.ReadElement() - if err == bsonrw.ErrEOD { + if errors.Is(err, bsonrw.ErrEOD) { break } if err != nil { diff --git a/bson/bsoncodec/default_value_encoders.go b/bson/bsoncodec/default_value_encoders.go index 91a48c3a5b..4751ae995e 100644 --- a/bson/bsoncodec/default_value_encoders.go +++ b/bson/bsoncodec/default_value_encoders.go @@ -352,7 +352,7 @@ func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.Docum return err } - if lookupErr == errInvalidValue { + if errors.Is(lookupErr, errInvalidValue) { err = vw.WriteNull() if err != nil { return err @@ -427,7 +427,7 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.Val return err } - if lookupErr == errInvalidValue { + if errors.Is(lookupErr, errInvalidValue) { err = vw.WriteNull() if err != nil { return err @@ -496,7 +496,7 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.Val return err } - if lookupErr == errInvalidValue { + if errors.Is(lookupErr, errInvalidValue) { err = vw.WriteNull() if err != nil { return err diff --git a/bson/bsoncodec/map_codec.go b/bson/bsoncodec/map_codec.go index 6a5292f2c0..868e39ccc0 100644 --- a/bson/bsoncodec/map_codec.go +++ b/bson/bsoncodec/map_codec.go @@ -8,6 +8,7 @@ package bsoncodec import ( "encoding" + "errors" "fmt" "reflect" "strconv" @@ -137,7 +138,7 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, v return err } - if lookupErr == errInvalidValue { + if errors.Is(lookupErr, errInvalidValue) { err = vw.WriteNull() if err != nil { return err @@ -200,7 +201,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref for { key, vr, err := dr.ReadElement() - if err == bsonrw.ErrEOD { + if errors.Is(err, bsonrw.ErrEOD) { break } if err != nil { diff --git a/bson/bsoncodec/registry_test.go b/bson/bsoncodec/registry_test.go index d09f32be5e..acc24a6e4d 100644 --- a/bson/bsoncodec/registry_test.go +++ b/bson/bsoncodec/registry_test.go @@ -7,6 +7,7 @@ package bsoncodec import ( + "errors" "reflect" "testing" @@ -351,7 +352,8 @@ func TestRegistryBuilder(t *testing.T) { }) t.Run("Decoder", func(t *testing.T) { wanterr := tc.wanterr - if ene, ok := tc.wanterr.(ErrNoEncoder); ok { + var ene ErrNoEncoder + if errors.As(tc.wanterr, &ene) { wanterr = ErrNoDecoder(ene) } @@ -775,7 +777,8 @@ func TestRegistry(t *testing.T) { t.Parallel() wanterr := tc.wanterr - if ene, ok := tc.wanterr.(ErrNoEncoder); ok { + var ene ErrNoEncoder + if errors.As(tc.wanterr, &ene) { wanterr = ErrNoDecoder(ene) } diff --git a/bson/bsonrw/copier.go b/bson/bsonrw/copier.go index 4d279b7fee..1e25570b85 100644 --- a/bson/bsonrw/copier.go +++ b/bson/bsonrw/copier.go @@ -7,6 +7,7 @@ package bsonrw import ( + "errors" "fmt" "io" @@ -442,7 +443,7 @@ func (c Copier) copyArray(dst ValueWriter, src ValueReader) error { for { vr, err := ar.ReadValue() - if err == ErrEOA { + if errors.Is(err, ErrEOA) { break } if err != nil { @@ -466,7 +467,7 @@ func (c Copier) copyArray(dst ValueWriter, src ValueReader) error { func (c Copier) copyDocumentCore(dw DocumentWriter, dr DocumentReader) error { for { key, vr, err := dr.ReadElement() - if err == ErrEOD { + if errors.Is(err, ErrEOD) { break } if err != nil { diff --git a/bson/bsonrw/extjson_parser.go b/bson/bsonrw/extjson_parser.go index 54c76bf746..bb52a0ec3d 100644 --- a/bson/bsonrw/extjson_parser.go +++ b/bson/bsonrw/extjson_parser.go @@ -313,7 +313,7 @@ func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) { // convert hex to bytes bytes, err := hex.DecodeString(uuidNoHyphens) if err != nil { - return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %v", err) + return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %w", err) } ejp.advanceState() diff --git a/bson/bsonrw/extjson_parser_test.go b/bson/bsonrw/extjson_parser_test.go index 6808b14174..5da5326688 100644 --- a/bson/bsonrw/extjson_parser_test.go +++ b/bson/bsonrw/extjson_parser_test.go @@ -7,6 +7,7 @@ package bsonrw import ( + "errors" "io" "strings" "testing" @@ -47,7 +48,7 @@ type readKeyValueTestCase struct { func expectSpecificError(expected error) expectedErrorFunc { return func(t *testing.T, err error, desc string) { - if err != expected { + if !errors.Is(err, expected) { t.Helper() t.Errorf("%s: Expected %v but got: %v", desc, expected, err) t.FailNow() diff --git a/bson/bsonrw/extjson_reader.go b/bson/bsonrw/extjson_reader.go index 2aca37a91f..59ddfc4485 100644 --- a/bson/bsonrw/extjson_reader.go +++ b/bson/bsonrw/extjson_reader.go @@ -7,6 +7,7 @@ package bsonrw import ( + "errors" "fmt" "io" "sync" @@ -613,7 +614,7 @@ func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) { name, t, err := ejvr.p.readKey() if err != nil { - if err == ErrEOD { + if errors.Is(err, ErrEOD) { if ejvr.stack[ejvr.frame].mode == mCodeWithScope { _, err := ejvr.p.peekType() if err != nil { @@ -640,7 +641,7 @@ func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) { t, err := ejvr.p.peekType() if err != nil { - if err == ErrEOA { + if errors.Is(err, ErrEOA) { ejvr.pop() } diff --git a/bson/bsonrw/json_scanner.go b/bson/bsonrw/json_scanner.go index cd4843a3a4..65a812ac18 100644 --- a/bson/bsonrw/json_scanner.go +++ b/bson/bsonrw/json_scanner.go @@ -58,7 +58,7 @@ func (js *jsonScanner) nextToken() (*jsonToken, error) { c, err = js.readNextByte() } - if err == io.EOF { + if errors.Is(err, io.EOF) { return &jsonToken{t: jttEOF}, nil } else if err != nil { return nil, err @@ -198,7 +198,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) { for { c, err = js.readNextByte() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil, errors.New("end of input in JSON string") } return nil, err @@ -209,7 +209,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) { case '\\': c, err = js.readNextByte() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil, errors.New("end of input in JSON string") } return nil, err @@ -248,7 +248,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) { if utf16.IsSurrogate(rn) { c, err = js.readNextByte() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil, errors.New("end of input in JSON string") } return nil, err @@ -264,7 +264,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) { c, err = js.readNextByte() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil, errors.New("end of input in JSON string") } return nil, err @@ -384,7 +384,7 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { for { c, err = js.readNextByte() - if err != nil && err != io.EOF { + if err != nil && !errors.Is(err, io.EOF) { return nil, err } diff --git a/bson/bsonrw/value_reader_test.go b/bson/bsonrw/value_reader_test.go index 1716eb54c4..11b257277e 100644 --- a/bson/bsonrw/value_reader_test.go +++ b/bson/bsonrw/value_reader_test.go @@ -8,6 +8,7 @@ package bsonrw import ( "bytes" + "errors" "fmt" "io" "math" @@ -185,7 +186,7 @@ func TestValueReader(t *testing.T) { // invalid length vr.d = []byte{0x00, 0x00} _, err := vr.ReadDocument() - if err != io.EOF { + if !errors.Is(err, io.EOF) { t.Errorf("Expected io.EOF with document length too small. got %v; want %v", err, io.EOF) } @@ -239,7 +240,7 @@ func TestValueReader(t *testing.T) { vr.frame-- _, err = vr.ReadDocument() - if err != io.EOF { + if !errors.Is(err, io.EOF) { t.Errorf("Should return error when attempting to read length with not enough bytes. got %v; want %v", err, io.EOF) } }) @@ -1482,7 +1483,7 @@ func TestValueReader(t *testing.T) { frame: 0, } gotType, got, gotErr := vr.ReadValueBytes(nil) - if gotErr != tc.wantErr { + if !errors.Is(gotErr, tc.wantErr) { t.Errorf("Did not receive expected error. got %v; want %v", gotErr, tc.wantErr) } if tc.wantErr == nil && gotType != tc.wantType { diff --git a/bson/decoder_test.go b/bson/decoder_test.go index c91f4e0491..c4476dddab 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -279,7 +279,7 @@ func TestDecoderv2(t *testing.T) { var got *D err = dec.Decode(got) - if err != ErrDecodeToNil { + if !errors.Is(err, ErrDecodeToNil) { t.Fatalf("Decode error mismatch; expected %v, got %v", ErrDecodeToNil, err) } }) diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index 466f135e83..35e7ba9a91 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -28,7 +28,7 @@ import ( func bytesFromDoc(doc interface{}) []byte { b, err := Marshal(doc) if err != nil { - panic(fmt.Errorf("Couldn't marshal BSON document: %v", err)) + panic(fmt.Errorf("Couldn't marshal BSON document: %w", err)) } return b } @@ -471,7 +471,7 @@ func TestDefaultValueEncoders(t *testing.T) { enc, err := NewEncoder(vw) noerr(t, err) err = enc.Encode(tc.value) - if err != tc.err { + if !errors.Is(err, tc.err) { t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) } if diff := cmp.Diff([]byte(b), tc.b); diff != "" { diff --git a/bson/raw_test.go b/bson/raw_test.go index 02c9f63136..644a2eea16 100644 --- a/bson/raw_test.go +++ b/bson/raw_test.go @@ -9,6 +9,7 @@ package bson import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "strings" @@ -52,7 +53,7 @@ func TestRaw(t *testing.T) { r := make(Raw, 5) binary.LittleEndian.PutUint32(r[0:4], 200) got := r.Validate() - if got != want { + if !errors.Is(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) @@ -62,7 +63,7 @@ func TestRaw(t *testing.T) { binary.LittleEndian.PutUint32(r[0:4], 8) r[4], r[5], r[6], r[7] = '\x02', 'f', 'o', 'o' got := r.Validate() - if got != want { + if !errors.Is(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) @@ -72,7 +73,7 @@ func TestRaw(t *testing.T) { binary.LittleEndian.PutUint32(r[0:4], 9) r[4], r[5], r[6], r[7], r[8] = '\x0A', 'f', 'o', 'o', '\x00' got := r.Validate() - if got != want { + if !errors.Is(got, want) { t.Errorf("Did not get expected error. got %v; want %v", got, want) } }) diff --git a/examples/documentation_examples/examples.go b/examples/documentation_examples/examples.go index ca92646865..c6bfd0faed 100644 --- a/examples/documentation_examples/examples.go +++ b/examples/documentation_examples/examples.go @@ -8,6 +8,7 @@ package documentation_examples import ( "context" + "errors" "fmt" "io/ioutil" logger "log" @@ -1816,7 +1817,8 @@ func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.Session log.Println("Transaction aborted. Caught exception during transaction.") // If transient error, retry the whole transaction - if cmdErr, ok := err.(mongo.CommandError); ok && cmdErr.HasErrorLabel("TransientTransactionError") { + var cmdErr mongo.CommandError + if errors.As(err, &cmdErr) && cmdErr.HasErrorLabel("TransientTransactionError") { log.Println("TransientTransactionError, retrying transaction...") continue } @@ -1883,7 +1885,8 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error { log.Println("Transaction aborted. Caught exception during transaction.") // If transient error, retry the whole transaction - if cmdErr, ok := err.(mongo.CommandError); ok && cmdErr.HasErrorLabel("TransientTransactionError") { + var cmdErr mongo.CommandError + if errors.As(err, &cmdErr) && cmdErr.HasErrorLabel("TransientTransactionError") { log.Println("TransientTransactionError, retrying transaction...") continue } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 03d42814f4..2250286e4a 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -183,7 +183,7 @@ func selectLogSink(sink LogSink) (LogSink, *os.File, error) { if path != "" { logFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666) if err != nil { - return nil, nil, fmt.Errorf("unable to open log file: %v", err) + return nil, nil, fmt.Errorf("unable to open log file: %w", err) } return NewIOSink(logFile), logFile, nil diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index 42d286ea7d..a7efd551e7 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -8,6 +8,7 @@ package mongo import ( "context" + "errors" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/primitive" @@ -108,8 +109,8 @@ func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWr case *InsertOneModel: res, err := bw.runInsert(ctx, batch) if err != nil { - writeErr, ok := err.(driver.WriteCommandError) - if !ok { + var writeErr driver.WriteCommandError + if !errors.As(err, &writeErr) { return BulkWriteResult{}, batchErr, err } writeErrors = writeErr.WriteErrors @@ -120,8 +121,8 @@ func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWr case *DeleteOneModel, *DeleteManyModel: res, err := bw.runDelete(ctx, batch) if err != nil { - writeErr, ok := err.(driver.WriteCommandError) - if !ok { + var writeErr driver.WriteCommandError + if !errors.As(err, &writeErr) { return BulkWriteResult{}, batchErr, err } writeErrors = writeErr.WriteErrors @@ -132,8 +133,8 @@ func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWr case *ReplaceOneModel, *UpdateOneModel, *UpdateManyModel: res, err := bw.runUpdate(ctx, batch) if err != nil { - writeErr, ok := err.(driver.WriteCommandError) - if !ok { + var writeErr driver.WriteCommandError + if !errors.As(err, &writeErr) { return BulkWriteResult{}, batchErr, err } writeErrors = writeErr.WriteErrors diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 773cbb0e5d..c4c2fb2590 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -689,8 +689,8 @@ func (cs *ChangeStream) loopNext(ctx context.Context, nonBlocking bool) { } func (cs *ChangeStream) isResumableError() bool { - commandErr, ok := cs.err.(CommandError) - if !ok || commandErr.HasErrorLabel(networkErrorLabel) { + var commandErr CommandError + if !errors.As(cs.err, &commandErr) || commandErr.HasErrorLabel(networkErrorLabel) { // All non-server errors or network errors are resumable. return true } diff --git a/mongo/client.go b/mongo/client.go index 5929274831..280749c7dd 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -555,7 +555,7 @@ func (c *Client) newMongoCrypt(opts *options.AutoEncryptionOptions) (*mongocrypt kmsProviders, err := marshal(opts.KmsProviders, c.bsonOpts, c.registry) if err != nil { - return nil, fmt.Errorf("error creating KMS providers document: %v", err) + return nil, fmt.Errorf("error creating KMS providers document: %w", err) } // Set the crypt_shared library override path from the "cryptSharedLibPath" extra option if one diff --git a/mongo/client_encryption.go b/mongo/client_encryption.go index 01c2ec3193..b51f57b473 100644 --- a/mongo/client_encryption.go +++ b/mongo/client_encryption.go @@ -46,7 +46,7 @@ func NewClientEncryption(keyVaultClient *Client, opts ...*options.ClientEncrypti kmsProviders, err := marshal(ceo.KmsProviders, nil, nil) if err != nil { - return nil, fmt.Errorf("error creating KMS providers map: %v", err) + return nil, fmt.Errorf("error creating KMS providers map: %w", err) } mc, err := mongocrypt.NewMongoCrypt(mcopts.MongoCrypt(). diff --git a/mongo/collection.go b/mongo/collection.go index fcbfcc77a1..ac173307ff 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -313,8 +313,8 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, op = op.Retry(retry) err = op.Execute(ctx) - wce, ok := err.(driver.WriteCommandError) - if !ok { + var wce driver.WriteCommandError + if !errors.As(err, &wce) { return result, err } @@ -388,8 +388,8 @@ func (coll *Collection) InsertMany(ctx context.Context, documents []interface{}, } imResult := &InsertManyResult{InsertedIDs: result} - writeException, ok := err.(WriteException) - if !ok { + var writeException WriteException + if !errors.As(err, &writeException) { return imResult, err } @@ -1806,7 +1806,7 @@ func (coll *Collection) Drop(ctx context.Context) error { func (coll *Collection) dropEncryptedCollection(ctx context.Context, ef interface{}) error { efBSON, err := marshal(ef, coll.bsonOpts, coll.registry) if err != nil { - return fmt.Errorf("error transforming document: %v", err) + return fmt.Errorf("error transforming document: %w", err) } // Drop the two encryption-related, associated collections: `escCollection` and `ecocCollection`. diff --git a/mongo/integration/mtest/proxy_dialer.go b/mongo/integration/mtest/proxy_dialer.go index b50f37488a..c8e9e6d456 100644 --- a/mongo/integration/mtest/proxy_dialer.go +++ b/mongo/integration/mtest/proxy_dialer.go @@ -51,7 +51,7 @@ func newProxyDialer() *proxyDialer { } func newProxyErrorWithWireMsg(wm []byte, err error) error { - return fmt.Errorf("proxy error for wiremessage %v: %v", wm, err) + return fmt.Errorf("proxy error for wiremessage %v: %w", wm, err) } // DialContext creates a new proxyConnection. @@ -149,7 +149,7 @@ type proxyConn struct { // server. func (pc *proxyConn) Write(wm []byte) (n int, err error) { if err := pc.dialer.storeSentMessage(wm); err != nil { - wrapped := fmt.Errorf("error storing sent message: %v", err) + wrapped := fmt.Errorf("error storing sent message: %w", err) return 0, newProxyErrorWithWireMsg(wm, wrapped) } @@ -178,7 +178,7 @@ func (pc *proxyConn) Read(buffer []byte) (int, error) { wm = bsoncore.UpdateLength(wm, idx, int32(len(wm[idx:]))) if err := pc.dialer.storeReceivedMessage(wm, pc.RemoteAddr().String()); err != nil { - wrapped := fmt.Errorf("error storing received message: %v", err) + wrapped := fmt.Errorf("error storing received message: %w", err) return 0, newProxyErrorWithWireMsg(wm, wrapped) } diff --git a/mongo/integration/mtest/received_message.go b/mongo/integration/mtest/received_message.go index 3df507e5a5..2e2f952242 100644 --- a/mongo/integration/mtest/received_message.go +++ b/mongo/integration/mtest/received_message.go @@ -49,7 +49,7 @@ func parseReceivedMessage(wm []byte) (*ReceivedMessage, error) { } received, err := parseFn(remaining) if err != nil { - return nil, fmt.Errorf("error parsing wiremessage with opcode %s: %v", opcode, err) + return nil, fmt.Errorf("error parsing wiremessage with opcode %s: %w", opcode, err) } received.ResponseTo = responseTo @@ -97,7 +97,7 @@ func parseReceivedOpMsg(wm []byte) (*ReceivedMessage, error) { } if wm, err = assertMsgSectionType(wm, wiremessage.SingleDocument); err != nil { - return nil, fmt.Errorf("error verifying section type for response document: %v", err) + return nil, fmt.Errorf("error verifying section type for response document: %w", err) } response, wm, ok := wiremessage.ReadMsgSectionSingleDocument(wm) diff --git a/mongo/integration/mtest/sent_message.go b/mongo/integration/mtest/sent_message.go index 6b96e061bc..94eed12257 100644 --- a/mongo/integration/mtest/sent_message.go +++ b/mongo/integration/mtest/sent_message.go @@ -124,7 +124,7 @@ func parseSentMessage(wm []byte) (*SentMessage, error) { } sent, err := parseFn(remaining) if err != nil { - return nil, fmt.Errorf("error parsing wiremessage with opcode %s: %v", opcode, err) + return nil, fmt.Errorf("error parsing wiremessage with opcode %s: %w", opcode, err) } sent.RequestID = requestID @@ -142,7 +142,7 @@ func parseSentOpMsg(wm []byte) (*SentMessage, error) { } if wm, err = assertMsgSectionType(wm, wiremessage.SingleDocument); err != nil { - return nil, fmt.Errorf("error verifying section type for command document: %v", err) + return nil, fmt.Errorf("error verifying section type for command document: %w", err) } var commandDoc bsoncore.Document @@ -160,7 +160,7 @@ func parseSentOpMsg(wm []byte) (*SentMessage, error) { if len(wm) != 0 { // If there are bytes remaining in the wire message, they must correspond to a DocumentSequence section. if wm, err = assertMsgSectionType(wm, wiremessage.DocumentSequence); err != nil { - return nil, fmt.Errorf("error verifying section type for document sequence: %v", err) + return nil, fmt.Errorf("error verifying section type for document sequence: %w", err) } var data []byte diff --git a/mongo/integration/mtest/setup.go b/mongo/integration/mtest/setup.go index 303b7afdc2..49aacfd194 100644 --- a/mongo/integration/mtest/setup.go +++ b/mongo/integration/mtest/setup.go @@ -83,13 +83,13 @@ func Setup(setupOpts ...*SetupOptions) error { var err error uri, err = integtest.MongoDBURI() if err != nil { - return fmt.Errorf("error getting uri: %v", err) + return fmt.Errorf("error getting uri: %w", err) } } testContext.connString, err = connstring.ParseAndValidate(uri) if err != nil { - return fmt.Errorf("error parsing and validating connstring: %v", err) + return fmt.Errorf("error parsing and validating connstring: %w", err) } testContext.dataLake = os.Getenv("ATLAS_DATA_LAKE_INTEGRATION_TEST") == "true" @@ -100,20 +100,20 @@ func Setup(setupOpts ...*SetupOptions) error { cfg, err := topology.NewConfig(clientOpts, nil) if err != nil { - return fmt.Errorf("error constructing topology config: %v", err) + return fmt.Errorf("error constructing topology config: %w", err) } testContext.topo, err = topology.New(cfg) if err != nil { - return fmt.Errorf("error creating topology: %v", err) + return fmt.Errorf("error creating topology: %w", err) } if err = testContext.topo.Connect(); err != nil { - return fmt.Errorf("error connecting topology: %v", err) + return fmt.Errorf("error connecting topology: %w", err) } testContext.client, err = setupClient(options.Client().ApplyURI(uri)) if err != nil { - return fmt.Errorf("error connecting test client: %v", err) + return fmt.Errorf("error connecting test client: %w", err) } pingCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) diff --git a/x/bsonx/bsoncore/array_test.go b/x/bsonx/bsoncore/array_test.go index 8249d82808..4171adade7 100644 --- a/x/bsonx/bsoncore/array_test.go +++ b/x/bsonx/bsoncore/array_test.go @@ -115,7 +115,7 @@ func TestArray(t *testing.T) { t.Run("Out of bounds", func(t *testing.T) { rdr := Array{0xe, 0x0, 0x0, 0x0, 0xa, '0', 0x0, 0xa, '1', 0x0, 0xa, 0x7a, 0x0, 0x0} _, err := rdr.IndexErr(3) - if err != ErrOutOfBounds { + if !errors.Is(err, ErrOutOfBounds) { t.Errorf("Out of bounds should be returned when accessing element beyond end of Array. got %v; want %v", err, ErrOutOfBounds) } }) diff --git a/x/bsonx/bsoncore/document_sequence_test.go b/x/bsonx/bsoncore/document_sequence_test.go index c9a395d4f2..bf40fa878d 100644 --- a/x/bsonx/bsoncore/document_sequence_test.go +++ b/x/bsonx/bsoncore/document_sequence_test.go @@ -8,6 +8,7 @@ package bsoncore import ( "bytes" + "errors" "io" "strconv" "testing" @@ -113,7 +114,7 @@ func TestDocumentSequence(t *testing.T) { if !cmp.Equal(documents, tc.documents) { t.Errorf("Documents do not match. got %v; want %v", documents, tc.documents) } - if err != tc.err { + if !errors.Is(err, tc.err) { t.Errorf("Errors do not match. got %v; want %v", err, tc.err) } }) @@ -224,7 +225,7 @@ func TestDocumentSequence(t *testing.T) { if !bytes.Equal(document, tc.document) { t.Errorf("Documents do not match. got %v; want %v", document, tc.document) } - if err != tc.err { + if !errors.Is(err, tc.err) { t.Errorf("Errors do not match. got %v; want %v", err, tc.err) } }) @@ -275,7 +276,7 @@ func TestDocumentSequence(t *testing.T) { var docs []Document for { doc, err := ds.Next() - if err == io.EOF { + if errors.Is(err, io.EOF) { break } if err != nil { diff --git a/x/bsonx/bsoncore/document_test.go b/x/bsonx/bsoncore/document_test.go index 0d77b79d30..a5609e689e 100644 --- a/x/bsonx/bsoncore/document_test.go +++ b/x/bsonx/bsoncore/document_test.go @@ -9,6 +9,7 @@ package bsoncore import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "testing" @@ -113,7 +114,7 @@ func TestDocument(t *testing.T) { t.Run("empty-key", func(t *testing.T) { rdr := Document{'\x05', '\x00', '\x00', '\x00', '\x00'} _, err := rdr.LookupErr() - if err != ErrEmptyKey { + if !errors.Is(err, ErrEmptyKey) { t.Errorf("Empty key lookup did not return expected result. got %v; want %v", err, ErrEmptyKey) } }) @@ -206,7 +207,7 @@ func TestDocument(t *testing.T) { }) t.Run("LookupErr", func(t *testing.T) { got, err := tc.r.LookupErr(tc.key...) - if err != tc.err { + if !errors.Is(err, tc.err) { t.Errorf("Returned error does not match. got %v; want %v", err, tc.err) } if !cmp.Equal(got, tc.want) { @@ -220,7 +221,7 @@ func TestDocument(t *testing.T) { t.Run("Out of bounds", func(t *testing.T) { rdr := Document{0xe, 0x0, 0x0, 0x0, 0xa, 0x78, 0x0, 0xa, 0x79, 0x0, 0xa, 0x7a, 0x0, 0x0} _, err := rdr.IndexErr(3) - if err != ErrOutOfBounds { + if !errors.Is(err, ErrOutOfBounds) { t.Errorf("Out of bounds should be returned when accessing element beyond end of document. got %v; want %v", err, ErrOutOfBounds) } }) diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index 7d3703f7be..827e536137 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -451,7 +451,8 @@ func (bc *BatchCursor) getMore(ctx context.Context) { // If we're in load balanced mode and the pinned connection encounters a network error, we should not use it for // future commands. Per the spec, the connection will not be unpinned until the cursor is actually closed, but // we set the cursor ID to 0 to ensure the Close() call will not execute a killCursors command. - if driverErr, ok := bc.err.(Error); ok && driverErr.NetworkError() && bc.connection != nil { + var driverErr Error + if errors.As(bc.err, &driverErr) && driverErr.NetworkError() && bc.connection != nil { bc.id = 0 } diff --git a/x/mongo/driver/ocsp/cache_test.go b/x/mongo/driver/ocsp/cache_test.go index 8558191f15..047b749969 100644 --- a/x/mongo/driver/ocsp/cache_test.go +++ b/x/mongo/driver/ocsp/cache_test.go @@ -34,8 +34,8 @@ func TestCache(t *testing.T) { err := Verify(ctx, tls.ConnectionState{}, &VerifyOptions{}) assert.NotNil(t, err, "expected error, got nil") - ocspErr, ok := err.(*Error) - assert.True(t, ok, "expected error of type %T, got %v of type %T", &Error{}, err, err) + var ocspErr *Error + assert.True(t, errors.As(err, &ocspErr), "expected error of type %T, got %v of type %T", &Error{}, err, err) expected := &Error{ wrapped: errors.New("no OCSP cache provided"), } diff --git a/x/mongo/driver/ocsp/config.go b/x/mongo/driver/ocsp/config.go index 94a5dd775f..5b720cd590 100644 --- a/x/mongo/driver/ocsp/config.go +++ b/x/mongo/driver/ocsp/config.go @@ -61,7 +61,7 @@ func newConfig(certChain []*x509.Certificate, opts *VerifyOptions) (config, erro } cfg.ocspRequest, err = ocsp.ParseRequest(cfg.ocspRequestBytes) if err != nil { - return cfg, fmt.Errorf("error parsing OCSP request bytes: %v", err) + return cfg, fmt.Errorf("error parsing OCSP request bytes: %w", err) } return cfg, nil diff --git a/x/mongo/driver/ocsp/ocsp.go b/x/mongo/driver/ocsp/ocsp.go index 849530fde9..8700728729 100644 --- a/x/mongo/driver/ocsp/ocsp.go +++ b/x/mongo/driver/ocsp/ocsp.go @@ -161,10 +161,10 @@ func processStaple(cfg config, staple []byte) (*ResponseDetails, error) { // If the stapled response could not be parsed correctly, error. This can happen if the response is malformed, // the response does not cover the certificate presented by the server, or if the response contains an error // status. - return nil, fmt.Errorf("error parsing stapled response: %v", err) + return nil, fmt.Errorf("error parsing stapled response: %w", err) } if err = verifyResponse(cfg, parsedResponse); err != nil { - return nil, fmt.Errorf("error validating stapled response: %v", err) + return nil, fmt.Errorf("error validating stapled response: %w", err) } return extractResponseDetails(parsedResponse), nil @@ -192,7 +192,7 @@ func isMustStapleCertificate(cert *x509.Certificate) (bool, error) { // Use []*big.Int to ensure that all values in the sequence can be successfully unmarshalled. var featureValues []*big.Int if _, err := asn1.Unmarshal(featureExtension.Value, &featureValues); err != nil { - return false, fmt.Errorf("error unmarshalling TLS feature extension values: %v", err) + return false, fmt.Errorf("error unmarshalling TLS feature extension values: %w", err) } for _, value := range featureValues { diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 905c9cfc55..33ed562426 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -142,7 +142,8 @@ func convertInt64PtrToInt32Ptr(i64 *int64) *int32 { // write errors are included since the actual command did succeed, only writes // failed. func (info finishedInformation) success() bool { - if _, ok := info.cmdErr.(WriteCommandError); ok { + var writeCmdErr WriteCommandError + if errors.As(info.cmdErr, &writeCmdErr) { return true } @@ -1492,7 +1493,7 @@ func (op Operation) addWriteConcern(dst []byte, desc description.SelectedServer) } t, data, err := wc.MarshalBSONValue() - if err == writeconcern.ErrEmptyWriteConcern { + if errors.Is(err, writeconcern.ErrEmptyWriteConcern) { return dst, nil } if err != nil { diff --git a/x/mongo/driver/topology/errors.go b/x/mongo/driver/topology/errors.go index 7ce41864e6..a6630aae76 100644 --- a/x/mongo/driver/topology/errors.go +++ b/x/mongo/driver/topology/errors.go @@ -8,6 +8,7 @@ package topology import ( "context" + "errors" "fmt" "time" @@ -86,9 +87,9 @@ type pinnedConnections struct { // Error implements the error interface. func (w WaitQueueTimeoutError) Error() string { errorMsg := "timed out while checking out a connection from connection pool" - switch w.Wrapped { - case nil: - case context.Canceled: + switch { + case w.Wrapped == nil: + case errors.Is(w.Wrapped, context.Canceled): errorMsg = fmt.Sprintf( "%s: %s", "canceled while checking out a connection from connection pool",