From 930da0c2538793aed37356bdbd8015ff1b07e76b Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 12 Nov 2024 12:20:17 +0330 Subject: [PATCH 1/4] wip: adding API key middleware! --- routers/http.py | 7 ++++--- services/__init__.py | 0 services/api_key.py | 23 +++++++++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 services/__init__.py create mode 100644 services/api_key.py diff --git a/routers/http.py b/routers/http.py index d53e790..fb45565 100644 --- a/routers/http.py +++ b/routers/http.py @@ -1,11 +1,12 @@ import logging from celery.result import AsyncResult -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel from schema import HTTPPayload, QuestionModel, ResponseModel from utils.persist_payload import PersistPayload from worker.tasks import ask_question_auto_search +from services.api_key import api_key_header class RequestPayload(BaseModel): @@ -16,7 +17,7 @@ class RequestPayload(BaseModel): router = APIRouter() -@router.post("/ask") +@router.post("/ask", dependencies=[Depends(api_key_header)]) async def ask(payload: RequestPayload): query = payload.question.message community_id = payload.communityId @@ -36,7 +37,7 @@ async def ask(payload: RequestPayload): return {"id": task.id} -@router.get("/status") +@router.get("/status", dependencies=[Depends(api_key_header)]) async def status(task_id: str): task = AsyncResult(task_id) if task.status == "SUCCESS": diff --git a/services/__init__.py b/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/api_key.py b/services/api_key.py new file mode 100644 index 0000000..82002cd --- /dev/null +++ b/services/api_key.py @@ -0,0 +1,23 @@ +from fastapi import FastAPI, Security, HTTPException, Depends +from fastapi.security.api_key import APIKeyHeader +from typing import List +from starlette.status import HTTP_403_FORBIDDEN + + +# List of valid API keys - in production, this should be stored securely +VALID_API_KEYS = ["key1", "key2", "test_key"] +API_KEY_NAME = "X-API-Key" + +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + + +async def get_api_key(api_key_header: str = Security(api_key_header)): + if not api_key_header: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="No API key provided" + ) + + if api_key_header not in VALID_API_KEYS: + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Invalid API key") + + return api_key_header From a9b9e0d8b8ab1ae69052827070d59b4202623541 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 12 Nov 2024 13:01:26 +0330 Subject: [PATCH 2/4] feat: Added the API key validator with mongoDB support! --- services/api_key.py | 42 +++++++++++-- tests/integration/test_validate_token.py | 75 ++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 6 deletions(-) create mode 100644 tests/integration/test_validate_token.py diff --git a/services/api_key.py b/services/api_key.py index 82002cd..6eeecc8 100644 --- a/services/api_key.py +++ b/services/api_key.py @@ -1,23 +1,53 @@ -from fastapi import FastAPI, Security, HTTPException, Depends +from fastapi import Security, HTTPException from fastapi.security.api_key import APIKeyHeader from typing import List -from starlette.status import HTTP_403_FORBIDDEN +from starlette.status import HTTP_401_UNAUTHORIZED +from utils.mongo import MongoSingleton # List of valid API keys - in production, this should be stored securely -VALID_API_KEYS = ["key1", "key2", "test_key"] API_KEY_NAME = "X-API-Key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) async def get_api_key(api_key_header: str = Security(api_key_header)): + validator = ValidateAPIKey() + if not api_key_header: raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="No API key provided" + status_code=HTTP_401_UNAUTHORIZED, detail="No API key provided" ) - if api_key_header not in VALID_API_KEYS: - raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Invalid API key") + if api_key_header not in validator(api_key_header): + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key") return api_key_header + + +class ValidateAPIKey: + def __init__(self) -> None: + self.client = MongoSingleton.get_instance().get_client() + self.db = "hivemind" + self.tokens_collection = "tokens" + + def validate(self, api_key: str) -> bool: + """ + check if the api key is available in mongodb or not + + Parameters + ------------ + api_key : str + the provided key to check in db + + Returns + --------- + valid : bool + if the key was available in mongo collection, then return True + else, the token is not valid and return False + """ + document = self.client[self.db][self.tokens_collection].find_one( + {"token": api_key} + ) + + return True if document else False diff --git a/tests/integration/test_validate_token.py b/tests/integration/test_validate_token.py new file mode 100644 index 0000000..b9d88a6 --- /dev/null +++ b/tests/integration/test_validate_token.py @@ -0,0 +1,75 @@ +from unittest import TestCase + +from utils.mongo import MongoSingleton +from services.api_key import ValidateAPIKey + + +class TestValidateToken(TestCase): + def setUp(self) -> None: + self.client = MongoSingleton.get_instance().get_client() + self.validator = ValidateAPIKey() + + # changing the db so not to overlap with the right ones + self.validator.db = "hivemind_test" + self.validator.tokens_collection = "tokens_test" + + self.client.drop_database(self.validator.db) + + def tearDown(self) -> None: + self.client.drop_database(self.validator.db) + + def test_no_token_available(self): + api_key = "1234" + valid = self.validator.validate(api_key) + + self.assertEqual(valid, False) + + def test_no_matching_token_available(self): + self.client[self.validator.db][self.validator.tokens_collection].insert_many( + [ + { + "id": 1, + "token": "1111", + "options": {}, + }, + { + "id": 2, + "token": "2222", + "options": {}, + }, + { + "id": 3, + "token": "3333", + "options": {}, + }, + ] + ) + api_key = "1234" + valid = self.validator.validate(api_key) + + self.assertEqual(valid, False) + + def test_single_token_available(self): + api_key = "1234" + self.client[self.validator.db][self.validator.tokens_collection].insert_many( + [ + { + "id": 1, + "token": api_key, + "options": {}, + }, + { + "id": 2, + "token": "2222", + "options": {}, + }, + { + "id": 3, + "token": "3333", + "options": {}, + }, + ] + ) + valid = self.validator.validate(api_key) + + self.assertEqual(valid, True) From c84c5192ddfe632c7c31a9a9c3b534219c05a058 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 12 Nov 2024 13:13:49 +0330 Subject: [PATCH 3/4] fix: linter issues error! --- routers/http.py | 2 +- services/api_key.py | 6 ++---- tests/integration/test_validate_token.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/routers/http.py b/routers/http.py index fb45565..e0e69bd 100644 --- a/routers/http.py +++ b/routers/http.py @@ -4,9 +4,9 @@ from fastapi import APIRouter, Depends from pydantic import BaseModel from schema import HTTPPayload, QuestionModel, ResponseModel +from services.api_key import api_key_header from utils.persist_payload import PersistPayload from worker.tasks import ask_question_auto_search -from services.api_key import api_key_header class RequestPayload(BaseModel): diff --git a/services/api_key.py b/services/api_key.py index 6eeecc8..8fe05d0 100644 --- a/services/api_key.py +++ b/services/api_key.py @@ -1,8 +1,6 @@ -from fastapi import Security, HTTPException +from fastapi import HTTPException, Security from fastapi.security.api_key import APIKeyHeader -from typing import List from starlette.status import HTTP_401_UNAUTHORIZED - from utils.mongo import MongoSingleton # List of valid API keys - in production, this should be stored securely @@ -19,7 +17,7 @@ async def get_api_key(api_key_header: str = Security(api_key_header)): status_code=HTTP_401_UNAUTHORIZED, detail="No API key provided" ) - if api_key_header not in validator(api_key_header): + if api_key_header not in validator.validate(api_key_header): raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key") return api_key_header diff --git a/tests/integration/test_validate_token.py b/tests/integration/test_validate_token.py index b9d88a6..b4771c1 100644 --- a/tests/integration/test_validate_token.py +++ b/tests/integration/test_validate_token.py @@ -1,7 +1,7 @@ from unittest import TestCase -from utils.mongo import MongoSingleton from services.api_key import ValidateAPIKey +from utils.mongo import MongoSingleton class TestValidateToken(TestCase): From b0dd8bbb3f785450c1023b89138a53dd3df6ac23 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 12 Nov 2024 13:16:12 +0330 Subject: [PATCH 4/4] fix: wrong condition! --- services/api_key.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/api_key.py b/services/api_key.py index 8fe05d0..a7c60c2 100644 --- a/services/api_key.py +++ b/services/api_key.py @@ -17,7 +17,7 @@ async def get_api_key(api_key_header: str = Security(api_key_header)): status_code=HTTP_401_UNAUTHORIZED, detail="No API key provided" ) - if api_key_header not in validator.validate(api_key_header): + if not validator.validate(api_key_header): raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key") return api_key_header