Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor]: Enhance cache flexibility in BaseChatModel #17386

Merged
merged 21 commits into from
Mar 19, 2024
30 changes: 22 additions & 8 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain.caches import BaseCache

if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -103,13 +104,14 @@ async def agenerate_from_stream(
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models."""

cache: Optional[bool] = None
cache: Optional[bool | BaseCache] = None
eyurtsev marked this conversation as resolved.
Show resolved Hide resolved
"""Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(
default=None, exclude=True)
"""[DEPRECATED] Callback manager to add to the run trace."""
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
Expand Down Expand Up @@ -205,7 +207,8 @@ def stream(
if type(self)._stream == BaseChatModel._stream:
# model doesn't implement streaming, so use default implementation
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
BaseMessageChunk, self.invoke(
input, config=config, stop=stop, **kwargs)
)
else:
config = ensure_config(config)
Expand Down Expand Up @@ -404,13 +407,15 @@ def generate(
)
except BaseException as e:
if run_managers:
run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
run_managers[i].on_llm_error(
e, response=LLMResult(generations=[]))
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
llm_output = self._combine_llm_outputs(
[res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
if run_managers:
Expand Down Expand Up @@ -504,7 +509,8 @@ async def agenerate(
*[
run_manager.on_llm_end(
LLMResult(
generations=[res.generations], llm_output=res.llm_output
generations=[
res.generations], llm_output=res.llm_output
)
)
for run_manager, res in zip(run_managers, results)
Expand All @@ -516,7 +522,8 @@ async def agenerate(
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
llm_output = self._combine_llm_outputs(
[res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await asyncio.gather(
Expand Down Expand Up @@ -566,6 +573,12 @@ def _generate_with_cache(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
# Add custom cache check
if isinstance(self.cache, BaseCache):
cache_key = self._get_cache_key(messages, stop, **kwargs)
cached_result = self.cache.lookup(cache_key)
if cached_result is not None:
return cached_result
llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
Expand Down Expand Up @@ -803,7 +816,8 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
output_str = self._call(messages, stop=stop,
run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
Expand Down
Loading