From 9118a8859528afd8915a8c6a6feddf5ebaa3d683 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 8 Feb 2024 12:41:39 -0500
Subject: [PATCH] test(go/adbc/driver/flightsql): test handling of bad
locations
Add a test to ensure that the driver can still fetch data from a
server that returns 1 unreachable and 1 reachable location.
Related to #1527.
---
docs/source/driver/flight_sql.rst | 6 ++
.../flightsql/flightsql_adbc_server_test.go | 63 ++++++++++++++++++-
.../driver/flightsql/flightsql_database.go | 23 ++++++-
go/adbc/driver/flightsql/flightsql_driver.go | 8 ++-
go/adbc/driver/flightsql/timeouts.go | 12 ++++
5 files changed, 106 insertions(+), 6 deletions(-)
diff --git a/docs/source/driver/flight_sql.rst b/docs/source/driver/flight_sql.rst
index 7473a7cb4c..a9067fcb74 100644
--- a/docs/source/driver/flight_sql.rst
+++ b/docs/source/driver/flight_sql.rst
@@ -326,6 +326,12 @@ The options are as follows:
For example, this controls the timeout of the underlying Flight
calls that implement bulk ingestion, or transaction support.
+There is also a timeout that is set on the :cpp:class:`AdbcDatabase`:
+
+``adbc.flight.sql.rpc.timeout_seconds.connect``
+ A timeout (in floating-point seconds) for establishing a connection. The
+ default is 20 seconds.
+
Transactions
------------
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 3df7c9a33c..44ebb1b59f 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -23,6 +23,7 @@ import (
"context"
"errors"
"fmt"
+ "net"
"net/textproto"
"os"
"strconv"
@@ -810,11 +811,18 @@ func (ts *IncrementalPollTests) TestQueryTransaction() {
type TimeoutTestServer struct {
flightsql.BaseServer
+ badPort int
+ goodPort int
}
func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
- if string(tkt.GetStatementHandle()) == "sleep and succeed" {
+ ticket := string(tkt.GetStatementHandle())
+ if ticket == "sleep and succeed" {
time.Sleep(1 * time.Second)
+ }
+
+ switch ticket {
+ case "bad endpoint", "sleep and succeed":
sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, strings.NewReader(`[{"a": 5}]`))
if err != nil {
@@ -850,6 +858,23 @@ func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx context.Context, cmd fli
switch cmd.GetQuery() {
case "timeout":
<-ctx.Done()
+ case "bad endpoint":
+ tkt, _ := flightsql.CreateStatementQueryTicket([]byte("bad endpoint"))
+ info := &flight.FlightInfo{
+ FlightDescriptor: desc,
+ Endpoint: []*flight.FlightEndpoint{
+ {
+ Ticket: &flight.Ticket{Ticket: tkt},
+ Location: []*flight.Location{
+ {Uri: fmt.Sprintf("grpc://localhost:%d", ts.badPort)},
+ {Uri: fmt.Sprintf("grpc://localhost:%d", ts.goodPort)},
+ },
+ },
+ },
+ TotalRecords: -1,
+ TotalBytes: -1,
+ }
+ return info, nil
case "fetch":
tkt, _ := flightsql.CreateStatementQueryTicket([]byte("fetch"))
info := &flight.FlightInfo{
@@ -884,10 +909,23 @@ func (ts *TimeoutTestServer) CreatePreparedStatement(ctx context.Context, req fl
type TimeoutTests struct {
ServerBasedTests
+ server net.Listener
}
func (suite *TimeoutTests) SetupSuite() {
- suite.DoSetupSuite(&TimeoutTestServer{}, nil, nil)
+ var err error
+ suite.server, err = net.Listen("tcp", "localhost:0")
+ suite.NoError(err)
+
+ badPort := suite.server.Addr().(*net.TCPAddr).Port
+ server := &TimeoutTestServer{badPort: badPort}
+ suite.DoSetupSuite(server, nil, nil)
+ server.goodPort = suite.s.Addr().(*net.TCPAddr).Port
+}
+
+func (suite *TimeoutTests) TearDownSuite() {
+ suite.ServerBasedTests.TearDownSuite()
+ suite.NoError(suite.server.Close())
}
func (ts *TimeoutTests) TestInvalidValues() {
@@ -1075,6 +1113,27 @@ func (ts *TimeoutTests) TestDontTimeout() {
ts.Truef(array.RecordEqual(rec, expected), "expected: %s\nactual: %s", expected, rec)
}
+func (ts *TimeoutTests) TestBadAddress() {
+ stmt, err := ts.cnxn.NewStatement()
+ ts.Require().NoError(err)
+ defer stmt.Close()
+ ts.Require().NoError(stmt.SetSqlQuery("bad endpoint"))
+
+ ts.Require().NoError(ts.db.(adbc.GetSetOptions).SetOptionDouble(driver.OptionTimeoutConnect, 5))
+
+ rr, _, err := stmt.ExecuteQuery(context.Background())
+ ts.Require().NoError(err)
+ defer rr.Release()
+
+ rr, _, err = stmt.ExecuteQuery(context.Background())
+ ts.Require().NoError(err)
+ defer rr.Release()
+
+ rr, _, err = stmt.ExecuteQuery(context.Background())
+ ts.Require().NoError(err)
+ defer rr.Release()
+}
+
// ---- Cookie Tests --------------------
type CookieTestServer struct {
flightsql.BaseServer
diff --git a/go/adbc/driver/flightsql/flightsql_database.go b/go/adbc/driver/flightsql/flightsql_database.go
index 1407fedf18..5e5e3af978 100644
--- a/go/adbc/driver/flightsql/flightsql_database.go
+++ b/go/adbc/driver/flightsql/flightsql_database.go
@@ -194,6 +194,13 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error {
delete(cnOptions, OptionTimeoutUpdate)
}
+ if tv, ok := cnOptions[OptionTimeoutConnect]; ok {
+ if err = d.timeout.setTimeoutString(OptionTimeoutConnect, tv); err != nil {
+ return err
+ }
+ delete(cnOptions, OptionTimeoutConnect)
+ }
+
if val, ok := cnOptions[OptionWithBlock]; ok {
if val == adbc.OptionValueEnabled {
d.dialOpts.block = true
@@ -257,6 +264,8 @@ func (d *databaseImpl) GetOption(key string) (string, error) {
return d.timeout.queryTimeout.String(), nil
case OptionTimeoutUpdate:
return d.timeout.updateTimeout.String(), nil
+ case OptionTimeoutConnect:
+ return d.timeout.connectTimeout.String(), nil
}
if val, ok := d.options[key]; ok {
return val, nil
@@ -271,6 +280,8 @@ func (d *databaseImpl) GetOptionInt(key string) (int64, error) {
case OptionTimeoutQuery:
fallthrough
case OptionTimeoutUpdate:
+ fallthrough
+ case OptionTimeoutConnect:
val, err := d.GetOptionDouble(key)
if err != nil {
return 0, err
@@ -289,6 +300,8 @@ func (d *databaseImpl) GetOptionDouble(key string) (float64, error) {
return d.timeout.queryTimeout.Seconds(), nil
case OptionTimeoutUpdate:
return d.timeout.updateTimeout.Seconds(), nil
+ case OptionTimeoutConnect:
+ return d.timeout.connectTimeout.Seconds(), nil
}
return d.DatabaseImplBase.GetOptionDouble(key)
@@ -297,7 +310,7 @@ func (d *databaseImpl) GetOptionDouble(key string) (float64, error) {
func (d *databaseImpl) SetOption(key, value string) error {
// We can't change most options post-init
switch key {
- case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate:
+ case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate, OptionTimeoutConnect:
return d.timeout.setTimeoutString(key, value)
}
if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
@@ -313,6 +326,8 @@ func (d *databaseImpl) SetOptionInt(key string, value int64) error {
case OptionTimeoutQuery:
fallthrough
case OptionTimeoutUpdate:
+ fallthrough
+ case OptionTimeoutConnect:
return d.timeout.setTimeout(key, float64(value))
}
@@ -326,6 +341,8 @@ func (d *databaseImpl) SetOptionDouble(key string, value float64) error {
case OptionTimeoutQuery:
fallthrough
case OptionTimeoutUpdate:
+ fallthrough
+ case OptionTimeoutConnect:
return d.timeout.setTimeout(key, value)
}
@@ -366,8 +383,9 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl
creds = insecure.NewCredentials()
target = "unix:" + uri.Path
}
- dialOpts := append(d.dialOpts.opts, grpc.WithTransportCredentials(creds))
+ dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds))
+ d.Logger.DebugContext(ctx, "new client", "location", loc)
cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
if err != nil {
return nil, adbc.Error{
@@ -395,7 +413,6 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl
}
}
- d.Logger.DebugContext(ctx, "new client", "location", loc)
return cl, nil
}
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go
index 727ed827ca..4914ad1cba 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -35,6 +35,7 @@ import (
"net/url"
"runtime/debug"
"strings"
+ "time"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
@@ -53,6 +54,7 @@ const (
OptionWithBlock = "adbc.flight.sql.client_option.with_block"
OptionWithMaxMsgSize = "adbc.flight.sql.client_option.with_max_msg_size"
OptionAuthorizationHeader = "adbc.flight.sql.authorization_header"
+ OptionTimeoutConnect = "adbc.flight.sql.rpc.timeout_seconds.connect"
OptionTimeoutFetch = "adbc.flight.sql.rpc.timeout_seconds.fetch"
OptionTimeoutQuery = "adbc.flight.sql.rpc.timeout_seconds.query"
OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update"
@@ -126,7 +128,11 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error)
db := &databaseImpl{
DatabaseImplBase: driverbase.NewDatabaseImplBase(&d.DriverImplBase),
- hdrs: make(metadata.MD),
+ timeout: timeoutOption{
+ // Match gRPC default
+ connectTimeout: time.Second * 20,
+ },
+ hdrs: make(metadata.MD),
}
var err error
diff --git a/go/adbc/driver/flightsql/timeouts.go b/go/adbc/driver/flightsql/timeouts.go
index 7737526860..db39ca4622 100644
--- a/go/adbc/driver/flightsql/timeouts.go
+++ b/go/adbc/driver/flightsql/timeouts.go
@@ -28,6 +28,7 @@ import (
"github.com/apache/arrow-adbc/go/adbc"
"google.golang.org/grpc"
+ "google.golang.org/grpc/backoff"
"google.golang.org/grpc/metadata"
)
@@ -40,6 +41,8 @@ type timeoutOption struct {
queryTimeout time.Duration
// timeout for DoPut or DoAction requests
updateTimeout time.Duration
+ // timeout for establishing a new connection
+ connectTimeout time.Duration
}
func (t *timeoutOption) setTimeout(key string, value float64) error {
@@ -60,6 +63,8 @@ func (t *timeoutOption) setTimeout(key string, value float64) error {
t.queryTimeout = timeout
case OptionTimeoutUpdate:
t.updateTimeout = timeout
+ case OptionTimeoutConnect:
+ t.connectTimeout = timeout
default:
return adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option '%s'", key),
@@ -81,6 +86,13 @@ func (t *timeoutOption) setTimeoutString(key string, value string) error {
return t.setTimeout(key, timeout)
}
+func (t *timeoutOption) connectParams() grpc.ConnectParams {
+ return grpc.ConnectParams{
+ Backoff: backoff.DefaultConfig,
+ MinConnectTimeout: t.connectTimeout,
+ }
+}
+
func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration, bool) {
for _, opt := range callOptions {
if to, ok := opt.(timeoutOption); ok {