-
Notifications
You must be signed in to change notification settings - Fork 16k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
core[minor], langchain[patch]:
tools
dependencies refactoring (#18759)
The `langchain.tools` [namespace](https://api.python.langchain.com/en/latest/langchain_api_reference.html#module-langchain.tools) can be completely eliminated by moving one class and 3 functions into `core`. It makes sense since the class and functions are very core.
- Loading branch information
Showing
3 changed files
with
136 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +1,15 @@ | ||
from functools import partial | ||
from typing import Optional | ||
|
||
from langchain_core.callbacks.manager import ( | ||
Callbacks, | ||
) | ||
from langchain_core.prompts import ( | ||
BasePromptTemplate, | ||
PromptTemplate, | ||
aformat_document, | ||
format_document, | ||
from langchain_core.tools import ( | ||
RetrieverInput, | ||
ToolsRenderer, | ||
create_retriever_tool, | ||
render_text_description, | ||
render_text_description_and_args, | ||
) | ||
from langchain_core.pydantic_v1 import BaseModel, Field | ||
from langchain_core.retrievers import BaseRetriever | ||
|
||
from langchain.tools import Tool | ||
|
||
|
||
class RetrieverInput(BaseModel): | ||
"""Input to the retriever.""" | ||
|
||
query: str = Field(description="query to look up in retriever") | ||
|
||
|
||
def _get_relevant_documents( | ||
query: str, | ||
retriever: BaseRetriever, | ||
document_prompt: BasePromptTemplate, | ||
document_separator: str, | ||
callbacks: Callbacks = None, | ||
) -> str: | ||
docs = retriever.get_relevant_documents(query, callbacks=callbacks) | ||
return document_separator.join( | ||
format_document(doc, document_prompt) for doc in docs | ||
) | ||
|
||
|
||
async def _aget_relevant_documents( | ||
query: str, | ||
retriever: BaseRetriever, | ||
document_prompt: BasePromptTemplate, | ||
document_separator: str, | ||
callbacks: Callbacks = None, | ||
) -> str: | ||
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks) | ||
return document_separator.join( | ||
[await aformat_document(doc, document_prompt) for doc in docs] | ||
) | ||
|
||
|
||
def create_retriever_tool( | ||
retriever: BaseRetriever, | ||
name: str, | ||
description: str, | ||
*, | ||
document_prompt: Optional[BasePromptTemplate] = None, | ||
document_separator: str = "\n\n", | ||
) -> Tool: | ||
"""Create a tool to do retrieval of documents. | ||
Args: | ||
retriever: The retriever to use for the retrieval | ||
name: The name for the tool. This will be passed to the language model, | ||
so should be unique and somewhat descriptive. | ||
description: The description for the tool. This will be passed to the language | ||
model, so should be descriptive. | ||
|
||
Returns: | ||
Tool class to pass to an agent | ||
""" | ||
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}") | ||
func = partial( | ||
_get_relevant_documents, | ||
retriever=retriever, | ||
document_prompt=document_prompt, | ||
document_separator=document_separator, | ||
) | ||
afunc = partial( | ||
_aget_relevant_documents, | ||
retriever=retriever, | ||
document_prompt=document_prompt, | ||
document_separator=document_separator, | ||
) | ||
return Tool( | ||
name=name, | ||
description=description, | ||
func=func, | ||
coroutine=afunc, | ||
args_schema=RetrieverInput, | ||
) | ||
__all__ = [ | ||
"RetrieverInput", | ||
"ToolsRenderer", | ||
"create_retriever_tool", | ||
"render_text_description", | ||
"render_text_description_and_args", | ||
] |