diff --git a/README.md b/README.md index a9490035ac828..65985434439d5 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Looking for the JS/TS library? Check out [LangChain.js](https://github.com/langc To help you ship LangChain apps to production faster, check out [LangSmith](https://smith.langchain.com). [LangSmith](https://smith.langchain.com) is a unified developer platform for building, testing, and monitoring LLM applications. -Fill out [this form](https://airtable.com/appwQzlErAS2qiP0L/shrGtGaVBVAz7NcV2) to get off the waitlist or speak with our sales team. +Fill out [this form](https://www.langchain.com/contact-sales) to speak with our sales team. ## Quick Install diff --git a/docs/docs/contributing/repo_structure.mdx b/docs/docs/contributing/repo_structure.mdx new file mode 100644 index 0000000000000..90f212265c85e --- /dev/null +++ b/docs/docs/contributing/repo_structure.mdx @@ -0,0 +1,54 @@ +--- +sidebar_position: 0.5 +--- +# Repository Structure + +If you plan on contributing to LangChain code or documentation, it can be useful +to understand the high level structure of the repository. + +LangChain is organized as a [monorep](https://en.wikipedia.org/wiki/Monorepo) that contains multiple packages. + +Here's the structure visualized as a tree: + +```text +. +├── cookbook # Tutorials and examples +├── docs # Contains content for the documentation here: https://python.langchain.com/ +├── libs +│ ├── langchain # Main package +│ │ ├── tests/unit_tests # Unit tests (present in each package not shown for brevity) +│ │ ├── tests/integration_tests # Integration tests (present in each package not shown for brevity) +│ ├── langchain-community # Third-party integrations +│ ├── langchain-core # Base interfaces for key abstractions +│ ├── langchain-experimental # Experimental components and chains +│ ├── partners +│ ├── langchain-partner-1 +│ ├── langchain-partner-2 +│ ├── ... +│ +├── templates # A collection of easily deployable reference architectures for a wide variety of tasks. +``` + +The root directory also contains the following files: + +* `pyproject.toml`: Dependencies for building docs and linting docs, cookbook. +* `Makefile`: A file that contains shortcuts for building, linting and docs and cookbook. + +There are other files in the root directory level, but their presence should be self-explanatory. Feel free to browse around! + +## Documentation + +The `/docs` directory contains the content for the documentation that is shown +at https://python.langchain.com/ and the associated API Reference https://api.python.langchain.com/en/latest/langchain_api_reference.html. + +See the [documentation](./documentation) guidelines to learn how to contribute to the documentation. + +## Code + +The `/libs` directory contains the code for the LangChain packages. + +To learn more about how to contribute code see the following guidelines: + +- [Code](./code.mdx) Learn how to develop in the LangChain codebase. +- [Integrations](./integrations.mdx) to learn how to contribute to third-party integrations to langchain-community or to start a new partner package. +- [Testing](./testing.mdx) guidelines to learn how to write tests for the packages. diff --git a/docs/docs/integrations/chat_loaders/langsmith_dataset.ipynb b/docs/docs/integrations/chat_loaders/langsmith_dataset.ipynb index 5b98b2cc96f3d..85586b03711a4 100644 --- a/docs/docs/integrations/chat_loaders/langsmith_dataset.ipynb +++ b/docs/docs/integrations/chat_loaders/langsmith_dataset.ipynb @@ -55,7 +55,7 @@ "source": [ "## 1. Select a dataset\n", "\n", - "This notebook fine-tunes a model directly on selecting which runs to fine-tune on. You will often curate these from traced runs. You can learn more about LangSmith datasets in the docs [docs](https://docs.smith.langchain.com/evaluation/datasets).\n", + "This notebook fine-tunes a model directly on selecting which runs to fine-tune on. You will often curate these from traced runs. You can learn more about LangSmith datasets in the docs [docs](https://docs.smith.langchain.com/evaluation/concepts#datasets).\n", "\n", "For the sake of this tutorial, we will upload an existing dataset here that you can use." ] diff --git a/libs/community/langchain_community/cache.py b/libs/community/langchain_community/cache.py index c8741c8f67ae0..5f36aa5856b44 100644 --- a/libs/community/langchain_community/cache.py +++ b/libs/community/langchain_community/cache.py @@ -29,12 +29,14 @@ import warnings from abc import ABC from datetime import timedelta -from functools import lru_cache +from functools import lru_cache, wraps from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Dict, + Generator, List, Optional, Sequence, @@ -56,20 +58,23 @@ from langchain_core.caches import RETURN_VAL_TYPE, BaseCache from langchain_core.embeddings import Embeddings -from langchain_core.language_models.llms import LLM, get_prompts +from langchain_core.language_models.llms import LLM, aget_prompts, get_prompts from langchain_core.load.dump import dumps from langchain_core.load.load import loads from langchain_core.outputs import ChatGeneration, Generation from langchain_core.utils import get_from_env -from langchain_community.utilities.astradb import AstraDBEnvironment +from langchain_community.utilities.astradb import ( + SetupMode, + _AstraDBCollectionEnvironment, +) from langchain_community.vectorstores.redis import Redis as RedisVectorstore logger = logging.getLogger(__file__) if TYPE_CHECKING: import momento - from astrapy.db import AstraDB + from astrapy.db import AstraDB, AsyncAstraDB from cassandra.cluster import Session as CassandraSession @@ -1371,6 +1376,10 @@ class AstraDBCache(BaseCache): (needed to prevent same-prompt-different-model collisions) """ + @staticmethod + def _make_id(prompt: str, llm_string: str) -> str: + return f"{_hash(prompt)}#{_hash(llm_string)}" + def __init__( self, *, @@ -1378,7 +1387,10 @@ def __init__( token: Optional[str] = None, api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDB] = None, + async_astra_db_client: Optional[AsyncAstraDB] = None, namespace: Optional[str] = None, + pre_delete_collection: bool = False, + setup_mode: SetupMode = SetupMode.SYNC, ): """ Create an AstraDB cache using a collection for storage. @@ -1388,29 +1400,35 @@ def __init__( token (Optional[str]): API token for Astra DB usage. api_endpoint (Optional[str]): full URL to the API endpoint, such as "https://-us-east1.apps.astra.datastax.com". - astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, + astra_db_client (Optional[AstraDB]): + *alternative to token+api_endpoint*, you can pass an already-created 'astrapy.db.AstraDB' instance. + async_astra_db_client (Optional[AsyncAstraDB]): + *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. namespace (Optional[str]): namespace (aka keyspace) where the collection is created. Defaults to the database's "default namespace". + pre_delete_collection (bool): whether to delete and re-create the + collection. Defaults to False. + async_setup (bool): whether to create the collection asynchronously. + Enable only if there is a running asyncio event loop. Defaults to False. """ - astra_env = AstraDBEnvironment( + self.astra_env = _AstraDBCollectionEnvironment( + collection_name=collection_name, token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, ) - self.astra_db = astra_env.astra_db - self.collection = self.astra_db.create_collection( - collection_name=collection_name, - ) - self.collection_name = collection_name - - @staticmethod - def _make_id(prompt: str, llm_string: str) -> str: - return f"{_hash(prompt)}#{_hash(llm_string)}" + self.collection = self.astra_env.collection + self.async_collection = self.astra_env.async_collection def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" + self.astra_env.ensure_db_setup() doc_id = self._make_id(prompt, llm_string) item = self.collection.find_one( filter={ @@ -1420,18 +1438,27 @@ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: "body_blob": 1, }, )["data"]["document"] - if item is not None: - generations = _loads_generations(item["body_blob"]) - # this protects against malformed cached items: - if generations is not None: - return generations - else: - return None - else: - return None + return _loads_generations(item["body_blob"]) if item is not None else None + + async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + await self.astra_env.aensure_db_setup() + doc_id = self._make_id(prompt, llm_string) + item = ( + await self.async_collection.find_one( + filter={ + "_id": doc_id, + }, + projection={ + "body_blob": 1, + }, + ) + )["data"]["document"] + return _loads_generations(item["body_blob"]) if item is not None else None def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update cache based on prompt and llm_string.""" + self.astra_env.ensure_db_setup() doc_id = self._make_id(prompt, llm_string) blob = _dumps_generations(return_val) self.collection.upsert( @@ -1441,6 +1468,20 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N }, ) + async def aupdate( + self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE + ) -> None: + """Update cache based on prompt and llm_string.""" + await self.astra_env.aensure_db_setup() + doc_id = self._make_id(prompt, llm_string) + blob = _dumps_generations(return_val) + await self.async_collection.upsert( + { + "_id": doc_id, + "body_blob": blob, + }, + ) + def delete_through_llm( self, prompt: str, llm: LLM, stop: Optional[List[str]] = None ) -> None: @@ -1454,14 +1495,42 @@ def delete_through_llm( )[1] return self.delete(prompt, llm_string=llm_string) + async def adelete_through_llm( + self, prompt: str, llm: LLM, stop: Optional[List[str]] = None + ) -> None: + """ + A wrapper around `adelete` with the LLM being passed. + In case the llm(prompt) calls have a `stop` param, you should pass it here + """ + llm_string = ( + await aget_prompts( + {**llm.dict(), **{"stop": stop}}, + [], + ) + )[1] + return await self.adelete(prompt, llm_string=llm_string) + def delete(self, prompt: str, llm_string: str) -> None: """Evict from cache if there's an entry.""" + self.astra_env.ensure_db_setup() doc_id = self._make_id(prompt, llm_string) self.collection.delete_one(doc_id) + async def adelete(self, prompt: str, llm_string: str) -> None: + """Evict from cache if there's an entry.""" + await self.astra_env.aensure_db_setup() + doc_id = self._make_id(prompt, llm_string) + await self.async_collection.delete_one(doc_id) + def clear(self, **kwargs: Any) -> None: """Clear cache. This is for all LLMs at once.""" - self.astra_db.truncate_collection(self.collection_name) + self.astra_env.ensure_db_setup() + self.collection.clear() + + async def aclear(self, **kwargs: Any) -> None: + """Clear cache. This is for all LLMs at once.""" + await self.astra_env.aensure_db_setup() + await self.async_collection.clear() ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD = 0.85 @@ -1469,6 +1538,42 @@ def clear(self, **kwargs: Any) -> None: ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE = 16 +_unset = ["unset"] + + +class _CachedAwaitable: + """Caches the result of an awaitable so it can be awaited multiple times""" + + def __init__(self, awaitable: Awaitable[Any]): + self.awaitable = awaitable + self.result = _unset + + def __await__(self) -> Generator: + if self.result is _unset: + self.result = yield from self.awaitable.__await__() + return self.result + + +def _reawaitable(func: Callable) -> Callable: + """Makes an async function result awaitable multiple times""" + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> _CachedAwaitable: + return _CachedAwaitable(func(*args, **kwargs)) + + return wrapper + + +def _async_lru_cache(maxsize: int = 128, typed: bool = False) -> Callable: + """Least-recently-used async cache decorator. + Equivalent to functools.lru_cache for async functions""" + + def decorating_function(user_function: Callable) -> Callable: + return lru_cache(maxsize, typed)(_reawaitable(user_function)) + + return decorating_function + + class AstraDBSemanticCache(BaseCache): """ Cache that uses Astra DB as a vector-store backend for semantic @@ -1479,7 +1584,7 @@ class AstraDBSemanticCache(BaseCache): in the document metadata. You can choose the preferred similarity (or use the API default) -- - remember the threshold might require metric-dependend tuning. + remember the threshold might require metric-dependent tuning. """ def __init__( @@ -1489,7 +1594,10 @@ def __init__( token: Optional[str] = None, api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDB] = None, + async_astra_db_client: Optional[AsyncAstraDB] = None, namespace: Optional[str] = None, + setup_mode: SetupMode = SetupMode.SYNC, + pre_delete_collection: bool = False, embedding: Embeddings, metric: Optional[str] = None, similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD, @@ -1502,10 +1610,17 @@ def __init__( token (Optional[str]): API token for Astra DB usage. api_endpoint (Optional[str]): full URL to the API endpoint, such as "https://-us-east1.apps.astra.datastax.com". - astra_db_client (Optional[Any]): *alternative to token+api_endpoint*, + astra_db_client (Optional[AstraDB]): *alternative to token+api_endpoint*, you can pass an already-created 'astrapy.db.AstraDB' instance. + async_astra_db_client (Optional[AsyncAstraDB]): + *alternative to token+api_endpoint*, + you can pass an already-created 'astrapy.db.AsyncAstraDB' instance. namespace (Optional[str]): namespace (aka keyspace) where the collection is created. Defaults to the database's "default namespace". + setup_mode (SetupMode): mode used to create the collection in the DB + (SYNC, ASYNC or OFF). Defaults to SYNC. + pre_delete_collection (bool): whether to delete and re-create the + collection. Defaults to False. embedding (Embedding): Embedding provider for semantic encoding and search. metric: the function to use for evaluating similarity of text embeddings. @@ -1516,17 +1631,10 @@ def __init__( The default score threshold is tuned to the default metric. Tune it carefully yourself if switching to another distance metric. """ - astra_env = AstraDBEnvironment( - token=token, - api_endpoint=api_endpoint, - astra_db_client=astra_db_client, - namespace=namespace, - ) - self.astra_db = astra_env.astra_db - self.embedding = embedding self.metric = metric self.similarity_threshold = similarity_threshold + self.collection_name = collection_name # The contract for this class has separate lookup and update: # in order to spare some embedding calculations we cache them between @@ -1538,25 +1646,47 @@ def _cache_embedding(text: str) -> List[float]: return self.embedding.embed_query(text=text) self._get_embedding = _cache_embedding - self.embedding_dimension = self._get_embedding_dimension() - self.collection_name = collection_name + @_async_lru_cache(maxsize=ASTRA_DB_SEMANTIC_CACHE_EMBEDDING_CACHE_SIZE) + async def _acache_embedding(text: str) -> List[float]: + return await self.embedding.aembed_query(text=text) + + self._aget_embedding = _acache_embedding + + embedding_dimension: Union[int, Awaitable[int], None] = None + if setup_mode == SetupMode.ASYNC: + embedding_dimension = self._aget_embedding_dimension() + elif setup_mode == SetupMode.SYNC: + embedding_dimension = self._get_embedding_dimension() - self.collection = self.astra_db.create_collection( - collection_name=self.collection_name, - dimension=self.embedding_dimension, - metric=self.metric, + self.astra_env = _AstraDBCollectionEnvironment( + collection_name=collection_name, + token=token, + api_endpoint=api_endpoint, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, + embedding_dimension=embedding_dimension, + metric=metric, ) + self.collection = self.astra_env.collection + self.async_collection = self.astra_env.async_collection def _get_embedding_dimension(self) -> int: return len(self._get_embedding(text="This is a sample sentence.")) + async def _aget_embedding_dimension(self) -> int: + return len(await self._aget_embedding(text="This is a sample sentence.")) + @staticmethod def _make_id(prompt: str, llm_string: str) -> str: return f"{_hash(prompt)}#{_hash(llm_string)}" def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: """Update cache based on prompt and llm_string.""" + self.astra_env.ensure_db_setup() doc_id = self._make_id(prompt, llm_string) llm_string_hash = _hash(llm_string) embedding_vector = self._get_embedding(text=prompt) @@ -1571,6 +1701,25 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N } ) + async def aupdate( + self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE + ) -> None: + """Update cache based on prompt and llm_string.""" + await self.astra_env.aensure_db_setup() + doc_id = self._make_id(prompt, llm_string) + llm_string_hash = _hash(llm_string) + embedding_vector = await self._aget_embedding(text=prompt) + body = _dumps_generations(return_val) + # + await self.async_collection.upsert( + { + "_id": doc_id, + "body_blob": body, + "llm_string_hash": llm_string_hash, + "$vector": embedding_vector, + } + ) + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" hit_with_id = self.lookup_with_id(prompt, llm_string) @@ -1579,6 +1728,14 @@ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: else: return None + async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: + """Look up based on prompt and llm_string.""" + hit_with_id = await self.alookup_with_id(prompt, llm_string) + if hit_with_id is not None: + return hit_with_id[1] + else: + return None + def lookup_with_id( self, prompt: str, llm_string: str ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: @@ -1586,6 +1743,7 @@ def lookup_with_id( Look up based on prompt and llm_string. If there are hits, return (document_id, cached_entry) for the top hit """ + self.astra_env.ensure_db_setup() prompt_embedding: List[float] = self._get_embedding(text=prompt) llm_string_hash = _hash(llm_string) @@ -1604,7 +1762,37 @@ def lookup_with_id( generations = _loads_generations(hit["body_blob"]) if generations is not None: # this protects against malformed cached items: - return (hit["_id"], generations) + return hit["_id"], generations + else: + return None + + async def alookup_with_id( + self, prompt: str, llm_string: str + ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: + """ + Look up based on prompt and llm_string. + If there are hits, return (document_id, cached_entry) for the top hit + """ + await self.astra_env.aensure_db_setup() + prompt_embedding: List[float] = await self._aget_embedding(text=prompt) + llm_string_hash = _hash(llm_string) + + hit = await self.async_collection.vector_find_one( + vector=prompt_embedding, + filter={ + "llm_string_hash": llm_string_hash, + }, + fields=["body_blob", "_id"], + include_similarity=True, + ) + + if hit is None or hit["$similarity"] < self.similarity_threshold: + return None + else: + generations = _loads_generations(hit["body_blob"]) + if generations is not None: + # this protects against malformed cached items: + return hit["_id"], generations else: return None @@ -1617,14 +1805,41 @@ def lookup_with_id_through_llm( )[1] return self.lookup_with_id(prompt, llm_string=llm_string) + async def alookup_with_id_through_llm( + self, prompt: str, llm: LLM, stop: Optional[List[str]] = None + ) -> Optional[Tuple[str, RETURN_VAL_TYPE]]: + llm_string = ( + await aget_prompts( + {**llm.dict(), **{"stop": stop}}, + [], + ) + )[1] + return await self.alookup_with_id(prompt, llm_string=llm_string) + def delete_by_document_id(self, document_id: str) -> None: """ Given this is a "similarity search" cache, an invalidation pattern that makes sense is first a lookup to get an ID, and then deleting with that ID. This is for the second step. """ + self.astra_env.ensure_db_setup() self.collection.delete_one(document_id) + async def adelete_by_document_id(self, document_id: str) -> None: + """ + Given this is a "similarity search" cache, an invalidation pattern + that makes sense is first a lookup to get an ID, and then deleting + with that ID. This is for the second step. + """ + await self.astra_env.aensure_db_setup() + await self.async_collection.delete_one(document_id) + def clear(self, **kwargs: Any) -> None: """Clear the *whole* semantic cache.""" - self.astra_db.truncate_collection(self.collection_name) + self.astra_env.ensure_db_setup() + self.collection.clear() + + async def aclear(self, **kwargs: Any) -> None: + """Clear the *whole* semantic cache.""" + await self.astra_env.aensure_db_setup() + await self.async_collection.clear() diff --git a/libs/community/langchain_community/chat_message_histories/astradb.py b/libs/community/langchain_community/chat_message_histories/astradb.py index 7257476101a38..f820480ff26b7 100644 --- a/libs/community/langchain_community/chat_message_histories/astradb.py +++ b/libs/community/langchain_community/chat_message_histories/astradb.py @@ -3,9 +3,12 @@ import json import time -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Sequence -from langchain_community.utilities.astradb import AstraDBEnvironment +from langchain_community.utilities.astradb import ( + SetupMode, + _AstraDBCollectionEnvironment, +) if TYPE_CHECKING: from astrapy.db import AstraDB @@ -45,24 +48,30 @@ def __init__( api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDB] = None, namespace: Optional[str] = None, + setup_mode: SetupMode = SetupMode.SYNC, + pre_delete_collection: bool = False, ) -> None: """Create an Astra DB chat message history.""" - astra_env = AstraDBEnvironment( + self.astra_env = _AstraDBCollectionEnvironment( + collection_name=collection_name, token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, namespace=namespace, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, ) - self.astra_db = astra_env.astra_db - self.collection = self.astra_db.create_collection(collection_name) + self.collection = self.astra_env.collection + self.async_collection = self.astra_env.async_collection self.session_id = session_id self.collection_name = collection_name @property - def messages(self) -> List[BaseMessage]: # type: ignore + def messages(self) -> List[BaseMessage]: """Retrieve all session messages from DB""" + self.astra_env.ensure_db_setup() message_blobs = [ doc["body_blob"] for doc in sorted( @@ -82,16 +91,63 @@ def messages(self) -> List[BaseMessage]: # type: ignore messages = messages_from_dict(items) return messages - def add_message(self, message: BaseMessage) -> None: + @messages.setter + def messages(self, messages: List[BaseMessage]) -> None: + raise NotImplementedError("Use add_messages instead") + + async def aget_messages(self) -> List[BaseMessage]: + """Retrieve all session messages from DB""" + await self.astra_env.aensure_db_setup() + docs = self.async_collection.paginated_find( + filter={ + "session_id": self.session_id, + }, + projection={ + "timestamp": 1, + "body_blob": 1, + }, + ) + sorted_docs = sorted( + [doc async for doc in docs], + key=lambda _doc: _doc["timestamp"], + ) + message_blobs = [doc["body_blob"] for doc in sorted_docs] + items = [json.loads(message_blob) for message_blob in message_blobs] + messages = messages_from_dict(items) + return messages + + def add_messages(self, messages: Sequence[BaseMessage]) -> None: """Write a message to the table""" - self.collection.insert_one( + self.astra_env.ensure_db_setup() + docs = [ { "timestamp": time.time(), "session_id": self.session_id, "body_blob": json.dumps(message_to_dict(message)), } - ) + for message in messages + ] + self.collection.chunked_insert_many(docs) + + async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: + """Write a message to the table""" + await self.astra_env.aensure_db_setup() + docs = [ + { + "timestamp": time.time(), + "session_id": self.session_id, + "body_blob": json.dumps(message_to_dict(message)), + } + for message in messages + ] + await self.async_collection.chunked_insert_many(docs) def clear(self) -> None: """Clear session memory from DB""" + self.astra_env.ensure_db_setup() self.collection.delete_many(filter={"session_id": self.session_id}) + + async def aclear(self) -> None: + """Clear session memory from DB""" + await self.astra_env.aensure_db_setup() + await self.async_collection.delete_many(filter={"session_id": self.session_id}) diff --git a/libs/community/langchain_community/document_loaders/astradb.py b/libs/community/langchain_community/document_loaders/astradb.py index 8a78e287a6e02..8dae5c4528a32 100644 --- a/libs/community/langchain_community/document_loaders/astradb.py +++ b/libs/community/langchain_community/document_loaders/astradb.py @@ -2,8 +2,6 @@ import json import logging -import threading -from queue import Queue from typing import ( TYPE_CHECKING, Any, @@ -16,10 +14,9 @@ ) from langchain_core.documents import Document -from langchain_core.runnables import run_in_executor from langchain_community.document_loaders.base import BaseLoader -from langchain_community.utilities.astradb import AstraDBEnvironment +from langchain_community.utilities.astradb import _AstraDBEnvironment if TYPE_CHECKING: from astrapy.db import AstraDB, AsyncAstraDB @@ -33,6 +30,7 @@ class AstraDBLoader(BaseLoader): def __init__( self, collection_name: str, + *, token: Optional[str] = None, api_endpoint: Optional[str] = None, astra_db_client: Optional[AstraDB] = None, @@ -44,7 +42,7 @@ def __init__( nb_prefetched: int = 1000, extraction_function: Callable[[Dict], str] = json.dumps, ) -> None: - astra_env = AstraDBEnvironment( + astra_env = _AstraDBEnvironment( token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, @@ -65,38 +63,27 @@ def load(self) -> List[Document]: return list(self.lazy_load()) def lazy_load(self) -> Iterator[Document]: - queue = Queue(self.nb_prefetched) # type: ignore - t = threading.Thread(target=self.fetch_results, args=(queue,)) - t.start() - while True: - doc = queue.get() - if doc is None: - break - yield doc - t.join() + for doc in self.collection.paginated_find( + filter=self.filter, + options=self.find_options, + projection=self.projection, + sort=None, + prefetched=self.nb_prefetched, + ): + yield Document( + page_content=self.extraction_function(doc), + metadata={ + "namespace": self.collection.astra_db.namespace, + "api_endpoint": self.collection.astra_db.base_url, + "collection": self.collection_name, + }, + ) async def aload(self) -> List[Document]: """Load data into Document objects.""" return [doc async for doc in self.alazy_load()] async def alazy_load(self) -> AsyncIterator[Document]: - if not self.astra_env.async_astra_db: - iterator = run_in_executor( - None, - self.collection.paginated_find, - filter=self.filter, - options=self.find_options, - projection=self.projection, - sort=None, - prefetched=True, - ) - done = object() - while True: - item = await run_in_executor(None, lambda it: next(it, done), iterator) - if item is done: - break - yield item # type: ignore[misc] - return async_collection = await self.astra_env.async_astra_db.collection( self.collection_name ) @@ -105,7 +92,7 @@ async def alazy_load(self) -> AsyncIterator[Document]: options=self.find_options, projection=self.projection, sort=None, - prefetched=True, + prefetched=self.nb_prefetched, ): yield Document( page_content=self.extraction_function(doc), @@ -115,29 +102,3 @@ async def alazy_load(self) -> AsyncIterator[Document]: "collection": self.collection_name, }, ) - - def fetch_results(self, queue: Queue): # type: ignore[no-untyped-def] - self.fetch_page_result(queue) - while self.find_options.get("pageState"): - self.fetch_page_result(queue) - queue.put(None) - - def fetch_page_result(self, queue: Queue): # type: ignore[no-untyped-def] - res = self.collection.find( - filter=self.filter, - options=self.find_options, - projection=self.projection, - sort=None, - ) - self.find_options["pageState"] = res["data"].get("nextPageState") - for doc in res["data"]["documents"]: - queue.put( - Document( - page_content=self.extraction_function(doc), - metadata={ - "namespace": self.collection.astra_db.namespace, - "api_endpoint": self.collection.astra_db.base_url, - "collection": self.collection.collection_name, - }, - ) - ) diff --git a/libs/community/langchain_community/document_loaders/cassandra.py b/libs/community/langchain_community/document_loaders/cassandra.py index a3b7732c131cb..3cef56a1cbcc8 100644 --- a/libs/community/langchain_community/document_loaders/cassandra.py +++ b/libs/community/langchain_community/document_loaders/cassandra.py @@ -29,18 +29,18 @@ def __init__( table: Optional[str] = None, session: Optional[Session] = None, keyspace: Optional[str] = None, - query: Optional[Union[str, Statement]] = None, + query: Union[str, Statement, None] = None, page_content_mapper: Callable[[Any], str] = str, metadata_mapper: Callable[[Any], dict] = lambda _: {}, *, - query_parameters: Union[dict, Sequence] = None, # type: ignore[assignment] + query_parameters: Union[dict, Sequence, None] = None, query_timeout: Optional[float] = _NOT_SET, # type: ignore[assignment] query_trace: bool = False, - query_custom_payload: dict = None, # type: ignore[assignment] + query_custom_payload: Optional[dict] = None, query_execution_profile: Any = _NOT_SET, query_paging_state: Any = None, - query_host: Host = None, - query_execute_as: str = None, # type: ignore[assignment] + query_host: Optional[Host] = None, + query_execute_as: Optional[str] = None, ) -> None: """ Document Loader for Apache Cassandra. diff --git a/libs/community/langchain_community/storage/astradb.py b/libs/community/langchain_community/storage/astradb.py index f84ae1721c837..0cb2ea310aad2 100644 --- a/libs/community/langchain_community/storage/astradb.py +++ b/libs/community/langchain_community/storage/astradb.py @@ -16,7 +16,7 @@ from langchain_core.stores import BaseStore, ByteStore -from langchain_community.utilities.astradb import AstraDBEnvironment +from langchain_community.utilities.astradb import _AstraDBEnvironment if TYPE_CHECKING: from astrapy.db import AstraDB @@ -35,7 +35,7 @@ def __init__( astra_db_client: Optional[AstraDB] = None, namespace: Optional[str] = None, ) -> None: - astra_env = AstraDBEnvironment( + astra_env = _AstraDBEnvironment( token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, diff --git a/libs/community/langchain_community/utilities/astradb.py b/libs/community/langchain_community/utilities/astradb.py index 3ad3d3274974d..c113d660792b6 100644 --- a/libs/community/langchain_community/utilities/astradb.py +++ b/libs/community/langchain_community/utilities/astradb.py @@ -1,6 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +import asyncio +import inspect +from asyncio import InvalidStateError, Task +from enum import Enum +from typing import TYPE_CHECKING, Awaitable, Optional, Union if TYPE_CHECKING: from astrapy.db import ( @@ -9,7 +13,13 @@ ) -class AstraDBEnvironment: +class SetupMode(Enum): + SYNC = 1 + ASYNC = 2 + OFF = 3 + + +class _AstraDBEnvironment: def __init__( self, token: Optional[str] = None, @@ -21,21 +31,20 @@ def __init__( self.token = token self.api_endpoint = api_endpoint astra_db = astra_db_client - self.async_astra_db = async_astra_db_client + async_astra_db = async_astra_db_client self.namespace = namespace - from astrapy import db - try: - from astrapy.db import AstraDB + from astrapy.db import ( + AstraDB, + AsyncAstraDB, + ) except (ImportError, ModuleNotFoundError): raise ImportError( "Could not import a recent astrapy python package. " "Please install it with `pip install --upgrade astrapy`." ) - supports_async = hasattr(db, "AsyncAstraDB") - # Conflicting-arg checks: if astra_db_client is not None or async_astra_db_client is not None: if token is not None or api_endpoint is not None: @@ -46,39 +55,115 @@ def __init__( if token and api_endpoint: astra_db = AstraDB( - token=self.token, - api_endpoint=self.api_endpoint, + token=token, + api_endpoint=api_endpoint, + namespace=self.namespace, + ) + async_astra_db = AsyncAstraDB( + token=token, + api_endpoint=api_endpoint, namespace=self.namespace, ) - if supports_async: - self.async_astra_db = db.AsyncAstraDB( - token=self.token, - api_endpoint=self.api_endpoint, - namespace=self.namespace, - ) if astra_db: self.astra_db = astra_db + if async_astra_db: + self.async_astra_db = async_astra_db + else: + self.async_astra_db = AsyncAstraDB( + token=self.astra_db.token, + api_endpoint=self.astra_db.base_url, + api_path=self.astra_db.api_path, + api_version=self.astra_db.api_version, + namespace=self.astra_db.namespace, + ) + elif async_astra_db: + self.async_astra_db = async_astra_db + self.astra_db = AstraDB( + token=self.async_astra_db.token, + api_endpoint=self.async_astra_db.base_url, + api_path=self.async_astra_db.api_path, + api_version=self.async_astra_db.api_version, + namespace=self.async_astra_db.namespace, + ) else: - if self.async_astra_db: - self.astra_db = AstraDB( - token=self.async_astra_db.token, - api_endpoint=self.async_astra_db.base_url, - api_path=self.async_astra_db.api_path, - api_version=self.async_astra_db.api_version, - namespace=self.async_astra_db.namespace, + raise ValueError( + "Must provide 'astra_db_client' or 'async_astra_db_client' or " + "'token' and 'api_endpoint'" + ) + + +class _AstraDBCollectionEnvironment(_AstraDBEnvironment): + def __init__( + self, + collection_name: str, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + astra_db_client: Optional[AstraDB] = None, + async_astra_db_client: Optional[AsyncAstraDB] = None, + namespace: Optional[str] = None, + setup_mode: SetupMode = SetupMode.SYNC, + pre_delete_collection: bool = False, + embedding_dimension: Union[int, Awaitable[int], None] = None, + metric: Optional[str] = None, + ) -> None: + from astrapy.db import AstraDBCollection, AsyncAstraDBCollection + + super().__init__( + token, api_endpoint, astra_db_client, async_astra_db_client, namespace + ) + self.collection_name = collection_name + self.collection = AstraDBCollection( + collection_name=collection_name, + astra_db=self.astra_db, + ) + + self.async_collection = AsyncAstraDBCollection( + collection_name=collection_name, + astra_db=self.async_astra_db, + ) + + self.async_setup_db_task: Optional[Task] = None + if setup_mode == SetupMode.ASYNC: + async_astra_db = self.async_astra_db + + async def _setup_db() -> None: + if pre_delete_collection: + await async_astra_db.delete_collection(collection_name) + if inspect.isawaitable(embedding_dimension): + dimension = await embedding_dimension + else: + dimension = embedding_dimension + await async_astra_db.create_collection( + collection_name, dimension=dimension, metric=metric ) - else: + + self.async_setup_db_task = asyncio.create_task(_setup_db()) + elif setup_mode == SetupMode.SYNC: + if pre_delete_collection: + self.astra_db.delete_collection(collection_name) + if inspect.isawaitable(embedding_dimension): raise ValueError( - "Must provide 'astra_db_client' or 'async_astra_db_client' or " - "'token' and 'api_endpoint'" + "Cannot use an awaitable embedding_dimension with async_setup " + "set to False" ) - - if not self.async_astra_db and self.astra_db and supports_async: - self.async_astra_db = db.AsyncAstraDB( - token=self.astra_db.token, - api_endpoint=self.astra_db.base_url, - api_path=self.astra_db.api_path, - api_version=self.astra_db.api_version, - namespace=self.astra_db.namespace, + self.astra_db.create_collection( + collection_name, + dimension=embedding_dimension, # type: ignore[arg-type] + metric=metric, ) + + def ensure_db_setup(self) -> None: + if self.async_setup_db_task: + try: + self.async_setup_db_task.result() + except InvalidStateError: + raise ValueError( + "Asynchronous setup of the DB not finished. " + "NB: AstraDB components sync methods shouldn't be called from the " + "event loop. Consider using their async equivalents." + ) + + async def aensure_db_setup(self) -> None: + if self.async_setup_db_task: + await self.async_setup_db_task diff --git a/libs/community/langchain_community/utilities/graphql.py b/libs/community/langchain_community/utilities/graphql.py index 87be94d09c362..a576419e5be07 100644 --- a/libs/community/langchain_community/utilities/graphql.py +++ b/libs/community/langchain_community/utilities/graphql.py @@ -37,7 +37,10 @@ def validate_environment(cls, values: Dict) -> Dict: url=values["graphql_endpoint"], headers=headers, ) - client = Client(transport=transport, fetch_schema_from_transport=True) + fetch_schema_from_transport = values.get("fetch_schema_from_transport", True) + client = Client( + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport + ) values["gql_client"] = client values["gql_function"] = gql return values diff --git a/libs/community/tests/integration_tests/document_loaders/test_astradb.py b/libs/community/tests/integration_tests/document_loaders/test_astradb.py index 8f9146aacb51f..b0a1104f82316 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_astradb.py +++ b/libs/community/tests/integration_tests/document_loaders/test_astradb.py @@ -15,7 +15,7 @@ import json import os import uuid -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, AsyncIterator, Iterator import pytest @@ -37,12 +37,12 @@ def _has_env_vars() -> bool: @pytest.fixture -def astra_db_collection() -> AstraDBCollection: +def astra_db_collection() -> Iterator[AstraDBCollection]: from astrapy.db import AstraDB astra_db = AstraDB( - token=ASTRA_DB_APPLICATION_TOKEN, - api_endpoint=ASTRA_DB_API_ENDPOINT, + token=ASTRA_DB_APPLICATION_TOKEN or "", + api_endpoint=ASTRA_DB_API_ENDPOINT or "", namespace=ASTRA_DB_KEYSPACE, ) collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}" @@ -58,12 +58,12 @@ def astra_db_collection() -> AstraDBCollection: @pytest.fixture -async def async_astra_db_collection() -> AsyncAstraDBCollection: +async def async_astra_db_collection() -> AsyncIterator[AsyncAstraDBCollection]: from astrapy.db import AsyncAstraDB astra_db = AsyncAstraDB( - token=ASTRA_DB_APPLICATION_TOKEN, - api_endpoint=ASTRA_DB_API_ENDPOINT, + token=ASTRA_DB_APPLICATION_TOKEN or "", + api_endpoint=ASTRA_DB_API_ENDPOINT or "", namespace=ASTRA_DB_KEYSPACE, ) collection_name = f"lc_test_loader_{str(uuid.uuid4()).split('-')[0]}" @@ -167,5 +167,5 @@ async def test_extraction_function_async( find_options={"limit": 30}, extraction_function=lambda x: x["foo"], ) - doc = await anext(loader.alazy_load()) # type: ignore[name-defined] + doc = await loader.alazy_load().__anext__() assert doc.page_content == "bar" diff --git a/libs/community/tests/integration_tests/document_loaders/test_cassandra.py b/libs/community/tests/integration_tests/document_loaders/test_cassandra.py index 5562188eced86..a93a6abba6824 100644 --- a/libs/community/tests/integration_tests/document_loaders/test_cassandra.py +++ b/libs/community/tests/integration_tests/document_loaders/test_cassandra.py @@ -2,7 +2,7 @@ Test of Cassandra document loader class `CassandraLoader` """ import os -from typing import Any +from typing import Any, Iterator import pytest from langchain_core.documents import Document @@ -14,7 +14,7 @@ @pytest.fixture(autouse=True, scope="session") -def keyspace() -> str: # type: ignore[misc] +def keyspace() -> Iterator[str]: import cassio from cassandra.cluster import Cluster from cassio.config import check_resolve_session, resolve_keyspace diff --git a/libs/langchain/README.md b/libs/langchain/README.md index 2c8d69bcb4919..bd16bdf6e6fc8 100644 --- a/libs/langchain/README.md +++ b/libs/langchain/README.md @@ -20,7 +20,7 @@ Looking for the JS/TS version? Check out [LangChain.js](https://github.com/langc To help you ship LangChain apps to production faster, check out [LangSmith](https://smith.langchain.com). [LangSmith](https://smith.langchain.com) is a unified developer platform for building, testing, and monitoring LLM applications. -Fill out [this form](https://airtable.com/appwQzlErAS2qiP0L/shrGtGaVBVAz7NcV2) to get off the waitlist or speak with our sales team +Fill out [this form](https://www.langchain.com/contact-sales) to speak with our sales team. ## Quick Install diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 7f8aec6d21dbf..5c5bd1aadfa0c 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -464,7 +464,7 @@ def prep_outputs( return {**inputs, **outputs} def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]: - """Validate and prepare chain inputs, including adding inputs from memory. + """Prepare chain inputs, including adding inputs from memory. Args: inputs: Dictionary of raw inputs, or single input if chain expects diff --git a/libs/langchain/tests/integration_tests/cache/test_astradb.py b/libs/langchain/tests/integration_tests/cache/test_astradb.py index 37d538f8004c2..0c973e60fd1f2 100644 --- a/libs/langchain/tests/integration_tests/cache/test_astradb.py +++ b/libs/langchain/tests/integration_tests/cache/test_astradb.py @@ -12,9 +12,12 @@ """ import os -from typing import Iterator +from typing import AsyncIterator, Iterator import pytest +from langchain_community.utilities.astradb import SetupMode +from langchain_core.caches import BaseCache +from langchain_core.language_models import LLM from langchain_core.outputs import Generation, LLMResult from langchain.cache import AstraDBCache, AstraDBSemanticCache @@ -41,7 +44,22 @@ def astradb_cache() -> Iterator[AstraDBCache]: namespace=os.environ.get("ASTRA_DB_KEYSPACE"), ) yield cache - cache.astra_db.delete_collection("lc_integration_test_cache") + cache.collection.astra_db.delete_collection("lc_integration_test_cache") + + +@pytest.fixture +async def async_astradb_cache() -> AsyncIterator[AstraDBCache]: + cache = AstraDBCache( + collection_name="lc_integration_test_cache_async", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + setup_mode=SetupMode.ASYNC, + ) + yield cache + await cache.async_collection.astra_db.delete_collection( + "lc_integration_test_cache_async" + ) @pytest.fixture(scope="module") @@ -55,46 +73,87 @@ def astradb_semantic_cache() -> Iterator[AstraDBSemanticCache]: embedding=fake_embe, ) yield sem_cache - sem_cache.astra_db.delete_collection("lc_integration_test_cache") + sem_cache.collection.astra_db.delete_collection("lc_integration_test_sem_cache") + + +@pytest.fixture +async def async_astradb_semantic_cache() -> AsyncIterator[AstraDBSemanticCache]: + fake_embe = FakeEmbeddings() + sem_cache = AstraDBSemanticCache( + collection_name="lc_integration_test_sem_cache_async", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + embedding=fake_embe, + setup_mode=SetupMode.ASYNC, + ) + yield sem_cache + sem_cache.collection.astra_db.delete_collection( + "lc_integration_test_sem_cache_async" + ) @pytest.mark.requires("astrapy") @pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") class TestAstraDBCaches: def test_astradb_cache(self, astradb_cache: AstraDBCache) -> None: - set_llm_cache(astradb_cache) + self.do_cache_test(FakeLLM(), astradb_cache, "foo") + + async def test_astradb_cache_async(self, async_astradb_cache: AstraDBCache) -> None: + await self.ado_cache_test(FakeLLM(), async_astradb_cache, "foo") + + def test_astradb_semantic_cache( + self, astradb_semantic_cache: AstraDBSemanticCache + ) -> None: llm = FakeLLM() + self.do_cache_test(llm, astradb_semantic_cache, "bar") + output = llm.generate(["bar"]) # 'fizz' is erased away now + assert output != LLMResult( + generations=[[Generation(text="fizz")]], + llm_output={}, + ) + astradb_semantic_cache.clear() + + async def test_astradb_semantic_cache_async( + self, async_astradb_semantic_cache: AstraDBSemanticCache + ) -> None: + llm = FakeLLM() + await self.ado_cache_test(llm, async_astradb_semantic_cache, "bar") + output = await llm.agenerate(["bar"]) # 'fizz' is erased away now + assert output != LLMResult( + generations=[[Generation(text="fizz")]], + llm_output={}, + ) + await async_astradb_semantic_cache.aclear() + + @staticmethod + def do_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None: + set_llm_cache(cache) params = llm.dict() params["stop"] = None llm_string = str(sorted([(k, v) for k, v in params.items()])) get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) - output = llm.generate(["foo"]) - print(output) # noqa: T201 + output = llm.generate([prompt]) expected_output = LLMResult( generations=[[Generation(text="fizz")]], llm_output={}, ) - print(expected_output) # noqa: T201 assert output == expected_output - astradb_cache.clear() + # clear the cache + cache.clear() - def test_cassandra_semantic_cache( - self, astradb_semantic_cache: AstraDBSemanticCache - ) -> None: - set_llm_cache(astradb_semantic_cache) - llm = FakeLLM() + @staticmethod + async def ado_cache_test(llm: LLM, cache: BaseCache, prompt: str) -> None: + set_llm_cache(cache) params = llm.dict() params["stop"] = None llm_string = str(sorted([(k, v) for k, v in params.items()])) - get_llm_cache().update("foo", llm_string, [Generation(text="fizz")]) - output = llm.generate(["bar"]) # same embedding as 'foo' + await get_llm_cache().aupdate("foo", llm_string, [Generation(text="fizz")]) + output = await llm.agenerate([prompt]) expected_output = LLMResult( generations=[[Generation(text="fizz")]], llm_output={}, ) assert output == expected_output # clear the cache - astradb_semantic_cache.clear() - output = llm.generate(["bar"]) # 'fizz' is erased away now - assert output != expected_output - astradb_semantic_cache.clear() + await cache.aclear() diff --git a/libs/langchain/tests/integration_tests/memory/test_astradb.py b/libs/langchain/tests/integration_tests/memory/test_astradb.py index a8ed9e7574ba7..4caf39985ec8e 100644 --- a/libs/langchain/tests/integration_tests/memory/test_astradb.py +++ b/libs/langchain/tests/integration_tests/memory/test_astradb.py @@ -1,10 +1,11 @@ import os -from typing import Iterable +from typing import AsyncIterable, Iterable import pytest from langchain_community.chat_message_histories.astradb import ( AstraDBChatMessageHistory, ) +from langchain_community.utilities.astradb import SetupMode from langchain_core.messages import AIMessage, HumanMessage from langchain.memory import ConversationBufferMemory @@ -29,7 +30,7 @@ def history1() -> Iterable[AstraDBChatMessageHistory]: namespace=os.environ.get("ASTRA_DB_KEYSPACE"), ) yield history1 - history1.astra_db.delete_collection("langchain_cmh_test") + history1.collection.astra_db.delete_collection("langchain_cmh_test") @pytest.fixture(scope="function") @@ -42,7 +43,35 @@ def history2() -> Iterable[AstraDBChatMessageHistory]: namespace=os.environ.get("ASTRA_DB_KEYSPACE"), ) yield history2 - history2.astra_db.delete_collection("langchain_cmh_test") + history2.collection.astra_db.delete_collection("langchain_cmh_test") + + +@pytest.fixture +async def async_history1() -> AsyncIterable[AstraDBChatMessageHistory]: + history1 = AstraDBChatMessageHistory( + session_id="async-session-test-1", + collection_name="langchain_cmh_test", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + setup_mode=SetupMode.ASYNC, + ) + yield history1 + await history1.async_collection.astra_db.delete_collection("langchain_cmh_test") + + +@pytest.fixture(scope="function") +async def async_history2() -> AsyncIterable[AstraDBChatMessageHistory]: + history2 = AstraDBChatMessageHistory( + session_id="async-session-test-2", + collection_name="langchain_cmh_test", + token=os.environ["ASTRA_DB_APPLICATION_TOKEN"], + api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"], + namespace=os.environ.get("ASTRA_DB_KEYSPACE"), + setup_mode=SetupMode.ASYNC, + ) + yield history2 + await history2.async_collection.astra_db.delete_collection("langchain_cmh_test") @pytest.mark.requires("astrapy") @@ -58,8 +87,12 @@ def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None: assert memory.chat_memory.messages == [] # add some messages - memory.chat_memory.add_ai_message("This is me, the AI") - memory.chat_memory.add_user_message("This is me, the human") + memory.chat_memory.add_messages( + [ + AIMessage(content="This is me, the AI"), + HumanMessage(content="This is me, the human"), + ] + ) messages = memory.chat_memory.messages expected = [ @@ -74,6 +107,41 @@ def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None: assert memory.chat_memory.messages == [] +@pytest.mark.requires("astrapy") +@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +async def test_memory_with_message_store_async( + async_history1: AstraDBChatMessageHistory, +) -> None: + """Test the memory with a message store.""" + memory = ConversationBufferMemory( + memory_key="baz", + chat_memory=async_history1, + return_messages=True, + ) + + assert await memory.chat_memory.aget_messages() == [] + + # add some messages + await memory.chat_memory.aadd_messages( + [ + AIMessage(content="This is me, the AI"), + HumanMessage(content="This is me, the human"), + ] + ) + + messages = await memory.chat_memory.aget_messages() + expected = [ + AIMessage(content="This is me, the AI"), + HumanMessage(content="This is me, the human"), + ] + assert messages == expected + + # clear the store + await memory.chat_memory.aclear() + + assert await memory.chat_memory.aget_messages() == [] + + @pytest.mark.requires("astrapy") @pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") def test_memory_separate_session_ids( @@ -91,7 +159,7 @@ def test_memory_separate_session_ids( return_messages=True, ) - memory1.chat_memory.add_ai_message("Just saying.") + memory1.chat_memory.add_messages([AIMessage(content="Just saying.")]) assert memory2.chat_memory.messages == [] @@ -102,3 +170,33 @@ def test_memory_separate_session_ids( memory1.chat_memory.clear() assert memory1.chat_memory.messages == [] + + +@pytest.mark.requires("astrapy") +@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars") +async def test_memory_separate_session_ids_async( + async_history1: AstraDBChatMessageHistory, async_history2: AstraDBChatMessageHistory +) -> None: + """Test that separate session IDs do not share entries.""" + memory1 = ConversationBufferMemory( + memory_key="mk1", + chat_memory=async_history1, + return_messages=True, + ) + memory2 = ConversationBufferMemory( + memory_key="mk2", + chat_memory=async_history2, + return_messages=True, + ) + + await memory1.chat_memory.aadd_messages([AIMessage(content="Just saying.")]) + + assert await memory2.chat_memory.aget_messages() == [] + + await memory2.chat_memory.aclear() + + assert await memory1.chat_memory.aget_messages() != [] + + await memory1.chat_memory.aclear() + + assert await memory1.chat_memory.aget_messages() == []