diff --git a/docs/source/driver/flight_sql.rst b/docs/source/driver/flight_sql.rst index 7473a7cb4c..a9067fcb74 100644 --- a/docs/source/driver/flight_sql.rst +++ b/docs/source/driver/flight_sql.rst @@ -326,6 +326,12 @@ The options are as follows: For example, this controls the timeout of the underlying Flight calls that implement bulk ingestion, or transaction support. +There is also a timeout that is set on the :cpp:class:`AdbcDatabase`: + +``adbc.flight.sql.rpc.timeout_seconds.connect`` + A timeout (in floating-point seconds) for establishing a connection. The + default is 20 seconds. + Transactions ------------ diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index 3df7c9a33c..44ebb1b59f 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -23,6 +23,7 @@ import ( "context" "errors" "fmt" + "net" "net/textproto" "os" "strconv" @@ -810,11 +811,18 @@ func (ts *IncrementalPollTests) TestQueryTransaction() { type TimeoutTestServer struct { flightsql.BaseServer + badPort int + goodPort int } func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) { - if string(tkt.GetStatementHandle()) == "sleep and succeed" { + ticket := string(tkt.GetStatementHandle()) + if ticket == "sleep and succeed" { time.Sleep(1 * time.Second) + } + + switch ticket { + case "bad endpoint", "sleep and succeed": sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, strings.NewReader(`[{"a": 5}]`)) if err != nil { @@ -850,6 +858,23 @@ func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx context.Context, cmd fli switch cmd.GetQuery() { case "timeout": <-ctx.Done() + case "bad endpoint": + tkt, _ := flightsql.CreateStatementQueryTicket([]byte("bad endpoint")) + info := &flight.FlightInfo{ + FlightDescriptor: desc, + Endpoint: []*flight.FlightEndpoint{ + { + Ticket: &flight.Ticket{Ticket: tkt}, + Location: []*flight.Location{ + {Uri: fmt.Sprintf("grpc://localhost:%d", ts.badPort)}, + {Uri: fmt.Sprintf("grpc://localhost:%d", ts.goodPort)}, + }, + }, + }, + TotalRecords: -1, + TotalBytes: -1, + } + return info, nil case "fetch": tkt, _ := flightsql.CreateStatementQueryTicket([]byte("fetch")) info := &flight.FlightInfo{ @@ -884,10 +909,23 @@ func (ts *TimeoutTestServer) CreatePreparedStatement(ctx context.Context, req fl type TimeoutTests struct { ServerBasedTests + server net.Listener } func (suite *TimeoutTests) SetupSuite() { - suite.DoSetupSuite(&TimeoutTestServer{}, nil, nil) + var err error + suite.server, err = net.Listen("tcp", "localhost:0") + suite.NoError(err) + + badPort := suite.server.Addr().(*net.TCPAddr).Port + server := &TimeoutTestServer{badPort: badPort} + suite.DoSetupSuite(server, nil, nil) + server.goodPort = suite.s.Addr().(*net.TCPAddr).Port +} + +func (suite *TimeoutTests) TearDownSuite() { + suite.ServerBasedTests.TearDownSuite() + suite.NoError(suite.server.Close()) } func (ts *TimeoutTests) TestInvalidValues() { @@ -1075,6 +1113,27 @@ func (ts *TimeoutTests) TestDontTimeout() { ts.Truef(array.RecordEqual(rec, expected), "expected: %s\nactual: %s", expected, rec) } +func (ts *TimeoutTests) TestBadAddress() { + stmt, err := ts.cnxn.NewStatement() + ts.Require().NoError(err) + defer stmt.Close() + ts.Require().NoError(stmt.SetSqlQuery("bad endpoint")) + + ts.Require().NoError(ts.db.(adbc.GetSetOptions).SetOptionDouble(driver.OptionTimeoutConnect, 5)) + + rr, _, err := stmt.ExecuteQuery(context.Background()) + ts.Require().NoError(err) + defer rr.Release() + + rr, _, err = stmt.ExecuteQuery(context.Background()) + ts.Require().NoError(err) + defer rr.Release() + + rr, _, err = stmt.ExecuteQuery(context.Background()) + ts.Require().NoError(err) + defer rr.Release() +} + // ---- Cookie Tests -------------------- type CookieTestServer struct { flightsql.BaseServer diff --git a/go/adbc/driver/flightsql/flightsql_database.go b/go/adbc/driver/flightsql/flightsql_database.go index 1407fedf18..5e5e3af978 100644 --- a/go/adbc/driver/flightsql/flightsql_database.go +++ b/go/adbc/driver/flightsql/flightsql_database.go @@ -194,6 +194,13 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error { delete(cnOptions, OptionTimeoutUpdate) } + if tv, ok := cnOptions[OptionTimeoutConnect]; ok { + if err = d.timeout.setTimeoutString(OptionTimeoutConnect, tv); err != nil { + return err + } + delete(cnOptions, OptionTimeoutConnect) + } + if val, ok := cnOptions[OptionWithBlock]; ok { if val == adbc.OptionValueEnabled { d.dialOpts.block = true @@ -257,6 +264,8 @@ func (d *databaseImpl) GetOption(key string) (string, error) { return d.timeout.queryTimeout.String(), nil case OptionTimeoutUpdate: return d.timeout.updateTimeout.String(), nil + case OptionTimeoutConnect: + return d.timeout.connectTimeout.String(), nil } if val, ok := d.options[key]; ok { return val, nil @@ -271,6 +280,8 @@ func (d *databaseImpl) GetOptionInt(key string) (int64, error) { case OptionTimeoutQuery: fallthrough case OptionTimeoutUpdate: + fallthrough + case OptionTimeoutConnect: val, err := d.GetOptionDouble(key) if err != nil { return 0, err @@ -289,6 +300,8 @@ func (d *databaseImpl) GetOptionDouble(key string) (float64, error) { return d.timeout.queryTimeout.Seconds(), nil case OptionTimeoutUpdate: return d.timeout.updateTimeout.Seconds(), nil + case OptionTimeoutConnect: + return d.timeout.connectTimeout.Seconds(), nil } return d.DatabaseImplBase.GetOptionDouble(key) @@ -297,7 +310,7 @@ func (d *databaseImpl) GetOptionDouble(key string) (float64, error) { func (d *databaseImpl) SetOption(key, value string) error { // We can't change most options post-init switch key { - case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: + case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate, OptionTimeoutConnect: return d.timeout.setTimeoutString(key, value) } if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { @@ -313,6 +326,8 @@ func (d *databaseImpl) SetOptionInt(key string, value int64) error { case OptionTimeoutQuery: fallthrough case OptionTimeoutUpdate: + fallthrough + case OptionTimeoutConnect: return d.timeout.setTimeout(key, float64(value)) } @@ -326,6 +341,8 @@ func (d *databaseImpl) SetOptionDouble(key string, value float64) error { case OptionTimeoutQuery: fallthrough case OptionTimeoutUpdate: + fallthrough + case OptionTimeoutConnect: return d.timeout.setTimeout(key, value) } @@ -366,8 +383,9 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl creds = insecure.NewCredentials() target = "unix:" + uri.Path } - dialOpts := append(d.dialOpts.opts, grpc.WithTransportCredentials(creds)) + dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds)) + d.Logger.DebugContext(ctx, "new client", "location", loc) cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...) if err != nil { return nil, adbc.Error{ @@ -395,7 +413,6 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl } } - d.Logger.DebugContext(ctx, "new client", "location", loc) return cl, nil } diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go index 727ed827ca..4914ad1cba 100644 --- a/go/adbc/driver/flightsql/flightsql_driver.go +++ b/go/adbc/driver/flightsql/flightsql_driver.go @@ -35,6 +35,7 @@ import ( "net/url" "runtime/debug" "strings" + "time" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-adbc/go/adbc/driver/driverbase" @@ -53,6 +54,7 @@ const ( OptionWithBlock = "adbc.flight.sql.client_option.with_block" OptionWithMaxMsgSize = "adbc.flight.sql.client_option.with_max_msg_size" OptionAuthorizationHeader = "adbc.flight.sql.authorization_header" + OptionTimeoutConnect = "adbc.flight.sql.rpc.timeout_seconds.connect" OptionTimeoutFetch = "adbc.flight.sql.rpc.timeout_seconds.fetch" OptionTimeoutQuery = "adbc.flight.sql.rpc.timeout_seconds.query" OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update" @@ -126,7 +128,11 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error) db := &databaseImpl{ DatabaseImplBase: driverbase.NewDatabaseImplBase(&d.DriverImplBase), - hdrs: make(metadata.MD), + timeout: timeoutOption{ + // Match gRPC default + connectTimeout: time.Second * 20, + }, + hdrs: make(metadata.MD), } var err error diff --git a/go/adbc/driver/flightsql/timeouts.go b/go/adbc/driver/flightsql/timeouts.go index 7737526860..db39ca4622 100644 --- a/go/adbc/driver/flightsql/timeouts.go +++ b/go/adbc/driver/flightsql/timeouts.go @@ -28,6 +28,7 @@ import ( "github.com/apache/arrow-adbc/go/adbc" "google.golang.org/grpc" + "google.golang.org/grpc/backoff" "google.golang.org/grpc/metadata" ) @@ -40,6 +41,8 @@ type timeoutOption struct { queryTimeout time.Duration // timeout for DoPut or DoAction requests updateTimeout time.Duration + // timeout for establishing a new connection + connectTimeout time.Duration } func (t *timeoutOption) setTimeout(key string, value float64) error { @@ -60,6 +63,8 @@ func (t *timeoutOption) setTimeout(key string, value float64) error { t.queryTimeout = timeout case OptionTimeoutUpdate: t.updateTimeout = timeout + case OptionTimeoutConnect: + t.connectTimeout = timeout default: return adbc.Error{ Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option '%s'", key), @@ -81,6 +86,13 @@ func (t *timeoutOption) setTimeoutString(key string, value string) error { return t.setTimeout(key, timeout) } +func (t *timeoutOption) connectParams() grpc.ConnectParams { + return grpc.ConnectParams{ + Backoff: backoff.DefaultConfig, + MinConnectTimeout: t.connectTimeout, + } +} + func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration, bool) { for _, opt := range callOptions { if to, ok := opt.(timeoutOption); ok {