Skip to content

Commit

Permalink
agentic orchestrator (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Jun 19, 2024
1 parent 260c12c commit 135b5fd
Show file tree
Hide file tree
Showing 15 changed files with 633 additions and 101 deletions.
12 changes: 1 addition & 11 deletions agentfile/control_plane/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def create_task(self, task_def: TaskDefinition) -> None:
...

@abstractmethod
async def send_task_to_service(self, task_def: TaskDefinition) -> None:
async def send_task_to_service(self, task_def: TaskDefinition) -> TaskDefinition:
"""
Send a task to an service.
Expand All @@ -94,16 +94,6 @@ async def handle_service_completion(
"""
...

@abstractmethod
async def get_next_service(self, task_id: str) -> str:
"""
Get the next service for a task.
:param task_id: Unique identifier of the task.
:return: Unique identifier of the next service.
"""
...

@abstractmethod
async def get_task_state(self, task_id: str) -> dict:
"""
Expand Down
119 changes: 46 additions & 73 deletions agentfile/control_plane/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,17 @@
from typing import Any, Callable, Dict, List, Optional

from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.llms import LLM
from llama_index.core.objects import ObjectIndex, SimpleObjectNodeMapping
from llama_index.core.storage.kvstore.types import BaseKVStore
from llama_index.core.storage.kvstore import SimpleKVStore
from llama_index.core.selectors import PydanticMultiSelector
from llama_index.core.settings import Settings
from llama_index.core.tools import ToolMetadata
from llama_index.core.vector_stores.types import BasePydanticVectorStore

from agentfile.control_plane.base import BaseControlPlane
from agentfile.message_consumers.base import BaseMessageQueueConsumer
from agentfile.message_queues.base import BaseMessageQueue, PublishCallback
from agentfile.messages.base import QueueMessage
from agentfile.orchestrators.base import BaseOrchestrator
from agentfile.orchestrators.service_tool import ServiceTool
from agentfile.types import (
ActionTypes,
ServiceDefinition,
Expand Down Expand Up @@ -51,7 +49,7 @@ class FastAPIControlPlane(BaseControlPlane):
def __init__(
self,
message_queue: BaseMessageQueue,
llm: Optional[LLM] = None,
orchestrator: BaseOrchestrator,
vector_store: Optional[BasePydanticVectorStore] = None,
publish_callback: Optional[PublishCallback] = None,
state_store: Optional[BaseKVStore] = None,
Expand All @@ -62,7 +60,7 @@ def __init__(
step_interval: float = 0.1,
running: bool = True,
) -> None:
self.llm = llm or Settings.llm
self.orchestrator = orchestrator
self.object_index = ObjectIndex(
VectorStoreIndex(
nodes=[],
Expand Down Expand Up @@ -168,38 +166,16 @@ async def deregister_flow(self, flow_id: str) -> None:
await self.state_store.adelete(flow_id, collection=self.flows_store_key)

async def create_task(self, task_def: TaskDefinition) -> None:
"""
TODO:
Ideally, this would
- get/create state for the task
- call orchestrator.get_next_messages(task_def, state)
- publish messages to the next services
"""

await self.state_store.aput(
task_def.task_id, task_def.dict(), collection=self.tasks_store_key
)

await self.send_task_to_service(task_def)

async def get_task_state(self, task_id: str) -> TaskDefinition:
state_dict = await self.state_store.aget(
task_id, collection=self.tasks_store_key
task_def = await self.send_task_to_service(task_def)
await self.state_store.aput(
task_def.task_id, task_def.dict(), collection=self.tasks_store_key
)
if state_dict is None:
raise ValueError(f"Task with id {task_id} not found")

return TaskDefinition.parse_obj(state_dict)

async def get_all_tasks(self) -> Dict[str, TaskDefinition]:
state_dicts = await self.state_store.aget_all(collection=self.tasks_store_key)
return {
task_id: TaskDefinition.parse_obj(state_dict)
for task_id, state_dict in state_dicts.items()
}

async def send_task_to_service(self, task_def: TaskDefinition) -> None:
async def send_task_to_service(self, task_def: TaskDefinition) -> TaskDefinition:
service_retriever = self.object_index.as_retriever(similarity_top_k=5)

# could also route based on similarity alone.
Expand All @@ -211,55 +187,52 @@ async def send_task_to_service(self, task_def: TaskDefinition) -> None:
ServiceDefinition.parse_obj(service_def_dict)
for service_def_dict in service_def_dicts
]
if len(service_defs) > 1:
selector = PydanticMultiSelector.from_defaults(
llm=self.llm,
)
service_def_metadata = [
ToolMetadata(
description=service_def.description,
name=service_def.service_name,
)
for service_def in service_defs
]
result = await selector.aselect(service_def_metadata, task_def.input)

selected_service_name = service_defs[result.inds[0]].service_name
else:
selected_service_name = service_defs[0].service_name

await self.publish(
QueueMessage(
type=selected_service_name,
data=task_def.dict(),
action=ActionTypes.NEW_TASK,
),

service_tools = [
ServiceTool.from_service_definition(service_def)
for service_def in service_defs
]
next_messages, task_state = await self.orchestrator.get_next_messages(
task_def, service_tools, task_def.state
)

for message in next_messages:
await self.publish(message)

task_def.state.update(task_state)
return task_def

async def handle_service_completion(
self,
task_result: TaskResult,
) -> None:
"""
TODO:
Ideally, this would
- get state for the task
- call orchestrator.add_result_to_state(state, task_result)
- call orchestrator.get_next_messages(task_def, state)
- publish messages to the next services (if any)
- if no more, send result to human and remove task from control plane
"""
await self.publish(
QueueMessage(
type="human",
action=ActionTypes.COMPLETED_TASK,
data=task_result.result,
)
# add result to task state
task_def = await self.get_task_state(task_result.task_id)
state = await self.orchestrator.add_result_to_state(task_result, task_def.state)
task_def.state.update(state)

# generate and send new tasks (if any)
task_def = await self.send_task_to_service(task_def)

await self.state_store.aput(
task_def.task_id, task_def.dict(), collection=self.tasks_store_key
)

async def get_next_service(self, task_id: str) -> str:
return ""
async def get_task_state(self, task_id: str) -> TaskDefinition:
state_dict = await self.state_store.aget(
task_id, collection=self.tasks_store_key
)
if state_dict is None:
raise ValueError(f"Task with id {task_id} not found")

return TaskDefinition.parse_obj(state_dict)

async def get_all_tasks(self) -> Dict[str, TaskDefinition]:
state_dicts = await self.state_store.aget_all(collection=self.tasks_store_key)
return {
task_id: TaskDefinition.parse_obj(state_dict)
for task_id, state_dict in state_dicts.items()
}

async def request_user_input(self, task_id: str, message: str) -> None:
pass
Expand Down
9 changes: 3 additions & 6 deletions agentfile/launchers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from agentfile.message_queues.simple import SimpleMessageQueue
from agentfile.message_queues.base import PublishCallback
from agentfile.messages.base import QueueMessage
from agentfile.types import ActionTypes, TaskDefinition
from agentfile.types import ActionTypes, TaskDefinition, TaskResult
from agentfile.message_publishers.publisher import MessageQueuePublisherMixin


Expand Down Expand Up @@ -52,11 +52,8 @@ def publish_callback(self) -> Optional[PublishCallback]:
return self._publish_callback

async def handle_human_message(self, **kwargs: Any) -> None:
message_data = kwargs["message_data"]
result = (
message_data["result"] if "result" in message_data else str(message_data)
)
print("Got response:\n", result, flush=True)
result = TaskResult(**kwargs["message_data"])
print("Got response:\n", result.result, flush=True)

async def register_consumers(
self, consumers: Optional[List[BaseMessageQueueConsumer]] = None
Expand Down
120 changes: 120 additions & 0 deletions agentfile/orchestrators/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Any, Dict, List, Tuple

from llama_index.core.llms import LLM, ChatMessage
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import BaseTool

from agentfile.messages.base import QueueMessage
from agentfile.orchestrators.base import BaseOrchestrator
from agentfile.orchestrators.service_tool import ServiceTool
from agentfile.types import ActionTypes, TaskDefinition, TaskResult

HISTORY_KEY = "chat_history"
DEFAULT_SUMMARIZE_TMPL = "{history}\n\nThe above represents the progress so far, please condense the messages into a single message."
DEFAULT_FOLLOWUP_TMPL = "Pick the next action to take, or return a final response if my original input is satisfied. As a reminder, the original input was: {original_input}"


class AgentOrchestrator(BaseOrchestrator):
def __init__(
self,
llm: LLM,
human_description: str = "Useful for sending a final response.",
summarize_prompt: str = DEFAULT_SUMMARIZE_TMPL,
followup_prompt: str = DEFAULT_FOLLOWUP_TMPL,
):
self.llm = llm
self.summarize_prompt = summarize_prompt
self.followup_prompt = followup_prompt
self.human_tool = ServiceTool(name="human", description=human_description)

async def get_next_messages(
self, task_def: TaskDefinition, tools: List[BaseTool], state: Dict[str, Any]
) -> Tuple[List[QueueMessage], Dict[str, Any]]:
tools_plus_human = [self.human_tool, *tools]

chat_dicts = state.get(HISTORY_KEY, [])
chat_history = [ChatMessage.parse_obj(x) for x in chat_dicts]

# TODO: how to make memory configurable?
memory = ChatMemoryBuffer.from_defaults(chat_history=chat_history, llm=self.llm)

# check if first message
if len(chat_history) == 0:
memory.put(ChatMessage(role="user", content=task_def.input))
response = await self.llm.apredict_and_call(
tools,
user_msg=task_def.input,
# error_on_no_tool_call=False,
)
else:
messages = memory.get()
response = await self.llm.apredict_and_call(
tools_plus_human,
chat_history=messages,
# error_on_no_tool_call=False,
)

# check if there was a tool call
queue_messages = []
if len(response.sources) == 0 or response.sources[0].tool_name == "human":
queue_messages.append(
QueueMessage(
type="human",
data=TaskResult(
task_id=task_def.task_id,
history=memory.get_all(),
result=response.response,
).dict(),
action=ActionTypes.COMPLETED_TASK,
)
)
else:
for source in response.sources:
name = source.tool_name
input_data = source.raw_input
input_str = next(iter(input_data.values()))
queue_messages.append(
QueueMessage(
type=name,
data=TaskDefinition(
task_id=task_def.task_id, input=input_str
).dict(),
action=ActionTypes.NEW_TASK,
)
)

new_state = {HISTORY_KEY: [x.dict() for x in memory.get_all()]}
return queue_messages, new_state

async def add_result_to_state(
self,
result: TaskResult,
state: Dict[str, Any],
) -> Dict[str, Any]:
"""Add the result of processing a message to the state. Returns the new state."""

# summarize the result
new_history = result.history
new_history_str = "\n".join([str(x) for x in new_history])
# TODO: Better logic for when to summarize?
if len(new_history) > 1:
summarize_prompt_str = self.summarize_prompt.format(history=new_history_str)
summary = await self.llm.acomplete(summarize_prompt_str)

# get the current chat history, add the summary to it
chat_dicts = state.get(HISTORY_KEY, [])
chat_history = [ChatMessage.parse_obj(x) for x in chat_dicts]

chat_history.append(ChatMessage(role="assistant", content=str(summary)))

# add the followup prompt to the chat history
original_input = chat_history[0].content
chat_history.append(
ChatMessage(
role="user",
content=self.followup_prompt.format(original_input=original_input),
)
)

new_state = {HISTORY_KEY: [x.dict() for x in chat_history]}
return new_state
6 changes: 4 additions & 2 deletions agentfile/orchestrators/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple

from llama_index.core.tools import BaseTool

from agentfile.messages.base import QueueMessage
from agentfile.types import TaskDefinition, TaskResult


class BaseOrchestrator(ABC):
@abstractmethod
async def get_next_messages(
self, task_def: TaskDefinition, state: Dict[str, Any]
self, task_def: TaskDefinition, tools: List[BaseTool], state: Dict[str, Any]
) -> Tuple[List[QueueMessage], Dict[str, Any]]:
"""Get the next message to process. Returns the message and the new state."""
...

@abstractmethod
async def add_result_to_state(
self, state: Dict[str, Any], result: TaskResult
self, result: TaskResult, state: Dict[str, Any]
) -> Dict[str, Any]:
"""Add the result of processing a message to the state. Returns the new state."""
...
Loading

0 comments on commit 135b5fd

Please sign in to comment.