Skip to content

Commit

Permalink
feat(go/adbc/driver/flightsql): add context to gRPC errors
Browse files Browse the repository at this point in the history
See #862.
  • Loading branch information
lidavidm committed Jul 20, 2023
1 parent 5620b03 commit 0e28c06
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 33 deletions.
48 changes: 24 additions & 24 deletions go/adbc/driver/flightsql/flightsql_adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -892,10 +892,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
info, err := c.cl.GetSqlInfo(ctx, translated, c.timeouts)
if err == nil {
for _, endpoint := range info.Endpoint {
for i, endpoint := range info.Endpoint {
rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, c.timeouts)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location)
}

for rdr.Next() {
Expand All @@ -922,11 +922,11 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode) (array.Re
}

if rdr.Err() != nil {
return nil, adbcFromFlightStatus(rdr.Err())
return nil, adbcFromFlightStatus(rdr.Err(), "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location)
}
}
} else if grpcstatus.Code(err) != grpccodes.Unimplemented {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)")
}

final := bldr.NewRecord()
Expand Down Expand Up @@ -1032,12 +1032,12 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *
// To avoid an N+1 query problem, we assume result sets here will fit in memory and build up a single response.
info, err := c.cl.GetCatalogs(ctx)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
}

rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
}
defer rdr.Release()

Expand All @@ -1058,7 +1058,7 @@ func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog *
}

if err = rdr.Err(); err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
}

return g.Finish()
Expand All @@ -1069,7 +1069,7 @@ func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info
// use a default queueSize for the reader
rdr, err := newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "DoGet")
}

if !rdr.Schema().Equal(expectedSchema) {
Expand All @@ -1091,12 +1091,12 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,
// Pre-populate the map of which schemas are in which catalogs
info, err := c.cl.GetDBSchemas(ctx, &flightsql.GetDBSchemasOpts{DbSchemaFilterPattern: dbSchema})
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)")
}

rdr, err := c.readInfo(ctx, schema_ref.DBSchemas, info)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetDBSchemas)")
}
defer rdr.Release()

Expand All @@ -1117,7 +1117,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,

if rdr.Err() != nil {
result = nil
err = adbcFromFlightStatus(rdr.Err())
err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetDBSchemas)")
}
return
}
Expand All @@ -1137,7 +1137,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat
IncludeSchema: includeSchema,
})
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)")
}

expectedSchema := schema_ref.Tables
Expand All @@ -1146,7 +1146,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat
}
rdr, err := c.readInfo(ctx, expectedSchema, info)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)")
}
defer rdr.Release()

Expand Down Expand Up @@ -1195,7 +1195,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, cat

if rdr.Err() != nil {
result = nil
err = adbcFromFlightStatus(rdr.Err())
err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetTables)")
}
return
}
Expand All @@ -1211,12 +1211,12 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
info, err := c.cl.GetTables(ctx, opts, c.timeouts)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetTableSchema(GetTables)")
}

rdr, err := doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)")
}
defer rdr.Release()

Expand All @@ -1228,7 +1228,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
Code: adbc.StatusNotFound,
}
}
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)")
}

if rec.NumRows() == 0 {
Expand All @@ -1246,7 +1246,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema *st
schemaBytes := rec.Column(4).(*array.Binary).Value(0)
s, err := flight.DeserializeSchema(schemaBytes, c.db.alloc)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetTableSchema")
}
return s, nil
}
Expand All @@ -1262,7 +1262,7 @@ func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
info, err := c.cl.GetTableTypes(ctx, c.timeouts)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "GetTableTypes")
}

return newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5)
Expand All @@ -1289,12 +1289,12 @@ func (c *cnxn) Commit(ctx context.Context) error {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
err := c.txn.Commit(ctx, c.timeouts)
if err != nil {
return adbcFromFlightStatus(err)
return adbcFromFlightStatus(err, "Commit")
}

c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts)
if err != nil {
return adbcFromFlightStatus(err)
return adbcFromFlightStatus(err, "BeginTransaction")
}
return nil
}
Expand All @@ -1320,12 +1320,12 @@ func (c *cnxn) Rollback(ctx context.Context) error {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
err := c.txn.Rollback(ctx, c.timeouts)
if err != nil {
return adbcFromFlightStatus(err)
return adbcFromFlightStatus(err, "Rollback")
}

c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts)
if err != nil {
return adbcFromFlightStatus(err)
return adbcFromFlightStatus(err, "BeginTransaction")
}
return nil
}
Expand Down Expand Up @@ -1428,7 +1428,7 @@ func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte) (r
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
rdr, err = doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "ReadPartition(DoGet)")
}
return rdr, nil
}
Expand Down
10 changes: 5 additions & 5 deletions go/adbc/driver/flightsql/flightsql_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr array.RecordReader, n
}

if err != nil {
return nil, -1, adbcFromFlightStatus(err)
return nil, -1, adbcFromFlightStatus(err, "ExecuteQuery")
}

nrec = info.TotalRecords
Expand All @@ -259,7 +259,7 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) {
}

if err != nil {
err = adbcFromFlightStatus(err)
err = adbcFromFlightStatus(err, "ExecuteUpdate")
}

return
Expand All @@ -271,7 +271,7 @@ func (s *statement) Prepare(ctx context.Context) error {
ctx = metadata.NewOutgoingContext(ctx, s.hdrs)
prep, err := s.query.prepare(ctx, s.cnxn, s.timeouts)
if err != nil {
return adbcFromFlightStatus(err)
return adbcFromFlightStatus(err, "Prepare")
}
s.prepared = prep
return nil
Expand Down Expand Up @@ -394,13 +394,13 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
}

if err != nil {
return nil, out, -1, adbcFromFlightStatus(err)
return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions")
}

if len(info.Schema) > 0 {
sc, err = flight.DeserializeSchema(info.Schema, s.alloc)
if err != nil {
return nil, out, -1, adbcFromFlightStatus(err)
return nil, out, -1, adbcFromFlightStatus(err, "ExecutePartitions: could not deserialize FlightInfo schema:")
}
}

Expand Down
4 changes: 2 additions & 2 deletions go/adbc/driver/flightsql/record_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.
} else {
rdr, err := doGet(ctx, cl, endpoints[0], clCache, opts...)
if err != nil {
return nil, adbcFromFlightStatus(err)
return nil, adbcFromFlightStatus(err, "DoGet: endpoint 0: remote: %s", endpoints[0].Location)
}
schema = rdr.Schema()
group.Go(func() error {
Expand Down Expand Up @@ -135,7 +135,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, cl *flightsql.

rdr, err := doGet(ctx, cl, endpoint, clCache, opts...)
if err != nil {
return err
return adbcFromFlightStatus(err, "DoGet: endpoint %d: %s", endpointIndex, endpoint.Location)
}
defer rdr.Release()

Expand Down
7 changes: 5 additions & 2 deletions go/adbc/driver/flightsql/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
package flightsql

import (
"fmt"

"github.com/apache/arrow-adbc/go/adbc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func adbcFromFlightStatus(err error) error {
func adbcFromFlightStatus(err error, context string, args ...any) error {
if _, ok := err.(adbc.Error); ok {
return err
}
Expand Down Expand Up @@ -70,8 +72,9 @@ func adbcFromFlightStatus(err error) error {
adbcCode = adbc.StatusUnknown
}

// People don't read error messages, so backload the context and frontload the server error
return adbc.Error{
Msg: grpcStatus.Message(),
Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)),
Code: adbcCode,
}
}

0 comments on commit 0e28c06

Please sign in to comment.