Skip to content

Commit

Permalink
update rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
russbiggs committed Oct 29, 2024
1 parent eb241f6 commit 7ea1688
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions openaq_api/openaq_api/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import logging
from datetime import datetime
from .settings import settings
from fastapi import (
Security,
Response
)
from fastapi import Security, Response
from starlette.requests import Request

from fastapi.security import (
Expand All @@ -20,7 +17,7 @@
from openaq_api.exceptions import (
NOT_AUTHENTICATED_EXCEPTION,
TOO_MANY_REQUESTS,
)
)

logger = logging.getLogger("dependencies")

Expand All @@ -46,8 +43,8 @@ def in_allowed_list(route: str) -> bool:
async def check_api_key(
request: Request,
response: Response,
api_key=Security(APIKeyHeader(name='X-API-Key', auto_error=False)),
):
api_key=Security(APIKeyHeader(name="X-API-Key", auto_error=False)),
):
"""
Check for an api key and then to see if they are rate limited. Throws a
`not authenticated` or `to many reqests` error if appropriate.
Expand All @@ -64,14 +61,14 @@ async def check_api_key(
redis = request.app.redis

if redis is None:
logger.warning('No redis client found')
logger.warning("No redis client found")
return api_key
elif api_key is None:
logging.info(
UnauthorizedLog(
request=request, detail=f"api key not provided"
).model_dump_json()
)
UnauthorizedLog(
request=request, detail=f"api key not provided"
).model_dump_json()
)
raise NOT_AUTHENTICATED_EXCEPTION
else:
# check api key
Expand Down Expand Up @@ -99,26 +96,30 @@ async def check_api_key(
if value is None:
# if the value is none than we need to add that key to the redis db
# and set it, increment it and set it to timeout/delete is 60 seconds
logger.debug('redis no key for current minute so not limited')
logger.debug("redis no key for current minute so not limited")
async with redis.pipeline() as pipe:
[incr, _] = await pipe.incr(key).expire(key, 60).execute()
requests_used = limit - incr
elif int(value) < limit:
# if that key does exist and the value is below the allowed number of requests
# wea re going to increment it and move on
logger.debug(f'redis - has key for current minute value ({value}) < limit ({limit})')
logger.debug(
f"redis - has key for current minute value ({value}) < limit ({limit})"
)
async with redis.pipeline() as pipe:
[incr, _] = await pipe.incr(key).execute()
[incr] = await pipe.incr(key).execute()
requests_used = limit - incr
else:
# otherwise the user is over their limit and so we are going to throw a 429
# after we set the headers
logger.debug(f'redis - has key for current minute and value ({value}) >= limit ({limit})')
logger.debug(
f"redis - has key for current minute and value ({value}) >= limit ({limit})"
)
limited = True
requests_used = int(value)

response.headers["x-ratelimit-limit"] = str(limit)
response.headers["x-ratelimit-remaining"] = str(limit-requests_used)
response.headers["x-ratelimit-remaining"] = str(limit - requests_used)
response.headers["x-ratelimit-used"] = str(requests_used)
response.headers["x-ratelimit-reset"] = str(ttl)

Expand Down

0 comments on commit 7ea1688

Please sign in to comment.