Skip to content

Commit

Permalink
Working prototype of execute_async, get_query_state and get_execution…
Browse files Browse the repository at this point in the history
…_result
  • Loading branch information
jprakash-db committed Nov 4, 2024
1 parent 925b2a3 commit 756ac17
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 44 deletions.
32 changes: 20 additions & 12 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def execute(
self,
operation: str,
parameters: Optional[TParameterCollection] = None,
perform_async = False
async_op=False,
) -> "Cursor":
"""
Execute a query and wait for execution to complete.
Expand Down Expand Up @@ -797,14 +797,15 @@ def execute(
cursor=self,
use_cloud_fetch=self.connection.use_cloud_fetch,
parameters=prepared_params,
perform_async=perform_async,
async_op=async_op,
)
self.active_result_set = ResultSet(
self.connection,
execute_response,
self.thrift_backend,
self.buffer_size_bytes,
self.arraysize,
async_op,
)

if execute_response.is_staging_operation:
Expand All @@ -814,21 +815,25 @@ def execute(

return self

def execute_async(self,
operation: str,
parameters: Optional[TParameterCollection] = None,):
def execute_async(
self,
operation: str,
parameters: Optional[TParameterCollection] = None,
):
return self.execute(operation, parameters, True)

def get_query_status(self):
def get_query_state(self):
self._check_not_closed()
return self.thrift_backend.get_query_status(self.active_op_handle)
return self.thrift_backend.get_query_state(self.active_op_handle)

def get_execution_result(self):
self._check_not_closed()

operation_state = self.get_query_status()
if operation_state.statusCode == ttypes.TStatusCode.SUCCESS_STATUS or operation_state.statusCode == ttypes.TStatusCode.SUCCESS_WITH_INFO_STATUS:
execute_response=self.thrift_backend.get_execution_result(self.active_op_handle)
operation_state = self.get_query_state()
if operation_state == ttypes.TOperationState.FINISHED_STATE:
execute_response = self.thrift_backend.get_execution_result(
self.active_op_handle, self
)
self.active_result_set = ResultSet(
self.connection,
execute_response,
Expand All @@ -844,7 +849,9 @@ def get_execution_result(self):

return self
else:
raise Error(f"get_execution_result failed with status code {operation_state.statusCode}")
raise Error(
f"get_execution_result failed with Operation status {operation_state}"
)

def executemany(self, operation, seq_of_parameters):
"""
Expand Down Expand Up @@ -1131,6 +1138,7 @@ def __init__(
thrift_backend: ThriftBackend,
result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
arraysize: int = 10000,
async_op=False,
):
"""
A ResultSet manages the results of a single command.
Expand All @@ -1153,7 +1161,7 @@ def __init__(
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
self._next_row_index = 0

if execute_response.arrow_queue or True:
if execute_response.arrow_queue or async_op:
# In this case the server has taken the fast path and returned an initial batch of
# results
self.results = execute_response.arrow_queue
Expand Down
12 changes: 0 additions & 12 deletions src/databricks/sql/constants.py

This file was deleted.

37 changes: 17 additions & 20 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
arrow_schema_bytes=schema_bytes,
)

def get_execution_result(self, op_handle):
def get_execution_result(self, op_handle, cursor):

assert op_handle is not None

Expand All @@ -780,15 +780,15 @@ def get_execution_result(self, op_handle):
False,
op_handle.modifiedRowCount,
),
maxRows=max_rows,
maxBytes=max_bytes,
maxRows=cursor.arraysize,
maxBytes=cursor.buffer_size_bytes,
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
includeResultSetMetadata=True,
)

resp = self.make_request(self._client.FetchResults, req)

t_result_set_metadata_resp = resp.resultSetMetaData
t_result_set_metadata_resp = resp.resultSetMetadata

lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
Expand All @@ -797,15 +797,12 @@ def get_execution_result(self, op_handle):
t_result_set_metadata_resp.schema
)

if pyarrow:
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
.serialize()
.to_pybytes()
)
else:
schema_bytes = None
schema_bytes = (
t_result_set_metadata_resp.arrowSchema
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
.serialize()
.to_pybytes()
)

queue = ResultSetQueueFactory.build_queue(
row_set_type=resp.resultSetMetadata.resultFormat,
Expand All @@ -820,11 +817,11 @@ def get_execution_result(self, op_handle):
return ExecuteResponse(
arrow_queue=queue,
status=resp.status,
has_been_closed_server_side=has_been_closed_server_side,
has_been_closed_server_side=False,
has_more_rows=has_more_rows,
lz4_compressed=lz4_compressed,
is_staging_operation=is_staging_operation,
command_handle=resp.operationHandle,
command_handle=op_handle,
description=description,
arrow_schema_bytes=schema_bytes,
)
Expand All @@ -847,9 +844,9 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
return operation_state

def get_query_status(self, op_handle):
def get_query_state(self, op_handle):
poll_resp = self._poll_for_status(op_handle)
operation_state = poll_resp.status
operation_state = poll_resp.operationState
self._check_command_not_in_error_or_closed_state(op_handle, poll_resp)
return operation_state

Expand Down Expand Up @@ -883,7 +880,7 @@ def execute_command(
cursor,
use_cloud_fetch=True,
parameters=[],
perform_async=False,
async_op=False,
):
assert session_handle is not None

Expand Down Expand Up @@ -914,7 +911,7 @@ def execute_command(
)
resp = self.make_request(self._client.ExecuteStatement, req)

if perform_async:
if async_op:
return self._handle_execute_response_async(resp, cursor)
else:
return self._handle_execute_response(resp, cursor)
Expand Down Expand Up @@ -1012,7 +1009,7 @@ def _handle_execute_response(self, resp, cursor):
final_operation_state = self._wait_until_command_done(
resp.operationHandle,
resp.directResults and resp.directResults.operationStatus,
)
)

return self._results_message_to_execute_response(resp, final_operation_state)

Expand Down

0 comments on commit 756ac17

Please sign in to comment.