Skip to content

Commit

Permalink
endpoints for create task nowait and getting event stream (#301)
Browse files Browse the repository at this point in the history
* endpoints for create task nowait and getting stream

* rm trailing backslash

* unit test for create task nowait

* add unit test for stream events

* unit test for get task result

* nit

* session endpoints

* delete session endpoint

* make create_nowait default create and introduce run endpoint

* re-jig endpoints task_id as path param
  • Loading branch information
nerdai authored Oct 10, 2024
1 parent d93849b commit 1883e87
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 16 deletions.
114 changes: 101 additions & 13 deletions llama_deploy/apiserver/routers/deployments.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json

from fastapi import APIRouter, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
from typing import AsyncGenerator

from llama_deploy.apiserver.server import manager
from llama_deploy.apiserver.config_parser import Config
from llama_deploy.types import TaskDefinition


deployments_router = APIRouter(
prefix="/deployments",
)
Expand Down Expand Up @@ -35,11 +37,25 @@ async def read_deployment(deployment_name: str) -> JSONResponse:
)


@deployments_router.post("/{deployment_name}/tasks/create")
@deployments_router.post("/create")
async def create_deployment(config_file: UploadFile = File(...)) -> JSONResponse:
"""Creates a new deployment by uploading a configuration file."""
config = Config.from_yaml_bytes(await config_file.read())
manager.deploy(config)

# Return some details about the file
return JSONResponse(
{
"name": config.name,
}
)


@deployments_router.post("/{deployment_name}/tasks/run")
async def create_deployment_task(
deployment_name: str, task_definition: TaskDefinition
) -> JSONResponse:
"""Create a task for the deployment."""
"""Create a task for the deployment, wait for result and delete associated session."""
deployment = manager.get_deployment(deployment_name)
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")
Expand All @@ -61,15 +77,87 @@ async def create_deployment_task(
return JSONResponse(result)


@deployments_router.post("/create")
async def create_deployment(config_file: UploadFile = File(...)) -> JSONResponse:
"""Creates a new deployment by uploading a configuration file."""
config = Config.from_yaml_bytes(await config_file.read())
manager.deploy(config)
@deployments_router.post("/{deployment_name}/tasks/create")
async def create_deployment_task_nowait(
deployment_name: str, task_definition: TaskDefinition
) -> JSONResponse:
"""Create a task for the deployment but don't wait for result."""
deployment = manager.get_deployment(deployment_name)
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")

# Return some details about the file
return JSONResponse(
{
"name": config.name,
}
if task_definition.agent_id is None:
if deployment.default_service is None:
raise HTTPException(
status_code=400,
detail="Service is None and deployment has no default service",
)
task_definition.agent_id = deployment.default_service

session = await deployment.client.create_session()
task_id = await session.run_nowait(
task_definition.agent_id or "", **json.loads(task_definition.input)
)

return JSONResponse({"session_id": session.session_id, "task_id": task_id})


@deployments_router.get("/{deployment_name}/tasks/{task_id}/events")
async def get_events(
deployment_name: str, session_id: str, task_id: str
) -> StreamingResponse:
"""Get the stream of events from a given task and session."""
deployment = manager.get_deployment(deployment_name)
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")

session = await deployment.client.get_session(session_id)

async def event_stream() -> AsyncGenerator[str, None]:
# need to convert back to str to use SSE
async for event in session.get_task_result_stream(task_id):
yield json.dumps(event) + "\n"

return StreamingResponse(
event_stream(),
media_type="application/x-ndjson",
)


@deployments_router.get("/{deployment_name}/tasks/{task_id}/results")
async def get_task_result(
deployment_name: str, session_id: str, task_id: str
) -> JSONResponse:
"""Get the task result associated with a task and session."""
deployment = manager.get_deployment(deployment_name)
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")

session = await deployment.client.get_session(session_id)
result = await session.get_task_result(task_id)

return JSONResponse(result.result if result else "")


@deployments_router.get("/{deployment_name}/sessions")
async def get_sessions(
deployment_name: str,
) -> JSONResponse:
"""Get the active sessions in a deployment and service."""
deployment = manager.get_deployment(deployment_name)
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")

sessions = await deployment.client.list_sessions()
return JSONResponse(sessions)


@deployments_router.post("/{deployment_name}/sessions/delete")
async def delete_session(deployment_name: str, session_id: str) -> JSONResponse:
"""Get the active sessions in a deployment and service."""
deployment = manager.get_deployment(deployment_name)
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")

await deployment.client.delete_session(session_id)
return JSONResponse({"session_id": session_id, "status": "Deleted"})
119 changes: 116 additions & 3 deletions tests/apiserver/routers/test_deployments.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import json
import pytest
from pathlib import Path
from unittest import mock

from fastapi.testclient import TestClient

from llama_deploy.apiserver import Config
from llama_deploy.types import TaskResult

from llama_index.core.workflow import Event


def test_read_deployments(http_client: TestClient) -> None:
Expand Down Expand Up @@ -78,7 +83,7 @@ def test_create_deployment_task_missing_service(
)


def test_create_deployment_task(http_client: TestClient, data_path: Path) -> None:
def test_run_deployment_task(http_client: TestClient, data_path: Path) -> None:
with mock.patch(
"llama_deploy.apiserver.routers.deployments.manager"
) as mocked_manager:
Expand All @@ -87,11 +92,119 @@ def test_create_deployment_task(http_client: TestClient, data_path: Path) -> Non
session = mock.AsyncMock()
deployment.client.create_session.return_value = session
session.run.return_value = {"result": "test_result"}
session.session_id = 42
session.session_id = "42"
mocked_manager.get_deployment.return_value = deployment
response = http_client.post(
"/deployments/test-deployment/tasks/run/",
json={"input": "{}"},
)
assert response.status_code == 200
deployment.client.delete_session.assert_called_with("42")


def test_create_deployment_task(http_client: TestClient, data_path: Path) -> None:
with mock.patch(
"llama_deploy.apiserver.routers.deployments.manager"
) as mocked_manager:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
session = mock.AsyncMock()
deployment.client.create_session.return_value = session
session.session_id = "42"
session.run_nowait.return_value = "test_task_id"
mocked_manager.get_deployment.return_value = deployment
response = http_client.post(
"/deployments/test-deployment/tasks/create/",
json={"input": "{}"},
)
assert response.status_code == 200
deployment.client.delete_session.assert_called_with(42)
assert response.json() == {"session_id": "42", "task_id": "test_task_id"}


@pytest.mark.asyncio
async def test_get_event_stream(http_client: TestClient, data_path: Path) -> None:
mock_events = [
Event(msg="mock event 1"),
Event(msg="mock event 2"),
Event(msg="mock event 3"),
]

with mock.patch(
"llama_deploy.apiserver.routers.deployments.manager"
) as mocked_manager:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
session = mock.MagicMock()
deployment.client.get_session.return_value = session
mocked_manager.get_deployment.return_value = deployment
mocked_get_task_result_stream = mock.MagicMock()
mocked_get_task_result_stream.__aiter__.return_value = (
ev.dict() for ev in mock_events
)
session.get_task_result_stream.return_value = mocked_get_task_result_stream

response = http_client.get(
"/deployments/test-deployment/tasks/test_task_id/events/?session_id=42",
)
assert response.status_code == 200
ix = 0
async for line in response.aiter_lines():
data = json.loads(line)
assert data == mock_events[ix].dict()
ix += 1
deployment.client.get_session.assert_called_with("42")
session.get_task_result_stream.assert_called_with("test_task_id")


def test_get_task_result(http_client: TestClient, data_path: Path) -> None:
with mock.patch(
"llama_deploy.apiserver.routers.deployments.manager"
) as mocked_manager:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
session = mock.AsyncMock()
deployment.client.get_session.return_value = session
session.get_task_result.return_value = TaskResult(
result="test_result", history=[], task_id="test_task_id"
)
mocked_manager.get_deployment.return_value = deployment

response = http_client.get(
"/deployments/test-deployment/tasks/test_task_id/results/?session_id=42",
)
assert response.status_code == 200
assert response.json() == "test_result"
session.get_task_result.assert_called_with("test_task_id")
deployment.client.get_session.assert_called_with("42")


def test_get_sessions(http_client: TestClient, data_path: Path) -> None:
with mock.patch(
"llama_deploy.apiserver.routers.deployments.manager"
) as mocked_manager:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
deployment.client.list_sessions.return_value = []
mocked_manager.get_deployment.return_value = deployment

response = http_client.get(
"/deployments/test-deployment/sessions/",
)
assert response.status_code == 200
assert response.json() == []


def test_delete_session(http_client: TestClient, data_path: Path) -> None:
with mock.patch(
"llama_deploy.apiserver.routers.deployments.manager"
) as mocked_manager:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
mocked_manager.get_deployment.return_value = deployment

response = http_client.post(
"/deployments/test-deployment/sessions/delete/?session_id=42",
)
assert response.status_code == 200
assert response.json() == {"session_id": "42", "status": "Deleted"}
deployment.client.delete_session.assert_called_with("42")

0 comments on commit 1883e87

Please sign in to comment.