Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make streaming vercel compatible #62

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# required.
NEUROAGENT_TOOLS__LITERATURE__URL=
NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL=
NEUROAGENT_GENERATIVE__OPENAI__TOKEN=
NEUROAGENT_OPENAI__TOKEN=

# Important but not required
NEUROAGENT_AGENT__MODEL=

NEUROAGENT_KNOWLEDGE_GRAPH__DOWNLOAD_HIERARCHY=

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Return model dumps of DB schema objects.
- Moved swarm_copy to neuroagent and delete old code.
- Made streaming vercel compatible.

### Added
- LLM evaluation logic
Expand Down
75 changes: 52 additions & 23 deletions src/neuroagent/agent_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ async def get_chat_completion(
"tool_choice": agent.tool_choice,
"stream": stream,
}
if stream:
create_params["stream_options"] = {"include_usage": True} # type: ignore

if tools:
create_params["parallel_tool_calls"] = agent.parallel_tool_calls
Expand Down Expand Up @@ -155,6 +157,7 @@ async def handle_tool_call(
if tool_call.validated is None:
return HILResponse(
message="Please validate the following inputs before proceeding.",
name=tool_call.name,
inputs=input_schema.model_dump(),
tool_call_id=tool_call.tool_call_id,
), None
Expand Down Expand Up @@ -221,6 +224,7 @@ async def arun(
model_override=model_override,
stream=False,
)

message = completion.choices[0].message # type: ignore
message.sender = active_agent.name

Expand Down Expand Up @@ -266,7 +270,7 @@ async def arun(
# If the tool call response contains HIL validation, do not update anything and return
if partial_response.hil_messages:
return Response(
messages=history[init_len:],
messages=[],
agent=active_agent,
context_variables=context_variables,
hil_messages=partial_response.hil_messages,
Expand Down Expand Up @@ -307,7 +311,6 @@ async def astream(
content = await messages_to_openai_content(messages)
history = copy.deepcopy(content)
init_len = len(messages)
is_streaming = False

while len(history) - init_len < max_turns:
message: dict[str, Any] = {
Expand All @@ -333,27 +336,53 @@ async def astream(
model_override=model_override,
stream=True,
)
draft_tool_calls = [] # type: ignore
draft_tool_calls_index = -1
async for chunk in completion: # type: ignore
delta = json.loads(chunk.choices[0].delta.model_dump_json())

# Check for tool calls
if delta["tool_calls"]:
tool = delta["tool_calls"][0]["function"]
if tool["name"]:
yield f"\nCalling tool : {tool['name']} with arguments : "
if tool["arguments"]:
yield tool["arguments"]

# Check for content
if delta["content"]:
if not is_streaming:
yield "\n<begin_llm_response>\n"
is_streaming = True
yield delta["content"]

delta.pop("role", None)
merge_chunk(message, delta)

for choice in chunk.choices:
if choice.finish_reason == "stop":
continue

elif choice.finish_reason == "tool_calls":
for tool_call in draft_tool_calls:
yield f"9:{{'toolCallId':'{tool_call['id']}','toolName':'{tool_call['name']}','args':{tool_call['arguments']}}}\n"

# Check for tool calls
elif choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
id = tool_call.id
name = tool_call.function.name
arguments = tool_call.function.arguments
if id is not None:
draft_tool_calls_index += 1
draft_tool_calls.append(
{"id": id, "name": name, "arguments": ""}
)
yield f"b:{{'toolCallId':{id},'toolName':{name}}}\n"

else:
draft_tool_calls[draft_tool_calls_index][
"arguments"
] += arguments
yield f"c:{{toolCallId:{id}; argsTextDelta:{arguments}}}\n"

else:
yield f"0:{json.dumps(choice.delta.content)}\n"

delta_json = choice.delta.model_dump()
delta_json.pop("role", None)
merge_chunk(message, delta_json)

if chunk.choices == []:
usage = chunk.usage
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens

yield 'd:{{"finishReason":"{reason}","usage":{{"promptTokens":{prompt},"completionTokens":{completion}}}}}\n'.format(
reason="tool-calls" if len(draft_tool_calls) > 0 else "stop",
prompt=prompt_tokens,
completion=completion_tokens,
)
message["tool_calls"] = list(message.get("tool_calls", {}).values())
if not message["tool_calls"]:
message["tool_calls"] = None
Expand Down Expand Up @@ -396,7 +425,7 @@ async def astream(
# If the tool call response contains HIL validation, do not update anything and return
if partial_response.hil_messages:
yield Response(
messages=history[init_len:],
messages=[],
agent=active_agent,
context_variables=context_variables,
hil_messages=partial_response.hil_messages,
Expand Down
26 changes: 13 additions & 13 deletions src/neuroagent/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Literal, Optional

from dotenv import dotenv_values
from pydantic import BaseModel, ConfigDict, SecretStr, model_validator
from pydantic import BaseModel, ConfigDict, SecretStr
from pydantic_settings import BaseSettings, SettingsConfigDict


Expand Down Expand Up @@ -228,21 +228,21 @@ class Settings(BaseSettings):
frozen=True,
)

@model_validator(mode="after")
def check_consistency(self) -> "Settings":
"""Check if consistent.
# @model_validator(mode="after")
# def check_consistency(self) -> "Settings":
# """Check if consistent.

ATTENTION: Do not put model validators into the child settings. The
model validator is run during instantiation.
# ATTENTION: Do not put model validators into the child settings. The
# model validator is run during instantiation.

"""
# If you don't enforce keycloak auth, you need a way to communicate with the APIs the tools leverage
if not self.keycloak.password and not self.keycloak.validate_token:
raise ValueError(
"Need an auth method for subsequent APIs called by the tools."
)
# """
# # If you don't enforce keycloak auth, you need a way to communicate with the APIs the tools leverage
# if not self.keycloak.password and not self.keycloak.validate_token:
# raise ValueError(
# "Need an auth method for subsequent APIs called by the tools."
# )

return self
# return self


# Load the remaining variables into the environment
Expand Down
13 changes: 10 additions & 3 deletions src/neuroagent/app/routers/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
AgentRequest,
AgentResponse,
HILResponse,
VercelRequest,
)
from neuroagent.stream import stream_agent_response

Expand Down Expand Up @@ -95,7 +96,7 @@ async def run_chat_agent(

@router.post("/chat_streamed/{thread_id}")
async def stream_chat_agent(
user_request: AgentRequest,
user_request: VercelRequest,
request: Request,
agents_routine: Annotated[AgentsRoutine, Depends(get_agents_routine)],
agent: Annotated[Agent, Depends(get_starting_agent)],
Expand All @@ -114,7 +115,9 @@ async def stream_chat_agent(
order=len(messages),
thread_id=thread.thread_id,
entity=Entity.USER,
content=json.dumps({"role": "user", "content": user_request.query}),
content=json.dumps(
{"role": "user", "content": user_request.messages[0].content}
),
)
)
stream_generator = stream_agent_response(
Expand All @@ -125,4 +128,8 @@ async def stream_chat_agent(
thread,
request,
)
return StreamingResponse(stream_generator, media_type="text/event-stream")
return StreamingResponse(
stream_generator,
media_type="text/event-stream",
headers={"x-vercel-ai-data-stream": "v1"},
)
47 changes: 44 additions & 3 deletions src/neuroagent/app/routers/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from neuroagent.app.database.schemas import ToolCallSchema
from neuroagent.app.database.sql_schemas import Entity, Messages, Threads, ToolCalls
from neuroagent.app.dependencies import get_session, get_starting_agent
from neuroagent.new_types import Agent, HILValidation
from neuroagent.new_types import Agent, HILResponse, HILValidation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,16 +113,57 @@ async def get_tool_returns(
return tool_output


@router.patch("/validate/{thread_id}")
@router.get("/validation/{thread_id}/")
async def get_required_validation(
_: Annotated[Threads, Depends(get_thread)],
thread_id: str,
session: Annotated[AsyncSession, Depends(get_session)],
starting_agent: Annotated[Agent, Depends(get_starting_agent)],
) -> list[HILResponse]:
"""List tool calls currently requiring validation in a thread."""
message_query = await session.execute(
select(Messages)
.where(Messages.thread_id == thread_id)
.order_by(desc(Messages.order))
.limit(1)
)
message = message_query.scalar_one_or_none()
if not message or message.entity != Entity.AI_TOOL:
return []

else:
tool_calls = await message.awaitable_attrs.tool_calls
need_validation = []
for tool_call in tool_calls:
tool = next(
tool for tool in starting_agent.tools if tool.name == tool_call.name
)
if tool.hil and tool_call.validated is None:
input_schema = tool.__annotations__["input_schema"](
**json.loads(tool_call.arguments)
)
need_validation.append(
HILResponse(
message="Please validate the following inputs before proceeding.",
name=tool_call.name,
inputs=input_schema.model_dump(),
tool_call_id=tool_call.tool_call_id,
)
)
return need_validation


@router.patch("/validation/{thread_id}/{tool_call_id}")
async def validate_input(
user_request: HILValidation,
_: Annotated[Threads, Depends(get_thread)],
tool_call_id: str,
session: Annotated[AsyncSession, Depends(get_session)],
starting_agent: Annotated[Agent, Depends(get_starting_agent)],
) -> ToolCallSchema:
"""Validate HIL inputs."""
# We first find the AI TOOL message to modify.
tool_call = await session.get(ToolCalls, user_request.tool_call_id)
tool_call = await session.get(ToolCalls, tool_call_id)
if not tool_call:
raise HTTPException(status_code=404, detail="Specified tool call not found.")
if tool_call.validated is not None:
Expand Down
34 changes: 33 additions & 1 deletion src/neuroagent/new_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ class HILResponse(BaseModel):
"""Response for tools that require HIL validation."""

message: str
name: str
inputs: dict[str, Any]
tool_call_id: str


class HILValidation(BaseModel):
"""Class to send the validated json to the api."""

tool_call_id: str
validated_inputs: dict[str, Any] | None = None
is_validated: bool = True

Expand All @@ -56,6 +56,38 @@ class AgentResponse(BaseModel):
message: str = ""


class ClientAttachment(BaseModel):
"""Vercel class."""

name: str
contentType: str
url: str


class ToolInvocation(BaseModel):
"""Vercel class."""

toolCallId: str
toolName: str
args: dict[str, Any]
result: dict[str, Any]


class ClientMessage(BaseModel):
"""Vercel class."""

role: str
content: str
experimental_attachments: list[ClientAttachment] | None = None
toolInvocations: list[ToolInvocation] | None = None


class VercelRequest(BaseModel):
"""Vercel class."""

messages: list[ClientMessage]


class Result(BaseModel):
"""
Encapsulates the possible return values for an agent function.
Expand Down
5 changes: 2 additions & 3 deletions src/neuroagent/stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Wrapper around streaming methods to reinitiate connections due to the way fastAPI StreamingResponse works."""

import json
from typing import Any, AsyncIterator

from fastapi import Request
Expand Down Expand Up @@ -51,9 +52,7 @@ async def stream_agent_response(
yield chunk
# Final chunk that contains the whole response
elif chunk.hil_messages:
yield str(
[hil_message.model_dump_json() for hil_message in chunk.hil_messages]
)
yield f"2:{json.dumps([hil_message.model_dump_json() for hil_message in chunk.hil_messages])}\n"

# Save the new messages in DB
thread.update_date = utc_now()
Expand Down
Loading
Loading