diff --git a/llama_deploy/services/workflow.py b/llama_deploy/services/workflow.py index 74388b28..c141dc95 100644 --- a/llama_deploy/services/workflow.py +++ b/llama_deploy/services/workflow.py @@ -9,9 +9,9 @@ from logging import getLogger from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from pydantic_settings import BaseSettings, SettingsConfigDict -from typing import AsyncGenerator, Dict, Optional, Any, Type +from typing import AsyncGenerator, Dict, Optional, Any -from llama_index.core.workflow import Context, Workflow, Event, BlockingEvent +from llama_index.core.workflow import Context, Workflow from llama_index.core.workflow.context_serializers import ( JsonPickleSerializer, JsonSerializer, @@ -129,7 +129,7 @@ class WorkflowService(BaseService): _publish_callback: Optional[PublishCallback] = PrivateAttr() _lock: asyncio.Lock = PrivateAttr() _outstanding_calls: Dict[str, WorkflowState] = PrivateAttr() - _events_buffer: Dict[str, Dict[Type[Event], asyncio.Queue]] = PrivateAttr() + _events_buffer: Dict[str, asyncio.Queue] = PrivateAttr() def __init__( self, @@ -168,9 +168,7 @@ def __init__( self._outstanding_calls: Dict[str, WorkflowState] = {} self._ongoing_tasks: Dict[str, asyncio.Task] = {} - self._events_buffer: Dict[str, Dict[Type[Event], asyncio.Queue]] = defaultdict( - lambda: defaultdict(asyncio.Queue) - ) + self._events_buffer: Dict[str, asyncio.Queue] = defaultdict(asyncio.Queue) self._app = FastAPI(lifespan=self.lifespan) @@ -283,20 +281,22 @@ async def process_call(self, current_call: WorkflowState) -> None: # run the workflow handler = self.workflow.run(ctx=ctx, **current_call.run_kwargs) + # create send_event background task + close_send_events = asyncio.Event() + + async def send_events(handler: Any, close_event: asyncio.Event) -> None: + while not close_event.is_set(): + try: + event = self._events_buffer[current_call.task_id].get_nowait() + handler.ctx.send_event(event) + except asyncio.QueueEmpty: + pass + await asyncio.sleep(self.step_interval) + + _ = asyncio.create_task(send_events(handler, close_send_events)) + index = 0 async for ev in handler.stream_events(): - # check if blocking event - if isinstance(ev, BlockingEvent): - # wait for the unblocking event type - logger.debug(f"Waiting for unblocking event: {ev}") - unblocking_ev = await self._events_buffer[current_call.task_id][ - ev.unblocking_event_type - ].get() - - # send event - logger.debug(f"Sending unblocking event: {ev}") - handler.ctx.send_event(unblocking_ev) - # send the event to control plane for client / api server streaming logger.debug(f"Publishing event: {ev}") await self.message_queue.publish( @@ -354,6 +354,7 @@ async def process_call(self, current_call: WorkflowState) -> None: ) finally: # clean up + close_send_events.set() async with self.lock: self._outstanding_calls.pop(current_call.task_id, None) self._ongoing_tasks.pop(current_call.task_id, None) @@ -432,7 +433,7 @@ async def process_message(self, message: QueueMessage) -> None: task_def = TaskDefinition(**message.data or {}) event = serializer.deserialize(task_def.input) async with self.lock: - self._events_buffer[task_def.task_id][type(event)].put_nowait(event) + self._events_buffer[task_def.task_id].put_nowait(event) else: raise ValueError(f"Unhandled action: {message.action}")