From 1883e8792decc49ab7d5ab3ba2fc52519aa49265 Mon Sep 17 00:00:00 2001 From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com> Date: Thu, 10 Oct 2024 08:44:00 -0400 Subject: [PATCH] endpoints for create task nowait and getting event stream (#301) * endpoints for create task nowait and getting stream * rm trailing backslash * unit test for create task nowait * add unit test for stream events * unit test for get task result * nit * session endpoints * delete session endpoint * make create_nowait default create and introduce run endpoint * re-jig endpoints task_id as path param --- llama_deploy/apiserver/routers/deployments.py | 114 +++++++++++++++-- tests/apiserver/routers/test_deployments.py | 119 +++++++++++++++++- 2 files changed, 217 insertions(+), 16 deletions(-) diff --git a/llama_deploy/apiserver/routers/deployments.py b/llama_deploy/apiserver/routers/deployments.py index f2fb441f..c4112f70 100644 --- a/llama_deploy/apiserver/routers/deployments.py +++ b/llama_deploy/apiserver/routers/deployments.py @@ -1,12 +1,14 @@ import json from fastapi import APIRouter, File, UploadFile, HTTPException -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse +from typing import AsyncGenerator from llama_deploy.apiserver.server import manager from llama_deploy.apiserver.config_parser import Config from llama_deploy.types import TaskDefinition + deployments_router = APIRouter( prefix="/deployments", ) @@ -35,11 +37,25 @@ async def read_deployment(deployment_name: str) -> JSONResponse: ) -@deployments_router.post("/{deployment_name}/tasks/create") +@deployments_router.post("/create") +async def create_deployment(config_file: UploadFile = File(...)) -> JSONResponse: + """Creates a new deployment by uploading a configuration file.""" + config = Config.from_yaml_bytes(await config_file.read()) + manager.deploy(config) + + # Return some details about the file + return JSONResponse( + { + "name": config.name, + } + ) + + +@deployments_router.post("/{deployment_name}/tasks/run") async def create_deployment_task( deployment_name: str, task_definition: TaskDefinition ) -> JSONResponse: - """Create a task for the deployment.""" + """Create a task for the deployment, wait for result and delete associated session.""" deployment = manager.get_deployment(deployment_name) if deployment is None: raise HTTPException(status_code=404, detail="Deployment not found") @@ -61,15 +77,87 @@ async def create_deployment_task( return JSONResponse(result) -@deployments_router.post("/create") -async def create_deployment(config_file: UploadFile = File(...)) -> JSONResponse: - """Creates a new deployment by uploading a configuration file.""" - config = Config.from_yaml_bytes(await config_file.read()) - manager.deploy(config) +@deployments_router.post("/{deployment_name}/tasks/create") +async def create_deployment_task_nowait( + deployment_name: str, task_definition: TaskDefinition +) -> JSONResponse: + """Create a task for the deployment but don't wait for result.""" + deployment = manager.get_deployment(deployment_name) + if deployment is None: + raise HTTPException(status_code=404, detail="Deployment not found") - # Return some details about the file - return JSONResponse( - { - "name": config.name, - } + if task_definition.agent_id is None: + if deployment.default_service is None: + raise HTTPException( + status_code=400, + detail="Service is None and deployment has no default service", + ) + task_definition.agent_id = deployment.default_service + + session = await deployment.client.create_session() + task_id = await session.run_nowait( + task_definition.agent_id or "", **json.loads(task_definition.input) + ) + + return JSONResponse({"session_id": session.session_id, "task_id": task_id}) + + +@deployments_router.get("/{deployment_name}/tasks/{task_id}/events") +async def get_events( + deployment_name: str, session_id: str, task_id: str +) -> StreamingResponse: + """Get the stream of events from a given task and session.""" + deployment = manager.get_deployment(deployment_name) + if deployment is None: + raise HTTPException(status_code=404, detail="Deployment not found") + + session = await deployment.client.get_session(session_id) + + async def event_stream() -> AsyncGenerator[str, None]: + # need to convert back to str to use SSE + async for event in session.get_task_result_stream(task_id): + yield json.dumps(event) + "\n" + + return StreamingResponse( + event_stream(), + media_type="application/x-ndjson", ) + + +@deployments_router.get("/{deployment_name}/tasks/{task_id}/results") +async def get_task_result( + deployment_name: str, session_id: str, task_id: str +) -> JSONResponse: + """Get the task result associated with a task and session.""" + deployment = manager.get_deployment(deployment_name) + if deployment is None: + raise HTTPException(status_code=404, detail="Deployment not found") + + session = await deployment.client.get_session(session_id) + result = await session.get_task_result(task_id) + + return JSONResponse(result.result if result else "") + + +@deployments_router.get("/{deployment_name}/sessions") +async def get_sessions( + deployment_name: str, +) -> JSONResponse: + """Get the active sessions in a deployment and service.""" + deployment = manager.get_deployment(deployment_name) + if deployment is None: + raise HTTPException(status_code=404, detail="Deployment not found") + + sessions = await deployment.client.list_sessions() + return JSONResponse(sessions) + + +@deployments_router.post("/{deployment_name}/sessions/delete") +async def delete_session(deployment_name: str, session_id: str) -> JSONResponse: + """Get the active sessions in a deployment and service.""" + deployment = manager.get_deployment(deployment_name) + if deployment is None: + raise HTTPException(status_code=404, detail="Deployment not found") + + await deployment.client.delete_session(session_id) + return JSONResponse({"session_id": session_id, "status": "Deleted"}) diff --git a/tests/apiserver/routers/test_deployments.py b/tests/apiserver/routers/test_deployments.py index ad2393a2..58f3245a 100644 --- a/tests/apiserver/routers/test_deployments.py +++ b/tests/apiserver/routers/test_deployments.py @@ -1,9 +1,14 @@ +import json +import pytest from pathlib import Path from unittest import mock from fastapi.testclient import TestClient from llama_deploy.apiserver import Config +from llama_deploy.types import TaskResult + +from llama_index.core.workflow import Event def test_read_deployments(http_client: TestClient) -> None: @@ -78,7 +83,7 @@ def test_create_deployment_task_missing_service( ) -def test_create_deployment_task(http_client: TestClient, data_path: Path) -> None: +def test_run_deployment_task(http_client: TestClient, data_path: Path) -> None: with mock.patch( "llama_deploy.apiserver.routers.deployments.manager" ) as mocked_manager: @@ -87,11 +92,119 @@ def test_create_deployment_task(http_client: TestClient, data_path: Path) -> Non session = mock.AsyncMock() deployment.client.create_session.return_value = session session.run.return_value = {"result": "test_result"} - session.session_id = 42 + session.session_id = "42" + mocked_manager.get_deployment.return_value = deployment + response = http_client.post( + "/deployments/test-deployment/tasks/run/", + json={"input": "{}"}, + ) + assert response.status_code == 200 + deployment.client.delete_session.assert_called_with("42") + + +def test_create_deployment_task(http_client: TestClient, data_path: Path) -> None: + with mock.patch( + "llama_deploy.apiserver.routers.deployments.manager" + ) as mocked_manager: + deployment = mock.AsyncMock() + deployment.default_service = "TestService" + session = mock.AsyncMock() + deployment.client.create_session.return_value = session + session.session_id = "42" + session.run_nowait.return_value = "test_task_id" mocked_manager.get_deployment.return_value = deployment response = http_client.post( "/deployments/test-deployment/tasks/create/", json={"input": "{}"}, ) assert response.status_code == 200 - deployment.client.delete_session.assert_called_with(42) + assert response.json() == {"session_id": "42", "task_id": "test_task_id"} + + +@pytest.mark.asyncio +async def test_get_event_stream(http_client: TestClient, data_path: Path) -> None: + mock_events = [ + Event(msg="mock event 1"), + Event(msg="mock event 2"), + Event(msg="mock event 3"), + ] + + with mock.patch( + "llama_deploy.apiserver.routers.deployments.manager" + ) as mocked_manager: + deployment = mock.AsyncMock() + deployment.default_service = "TestService" + session = mock.MagicMock() + deployment.client.get_session.return_value = session + mocked_manager.get_deployment.return_value = deployment + mocked_get_task_result_stream = mock.MagicMock() + mocked_get_task_result_stream.__aiter__.return_value = ( + ev.dict() for ev in mock_events + ) + session.get_task_result_stream.return_value = mocked_get_task_result_stream + + response = http_client.get( + "/deployments/test-deployment/tasks/test_task_id/events/?session_id=42", + ) + assert response.status_code == 200 + ix = 0 + async for line in response.aiter_lines(): + data = json.loads(line) + assert data == mock_events[ix].dict() + ix += 1 + deployment.client.get_session.assert_called_with("42") + session.get_task_result_stream.assert_called_with("test_task_id") + + +def test_get_task_result(http_client: TestClient, data_path: Path) -> None: + with mock.patch( + "llama_deploy.apiserver.routers.deployments.manager" + ) as mocked_manager: + deployment = mock.AsyncMock() + deployment.default_service = "TestService" + session = mock.AsyncMock() + deployment.client.get_session.return_value = session + session.get_task_result.return_value = TaskResult( + result="test_result", history=[], task_id="test_task_id" + ) + mocked_manager.get_deployment.return_value = deployment + + response = http_client.get( + "/deployments/test-deployment/tasks/test_task_id/results/?session_id=42", + ) + assert response.status_code == 200 + assert response.json() == "test_result" + session.get_task_result.assert_called_with("test_task_id") + deployment.client.get_session.assert_called_with("42") + + +def test_get_sessions(http_client: TestClient, data_path: Path) -> None: + with mock.patch( + "llama_deploy.apiserver.routers.deployments.manager" + ) as mocked_manager: + deployment = mock.AsyncMock() + deployment.default_service = "TestService" + deployment.client.list_sessions.return_value = [] + mocked_manager.get_deployment.return_value = deployment + + response = http_client.get( + "/deployments/test-deployment/sessions/", + ) + assert response.status_code == 200 + assert response.json() == [] + + +def test_delete_session(http_client: TestClient, data_path: Path) -> None: + with mock.patch( + "llama_deploy.apiserver.routers.deployments.manager" + ) as mocked_manager: + deployment = mock.AsyncMock() + deployment.default_service = "TestService" + mocked_manager.get_deployment.return_value = deployment + + response = http_client.post( + "/deployments/test-deployment/sessions/delete/?session_id=42", + ) + assert response.status_code == 200 + assert response.json() == {"session_id": "42", "status": "Deleted"} + deployment.client.delete_session.assert_called_with("42")