Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 20240220 #343

Merged
merged 5 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/deploy-prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ env:

EMAIL_SENDER: ${{ secrets.EMAIL_SENDER }}

EXPLORER_API_KEY: ${{ secrets.EXPLORER_API_KEY }}


jobs:
deploy:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/deploy-staging.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ env:
RATE_TIME: 1

EMAIL_SENDER: ${{ secrets.EMAIL_SENDER }}

EXPLORER_API_KEY: ${{ secrets.EXPLORER_API_KEY }}


jobs:
deploy:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ env:
DATABASE_DB: "placeholder"
DATABASE_HOST: "placeholder"
DATABASE_PORT: 42
EXPLORER_API_KEY: "placeholder"

jobs:
test:
Expand Down
82 changes: 79 additions & 3 deletions openaq_api/openaq_api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
allowed_config_params = ["work_mem"]



DEFAULT_CONNECTION_TIMEOUT = 6
MAX_CONNECTION_TIMEOUT = 15

Expand All @@ -34,6 +33,7 @@ def default(obj):
# function is used in the `cached` decorator and without it
# we will get a number of arguments error


def dbkey(m, f, query, args, timeout=None, config=None):
j = orjson.dumps(
args, option=orjson.OPT_OMIT_MICROSECONDS, default=default
Expand Down Expand Up @@ -115,7 +115,9 @@ async def fetch(
q = f"SELECT set_config('{param}', $1, TRUE)"
s = await con.execute(q, str(value))
if not isinstance(timeout, (str, int)):
logger.warning(f"Non int or string timeout value passed - {timeout}")
logger.warning(
f"Non int or string timeout value passed - {timeout}"
)
timeout = DEFAULT_CONNECTION_TIMEOUT
r = await wait_for(con.fetch(rquery, *args), timeout=timeout)
await tr.commit()
Expand Down Expand Up @@ -193,9 +195,83 @@ async def create_user(self, user: User) -> str:
await conn.close()
return verification_token[0][0]

async def get_user(self, users_id: int) -> str:
"""
gets user info from users table and entities table
"""
query = """
SELECT
e.full_name
, u.email_address
, u.verification_code
FROM
users u
JOIN
users_entities USING (users_id)
JOIN
entities e USING (entities_id)
WHERE
u.users_id = :users_id
"""
conn = await asyncpg.connect(settings.DATABASE_READ_URL)
rquery, args = render(query, **{"users_id": users_id})
user = await conn.fetch(rquery, *args)
await conn.close()
return user[0]

async def generate_verification_code(self, email_address: str) -> str:
"""
gets user info from users table and entities table
"""
query = """
UPDATE
users
SET
verification_code = generate_token()
, expires_on = (timestamptz (NOW() + INTERVAL '30min'))
WHERE
email_address = :email_address
RETURNING verification_code as "verificationCode"
"""
conn = await asyncpg.connect(settings.DATABASE_WRITE_URL)
rquery, args = render(query, **{"email_address": email_address})
row = await conn.fetch(rquery, *args)
await conn.close()
return row[0][0]

async def regenerate_user_token(self, users_id: int, token: str) -> str:
"""
calls the get_user_token plpgsql function to verify user email and generate API token
"""
query = """
UPDATE
user_keys
SET
token = generate_token()
WHERE
users_id = :users_id
AND
token = :token
"""
conn = await asyncpg.connect(settings.DATABASE_WRITE_URL)
rquery, args = render(query, **{"users_id": users_id, "token": token})
await conn.fetch(rquery, *args)
await conn.close()

async def get_user_token(self, users_id: int) -> str:
""" """
query = """
SELECT token FROM user_keys WHERE users_id = :users_id
"""
conn = await asyncpg.connect(settings.DATABASE_WRITE_URL)
rquery, args = render(query, **{"users_id": users_id})
api_token = await conn.fetch(rquery, *args)
await conn.close()
return api_token[0][0]

async def generate_user_token(self, users_id: int) -> str:
"""
calls the get_user_token plpgsql function to vefiry user email and generate API token
calls the get_user_token plpgsql function to verify user email and generate API token
"""
query = """
SELECT * FROM get_user_token(:users_id)
Expand Down
98 changes: 51 additions & 47 deletions openaq_api/openaq_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from openaq_api.middleware import (
CacheControlMiddleware,
LoggingMiddleware,
PrivatePathsMiddleware,
RateLimiterMiddleWare,
)
from openaq_api.models.logging import (
Expand All @@ -47,6 +48,7 @@

# V3 routers
from openaq_api.v3.routers import (
auth,
countries,
instruments,
locations,
Expand Down Expand Up @@ -91,6 +93,8 @@ def default(obj):
return round(obj, 5)
if isinstance(obj, datetime.datetime):
return obj.strptime("%Y-%m-%dT%H:%M:%SZ")
if isinstance(obj, datetime.date):
return obj.strptime("%Y-%m-%d")


class ORJSONResponse(JSONResponse):
Expand All @@ -99,16 +103,38 @@ def render(self, content: Any) -> bytes:
return orjson.dumps(content, default=default)


redis_client = None # initialize for generalize_schema.py


@asynccontextmanager
async def lifespan(app: FastAPI):
if not hasattr(app.state, "pool"):
logger.debug("initializing connection pool")
app.state.pool = await db_pool(None)
logger.debug("Connection pool established")

if hasattr(app.state, "counter"):
app.state.counter += 1
else:
app.state.counter = 0
app.state.redis_client = redis_client
yield
if hasattr(app.state, "pool") and not settings.USE_SHARED_POOL:
logger.debug("Closing connection")
await app.state.pool.close()
delattr(app.state, "pool")
logger.debug("Connection closed")


app = FastAPI(
title="OpenAQ",
description="OpenAQ API",
version="2.0.0",
default_response_class=ORJSONResponse,
docs_url="/docs",
lifespan=lifespan,
)

redis_client = None # initialize for generalize_schema.py


if settings.RATE_LIMITING is True:
if settings.RATE_LIMITING:
Expand All @@ -126,6 +152,7 @@ def render(self, content: Any) -> bytes:
logging.error(
InfrastructureErrorLog(detail=f"failed to connect to redis: {e}")
)
print(redis_client)
logger.debug("Redis connected")
if redis_client:
app.add_middleware(
Expand All @@ -152,8 +179,7 @@ def render(self, content: Any) -> bytes:
app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=900")
app.add_middleware(LoggingMiddleware)
app.add_middleware(GZipMiddleware, minimum_size=1000)

app.include_router(auth_router)
app.add_middleware(PrivatePathsMiddleware)


class OpenAQValidationResponseDetail(BaseModel):
Expand All @@ -171,57 +197,33 @@ async def openaq_request_validation_exception_handler(
request: Request, exc: RequestValidationError
):
return ORJSONResponse(status_code=422, content=jsonable_encoder(str(exc)))
return PlainTextResponse(str(exc))
print("\n\n\n\n\n")
print(str(exc))
print("\n\n\n\n\n")
detail = orjson.loads(str(exc))
logger.debug(traceback.format_exc())
logger.info(
UnprocessableEntityLog(request=request, detail=str(exc)).model_dump_json()
)
detail = OpenAQValidationResponse(detail=detail)
return ORJSONResponse(status_code=422, content=jsonable_encoder(detail))
#return PlainTextResponse(str(exc))
# print("\n\n\n\n\n")
# print(str(exc))
# print("\n\n\n\n\n")
# detail = orjson.loads(str(exc))
# logger.debug(traceback.format_exc())
# logger.info(
# UnprocessableEntityLog(request=request, detail=str(exc)).model_dump_json()
# )
# detail = OpenAQValidationResponse(detail=detail)
#return ORJSONResponse(status_code=422, content=jsonable_encoder(detail))


@app.exception_handler(ValidationError)
async def openaq_exception_handler(request: Request, exc: ValidationError):
return ORJSONResponse(status_code=422, content=jsonable_encoder(str(exc)))

detail = orjson.loads(exc.model_dump_json())
logger.debug(traceback.format_exc())
logger.error(
ModelValidationError(
request=request, detail=exc.jsmodel_dump_jsonon()
).model_dump_json()
)
return ORJSONResponse(status_code=422, content=jsonable_encoder(detail))
# detail = orjson.loads(exc.model_dump_json())
# logger.debug(traceback.format_exc())
# logger.error(
# ModelValidationError(
# request=request, detail=exc.jsmodel_dump_jsonon()
# ).model_dump_json()
# )
#return ORJSONResponse(status_code=422, content=jsonable_encoder(detail))
# return ORJSONResponse(status_code=500, content={"message": "internal server error"})


@app.on_event("startup")
async def startup_event():
if not hasattr(app.state, "pool"):
logger.debug("initializing connection pool")
app.state.pool = await db_pool(None)
logger.debug("Connection pool established")

if hasattr(app.state, "counter"):
app.state.counter += 1
else:
app.state.counter = 0
app.state.redis_client = redis_client


@app.on_event("shutdown")
async def shutdown_event():
if hasattr(app.state, "pool") and not settings.USE_SHARED_POOL:
logger.debug("Closing connection")
await app.state.pool.close()
delattr(app.state, "pool")
logger.debug("Connection closed")


@app.get("/ping", include_in_schema=False)
def pong():
"""
Expand All @@ -239,6 +241,7 @@ def favico():


# v3
app.include_router(auth.router)
app.include_router(instruments.router)
app.include_router(locations.router)
app.include_router(parameters.router)
Expand Down Expand Up @@ -267,6 +270,7 @@ def favico():

static_dir = Path.joinpath(Path(__file__).resolve().parent, "static")


app.mount("/", StaticFiles(directory=str(static_dir), html=True))


Expand Down
22 changes: 22 additions & 0 deletions openaq_api/openaq_api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,24 @@ async def dispatch(self, request: Request, call_next):
return response


class PrivatePathsMiddleware(BaseHTTPMiddleware):
"""
Middleware to protect private endpoints with an API key
"""

async def dispatch(self, request: Request, call_next):
route = request.url.path
if "/auth" in route:
auth = request.headers.get("x-api-key", None)
if auth != settings.EXPLORER_API_KEY:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"message": "invalid credentials"},
)
response = await call_next(request)
return response


class RateLimiterMiddleWare(BaseHTTPMiddleware):
def __init__(
self,
Expand Down Expand Up @@ -167,8 +185,12 @@ def limited_path(route: str) -> bool:
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
print("RATE LIMIT\n\n\n")
route = request.url.path
auth = request.headers.get("x-api-key", None)
if auth == settings.EXPLORER_API_KEY:
response = await call_next(request)
return response
limit = self.rate_amount
now = datetime.now()
key = f"{request.client.host}:{now.year}{now.month}{now.day}{now.hour}{now.minute}"
Expand Down
Loading
Loading