Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(go/adbc/driver/flightsql): test handling of bad locations #1533

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/driver/flight_sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ The options are as follows:
For example, this controls the timeout of the underlying Flight
calls that implement bulk ingestion, or transaction support.

There is also a timeout that is set on the :cpp:class:`AdbcDatabase`:

``adbc.flight.sql.rpc.timeout_seconds.connect``
A timeout (in floating-point seconds) for establishing a connection. The
default is 20 seconds.

Transactions
------------

Expand Down
63 changes: 61 additions & 2 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"context"
"errors"
"fmt"
"net"
"net/textproto"
"os"
"strconv"
Expand Down Expand Up @@ -810,11 +811,18 @@ func (ts *IncrementalPollTests) TestQueryTransaction() {

type TimeoutTestServer struct {
flightsql.BaseServer
badPort int
goodPort int
}

func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
if string(tkt.GetStatementHandle()) == "sleep and succeed" {
ticket := string(tkt.GetStatementHandle())
if ticket == "sleep and succeed" {
time.Sleep(1 * time.Second)
}

switch ticket {
case "bad endpoint", "sleep and succeed":
sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, strings.NewReader(`[{"a": 5}]`))
if err != nil {
Expand Down Expand Up @@ -850,6 +858,23 @@ func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx context.Context, cmd fli
switch cmd.GetQuery() {
case "timeout":
<-ctx.Done()
case "bad endpoint":
tkt, _ := flightsql.CreateStatementQueryTicket([]byte("bad endpoint"))
info := &flight.FlightInfo{
FlightDescriptor: desc,
Endpoint: []*flight.FlightEndpoint{
{
Ticket: &flight.Ticket{Ticket: tkt},
Location: []*flight.Location{
{Uri: fmt.Sprintf("grpc://localhost:%d", ts.badPort)},
{Uri: fmt.Sprintf("grpc://localhost:%d", ts.goodPort)},
},
},
},
TotalRecords: -1,
TotalBytes: -1,
}
return info, nil
case "fetch":
tkt, _ := flightsql.CreateStatementQueryTicket([]byte("fetch"))
info := &flight.FlightInfo{
Expand Down Expand Up @@ -884,10 +909,23 @@ func (ts *TimeoutTestServer) CreatePreparedStatement(ctx context.Context, req fl

type TimeoutTests struct {
ServerBasedTests
server net.Listener
}

func (suite *TimeoutTests) SetupSuite() {
suite.DoSetupSuite(&TimeoutTestServer{}, nil, nil)
var err error
suite.server, err = net.Listen("tcp", "localhost:0")
suite.NoError(err)

badPort := suite.server.Addr().(*net.TCPAddr).Port
server := &TimeoutTestServer{badPort: badPort}
suite.DoSetupSuite(server, nil, nil)
server.goodPort = suite.s.Addr().(*net.TCPAddr).Port
}

func (suite *TimeoutTests) TearDownSuite() {
suite.ServerBasedTests.TearDownSuite()
suite.NoError(suite.server.Close())
}

func (ts *TimeoutTests) TestInvalidValues() {
Expand Down Expand Up @@ -1075,6 +1113,27 @@ func (ts *TimeoutTests) TestDontTimeout() {
ts.Truef(array.RecordEqual(rec, expected), "expected: %s\nactual: %s", expected, rec)
}

func (ts *TimeoutTests) TestBadAddress() {
stmt, err := ts.cnxn.NewStatement()
ts.Require().NoError(err)
defer stmt.Close()
ts.Require().NoError(stmt.SetSqlQuery("bad endpoint"))

ts.Require().NoError(ts.db.(adbc.GetSetOptions).SetOptionDouble(driver.OptionTimeoutConnect, 5))

rr, _, err := stmt.ExecuteQuery(context.Background())
ts.Require().NoError(err)
defer rr.Release()

rr, _, err = stmt.ExecuteQuery(context.Background())
ts.Require().NoError(err)
defer rr.Release()

rr, _, err = stmt.ExecuteQuery(context.Background())
ts.Require().NoError(err)
defer rr.Release()
}

// ---- Cookie Tests --------------------
type CookieTestServer struct {
flightsql.BaseServer
Expand Down
23 changes: 20 additions & 3 deletions go/adbc/driver/flightsql/flightsql_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,13 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error {
delete(cnOptions, OptionTimeoutUpdate)
}

if tv, ok := cnOptions[OptionTimeoutConnect]; ok {
if err = d.timeout.setTimeoutString(OptionTimeoutConnect, tv); err != nil {
return err
}
delete(cnOptions, OptionTimeoutConnect)
}

if val, ok := cnOptions[OptionWithBlock]; ok {
if val == adbc.OptionValueEnabled {
d.dialOpts.block = true
Expand Down Expand Up @@ -257,6 +264,8 @@ func (d *databaseImpl) GetOption(key string) (string, error) {
return d.timeout.queryTimeout.String(), nil
case OptionTimeoutUpdate:
return d.timeout.updateTimeout.String(), nil
case OptionTimeoutConnect:
return d.timeout.connectTimeout.String(), nil
}
if val, ok := d.options[key]; ok {
return val, nil
Expand All @@ -271,6 +280,8 @@ func (d *databaseImpl) GetOptionInt(key string) (int64, error) {
case OptionTimeoutQuery:
fallthrough
case OptionTimeoutUpdate:
fallthrough
case OptionTimeoutConnect:
val, err := d.GetOptionDouble(key)
if err != nil {
return 0, err
Expand All @@ -289,6 +300,8 @@ func (d *databaseImpl) GetOptionDouble(key string) (float64, error) {
return d.timeout.queryTimeout.Seconds(), nil
case OptionTimeoutUpdate:
return d.timeout.updateTimeout.Seconds(), nil
case OptionTimeoutConnect:
return d.timeout.connectTimeout.Seconds(), nil
}

return d.DatabaseImplBase.GetOptionDouble(key)
Expand All @@ -297,7 +310,7 @@ func (d *databaseImpl) GetOptionDouble(key string) (float64, error) {
func (d *databaseImpl) SetOption(key, value string) error {
// We can't change most options post-init
switch key {
case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate:
case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate, OptionTimeoutConnect:
return d.timeout.setTimeoutString(key, value)
}
if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
Expand All @@ -313,6 +326,8 @@ func (d *databaseImpl) SetOptionInt(key string, value int64) error {
case OptionTimeoutQuery:
fallthrough
case OptionTimeoutUpdate:
fallthrough
case OptionTimeoutConnect:
return d.timeout.setTimeout(key, float64(value))
}

Expand All @@ -326,6 +341,8 @@ func (d *databaseImpl) SetOptionDouble(key string, value float64) error {
case OptionTimeoutQuery:
fallthrough
case OptionTimeoutUpdate:
fallthrough
case OptionTimeoutConnect:
return d.timeout.setTimeout(key, value)
}

Expand Down Expand Up @@ -366,8 +383,9 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl
creds = insecure.NewCredentials()
target = "unix:" + uri.Path
}
dialOpts := append(d.dialOpts.opts, grpc.WithTransportCredentials(creds))
dialOpts := append(d.dialOpts.opts, grpc.WithConnectParams(d.timeout.connectParams()), grpc.WithTransportCredentials(creds))

d.Logger.DebugContext(ctx, "new client", "location", loc)
cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
if err != nil {
return nil, adbc.Error{
Expand Down Expand Up @@ -395,7 +413,6 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl
}
}

d.Logger.DebugContext(ctx, "new client", "location", loc)
return cl, nil
}

Expand Down
8 changes: 7 additions & 1 deletion go/adbc/driver/flightsql/flightsql_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"net/url"
"runtime/debug"
"strings"
"time"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
Expand All @@ -53,6 +54,7 @@ const (
OptionWithBlock = "adbc.flight.sql.client_option.with_block"
OptionWithMaxMsgSize = "adbc.flight.sql.client_option.with_max_msg_size"
OptionAuthorizationHeader = "adbc.flight.sql.authorization_header"
OptionTimeoutConnect = "adbc.flight.sql.rpc.timeout_seconds.connect"
OptionTimeoutFetch = "adbc.flight.sql.rpc.timeout_seconds.fetch"
OptionTimeoutQuery = "adbc.flight.sql.rpc.timeout_seconds.query"
OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update"
Expand Down Expand Up @@ -126,7 +128,11 @@ func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, error)

db := &databaseImpl{
DatabaseImplBase: driverbase.NewDatabaseImplBase(&d.DriverImplBase),
hdrs: make(metadata.MD),
timeout: timeoutOption{
// Match gRPC default
connectTimeout: time.Second * 20,
},
hdrs: make(metadata.MD),
}

var err error
Expand Down
12 changes: 12 additions & 0 deletions go/adbc/driver/flightsql/timeouts.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/apache/arrow-adbc/go/adbc"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/metadata"
)

Expand All @@ -40,6 +41,8 @@ type timeoutOption struct {
queryTimeout time.Duration
// timeout for DoPut or DoAction requests
updateTimeout time.Duration
// timeout for establishing a new connection
connectTimeout time.Duration
}

func (t *timeoutOption) setTimeout(key string, value float64) error {
Expand All @@ -60,6 +63,8 @@ func (t *timeoutOption) setTimeout(key string, value float64) error {
t.queryTimeout = timeout
case OptionTimeoutUpdate:
t.updateTimeout = timeout
case OptionTimeoutConnect:
t.connectTimeout = timeout
default:
return adbc.Error{
Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option '%s'", key),
Expand All @@ -81,6 +86,13 @@ func (t *timeoutOption) setTimeoutString(key string, value string) error {
return t.setTimeout(key, timeout)
}

func (t *timeoutOption) connectParams() grpc.ConnectParams {
return grpc.ConnectParams{
Backoff: backoff.DefaultConfig,
MinConnectTimeout: t.connectTimeout,
}
}

func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration, bool) {
for _, opt := range callOptions {
if to, ok := opt.(timeoutOption); ok {
Expand Down
Loading