Skip to content

Commit

Permalink
Add get validation endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
WonderPG committed Jan 10, 2025
1 parent 48c45b8 commit 4d9549f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/neuroagent/agent_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ async def get_chat_completion(
"messages": messages,
"tools": tools or None,
"tool_choice": agent.tool_choice,
"stream_options": {"include_usage": True},
"stream": stream,
}
if stream:
create_params["stream_options"] = ({"include_usage": True},)

if tools:
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
Expand Down Expand Up @@ -156,6 +157,7 @@ async def handle_tool_call(
if tool_call.validated is None:
return HILResponse(
message="Please validate the following inputs before proceeding.",
name=tool_call.name,
inputs=input_schema.model_dump(),
tool_call_id=tool_call.tool_call_id,
), None
Expand Down
44 changes: 42 additions & 2 deletions src/neuroagent/app/routers/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from neuroagent.app.database.schemas import ToolCallSchema
from neuroagent.app.database.sql_schemas import Entity, Messages, Threads, ToolCalls
from neuroagent.app.dependencies import get_session, get_starting_agent
from neuroagent.new_types import Agent, HILValidation
from neuroagent.new_types import Agent, HILResponse, HILValidation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,7 +113,47 @@ async def get_tool_returns(
return tool_output


@router.patch("/validate/{thread_id}/{tool_call_id}")
@router.get("/validation/{thread_id}/")
async def get_required_validation(
_: Annotated[Threads, Depends(get_thread)],
thread_id: str,
session: Annotated[AsyncSession, Depends(get_session)],
starting_agent: Annotated[Agent, Depends(get_starting_agent)],
) -> list[HILResponse]:
"""List tool calls currently requiring validation in a thread."""
message_query = await session.execute(
select(Messages)
.where(Messages.thread_id == thread_id)
.order_by(desc(Messages.order))
.limit(1)
)
message = message_query.scalar_one_or_none()
if not message or message.entity != Entity.AI_TOOL:
return []

else:
tool_calls = await message.awaitable_attrs.tool_calls
need_validation = []
for tool_call in tool_calls:
tool = next(
tool for tool in starting_agent.tools if tool.name == tool_call.name
)
if tool.hil and tool_call.validated is None:
input_schema = tool.__annotations__["input_schema"](
**json.loads(tool_call.arguments)
)
need_validation.append(
HILResponse(
message="Please validate the following inputs before proceeding.",
name=tool_call.name,
inputs=input_schema.model_dump(),
tool_call_id=tool_call.tool_call_id,
)
)
return need_validation


@router.patch("/validation/{thread_id}/{tool_call_id}")
async def validate_input(
user_request: HILValidation,
_: Annotated[Threads, Depends(get_thread)],
Expand Down
1 change: 1 addition & 0 deletions src/neuroagent/new_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class HILResponse(BaseModel):
"""Response for tools that require HIL validation."""

message: str
name: str
inputs: dict[str, Any]
tool_call_id: str

Expand Down

0 comments on commit 4d9549f

Please sign in to comment.