diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 97afdfaca0e46..8f10ce770f0a5 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -23,6 +23,7 @@ import uuid import warnings from abc import abstractmethod +from functools import partial from inspect import signature from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union @@ -32,9 +33,17 @@ BaseCallbackManager, CallbackManager, CallbackManagerForToolRun, +) +from langchain_core.callbacks.manager import ( Callbacks, ) from langchain_core.load.serializable import Serializable +from langchain_core.prompts import ( + BasePromptTemplate, + PromptTemplate, + aformat_document, + format_document, +) from langchain_core.pydantic_v1 import ( BaseModel, Extra, @@ -44,6 +53,7 @@ root_validator, validate_arguments, ) +from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import ( Runnable, RunnableConfig, @@ -920,3 +930,111 @@ def _partial(func: Callable[[str], str]) -> BaseTool: return _partial else: raise ValueError("Too many arguments for tool decorator") + + +class RetrieverInput(BaseModel): + """Input to the retriever.""" + + query: str = Field(description="query to look up in retriever") + + +def _get_relevant_documents( + query: str, + retriever: BaseRetriever, + document_prompt: BasePromptTemplate, + document_separator: str, + callbacks: Callbacks = None, +) -> str: + docs = retriever.get_relevant_documents(query, callbacks=callbacks) + return document_separator.join( + format_document(doc, document_prompt) for doc in docs + ) + + +async def _aget_relevant_documents( + query: str, + retriever: BaseRetriever, + document_prompt: BasePromptTemplate, + document_separator: str, + callbacks: Callbacks = None, +) -> str: + docs = await retriever.aget_relevant_documents(query, callbacks=callbacks) + return document_separator.join( + [await aformat_document(doc, document_prompt) for doc in docs] + ) + + +def create_retriever_tool( + retriever: BaseRetriever, + name: str, + description: str, + *, + document_prompt: Optional[BasePromptTemplate] = None, + document_separator: str = "\n\n", +) -> Tool: + """Create a tool to do retrieval of documents. + + Args: + retriever: The retriever to use for the retrieval + name: The name for the tool. This will be passed to the language model, + so should be unique and somewhat descriptive. + description: The description for the tool. This will be passed to the language + model, so should be descriptive. + + Returns: + Tool class to pass to an agent + """ + document_prompt = document_prompt or PromptTemplate.from_template("{page_content}") + func = partial( + _get_relevant_documents, + retriever=retriever, + document_prompt=document_prompt, + document_separator=document_separator, + ) + afunc = partial( + _aget_relevant_documents, + retriever=retriever, + document_prompt=document_prompt, + document_separator=document_separator, + ) + return Tool( + name=name, + description=description, + func=func, + coroutine=afunc, + args_schema=RetrieverInput, + ) + + +ToolsRenderer = Callable[[List[BaseTool]], str] + + +def render_text_description(tools: List[BaseTool]) -> str: + """Render the tool name and description in plain text. + + Output will be in the format of: + + .. code-block:: markdown + + search: This tool is used for search + calculator: This tool is used for math + """ + return "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) + + +def render_text_description_and_args(tools: List[BaseTool]) -> str: + """Render the tool name, description, and args in plain text. + + Output will be in the format of: + + .. code-block:: markdown + + search: This tool is used for search, args: {"query": {"type": "string"}} + calculator: This tool is used for math, \ +args: {"expression": {"type": "string"}} + """ + tool_strings = [] + for tool in tools: + args_schema = str(tool.args) + tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") + return "\n".join(tool_strings) diff --git a/libs/langchain/langchain/tools/render.py b/libs/langchain/langchain/tools/render.py index cb6fde55ead3b..f8494bde14e54 100644 --- a/libs/langchain/langchain/tools/render.py +++ b/libs/langchain/langchain/tools/render.py @@ -4,10 +4,13 @@ you may want Tools to be rendered in a different way. This module contains various ways to render tools. """ -from typing import Callable, List # For backwards compatibility -from langchain_core.tools import BaseTool +from langchain_core.tools import ( + ToolsRenderer, + render_text_description, + render_text_description_and_args, +) from langchain_core.utils.function_calling import ( format_tool_to_openai_function, format_tool_to_openai_tool, @@ -20,37 +23,3 @@ "format_tool_to_openai_tool", "format_tool_to_openai_function", ] - - -ToolsRenderer = Callable[[List[BaseTool]], str] - - -def render_text_description(tools: List[BaseTool]) -> str: - """Render the tool name and description in plain text. - - Output will be in the format of: - - .. code-block:: markdown - - search: This tool is used for search - calculator: This tool is used for math - """ - return "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) - - -def render_text_description_and_args(tools: List[BaseTool]) -> str: - """Render the tool name, description, and args in plain text. - - Output will be in the format of: - - .. code-block:: markdown - - search: This tool is used for search, args: {"query": {"type": "string"}} - calculator: This tool is used for math, \ -args: {"expression": {"type": "string"}} - """ - tool_strings = [] - for tool in tools: - args_schema = str(tool.args) - tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") - return "\n".join(tool_strings) diff --git a/libs/langchain/langchain/tools/retriever.py b/libs/langchain/langchain/tools/retriever.py index 5feeab6e04b77..6d76c02b5257f 100644 --- a/libs/langchain/langchain/tools/retriever.py +++ b/libs/langchain/langchain/tools/retriever.py @@ -1,90 +1,15 @@ -from functools import partial -from typing import Optional - -from langchain_core.callbacks.manager import ( - Callbacks, -) -from langchain_core.prompts import ( - BasePromptTemplate, - PromptTemplate, - aformat_document, - format_document, +from langchain_core.tools import ( + RetrieverInput, + ToolsRenderer, + create_retriever_tool, + render_text_description, + render_text_description_and_args, ) -from langchain_core.pydantic_v1 import BaseModel, Field -from langchain_core.retrievers import BaseRetriever - -from langchain.tools import Tool - - -class RetrieverInput(BaseModel): - """Input to the retriever.""" - - query: str = Field(description="query to look up in retriever") - - -def _get_relevant_documents( - query: str, - retriever: BaseRetriever, - document_prompt: BasePromptTemplate, - document_separator: str, - callbacks: Callbacks = None, -) -> str: - docs = retriever.get_relevant_documents(query, callbacks=callbacks) - return document_separator.join( - format_document(doc, document_prompt) for doc in docs - ) - - -async def _aget_relevant_documents( - query: str, - retriever: BaseRetriever, - document_prompt: BasePromptTemplate, - document_separator: str, - callbacks: Callbacks = None, -) -> str: - docs = await retriever.aget_relevant_documents(query, callbacks=callbacks) - return document_separator.join( - [await aformat_document(doc, document_prompt) for doc in docs] - ) - - -def create_retriever_tool( - retriever: BaseRetriever, - name: str, - description: str, - *, - document_prompt: Optional[BasePromptTemplate] = None, - document_separator: str = "\n\n", -) -> Tool: - """Create a tool to do retrieval of documents. - - Args: - retriever: The retriever to use for the retrieval - name: The name for the tool. This will be passed to the language model, - so should be unique and somewhat descriptive. - description: The description for the tool. This will be passed to the language - model, so should be descriptive. - Returns: - Tool class to pass to an agent - """ - document_prompt = document_prompt or PromptTemplate.from_template("{page_content}") - func = partial( - _get_relevant_documents, - retriever=retriever, - document_prompt=document_prompt, - document_separator=document_separator, - ) - afunc = partial( - _aget_relevant_documents, - retriever=retriever, - document_prompt=document_prompt, - document_separator=document_separator, - ) - return Tool( - name=name, - description=description, - func=func, - coroutine=afunc, - args_schema=RetrieverInput, - ) +__all__ = [ + "RetrieverInput", + "ToolsRenderer", + "create_retriever_tool", + "render_text_description", + "render_text_description_and_args", +]