Skip to content

Commit

Permalink
Merge branch 'main' into wanhan/add_signature_for_evals
Browse files Browse the repository at this point in the history
  • Loading branch information
YingChen1996 authored May 7, 2024
2 parents 345f129 + e376e86 commit 05771a7
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 16 deletions.
36 changes: 28 additions & 8 deletions src/promptflow-tools/promptflow/tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,19 +519,25 @@ def is_retriable_api_connection_error(e: APIConnectionError):


# TODO(2971352): revisit this tries=100 when there is any change to the 10min timeout logic
def handle_openai_error(tries: int = 100):
def handle_openai_error(tries: int = 100, unprocessable_entity_error_tries: int = 3):
"""
A decorator function that used to handle OpenAI error.
OpenAI Error falls into retriable vs non-retriable ones.
A decorator function for handling OpenAI errors.
For retriable error, the decorator use below parameters to control its retry activity with exponential backoff:
`tries` : max times for the function invocation, type is int
'delay': base delay seconds for exponential delay, type is float
"""
OpenAI errors are categorized into retriable and non-retriable.
For retriable errors, the decorator uses the following parameters to control its retry behavior:
`tries`: max times for the function invocation, type is int
`unprocessable_entity_error_tries`: max times for the function invocation when consecutive
422 error occurs, type is int
Note:
- The retry policy for UnprocessableEntityError is different because retrying may not be beneficial,
so small threshold and requiring consecutive errors.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
consecutive_422_error_count = 0
for i in range(tries + 1):
try:
return func(*args, **kwargs)
Expand All @@ -542,6 +548,7 @@ def wrapper(*args, **kwargs):
# Handle retriable exception, please refer to
# https://platform.openai.com/docs/guides/error-codes/api-errors
print(f"Exception occurs: {type(e).__name__}: {str(e)}", file=sys.stderr)
# Firstly, exclude some non-retriable errors.
# Vision model does not support all chat api parameters, e.g. response_format and function_call.
# Recommend user to use vision model in vision tools, rather than LLM tool.
# Related issue https://github.com/microsoft/promptflow/issues/1683
Expand All @@ -558,7 +565,11 @@ def wrapper(*args, **kwargs):
if isinstance(e, APIConnectionError) and not isinstance(e, APITimeoutError) \
and not is_retriable_api_connection_error(e):
raise WrappedOpenAIError(e)

# Retry InternalServerError(>=500), RateLimitError(429), UnprocessableEntityError(422)
# Solution references:
# https://platform.openai.com/docs/guides/error-codes/api-errors
# https://platform.openai.com/docs/guides/error-codes/python-library-error-types
if isinstance(e, APIStatusError):
status_code = e.response.status_code
if status_code < 500 and status_code not in [429, 422]:
Expand All @@ -567,7 +578,16 @@ def wrapper(*args, **kwargs):
# Exit retry if this is quota insufficient error
print(f"{type(e).__name__} with insufficient quota. Throw user error.", file=sys.stderr)
raise WrappedOpenAIError(e)
if i == tries:

# Retriable errors.
# To fix issue #2296, retry on api connection error, but with a separate retry policy.
if isinstance(e, APIStatusError) and e.response.status_code == 422:
consecutive_422_error_count += 1
else:
# If other retriable errors, reset consecutive_422_error_count.
consecutive_422_error_count = 0

if i == tries or consecutive_422_error_count == unprocessable_entity_error_tries:
# Exit retry if max retry reached
print(f"{type(e).__name__} reached max retry. Exit retry with user error.", file=sys.stderr)
raise ExceedMaxRetryTimes(e)
Expand Down
26 changes: 18 additions & 8 deletions src/promptflow-tools/tests/test_handle_openai_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from promptflow.tools.aoai import chat, completion
from promptflow.tools.common import handle_openai_error
from promptflow.tools.exception import ChatAPIInvalidRole, WrappedOpenAIError, to_openai_error_message, \
JinjaTemplateError, LLMError, ChatAPIFunctionRoleInvalidFormat
JinjaTemplateError, LLMError, ChatAPIFunctionRoleInvalidFormat, ExceedMaxRetryTimes
from promptflow.tools.openai import chat as openai_chat
from promptflow.tools.aoai_gpt4v import AzureOpenAI as AzureOpenAIVision
from pytest_mock import MockerFixture
Expand Down Expand Up @@ -115,8 +115,6 @@ def create_api_connection_error_with_cause():
create_api_connection_error_with_cause(),
InternalServerError("Something went wrong", response=httpx.Response(
503, request=httpx.Request('GET', 'https://www.example.com')), body=None),
UnprocessableEntityError("Something went wrong", response=httpx.Response(
422, request=httpx.Request('GET', 'https://www.example.com')), body=None)
]
),
],
Expand Down Expand Up @@ -155,9 +153,6 @@ def test_retriable_openai_error_handle(self, mocker: MockerFixture, dummyExcepti
InternalServerError("Something went wrong", response=httpx.Response(
503, request=httpx.Request('GET', 'https://www.example.com'), headers={"retry-after": "0.3"}),
body=None),
UnprocessableEntityError("Something went wrong", response=httpx.Response(
422, request=httpx.Request('GET', 'https://www.example.com'), headers={"retry-after": "0.3"}),
body=None)
]
),
],
Expand Down Expand Up @@ -188,6 +183,23 @@ def test_retriable_openai_error_handle_with_header(
]
mock_sleep.assert_has_calls(expected_calls)

def test_unprocessable_entity_error(self, mocker: MockerFixture):
unprocessable_entity_error = UnprocessableEntityError(
"Something went wrong", response=httpx.Response(
422, request=httpx.Request('GET', 'https://www.example.com')), body=None)
rate_limit_error = RateLimitError("Something went wrong", response=httpx.Response(
429, request=httpx.Request('GET', 'https://www.example.com'), headers={"retry-after": "0.3"}),
body=None)
# for below exception sequence, "consecutive_422_error_count" changes: 0 -> 1 -> 0 -> 1 -> 2.
exception_sequence = [
unprocessable_entity_error, rate_limit_error, unprocessable_entity_error, unprocessable_entity_error]
patched_test_method = mocker.patch("promptflow.tools.aoai.AzureOpenAI.chat", side_effect=exception_sequence)
# limit api connection error retry threshold to 2.
decorated_test_method = handle_openai_error(unprocessable_entity_error_tries=2)(patched_test_method)
with pytest.raises(ExceedMaxRetryTimes):
decorated_test_method()
assert patched_test_method.call_count == 4

@pytest.mark.parametrize(
"dummyExceptionList",
[
Expand All @@ -197,8 +209,6 @@ def test_retriable_openai_error_handle_with_header(
body=None),
BadRequestError("Something went wrong", response=httpx.get('https://www.example.com'),
body=None),
APIConnectionError(message="Something went wrong",
request=httpx.Request('GET', 'https://www.example.com')),
]
),
],
Expand Down

0 comments on commit 05771a7

Please sign in to comment.