Skip to content

Commit

Permalink
Black the code and tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jesse Whitehouse <[email protected]>
  • Loading branch information
Jesse Whitehouse committed Aug 7, 2023
1 parent bea6371 commit 0947c64
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 62 deletions.
61 changes: 35 additions & 26 deletions src/databricks/sql/auth/thrift_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
from io import BytesIO

from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager, Retry
from databricks.sql.exc import MaxRetryDurationError, NonRecoverableNetworkError, UnsafeToRetryError, SessionAlreadyClosedError, CursorAlreadyClosedError
from databricks.sql.exc import (
MaxRetryDurationError,
NonRecoverableNetworkError,
UnsafeToRetryError,
SessionAlreadyClosedError,
CursorAlreadyClosedError,
)
from enum import Enum, auto



class CommandType(Enum):
EXECUTE_STATEMENT = "ExecuteStatement"
CLOSE_SESSION = "CloseSession"
Expand Down Expand Up @@ -69,8 +74,10 @@ def __init__(
self._retry_start_time = kwargs.pop("_retry_start_time", None)
self.command_type = kwargs.pop("_command_type", CommandType.OTHER)

_total = min(stop_after_attempts_count, kwargs.pop("total", stop_after_attempts_count))

_total = min(
stop_after_attempts_count, kwargs.pop("total", stop_after_attempts_count)
)

super().__init__(
total=_total,
respect_retry_after_header=True,
Expand All @@ -81,7 +88,6 @@ def __init__(
**kwargs,
)


def new(self, **kw: typing.Any) -> Retry:
"""urllib3 calls Retry.new() between successive requests
We need to override this method to include mapping of our extra keyword arguments:
Expand Down Expand Up @@ -109,7 +115,7 @@ def new(self, **kw: typing.Any) -> Retry:
stop_after_attempts_duration=self.stop_after_attempts_duration,
delay_default=self.delay_default,
_retry_start_time=self._retry_start_time,
_command_type=self.command_type
_command_type=self.command_type,
)

params.update(kw)
Expand Down Expand Up @@ -139,10 +145,9 @@ def start_retry_timer(self):
def check_timer_duration(self):
"""Return time in seconds since the timer was started"""
return time.time() - self._retry_start_time

def check_proposed_wait(self, proposed_wait: int) -> None:
"""Raise an exception if the proposed wait would exceed the configured max_attempts_duration
"""
"""Raise an exception if the proposed wait would exceed the configured max_attempts_duration"""

proposed_overall_time = self.check_timer_duration() + proposed_wait
if proposed_overall_time > self.stop_after_attempts_duration:
Expand All @@ -157,8 +162,9 @@ def get_backoff_time(self) -> float:
self.check_proposed_wait(proposed_backoff)

return proposed_backoff

from urllib3.response import BaseHTTPResponse

def sleep_for_retry(self, response: BaseHTTPResponse) -> bool:
retry_after = self.get_retry_after(response)
if retry_after:
Expand All @@ -167,60 +173,65 @@ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool:
return True

return False

def should_retry(self, method: str, status_code: int) -> bool:
"""We retry by default with the following exceptions:
Any 200 status code -> Because the request succeeded
Any 501 status code -> Because it's not recoverable ever. Raise a NonRecoverableNetworkError.
Any 404 if the command was CancelOperation or CloseSession and this is not the first request
Any ExecuteStatement command unless it has code 429 or 504
"""

# TODO: Port in the NoRetryReason enum for logging?
# TODO: something needs to go here about max wall-clock duration.
# It doesn't belong in the backoff calculation method

# Request succeeded. Don't retry.
if status_code == 200:
return False

# Request failed and server said NotImplemented. This isn't recoverable. Don't retry.
if status_code == 501:
raise NonRecoverableNetworkError("Received code 501 from server.")

# Request failed and this method is not retryable. We only retry POST requests.
if not self._is_method_retryable(method):
return False

# Request failed with 404 because CloseSession return 404 if you repeat the request.
if (
status_code == 404 and
self.command_type == CommandType.CLOSE_SESSION
status_code == 404
and self.command_type == CommandType.CLOSE_SESSION
and len(self.history) > 0
):
raise SessionAlreadyClosedError("CloseSession received 404 code from Databricks. Session is already closed.")
raise SessionAlreadyClosedError(
"CloseSession received 404 code from Databricks. Session is already closed."
)

# Request failed with 404 because CloseOperation return 404 if you repeat the request.
if (
status_code == 404 and
self.command_type == CommandType.CLOSE_OPERATION
status_code == 404
and self.command_type == CommandType.CLOSE_OPERATION
and len(self.history) > 0
):
raise CursorAlreadyClosedError("CloseOperation received 404 code from Databricks. Cursor is already closed.")
raise CursorAlreadyClosedError(
"CloseOperation received 404 code from Databricks. Cursor is already closed."
)

# Request failed, was an ExecuteStatement and the command may have reached the server
if (
self.command_type == CommandType.EXECUTE_STATEMENT
and status_code not in self.status_forcelist
):
raise UnsafeToRetryError("ExecuteStatement command can only be retried for codes 429 and 503")
raise UnsafeToRetryError(
"ExecuteStatement command can only be retried for codes 429 and 503"
)

# None of the above conditions applied. Eagerly retry.
logger.debug(f"This request should be retried: {self.command_type.value}")
return True


def is_retry(
self, method: str, status_code: int, has_retry_after: bool = False
) -> bool:
Expand All @@ -229,8 +240,6 @@ def is_retry(
"""

return self.should_retry(method, status_code)




class THttpClient(thrift.transport.THttpClient.THttpClient):
Expand Down Expand Up @@ -399,7 +408,7 @@ def flush(self):
headers=headers,
preload_content=False,
timeout=self.__timeout,
retries=self.retry_policy
retries=self.retry_policy,
)

# Get reply to flush the request
Expand Down
6 changes: 5 additions & 1 deletion src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

from databricks.sql import __version__
from databricks.sql import *
from databricks.sql.exc import OperationalError, SessionAlreadyClosedError, CursorAlreadyClosedError
from databricks.sql.exc import (
OperationalError,
SessionAlreadyClosedError,
CursorAlreadyClosedError,
)
from databricks.sql.thrift_backend import ThriftBackend
from databricks.sql.utils import ExecuteResponse, ParamEscaper, inject_parameters
from databricks.sql.types import Row
Expand Down
10 changes: 7 additions & 3 deletions src/databricks/sql/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,24 @@ class RequestError(OperationalError):

pass


class MaxRetryDurationError(RequestError):
"""Thrown if the next HTTP request retry would exceed the configured
stop_after_attempts_duration
"""


class NonRecoverableNetworkError(RequestError):
"""Thrown if an HTTP code 501 is received
"""
"""Thrown if an HTTP code 501 is received"""


class UnsafeToRetryError(RequestError):
"""Thrown if ExecuteStatement request receives a code other than 200, 429, or 503"""


class SessionAlreadyClosedError(RequestError):
"""Thrown if CloseSession receives a code 404. ThriftBackend should gracefully proceed as this is expected."""


class CursorAlreadyClosedError(RequestError):
"""Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected."""
"""Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected."""
23 changes: 13 additions & 10 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def __init__(
# Cloud fetch
self.max_download_threads = kwargs.get("max_download_threads", 10)


# Configure tls context
ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file"))
if kwargs.get("_tls_no_verify") is True:
Expand Down Expand Up @@ -178,20 +177,20 @@ def __init__(
additional_transport_args = {}
if self.enable_v3_retries:
self.retry_policy = databricks.sql.auth.thrift_http_client.DatabricksRetryPolicy(
delay_min=self._retry_delay_min,
delay_max=self._retry_delay_max,
stop_after_attempts_count=self._retry_stop_after_attempts_count,
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
delay_default=self._retry_delay_default,
)
delay_min=self._retry_delay_min,
delay_max=self._retry_delay_max,
stop_after_attempts_count=self._retry_stop_after_attempts_count,
stop_after_attempts_duration=self._retry_stop_after_attempts_duration,
delay_default=self._retry_delay_default,
)

additional_transport_args["retry_policy"] = self.retry_policy

self._transport = databricks.sql.auth.thrift_http_client.THttpClient(
auth_provider=self._auth_provider,
uri_or_host=uri,
ssl_context=ssl_context,
**additional_transport_args
**additional_transport_args,
)

timeout = kwargs.get("_socket_timeout", DEFAULT_SOCKET_TIMEOUT)
Expand Down Expand Up @@ -392,7 +391,11 @@ def attempt_request(attempt):

gos_name = TCLIServiceClient.GetOperationStatus.__name__
if method.__name__ == gos_name:
delay_default = self.enable_v3_retries and self.retry_policy.get_operation_status_delay or self._retry_delay_default
delay_default = (
self.enable_v3_retries
and self.retry_policy.get_operation_status_delay
or self._retry_delay_default
)
retry_delay = bound_retry_delay(attempt, delay_default)
logger.info(
f"GetOperationStatus failed with HTTP error and will be retried: {str(err)}"
Expand Down
52 changes: 30 additions & 22 deletions tests/e2e/common/retry_test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _test_retry_disabled_with_message(self, error_msg_substring, exception_type)
MaxRetryDurationError,
NonRecoverableNetworkError,
UnsafeToRetryError,
SessionAlreadyClosedError
SessionAlreadyClosedError,
)

from contextlib import contextmanager
Expand Down Expand Up @@ -115,18 +115,17 @@ class PySQLRetryTestsMixin:
}

def test_oserror_retries(self):
"""If a network error occurs during make_request, the request is retried according to policy
"""
with patch("urllib3.connectionpool.HTTPSConnectionPool._validate_conn",) as mock_validate_conn:
"""If a network error occurs during make_request, the request is retried according to policy"""
with patch(
"urllib3.connectionpool.HTTPSConnectionPool._validate_conn",
) as mock_validate_conn:
mock_validate_conn.side_effect = OSError("Some arbitrary network error")
with self.assertRaises(MaxRetryError) as cm:
with self.connection(extra_params=self._retry_policy) as conn:
pass

assert mock_validate_conn.call_count == 6



def test_retry_max_count_not_exceeded(self):
"""GIVEN the max_attempts_count is 5
WHEN the server sends nothing but 429 responses
Expand All @@ -150,7 +149,6 @@ def test_retry_max_duration_not_exceeded(self):
pass
assert isinstance(cm.exception.args[1], MaxRetryDurationError)


def test_retry_abort_non_recoverable_error(self):
"""GIVEN the server returns a code 501
WHEN the connector receives this response
Expand Down Expand Up @@ -199,9 +197,9 @@ def test_retry_safe_execute_statement_retry_condition(self):
assert mock_obj.return_value.getresponse.call_count == 2

def test_retry_abort_close_session_on_404(self):
""" GIVEN the connector sends a CloseSession command
WHEN server sends a 404 (which is normally retried)
THEN nothing is retried because 404 means the session already closed
"""GIVEN the connector sends a CloseSession command
WHEN server sends a 404 (which is normally retried)
THEN nothing is retried because 404 means the session already closed
"""

# First response is a Bad Gateway -> Result is the command actually goes through
Expand All @@ -213,20 +211,25 @@ def test_retry_abort_close_session_on_404(self):

with self.connection(extra_params={**self._retry_policy}) as conn:
with mock_sequential_server_responses(responses):
with self.assertLogs("databricks.sql", level="INFO",) as cm:
with self.assertLogs(
"databricks.sql",
level="INFO",
) as cm:
conn.close()
expected_message_was_found = False
for log in cm.output:
if expected_message_was_found:
break
target = "Session was closed by a prior request"
expected_message_was_found = target in log
self.assertTrue(expected_message_was_found, "Did not find expected log messages")
self.assertTrue(
expected_message_was_found, "Did not find expected log messages"
)

def test_retry_abort_close_operation_on_404(self):
""" GIVEN the connector sends a CancelOperation command
WHEN server sends a 404 (which is normally retried)
THEN nothing is retried because 404 means the operation was already canceled
"""GIVEN the connector sends a CancelOperation command
WHEN server sends a 404 (which is normally retried)
THEN nothing is retried because 404 means the operation was already canceled
"""

# First response is a Bad Gateway -> Result is the command actually goes through
Expand All @@ -238,20 +241,25 @@ def test_retry_abort_close_operation_on_404(self):

with self.connection(extra_params={**self._retry_policy}) as conn:
with conn.cursor() as curs:
with patch("databricks.sql.utils.ExecuteResponse.has_been_closed_server_side", new_callable=PropertyMock, return_value=False):
with patch(
"databricks.sql.utils.ExecuteResponse.has_been_closed_server_side",
new_callable=PropertyMock,
return_value=False,
):
# This call guarantees we have an open cursor at the server
curs.execute("SELECT 1")
with mock_sequential_server_responses(responses):
with self.assertLogs("databricks.sql", level="INFO",) as cm:
with self.assertLogs(
"databricks.sql",
level="INFO",
) as cm:
curs.close()
expected_message_was_found = False
for log in cm.output:
if expected_message_was_found:
break
target = "Operation was canceled by a prior request"
expected_message_was_found = target in log
self.assertTrue(expected_message_was_found, "Did not find expected log messages")




self.assertTrue(
expected_message_was_found, "Did not find expected log messages"
)

0 comments on commit 0947c64

Please sign in to comment.