Skip to content

Commit

Permalink
Merge pull request #155 from wilhelm-lab/chore/improve_koina_error_ha…
Browse files Browse the repository at this point in the history
…ndling

Improve Koina client error handling.
  • Loading branch information
picciama authored Dec 7, 2023
2 parents f4eebb7 + f3407a8 commit c6c45f3
Showing 1 changed file with 75 additions and 29 deletions.
104 changes: 75 additions & 29 deletions oktoberfest/predict/koina.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import warnings
from functools import partial
from typing import Dict, Generator, KeysView, List, Optional, Union

Expand All @@ -20,6 +21,7 @@ class Koina:
model_inputs: Dict[str, str]
model_outputs: Dict[str, np.ndarray]
batch_size: int
_response_dict: Dict[int, Union[InferResult, InferenceServerException]]

def __init__(
self,
Expand All @@ -45,7 +47,7 @@ def __init__(
"""
self.model_inputs = {}
self.model_outputs = {}
# self.batchsize = No
self._response_dict = {}

self.model_name = model_name
self.url = server_url
Expand All @@ -67,6 +69,11 @@ def __init__(
self.__get_outputs(targets)
self.__get_batchsize()

@property
def response_dict(self):
"""The dictionary containing raw InferenceResult/InferenceServerException objects (values) for a given request_id (key)."""
return self._response_dict

Check warning on line 75 in oktoberfest/predict/koina.py

View check run for this annotation

Codecov / codecov/patch

oktoberfest/predict/koina.py#L75

Added line #L75 was not covered by tests

def _is_server_ready(self):
"""
Check if the inference server is live and accessible.
Expand Down Expand Up @@ -350,7 +357,13 @@ def __merge_list_dict_array(dict_list: List[Dict[str, np.ndarray]]) -> Dict[str,
out[k] = np.concatenate([x[k] for x in dict_list])
return out

def __async_callback(self, infer_results: List[InferResult], result: InferResult, error):
def __async_callback(
self,
infer_results: Dict[int, Union[InferResult, InferenceServerException]],
request_id: int,
result: InferResult,
error: Optional[InferenceServerException],
):
"""
Callback function for asynchronous inference.
Expand All @@ -359,18 +372,22 @@ def __async_callback(self, infer_results: List[InferResult], result: InferResult
encountered error is checked and handled appropriately. Note: This method is for internal use and is typically
called during asynchronous inference.
:param infer_results: A list to which the results of asynchronous inference will be appended.
:param infer_results: A dictionary to which the results of asynchronous inference will be added.
:param request_id: The request id used as key in the infer_results dictionary
:param result: The result of an asynchronous inference operation.
:param error: An error, if any, encountered during asynchronous inference.
:raises error: if any exception was encountered during asynchronous inference.
"""
if error:
raise error
infer_results[request_id] = error

Check warning on line 381 in oktoberfest/predict/koina.py

View check run for this annotation

Codecov / codecov/patch

oktoberfest/predict/koina.py#L381

Added line #L381 was not covered by tests
else:
infer_results.append(result)
infer_results[request_id] = result

def __async_predict_batch(
self, data: Dict[str, np.ndarray], infer_results: List[InferResult], request_id: int, timeout: int = 10
self,
data: Dict[str, np.ndarray],
infer_results: Dict[int, Union[InferResult, InferenceServerException]],
request_id: int,
timeout: int = 60000,
):
"""
Perform asynchronous batch inference on the given data using the Koina model.
Expand All @@ -381,24 +398,36 @@ def __async_predict_batch(
'timeout' is reached.
:param data: A dictionary containing input data for batch inference. Keys are input names, and values are numpy arrays.
:param infer_results: A list to which the results of asynchronous inference will be appended.
:param infer_results: A dictionary to which the results of asynchronous inference will be added.
:param request_id: An identifier for the inference request, used to track the order of completion.
:param timeout: The maximum time (in seconds) to wait for the inference to complete. Defaults to 10 seconds.
"""
batch_outputs = self.__get_batch_outputs(self.model_outputs.keys())
batch_inputs = self.__get_batch_inputs(data)

self.client.async_infer(
model_name=self.model_name,
request_id=str(request_id),
inputs=batch_inputs,
callback=partial(self.__async_callback, infer_results),
outputs=batch_outputs,
client_timeout=timeout,
)
max_requests = 3

for _ in range(max_requests):
self.client.async_infer(
model_name=self.model_name,
request_id=str(request_id),
inputs=batch_inputs,
callback=partial(self.__async_callback, infer_results, request_id),
outputs=batch_outputs,
client_timeout=timeout,
)
while infer_results.get(request_id) is None:
time.sleep(0.1)
if isinstance(infer_results.get(request_id), InferResult):
break
del infer_results[request_id]

Check warning on line 423 in oktoberfest/predict/koina.py

View check run for this annotation

Codecov / codecov/patch

oktoberfest/predict/koina.py#L423

Added line #L423 was not covered by tests

def predict(
self, data: Union[Dict[str, np.ndarray], pd.DataFrame], disable_progress_bar: bool = False, _async: bool = True
self,
data: Union[Dict[str, np.ndarray], pd.DataFrame],
disable_progress_bar: bool = False,
_async: bool = True,
debug=False,
) -> Dict[str, np.ndarray]:
"""
Perform inference on the given data using the Koina model.
Expand All @@ -414,12 +443,13 @@ def predict(
in the column names.
:param disable_progress_bar: If True, disable the progress bar during inference. Defaults to False.
:param _async: If True, perform asynchronous inference; if False, perform sequential inference. Defaults to True.
:param debug: If True and using _async mode, store raw InferResult / InferServerException dictionary for later analysis.
:return: A dictionary containing the model's predictions. Keys are output names, and values are numpy arrays
representing the model's output.
Example::
model = KoinaModel("Prosit_2019_intensity")
model = Koina("Prosit_2019_intensity")
input_data = {
"peptide_sequences": np.array(["PEPTIDEK" for _ in range(size)]),
"precursor_charges": np.array([2 for _ in range(size)]),
Expand All @@ -432,12 +462,13 @@ def predict(
if isinstance(data, pd.DataFrame):
data = {input_field: data[input_field].to_numpy() for input_field in self.model_inputs.keys()}
if _async:
pred_func = self.__predict_async
return self.__predict_async(data, disable_progress_bar=disable_progress_bar, debug=debug)
else:
pred_func = self.__predict_sequential
return pred_func(data, disable_progress_bar=disable_progress_bar)
return self.__predict_sequential(data, disable_progress_bar=disable_progress_bar)

Check warning on line 467 in oktoberfest/predict/koina.py

View check run for this annotation

Codecov / codecov/patch

oktoberfest/predict/koina.py#L467

Added line #L467 was not covered by tests

def __predict_async(self, data: Dict[str, np.ndarray], disable_progress_bar: bool = False) -> Dict[str, np.ndarray]:
def __predict_async(
self, data: Dict[str, np.ndarray], disable_progress_bar: bool = False, debug=False
) -> Dict[str, np.ndarray]:
"""
Perform asynchronous inference on the given data using the Koina model.
Expand All @@ -448,11 +479,13 @@ def __predict_async(self, data: Dict[str, np.ndarray], disable_progress_bar: boo
:param data: A dictionary containing input data for inference. Keys are input names, and values are numpy arrays.
:param disable_progress_bar: If True, disable the progress bar during asynchronous inference. Defaults to False.
:param debug: If True, store raw InferResult / InferServerException dictionary for later analysis.
:raises InferenceServerException: If at least one batch of predictions could not be inferred.
:return: A dictionary containing the model's predictions. Keys are output names, and values are numpy arrays
representing the model's output.
"""
infer_results: List[InferResult] = []
infer_results: Dict[int, Union[InferResult, InferenceServerException]] = {}
for i, data_batch in enumerate(self.__slice_dict(data, self.batchsize)):
self.__async_predict_batch(data_batch, infer_results, request_id=i)

Expand All @@ -464,10 +497,23 @@ def __predict_async(self, data: Dict[str, np.ndarray], disable_progress_bar: boo
pbar.n = len(infer_results)
pbar.refresh()

# sort according to request id
infer_results_to_return = [
self.__extract_predictions(infer_results[i])
for i in np.argsort(np.array([int(y.get_response("id")["id"]) for y in infer_results]))
]
if debug:
self._response_dict = infer_results

Check warning on line 501 in oktoberfest/predict/koina.py

View check run for this annotation

Codecov / codecov/patch

oktoberfest/predict/koina.py#L501

Added line #L501 was not covered by tests

return self.__merge_list_dict_array(infer_results_to_return)
try:
# sort according to request id
infer_results_to_return = [
self.__extract_predictions(infer_results[i]) for i in np.argsort(list(infer_results.keys()))
]
return self.__merge_list_dict_array(infer_results_to_return)
except AttributeError:

Check warning on line 509 in oktoberfest/predict/koina.py

View check run for this annotation

Codecov / codecov/patch

oktoberfest/predict/koina.py#L509

Added line #L509 was not covered by tests
for res in infer_results.values():
if isinstance(res, InferenceServerException):
warnings.warn(res.message(), stacklevel=1)

Check warning on line 512 in oktoberfest/predict/koina.py

View check run for this annotation

Codecov / codecov/patch

oktoberfest/predict/koina.py#L512

Added line #L512 was not covered by tests
else:
raise InferenceServerException(

Check warning on line 514 in oktoberfest/predict/koina.py

View check run for this annotation

Codecov / codecov/patch

oktoberfest/predict/koina.py#L514

Added line #L514 was not covered by tests
"""
At least one request failed. Check the error message above and try again.
To get a list of responses run koina.predict(..., debug = True), then call koina.response_dict
"""
) from None

0 comments on commit c6c45f3

Please sign in to comment.