From af8c5c185bf1daf6e42fd31ea3b833b9f578a73a Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 31 Jan 2024 20:08:11 +0100 Subject: [PATCH] langchain[minor],community[minor]: Add async methods in BaseLoader (#16634) Adds: * methods `aload()` and `alazy_load()` to interface `BaseLoader` * implementation for class `MergedDataLoader ` * support for class `BaseLoader` in async function `aindex()` with unit tests Note: this is compatible with existing `aload()` methods that some loaders already had. **Twitter handle:** @cbornet_ --------- Co-authored-by: Eugene Yurtsev --- .../document_loaders/base.py | 17 ++++-- .../document_loaders/merge.py | 8 ++- .../unit_tests/document_loaders/test_base.py | 21 +++++++- libs/langchain/langchain/indexes/_api.py | 24 +++++---- .../tests/unit_tests/indexes/test_indexing.py | 53 ++++++------------- 5 files changed, 71 insertions(+), 52 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/base.py b/libs/community/langchain_community/document_loaders/base.py index 8474fa0a579ec..7a3e5a2706c80 100644 --- a/libs/community/langchain_community/document_loaders/base.py +++ b/libs/community/langchain_community/document_loaders/base.py @@ -2,9 +2,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Iterator, List, Optional +from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional from langchain_core.documents import Document +from langchain_core.runnables import run_in_executor from langchain_community.document_loaders.blob_loaders import Blob @@ -52,14 +53,22 @@ def load_and_split( # Attention: This method will be upgraded into an abstractmethod once it's # implemented in all the existing subclasses. - def lazy_load( - self, - ) -> Iterator[Document]: + def lazy_load(self) -> Iterator[Document]: """A lazy loader for Documents.""" raise NotImplementedError( f"{self.__class__.__name__} does not implement lazy_load()" ) + async def alazy_load(self) -> AsyncIterator[Document]: + """A lazy loader for Documents.""" + iterator = await run_in_executor(None, self.lazy_load) + done = object() + while True: + doc = await run_in_executor(None, next, iterator, done) + if doc is done: + break + yield doc + class BaseBlobParser(ABC): """Abstract interface for blob parsers. diff --git a/libs/community/langchain_community/document_loaders/merge.py b/libs/community/langchain_community/document_loaders/merge.py index c93963e70cace..9ef1a0fd3c102 100644 --- a/libs/community/langchain_community/document_loaders/merge.py +++ b/libs/community/langchain_community/document_loaders/merge.py @@ -1,4 +1,4 @@ -from typing import Iterator, List +from typing import AsyncIterator, Iterator, List from langchain_core.documents import Document @@ -26,3 +26,9 @@ def lazy_load(self) -> Iterator[Document]: def load(self) -> List[Document]: """Load docs.""" return list(self.lazy_load()) + + async def alazy_load(self) -> AsyncIterator[Document]: + """Lazy load docs from each individual loader.""" + for loader in self.loaders: + async for document in loader.alazy_load(): + yield document diff --git a/libs/community/tests/unit_tests/document_loaders/test_base.py b/libs/community/tests/unit_tests/document_loaders/test_base.py index 18a3646aa9acd..e966cf193b6bf 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_base.py +++ b/libs/community/tests/unit_tests/document_loaders/test_base.py @@ -1,9 +1,9 @@ """Test Base Schema of documents.""" -from typing import Iterator +from typing import Iterator, List from langchain_core.documents import Document -from langchain_community.document_loaders.base import BaseBlobParser +from langchain_community.document_loaders.base import BaseBlobParser, BaseLoader from langchain_community.document_loaders.blob_loaders import Blob @@ -27,3 +27,20 @@ def lazy_parse(self, blob: Blob) -> Iterator[Document]: docs = parser.parse(Blob(data="who?")) assert len(docs) == 1 assert docs[0].page_content == "foo" + + +async def test_default_aload() -> None: + class FakeLoader(BaseLoader): + def load(self) -> List[Document]: + return list(self.lazy_load()) + + def lazy_load(self) -> Iterator[Document]: + yield from [ + Document(page_content="foo"), + Document(page_content="bar"), + ] + + loader = FakeLoader() + docs = loader.load() + assert docs == [Document(page_content="foo"), Document(page_content="bar")] + assert docs == [doc async for doc in loader.alazy_load()] diff --git a/libs/langchain/langchain/indexes/_api.py b/libs/langchain/langchain/indexes/_api.py index 2f91c2ae45c6c..8dcf6e0c93a47 100644 --- a/libs/langchain/langchain/indexes/_api.py +++ b/libs/langchain/langchain/indexes/_api.py @@ -391,7 +391,7 @@ async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]: async def aindex( - docs_source: Union[Iterable[Document], AsyncIterator[Document]], + docs_source: Union[BaseLoader, Iterable[Document], AsyncIterator[Document]], record_manager: RecordManager, vector_store: VectorStore, *, @@ -469,16 +469,22 @@ async def aindex( # implementation which just raises a NotImplementedError raise ValueError("Vectorstore has not implemented the delete method") - if isinstance(docs_source, BaseLoader): - raise NotImplementedError( - "Not supported yet. Please pass an async iterator of documents." - ) async_doc_iterator: AsyncIterator[Document] - - if hasattr(docs_source, "__aiter__"): - async_doc_iterator = docs_source # type: ignore[assignment] + if isinstance(docs_source, BaseLoader): + try: + async_doc_iterator = docs_source.alazy_load() + except NotImplementedError: + # Exception triggered when neither lazy_load nor alazy_load are implemented. + # * The default implementation of alazy_load uses lazy_load. + # * The default implementation of lazy_load raises NotImplementedError. + # In such a case, we use the load method and convert it to an async + # iterator. + async_doc_iterator = _to_async_iterator(docs_source.load()) else: - async_doc_iterator = _to_async_iterator(docs_source) + if hasattr(docs_source, "__aiter__"): + async_doc_iterator = docs_source # type: ignore[assignment] + else: + async_doc_iterator = _to_async_iterator(docs_source) source_id_assigner = _get_source_id_assigner(source_id_key) diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index 5febe24ffeeb8..59ab527543bae 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -43,15 +43,8 @@ def load(self) -> List[Document]: async def alazy_load( self, ) -> AsyncIterator[Document]: - async def async_generator() -> AsyncIterator[Document]: - for document in self.documents: - yield document - - return async_generator() - - async def aload(self) -> List[Document]: - """Load the documents from the source.""" - return [doc async for doc in await self.alazy_load()] + for document in self.documents: + yield document class InMemoryVectorStore(VectorStore): @@ -232,7 +225,7 @@ async def test_aindexing_same_content( ] ) - assert await aindex(await loader.alazy_load(), arecord_manager, vector_store) == { + assert await aindex(loader, arecord_manager, vector_store) == { "num_added": 2, "num_deleted": 0, "num_skipped": 0, @@ -243,9 +236,7 @@ async def test_aindexing_same_content( for _ in range(2): # Run the indexing again - assert await aindex( - await loader.alazy_load(), arecord_manager, vector_store - ) == { + assert await aindex(loader, arecord_manager, vector_store) == { "num_added": 0, "num_deleted": 0, "num_skipped": 2, @@ -347,9 +338,7 @@ async def test_aindex_simple_delete_full( with patch.object( arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp() ): - assert await aindex( - await loader.alazy_load(), arecord_manager, vector_store, cleanup="full" - ) == { + assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { "num_added": 2, "num_deleted": 0, "num_skipped": 0, @@ -359,9 +348,7 @@ async def test_aindex_simple_delete_full( with patch.object( arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp() ): - assert await aindex( - await loader.alazy_load(), arecord_manager, vector_store, cleanup="full" - ) == { + assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { "num_added": 0, "num_deleted": 0, "num_skipped": 2, @@ -382,9 +369,7 @@ async def test_aindex_simple_delete_full( with patch.object( arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() ): - assert await aindex( - await loader.alazy_load(), arecord_manager, vector_store, cleanup="full" - ) == { + assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { "num_added": 1, "num_deleted": 1, "num_skipped": 1, @@ -402,9 +387,7 @@ async def test_aindex_simple_delete_full( with patch.object( arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() ): - assert await aindex( - await loader.alazy_load(), arecord_manager, vector_store, cleanup="full" - ) == { + assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { "num_added": 0, "num_deleted": 0, "num_skipped": 2, @@ -473,7 +456,7 @@ async def test_aincremental_fails_with_bad_source_ids( with pytest.raises(ValueError): # Should raise an error because no source id function was specified await aindex( - await loader.alazy_load(), + loader, arecord_manager, vector_store, cleanup="incremental", @@ -482,7 +465,7 @@ async def test_aincremental_fails_with_bad_source_ids( with pytest.raises(ValueError): # Should raise an error because no source id function was specified await aindex( - await loader.alazy_load(), + loader, arecord_manager, vector_store, cleanup="incremental", @@ -593,7 +576,7 @@ async def test_ano_delete( arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() ): assert await aindex( - await loader.alazy_load(), + loader, arecord_manager, vector_store, cleanup=None, @@ -610,7 +593,7 @@ async def test_ano_delete( arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() ): assert await aindex( - await loader.alazy_load(), + loader, arecord_manager, vector_store, cleanup=None, @@ -640,7 +623,7 @@ async def test_ano_delete( arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() ): assert await aindex( - await loader.alazy_load(), + loader, arecord_manager, vector_store, cleanup=None, @@ -779,7 +762,7 @@ async def test_aincremental_delete( arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() ): assert await aindex( - await loader.alazy_load(), + loader.lazy_load(), arecord_manager, vector_store, cleanup="incremental", @@ -803,7 +786,7 @@ async def test_aincremental_delete( arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() ): assert await aindex( - await loader.alazy_load(), + loader.lazy_load(), arecord_manager, vector_store, cleanup="incremental", @@ -838,7 +821,7 @@ async def test_aincremental_delete( arecord_manager, "aget_time", return_value=datetime(2021, 1, 3).timestamp() ): assert await aindex( - await loader.alazy_load(), + loader.lazy_load(), arecord_manager, vector_store, cleanup="incremental", @@ -883,9 +866,7 @@ async def test_aindexing_with_no_docs( """Check edge case when loader returns no new docs.""" loader = ToyLoader(documents=[]) - assert await aindex( - await loader.alazy_load(), arecord_manager, vector_store, cleanup="full" - ) == { + assert await aindex(loader, arecord_manager, vector_store, cleanup="full") == { "num_added": 0, "num_deleted": 0, "num_skipped": 0,