Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/img gen tool and display #55

Merged
merged 3 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions backend/app/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def init_db(session: Session) -> None:
description=skill_info.description,
managed=True,
owner_id=user.id,
display_name=skill_info.display_name
display_name=skill_info.display_name,
)
session.add(new_skill) # Prepare new skill for addition to the database

Expand All @@ -99,11 +99,7 @@ def init_modelprovider_model_db(session: Session) -> None:
(1, 'Ollama', 'fakeurl', 'fakeapikey', 'string', 'string fake'),
(2, 'Siliconflow', 'fakeurl', 'fakeapikey', 'string', 'siliconflow'),
(3, 'zhipuai', 'https://open.bigmodel.cn/api/paas/v4', 'fakeapikey', 'zhipuai', '智谱AI')
ON CONFLICT (id) DO UPDATE
SET base_url = EXCLUDED.base_url,
api_key = EXCLUDED.api_key,
icon = EXCLUDED.icon,
description = EXCLUDED.description;
ON CONFLICT (id) DO NOTHING;
"""

# Insert Models data
Expand Down
51 changes: 51 additions & 0 deletions backend/app/core/tools/siliconflow/siliconflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import requests
import json
import requests
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import StructuredTool


class Text2ImageInput(BaseModel):
"""Input for the text2img tool."""

prompt: str = Field(description="the prompt for generating image ")


def text2img(
prompt: str,
):
"""
invoke tools
"""

try:
# request URL
url = "https://api.siliconflow.cn/v1/image/generations"

payload = {
# "model": "black-forest-labs/FLUX.1-schnell",
"model": "stabilityai/stable-diffusion-3-medium",
"prompt": prompt,
"image_size": "1024x1024",
}
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": "Bearer sk-uaxgsvfwwwpeuguzhsjpqigwopyhblsiesbptxnuxaoefqrb",
}

response = requests.post(url, json=payload, headers=headers)

return response.json()

except Exception as e:
return json.dumps(f"Openweather API Key is invalid. {e}")


siliconflow = StructuredTool.from_function(
func=text2img,
name="Image Generation",
description="Image Generation is a tool that can generate images from text prompts.",
args_schema=Text2ImageInput,
return_direct=True,
)
42 changes: 32 additions & 10 deletions backend/app/core/workflow/init_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from functools import lru_cache
from typing import Any, Dict, Set

from langchain.pydantic_v1 import BaseModel
from langchain.tools import BaseTool
from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.runnables import RunnableLambda
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, StateGraph
from langgraph.graph import END, START, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import ToolNode

Expand Down Expand Up @@ -200,31 +199,54 @@ def initialize_graph(
source_node = next(node for node in nodes if node["id"] == edge["source"])
target_node = next(node for node in nodes if node["id"] == edge["target"])

if source_node["type"] == "start":
if edge["type"] == "default":
graph_builder.add_edge(START, edge["target"])
else:
raise ValueError("Start node can only have normal edge.")
if source_node["type"] == "llm":
if target_node["type"] == "tool":
conditional_edges[source_node["id"]]["call_tools"][
target_node["id"]
] = target_node["id"]
if edge["type"] == "default":
graph_builder.add_edge(edge["source"], edge["target"])
else:
conditional_edges[source_node["id"]]["call_tools"][
target_node["id"]
] = target_node["id"]
elif target_node["type"] == "end":
conditional_edges[source_node["id"]]["default"][END] = END
if edge["type"] == "default":
graph_builder.add_edge(edge["source"], END)
else:
conditional_edges[source_node["id"]]["default"][END] = END
elif target_node["type"] == "llm":
if edge["type"] == "default":
graph_builder.add_edge(edge["source"], edge["target"])
else:
conditional_edges[source_node["id"]]["default"][
target_node["id"]
] = target_node["id"]
else:
conditional_edges[source_node["id"]]["default"][
target_node["id"]
] = target_node["id"]
if edge["type"] == "default":
graph_builder.add_edge(edge["source"], edge["target"])
else:
conditional_edges[source_node["id"]]["default"][
target_node["id"]
] = target_node["id"]

elif source_node["type"] == "tool" and target_node["type"] == "llm":
# Tool to LLM edge
graph_builder.add_edge(edge["source"], edge["target"])

# Add conditional edges

for llm_id, conditions in conditional_edges.items():
edges_dict = {
"default": next(iter(conditions["default"].values()), END),
**conditions["call_tools"],
}
if conditions["call_human"]:
edges_dict["call_human"] = next(iter(conditions["call_human"].values()))
graph_builder.add_conditional_edges(llm_id, should_continue, edges_dict)
if edges_dict != {"default": END}:
graph_builder.add_conditional_edges(llm_id, should_continue, edges_dict)

# Set entry point
graph_builder.set_entry_point(metadata["entry_point"])
Expand Down
2 changes: 1 addition & 1 deletion backend/app/core/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ async def work(self, state: TeamState, config: RunnableConfig) -> ReturnTeamStat
"If you are unable to perform the task, that's OK, you can ask human for help, or just say that you are unable to perform the task."
"Execute what you can to make progress. "
"And your role is:" + self.system_prompt + "\n"
"And your name is:" + self.agent_name + "\n"
"And your name is:" + self.agent_name + "please remember your name\n"
"Stay true to your role and use your tools if necessary.\n\n",
),
(
Expand Down
2 changes: 1 addition & 1 deletion backend/app/initial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def init() -> None:
with Session(engine) as session:
init_db(session)
# init_modelprovider_model_db(session)
init_modelprovider_model_db(session)


def main() -> None:
Expand Down
1 change: 1 addition & 0 deletions web/src/components/Playground/ChatMain.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ const ChatMain = ({ isPlayground }: { isPlayground?: boolean }) => {
key={index}
message={message}
onResume={onResumeHandler}
isPlayground = {isPlayground}
/>
))}
</Box>
Expand Down
13 changes: 8 additions & 5 deletions web/src/components/Playground/MessageBox.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ import Markdown from "../Markdown/Markdown";
interface MessageBoxProps {
message: ChatResponse;
onResume: (decision: InterruptDecision, toolMessage: string | null) => void;
isPlayground?: boolean;
}

const MessageBox = ({ message, onResume }: MessageBoxProps) => {
const MessageBox = ({ message, onResume, isPlayground }: MessageBoxProps) => {
const { type, name, next, content, tool_calls, tool_output, documents } =
message;
const [decision, setDecision] = useState<InterruptDecision | null>(null);
Expand Down Expand Up @@ -89,10 +90,12 @@ const MessageBox = ({ message, onResume }: MessageBoxProps) => {
<VStack spacing={0} my={4} onMouseEnter={onOpen} onMouseLeave={onClose}>
<Box
w="full"
ml={10}
mr={10}
pl={10}
pr={10}

ml={isPlayground ? "10" : "0"}
mr={isPlayground ? "10" : "0"}
pl={isPlayground ? "10" : "0"}
pr={isPlayground ? "10" : "0"}

display="flex"
alignItems="center"
justifyContent={type === "human" ? "flex-end" : "flex-start"}
Expand Down
2 changes: 1 addition & 1 deletion web/src/components/WorkFlow/nodes/nodeConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export const nodeConfig: Record<string, NodeConfigItem> = {
targets: ["left", "right"],
},
initialData: {
tools: ["open-weather"],
tools: ["Open Weather"],
},
},
questionClassifier: {
Expand Down
Loading