-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added the API key validator with mongoDB support!
- Loading branch information
1 parent
930da0c
commit a9b9e0d
Showing
2 changed files
with
111 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,53 @@ | ||
from fastapi import FastAPI, Security, HTTPException, Depends | ||
from fastapi import Security, HTTPException | ||
from fastapi.security.api_key import APIKeyHeader | ||
from typing import List | ||
from starlette.status import HTTP_403_FORBIDDEN | ||
from starlette.status import HTTP_401_UNAUTHORIZED | ||
|
||
from utils.mongo import MongoSingleton | ||
|
||
# List of valid API keys - in production, this should be stored securely | ||
VALID_API_KEYS = ["key1", "key2", "test_key"] | ||
API_KEY_NAME = "X-API-Key" | ||
|
||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) | ||
|
||
|
||
async def get_api_key(api_key_header: str = Security(api_key_header)): | ||
validator = ValidateAPIKey() | ||
|
||
if not api_key_header: | ||
raise HTTPException( | ||
status_code=HTTP_403_FORBIDDEN, detail="No API key provided" | ||
status_code=HTTP_401_UNAUTHORIZED, detail="No API key provided" | ||
) | ||
|
||
if api_key_header not in VALID_API_KEYS: | ||
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Invalid API key") | ||
if api_key_header not in validator(api_key_header): | ||
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key") | ||
|
||
return api_key_header | ||
|
||
|
||
class ValidateAPIKey: | ||
def __init__(self) -> None: | ||
self.client = MongoSingleton.get_instance().get_client() | ||
self.db = "hivemind" | ||
self.tokens_collection = "tokens" | ||
|
||
def validate(self, api_key: str) -> bool: | ||
""" | ||
check if the api key is available in mongodb or not | ||
Parameters | ||
------------ | ||
api_key : str | ||
the provided key to check in db | ||
Returns | ||
--------- | ||
valid : bool | ||
if the key was available in mongo collection, then return True | ||
else, the token is not valid and return False | ||
""" | ||
document = self.client[self.db][self.tokens_collection].find_one( | ||
{"token": api_key} | ||
) | ||
|
||
return True if document else False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from unittest import TestCase | ||
|
||
from utils.mongo import MongoSingleton | ||
from services.api_key import ValidateAPIKey | ||
|
||
|
||
class TestValidateToken(TestCase): | ||
def setUp(self) -> None: | ||
self.client = MongoSingleton.get_instance().get_client() | ||
self.validator = ValidateAPIKey() | ||
|
||
# changing the db so not to overlap with the right ones | ||
self.validator.db = "hivemind_test" | ||
self.validator.tokens_collection = "tokens_test" | ||
|
||
self.client.drop_database(self.validator.db) | ||
|
||
def tearDown(self) -> None: | ||
self.client.drop_database(self.validator.db) | ||
|
||
def test_no_token_available(self): | ||
api_key = "1234" | ||
valid = self.validator.validate(api_key) | ||
|
||
self.assertEqual(valid, False) | ||
|
||
def test_no_matching_token_available(self): | ||
self.client[self.validator.db][self.validator.tokens_collection].insert_many( | ||
[ | ||
{ | ||
"id": 1, | ||
"token": "1111", | ||
"options": {}, | ||
}, | ||
{ | ||
"id": 2, | ||
"token": "2222", | ||
"options": {}, | ||
}, | ||
{ | ||
"id": 3, | ||
"token": "3333", | ||
"options": {}, | ||
}, | ||
] | ||
) | ||
api_key = "1234" | ||
valid = self.validator.validate(api_key) | ||
|
||
self.assertEqual(valid, False) | ||
|
||
def test_single_token_available(self): | ||
api_key = "1234" | ||
self.client[self.validator.db][self.validator.tokens_collection].insert_many( | ||
[ | ||
{ | ||
"id": 1, | ||
"token": api_key, | ||
"options": {}, | ||
}, | ||
{ | ||
"id": 2, | ||
"token": "2222", | ||
"options": {}, | ||
}, | ||
{ | ||
"id": 3, | ||
"token": "3333", | ||
"options": {}, | ||
}, | ||
] | ||
) | ||
valid = self.validator.validate(api_key) | ||
|
||
self.assertEqual(valid, True) |