Skip to content

Commit

Permalink
ratelimit fixes and token regeneration fix (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
russbiggs authored Nov 10, 2024
1 parent 42af5ec commit e8acf0e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 37 deletions.
18 changes: 11 additions & 7 deletions openaq_api/openaq_api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ async def check_api_key(
)
raise NOT_AUTHENTICATED_EXCEPTION
else:
# check api key
limit = settings.RATE_AMOUNT_KEY
limited = False

# check valid key
if await redis.sismember("keys", api_key) == 0:
logging.info(
Expand All @@ -82,7 +80,13 @@ async def check_api_key(
).model_dump_json()
)
raise NOT_AUTHENTICATED_EXCEPTION

# check api key
limit = await redis.hget(api_key, "rate")
try:
limit = int(limit)
except TypeError:
limit = 60
limited = False
# check if its limited
now = datetime.now()
# Using a sliding window rate limiting algorithm
Expand All @@ -91,7 +95,6 @@ async def check_api_key(
# if the that key is in our redis db it will return the number of requests
# that key has made during the current minute
value = await redis.get(key)
ttl = await redis.ttl(key)

if value is None:
# if the value is none than we need to add that key to the redis db
Expand All @@ -117,10 +120,11 @@ async def check_api_key(
)
limited = True
requests_used = int(value)
ttl = await redis.ttl(key)

response.headers["x-ratelimit-limit"] = str(limit)
response.headers["x-ratelimit-remaining"] = str(limit - requests_used)
response.headers["x-ratelimit-used"] = str(requests_used)
response.headers["x-ratelimit-remaining"] = str(requests_used)
response.headers["x-ratelimit-used"] = str(limit - requests_used)
response.headers["x-ratelimit-reset"] = str(ttl)

if limited:
Expand Down
25 changes: 1 addition & 24 deletions openaq_api/openaq_api/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ async def verify(request: Request, verification_code: str, db: DB = Depends()):
else:
try:
token = await db.get_user_token(row[0])
redis_client = getattr(request.app.state, "redis_client")
redis_client = getattr(request.app.state, "redis")
if redis_client:
await redis_client.sadd("keys", token)
send_api_key_email(token, row[3], row[4])
Expand All @@ -226,29 +226,6 @@ class RegenerateTokenBody(BaseModel):
token: str


@router.post("/regenerate-token")
async def regenerate_token(
token: int = Body(..., embed=True)
# request: Request,
# db: DB = Depends(),
):
""" """
_token = token
print(_token)
try:
# db.get_user_token
# await db.regenerate_user_token(body.users_id, _token)
# token = await db.get_user_token(body.users_id)
# redis_client = getattr(request.app.state, "redis_client")
# print("REDIS", redis_client)
# if redis_client:
# await redis_client.srem("keys", _token)
# await redis_client.sadd("keys", token)
return {"success"}
except Exception as e:
return e


@router.post("/send-verification")
async def get_register(
request: Request,
Expand Down
20 changes: 14 additions & 6 deletions openaq_api/openaq_api/v3/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@

router = APIRouter(
prefix="/auth",
include_in_schema=False,
include_in_schema=True,
)


def send_email(destination_email: str, msg: EmailMessage):
if settings.USE_SMTP_EMAIL:
return send_smtp_email(destination_email, msg)
else:
return send_ses_email(destination_email, msg)


def send_smtp_email(destination_email: str, msg: EmailMessage):
with smtplib.SMTP_SSL(settings.SMTP_EMAIL_HOST, 465) as s:
s.login(settings.SMTP_EMAIL_USER, settings.SMTP_EMAIL_PASSWORD)
Expand All @@ -46,6 +48,7 @@ def send_ses_email(destination_email: str, msg: EmailMessage):
RawMessage={"Data": msg.as_string()},
)


def send_change_password_email(full_name: str, email: str):
ses_client = boto3.client("ses")
TEXT_EMAIL_CONTENT = """
Expand Down Expand Up @@ -238,7 +241,10 @@ async def get_register(
return HTTPException(401)
redis_client = getattr(request.app, "redis")
if redis_client:
await redis_client.sadd("keys", user_token)
async with redis_client.pipeline() as pipe:
await pipe.sadd("keys", user_token).hset(
user_token, mapping={"rate": 2000}
).execute()
return {"message": "success"}
except Exception as e:
return e
Expand All @@ -262,10 +268,12 @@ async def get_register(
return HTTPException(401)
await db.regenerate_user_token(body.users_id, body.token)
new_token = await db.get_user_token(body.users_id)
redis_client = getattr(request.app.state, "redis_client")
redis_client = getattr(request.app, "redis")
if redis_client:
await redis_client.srem("keys", body.token)
await redis_client.sadd("keys", new_token)
async with redis_client.pipeline() as pipe:
await pipe.srem("keys", body.token).sadd("keys", new_token).hset(
new_token, mapping={"rate": 60}
).execute()
return {"message": "success"}
except Exception as e:
return e
Expand Down Expand Up @@ -347,7 +355,7 @@ async def verify_email(
):
try:
token = await db.get_user_token(body.users_id)
redis_client = getattr(request.app.state, "redis_client")
redis_client = getattr(request.app, "redis")
if redis_client:
await redis_client.sadd("keys", token)
except Exception as e:
Expand Down

0 comments on commit e8acf0e

Please sign in to comment.