From e8acf0e712d7d89c56a36cf896948ec9511d6026 Mon Sep 17 00:00:00 2001 From: Russ Biggs Date: Sun, 10 Nov 2024 09:23:00 -0700 Subject: [PATCH] ratelimit fixes and token regeneration fix (#377) --- openaq_api/openaq_api/dependencies.py | 18 ++++++++++------- openaq_api/openaq_api/routers/auth.py | 25 +----------------------- openaq_api/openaq_api/v3/routers/auth.py | 20 +++++++++++++------ 3 files changed, 26 insertions(+), 37 deletions(-) diff --git a/openaq_api/openaq_api/dependencies.py b/openaq_api/openaq_api/dependencies.py index 86341ab..0cfc976 100644 --- a/openaq_api/openaq_api/dependencies.py +++ b/openaq_api/openaq_api/dependencies.py @@ -71,9 +71,7 @@ async def check_api_key( ) raise NOT_AUTHENTICATED_EXCEPTION else: - # check api key - limit = settings.RATE_AMOUNT_KEY - limited = False + # check valid key if await redis.sismember("keys", api_key) == 0: logging.info( @@ -82,7 +80,13 @@ async def check_api_key( ).model_dump_json() ) raise NOT_AUTHENTICATED_EXCEPTION - + # check api key + limit = await redis.hget(api_key, "rate") + try: + limit = int(limit) + except TypeError: + limit = 60 + limited = False # check if its limited now = datetime.now() # Using a sliding window rate limiting algorithm @@ -91,7 +95,6 @@ async def check_api_key( # if the that key is in our redis db it will return the number of requests # that key has made during the current minute value = await redis.get(key) - ttl = await redis.ttl(key) if value is None: # if the value is none than we need to add that key to the redis db @@ -117,10 +120,11 @@ async def check_api_key( ) limited = True requests_used = int(value) + ttl = await redis.ttl(key) response.headers["x-ratelimit-limit"] = str(limit) - response.headers["x-ratelimit-remaining"] = str(limit - requests_used) - response.headers["x-ratelimit-used"] = str(requests_used) + response.headers["x-ratelimit-remaining"] = str(requests_used) + response.headers["x-ratelimit-used"] = str(limit - requests_used) response.headers["x-ratelimit-reset"] = str(ttl) if limited: diff --git a/openaq_api/openaq_api/routers/auth.py b/openaq_api/openaq_api/routers/auth.py index 05464d7..8b96a84 100644 --- a/openaq_api/openaq_api/routers/auth.py +++ b/openaq_api/openaq_api/routers/auth.py @@ -210,7 +210,7 @@ async def verify(request: Request, verification_code: str, db: DB = Depends()): else: try: token = await db.get_user_token(row[0]) - redis_client = getattr(request.app.state, "redis_client") + redis_client = getattr(request.app.state, "redis") if redis_client: await redis_client.sadd("keys", token) send_api_key_email(token, row[3], row[4]) @@ -226,29 +226,6 @@ class RegenerateTokenBody(BaseModel): token: str -@router.post("/regenerate-token") -async def regenerate_token( - token: int = Body(..., embed=True) - # request: Request, - # db: DB = Depends(), -): - """ """ - _token = token - print(_token) - try: - # db.get_user_token - # await db.regenerate_user_token(body.users_id, _token) - # token = await db.get_user_token(body.users_id) - # redis_client = getattr(request.app.state, "redis_client") - # print("REDIS", redis_client) - # if redis_client: - # await redis_client.srem("keys", _token) - # await redis_client.sadd("keys", token) - return {"success"} - except Exception as e: - return e - - @router.post("/send-verification") async def get_register( request: Request, diff --git a/openaq_api/openaq_api/v3/routers/auth.py b/openaq_api/openaq_api/v3/routers/auth.py index 0de6b6b..01faa91 100644 --- a/openaq_api/openaq_api/v3/routers/auth.py +++ b/openaq_api/openaq_api/v3/routers/auth.py @@ -24,15 +24,17 @@ router = APIRouter( prefix="/auth", - include_in_schema=False, + include_in_schema=True, ) + def send_email(destination_email: str, msg: EmailMessage): if settings.USE_SMTP_EMAIL: return send_smtp_email(destination_email, msg) else: return send_ses_email(destination_email, msg) + def send_smtp_email(destination_email: str, msg: EmailMessage): with smtplib.SMTP_SSL(settings.SMTP_EMAIL_HOST, 465) as s: s.login(settings.SMTP_EMAIL_USER, settings.SMTP_EMAIL_PASSWORD) @@ -46,6 +48,7 @@ def send_ses_email(destination_email: str, msg: EmailMessage): RawMessage={"Data": msg.as_string()}, ) + def send_change_password_email(full_name: str, email: str): ses_client = boto3.client("ses") TEXT_EMAIL_CONTENT = """ @@ -238,7 +241,10 @@ async def get_register( return HTTPException(401) redis_client = getattr(request.app, "redis") if redis_client: - await redis_client.sadd("keys", user_token) + async with redis_client.pipeline() as pipe: + await pipe.sadd("keys", user_token).hset( + user_token, mapping={"rate": 2000} + ).execute() return {"message": "success"} except Exception as e: return e @@ -262,10 +268,12 @@ async def get_register( return HTTPException(401) await db.regenerate_user_token(body.users_id, body.token) new_token = await db.get_user_token(body.users_id) - redis_client = getattr(request.app.state, "redis_client") + redis_client = getattr(request.app, "redis") if redis_client: - await redis_client.srem("keys", body.token) - await redis_client.sadd("keys", new_token) + async with redis_client.pipeline() as pipe: + await pipe.srem("keys", body.token).sadd("keys", new_token).hset( + new_token, mapping={"rate": 60} + ).execute() return {"message": "success"} except Exception as e: return e @@ -347,7 +355,7 @@ async def verify_email( ): try: token = await db.get_user_token(body.users_id) - redis_client = getattr(request.app.state, "redis_client") + redis_client = getattr(request.app, "redis") if redis_client: await redis_client.sadd("keys", token) except Exception as e: