diff --git a/python/example.py b/python/example.py index 784259ac..0d745dc9 100644 --- a/python/example.py +++ b/python/example.py @@ -17,6 +17,7 @@ import sys from pyarrow import flight +from http.cookies import SimpleCookie class DremioClientAuthMiddlewareFactory(flight.ClientMiddlewareFactory): @@ -58,6 +59,53 @@ def received_headers(self, headers): self.factory.set_call_credential([ b'authorization', authorization_header[0].encode("utf-8")]) +class CookieMiddlewareFactory(flight.ClientMiddlewareFactory): + """A factory that creates CookieMiddleware(s).""" + + def __init__(self): + self.cookies = {} + + def start_call(self, info): + return CookieMiddleware(self) + + +class CookieMiddleware(flight.ClientMiddleware): + """ + A ClientMiddleware that receives and retransmits cookies. + For simplicity, this does not auto-expire cookies. + + Parameters + ---------- + factory : CookieMiddlewareFactory + The factory containing the currently cached cookies. + """ + + def __init__(self, factory): + self.factory = factory + + def received_headers(self, headers): + for key in headers: + if key.lower() == 'set-cookie': + cookie = SimpleCookie() + cookie.load(headers.get(key)) + self.factory.cookies.update(cookie.items) + + def sending_headers(self): + if self.factory.cookies: + cookie_string = '; '.join("{!s}={!r}".format(key,val) for (key,val) in k.items()) + return {'Cookie': cookie_string} + return {} + +class KVParser(argpase.Action) + def __call__( self , parser, namespace, + values, option_string = None): + setattr(namespace, self.dest, dict()) + + for value in values: + # split it into key and value + key, value = value.split('=') + # assign into dictionary + getattr(namespace, self.dest)[key] = value def parse_arguments(): """ @@ -81,11 +129,13 @@ def parse_arguments(): required=False, default=False) parser.add_argument('-certs', '--trustedCertificates', type=str, help='Path to trusted certificates for encrypted connection', required=False) + parser.add_argument('-authToken', '--authToken', type=str, help="PAT or OAuth Token", required=False) + parser.add_argument('-sessionProperties', '--sessionProperties', nargs='*' action=KVParser) return parser.parse_args() def connect_to_dremio_flight_server_endpoint(hostname, flightport, username, password, sqlquery, - tls, certs, disableServerVerification): + tls, certs, disableServerVerification, authToken, sessionProperties): """ Connects to Dremio Flight server endpoint with the provided credentials. It also runs the query and retrieves the result set. @@ -113,21 +163,35 @@ def connect_to_dremio_flight_server_endpoint(hostname, flightport, username, pas print('[ERROR] Trusted certificates must be provided to establish a TLS connection') sys.exit() - # Two WLM settings can be provided upon initial authneitcation + # Two WLM settings can be provided upon initial authentication # with the Dremio Server Flight Endpoint: # - routing-tag # - routing queue - initial_options = flight.FlightCallOptions(headers=[ + headers = sessionProperties + headers.append([ (b'routing-tag', b'test-routing-tag'), - (b'routing-queue', b'Low Cost User Queries') - ]) - client_auth_middleware = DremioClientAuthMiddlewareFactory() - client = flight.FlightClient("{}://{}:{}".format(scheme, hostname, flightport), - middleware=[client_auth_middleware], **connection_args) - - # Authenticate with the server endpoint. - bearer_token = client.authenticate_basic_token(username, password, initial_options) - print('[INFO] Authentication was successful') + (b'routing-queue', b'Low Cost User Queries')]) + + client_cookie_middleware = CookieMiddlewareFactory() + + if username and password: + client_auth_middleware = DremioClientAuthMiddlewareFactory() + client = flight.FlightClient("{}://{}:{}".format(scheme, hostname, flightport), + middleware=[client_auth_middleware, client_cookie_middleware], **connection_args) + + # Authenticate with the server endpoint. + bearer_token = client.authenticate_basic_token(username, password, flight.FlightCallOptions(headers=headers)) + headers = [bearer_token] + print('[INFO] Authentication was successful') + elif authToken: + headers.append([b'authorization', b'Bearer ' + authToken]) + client = flight.FlightClient("{}://{}:{}".format(scheme, hostname, flightport), + middleware=[client_cookie_middleware], **connection_args) + + print('[INFO] Authentication skipped until first request') + else: + print('[ERROR] Username/password or Auth token must be supplied.') + sys.exit() if sqlquery: # Construct FlightDescriptor for the query result set. @@ -142,7 +206,7 @@ def connect_to_dremio_flight_server_endpoint(hostname, flightport, username, pas # ]) # Retrieve the schema of the result set. - options = flight.FlightCallOptions(headers=[bearer_token]) + options = flight.FlightCallOptions(headers=headers) schema = client.get_schema(flight_desc, options) print('[INFO] GetSchema was successful') print('[INFO] Schema: ', schema) @@ -170,3 +234,4 @@ def connect_to_dremio_flight_server_endpoint(hostname, flightport, username, pas # Connect to Dremio Arrow Flight server endpoint. connect_to_dremio_flight_server_endpoint(args.hostname, args.flightport, args.username, args.password, args.sqlquery, args.tls, args.trustedCertificates, args.disableServerVerification) + args.authToken, args.sessionProperties)