diff --git a/.github/workflows/deploy-prod.yml b/.github/workflows/deploy-prod.yml index 96a93c9..063629f 100644 --- a/.github/workflows/deploy-prod.yml +++ b/.github/workflows/deploy-prod.yml @@ -42,6 +42,8 @@ env: EMAIL_SENDER: ${{ secrets.EMAIL_SENDER }} + EXPLORER_API_KEY: ${{ secrets.EXPLORER_API_KEY }} + jobs: deploy: diff --git a/.github/workflows/deploy-staging.yml b/.github/workflows/deploy-staging.yml index cd9beca..9192076 100644 --- a/.github/workflows/deploy-staging.yml +++ b/.github/workflows/deploy-staging.yml @@ -40,6 +40,9 @@ env: RATE_TIME: 1 EMAIL_SENDER: ${{ secrets.EMAIL_SENDER }} + + EXPLORER_API_KEY: ${{ secrets.EXPLORER_API_KEY }} + jobs: deploy: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 409e70f..e3eab0b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,6 +17,7 @@ env: DATABASE_DB: "placeholder" DATABASE_HOST: "placeholder" DATABASE_PORT: 42 + EXPLORER_API_KEY: "placeholder" jobs: test: diff --git a/openaq_api/openaq_api/db.py b/openaq_api/openaq_api/db.py index cde570d..02fd542 100644 --- a/openaq_api/openaq_api/db.py +++ b/openaq_api/openaq_api/db.py @@ -21,7 +21,6 @@ allowed_config_params = ["work_mem"] - DEFAULT_CONNECTION_TIMEOUT = 6 MAX_CONNECTION_TIMEOUT = 15 @@ -34,6 +33,7 @@ def default(obj): # function is used in the `cached` decorator and without it # we will get a number of arguments error + def dbkey(m, f, query, args, timeout=None, config=None): j = orjson.dumps( args, option=orjson.OPT_OMIT_MICROSECONDS, default=default @@ -115,7 +115,9 @@ async def fetch( q = f"SELECT set_config('{param}', $1, TRUE)" s = await con.execute(q, str(value)) if not isinstance(timeout, (str, int)): - logger.warning(f"Non int or string timeout value passed - {timeout}") + logger.warning( + f"Non int or string timeout value passed - {timeout}" + ) timeout = DEFAULT_CONNECTION_TIMEOUT r = await wait_for(con.fetch(rquery, *args), timeout=timeout) await tr.commit() @@ -193,9 +195,83 @@ async def create_user(self, user: User) -> str: await conn.close() return verification_token[0][0] + async def get_user(self, users_id: int) -> str: + """ + gets user info from users table and entities table + """ + query = """ + SELECT + e.full_name + , u.email_address + , u.verification_code + FROM + users u + JOIN + users_entities USING (users_id) + JOIN + entities e USING (entities_id) + WHERE + u.users_id = :users_id + """ + conn = await asyncpg.connect(settings.DATABASE_READ_URL) + rquery, args = render(query, **{"users_id": users_id}) + user = await conn.fetch(rquery, *args) + await conn.close() + return user[0] + + async def generate_verification_code(self, email_address: str) -> str: + """ + gets user info from users table and entities table + """ + query = """ + UPDATE + users + SET + verification_code = generate_token() + , expires_on = (timestamptz (NOW() + INTERVAL '30min')) + WHERE + email_address = :email_address + RETURNING verification_code as "verificationCode" + """ + conn = await asyncpg.connect(settings.DATABASE_WRITE_URL) + rquery, args = render(query, **{"email_address": email_address}) + row = await conn.fetch(rquery, *args) + await conn.close() + return row[0][0] + + async def regenerate_user_token(self, users_id: int, token: str) -> str: + """ + calls the get_user_token plpgsql function to verify user email and generate API token + """ + query = """ + UPDATE + user_keys + SET + token = generate_token() + WHERE + users_id = :users_id + AND + token = :token + """ + conn = await asyncpg.connect(settings.DATABASE_WRITE_URL) + rquery, args = render(query, **{"users_id": users_id, "token": token}) + await conn.fetch(rquery, *args) + await conn.close() + async def get_user_token(self, users_id: int) -> str: + """ """ + query = """ + SELECT token FROM user_keys WHERE users_id = :users_id + """ + conn = await asyncpg.connect(settings.DATABASE_WRITE_URL) + rquery, args = render(query, **{"users_id": users_id}) + api_token = await conn.fetch(rquery, *args) + await conn.close() + return api_token[0][0] + + async def generate_user_token(self, users_id: int) -> str: """ - calls the get_user_token plpgsql function to vefiry user email and generate API token + calls the get_user_token plpgsql function to verify user email and generate API token """ query = """ SELECT * FROM get_user_token(:users_id) diff --git a/openaq_api/openaq_api/main.py b/openaq_api/openaq_api/main.py index 9a1f04b..9e53dd0 100644 --- a/openaq_api/openaq_api/main.py +++ b/openaq_api/openaq_api/main.py @@ -23,6 +23,7 @@ from openaq_api.middleware import ( CacheControlMiddleware, LoggingMiddleware, + PrivatePathsMiddleware, RateLimiterMiddleWare, ) from openaq_api.models.logging import ( @@ -47,6 +48,7 @@ # V3 routers from openaq_api.v3.routers import ( + auth, countries, instruments, locations, @@ -91,6 +93,8 @@ def default(obj): return round(obj, 5) if isinstance(obj, datetime.datetime): return obj.strptime("%Y-%m-%dT%H:%M:%SZ") + if isinstance(obj, datetime.date): + return obj.strptime("%Y-%m-%d") class ORJSONResponse(JSONResponse): @@ -99,16 +103,38 @@ def render(self, content: Any) -> bytes: return orjson.dumps(content, default=default) +redis_client = None # initialize for generalize_schema.py + + +@asynccontextmanager +async def lifespan(app: FastAPI): + if not hasattr(app.state, "pool"): + logger.debug("initializing connection pool") + app.state.pool = await db_pool(None) + logger.debug("Connection pool established") + + if hasattr(app.state, "counter"): + app.state.counter += 1 + else: + app.state.counter = 0 + app.state.redis_client = redis_client + yield + if hasattr(app.state, "pool") and not settings.USE_SHARED_POOL: + logger.debug("Closing connection") + await app.state.pool.close() + delattr(app.state, "pool") + logger.debug("Connection closed") + + app = FastAPI( title="OpenAQ", description="OpenAQ API", version="2.0.0", default_response_class=ORJSONResponse, docs_url="/docs", + lifespan=lifespan, ) -redis_client = None # initialize for generalize_schema.py - if settings.RATE_LIMITING is True: if settings.RATE_LIMITING: @@ -126,6 +152,7 @@ def render(self, content: Any) -> bytes: logging.error( InfrastructureErrorLog(detail=f"failed to connect to redis: {e}") ) + print(redis_client) logger.debug("Redis connected") if redis_client: app.add_middleware( @@ -152,8 +179,7 @@ def render(self, content: Any) -> bytes: app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=900") app.add_middleware(LoggingMiddleware) app.add_middleware(GZipMiddleware, minimum_size=1000) - -app.include_router(auth_router) +app.add_middleware(PrivatePathsMiddleware) class OpenAQValidationResponseDetail(BaseModel): @@ -171,57 +197,33 @@ async def openaq_request_validation_exception_handler( request: Request, exc: RequestValidationError ): return ORJSONResponse(status_code=422, content=jsonable_encoder(str(exc))) - return PlainTextResponse(str(exc)) - print("\n\n\n\n\n") - print(str(exc)) - print("\n\n\n\n\n") - detail = orjson.loads(str(exc)) - logger.debug(traceback.format_exc()) - logger.info( - UnprocessableEntityLog(request=request, detail=str(exc)).model_dump_json() - ) - detail = OpenAQValidationResponse(detail=detail) - return ORJSONResponse(status_code=422, content=jsonable_encoder(detail)) + #return PlainTextResponse(str(exc)) + # print("\n\n\n\n\n") + # print(str(exc)) + # print("\n\n\n\n\n") + # detail = orjson.loads(str(exc)) + # logger.debug(traceback.format_exc()) + # logger.info( + # UnprocessableEntityLog(request=request, detail=str(exc)).model_dump_json() + # ) + # detail = OpenAQValidationResponse(detail=detail) + #return ORJSONResponse(status_code=422, content=jsonable_encoder(detail)) @app.exception_handler(ValidationError) async def openaq_exception_handler(request: Request, exc: ValidationError): return ORJSONResponse(status_code=422, content=jsonable_encoder(str(exc))) - - detail = orjson.loads(exc.model_dump_json()) - logger.debug(traceback.format_exc()) - logger.error( - ModelValidationError( - request=request, detail=exc.jsmodel_dump_jsonon() - ).model_dump_json() - ) - return ORJSONResponse(status_code=422, content=jsonable_encoder(detail)) + # detail = orjson.loads(exc.model_dump_json()) + # logger.debug(traceback.format_exc()) + # logger.error( + # ModelValidationError( + # request=request, detail=exc.jsmodel_dump_jsonon() + # ).model_dump_json() + # ) + #return ORJSONResponse(status_code=422, content=jsonable_encoder(detail)) # return ORJSONResponse(status_code=500, content={"message": "internal server error"}) -@app.on_event("startup") -async def startup_event(): - if not hasattr(app.state, "pool"): - logger.debug("initializing connection pool") - app.state.pool = await db_pool(None) - logger.debug("Connection pool established") - - if hasattr(app.state, "counter"): - app.state.counter += 1 - else: - app.state.counter = 0 - app.state.redis_client = redis_client - - -@app.on_event("shutdown") -async def shutdown_event(): - if hasattr(app.state, "pool") and not settings.USE_SHARED_POOL: - logger.debug("Closing connection") - await app.state.pool.close() - delattr(app.state, "pool") - logger.debug("Connection closed") - - @app.get("/ping", include_in_schema=False) def pong(): """ @@ -239,6 +241,7 @@ def favico(): # v3 +app.include_router(auth.router) app.include_router(instruments.router) app.include_router(locations.router) app.include_router(parameters.router) @@ -267,6 +270,7 @@ def favico(): static_dir = Path.joinpath(Path(__file__).resolve().parent, "static") + app.mount("/", StaticFiles(directory=str(static_dir), html=True)) diff --git a/openaq_api/openaq_api/middleware.py b/openaq_api/openaq_api/middleware.py index b70a570..f8e64cb 100644 --- a/openaq_api/openaq_api/middleware.py +++ b/openaq_api/openaq_api/middleware.py @@ -116,6 +116,24 @@ async def dispatch(self, request: Request, call_next): return response +class PrivatePathsMiddleware(BaseHTTPMiddleware): + """ + Middleware to protect private endpoints with an API key + """ + + async def dispatch(self, request: Request, call_next): + route = request.url.path + if "/auth" in route: + auth = request.headers.get("x-api-key", None) + if auth != settings.EXPLORER_API_KEY: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"message": "invalid credentials"}, + ) + response = await call_next(request) + return response + + class RateLimiterMiddleWare(BaseHTTPMiddleware): def __init__( self, @@ -167,8 +185,12 @@ def limited_path(route: str) -> bool: async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: + print("RATE LIMIT\n\n\n") route = request.url.path auth = request.headers.get("x-api-key", None) + if auth == settings.EXPLORER_API_KEY: + response = await call_next(request) + return response limit = self.rate_amount now = datetime.now() key = f"{request.client.host}:{now.year}{now.month}{now.day}{now.hour}{now.minute}" diff --git a/openaq_api/openaq_api/routers/auth.py b/openaq_api/openaq_api/routers/auth.py index 1cf5943..8354d13 100644 --- a/openaq_api/openaq_api/routers/auth.py +++ b/openaq_api/openaq_api/routers/auth.py @@ -6,10 +6,11 @@ from email.message import EmailMessage import boto3 -from fastapi import APIRouter, Depends, Form, HTTPException, Request, status +from fastapi import APIRouter, Body, Depends, Form, HTTPException, Request, status from fastapi.responses import RedirectResponse from fastapi.templating import Jinja2Templates from passlib.hash import pbkdf2_sha256 +from pydantic import BaseModel from ..db import DB from ..forms.register import RegisterForm, UserExistsException @@ -220,6 +221,49 @@ async def verify(request: Request, verification_code: str, db: DB = Depends()): ) +class RegenerateTokenBody(BaseModel): + users_id: int + token: str + + +@router.post("/regenerate-token") +async def regenerate_token( + token: int = Body(..., embed=True) + # request: Request, + # db: DB = Depends(), +): + """ """ + _token = token + print(_token) + try: + # db.get_user_token + # await db.regenerate_user_token(body.users_id, _token) + # token = await db.get_user_token(body.users_id) + # redis_client = getattr(request.app.state, "redis_client") + # print("REDIS", redis_client) + # if redis_client: + # await redis_client.srem("keys", _token) + # await redis_client.sadd("keys", token) + return {"success"} + except Exception as e: + return e + + +@router.post("/send-verification") +async def get_register( + request: Request, + users_id: int, + db: DB = Depends(), +): + user = db.get_user(users_id=users_id) + full_name = user[0] + email_address = user[1] + verification_code = user[2] + response = send_verification_email(verification_code, full_name, email_address) + logger.info(InfoLog(detail=json.dumps(response)).model_dump_json()) + return RedirectResponse("/check-email", status_code=status.HTTP_303_SEE_OTHER) + + @router.get("/register") async def get_register(request: Request): return templates.TemplateResponse("register/index.html", {"request": request}) diff --git a/openaq_api/openaq_api/settings.py b/openaq_api/openaq_api/settings.py index d96839e..1b09a90 100644 --- a/openaq_api/openaq_api/settings.py +++ b/openaq_api/openaq_api/settings.py @@ -37,6 +37,8 @@ class Settings(BaseSettings): EMAIL_SENDER: str | None = None + EXPLORER_API_KEY: str + @computed_field(return_type=str, alias="DATABASE_READ_URL") @property def DATABASE_READ_URL(self): diff --git a/openaq_api/openaq_api/v3/models/queries.py b/openaq_api/openaq_api/v3/models/queries.py index 55104c0..aade5b8 100644 --- a/openaq_api/openaq_api/v3/models/queries.py +++ b/openaq_api/openaq_api/v3/models/queries.py @@ -524,6 +524,7 @@ class PeriodNames(StrEnum): hod = "hod" dow = "dow" moy = "moy" + raw = "raw" class PeriodNameQuery(QueryBaseModel): @@ -698,13 +699,13 @@ def validate_bbox_in_range(cls, v): errors = [] bbox = [float(x) for x in v.split(",")] minx, miny, maxx, maxy = bbox - if not minx >= -180 and minx <= 180: + if not (minx >= -180 and minx <= 180): errors.append("X min must be between -180 and 180") - if not miny >= -90 and miny <= 90: + if not (miny >= -90 and miny <= 90): errors.append("Y min must be between -90 and 90") - if not maxx >= -180 and maxx <= 180: + if not (maxx >= -180 and maxx <= 180): errors.append("X max must be between -180 and 180") - if not maxy >= -90 and maxy <= 90: + if not (maxy >= -90 and maxy <= 90): errors.append("Y max must be between -90 and 90") if minx > maxx: errors.append("X max must be greater than or equal to X min") diff --git a/openaq_api/openaq_api/v3/models/responses.py b/openaq_api/openaq_api/v3/models/responses.py index ba70f97..e7bf9fb 100644 --- a/openaq_api/openaq_api/v3/models/responses.py +++ b/openaq_api/openaq_api/v3/models/responses.py @@ -22,9 +22,6 @@ class OpenAQResult(JsonBase): results: list[Any] = [] -# - - class DatetimeObject(JsonBase): utc: str local: str @@ -214,11 +211,12 @@ class Location(JsonBase): class Measurement(JsonBase): - period: Period + #datetime: DatetimeObject value: float parameter: ParameterBase + period: Period | None = None coordinates: Coordinates | None = None - summary: Summary | None = None + #summary: Summary | None = None coverage: Coverage | None = None diff --git a/openaq_api/openaq_api/v3/routers/auth.py b/openaq_api/openaq_api/v3/routers/auth.py new file mode 100644 index 0000000..c7ceee4 --- /dev/null +++ b/openaq_api/openaq_api/v3/routers/auth.py @@ -0,0 +1,323 @@ +import json +import logging +import os +import pathlib +from email.message import EmailMessage + +import boto3 +from fastapi import APIRouter, Body, Depends, HTTPException, Request, status +from fastapi.responses import RedirectResponse +from fastapi.templating import Jinja2Templates + +from openaq_api.db import DB +from openaq_api.models.logging import AuthLog, ErrorLog, InfoLog, SESEmailLog +from openaq_api.settings import settings +from openaq_api.v3.models.responses import JsonBase + +logger = logging.getLogger("auth") + + +templates = Jinja2Templates( + directory=os.path.join(str(pathlib.Path(__file__).parent.parent), "templates") +) + +router = APIRouter( + prefix="/auth", + include_in_schema=False, +) + + +def send_change_password_email(full_name: str, email: str): + ses_client = boto3.client("ses") + TEXT_EMAIL_CONTENT = """ + We are contacting you to notify you that your OpenAQ Explorer password has been changed. + + If you did not make this change, please contact info@openaq.org. + """ + HTML_EMAIL_CONTENT = """ + +
+ +
+ We are contacting you to notify you that your OpenAQ Explorer password has been changed. +If you did not make this change, please contact info@openaq.org. + |
+
+ Thank you for signing up for an OpenAQ Explorer Account+Click the following link to verify your email: + https://explore.openaq.org/verify/{verification_code} + |
+
+ OpenAQ password reset requests+You have requested a password reset for your OpenAQ Explorer account. Please visit the following link (expires in 30 minutes): + https://explore.openaq.org/new-password?code={verification_code} + |
+
+ OpenAQ password reset+This email confirms you have successfully changed your password for you OpenAQ Explorer account. +If you believe you have recieved this email in error please reach out to info@openaq.org + |
+