Skip to content

Commit

Permalink
add non-blocking client api
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Mar 12, 2023
1 parent ed265f2 commit 3053a3f
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions mii/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def create_channel(host, port):
GRPC_MAX_MSG_SIZE)])


class QueryResultFuture():
def __init__(self, asyncio_loop, coro):
self.asyncio_loop = asyncio_loop
self.coro = coro

def result(self):
return self.asyncio_loop.run_until_complete(self.coro)


class MIIClient():
"""
Client to send queries to a single endpoint.
Expand All @@ -73,11 +82,15 @@ async def _request_async_response(self, request_dict, **query_kwargs):
proto_response
) if "unpack_response_from_proto" in conversions else proto_response

def query(self, request_dict, **query_kwargs):
return self.asyncio_loop.run_until_complete(
def query_async(self, request_dict, **query_kwargs):
return QueryResultFuture(
self.asyncio_loop,
self._request_async_response(request_dict,
**query_kwargs))

def query(self, request_dict, **query_kwargs):
return self.query_async(request_dict, **query_kwargs).result()

async def terminate_async(self):
await self.stub.Terminate(
modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty())
Expand Down Expand Up @@ -106,7 +119,13 @@ async def _query_in_tensor_parallel(self, request_string, query_kwargs):
**query_kwargs)))

await responses[0]
return responses[0]
return responses[0].result()

def query_async(self, request_dict, **query_kwargs):
return QueryResultFuture(
self.asyncio_loop,
self._query_in_tensor_parallel(request_dict,
query_kwargs))

def query(self, request_dict, **query_kwargs):
"""Query a local deployment:
Expand All @@ -121,11 +140,7 @@ def query(self, request_dict, **query_kwargs):
Returns:
response: Response of the model
"""
response = self.asyncio_loop.run_until_complete(
self._query_in_tensor_parallel(request_dict,
query_kwargs))
ret = response.result()
return ret
return self.query_async(request_dict, **query_kwargs).result()

def terminate(self):
"""Terminates the deployment"""
Expand Down

0 comments on commit 3053a3f

Please sign in to comment.