diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index a847c13e..3b162225 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -23,7 +23,7 @@ class CommandType(Enum): EXECUTE_STATEMENT = "ExecuteStatement" - CLOSE_SESSION = "OpenSession" + CLOSE_SESSION = "CloseSession" CANCEL_OPERATION = "CancelOperation" OTHER = "Other" diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index ac782c8d..347b9c10 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -8,7 +8,7 @@ from databricks.sql import __version__ from databricks.sql import * -from databricks.sql.exc import OperationalError +from databricks.sql.exc import OperationalError, SessionAlreadyClosedError from databricks.sql.thrift_backend import ThriftBackend from databricks.sql.utils import ExecuteResponse, ParamEscaper, inject_parameters from databricks.sql.types import Row @@ -257,6 +257,8 @@ def _close(self, close_cursors=True) -> None: try: self.thrift_backend.close_session(self._session_handle) + except SessionAlreadyClosedError as e: + logger.info("Session was closed by a prior request") except DatabaseError as e: if "Invalid SessionHandle" in str(e): logger.warning( diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index e05fe036..b2c522e9 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -104,4 +104,7 @@ class NonRecoverableNetworkError(RequestError): """ class UnsafeToRetryError(RequestError): - """Thrown if ExecuteStatement request receives a code other than 200, 429, or 503""" \ No newline at end of file + """Thrown if ExecuteStatement request receives a code other than 200, 429, or 503""" + +class SessionAlreadyClosedError(RequestError): + """Thrown if CloseSession receives a code 404. ThriftBackend should gracefully proceed as this is expected.""" \ No newline at end of file diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 51603e5e..6da75d45 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -360,6 +360,8 @@ def attempt_request(attempt): # TODO: Can I remove the line below since retry policy can now know the command type? # To do this, I need a way to intercept a retry in-flight from within the policy. Probably a method override self._transport.setAllowRetries(False) + elif this_method_name == "CloseSession": + self._transport.set_retry_command_type(databricks.sql.auth.thrift_http_client.CommandType.CLOSE_SESSION) else: self._transport.set_retry_command_type(databricks.sql.auth.thrift_http_client.CommandType.OTHER) diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index b63a5b3d..0fa9b137 100644 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -49,6 +49,7 @@ def _test_retry_disabled_with_message(self, error_msg_substring, exception_type) MaxRetryDurationError, NonRecoverableNetworkError, UnsafeToRetryError, + SessionAlreadyClosedError ) from contextlib import contextmanager @@ -119,7 +120,7 @@ def test_retry_max_count_not_exceeded(self): before raising an exception """ with mocked_server_response(status=404) as mock_obj: - with pytest.raises(MaxRetryError) as cm: + with self.assertRaises(MaxRetryError) as cm: with self.connection(extra_params=self._retry_policy) as conn: pass assert mock_obj.return_value.getresponse.call_count == 6 @@ -130,10 +131,11 @@ def test_retry_max_duration_not_exceeded(self): THEN the connector raises a MaxRetryDurationError """ with mocked_server_response(status=429, headers={"Retry-After": "60"}): - with pytest.raises(RequestError) as cm: + with self.assertRaises(RequestError) as cm: with self.connection(extra_params=self._retry_policy) as conn: pass - assert cm.exception == MaxRetryError + assert isinstance(cm.exception.args[1], MaxRetryDurationError) + def test_retry_abort_non_recoverable_error(self): """GIVEN the server returns a code 501 @@ -143,10 +145,10 @@ def test_retry_abort_non_recoverable_error(self): # Code 501 is a Not Implemented error with mocked_server_response(status=501, headers={"Retry-After": None}): - with pytest.raises(RequestError) as cm: + with self.assertRaises(RequestError) as cm: with self.connection(extra_params=self._retry_policy) as conn: pass - assert cm.exception == NonRecoverableNetworkError + assert isinstance(cm.exception.args[1], NonRecoverableNetworkError) def test_retry_abort_unsafe_execute_statement_retry_condition(self): """GIVEN the server sends a code other than 429 or 503 @@ -157,9 +159,9 @@ def test_retry_abort_unsafe_execute_statement_retry_condition(self): with conn.cursor() as cursor: # Code 502 is a Bad Gateway, which we commonly see in production under heavy load with mocked_server_response(status=502, headers={"Retry-After": None}): - with pytest.raises(RequestError) as cm: + with self.assertRaises(RequestError) as cm: cursor.execute("Not a real query") - assert cm.exception == UnsafeToRetryError + assert isinstance(cm.exception.args[1], UnsafeToRetryError) def test_retry_safe_execute_statement_retry_condition(self): """GIVEN the server sends a code other than 429 or 503 @@ -178,6 +180,37 @@ def test_retry_safe_execute_statement_retry_condition(self): with conn.cursor() as cursor: # Code 502 is a Bad Gateway, which we commonly see in production under heavy load with mock_sequential_server_responses(responses) as mock_obj: - with pytest.raises(MaxRetryError) as cm: + with pytest.raises(MaxRetryError): cursor.execute("This query never reaches the server") assert mock_obj.return_value.getresponse.call_count == 2 + + def test_retry_abort_close_session_on_404(self): + """ GIVEN the connector sends a CloseSession command + WHEN server sends a 404 (which is normally retried) + THEN nothing is retried because 404 means the session already closed + """ + + # First response is a Bad Gateway -> Result is the command actually goes through + # Second response is a 404 because the session is no longer found + responses = [ + {"status": 502, "headers": {"Retry-After": "1"}}, + {"status": 404, "headers": {"Retry-After": None}}, + ] + + with self.connection(extra_params={**self._retry_policy}) as conn: + conn.close() + conn.open = True + with mock_sequential_server_responses(responses): + with self.assertLogs("databricks.sql", level="INFO",) as cm: + conn.close() + expected_message_was_found = False + for log in cm.output: + if expected_message_was_found: + break + target = "Session was closed by a prior request" + expected_message_was_found = target in log + self.assertTrue(expected_message_was_found, "Did not find expected log messages") + + + +