Skip to content

Commit

Permalink
revert errors.As to use type assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
kumarlokesh committed Dec 8, 2023
1 parent daa7298 commit 6612270
Show file tree
Hide file tree
Showing 17 changed files with 30 additions and 61 deletions.
6 changes: 2 additions & 4 deletions bson/bsoncodec/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,7 @@ func TestRegistryBuilder(t *testing.T) {
})
t.Run("Decoder", func(t *testing.T) {
wanterr := tc.wanterr
var ene ErrNoEncoder
if errors.As(tc.wanterr, &ene) {
if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
wanterr = ErrNoDecoder(ene)
}

Expand Down Expand Up @@ -777,8 +776,7 @@ func TestRegistry(t *testing.T) {
t.Parallel()

wanterr := tc.wanterr
var ene ErrNoEncoder
if errors.As(tc.wanterr, &ene) {
if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
wanterr = ErrNoDecoder(ene)
}

Expand Down
7 changes: 2 additions & 5 deletions examples/documentation_examples/examples.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package documentation_examples

import (
"context"
"errors"
"fmt"
"io/ioutil"
logger "log"
Expand Down Expand Up @@ -1817,8 +1816,7 @@ func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.Session
log.Println("Transaction aborted. Caught exception during transaction.")

// If transient error, retry the whole transaction
var cmdErr mongo.CommandError
if errors.As(err, &cmdErr) && cmdErr.HasErrorLabel("TransientTransactionError") {
if cmdErr, ok := err.(mongo.CommandError); ok && cmdErr.HasErrorLabel("TransientTransactionError") {
log.Println("TransientTransactionError, retrying transaction...")
continue
}
Expand Down Expand Up @@ -1885,8 +1883,7 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error {
log.Println("Transaction aborted. Caught exception during transaction.")

// If transient error, retry the whole transaction
var cmdErr mongo.CommandError
if errors.As(err, &cmdErr) && cmdErr.HasErrorLabel("TransientTransactionError") {
if cmdErr, ok := err.(mongo.CommandError); ok && cmdErr.HasErrorLabel("TransientTransactionError") {
log.Println("TransientTransactionError, retrying transaction...")
continue
}
Expand Down
4 changes: 1 addition & 3 deletions internal/aws/awserr/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
package awserr

import (
"errors"
"fmt"
)

Expand Down Expand Up @@ -107,8 +106,7 @@ func (b baseError) OrigErr() error {
case 1:
return b.errs[0]
default:
var err Error
if errors.As(b.errs[0], &err) {
if err, ok := b.errs[0].(Error); ok {
return NewBatchError(err.Code(), err.Message(), b.errs[1:])
}
return NewBatchError("BatchedErrors",
Expand Down
7 changes: 3 additions & 4 deletions mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -929,8 +929,7 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) {

err = op.Execute(a.ctx)
if err != nil {
var wce driver.WriteCommandError
if errors.As(err, &wce) && wce.WriteConcernError != nil {
if wce, ok := err.(driver.WriteCommandError); ok && wce.WriteConcernError != nil {
return nil, *convertDriverWriteConcernError(wce.WriteConcernError)
}
return nil, replaceErrors(err)
Expand Down Expand Up @@ -1869,8 +1868,8 @@ func (coll *Collection) drop(ctx context.Context) error {
err = op.Execute(ctx)

// ignore namespace not found errors
var driverErr driver.Error
if !errors.As(err, &driverErr) || !driverErr.NamespaceNotFound() {
driverErr, ok := err.(driver.Error)
if !ok || (ok && !driverErr.NamespaceNotFound()) {
return replaceErrors(err)
}
return nil
Expand Down
4 changes: 2 additions & 2 deletions mongo/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ func (db *Database) Drop(ctx context.Context) error {

err = op.Execute(ctx)

var driverErr driver.Error
if err != nil && (!errors.As(err, &driverErr) || !driverErr.NamespaceNotFound()) {
driverErr, ok := err.(driver.Error)
if err != nil && (!ok || !driverErr.NamespaceNotFound()) {
return replaceErrors(err)
}
return nil
Expand Down
18 changes: 5 additions & 13 deletions mongo/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ func replaceErrors(err error) error {
if errors.Is(err, topology.ErrTopologyClosed) {
return ErrClientDisconnected
}

var de driver.Error
if errors.As(err, &de) {
if de, ok := err.(driver.Error); ok {
return CommandError{
Code: de.Code,
Message: de.Message,
Expand All @@ -67,9 +65,7 @@ func replaceErrors(err error) error {
Raw: bson.Raw(de.Raw),
}
}

var qe driver.QueryFailureError
if errors.As(err, &qe) {
if qe, ok := err.(driver.QueryFailureError); ok {
// qe.Message is "command failure"
ce := CommandError{
Name: qe.Message,
Expand All @@ -88,18 +84,15 @@ func replaceErrors(err error) error {

return ce
}

var me mongocrypt.Error
if errors.As(err, &me) {
if me, ok := err.(mongocrypt.Error); ok {
return MongocryptError{Code: me.Code, Message: me.Message}
}

if errors.Is(err, codecutil.ErrNilValue) {
return ErrNilValue
}

var marshalErr codecutil.MarshalError
if errors.As(err, &marshalErr) {
if marshalErr, ok := err.(codecutil.MarshalError); ok {
return MarshalError{
Value: marshalErr.Value,
Err: marshalErr.Err,
Expand Down Expand Up @@ -178,8 +171,7 @@ func unwrap(err error) error {
// errorHasLabel returns true if err contains the specified label
func errorHasLabel(err error, label string) bool {
for ; err != nil; err = unwrap(err) {
var le LabeledError
if errors.As(err, &le) && le.HasErrorLabel(label) {
if le, ok := err.(LabeledError); ok && le.HasErrorLabel(label) {
return true
}
}
Expand Down
3 changes: 1 addition & 2 deletions mongo/index_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ type IndexModel struct {
}

func isNamespaceNotFoundError(err error) bool {
var de driver.Error
if errors.As(err, &de) {
if de, ok := err.(driver.Error); ok {
return de.Code == 26
}
return false
Expand Down
4 changes: 1 addition & 3 deletions mongo/search_index_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package mongo

import (
"context"
"errors"
"fmt"
"strconv"

Expand Down Expand Up @@ -215,8 +214,7 @@ func (siv SearchIndexView) DropOne(
Timeout(siv.coll.client.timeout)

err = op.Execute(ctx)
var de driver.Error
if errors.As(err, &de) && de.NamespaceNotFound() {
if de, ok := err.(driver.Error); ok && de.NamespaceNotFound() {
return nil
}
return err
Expand Down
3 changes: 1 addition & 2 deletions mongo/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,7 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(ctx SessionCo
default:
}

var cerr CommandError
if errors.As(err, &cerr) {
if cerr, ok := err.(CommandError); ok {
if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() {
continue
}
Expand Down
3 changes: 1 addition & 2 deletions mongo/with_transactions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ func TestConvenientTransactions(t *testing.T) {
{"killAllSessions", bson.A{}},
}).Err()
if err != nil {
var ce CommandError
if !errors.As(err, &ce) || ce.Code != errorInterrupted {
if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted {
t.Fatalf("killAllSessions error: %v", err)
}
}
Expand Down
3 changes: 1 addition & 2 deletions x/mongo/driver/batch_cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,7 @@ func (bc *BatchCursor) getMore(ctx context.Context) {
// If we're in load balanced mode and the pinned connection encounters a network error, we should not use it for
// future commands. Per the spec, the connection will not be unpinned until the cursor is actually closed, but
// we set the cursor ID to 0 to ensure the Close() call will not execute a killCursors command.
var driverErr Error
if errors.As(bc.err, &driverErr) && driverErr.NetworkError() && bc.connection != nil {
if driverErr, ok := bc.err.(Error); ok && driverErr.NetworkError() && bc.connection != nil {
bc.id = 0
}

Expand Down
3 changes: 1 addition & 2 deletions x/mongo/driver/connstring/connstring.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,7 @@ func (p *parser) addHost(host string) error {
// this is unfortunate that SplitHostPort actually requires
// a port to exist.
if err != nil {
var addrError *net.AddrError
if !errors.As(err, &addrError) || addrError.Err != "missing port in address" {
if addrError, ok := err.(*net.AddrError); !ok || addrError.Err != "missing port in address" {
return err
}
}
Expand Down
7 changes: 2 additions & 5 deletions x/mongo/driver/integration/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package integration

import (
"context"
"errors"
"flag"
"fmt"
"os"
Expand Down Expand Up @@ -70,8 +69,7 @@ func autherr(t *testing.T, err error) {
t.Helper()
switch e := err.(type) {
case topology.ConnectionError:
var authErr *auth.Error
if !errors.As(e.Wrapped, &authErr) {
if _, ok := e.Wrapped.(*auth.Error); !ok {
t.Fatal("Expected auth error and didn't get one")
}
case *auth.Error:
Expand Down Expand Up @@ -135,8 +133,7 @@ func dropCollection(t *testing.T, dbname, colname string) {
err := operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "drop", colname))).
Database(dbname).ServerSelector(description.WriteSelector()).Deployment(integtest.Topology(t)).
Execute(context.Background())
var de driver.Error
if err != nil && !(errors.As(err, &de) && de.NamespaceNotFound()) {
if de, ok := err.(driver.Error); err != nil && !(ok && de.NamespaceNotFound()) {
require.NoError(t, err)
}
}
Expand Down
6 changes: 2 additions & 4 deletions x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ func convertInt64PtrToInt32Ptr(i64 *int64) *int32 {
// write errors are included since the actual command did succeed, only writes
// failed.
func (info finishedInformation) success() bool {
var writeCmdErr WriteCommandError
if errors.As(info.cmdErr, &writeCmdErr) {
if _, ok := info.cmdErr.(WriteCommandError); ok {
return true
}

Expand Down Expand Up @@ -627,8 +626,7 @@ func (op Operation) Execute(ctx context.Context) error {
// If the returned error is retryable and there are retries remaining (negative
// retries means retry indefinitely), then retry the operation. Set the server
// and connection to nil to request a new server and connection.
var rerr RetryablePoolError
if errors.As(err, &rerr) && rerr.Retryable() && retries != 0 {
if rerr, ok := err.(RetryablePoolError); ok && rerr.Retryable() && retries != 0 {
resetForRetry(err)
continue
}
Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/operation/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ func (c *Count) Execute(ctx context.Context) error {

// Swallow error if NamespaceNotFound(26) is returned from aggregate on non-existent namespace
if err != nil {
var dErr driver.Error
if errors.As(err, &dErr) && dErr.Code == 26 {
dErr, ok := err.(driver.Error)
if ok && dErr.Code == 26 {
err = nil
}
}
Expand Down
3 changes: 1 addition & 2 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,7 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
if !contextDeadlineUsed {
return originalError
}
var netErr net.Error
if errors.As(originalError, &netErr) && netErr.Timeout() {
if netErr, ok := originalError.(net.Error); ok && netErr.Timeout() {
return context.DeadlineExceeded
}

Expand Down
6 changes: 2 additions & 4 deletions x/mongo/driver/topology/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,7 @@ func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessE
}

// Ignore transient timeout errors.
var netErr net.Error
if errors.As(wrappedConnErr, &netErr) && netErr.Timeout() {
if netErr, ok := wrappedConnErr.(net.Error); ok && netErr.Timeout() {
return driver.NoChange
}
if errors.Is(wrappedConnErr, context.Canceled) || errors.Is(wrappedConnErr, context.DeadlineExceeded) {
Expand Down Expand Up @@ -630,8 +629,7 @@ func (s *Server) update() {
// We want to immediately retry on timeout error. Continue to next loop.
return true
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
if err, ok := err.(net.Error); ok && err.Timeout() {
timeoutCnt++
// We want to immediately retry on timeout error. Continue to next loop.
return true
Expand Down

0 comments on commit 6612270

Please sign in to comment.