Skip to content

Commit

Permalink
Prototype data-out stream RPC
Browse files Browse the repository at this point in the history
  • Loading branch information
gongy committed Nov 20, 2024
1 parent 967e74c commit f0cd0b4
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 16 deletions.
2 changes: 1 addition & 1 deletion modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ async def run_input_async(io_context: IOContext) -> None:
# Send up to this many outputs at a time.
generator_queue: asyncio.Queue[Any] = await container_io_manager._queue_create.aio(1024)
generator_output_task = asyncio.create_task(
container_io_manager.generator_output_task.aio(
container_io_manager.generator_output_task_new.aio(
function_call_ids[0],
io_context.finalized_function.data_format,
generator_queue,
Expand Down
76 changes: 76 additions & 0 deletions modal/_runtime/container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,74 @@ async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
async for data in _stream_function_call_data(self._client, function_call_id, "data_in"):
yield data

async def put_data_out_request(
self,
function_call_id: str,
start_index: int,
data_format: int,
messages_bytes: List[Any],
) -> None:
"""Put data onto the `data_out` stream of a function call.
This is used for generator outputs, which includes web endpoint responses. Note that this
was introduced as a performance optimization in client version 0.57, so older clients will
still use the previous Postgres-backed system based on `FunctionPutOutputs()`.
"""
data_chunks: List[api_pb2.DataChunk] = []
for i, message_bytes in enumerate(messages_bytes):
chunk = api_pb2.DataChunk(data_format=data_format, index=start_index + i) # type: ignore
if len(message_bytes) > MAX_OBJECT_SIZE_BYTES:
chunk.data_blob_id = await blob_upload(message_bytes, self._client.stub)
else:
chunk.data = message_bytes
data_chunks.append(chunk)

req = api_pb2.FunctionCallPutDataRequest(function_call_id=function_call_id, data_chunks=data_chunks)
return req

async def generator_output_task_new(
self, function_call_id: str, data_format: int, message_rx: asyncio.Queue
) -> None:
"""Task that feeds generator outputs into a function call's `data_out` stream."""

async def request_stream():
index = 1
received_sentinel = False

t_start_push = 0
while not received_sentinel:
message = await message_rx.get()
if message is self._GENERATOR_STOP_SENTINEL:
break
# ASGI 'http.response.start' and 'http.response.body' msgs are observed to be separated by 1ms.
# If we don't sleep here for 1ms we end up with an extra call to .put_data_out().
if index == 1:
await asyncio.sleep(0.001)
messages_bytes = [serialize_data_format(message, data_format)]
total_size = len(messages_bytes[0]) + 512

if t_start_push == 0 and total_size > 1024:
t_start_push = time.time()

while total_size < 16 * 1024 * 1024: # 16 MiB, maximum size in a single message
try:
message = message_rx.get_nowait()
except asyncio.QueueEmpty:
break
if message is self._GENERATOR_STOP_SENTINEL:
received_sentinel = True
break
else:
messages_bytes.append(serialize_data_format(message, data_format))
total_size += len(messages_bytes[-1]) + 512 # 512 bytes for estimated framing overhead

req = await self.put_data_out_request(function_call_id, index, data_format, messages_bytes)
yield req
index += len(messages_bytes)
print(f"pushed {index} chunks after {(time.time() - t_start_push) * 1000:.0f} ms")

await self._client.stub.FunctionCallPutDataOutStreaming(request_stream())

async def put_data_out(
self,
function_call_id: str,
Expand Down Expand Up @@ -547,6 +615,7 @@ async def generator_output_task(self, function_call_id: str, data_format: int, m
"""Task that feeds generator outputs into a function call's `data_out` stream."""
index = 1
received_sentinel = False
t_start_push = 0
while not received_sentinel:
message = await message_rx.get()
if message is self._GENERATOR_STOP_SENTINEL:
Expand All @@ -557,6 +626,10 @@ async def generator_output_task(self, function_call_id: str, data_format: int, m
await asyncio.sleep(0.001)
messages_bytes = [serialize_data_format(message, data_format)]
total_size = len(messages_bytes[0]) + 512

if t_start_push == 0 and total_size > 1024:
t_start_push = time.time()

while total_size < 16 * 1024 * 1024: # 16 MiB, maximum size in a single message
try:
message = message_rx.get_nowait()
Expand All @@ -568,9 +641,12 @@ async def generator_output_task(self, function_call_id: str, data_format: int, m
else:
messages_bytes.append(serialize_data_format(message, data_format))
total_size += len(messages_bytes[-1]) + 512 # 512 bytes for estimated framing overhead

await self.put_data_out(function_call_id, index, data_format, messages_bytes)
index += len(messages_bytes)

print(f"pushed {index} chunks after {(time.time() - t_start_push) * 1000:.0f} ms")

async def _queue_create(self, size: int) -> asyncio.Queue:
"""Create a queue, on the synchronicity event loop (needed on Python 3.8 and 3.9)."""
return asyncio.Queue(size)
Expand Down
47 changes: 36 additions & 11 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,20 +374,45 @@ async def _stream_function_call_data(
else:
raise ValueError(f"Invalid variant {variant}")

import time

twait = 0.0
print(f"waiting {twait}s for data channel to be saturated")
await asyncio.sleep(twait)

while True:
req = api_pb2.FunctionCallGetDataRequest(function_call_id=function_call_id, last_index=last_index)
print("make req for data")
try:
async for chunk in stub_fn.unary_stream(req):
if chunk.index <= last_index:
continue
if chunk.data_blob_id:
message_bytes = await blob_download(chunk.data_blob_id, client.stub)
else:
message_bytes = chunk.data
message = deserialize_data_format(message_bytes, chunk.data_format, client)

last_index = chunk.index
yield message
grpclib_method = await client._get_grpclib_method(stub_fn._wrapped_method_name)
async with grpclib_method.open() as stream:
# async for chunk in stub_fn.unary_stream(req):
await stream.send_message(req, end=True)
begin = time.time()
t0 = time.time()
intervals = []
async for chunk in stream:
if chunk.index <= last_index:
continue
if chunk.data_blob_id:
assert False
message_bytes = await blob_download(chunk.data_blob_id, client.stub)
else:
message_bytes = chunk.data

message = deserialize_data_format(message_bytes, chunk.data_format, client)
last_index = chunk.index
intervals.append(time.time() - t0)
print(
f"got chunk index {last_index} of size {len(message_bytes) / 1024 / 1024} MiB, "
f"took {(time.time() - t0) * 1000:.2f}ms, "
f"cumulative {time.time() - begin:.2f}s, "
f"cumulative average {(time.time() - begin) / last_index * 1000:.2f}ms/chunk, "
f"individual average {sum(intervals) / len(intervals) * 1000:.2f}ms/chunk, "
# f"time since server timestamp {(time.time() - chunk.timestamp)n / * 1000:.2f}ms"
)
yield message
t0 = time.time()
except (GRPCError, StreamTerminatedError) as exc:
if retries_remaining > 0:
retries_remaining -= 1
Expand Down
45 changes: 43 additions & 2 deletions modal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ async def _call_unary(
return await self._call_safely(coro, grpclib_method.name)

@synchronizer.nowrap
async def _call_stream(
async def _call_unary_stream(
self,
method_name: str,
request: Any,
Expand All @@ -367,6 +367,31 @@ async def _call_stream(
else:
await stream_context.__aexit__(None, None, None)

@synchronizer.nowrap
async def _call_stream_unary(
self, method_name: str, request_stream: AsyncIterator[Any], *, metadata: Optional[_MetadataLike]
):
grpclib_method = await self._get_grpclib_method(method_name)
stream_context = grpclib_method.open(metadata=metadata)
stream = await self._call_safely(stream_context.__aenter__(), f"{grpclib_method.name}.open")
try:
while True:
try:
request = await self._call_safely(request_stream.__anext__(), f"{grpclib_method.name}.request_iter")
await self._call_safely(stream.send_message(request), f"{grpclib_method.name}.send_message")
except StopAsyncIteration:
break

await self._call_safely(stream.end(), f"{grpclib_method.name}.end")
# TODO(gongy): what if the stream never comes back with anything?
resp = await self._call_safely(stream.__anext__(), f"{grpclib_method.name}.recv")
await stream_context.__aexit__(None, None, None)
return resp
except BaseException as exc:
did_handle_exception = await stream_context.__aexit__(type(exc), exc, exc.__traceback__)
if not did_handle_exception:
raise


Client = synchronize_api(_Client)

Expand Down Expand Up @@ -421,5 +446,21 @@ async def unary_stream(
if self.client._snapshotted:
logger.debug(f"refreshing client after snapshot for {self._wrapped_method_name}")
self.client = await _Client.from_env()
async for response in self.client._call_stream(self._wrapped_method_name, request, metadata=metadata):
async for response in self.client._call_unary_stream(self._wrapped_method_name, request, metadata=metadata):
yield response


class StreamUnaryWrapper(Generic[RequestType, ResponseType]):
wrapped_method: grpclib.client.StreamUnaryMethod[RequestType, ResponseType]

def __init__(self, wrapped_method: grpclib.client.StreamUnaryMethod[RequestType, ResponseType], client: _Client):
self._wrapped_full_name = wrapped_method.name
self._wrapped_method_name = wrapped_method.name.rsplit("/", 1)[1]
self.client = client

@property
def name(self) -> str:
return self._wrapped_full_name

async def __call__(self, request_stream: AsyncIterator[RequestType], metadata: Optional[Any] = None):
return await self.client._call_stream_unary(self._wrapped_method_name, request_stream, metadata=metadata)
1 change: 1 addition & 0 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2688,6 +2688,7 @@ service ModalClient {
rpc FunctionCallGetDataOut(FunctionCallGetDataRequest) returns (stream DataChunk);
rpc FunctionCallList(FunctionCallListRequest) returns (FunctionCallListResponse);
rpc FunctionCallPutDataOut(FunctionCallPutDataRequest) returns (google.protobuf.Empty);
rpc FunctionCallPutDataOutStreaming(stream FunctionCallPutDataRequest) returns (google.protobuf.Empty);
rpc FunctionCreate(FunctionCreateRequest) returns (FunctionCreateResponse);
rpc FunctionGet(FunctionGetRequest) returns (FunctionGetResponse);
rpc FunctionGetCallGraph(FunctionGetCallGraphRequest) returns (FunctionGetCallGraphResponse);
Expand Down
4 changes: 2 additions & 2 deletions protoc_plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def render(
wrapper_cls = "modal.client.UnaryUnaryWrapper"
elif cardinality is const.Cardinality.UNARY_STREAM:
wrapper_cls = "modal.client.UnaryStreamWrapper"
# elif cardinality is const.Cardinality.STREAM_UNARY:
# wrapper_cls = StreamUnaryWrapper
elif cardinality is const.Cardinality.STREAM_UNARY:
wrapper_cls = "modal.client.StreamUnaryWrapper"
# elif cardinality is const.Cardinality.STREAM_STREAM:
# wrapper_cls = StreamStreamWrapper
else:
Expand Down

0 comments on commit f0cd0b4

Please sign in to comment.