Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Oct 19, 2024
1 parent 53bdf82 commit 7c03a1e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
8 changes: 6 additions & 2 deletions llama_deploy/control_plane/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,13 @@ async def send_event(
session_id=session_id,
input=event_def.event_obj_str,
agent_id=event_def.agent_id,
is_send_event=True,
)
await self.send_task_to_service(task_def)
message = QueueMessage(
type=event_def.agent_id,
action=ActionTypes.SEND_EVENT,
data=task_def.model_dump(),
)
await self.publish(message)

async def get_session_state(self, session_id: str) -> Dict[str, Any]:
session = await self.get_session(session_id)
Expand Down
29 changes: 14 additions & 15 deletions llama_deploy/services/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,23 +417,22 @@ async def process_message(self, message: QueueMessage) -> None:
if message.action == ActionTypes.NEW_TASK:
task_def = TaskDefinition(**message.data or {})

if not task_def.is_send_event:
run_kwargs = json.loads(task_def.input)
workflow_state = WorkflowState(
session_id=task_def.session_id,
task_id=task_def.task_id,
run_kwargs=run_kwargs,
)
run_kwargs = json.loads(task_def.input)
workflow_state = WorkflowState(
session_id=task_def.session_id,
task_id=task_def.task_id,
run_kwargs=run_kwargs,
)

async with self.lock:
self._outstanding_calls[task_def.task_id] = workflow_state
else:
serializer = JsonSerializer()
async with self.lock:
self._outstanding_calls[task_def.task_id] = workflow_state
elif message.action == ActionTypes.SEND_EVENT:
serializer = JsonSerializer()

task_def = TaskDefinition(**message.data or {})
event = serializer.deserialize(task_def.input)
async with self.lock:
self._events_buffer[task_def.task_id][type(event)].put_nowait(event)
task_def = TaskDefinition(**message.data or {})
event = serializer.deserialize(task_def.input)
async with self.lock:
self._events_buffer[task_def.task_id][type(event)].put_nowait(event)

else:
raise ValueError(f"Unhandled action: {message.action}")
Expand Down
2 changes: 1 addition & 1 deletion llama_deploy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ActionTypes(str, Enum):
NEW_TOOL_CALL = "new_tool_call"
COMPLETED_TOOL_CALL = "completed_tool_call"
TASK_STREAM = "task_stream"
SEND_EVENT = "send_event"


class TaskDefinition(BaseModel):
Expand All @@ -100,7 +101,6 @@ class TaskDefinition(BaseModel):
task_id: str = Field(default_factory=generate_id)
session_id: Optional[str] = None
agent_id: Optional[str] = None
is_send_event: bool = False


class SessionDefinition(BaseModel):
Expand Down

0 comments on commit 7c03a1e

Please sign in to comment.