Skip to content

Commit

Permalink
core[minor]: Add BaseModel.rate_limiter, RateLimiter abstraction and …
Browse files Browse the repository at this point in the history
…in-memory implementation (#24669)

This PR proposes to create a rate limiter in the chat model directly,
and would replace: #21992

It resolves most of the constraints that the Runnable rate limiter
introduced:

1. It's not annoying to apply the rate limiter to existing code; i.e., 
possible to roll out the change at the location where the model is
instantiated,
rather than at every location where the model is used! (Which is
necessary
   if the model is used in different ways in a given application.)
2. batch rate limiting is enforced properly
3. the rate limiter works correctly with streaming
4. the rate limiter is aware of the cache
5. The rate limiter can take into account information about the inputs
into the
model (we can add optional inputs to it down-the road together with
outputs!)

The only downside is that information will not be properly reflected in
tracing
as we don't have any metadata evens about a rate limiter. So the total
time
spent on a model invocation will be: 

* time spent waiting for the rate limiter
* time spend on the actual model request

## Example

```python
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_groq import ChatGroq

groq = ChatGroq(rate_limiter=InMemoryRateLimiter(check_every_n_seconds=1))
groq.invoke('hello')
```
  • Loading branch information
eyurtsev authored Jul 26, 2024
1 parent c623ae6 commit 20690db
Show file tree
Hide file tree
Showing 7 changed files with 300 additions and 123 deletions.
25 changes: 25 additions & 0 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
Field,
root_validator,
)
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers._streaming import _StreamingCallbackHandler
Expand Down Expand Up @@ -210,6 +211,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED] Callback manager to add to the run trace."""

rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
"""An optional rate limiter to use for limiting the number of requests."""

@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used.
Expand Down Expand Up @@ -341,6 +345,10 @@ def stream(
batch_size=1,
)
generation: Optional[ChatGenerationChunk] = None

if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)

try:
for chunk in self._stream(messages, stop=stop, **kwargs):
if chunk.message.id is None:
Expand Down Expand Up @@ -412,6 +420,9 @@ async def astream(
batch_size=1,
)

if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)

generation: Optional[ChatGenerationChunk] = None
try:
async for chunk in self._astream(
Expand Down Expand Up @@ -742,6 +753,13 @@ def _generate_with_cache(
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)

# Apply the rate limiter after checking the cache, since
# we usually don't want to rate limit cache lookups, but
# we do want to rate limit API requests.
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)

# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _stream not implemented
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
Expand Down Expand Up @@ -822,6 +840,13 @@ async def _agenerate_with_cache(
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)

# Apply the rate limiter after checking the cache, since
# we usually don't want to rate limit cache lookups, but
# we do want to rate limit API requests.
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)

# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _astream not implemented
if (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
"""Interface and implementation for time based rate limiters.
This module defines an interface for rate limiting requests based on time.
The interface cannot account for the size of the request or any other factors.
The module also provides an in-memory implementation of the rate limiter.
"""
"""Interface for a rate limiter and an in-memory rate limiter."""

from __future__ import annotations

Expand All @@ -14,22 +7,14 @@
import threading
import time
from typing import (
Any,
Optional,
cast,
)

from langchain_core._api import beta
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.base import (
Input,
Output,
Runnable,
)


@beta(message="Introduced in 0.2.24. API subject to change.")
class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
class BaseRateLimiter(abc.ABC):
"""Base class for rate limiters.
Usage of the base limiter is through the acquire and aacquire methods depending
Expand All @@ -41,18 +26,10 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
Current limitations:
- The rate limiter is not designed to work across different processes. It is
an in-memory rate limiter, but it is thread safe.
- The rate limiter only supports time-based rate limiting. It does not take
into account the size of the request or any other factors.
- The current implementation does not handle streaming inputs well and will
consume all inputs even if the rate limit has not been reached. Better support
for streaming inputs will be added in the future.
- When the rate limiter is combined with another runnable via a RunnableSequence,
usage of .batch() or .abatch() will only respect the average rate limit.
There will be bursty behavior as .batch() and .abatch() wait for each step
to complete before starting the next step. One way to mitigate this is to
use batch_as_completed() or abatch_as_completed().
- Rate limiting information is not surfaced in tracing or callbacks. This means
that the total time it takes to invoke a chat model will encompass both
the time spent waiting for tokens and the time spent making the request.
.. versionadded:: 0.2.24
"""
Expand Down Expand Up @@ -95,55 +72,10 @@ async def aacquire(self, *, blocking: bool = True) -> bool:
True if the tokens were successfully acquired, False otherwise.
"""

def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Invoke the rate limiter.
This is a blocking call that waits until the given number of tokens are
available.
Args:
input: The input to the rate limiter.
config: The configuration for the rate limiter.
**kwargs: Additional keyword arguments.
Returns:
The output of the rate limiter.
"""

def _invoke(input: Input) -> Output:
"""Invoke the rate limiter. Internal function."""
self.acquire(blocking=True)
return cast(Output, input)

return self._call_with_config(_invoke, input, config, **kwargs)

async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Invoke the rate limiter. Async version.
This is a blocking call that waits until the given number of tokens are
available.
Args:
input: The input to the rate limiter.
config: The configuration for the rate limiter.
**kwargs: Additional keyword arguments.
"""

async def _ainvoke(input: Input) -> Output:
"""Invoke the rate limiter. Internal function."""
await self.aacquire(blocking=True)
return cast(Output, input)

return await self._acall_with_config(_ainvoke, input, config, **kwargs)


@beta(message="Introduced in 0.2.24. API subject to change.")
class InMemoryRateLimiter(BaseRateLimiter):
"""An in memory rate limiter.
"""An in memory rate limiter based on a token bucket algorithm.
This is an in memory rate limiter, so it cannot rate limit across
different processes.
Expand All @@ -168,19 +100,13 @@ class InMemoryRateLimiter(BaseRateLimiter):
an in-memory rate limiter, but it is thread safe.
- The rate limiter only supports time-based rate limiting. It does not take
into account the size of the request or any other factors.
- The current implementation does not handle streaming inputs well and will
consume all inputs even if the rate limit has not been reached. Better support
for streaming inputs will be added in the future.
- When the rate limiter is combined with another runnable via a RunnableSequence,
usage of .batch() or .abatch() will only respect the average rate limit.
There will be bursty behavior as .batch() and .abatch() wait for each step
to complete before starting the next step. One way to mitigate this is to
use batch_as_completed() or abatch_as_completed().
Example:
.. code-block:: python
from langchain_core import InMemoryRateLimiter
from langchain_core.runnables import RunnableLambda, InMemoryRateLimiter
rate_limiter = InMemoryRateLimiter(
Expand Down Expand Up @@ -239,7 +165,7 @@ def __init__(
self.check_every_n_seconds = check_every_n_seconds

def _consume(self) -> bool:
"""Consume the given amount of tokens if possible.
"""Try to consume a token.
Returns:
True means that the tokens were consumed, and the caller can proceed to
Expand Down Expand Up @@ -317,3 +243,9 @@ async def aacquire(self, *, blocking: bool = True) -> bool:
while not self._consume():
await asyncio.sleep(self.check_every_n_seconds)
return True


__all__ = [
"BaseRateLimiter",
"InMemoryRateLimiter",
]
2 changes: 0 additions & 2 deletions libs/core/langchain_core/runnables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
RunnablePassthrough,
RunnablePick,
)
from langchain_core.runnables.rate_limiter import InMemoryRateLimiter
from langchain_core.runnables.router import RouterInput, RouterRunnable
from langchain_core.runnables.utils import (
AddableDict,
Expand All @@ -65,7 +64,6 @@
"ensure_config",
"run_in_executor",
"patch_config",
"InMemoryRateLimiter",
"RouterInput",
"RouterRunnable",
"Runnable",
Expand Down
Loading

0 comments on commit 20690db

Please sign in to comment.