diff --git a/modal/_runtime/container_io_manager.py b/modal/_runtime/container_io_manager.py index d0a167cd4..9d4c9acbd 100644 --- a/modal/_runtime/container_io_manager.py +++ b/modal/_runtime/container_io_manager.py @@ -63,7 +63,9 @@ class IOContext: """ input_ids: list[str] + retry_counts: list[int] function_call_ids: list[str] + function_inputs: list[api_pb2.FunctionInput] finalized_function: "modal._runtime.user_code_imports.FinalizedFunction" _cancel_issued: bool = False @@ -72,6 +74,7 @@ class IOContext: def __init__( self, input_ids: list[str], + retry_counts: list[int], function_call_ids: list[str], finalized_function: "modal._runtime.user_code_imports.FinalizedFunction", function_inputs: list[api_pb2.FunctionInput], @@ -79,9 +82,10 @@ def __init__( client: _Client, ): self.input_ids = input_ids + self.retry_counts = retry_counts self.function_call_ids = function_call_ids self.finalized_function = finalized_function - self._function_inputs = function_inputs + self.function_inputs = function_inputs self._is_batched = is_batched self._client = client @@ -90,11 +94,11 @@ async def create( cls, client: _Client, finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"], - inputs: list[tuple[str, str, api_pb2.FunctionInput]], + inputs: list[tuple[str, int, str, api_pb2.FunctionInput]], is_batched: bool, ) -> "IOContext": assert len(inputs) >= 1 if is_batched else len(inputs) == 1 - input_ids, function_call_ids, function_inputs = zip(*inputs) + input_ids, retry_counts, function_call_ids, function_inputs = zip(*inputs) async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -> api_pb2.FunctionInput: # If we got a pointer to a blob, download it from S3. @@ -111,7 +115,7 @@ async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) - method_name = function_inputs[0].method_name assert all(method_name == input.method_name for input in function_inputs) finalized_function = finalized_functions[method_name] - return cls(input_ids, function_call_ids, finalized_function, function_inputs, is_batched, client) + return cls(input_ids, retry_counts, function_call_ids, finalized_function, function_inputs, is_batched, client) def set_cancel_callback(self, cb: Callable[[], None]): self._cancel_callback = cb @@ -135,7 +139,7 @@ def _args_and_kwargs(self) -> tuple[tuple[Any, ...], dict[str, list[Any]]]: # to make sure we handle user exceptions properly # and don't retry deserialized_args = [ - deserialize(input.args, self._client) if input.args else ((), {}) for input in self._function_inputs + deserialize(input.args, self._client) if input.args else ((), {}) for input in self.function_inputs ] if not self._is_batched: return deserialized_args[0] @@ -564,7 +568,7 @@ async def _generate_inputs( self, batch_max_size: int, batch_wait_ms: int, - ) -> AsyncIterator[list[tuple[str, str, api_pb2.FunctionInput]]]: + ) -> AsyncIterator[list[tuple[str, int, str, api_pb2.FunctionInput]]]: request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id) iteration = 0 while self._fetching_inputs: @@ -599,8 +603,7 @@ async def _generate_inputs( if item.kill_switch: logger.debug(f"Task {self.task_id} input kill signal input.") return - - inputs.append((item.input_id, item.function_call_id, item.input)) + inputs.append((item.input_id, item.retry_count, item.function_call_id, item.input)) if item.input.final_input: if request.batch_max_size > 0: logger.debug(f"Task {self.task_id} Final input not expected in batch input stream") @@ -661,8 +664,9 @@ async def _push_outputs( output_created_at=output_created_at, result=result, data_format=data_format, + retry_count=retry_count, ) - for input_id, result in zip(io_context.input_ids, results) + for input_id, retry_count, result in zip(io_context.input_ids, io_context.retry_counts, results) ] await retry_transient_errors( self._client.stub.FunctionPutOutputs, diff --git a/modal/_utils/async_utils.py b/modal/_utils/async_utils.py index 6320fea78..6d2182900 100644 --- a/modal/_utils/async_utils.py +++ b/modal/_utils/async_utils.py @@ -12,6 +12,7 @@ from typing import ( Any, Callable, + Generic, Optional, TypeVar, Union, @@ -26,6 +27,10 @@ from ..exception import InvalidError from .logger import logger +T = TypeVar("T") +P = ParamSpec("P") +V = TypeVar("V") + synchronizer = synchronicity.Synchronizer() @@ -260,7 +265,59 @@ def run_coro_blocking(coro): return fut.result() -async def queue_batch_iterator(q: asyncio.Queue, max_batch_size=100, debounce_time=0.015): +class TimestampPriorityQueue(Generic[T]): + """ + A priority queue that schedules items to be processed at specific timestamps. + """ + + _MAX_PRIORITY = float("inf") + + def __init__(self, maxsize: int = 0): + self.condition = asyncio.Condition() + self._queue: asyncio.PriorityQueue[tuple[float, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize) + + async def close(self): + await self.put(self._MAX_PRIORITY, None) + + async def put(self, timestamp: float, item: Union[T, None]): + """ + Add an item to the queue to be processed at a specific timestamp. + """ + await self._queue.put((timestamp, item)) + async with self.condition: + self.condition.notify_all() # notify any waiting coroutines + + async def get(self) -> Union[T, None]: + """ + Get the next item from the queue that is ready to be processed. + """ + while True: + async with self.condition: + while self.empty(): + await self.condition.wait() + # peek at the next item + timestamp, item = await self._queue.get() + now = time.time() + if timestamp < now: + return item + if timestamp == self._MAX_PRIORITY: + return None + # not ready yet, calculate sleep time + sleep_time = timestamp - now + self._queue.put_nowait((timestamp, item)) # put it back + # wait until either the timeout or a new item is added + try: + await asyncio.wait_for(self.condition.wait(), timeout=sleep_time) + except asyncio.TimeoutError: + continue + + def empty(self) -> bool: + return self._queue.empty() + + +async def queue_batch_iterator( + q: Union[asyncio.Queue, TimestampPriorityQueue], max_batch_size=100, debounce_time=0.015 +): """ Read from a queue but return lists of items when queue is large @@ -401,11 +458,6 @@ async def wrapper(): _shutdown_tasks.append(asyncio.create_task(wrapper())) -T = TypeVar("T") -P = ParamSpec("P") -V = TypeVar("V") - - def asyncify(f: Callable[P, T]) -> Callable[P, typing.Coroutine[None, None, T]]: """Convert a blocking function into one that runs in the current loop's executor.""" diff --git a/modal/functions.py b/modal/functions.py index 88c13d5de..fa2c98d71 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -1,4 +1,5 @@ # Copyright Modal Labs 2023 +import asyncio import dataclasses import inspect import textwrap @@ -256,7 +257,10 @@ async def run_function(self) -> Any: try: return await self._get_single_output(ctx.input_jwt) except (UserCodeException, FunctionTimeoutError) as exc: - await user_retry_manager.raise_or_sleep(exc) + delay_ms = user_retry_manager.get_delay_ms() + if delay_ms is None: + raise exc + await asyncio.sleep(delay_ms / 1000) except InternalFailure: # For system failures on the server, we retry immediately. pass @@ -1253,6 +1257,11 @@ async def _map( else: count_update_callback = None + if config.get("client_retries"): + function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC + else: + function_call_invocation_type = api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY + async with aclosing( _map_invocation( self, # type: ignore @@ -1261,6 +1270,7 @@ async def _map( order_outputs, return_exceptions, count_update_callback, + function_call_invocation_type, ) ) as stream: async for item in stream: diff --git a/modal/parallel_map.py b/modal/parallel_map.py index 11a943bfa..5cb938fbc 100644 --- a/modal/parallel_map.py +++ b/modal/parallel_map.py @@ -10,6 +10,7 @@ from modal._runtime.execution_context import current_input_id from modal._utils.async_utils import ( AsyncOrSyncIterable, + TimestampPriorityQueue, aclosing, async_map_ordered, async_merge, @@ -29,6 +30,7 @@ ) from modal._utils.grpc_utils import retry_transient_errors from modal.config import logger +from modal.retries import RetryManager from modal_proto import api_pb2 if typing.TYPE_CHECKING: @@ -62,6 +64,21 @@ class _OutputValue: value: Any +@dataclass +class _MapItemContext: + function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType" + input: api_pb2.FunctionInput + input_id: str + input_jwt: str + retry_manager: RetryManager + + +# maximum number of inputs that can be in progress (either queued to be sent, +# or waiting for completion). if this limit is reached, we will block sending +# more inputs to the server until some of the existing inputs are completed. +MAP_MAX_INPUTS_OUTSTANDING = 1000 + +# maximum number of inputs to send to the server in a single request MAP_INVOCATION_CHUNK_SIZE = 49 if typing.TYPE_CHECKING: @@ -75,6 +92,7 @@ async def _map_invocation( order_outputs: bool, return_exceptions: bool, count_update_callback: Optional[Callable[[int, int], None]], + function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType", ): assert client.stub request = api_pb2.FunctionMapRequest( @@ -82,10 +100,13 @@ async def _map_invocation( parent_input_id=current_input_id() or "", function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP, return_exceptions=return_exceptions, + function_call_invocation_type=function_call_invocation_type, ) response = await retry_transient_errors(client.stub.FunctionMap, request) function_call_id = response.function_call_id + function_call_jwt = response.function_call_jwt + retry_policy = response.retry_policy have_all_inputs = False num_inputs = 0 @@ -95,10 +116,13 @@ def count_update(): if count_update_callback is not None: count_update_callback(num_outputs, num_inputs) - pending_outputs: dict[str, int] = {} # Map input_id -> next expected gen_index value + retry_queue = TimestampPriorityQueue() + pending_outputs: dict[int, asyncio.Future[_MapItemContext]] = {} # Map input idx -> context completed_outputs: set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values) + input_queue: asyncio.Queue[api_pb2.FunctionPutInputsItem | None] = asyncio.Queue() - input_queue: asyncio.Queue = asyncio.Queue() + # semaphore to limit the number of inputs that can be in progress at once + inputs_outstanding = asyncio.BoundedSemaphore(MAP_MAX_INPUTS_OUTSTANDING) async def create_input(argskwargs): nonlocal num_inputs @@ -115,6 +139,8 @@ async def input_iter(): yield raw_input # args, kwargs async def drain_input_generator(): + nonlocal have_all_inputs + # Parallelize uploading blobs async with aclosing( async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM) @@ -122,26 +148,40 @@ async def drain_input_generator(): async for item in streamer: await input_queue.put(item) - # close queue iterator - await input_queue.put(None) + have_all_inputs = True yield async def pump_inputs(): assert client.stub nonlocal have_all_inputs, num_inputs - async for items in queue_batch_iterator(input_queue, MAP_INVOCATION_CHUNK_SIZE): + async for items in queue_batch_iterator(input_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE): + event_loop = asyncio.get_event_loop() + for item in items: + # acquire semaphore to limit the number of inputs in progress + # (either queued to be sent, waiting for completion, or retrying) + await inputs_outstanding.acquire() + + # create a future for each input, to be resolved when we have + # received the input ID and JWT from the server. this addresses + # a race condition where we could receive outputs before we have + # recorded the input ID and JWT in `pending_outputs`. + pending_outputs[item.idx] = event_loop.create_future() + request = api_pb2.FunctionPutInputsRequest( - function_id=function.object_id, inputs=items, function_call_id=function_call_id + function_id=function.object_id, + inputs=items, + function_call_id=function_call_id, ) logger.debug( f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}." ) + while True: try: resp = await retry_transient_errors( client.stub.FunctionPutInputs, request, - # with 8 retries we log the warning below about every 30 secondswhich isn't too spammy. + # with 8 retries we log the warning below about every 30 seconds, which isn't too spammy. max_retries=8, max_delay=15, additional_status_codes=[Status.RESOURCE_EXHAUSTED], @@ -151,26 +191,80 @@ async def pump_inputs(): if err.status != Status.RESOURCE_EXHAUSTED: raise err logger.warning( - f"Warning: map progress for function {function._function_name} is limited." - " Common bottlenecks include slow iteration over results, or function backlogs." + "Warning: map progress is limited. Common bottlenecks " + "include slow iteration over results, or function backlogs." ) count_update() - for item in resp.inputs: - pending_outputs.setdefault(item.input_id, 0) + + items_by_idx = {item.idx: item for item in items} + for response_item in resp.inputs: + original_item = items_by_idx[response_item.idx] + pending_outputs[response_item.idx].set_result( + _MapItemContext( + function_call_invocation_type=function_call_invocation_type, + input=original_item.input, + input_id=response_item.input_id, + input_jwt=response_item.input_jwt, + retry_manager=RetryManager(retry_policy), + ) + ) + logger.debug( f"Successfully pushed {len(items)} inputs to server. " f"Num queued inputs awaiting push is {input_queue.qsize()}." ) - have_all_inputs = True + yield + + async def retry_inputs(): + async for retriable_input_ids in queue_batch_iterator(retry_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE): + inputs = [] + for retriable_input_id in retriable_input_ids: + item_context = await pending_outputs[retriable_input_id] + inputs.append( + api_pb2.FunctionRetryInputsItem( + input_jwt=item_context.input_jwt, + input=item_context.input, + retry_count=item_context.retry_manager.attempt_count, + ) + ) + + request = api_pb2.FunctionRetryInputsRequest( + function_call_jwt=function_call_jwt, + inputs=inputs, + ) + + while True: + try: + await retry_transient_errors( + client.stub.FunctionRetryInputs, + request, + # with 8 retries we log the warning below about every 30 seconds, which isn't too spammy. + max_retries=8, + max_delay=15, + additional_status_codes=[Status.RESOURCE_EXHAUSTED], + ) + break + except GRPCError as err: + if err.status != Status.RESOURCE_EXHAUSTED: + raise err + logger.warning( + f"Warning: map progress for function {function._function_name} is limited." + " Common bottlenecks include slow iteration over results, or function backlogs." + ) + + logger.debug(f"Successfully pushed retry for {inputs} to server. ") yield async def get_all_outputs(): assert client.stub nonlocal num_inputs, num_outputs, have_all_inputs last_entry_id = "0-0" - while not have_all_inputs or len(pending_outputs) > len(completed_outputs): + + while not have_all_inputs or num_outputs < num_inputs: + logger.debug(f"Requesting outputs. Have {num_outputs} outputs, {num_inputs} inputs.") + request = api_pb2.FunctionGetOutputsRequest( function_call_id=function_call_id, timeout=OUTPUTS_TIMEOUT, @@ -186,16 +280,43 @@ async def get_all_outputs(): ) if len(response.outputs) == 0: + logger.debug("No outputs received.") continue + else: + logger.debug(f"Received {len(response.outputs)} outputs.") last_entry_id = response.last_entry_id + now_seconds = int(time.time()) for item in response.outputs: - pending_outputs.setdefault(item.input_id, 0) if item.input_id in completed_outputs: # If this input is already completed, it means the output has already been # processed and was received again due to a duplicate. continue + + future = pending_outputs.get(item.idx, None) + if future is None: + # We've already processed this output, so we can skip it. + # This can happen because the worker can sometimes send duplicate outputs. + continue + item_context = await future + + if item.result and item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS: + # clear the item context to allow it to be garbage collected + del pending_outputs[item.idx] + else: + # retry failed inputs when the function call invocation type is SYNC + if item_context.function_call_invocation_type == api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC: + delay_ms = item_context.retry_manager.get_delay_ms() + + if delay_ms is not None: + await retry_queue.put(now_seconds + (delay_ms / 1000), item.idx) + continue + else: + # we're out of retries, so we'll just output the error + pass + completed_outputs.add(item.input_id) + inputs_outstanding.release() num_outputs += 1 yield item @@ -216,6 +337,10 @@ async def get_all_outputs_and_clean_up(): ) await retry_transient_errors(client.stub.FunctionGetOutputs, request) + # close the input queue iterator + await input_queue.put(None) + await retry_queue.close() + async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]: try: output = await _process_result(item.result, item.data_format, client.stub, client) @@ -241,14 +366,23 @@ async def poll_outputs(): else: # hold on to outputs for function maps, so we can reorder them correctly. received_outputs[idx] = output - while output_idx in received_outputs: + + while True: + if output_idx not in received_outputs: + # we haven't received the output for the current index yet. + # stop returning outputs to the caller and instead wait for + # the next output to arrive from the server. + break + output = received_outputs.pop(output_idx) yield _OutputValue(output) output_idx += 1 assert len(received_outputs) == 0 - async with aclosing(async_merge(drain_input_generator(), pump_inputs(), poll_outputs())) as streamer: + async with aclosing( + async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), retry_inputs()) + ) as streamer: async for response in streamer: if response is not None: yield response.value diff --git a/modal/retries.py b/modal/retries.py index c803fed15..60a14606e 100644 --- a/modal/retries.py +++ b/modal/retries.py @@ -1,6 +1,6 @@ # Copyright Modal Labs 2022 -import asyncio from datetime import timedelta +from typing import Union from modal_proto import api_pb2 @@ -118,15 +118,17 @@ def __init__(self, retry_policy: api_pb2.FunctionRetryPolicy): self.retry_policy = retry_policy self.attempt_count = 0 - async def raise_or_sleep(self, exc: Exception): + def get_delay_ms(self) -> Union[float, None]: """ - Raises an exception if the maximum retry count has been reached, otherwise sleeps for calculated delay. + Returns the delay in milliseconds before the next retry, or None + if the maximum number of retries has been reached. """ self.attempt_count += 1 + if self.attempt_count > self.retry_policy.retries: - raise exc - delay_ms = self._retry_delay_ms(self.attempt_count, self.retry_policy) - await asyncio.sleep(delay_ms / 1000) + return None + + return self._retry_delay_ms(self.attempt_count, self.retry_policy) @staticmethod def _retry_delay_ms(attempt_count: int, retry_policy: api_pb2.FunctionRetryPolicy) -> float: diff --git a/test/async_utils_test.py b/test/async_utils_test.py index 1f3881b13..498ef7b0e 100644 --- a/test/async_utils_test.py +++ b/test/async_utils_test.py @@ -8,6 +8,7 @@ import subprocess import sys import textwrap +import time from test import helpers import pytest_asyncio @@ -16,6 +17,7 @@ from modal._utils import async_utils from modal._utils.async_utils import ( TaskContext, + TimestampPriorityQueue, aclosing, async_chain, async_map, @@ -1307,3 +1309,50 @@ def line(): assert p.wait() == 0 assert p.stdout.read() == "" assert p.stderr.read() == "" + + +@pytest.mark.asyncio +async def test_timed_priority_queue(): + queue: TimestampPriorityQueue = TimestampPriorityQueue() + now = time.time() + + async def producer(): + await queue.put(now + 0.2, 2) + await queue.put(now + 0.1, 1) + await queue.put(now + 0.3, 3) + + async def consumer(): + items = [] + for _ in range(3): + item = await queue.get() + items.append(item) + return items + + await producer() + items = await consumer() + assert items == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_timed_priority_queue_duplicates(): + class _QueueItem: + pass + + queue: TimestampPriorityQueue = async_utils.TimestampPriorityQueue() + now = time.time() + x = now + 0.1 + + async def producer(): + await queue.put(1, x) + await queue.put(1, x) + + async def consumer(): + items = [] + for _ in range(2): + item = await queue.get() + items.append(item) + return items + + await producer() + items = await consumer() + assert len([it for it in items]) == 2 diff --git a/test/function_retry_test.py b/test/function_retry_test.py index 3df909611..826f33863 100644 --- a/test/function_retry_test.py +++ b/test/function_retry_test.py @@ -26,14 +26,14 @@ def __init__(self, function_call_count): self.function_call_count = function_call_count -def counting_function(attempt_to_return_success: int): +def counting_function(return_success_on_retry_count: int): """ A function that updates the global function_call_count counter each time it is called. """ global function_call_count function_call_count += 1 - if function_call_count < attempt_to_return_success: + if function_call_count < return_success_on_retry_count: raise FunctionCallCountException(function_call_count) return function_call_count @@ -78,7 +78,7 @@ def test_no_retries_when_first_call_succeeds(client, setup_app_and_function, mon assert function_call_count == 1 -def test_retry_dealy_ms(): +def test_retry_delay_ms(): with pytest.raises(ValueError): RetryManager._retry_delay_ms(0, api_pb2.FunctionRetryPolicy()) @@ -99,3 +99,37 @@ def test_lost_inputs_retried(client, setup_app_and_function, monkeypatch, servic f.remote(10) # Assert the function was called 10 times assert function_call_count == 10 + + +def test_map_fails_immediately_without_retries(client, setup_app_and_function, monkeypatch): + monkeypatch.setenv("MODAL_CLIENT_RETRIES", "false") + app, f = setup_app_and_function + with app.run(client=client): + with pytest.raises(FunctionCallCountException) as exc_info: + list(f.map([999, 999, 999])) + assert exc_info.value.function_call_count == 1 + + +def test_map_all_retries_fail_raises_error(client, setup_app_and_function, monkeypatch): + monkeypatch.setenv("MODAL_CLIENT_RETRIES", "true") + app, f = setup_app_and_function + with app.run(client=client): + with pytest.raises(FunctionCallCountException) as exc_info: + list(f.map([999])) + assert exc_info.value.function_call_count == 4 + + +def test_map_failures_followed_by_success(client, setup_app_and_function, monkeypatch): + monkeypatch.setenv("MODAL_CLIENT_RETRIES", "true") + app, f = setup_app_and_function + with app.run(client=client): + results = list(f.map([3, 3, 3])) + assert set(results) == {3, 4, 5} + + +def test_map_no_retries_when_first_call_succeeds(client, setup_app_and_function, monkeypatch): + monkeypatch.setenv("MODAL_CLIENT_RETRIES", "true") + app, f = setup_app_and_function + with app.run(client=client): + results = list(f.map([1, 1, 1])) + assert set(results) == {1, 2, 3}