Skip to content

Commit

Permalink
Merge Existing changes from spc-langchain to skypoint-langchain (#3)
Browse files Browse the repository at this point in the history
* Merge Existing changes from spc-langchain to skypoint-langchain

* Update agent_toolkits imports and repository URL

* Update tools and version in langchain library

* Version dump
  • Loading branch information
arunraja1 authored Jan 4, 2024
1 parent b08201e commit ff181d7
Show file tree
Hide file tree
Showing 15 changed files with 1,053 additions and 7 deletions.
2 changes: 2 additions & 0 deletions libs/langchain/langchain/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
BaseMultiActionAgent,
BaseSingleActionAgent,
LLMSingleActionAgent,
StateAgentExecutor,
)
from langchain.agents.agent_iterator import AgentExecutorIterator
from langchain.agents.agent_toolkits import (
Expand Down Expand Up @@ -111,6 +112,7 @@ def __getattr__(name: str) -> Any:
"ReActTextWorldAgent",
"SelfAskWithSearchChain",
"StructuredChatAgent",
"StateAgentExecutor",
"Tool",
"ZeroShotAgent",
"create_json_agent",
Expand Down
151 changes: 150 additions & 1 deletion libs/langchain/langchain/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage
from langchain.schema.runnable import Runnable
from langchain.tools.base import BaseTool
from langchain.tools.base import BaseTool,StateTool
from langchain.utilities.asyncio import asyncio_timeout
from langchain.utils.input import get_color_mapping

Expand Down Expand Up @@ -1263,3 +1263,152 @@ def _prepare_intermediate_steps(
return self.trim_intermediate_steps(intermediate_steps)
else:
return intermediate_steps


class StateAgentExecutor(AgentExecutor):
"""Agent that is using tools."""

agent: Union[BaseSingleActionAgent, BaseMultiActionAgent]
"""The agent to run for creating a plan and determining actions
to take at each step of the execution loop."""
tools: Sequence[BaseTool]
"""The valid tools the agent can call."""
return_intermediate_steps: bool = False
"""Whether to return the agent's trajectory of intermediate steps
at the end in addition to the final output."""
max_iterations: Optional[int] = 15
"""The maximum number of steps to take before ending the execution
loop. Setting to 'None' could lead to an infinite loop."""

max_execution_time: Optional[float] = None
"""The maximum amount of wall clock time to spend in the execution
loop.
"""
early_stopping_method: str = "force"
"""The method to use for early stopping if the agent never
returns `AgentFinish`. Either 'force' or 'generate'.
`"force"` returns a string saying that it stopped because it met a
time or iteration limit.
`"generate"` calls the agent's LLM Chain one final time to generate
a final answer based on the previous steps.
"""
handle_parsing_errors: Union[
bool, str, Callable[[OutputParserException], str]
] = False
"""How to handle errors raised by the agent's output parser.
Defaults to `False`, which raises the error.
s
If `true`, the error will be sent back to the LLM as an observation.
If a string, the string itself will be sent to the LLM as an observation.
If a callable function, the function will be called with the exception
as an argument, and the result of that function will be passed to the agent
as an observation.
"""
state: List[Dict[str, Any]] = None
trim_intermediate_steps: Union[
int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]]
] = -1

def _take_next_step(
self,
name_to_tool_map: Dict[str, StateTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
"""Take a single step in the thought-action-observation loop.
Override this to take control of how the agent makes and acts on choices.
"""
try:
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)

# Call the LLM to see what to do.
output = self.agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
except OutputParserException as e:
if isinstance(self.handle_parsing_errors, bool):
raise_error = not self.handle_parsing_errors
else:
raise_error = False
if raise_error:
raise e
text = str(e)
if isinstance(self.handle_parsing_errors, bool):
if e.send_to_llm:
observation = str(e.observation)
text = str(e.llm_output)
else:
observation = "Invalid or incomplete response"
elif isinstance(self.handle_parsing_errors, str):
observation = self.handle_parsing_errors
elif callable(self.handle_parsing_errors):
observation = self.handle_parsing_errors(e)
else:
raise ValueError("Got unexpected type of `handle_parsing_errors`")
output = AgentAction("_Exception", observation, text)
if run_manager:
run_manager.on_agent_action(output, color="green")
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = ExceptionTool().run(
output.tool_input,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return [(output, observation)]
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
actions: List[AgentAction]
if isinstance(output, AgentAction):
actions = [output]
else:
actions = output
result = []
for agent_action in actions:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
# Otherwise we lookup the tool
if agent_action.tool in name_to_tool_map:
tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct
color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if intermediate_steps:
self.state.append(
{str(intermediate_steps[-1][0]): str(intermediate_steps[-1][1])}
)
else:
self.state = []
# metadata_dict = {}
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = tool.run(
agent_action.tool_input,
verbose=self.verbose,
color=color,
state=self.state,
callbacks=run_manager.get_child()
if run_manager
else None**tool_run_kwargs,
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = InvalidTool().run(
{
"requested_tool_name": agent_action.tool,
"available_tool_names": list(name_to_tool_map.keys()),
},
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
result.append((agent_action, observation))
return result
3 changes: 2 additions & 1 deletion libs/langchain/langchain/agents/agent_toolkits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit
from langchain.agents.agent_toolkits.sql.base import create_sql_agent
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_toolkits.unitycatalog.toolkit import UCSQLDatabaseToolkit
from langchain.agents.agent_toolkits.vectorstore.base import (
create_vectorstore_agent,
create_vectorstore_router_agent,
Expand All @@ -55,7 +56,6 @@
VectorStoreToolkit,
)
from langchain.agents.agent_toolkits.zapier.toolkit import ZapierToolkit
from langchain.tools.retriever import create_retriever_tool

DEPRECATED_AGENTS = [
"create_csv_agent",
Expand Down Expand Up @@ -97,6 +97,7 @@ def __getattr__(name: str) -> Any:
"PowerBIToolkit",
"SQLDatabaseToolkit",
"SparkSQLToolkit",
"UCSQLDatabaseToolkit",
"VectorStoreInfo",
"VectorStoreRouterToolkit",
"VectorStoreToolkit",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unity Catalog SQL agent."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""SQL agent."""
from typing import Any, Dict, List, Optional

from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
SQL_SUFFIX,
)
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, SystemMessage


def create_sql_agent(
llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX,
suffix: Optional[str] = None,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct an SQL agent from an LLM and tools."""
tools = toolkit.get_tools()
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
agent: BaseSingleActionAgent

if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix or SQL_SUFFIX,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)

elif agent_type == AgentType.OPENAI_FUNCTIONS:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
input_variables = ["input", "agent_scratchpad"]
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)

agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")

return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# flake8: noqa

SQL_PREFIX = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
"""

SQL_SUFFIX = """Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
{agent_scratchpad}"""

SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."""
Loading

0 comments on commit ff181d7

Please sign in to comment.