From bb3a445604c50790b9450ff5059529efac91c157 Mon Sep 17 00:00:00 2001 From: Christian Parker Date: Mon, 21 Oct 2024 11:39:26 -0700 Subject: [PATCH] Updated based on review --- openaq_api/openaq_api/dependencies.py | 136 ++++++++++++++++++++++++++ openaq_api/openaq_api/exceptions.py | 15 +++ openaq_api/openaq_api/main.py | 6 +- openaq_api/openaq_api/middleware.py | 126 +----------------------- 4 files changed, 155 insertions(+), 128 deletions(-) create mode 100644 openaq_api/openaq_api/dependencies.py create mode 100644 openaq_api/openaq_api/exceptions.py diff --git a/openaq_api/openaq_api/dependencies.py b/openaq_api/openaq_api/dependencies.py new file mode 100644 index 0000000..3374e69 --- /dev/null +++ b/openaq_api/openaq_api/dependencies.py @@ -0,0 +1,136 @@ +import logging +from datetime import datetime +from .settings import settings +from fastapi import ( + Security, + Response + ) +from starlette.requests import Request + +from fastapi.security import ( + APIKeyHeader, +) + +from openaq_api.models.logging import ( + TooManyRequestsLog, + UnauthorizedLog, + RedisErrorLog, +) + +from openaq_api.exceptions import ( + NOT_AUTHENTICATED_EXCEPTION, + TOO_MANY_REQUESTS, + ) + +logger = logging.getLogger("dependencies") + + +def is_whitelisted_route(route: str) -> bool: + logger.debug(f"Checking if '{route}' is whitelisted") + allow_list = ["/", "/openapi.json", "/docs", "/register"] + if route in allow_list: + return True + if "/v2/locations/tiles" in route: + return True + if "/v3/locations/tiles" in route: + return True + if "/assets" in route: + return True + if ".css" in route: + return True + if ".js" in route: + return True + return False + + +async def check_api_key( + request: Request, + response: Response, + api_key=Security(APIKeyHeader(name='X-API-Key', auto_error=False)), + ): + """ + Check for an api key and then to see if they are rate limited. Throws a + `not authenticated` or `to many reqests` error if appropriate. + Meant to be used as a dependency either at the app, router or function level + """ + route = request.url.path + # no checking or limiting for whitelistted routes + if is_whitelisted_route(route): + return api_key + elif api_key == settings.EXPLORER_API_KEY: + return api_key + else: + # check to see if we are limiting + redis = request.app.redis + + if redis is None: + logger.warning('No redis client found') + return api_key + elif api_key is None: + logging.info( + UnauthorizedLog( + request=request, detail=f"api key not provided" + ).model_dump_json() + ) + raise NOT_AUTHENTICATED_EXCEPTION + else: + # check api key + limit = settings.RATE_AMOUNT_KEY + limited = False + # check valid key + if await redis.sismember("keys", api_key) == 0: + logging.info( + UnauthorizedLog( + request=request, detail=f"api key not found" + ).model_dump_json() + ) + raise NOT_AUTHENTICATED_EXCEPTION + + # check if its limited + now = datetime.now() + # Using a sliding window rate limiting algorithm + # we add the current time to the minute to the api key and use that as our check + key = f"{api_key}:{now.year}{now.month}{now.day}{now.hour}{now.minute}" + # if the that key is in our redis db it will return the number of requests + # that key has made during the current minute + value = await redis.get(key) + ttl = await redis.ttl(key) + + if value is None: + # if the value is none than we need to add that key to the redis db + # and set it, increment it and set it to timeout/delete is 60 seconds + logger.debug('redis no key for current minute so not limited') + async with redis.pipeline() as pipe: + [incr, _] = await pipe.incr(key).expire(key, 60).execute() + requests_used = limit - incr + elif int(value) < limit: + # if that key does exist and the value is below the allowed number of requests + # wea re going to increment it and move on + logger.debug(f'redis - has key for current minute value ({value}) < limit ({limit})') + async with redis.pipeline() as pipe: + [incr, _] = await pipe.incr(key).execute() + requests_used = limit - incr + else: + # otherwise the user is over their limit and so we are going to throw a 429 + # after we set the headers + logger.debug(f'redis - has key for current minute and value ({value}) >= limit ({limit})') + limited = True + requests_used = int(value) + + response.headers["x-ratelimit-limit"] = str(limit) + response.headers["x-ratelimit-remaining"] = str(limit-requests_used) + response.headers["x-ratelimit-used"] = str(requests_used) + response.headers["x-ratelimit-reset"] = str(ttl) + + if limited: + logging.info( + TooManyRequestsLog( + request=request, + rate_limiter=f"{key}/{limit}/{requests_used}", + ).model_dump_json() + ) + raise TOO_MANY_REQUESTS + + # it would be ideal if we were returing the user information right here + # even it was just an email address it might be useful + return api_key diff --git a/openaq_api/openaq_api/exceptions.py b/openaq_api/openaq_api/exceptions.py new file mode 100644 index 0000000..aac4b91 --- /dev/null +++ b/openaq_api/openaq_api/exceptions.py @@ -0,0 +1,15 @@ + +from fastapi import ( + HTTPException, + status, + ) + +NOT_AUTHENTICATED_EXCEPTION = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid credentials", +) + +TOO_MANY_REQUESTS = HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="To many requests", +) diff --git a/openaq_api/openaq_api/main.py b/openaq_api/openaq_api/main.py index 9decef3..837ae90 100644 --- a/openaq_api/openaq_api/main.py +++ b/openaq_api/openaq_api/main.py @@ -20,8 +20,10 @@ from starlette.responses import JSONResponse, RedirectResponse from openaq_api.db import db_pool +from openaq_api.dependencies import ( + check_api_key + ) from openaq_api.middleware import ( - check_api_key, CacheControlMiddleware, LoggingMiddleware, ) @@ -32,8 +34,6 @@ WarnLog, ) - - # from openaq_api.routers.auth import router as auth_router from openaq_api.routers.averages import router as averages_router from openaq_api.routers.cities import router as cities_router diff --git a/openaq_api/openaq_api/middleware.py b/openaq_api/openaq_api/middleware.py index aa48473..c9e6fdb 100644 --- a/openaq_api/openaq_api/middleware.py +++ b/openaq_api/openaq_api/middleware.py @@ -1,143 +1,19 @@ import logging import time -from datetime import timedelta, datetime from os import environ -from fastapi import Response, status, Security, HTTPException +from fastapi import Response from fastapi.responses import JSONResponse -from redis.asyncio.cluster import RedisCluster from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.types import ASGIApp -from fastapi.security import ( - APIKeyHeader, -) - from openaq_api.models.logging import ( HTTPLog, LogType, - TooManyRequestsLog, - UnauthorizedLog, - RedisErrorLog, ) -from .settings import settings - logger = logging.getLogger("middleware") -NOT_AUTHENTICATED_EXCEPTION = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid credentials", -) - -TOO_MANY_REQUESTS = HTTPException( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - detail="To many requests", -) - -def is_whitelisted_route(route: str) -> bool: - logger.debug(f"Checking if '{route}' is whitelisted") - allow_list = ["/", "/openapi.json", "/docs", "/register"] - if route in allow_list: - return True - if "/v2/locations/tiles" in route: - return True - if "/v3/locations/tiles" in route: - return True - if "/assets" in route: - return True - if ".css" in route: - return True - if ".js" in route: - return True - return False - - -async def check_api_key( - request: Request, - response: Response, - api_key=Security(APIKeyHeader(name='X-API-Key', auto_error=False)), - ): - """ - Check for an api key and then to see if they are rate limited. Throws a - `not authenticated` or `to many reqests` error if appropriate. - Meant to be used as a dependency either at the app, router or function level - """ - route = request.url.path - # no checking or limiting for whitelistted routes - if is_whitelisted_route(route): - return api_key - elif api_key == settings.EXPLORER_API_KEY: - return api_key - else: - # check to see if we are limiting - redis = request.app.redis - - if redis is None: - logger.warning('No redis client found') - return api_key - elif api_key is None: - logger.debug('No api key provided') - raise NOT_AUTHENTICATED_EXCEPTION - else: - # check api key - limit = settings.RATE_AMOUNT_KEY - limited = False - # check valid key - if await redis.sismember("keys", api_key) == 0: - logger.debug('Api key not found') - raise NOT_AUTHENTICATED_EXCEPTION - - # check if its limited - now = datetime.now() - # Using a sliding window rate limiting algorithm - # we add the current time to the minute to the api key and use that as our check - key = f"{api_key}:{now.year}{now.month}{now.day}{now.hour}{now.minute}" - # if the that key is in our redis db it will return the number of requests - # that key has made during the current minute - value = await redis.get(key) - ttl = await redis.ttl(key) - - if value is None: - # if the value is none than we need to add that key to the redis db - # and set it, increment it and set it to timeout/delete is 60 seconds - logger.debug('redis no key for current minute so not limited') - async with redis.pipeline() as pipe: - [incr, _] = await pipe.incr(key).expire(key, 60).execute() - requests_used = limit - incr - elif int(value) < limit: - # if that key does exist and the value is below the allowed number of requests - # wea re going to increment it and move on - logger.debug(f'redis - has key for current minute value ({value}) < limit ({limit})') - async with redis.pipeline() as pipe: - [incr, _] = await pipe.incr(key).execute() - requests_used = limit - incr - else: - # otherwise the user is over their limit and so we are going to throw a 429 - # after we set the headers - logger.debug(f'redis - has key for current minute and value ({value}) >= limit ({limit})') - limited = True - requests_used = value - - response.headers["x-ratelimit-limit"] = str(limit) - response.headers["x-ratelimit-remaining"] = "0" - response.headers["x-ratelimit-used"] = str(requests_used) - response.headers["x-ratelimit-reset"] = str(ttl) - - if limited: - logging.info( - TooManyRequestsLog( - request=request, - rate_limiter=f"{key}/{limit}/{requests_used}", - ).model_dump_json() - ) - raise TOO_MANY_REQUESTS - - # it would be ideal if we were returing the user information right here - # even it was just an email address it might be useful - return api_key - - class CacheControlMiddleware(BaseHTTPMiddleware): """MiddleWare to add CacheControl in response headers."""