Skip to content

Commit

Permalink
(4/x) wire-in a way for the retry policy to evaluate the overall retry
Browse files Browse the repository at this point in the history
duration across successive requests

Signed-off-by: Jesse Whitehouse <[email protected]>
  • Loading branch information
Jesse Whitehouse committed Jul 26, 2023
1 parent 6a5635b commit 61b2f4e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
50 changes: 45 additions & 5 deletions src/databricks/sql/auth/thrift_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import six
import thrift
import time

logger = logging.getLogger(__name__)

Expand All @@ -23,6 +24,8 @@ class DatabricksRetryPolicy(Retry):
Same as urllib3.Retry but implements retry_wait_min and retry_wait_max as part of its exponential backoff calculation
Retry Policy arguments are defined in the docstring for our ThriftBackend class
delay_default was introduced in #24 under the previous retry regime. It's currently a no-op.
"""

def __init__(
Expand All @@ -33,11 +36,40 @@ def __init__(
stop_after_attempts_duration: float,
delay_default: float,
*args,
**kwargs
**kwargs,
):
super().__init__(*args, **kwargs)
super().__init__(
total=stop_after_attempts_count,
respect_retry_after_header=True,
backoff_max=delay_max,
*args,
**kwargs,
)

self.backoff_factor: float = delay_min
self.stop_after_attempts_duration: float = stop_after_attempts_duration
logger.info("DatabricksRetryPolicy is in effect!")

def start_retry_timer(self):
self._retry_start_time = time.time()

def check_retry_timer(self):
return time.time() - self._retry_start_time

def get_backoff_time(self) -> float:
"""Calls urllib3's built-in get_backoff_time but raises an exception if stop_after_attempts_duration would be exceeded"""

proposed_backoff = super(self).get_backoff_time()

if (
time.time() - self.check_retry_timer() + proposed_backoff
) > self.stop_after_attempts_duration:
raise Exception(
f"Retry request would exceed Retry policy max retry duration of {self.stop_after_attempts_duration} seconds"
)

return proposed_backoff


class THttpClient(thrift.transport.THttpClient.THttpClient):
def __init__(
Expand All @@ -51,7 +83,8 @@ def __init__(
key_file=None,
ssl_context=None,
max_connections: int = 1,
retry_policy: Retry = Retry(False),
# this type annotation is slightly incorrect because the default value is of type urllib3.Retry
retry_policy: DatabricksRetryPolicy = Retry(False),
):
if port is not None:
warnings.warn(
Expand Down Expand Up @@ -118,6 +151,13 @@ def setCustomHeaders(self, headers: Dict[str, str]):
self._headers = headers
super().setCustomHeaders(headers)

def startRetryTimer(self):
"""Notify DatabricksRetryPolicy of the request start time
This is used to enforce the retry_stop_after_attempts_duration
"""
self.retry_policy.start_retry_timer()

def setAllowRetries(self, value: bool):
logger.info(f"urllib3 is allowed to retry: {value}")
self._allow_retries = value
Expand Down Expand Up @@ -193,7 +233,7 @@ def flush(self):
toggle_retries = {"retries": False}
else:
toggle_retries = {}

# HTTP request
self.__resp = self.__pool.request(
"POST",
Expand All @@ -202,7 +242,7 @@ def flush(self):
headers=headers,
preload_content=False,
timeout=self.__timeout,
**toggle_retries
**toggle_retries,
)

# Get reply to flush the request
Expand Down
3 changes: 2 additions & 1 deletion src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,11 @@ def attempt_request(attempt):
logger.debug("Sending request: {}(<REDACTED>)".format(this_method_name))
unsafe_logger.debug("Sending request: {}".format(request))

# TODO: only allow retries when it's a 429 or a 503
# TODO: only allow retries when it's a 429 or a 503
if this_method_name == "ExecuteStatement":
self._transport.setAllowRetries(False)

self._transport.startRetryTimer()
response = method(request)

# Always default to retry after this request is complete
Expand Down

0 comments on commit 61b2f4e

Please sign in to comment.