Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain: included concurrency limit in RecursiveUrlLoader #15537

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
link_regex: Union[str, re.Pattern, None] = None,
headers: Optional[dict] = None,
check_response_status: bool = False,
concurrent_requests_limit: Optional[int] = 100
) -> None:
"""Initialize with URL to crawl and any subdirectories to exclude.

Expand All @@ -117,6 +118,8 @@ def __init__(
link_regex: Regex for extracting sub-links from the raw html of a web page.
check_response_status: If True, check HTTP response status and skip
URLs with error responses (400-599).
concurrent_requests_limit: Maximum number of concurrent http requests when
using asynchronous loading.
"""

self.url = url
Expand All @@ -142,6 +145,7 @@ def __init__(
self._lock = asyncio.Lock() if self.use_async else None
self.headers = headers
self.check_response_status = check_response_status
self.concurrent_requests_limit = concurrent_requests_limit

def _get_child_links_recursive(
self, url: str, visited: Set[str], *, depth: int = 0
Expand Down Expand Up @@ -196,16 +200,18 @@ async def _async_get_child_links_recursive(
self,
url: str,
visited: Set[str],
semaphore: asyncio.Semaphore,
*,
session: Optional[aiohttp.ClientSession] = None,
depth: int = 0,
depth: int = 0
) -> List[Document]:
"""Recursively get all child links starting with the path of the input URL.

Args:
url: The URL to crawl.
visited: A set of visited URLs.
depth: To reach the current url, how many pages have been visited.
semaphore: An object to keep track of the number of concurrent requests.
"""
try:
import aiohttp
Expand All @@ -231,19 +237,26 @@ async def _async_get_child_links_recursive(
)
async with self._lock: # type: ignore
visited.add(url)
try:
async with session.get(url) as response:
text = await response.text()
if self.check_response_status and 400 <= response.status <= 599:
raise ValueError(f"Received HTTP status {response.status}")
except (aiohttp.client_exceptions.InvalidURL, Exception) as e:
logger.warning(
f"Unable to load {url}. Received error {e} of type "
f"{e.__class__.__name__}"
)
if close_session:
await session.close()
return []

async with semaphore:
try:
async with session.get(url) as response:
text = await response.text()
if self.check_response_status and 400 <= response.status <= 599:
raise ValueError(f"Received HTTP status {response.status}")
except (aiohttp.client_exceptions.InvalidURL, Exception) as e:
logger.warning(
f"Unable to load {url}. Received error {e} of type "
f"{e.__class__.__name__}"
)
if close_session:
await session.close()
return []

# Wait if the concurrent requests limit is reached
if semaphore.locked():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this needed for, doesn't L241 already handle this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I included it to allow the loader to complete as many requests as possible once the limit is reached before sending other requests. My idea was to insert a delay to avoid HTTP errors caused by an excessive number of simultaneous requests.

If you think that's redundant I can quickly update the code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@baskaryan could you please tell me if I need to update the code?

await asyncio.sleep(1)

results = []
content = self.extractor(text)
if content:
Expand All @@ -270,7 +283,7 @@ async def _async_get_child_links_recursive(
for link in to_visit:
sub_tasks.append(
self._async_get_child_links_recursive(
link, visited, session=session, depth=depth + 1
link, visited, semaphore, session=session, depth=depth + 1
)
)
next_results = await asyncio.gather(*sub_tasks)
Expand All @@ -291,8 +304,9 @@ def lazy_load(self) -> Iterator[Document]:
but it will still work in the expected way, just not lazy."""
visited: Set[str] = set()
if self.use_async:
semaphore = asyncio.Semaphore(self.concurrent_requests_limit)
results = asyncio.run(
self._async_get_child_links_recursive(self.url, visited)
self._async_get_child_links_recursive(self.url, visited, semaphore)
)
return iter(results or [])
else:
Expand Down
Loading