diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index be43ae06..7373be38 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -733,7 +733,7 @@ def execute( self, operation: str, parameters: Optional[TParameterCollection] = None, - perform_async = False + async_op=False, ) -> "Cursor": """ Execute a query and wait for execution to complete. @@ -797,7 +797,7 @@ def execute( cursor=self, use_cloud_fetch=self.connection.use_cloud_fetch, parameters=prepared_params, - perform_async=perform_async, + async_op=async_op, ) self.active_result_set = ResultSet( self.connection, @@ -805,6 +805,7 @@ def execute( self.thrift_backend, self.buffer_size_bytes, self.arraysize, + async_op, ) if execute_response.is_staging_operation: @@ -814,21 +815,25 @@ def execute( return self - def execute_async(self, - operation: str, - parameters: Optional[TParameterCollection] = None,): + def execute_async( + self, + operation: str, + parameters: Optional[TParameterCollection] = None, + ): return self.execute(operation, parameters, True) - def get_query_status(self): + def get_query_state(self): self._check_not_closed() - return self.thrift_backend.get_query_status(self.active_op_handle) + return self.thrift_backend.get_query_state(self.active_op_handle) def get_execution_result(self): self._check_not_closed() - operation_state = self.get_query_status() - if operation_state.statusCode == ttypes.TStatusCode.SUCCESS_STATUS or operation_state.statusCode == ttypes.TStatusCode.SUCCESS_WITH_INFO_STATUS: - execute_response=self.thrift_backend.get_execution_result(self.active_op_handle) + operation_state = self.get_query_state() + if operation_state == ttypes.TOperationState.FINISHED_STATE: + execute_response = self.thrift_backend.get_execution_result( + self.active_op_handle, self + ) self.active_result_set = ResultSet( self.connection, execute_response, @@ -844,7 +849,9 @@ def get_execution_result(self): return self else: - raise Error(f"get_execution_result failed with status code {operation_state.statusCode}") + raise Error( + f"get_execution_result failed with Operation status {operation_state}" + ) def executemany(self, operation, seq_of_parameters): """ @@ -1131,6 +1138,7 @@ def __init__( thrift_backend: ThriftBackend, result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, arraysize: int = 10000, + async_op=False, ): """ A ResultSet manages the results of a single command. @@ -1153,7 +1161,7 @@ def __init__( self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._next_row_index = 0 - if execute_response.arrow_queue or True: + if execute_response.arrow_queue or async_op: # In this case the server has taken the fast path and returned an initial batch of # results self.results = execute_response.arrow_queue diff --git a/src/databricks/sql/constants.py b/src/databricks/sql/constants.py deleted file mode 100644 index 2245f4f7..00000000 --- a/src/databricks/sql/constants.py +++ /dev/null @@ -1,12 +0,0 @@ -from databricks.sql.thrift_api.TCLIService import ttypes - -class QueryExecutionStatus: - INITIALIZED_STATE=ttypes.TOperationState.INITIALIZED_STATE - RUNNING_STATE = ttypes.TOperationState.RUNNING_STATE - FINISHED_STATE = ttypes.TOperationState.FINISHED_STATE - CANCELED_STATE = ttypes.TOperationState.CANCELED_STATE - CLOSED_STATE = ttypes.TOperationState.CLOSED_STATE - ERROR_STATE = ttypes.TOperationState.ERROR_STATE - UKNOWN_STATE = ttypes.TOperationState.UKNOWN_STATE - PENDING_STATE = ttypes.TOperationState.PENDING_STATE - TIMEDOUT_STATE = ttypes.TOperationState.TIMEDOUT_STATE \ No newline at end of file diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index bbb90f1d..f6c10649 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -769,7 +769,7 @@ def _results_message_to_execute_response(self, resp, operation_state): arrow_schema_bytes=schema_bytes, ) - def get_execution_result(self, op_handle): + def get_execution_result(self, op_handle, cursor): assert op_handle is not None @@ -780,15 +780,15 @@ def get_execution_result(self, op_handle): False, op_handle.modifiedRowCount, ), - maxRows=max_rows, - maxBytes=max_bytes, + maxRows=cursor.arraysize, + maxBytes=cursor.buffer_size_bytes, orientation=ttypes.TFetchOrientation.FETCH_NEXT, includeResultSetMetadata=True, ) resp = self.make_request(self._client.FetchResults, req) - t_result_set_metadata_resp = resp.resultSetMetaData + t_result_set_metadata_resp = resp.resultSetMetadata lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation @@ -797,15 +797,12 @@ def get_execution_result(self, op_handle): t_result_set_metadata_resp.schema ) - if pyarrow: - schema_bytes = ( - t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) - .serialize() - .to_pybytes() - ) - else: - schema_bytes = None + schema_bytes = ( + t_result_set_metadata_resp.arrowSchema + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + .serialize() + .to_pybytes() + ) queue = ResultSetQueueFactory.build_queue( row_set_type=resp.resultSetMetadata.resultFormat, @@ -820,11 +817,11 @@ def get_execution_result(self, op_handle): return ExecuteResponse( arrow_queue=queue, status=resp.status, - has_been_closed_server_side=has_been_closed_server_side, + has_been_closed_server_side=False, has_more_rows=has_more_rows, lz4_compressed=lz4_compressed, is_staging_operation=is_staging_operation, - command_handle=resp.operationHandle, + command_handle=op_handle, description=description, arrow_schema_bytes=schema_bytes, ) @@ -847,9 +844,9 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) return operation_state - def get_query_status(self, op_handle): + def get_query_state(self, op_handle): poll_resp = self._poll_for_status(op_handle) - operation_state = poll_resp.status + operation_state = poll_resp.operationState self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) return operation_state @@ -883,7 +880,7 @@ def execute_command( cursor, use_cloud_fetch=True, parameters=[], - perform_async=False, + async_op=False, ): assert session_handle is not None @@ -914,7 +911,7 @@ def execute_command( ) resp = self.make_request(self._client.ExecuteStatement, req) - if perform_async: + if async_op: return self._handle_execute_response_async(resp, cursor) else: return self._handle_execute_response(resp, cursor) @@ -1012,7 +1009,7 @@ def _handle_execute_response(self, resp, cursor): final_operation_state = self._wait_until_command_done( resp.operationHandle, resp.directResults and resp.directResults.operationStatus, - ) + ) return self._results_message_to_execute_response(resp, final_operation_state)