Skip to content

Commit

Permalink
use background text to send events to workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Oct 21, 2024
1 parent 7c03a1e commit 80106c3
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions llama_deploy/services/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from logging import getLogger
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import AsyncGenerator, Dict, Optional, Any, Type
from typing import AsyncGenerator, Dict, Optional, Any

from llama_index.core.workflow import Context, Workflow, Event, BlockingEvent
from llama_index.core.workflow import Context, Workflow
from llama_index.core.workflow.context_serializers import (
JsonPickleSerializer,
JsonSerializer,
Expand Down Expand Up @@ -129,7 +129,7 @@ class WorkflowService(BaseService):
_publish_callback: Optional[PublishCallback] = PrivateAttr()
_lock: asyncio.Lock = PrivateAttr()
_outstanding_calls: Dict[str, WorkflowState] = PrivateAttr()
_events_buffer: Dict[str, Dict[Type[Event], asyncio.Queue]] = PrivateAttr()
_events_buffer: Dict[str, asyncio.Queue] = PrivateAttr()

def __init__(
self,
Expand Down Expand Up @@ -168,9 +168,7 @@ def __init__(

self._outstanding_calls: Dict[str, WorkflowState] = {}
self._ongoing_tasks: Dict[str, asyncio.Task] = {}
self._events_buffer: Dict[str, Dict[Type[Event], asyncio.Queue]] = defaultdict(
lambda: defaultdict(asyncio.Queue)
)
self._events_buffer: Dict[str, asyncio.Queue] = defaultdict(asyncio.Queue)

self._app = FastAPI(lifespan=self.lifespan)

Expand Down Expand Up @@ -283,20 +281,22 @@ async def process_call(self, current_call: WorkflowState) -> None:
# run the workflow
handler = self.workflow.run(ctx=ctx, **current_call.run_kwargs)

# create send_event background task
close_send_events = asyncio.Event()

async def send_events(handler: Any, close_event: asyncio.Event) -> None:
while not close_event.is_set():
try:
event = self._events_buffer[current_call.task_id].get_nowait()
handler.ctx.send_event(event)
except asyncio.QueueEmpty:
pass
await asyncio.sleep(self.step_interval)

_ = asyncio.create_task(send_events(handler, close_send_events))

index = 0
async for ev in handler.stream_events():
# check if blocking event
if isinstance(ev, BlockingEvent):
# wait for the unblocking event type
logger.debug(f"Waiting for unblocking event: {ev}")
unblocking_ev = await self._events_buffer[current_call.task_id][
ev.unblocking_event_type
].get()

# send event
logger.debug(f"Sending unblocking event: {ev}")
handler.ctx.send_event(unblocking_ev)

# send the event to control plane for client / api server streaming
logger.debug(f"Publishing event: {ev}")
await self.message_queue.publish(
Expand Down Expand Up @@ -354,6 +354,7 @@ async def process_call(self, current_call: WorkflowState) -> None:
)
finally:
# clean up
close_send_events.set()
async with self.lock:
self._outstanding_calls.pop(current_call.task_id, None)
self._ongoing_tasks.pop(current_call.task_id, None)
Expand Down Expand Up @@ -432,7 +433,7 @@ async def process_message(self, message: QueueMessage) -> None:
task_def = TaskDefinition(**message.data or {})
event = serializer.deserialize(task_def.input)
async with self.lock:
self._events_buffer[task_def.task_id][type(event)].put_nowait(event)
self._events_buffer[task_def.task_id].put_nowait(event)

else:
raise ValueError(f"Unhandled action: {message.action}")
Expand Down

0 comments on commit 80106c3

Please sign in to comment.