Skip to content

Commit

Permalink
Re-introduce Proper State Management (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Oct 14, 2024
1 parent 27244ea commit 36adfa3
Show file tree
Hide file tree
Showing 17 changed files with 895 additions and 553 deletions.
14 changes: 14 additions & 0 deletions e2e_tests/basic_session/launch_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from llama_deploy import deploy_core, ControlPlaneConfig, SimpleMessageQueueConfig


async def main():
await deploy_core(
ControlPlaneConfig(),
SimpleMessageQueueConfig(),
)


if __name__ == "__main__":
import asyncio

asyncio.run(main())
45 changes: 45 additions & 0 deletions e2e_tests/basic_session/launch_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import asyncio
from llama_index.core.workflow import (
Context,
Workflow,
StartEvent,
StopEvent,
step,
)
from llama_deploy import deploy_workflow, ControlPlaneConfig, WorkflowServiceConfig


class SessionWorkflow(Workflow):
@step()
async def step_1(self, ctx: Context, ev: StartEvent) -> StopEvent:
cur_val = await ctx.get("count", default=0)
await ctx.set("count", cur_val + 1)

return StopEvent(result=cur_val + 1)


session_workflow = SessionWorkflow(timeout=10)


async def main():
# sanity check
result = await session_workflow.run(arg1="hello_world")
assert result == 1, "Sanity check failed"

outer_task = asyncio.create_task(
deploy_workflow(
session_workflow,
WorkflowServiceConfig(
host="127.0.0.1",
port=8002,
service_name="session_workflow",
),
ControlPlaneConfig(),
)
)

await asyncio.gather(outer_task)


if __name__ == "__main__":
asyncio.run(main())
25 changes: 25 additions & 0 deletions e2e_tests/basic_session/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

# Kill any previously running scripts
echo "Killing any previously running scripts"
pkill -f "launch_"

# Wait for processes to terminate
sleep 2

set -e

echo "Launching core"
python ./launch_core.py &
sleep 5

echo "Launching workflow"
python ./launch_workflow.py &
sleep 5

echo "Running client tests"
python ./test_run_client.py

# Kill any previously running scripts
echo "Killing any previously running scripts"
pkill -f "launch_"
67 changes: 67 additions & 0 deletions e2e_tests/basic_session/test_run_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from llama_deploy import AsyncLlamaDeployClient, ControlPlaneConfig, LlamaDeployClient


def test_run_client():
client = LlamaDeployClient(ControlPlaneConfig(), timeout=10)

# sanity check
sessions = client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"

# create a session
session = client.create_session()

# kick off run
_ = session.run("session_workflow", arg1="hello_world")

# kick off another run
res = session.run("session_workflow", arg1="hello_world")

# if the session state is working across runs,
# the count should be 2
assert res == "2", f"Session state is not working across runs, result was {res}"

# delete the session
client.delete_session(session.session_id)

# sanity check
sessions = client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"


async def test_run_client_async():
client = AsyncLlamaDeployClient(ControlPlaneConfig(), timeout=10)

# sanity check
sessions = await client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"

# create a session
session = await client.create_session()

# kick off run
_ = await session.run("session_workflow", arg1="hello_world")

# kick off another run
res = await session.run("session_workflow", arg1="hello_world")

# if the session state is working across runs,
# the count should be 2
assert res == "2", f"Session state is not working across runs, result was {res}"

# delete the session
await client.delete_session(session.session_id)

# sanity check
sessions = await client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"


if __name__ == "__main__":
import asyncio

print("Running async test")
asyncio.run(test_run_client_async())

print("Running sync test")
test_run_client()
31 changes: 30 additions & 1 deletion llama_deploy/control_plane/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi.responses import StreamingResponse
from logging import getLogger
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import AsyncGenerator, Dict, List, Optional
from typing import AsyncGenerator, Any, Dict, List, Optional

from llama_index.core.storage.kvstore.types import BaseKVStore
from llama_index.core.storage.kvstore import SimpleKVStore
Expand Down Expand Up @@ -237,6 +237,18 @@ def __init__(
methods=["GET"],
tags=["Sessions"],
)
self.app.add_api_route(
"/sessions/{session_id}/state",
self.get_session_state,
methods=["GET"],
tags=["Sessions"],
)
self.app.add_api_route(
"/sessions/{session_id}/state",
self.update_session_state,
methods=["POST"],
tags=["Sessions"],
)

@property
def message_queue(self) -> BaseMessageQueue:
Expand Down Expand Up @@ -593,6 +605,23 @@ async def event_generator(
media_type="application/x-ndjson",
)

async def get_session_state(self, session_id: str) -> Dict[str, Any]:
session = await self.get_session(session_id)
if session.task_ids is None:
raise HTTPException(status_code=404, detail="Session not found")

return session.state

async def update_session_state(
self, session_id: str, state: Dict[str, Any]
) -> None:
session = await self.get_session(session_id)

session.state.update(state)
await self.state_store.aput(
session_id, session.model_dump(), collection=self.session_store_key
)

async def get_message_queue_config(self) -> Dict[str, dict]:
"""
Gets the config dict for the message queue being used.
Expand Down
6 changes: 2 additions & 4 deletions llama_deploy/orchestrators/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from llama_deploy.messages.base import QueueMessage
from llama_deploy.orchestrators.base import BaseOrchestrator
from llama_deploy.orchestrators.utils import get_result_key
from llama_deploy.types import ActionTypes, NewTask, TaskDefinition, TaskResult
from llama_deploy.types import ActionTypes, TaskDefinition, TaskResult


class SimpleOrchestratorConfig(BaseSettings):
Expand Down Expand Up @@ -75,9 +75,7 @@ async def get_next_messages(
QueueMessage(
type=destination,
action=ActionTypes.NEW_TASK,
data=NewTask(
task=task_def, state=state[task_def.task_id]
).model_dump(),
data=task_def.model_dump(),
)
]

Expand Down
4 changes: 1 addition & 3 deletions llama_deploy/services/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
ActionTypes,
ChatMessage,
MessageRole,
NewTask,
TaskResult,
TaskDefinition,
ToolCall,
Expand Down Expand Up @@ -310,8 +309,7 @@ async def processing_loop(self) -> None:
async def process_message(self, message: QueueMessage) -> None:
"""Handling for when a message is received."""
if message.action == ActionTypes.NEW_TASK:
new_task = NewTask(**message.data or {})
task_def = new_task.task
task_def = TaskDefinition(**message.data or {})
self.agent.create_task(task_def.input, task_id=task_def.task_id)
logger.info(f"Created new task: {task_def.task_id}")
elif message.action == ActionTypes.NEW_TOOL_CALL:
Expand Down
42 changes: 39 additions & 3 deletions llama_deploy/services/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import httpx
from abc import ABC, abstractmethod
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, PrivateAttr
from typing import Any

from llama_deploy.messages.base import QueueMessage
Expand Down Expand Up @@ -34,6 +34,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:

model_config = ConfigDict(arbitrary_types_allowed=True)
service_name: str
_control_plane_url: str | None = PrivateAttr(default=None)

@property
@abstractmethod
Expand Down Expand Up @@ -68,6 +69,7 @@ async def launch_server(self) -> None:

async def register_to_control_plane(self, control_plane_url: str) -> None:
"""Register the service to the control plane."""
self._control_plane_url = control_plane_url
service_def = self.service_definition
async with httpx.AsyncClient() as client:
response = await client.post(
Expand All @@ -76,15 +78,49 @@ async def register_to_control_plane(self, control_plane_url: str) -> None:
)
response.raise_for_status()

async def deregister_from_control_plane(self, control_plane_url: str) -> None:
async def deregister_from_control_plane(self) -> None:
"""Deregister the service from the control plane."""
if not self._control_plane_url:
raise ValueError(
"Control plane URL not set. Call register_to_control_plane first."
)
async with httpx.AsyncClient() as client:
response = await client.post(
f"{control_plane_url}/services/deregister",
f"{self._control_plane_url}/services/deregister",
json={"service_name": self.service_name},
)
response.raise_for_status()

async def get_session_state(self, session_id: str) -> dict[str, Any] | None:
"""Get the session state from the control plane."""
if not self._control_plane_url:
return None

async with httpx.AsyncClient() as client:
response = await client.get(
f"{self._control_plane_url}/sessions/{session_id}/state"
)
if response.status_code == 404:
return None
else:
response.raise_for_status()

return response.json()

async def update_session_state(
self, session_id: str, state: dict[str, Any]
) -> None:
"""Update the session state in the control plane."""
if not self._control_plane_url:
return

async with httpx.AsyncClient() as client:
response = await client.post(
f"{self._control_plane_url}/sessions/{session_id}/state",
json=state,
)
response.raise_for_status()

async def register_to_message_queue(self) -> StartConsumingCallable:
"""Register the service to the message queue."""
return await self.message_queue.register_consumer(self.as_consumer(remote=True))
10 changes: 4 additions & 6 deletions llama_deploy/services/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from llama_deploy.services.base import BaseService
from llama_deploy.types import (
ActionTypes,
NewTask,
TaskDefinition,
TaskResult,
ServiceDefinition,
CONTROL_PLANE_NAME,
Expand Down Expand Up @@ -190,12 +190,10 @@ async def processing_loop(self) -> None:
async def process_message(self, message: QueueMessage) -> None:
"""Process a message received from the message queue."""
if message.action == ActionTypes.NEW_TASK:
new_task = NewTask(**message.data or {})
task_def = new_task.task
task_def = TaskDefinition(**message.data or {})
input_dict = json.loads(task_def.input)
async with self.lock:
self._outstanding_calls[task_def.task_id] = new_task.state[
"__input_dict__"
]
self._outstanding_calls[task_def.task_id] = input_dict["__input_dict__"]
else:
raise ValueError(f"Unhandled action: {message.action}")

Expand Down
4 changes: 1 addition & 3 deletions llama_deploy/services/human.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
ActionTypes,
ChatMessage,
HumanResponse,
NewTask,
TaskDefinition,
TaskResult,
ToolCall,
Expand Down Expand Up @@ -288,8 +287,7 @@ class HumanTask(BaseModel):
async def process_message(self, message: QueueMessage) -> None:
"""Process a message received from the message queue."""
if message.action == ActionTypes.NEW_TASK:
new_task = NewTask(**message.data or {})
task_def = new_task.task
task_def = TaskDefinition(**message.data or {})
human_task = self.HumanTask(task_def=task_def)
logger.info(f"Created new task: {task_def.task_id}")
elif message.action == ActionTypes.NEW_TOOL_CALL:
Expand Down
Loading

0 comments on commit 36adfa3

Please sign in to comment.