Skip to content

Commit

Permalink
GODRIVER-2603 (Contd.) Revised error handling using Go 1.13 error APIs (
Browse files Browse the repository at this point in the history
  • Loading branch information
kumarlokesh authored Nov 30, 2023
1 parent 8705829 commit c61efde
Show file tree
Hide file tree
Showing 33 changed files with 98 additions and 80 deletions.
2 changes: 1 addition & 1 deletion bson/bsoncodec/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -1787,7 +1787,7 @@ func (DefaultValueDecoders) decodeElemsFromDocumentReader(dc DecodeContext, dr b
elems := make([]reflect.Value, 0)
for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
if errors.Is(err, bsonrw.ErrEOD) {
break
}
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions bson/bsoncodec/default_value_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.Docum
return err
}

if lookupErr == errInvalidValue {
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
Expand Down Expand Up @@ -427,7 +427,7 @@ func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.Val
return err
}

if lookupErr == errInvalidValue {
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
Expand Down Expand Up @@ -496,7 +496,7 @@ func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.Val
return err
}

if lookupErr == errInvalidValue {
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
Expand Down
5 changes: 3 additions & 2 deletions bson/bsoncodec/map_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bsoncodec

import (
"encoding"
"errors"
"fmt"
"reflect"
"strconv"
Expand Down Expand Up @@ -137,7 +138,7 @@ func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, v
return err
}

if lookupErr == errInvalidValue {
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
Expand Down Expand Up @@ -200,7 +201,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val ref

for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
if errors.Is(err, bsonrw.ErrEOD) {
break
}
if err != nil {
Expand Down
7 changes: 5 additions & 2 deletions bson/bsoncodec/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package bsoncodec

import (
"errors"
"reflect"
"testing"

Expand Down Expand Up @@ -351,7 +352,8 @@ func TestRegistryBuilder(t *testing.T) {
})
t.Run("Decoder", func(t *testing.T) {
wanterr := tc.wanterr
if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
var ene ErrNoEncoder
if errors.As(tc.wanterr, &ene) {
wanterr = ErrNoDecoder(ene)
}

Expand Down Expand Up @@ -775,7 +777,8 @@ func TestRegistry(t *testing.T) {
t.Parallel()

wanterr := tc.wanterr
if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
var ene ErrNoEncoder
if errors.As(tc.wanterr, &ene) {
wanterr = ErrNoDecoder(ene)
}

Expand Down
5 changes: 3 additions & 2 deletions bson/bsonrw/copier.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package bsonrw

import (
"errors"
"fmt"
"io"

Expand Down Expand Up @@ -442,7 +443,7 @@ func (c Copier) copyArray(dst ValueWriter, src ValueReader) error {

for {
vr, err := ar.ReadValue()
if err == ErrEOA {
if errors.Is(err, ErrEOA) {
break
}
if err != nil {
Expand All @@ -466,7 +467,7 @@ func (c Copier) copyArray(dst ValueWriter, src ValueReader) error {
func (c Copier) copyDocumentCore(dw DocumentWriter, dr DocumentReader) error {
for {
key, vr, err := dr.ReadElement()
if err == ErrEOD {
if errors.Is(err, ErrEOD) {
break
}
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion bson/bsonrw/extjson_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) {
// convert hex to bytes
bytes, err := hex.DecodeString(uuidNoHyphens)
if err != nil {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %v", err)
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %w", err)
}

ejp.advanceState()
Expand Down
3 changes: 2 additions & 1 deletion bson/bsonrw/extjson_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package bsonrw

import (
"errors"
"io"
"strings"
"testing"
Expand Down Expand Up @@ -47,7 +48,7 @@ type readKeyValueTestCase struct {

func expectSpecificError(expected error) expectedErrorFunc {
return func(t *testing.T, err error, desc string) {
if err != expected {
if !errors.Is(err, expected) {
t.Helper()
t.Errorf("%s: Expected %v but got: %v", desc, expected, err)
t.FailNow()
Expand Down
5 changes: 3 additions & 2 deletions bson/bsonrw/extjson_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package bsonrw

import (
"errors"
"fmt"
"io"
"sync"
Expand Down Expand Up @@ -613,7 +614,7 @@ func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) {
name, t, err := ejvr.p.readKey()

if err != nil {
if err == ErrEOD {
if errors.Is(err, ErrEOD) {
if ejvr.stack[ejvr.frame].mode == mCodeWithScope {
_, err := ejvr.p.peekType()
if err != nil {
Expand All @@ -640,7 +641,7 @@ func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) {

t, err := ejvr.p.peekType()
if err != nil {
if err == ErrEOA {
if errors.Is(err, ErrEOA) {
ejvr.pop()
}

Expand Down
12 changes: 6 additions & 6 deletions bson/bsonrw/json_scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (js *jsonScanner) nextToken() (*jsonToken, error) {
c, err = js.readNextByte()
}

if err == io.EOF {
if errors.Is(err, io.EOF) {
return &jsonToken{t: jttEOF}, nil
} else if err != nil {
return nil, err
Expand Down Expand Up @@ -198,7 +198,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) {
for {
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
Expand All @@ -209,7 +209,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) {
case '\\':
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
Expand Down Expand Up @@ -248,7 +248,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) {
if utf16.IsSurrogate(rn) {
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
Expand All @@ -264,7 +264,7 @@ func (js *jsonScanner) scanString() (*jsonToken, error) {

c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
return nil, errors.New("end of input in JSON string")
}
return nil, err
Expand Down Expand Up @@ -384,7 +384,7 @@ func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) {
for {
c, err = js.readNextByte()

if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}

Expand Down
7 changes: 4 additions & 3 deletions bson/bsonrw/value_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bsonrw

import (
"bytes"
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -185,7 +186,7 @@ func TestValueReader(t *testing.T) {
// invalid length
vr.d = []byte{0x00, 0x00}
_, err := vr.ReadDocument()
if err != io.EOF {
if !errors.Is(err, io.EOF) {
t.Errorf("Expected io.EOF with document length too small. got %v; want %v", err, io.EOF)
}

Expand Down Expand Up @@ -239,7 +240,7 @@ func TestValueReader(t *testing.T) {

vr.frame--
_, err = vr.ReadDocument()
if err != io.EOF {
if !errors.Is(err, io.EOF) {
t.Errorf("Should return error when attempting to read length with not enough bytes. got %v; want %v", err, io.EOF)
}
})
Expand Down Expand Up @@ -1482,7 +1483,7 @@ func TestValueReader(t *testing.T) {
frame: 0,
}
gotType, got, gotErr := vr.ReadValueBytes(nil)
if gotErr != tc.wantErr {
if !errors.Is(gotErr, tc.wantErr) {
t.Errorf("Did not receive expected error. got %v; want %v", gotErr, tc.wantErr)
}
if tc.wantErr == nil && gotType != tc.wantType {
Expand Down
2 changes: 1 addition & 1 deletion bson/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func TestDecoderv2(t *testing.T) {

var got *D
err = dec.Decode(got)
if err != ErrDecodeToNil {
if !errors.Is(err, ErrDecodeToNil) {
t.Fatalf("Decode error mismatch; expected %v, got %v", ErrDecodeToNil, err)
}
})
Expand Down
4 changes: 2 additions & 2 deletions bson/primitive_codecs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
func bytesFromDoc(doc interface{}) []byte {
b, err := Marshal(doc)
if err != nil {
panic(fmt.Errorf("Couldn't marshal BSON document: %v", err))
panic(fmt.Errorf("Couldn't marshal BSON document: %w", err))
}
return b
}
Expand Down Expand Up @@ -471,7 +471,7 @@ func TestDefaultValueEncoders(t *testing.T) {
enc, err := NewEncoder(vw)
noerr(t, err)
err = enc.Encode(tc.value)
if err != tc.err {
if !errors.Is(err, tc.err) {
t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err)
}
if diff := cmp.Diff([]byte(b), tc.b); diff != "" {
Expand Down
7 changes: 4 additions & 3 deletions bson/raw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package bson
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
Expand Down Expand Up @@ -52,7 +53,7 @@ func TestRaw(t *testing.T) {
r := make(Raw, 5)
binary.LittleEndian.PutUint32(r[0:4], 200)
got := r.Validate()
if got != want {
if !errors.Is(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
Expand All @@ -62,7 +63,7 @@ func TestRaw(t *testing.T) {
binary.LittleEndian.PutUint32(r[0:4], 8)
r[4], r[5], r[6], r[7] = '\x02', 'f', 'o', 'o'
got := r.Validate()
if got != want {
if !errors.Is(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
Expand All @@ -72,7 +73,7 @@ func TestRaw(t *testing.T) {
binary.LittleEndian.PutUint32(r[0:4], 9)
r[4], r[5], r[6], r[7], r[8] = '\x0A', 'f', 'o', 'o', '\x00'
got := r.Validate()
if got != want {
if !errors.Is(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
Expand Down
7 changes: 5 additions & 2 deletions examples/documentation_examples/examples.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package documentation_examples

import (
"context"
"errors"
"fmt"
"io/ioutil"
logger "log"
Expand Down Expand Up @@ -1816,7 +1817,8 @@ func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.Session
log.Println("Transaction aborted. Caught exception during transaction.")

// If transient error, retry the whole transaction
if cmdErr, ok := err.(mongo.CommandError); ok && cmdErr.HasErrorLabel("TransientTransactionError") {
var cmdErr mongo.CommandError
if errors.As(err, &cmdErr) && cmdErr.HasErrorLabel("TransientTransactionError") {
log.Println("TransientTransactionError, retrying transaction...")
continue
}
Expand Down Expand Up @@ -1883,7 +1885,8 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error {
log.Println("Transaction aborted. Caught exception during transaction.")

// If transient error, retry the whole transaction
if cmdErr, ok := err.(mongo.CommandError); ok && cmdErr.HasErrorLabel("TransientTransactionError") {
var cmdErr mongo.CommandError
if errors.As(err, &cmdErr) && cmdErr.HasErrorLabel("TransientTransactionError") {
log.Println("TransientTransactionError, retrying transaction...")
continue
}
Expand Down
2 changes: 1 addition & 1 deletion internal/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func selectLogSink(sink LogSink) (LogSink, *os.File, error) {
if path != "" {
logFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666)
if err != nil {
return nil, nil, fmt.Errorf("unable to open log file: %v", err)
return nil, nil, fmt.Errorf("unable to open log file: %w", err)
}

return NewIOSink(logFile), logFile, nil
Expand Down
13 changes: 7 additions & 6 deletions mongo/bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package mongo

import (
"context"
"errors"

"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/primitive"
Expand Down Expand Up @@ -108,8 +109,8 @@ func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWr
case *InsertOneModel:
res, err := bw.runInsert(ctx, batch)
if err != nil {
writeErr, ok := err.(driver.WriteCommandError)
if !ok {
var writeErr driver.WriteCommandError
if !errors.As(err, &writeErr) {
return BulkWriteResult{}, batchErr, err
}
writeErrors = writeErr.WriteErrors
Expand All @@ -120,8 +121,8 @@ func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWr
case *DeleteOneModel, *DeleteManyModel:
res, err := bw.runDelete(ctx, batch)
if err != nil {
writeErr, ok := err.(driver.WriteCommandError)
if !ok {
var writeErr driver.WriteCommandError
if !errors.As(err, &writeErr) {
return BulkWriteResult{}, batchErr, err
}
writeErrors = writeErr.WriteErrors
Expand All @@ -132,8 +133,8 @@ func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWr
case *ReplaceOneModel, *UpdateOneModel, *UpdateManyModel:
res, err := bw.runUpdate(ctx, batch)
if err != nil {
writeErr, ok := err.(driver.WriteCommandError)
if !ok {
var writeErr driver.WriteCommandError
if !errors.As(err, &writeErr) {
return BulkWriteResult{}, batchErr, err
}
writeErrors = writeErr.WriteErrors
Expand Down
Loading

0 comments on commit c61efde

Please sign in to comment.