Skip to content

Commit

Permalink
Add ability to send_event to a Worfklow's service ongoing task/run. (
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai authored Oct 22, 2024
1 parent 73ea594 commit 3c8410b
Show file tree
Hide file tree
Showing 10 changed files with 355 additions and 2 deletions.
Empty file.
55 changes: 55 additions & 0 deletions e2e_tests/basic_hitl/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio
import multiprocessing
import time

import pytest

from llama_deploy import (
ControlPlaneConfig,
SimpleMessageQueueConfig,
WorkflowServiceConfig,
deploy_core,
deploy_workflow,
)

from .workflow import HumanInTheLoopWorkflow


def run_async_core():
asyncio.run(deploy_core(ControlPlaneConfig(), SimpleMessageQueueConfig()))


@pytest.fixture(scope="package")
def core():
p = multiprocessing.Process(target=run_async_core)
p.start()
time.sleep(5)

yield

p.kill()


def run_async_workflow():
asyncio.run(
deploy_workflow(
HumanInTheLoopWorkflow(timeout=60),
WorkflowServiceConfig(
host="127.0.0.1",
port=8002,
service_name="hitl_workflow",
),
ControlPlaneConfig(),
)
)


@pytest.fixture(scope="package")
def services(core):
p = multiprocessing.Process(target=run_async_workflow)
p.start()
time.sleep(5)

yield

p.kill()
75 changes: 75 additions & 0 deletions e2e_tests/basic_hitl/test_run_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import asyncio
import pytest
import time

from llama_deploy import AsyncLlamaDeployClient, ControlPlaneConfig, LlamaDeployClient
from llama_index.core.workflow.events import HumanResponseEvent


@pytest.mark.e2ehitl
def test_run_client(services):
client = LlamaDeployClient(ControlPlaneConfig(), timeout=10)

# sanity check
sessions = client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"

# create a session
session = client.create_session()

# kick off run
task_id = session.run_nowait("hitl_workflow")

# send event
session.send_event(
ev=HumanResponseEvent(response="42"),
service_name="hitl_workflow",
task_id=task_id,
)

# get final result, polling to wait for workflow to finish after send event
final_result = None
while final_result is None:
final_result = session.get_task_result(task_id)
time.sleep(0.1)
assert final_result.result == "42", "The human's response is not consistent."

# delete the session
client.delete_session(session.session_id)
sessions = client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"


@pytest.mark.e2ehitl
@pytest.mark.asyncio
async def test_run_client_async(services):
client = AsyncLlamaDeployClient(ControlPlaneConfig(), timeout=10)

# sanity check
sessions = await client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"

# create a session
session = await client.create_session()

# kick off run
task_id = await session.run_nowait("hitl_workflow")

# send event
await session.send_event(
ev=HumanResponseEvent(response="42"),
service_name="hitl_workflow",
task_id=task_id,
)

# get final result, polling to wait for workflow to finish after send event
final_result = None
while final_result is None:
final_result = await session.get_task_result(task_id)
asyncio.sleep(0.1)
assert final_result.result == "42", "The human's response is not consistent."

# delete the session
await client.delete_session(session.session_id)
sessions = await client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"
20 changes: 20 additions & 0 deletions e2e_tests/basic_hitl/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from llama_index.core.workflow import (
StartEvent,
StopEvent,
Workflow,
step,
)
from llama_index.core.workflow.events import (
HumanResponseEvent,
InputRequiredEvent,
)


class HumanInTheLoopWorkflow(Workflow):
@step
async def step1(self, ev: StartEvent) -> InputRequiredEvent:
return InputRequiredEvent(prefix="Enter a number: ")

@step
async def step2(self, ev: HumanResponseEvent) -> StopEvent:
return StopEvent(result=ev.response)
23 changes: 23 additions & 0 deletions llama_deploy/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

from llama_deploy.control_plane.server import ControlPlaneConfig
from llama_deploy.types import (
EventDefinition,
ServiceDefinition,
SessionDefinition,
TaskDefinition,
TaskResult,
)
from llama_index.core.workflow import Event
from llama_index.core.workflow.context_serializers import JsonSerializer

DEFAULT_TIMEOUT = 120.0
DEFAULT_POLL_INTERVAL = 0.5
Expand Down Expand Up @@ -149,6 +152,26 @@ async def get_task_result_stream(
f"Task result not available after waiting for {self.timeout} seconds"
)

async def send_event(self, service_name: str, task_id: str, ev: Event) -> None:
"""Send event to a Workflow service.
Args:
event (Event): The event to be submitted to the workflow.
Returns:
None
"""
serializer = JsonSerializer()
event_def = EventDefinition(
event_obj_str=serializer.serialize(ev), agent_id=service_name
)

async with httpx.AsyncClient(timeout=self.timeout) as client:
await client.post(
f"{self.control_plane_url}/sessions/{self.session_id}/tasks/{task_id}/send_event",
json=event_def.model_dump(),
)


class AsyncLlamaDeployClient:
def __init__(
Expand Down
24 changes: 24 additions & 0 deletions llama_deploy/client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@

from llama_deploy.control_plane.server import ControlPlaneConfig
from llama_deploy.types import (
EventDefinition,
TaskDefinition,
ServiceDefinition,
TaskResult,
SessionDefinition,
)
from llama_index.core.workflow import Event
from llama_index.core.workflow.context_serializers import JsonSerializer


DEFAULT_TIMEOUT = 120.0
DEFAULT_POLL_INTERVAL = 0.5
Expand Down Expand Up @@ -151,6 +155,26 @@ def get_task_result_stream(
f"Task result not available after waiting for {self.timeout} seconds"
)

def send_event(self, service_name: str, task_id: str, ev: Event) -> None:
"""Send event to a Workflow service.
Args:
event (Event): The event to be submitted to the workflow.
Returns:
None
"""
serializer = JsonSerializer()
event_def = EventDefinition(
event_obj_str=serializer.serialize(ev), agent_id=service_name
)

with httpx.Client(timeout=self.timeout) as client:
client.post(
f"{self.control_plane_url}/sessions/{self.session_id}/tasks/{task_id}/send_event",
json=event_def.model_dump(),
)


class LlamaDeployClient:
def __init__(
Expand Down
26 changes: 26 additions & 0 deletions llama_deploy/control_plane/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from llama_deploy.orchestrators.utils import get_result_key, get_stream_key
from llama_deploy.types import (
ActionTypes,
EventDefinition,
ServiceDefinition,
SessionDefinition,
TaskDefinition,
Expand Down Expand Up @@ -237,6 +238,12 @@ def __init__(
methods=["GET"],
tags=["Sessions"],
)
self.app.add_api_route(
"/sessions/{session_id}/tasks/{task_id}/send_event",
self.send_event,
methods=["POST"],
tags=["Sessions"],
)
self.app.add_api_route(
"/sessions/{session_id}/state",
self.get_session_state,
Expand Down Expand Up @@ -584,6 +591,25 @@ async def event_generator(
media_type="application/x-ndjson",
)

async def send_event(
self,
session_id: str,
task_id: str,
event_def: EventDefinition,
) -> None:
task_def = TaskDefinition(
task_id=task_id,
session_id=session_id,
input=event_def.event_obj_str,
agent_id=event_def.agent_id,
)
message = QueueMessage(
type=event_def.agent_id,
action=ActionTypes.SEND_EVENT,
data=task_def.model_dump(),
)
await self.publish(message)

async def get_session_state(self, session_id: str) -> Dict[str, Any]:
session = await self.get_session(session_id)
if session.task_ids is None:
Expand Down
40 changes: 38 additions & 2 deletions llama_deploy/services/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import uuid
import uvicorn
from collections import defaultdict
from contextlib import asynccontextmanager
from fastapi import FastAPI
from logging import getLogger
Expand All @@ -11,7 +12,11 @@
from typing import AsyncGenerator, Dict, Optional, Any

from llama_index.core.workflow import Context, Workflow
from llama_index.core.workflow.context_serializers import JsonPickleSerializer
from llama_index.core.workflow.handler import WorkflowHandler
from llama_index.core.workflow.context_serializers import (
JsonPickleSerializer,
JsonSerializer,
)

from llama_deploy.message_consumers.base import BaseMessageQueueConsumer
from llama_deploy.message_consumers.callable import CallableMessageConsumer
Expand Down Expand Up @@ -127,6 +132,7 @@ class WorkflowService(BaseService):
_publish_callback: Optional[PublishCallback] = PrivateAttr()
_lock: asyncio.Lock = PrivateAttr()
_outstanding_calls: Dict[str, WorkflowState] = PrivateAttr()
_events_buffer: Dict[str, asyncio.Queue] = PrivateAttr()

def __init__(
self,
Expand Down Expand Up @@ -165,6 +171,7 @@ def __init__(

self._outstanding_calls: Dict[str, WorkflowState] = {}
self._ongoing_tasks: Dict[str, asyncio.Task] = {}
self._events_buffer: Dict[str, asyncio.Queue] = defaultdict(asyncio.Queue)

self._app = FastAPI(lifespan=self.lifespan)

Expand Down Expand Up @@ -285,10 +292,29 @@ async def process_call(self, current_call: WorkflowState) -> None:
# run the workflow
handler = self.workflow.run(ctx=ctx, **current_call.run_kwargs)

# create send_event background task
close_send_events = asyncio.Event()

async def send_events(
handler: WorkflowHandler, close_event: asyncio.Event
) -> None:
if handler.ctx is None:
raise ValueError("handler does not have a valid Context.")

while not close_event.is_set():
try:
event = self._events_buffer[current_call.task_id].get_nowait()
handler.ctx.send_event(event)
except asyncio.QueueEmpty:
pass
await asyncio.sleep(self.step_interval)

_ = asyncio.create_task(send_events(handler, close_send_events))

index = 0
async for ev in handler.stream_events():
# send the event to control plane for client / api server streaming
logger.debug(f"Publishing event: {ev}")

await self.message_queue.publish(
QueueMessage(
type=CONTROL_PLANE_NAME,
Expand Down Expand Up @@ -344,6 +370,7 @@ async def process_call(self, current_call: WorkflowState) -> None:
)
finally:
# clean up
close_send_events.set()
async with self.lock:
self._outstanding_calls.pop(current_call.task_id, None)
self._ongoing_tasks.pop(current_call.task_id, None)
Expand Down Expand Up @@ -406,6 +433,7 @@ async def process_message(self, message: QueueMessage) -> None:
"""Process a message received from the message queue."""
if message.action == ActionTypes.NEW_TASK:
task_def = TaskDefinition(**message.data or {})

run_kwargs = json.loads(task_def.input)
workflow_state = WorkflowState(
session_id=task_def.session_id,
Expand All @@ -415,6 +443,14 @@ async def process_message(self, message: QueueMessage) -> None:

async with self.lock:
self._outstanding_calls[task_def.task_id] = workflow_state
elif message.action == ActionTypes.SEND_EVENT:
serializer = JsonSerializer()

task_def = TaskDefinition(**message.data or {})
event = serializer.deserialize(task_def.input)
async with self.lock:
self._events_buffer[task_def.task_id].put_nowait(event)

else:
raise ValueError(f"Unhandled action: {message.action}")

Expand Down
Loading

0 comments on commit 3c8410b

Please sign in to comment.