Skip to content

Commit

Permalink
test(go/adbc/driver/flightsql): test handling of bad locations
Browse files Browse the repository at this point in the history
Add a test to ensure that the driver can still fetch data from a
server that returns 1 unreachable and 1 reachable location.

Related to #1527.
  • Loading branch information
lidavidm committed Feb 8, 2024
1 parent 0b103c8 commit 9118a88
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 6 deletions.
6 changes: 6 additions & 0 deletions docs/source/driver/flight_sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ The options are as follows:
For example, this controls the timeout of the underlying Flight
calls that implement bulk ingestion, or transaction support.

There is also a timeout that is set on the :cpp:class:`AdbcDatabase`:

``adbc.flight.sql.rpc.timeout_seconds.connect``
A timeout (in floating-point seconds) for establishing a connection. The
default is 20 seconds.

Transactions
------------

Expand Down
63 changes: 61 additions & 2 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"context"
"errors"
"fmt"
"net"
"net/textproto"
"os"
"strconv"
Expand Down Expand Up @@ -810,11 +811,18 @@ func (ts *IncrementalPollTests) TestQueryTransaction() {

type TimeoutTestServer struct {
flightsql.BaseServer
badPort int
goodPort int
}

func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
if string(tkt.GetStatementHandle()) == "sleep and succeed" {
ticket := string(tkt.GetStatementHandle())
if ticket == "sleep and succeed" {
time.Sleep(1 * time.Second)
}

switch ticket {
case "bad endpoint", "sleep and succeed":
sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, strings.NewReader(`[{"a": 5}]`))
if err != nil {
Expand Down Expand Up @@ -850,6 +858,23 @@ func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx context.Context, cmd fli
switch cmd.GetQuery() {
case "timeout":
<-ctx.Done()
case "bad endpoint":
tkt, _ := flightsql.CreateStatementQueryTicket([]byte("bad endpoint"))
info := &flight.FlightInfo{
FlightDescriptor: desc,
Endpoint: []*flight.FlightEndpoint{
{
Ticket: &flight.Ticket{Ticket: tkt},
Location: []*flight.Location{
{Uri: fmt.Sprintf("grpc://localhost:%d", ts.badPort)},
{Uri: fmt.Sprintf("grpc://localhost:%d", ts.goodPort)},
},
},
},
TotalRecords: -1,
TotalBytes: -1,
}
return info, nil
case "fetch":
tkt, _ := flightsql.CreateStatementQueryTicket([]byte("fetch"))
info := &flight.FlightInfo{
Expand Down Expand Up @@ -884,10 +909,23 @@ func (ts *TimeoutTestServer) CreatePreparedStatement(ctx context.Context, req fl

type TimeoutTests struct {
ServerBasedTests
server net.Listener
}

func (suite *TimeoutTests) SetupSuite() {
suite.DoSetupSuite(&TimeoutTestServer{}, nil, nil)
var err error
suite.server, err = net.Listen("tcp", "localhost:0")
suite.NoError(err)

badPort := suite.server.Addr().(*net.TCPAddr).Port
server := &TimeoutTestServer{badPort: badPort}
suite.DoSetupSuite(server, nil, nil)
server.goodPort = suite.s.Addr().(*net.TCPAddr).Port
}

func (suite *TimeoutTests) TearDownSuite() {
suite.ServerBasedTests.TearDownSuite()
suite.NoError(suite.server.Close())
}

func (ts *TimeoutTests) TestInvalidValues() {
Expand Down Expand Up @@ -1075,6 +1113,27 @@ func (ts *TimeoutTests) TestDontTimeout() {
ts.Truef(array.RecordEqual(rec, expected), "expected: %s\nactual: %s", expected, rec)
}

func (ts *TimeoutTests) TestBadAddress() {
stmt, err := ts.cnxn.NewStatement()
ts.Require().NoError(err)
defer stmt.Close()
ts.Require().NoError(stmt.SetSqlQuery("bad endpoint"))

ts.Require().NoError(ts.db.(adbc.GetSetOptions).SetOptionDouble(driver.OptionTimeoutConnect, 5))

rr, _, err := stmt.ExecuteQuery(context.Background())
ts.Require().NoError(err)
defer rr.Release()

rr, _, err = stmt.ExecuteQuery(context.Background())
ts.Require().NoError(err)
defer rr.Release()

rr, _, err = stmt.ExecuteQuery(context.Background())
ts.Require().NoError(err)
defer rr.Release()
}

// ---- Cookie Tests --------------------
type CookieTestServer struct {
flightsql.BaseServer
Expand Down
23 changes: 20 additions & 3 deletions go/adbc/driver/flightsql/flightsql_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,13 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error {
delete(cnOptions, OptionTimeoutUpdate)
}

if tv, ok := cnOptions[OptionTimeoutConnect]; ok {
if err = d.timeout.setTimeoutString(OptionTimeoutConnect, tv); err != nil {
return err
}
delete(cnOptions, OptionTimeoutConnect)
}

if val, ok := cnOptions[OptionWithBlock]; ok {
if val == adbc.OptionValueEnabled {
d.dialOpts.block = true
Expand Down Expand Up @@ -257,6 +264,8 @@ func (d *databaseImpl) GetOption(key string) (string, error) {
return d.timeout.queryTimeout.String(), nil
case OptionTimeoutUpdate:
return d.timeout.updateTimeout.String(), nil
case OptionTimeoutConnect:
return d.timeout.connectTimeout.String(), nil
}
if val, ok := d.options[key]; ok {
return val, nil
Expand All @@ -271,6 +280,8 @@ func (d *databaseImpl) GetOptionInt(key string) (int64, error) {
case OptionTimeoutQuery:
fallthrough
case OptionTimeoutUpdate:
fallthrough
case OptionTimeoutConnect:
val, err := d.GetOptionDouble(key)
if err != nil {
return 0, err
Expand All @@ -289,6 +300,8 @@ func (d *databaseImpl) GetOptionDouble(key string) (float64, error) {
return d.timeout.queryTimeout.Seconds(), nil
case OptionTimeoutUpdate:
return d.timeout.updateTimeout.Seconds(), nil
case OptionTimeoutConnect:
return d.timeout.connectTimeout.Seconds(), nil
}

return d.DatabaseImplBase.GetOptionDouble(key)
Expand All @@ -297,7 +310,7 @@ func (d *databaseImpl) GetOptionDouble(key string) (float64, error) {
func (d *databaseImpl) SetOption(key, value string) error {
// We can't change most options post-init
switch key {
case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate:
case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate, OptionTimeoutConnect:
return d.timeout.setTimeoutString(key, value)
}
if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
Expand All @@ -313,6 +326,8 @@ func (d *databaseImpl) SetOptionInt(key string, value int64) error {
case OptionTimeoutQuery:
fallthrough
case OptionTimeoutUpdate:
fallthrough
case OptionTimeoutConnect:
return d.timeout.setTimeout(key, float64(value))
}

Expand All @@ -326,6 +341,8 @@ func (d *databaseImpl) SetOptionDouble(key string, value float64) error {
case OptionTimeoutQuery:
fallthrough
case OptionTimeoutUpdate:
fallthrough
case OptionTimeoutConnect:
return d.timeout.setTimeout(key, value)
}

Expand Down Expand Up @@ -366,8 +383,9 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl
creds = insecure.NewCredentials()
target = "unix:" + uri.Path
}
dialOpts := append(d.dialOpts.opts, grpc.WithTransportCredentials(creds))
dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds))

d.Logger.DebugContext(ctx, "new client", "location", loc)
cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
if err != nil {
return nil, adbc.Error{
Expand Down Expand Up @@ -395,7 +413,6 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl
}
}

d.Logger.DebugContext(ctx, "new client", "location", loc)
return cl, nil
}

Expand Down
8 changes: 7 additions & 1 deletion go/adbc/driver/flightsql/flightsql_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"net/url"
"runtime/debug"
"strings"
"time"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
Expand All @@ -53,6 +54,7 @@ const (
OptionWithBlock = "adbc.flight.sql.client_option.with_block"
OptionWithMaxMsgSize = "adbc.flight.sql.client_option.with_max_msg_size"
OptionAuthorizationHeader = "adbc.flight.sql.authorization_header"
OptionTimeoutConnect = "adbc.flight.sql.rpc.timeout_seconds.connect"
OptionTimeoutFetch = "adbc.flight.sql.rpc.timeout_seconds.fetch"
OptionTimeoutQuery = "adbc.flight.sql.rpc.timeout_seconds.query"
OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update"
Expand Down Expand Up @@ -126,7 +128,11 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error)

db := &databaseImpl{
DatabaseImplBase: driverbase.NewDatabaseImplBase(&d.DriverImplBase),
hdrs: make(metadata.MD),
timeout: timeoutOption{
// Match gRPC default
connectTimeout: time.Second * 20,
},
hdrs: make(metadata.MD),
}

var err error
Expand Down
12 changes: 12 additions & 0 deletions go/adbc/driver/flightsql/timeouts.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/apache/arrow-adbc/go/adbc"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/metadata"
)

Expand All @@ -40,6 +41,8 @@ type timeoutOption struct {
queryTimeout time.Duration
// timeout for DoPut or DoAction requests
updateTimeout time.Duration
// timeout for establishing a new connection
connectTimeout time.Duration
}

func (t *timeoutOption) setTimeout(key string, value float64) error {
Expand All @@ -60,6 +63,8 @@ func (t *timeoutOption) setTimeout(key string, value float64) error {
t.queryTimeout = timeout
case OptionTimeoutUpdate:
t.updateTimeout = timeout
case OptionTimeoutConnect:
t.connectTimeout = timeout
default:
return adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option '%s'", key),
Expand All @@ -81,6 +86,13 @@ func (t *timeoutOption) setTimeoutString(key string, value string) error {
return t.setTimeout(key, timeout)
}

func (t *timeoutOption) connectParams() grpc.ConnectParams {
return grpc.ConnectParams{
Backoff: backoff.DefaultConfig,
MinConnectTimeout: t.connectTimeout,
}
}

func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration, bool) {
for _, opt := range callOptions {
if to, ok := opt.(timeoutOption); ok {
Expand Down

0 comments on commit 9118a88

Please sign in to comment.