From 7704696b74446a69e486133362e670da6cf2d037 Mon Sep 17 00:00:00 2001 From: Onelevenvy <49232224+Onelevenvy@users.noreply.github.com> Date: Tue, 24 Sep 2024 15:28:19 +0800 Subject: [PATCH] feat:add system prompt support--backend (#44) --- backend/app/core/workflow/init_graph.py | 2 + backend/app/core/workflow/node.py | 97 +++++++++++++++---------- 2 files changed, 60 insertions(+), 39 deletions(-) diff --git a/backend/app/core/workflow/init_graph.py b/backend/app/core/workflow/init_graph.py index eaf2151..3efabb9 100644 --- a/backend/app/core/workflow/init_graph.py +++ b/backend/app/core/workflow/init_graph.py @@ -140,6 +140,7 @@ def initialize_graph( break if model_info is None: raise ValueError(f"Model {model_name} not supported now.") + # in the future wo can use more langchain templates here apply to different node type TODO if is_sequential: # node_class = SequentialWorkerNode @@ -179,6 +180,7 @@ def initialize_graph( openai_api_key=model_info["api_key"], openai_api_base=model_info["base_url"], temperature=node_data["temperature"], + system_prompt=node_data["systemMessage"], ).work ), ) diff --git a/backend/app/core/workflow/node.py b/backend/app/core/workflow/node.py index ddfbd67..9d98734 100644 --- a/backend/app/core/workflow/node.py +++ b/backend/app/core/workflow/node.py @@ -162,8 +162,9 @@ def __init__( openai_api_key: str, openai_api_base: str, temperature: float, + system_prompt: str, ): - + self.system_prompt = system_prompt if provider in ["zhipuai", "Siliconflow"]: self.model = ChatOpenAI( model=model, @@ -236,6 +237,62 @@ def get_team_members_name( return ",".join(list(team_members)) +class LLMNode(BaseNode): + """Perform LLM Node actions""" + + async def work(self, state: TeamState, config: RunnableConfig) -> ReturnTeamState: + if self.system_prompt: + llm_node_prompts = ChatPromptTemplate.from_messages( + [ + ( + "system", + "Perform the task given to you.\n" + "If you are unable to perform the task, that's OK, you can ask human for help, or just say that you are unable to perform the task." + "Execute what you can to make progress. " + "And your role is:" + self.system_prompt + "\n" + "Stay true to your role and use your tools if necessary.\n\n", + ), + ( + "human", + "Here is the previous conversation: \n\n {history_string} \n\n Provide your response.", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + else: + llm_node_prompts = ChatPromptTemplate.from_messages( + [ + ( + "system", + ( + "Perform the task given to you.\n" + "If you are unable to perform the task, that's OK, you can ask human for help, or just say that you are unable to perform the task." + "Execute what you can to make progress. " + "Stay true to your role and use your tools if necessary.\n\n" + ), + ), + ( + "human", + "Here is the previous conversation: \n\n {history_string} \n\n Provide your response.", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + history = state.get("history", []) + messages = state.get("messages", []) + prompt = llm_node_prompts.partial(history_string=format_messages(history)) + chain: RunnableSerializable[dict[str, Any], AnyMessage] = prompt | self.model + result: AIMessage = await chain.ainvoke(state, config) + + return_state: ReturnTeamState = { + "history": history + [result], + "messages": [result] if result.tool_calls else [], + "all_messages": messages + [result], + } + return return_state + + class WorkerNode(BaseNode): worker_prompt = ChatPromptTemplate.from_messages( [ @@ -334,44 +391,6 @@ async def work(self, state: TeamState, config: RunnableConfig) -> ReturnTeamStat return return_state -class LLMNode(BaseNode): - """Perform Sequential Worker actions""" - - worker_prompt = ChatPromptTemplate.from_messages( - [ - ( - "system", - ( - "Perform the task given to you.\n" - "If you are unable to perform the task, that's OK, you can ask human for help, or just say that you are unable to perform the task." - "Execute what you can to make progress. " - "Stay true to your role and use your tools if necessary.\n\n" - ), - ), - ( - "human", - "Here is the previous conversation: \n\n {history_string} \n\n Provide your response.", - ), - MessagesPlaceholder(variable_name="messages"), - ] - ) - - async def work(self, state: TeamState, config: RunnableConfig) -> ReturnTeamState: - history = state.get("history", []) - messages = state.get("messages", []) - prompt = self.worker_prompt.partial(history_string=format_messages(history)) - chain: RunnableSerializable[dict[str, Any], AnyMessage] = prompt | self.model - work_chain = chain - result: AIMessage = await work_chain.ainvoke(state, config) - - return_state: ReturnTeamState = { - "history": history + [result], - "messages": [result] if result.tool_calls else [], - "all_messages": messages + [result], - } - return return_state - - class LeaderNode(BaseNode): leader_prompt = ChatPromptTemplate.from_messages( [