diff --git a/.github/workflows/deploy-prod.yml b/.github/workflows/deploy-prod.yml index 96a93c9..063629f 100644 --- a/.github/workflows/deploy-prod.yml +++ b/.github/workflows/deploy-prod.yml @@ -42,6 +42,8 @@ env: EMAIL_SENDER: ${{ secrets.EMAIL_SENDER }} + EXPLORER_API_KEY: ${{ secrets.EXPLORER_API_KEY }} + jobs: deploy: diff --git a/.github/workflows/deploy-staging.yml b/.github/workflows/deploy-staging.yml index cd9beca..9192076 100644 --- a/.github/workflows/deploy-staging.yml +++ b/.github/workflows/deploy-staging.yml @@ -40,6 +40,9 @@ env: RATE_TIME: 1 EMAIL_SENDER: ${{ secrets.EMAIL_SENDER }} + + EXPLORER_API_KEY: ${{ secrets.EXPLORER_API_KEY }} + jobs: deploy: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 409e70f..e3eab0b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,6 +17,7 @@ env: DATABASE_DB: "placeholder" DATABASE_HOST: "placeholder" DATABASE_PORT: 42 + EXPLORER_API_KEY: "placeholder" jobs: test: diff --git a/openaq_api/openaq_api/db.py b/openaq_api/openaq_api/db.py index cde570d..02fd542 100644 --- a/openaq_api/openaq_api/db.py +++ b/openaq_api/openaq_api/db.py @@ -21,7 +21,6 @@ allowed_config_params = ["work_mem"] - DEFAULT_CONNECTION_TIMEOUT = 6 MAX_CONNECTION_TIMEOUT = 15 @@ -34,6 +33,7 @@ def default(obj): # function is used in the `cached` decorator and without it # we will get a number of arguments error + def dbkey(m, f, query, args, timeout=None, config=None): j = orjson.dumps( args, option=orjson.OPT_OMIT_MICROSECONDS, default=default @@ -115,7 +115,9 @@ async def fetch( q = f"SELECT set_config('{param}', $1, TRUE)" s = await con.execute(q, str(value)) if not isinstance(timeout, (str, int)): - logger.warning(f"Non int or string timeout value passed - {timeout}") + logger.warning( + f"Non int or string timeout value passed - {timeout}" + ) timeout = DEFAULT_CONNECTION_TIMEOUT r = await wait_for(con.fetch(rquery, *args), timeout=timeout) await tr.commit() @@ -193,9 +195,83 @@ async def create_user(self, user: User) -> str: await conn.close() return verification_token[0][0] + async def get_user(self, users_id: int) -> str: + """ + gets user info from users table and entities table + """ + query = """ + SELECT + e.full_name + , u.email_address + , u.verification_code + FROM + users u + JOIN + users_entities USING (users_id) + JOIN + entities e USING (entities_id) + WHERE + u.users_id = :users_id + """ + conn = await asyncpg.connect(settings.DATABASE_READ_URL) + rquery, args = render(query, **{"users_id": users_id}) + user = await conn.fetch(rquery, *args) + await conn.close() + return user[0] + + async def generate_verification_code(self, email_address: str) -> str: + """ + gets user info from users table and entities table + """ + query = """ + UPDATE + users + SET + verification_code = generate_token() + , expires_on = (timestamptz (NOW() + INTERVAL '30min')) + WHERE + email_address = :email_address + RETURNING verification_code as "verificationCode" + """ + conn = await asyncpg.connect(settings.DATABASE_WRITE_URL) + rquery, args = render(query, **{"email_address": email_address}) + row = await conn.fetch(rquery, *args) + await conn.close() + return row[0][0] + + async def regenerate_user_token(self, users_id: int, token: str) -> str: + """ + calls the get_user_token plpgsql function to verify user email and generate API token + """ + query = """ + UPDATE + user_keys + SET + token = generate_token() + WHERE + users_id = :users_id + AND + token = :token + """ + conn = await asyncpg.connect(settings.DATABASE_WRITE_URL) + rquery, args = render(query, **{"users_id": users_id, "token": token}) + await conn.fetch(rquery, *args) + await conn.close() + async def get_user_token(self, users_id: int) -> str: + """ """ + query = """ + SELECT token FROM user_keys WHERE users_id = :users_id + """ + conn = await asyncpg.connect(settings.DATABASE_WRITE_URL) + rquery, args = render(query, **{"users_id": users_id}) + api_token = await conn.fetch(rquery, *args) + await conn.close() + return api_token[0][0] + + async def generate_user_token(self, users_id: int) -> str: """ - calls the get_user_token plpgsql function to vefiry user email and generate API token + calls the get_user_token plpgsql function to verify user email and generate API token """ query = """ SELECT * FROM get_user_token(:users_id) diff --git a/openaq_api/openaq_api/main.py b/openaq_api/openaq_api/main.py index 9a1f04b..9e53dd0 100644 --- a/openaq_api/openaq_api/main.py +++ b/openaq_api/openaq_api/main.py @@ -23,6 +23,7 @@ from openaq_api.middleware import ( CacheControlMiddleware, LoggingMiddleware, + PrivatePathsMiddleware, RateLimiterMiddleWare, ) from openaq_api.models.logging import ( @@ -47,6 +48,7 @@ # V3 routers from openaq_api.v3.routers import ( + auth, countries, instruments, locations, @@ -91,6 +93,8 @@ def default(obj): return round(obj, 5) if isinstance(obj, datetime.datetime): return obj.strptime("%Y-%m-%dT%H:%M:%SZ") + if isinstance(obj, datetime.date): + return obj.strptime("%Y-%m-%d") class ORJSONResponse(JSONResponse): @@ -99,16 +103,38 @@ def render(self, content: Any) -> bytes: return orjson.dumps(content, default=default) +redis_client = None # initialize for generalize_schema.py + + +@asynccontextmanager +async def lifespan(app: FastAPI): + if not hasattr(app.state, "pool"): + logger.debug("initializing connection pool") + app.state.pool = await db_pool(None) + logger.debug("Connection pool established") + + if hasattr(app.state, "counter"): + app.state.counter += 1 + else: + app.state.counter = 0 + app.state.redis_client = redis_client + yield + if hasattr(app.state, "pool") and not settings.USE_SHARED_POOL: + logger.debug("Closing connection") + await app.state.pool.close() + delattr(app.state, "pool") + logger.debug("Connection closed") + + app = FastAPI( title="OpenAQ", description="OpenAQ API", version="2.0.0", default_response_class=ORJSONResponse, docs_url="/docs", + lifespan=lifespan, ) -redis_client = None # initialize for generalize_schema.py - if settings.RATE_LIMITING is True: if settings.RATE_LIMITING: @@ -126,6 +152,7 @@ def render(self, content: Any) -> bytes: logging.error( InfrastructureErrorLog(detail=f"failed to connect to redis: {e}") ) + print(redis_client) logger.debug("Redis connected") if redis_client: app.add_middleware( @@ -152,8 +179,7 @@ def render(self, content: Any) -> bytes: app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=900") app.add_middleware(LoggingMiddleware) app.add_middleware(GZipMiddleware, minimum_size=1000) - -app.include_router(auth_router) +app.add_middleware(PrivatePathsMiddleware) class OpenAQValidationResponseDetail(BaseModel): @@ -171,57 +197,33 @@ async def openaq_request_validation_exception_handler( request: Request, exc: RequestValidationError ): return ORJSONResponse(status_code=422, content=jsonable_encoder(str(exc))) - return PlainTextResponse(str(exc)) - print("\n\n\n\n\n") - print(str(exc)) - print("\n\n\n\n\n") - detail = orjson.loads(str(exc)) - logger.debug(traceback.format_exc()) - logger.info( - UnprocessableEntityLog(request=request, detail=str(exc)).model_dump_json() - ) - detail = OpenAQValidationResponse(detail=detail) - return ORJSONResponse(status_code=422, content=jsonable_encoder(detail)) + #return PlainTextResponse(str(exc)) + # print("\n\n\n\n\n") + # print(str(exc)) + # print("\n\n\n\n\n") + # detail = orjson.loads(str(exc)) + # logger.debug(traceback.format_exc()) + # logger.info( + # UnprocessableEntityLog(request=request, detail=str(exc)).model_dump_json() + # ) + # detail = OpenAQValidationResponse(detail=detail) + #return ORJSONResponse(status_code=422, content=jsonable_encoder(detail)) @app.exception_handler(ValidationError) async def openaq_exception_handler(request: Request, exc: ValidationError): return ORJSONResponse(status_code=422, content=jsonable_encoder(str(exc))) - - detail = orjson.loads(exc.model_dump_json()) - logger.debug(traceback.format_exc()) - logger.error( - ModelValidationError( - request=request, detail=exc.jsmodel_dump_jsonon() - ).model_dump_json() - ) - return ORJSONResponse(status_code=422, content=jsonable_encoder(detail)) + # detail = orjson.loads(exc.model_dump_json()) + # logger.debug(traceback.format_exc()) + # logger.error( + # ModelValidationError( + # request=request, detail=exc.jsmodel_dump_jsonon() + # ).model_dump_json() + # ) + #return ORJSONResponse(status_code=422, content=jsonable_encoder(detail)) # return ORJSONResponse(status_code=500, content={"message": "internal server error"}) -@app.on_event("startup") -async def startup_event(): - if not hasattr(app.state, "pool"): - logger.debug("initializing connection pool") - app.state.pool = await db_pool(None) - logger.debug("Connection pool established") - - if hasattr(app.state, "counter"): - app.state.counter += 1 - else: - app.state.counter = 0 - app.state.redis_client = redis_client - - -@app.on_event("shutdown") -async def shutdown_event(): - if hasattr(app.state, "pool") and not settings.USE_SHARED_POOL: - logger.debug("Closing connection") - await app.state.pool.close() - delattr(app.state, "pool") - logger.debug("Connection closed") - - @app.get("/ping", include_in_schema=False) def pong(): """ @@ -239,6 +241,7 @@ def favico(): # v3 +app.include_router(auth.router) app.include_router(instruments.router) app.include_router(locations.router) app.include_router(parameters.router) @@ -267,6 +270,7 @@ def favico(): static_dir = Path.joinpath(Path(__file__).resolve().parent, "static") + app.mount("/", StaticFiles(directory=str(static_dir), html=True)) diff --git a/openaq_api/openaq_api/middleware.py b/openaq_api/openaq_api/middleware.py index b70a570..f8e64cb 100644 --- a/openaq_api/openaq_api/middleware.py +++ b/openaq_api/openaq_api/middleware.py @@ -116,6 +116,24 @@ async def dispatch(self, request: Request, call_next): return response +class PrivatePathsMiddleware(BaseHTTPMiddleware): + """ + Middleware to protect private endpoints with an API key + """ + + async def dispatch(self, request: Request, call_next): + route = request.url.path + if "/auth" in route: + auth = request.headers.get("x-api-key", None) + if auth != settings.EXPLORER_API_KEY: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"message": "invalid credentials"}, + ) + response = await call_next(request) + return response + + class RateLimiterMiddleWare(BaseHTTPMiddleware): def __init__( self, @@ -167,8 +185,12 @@ def limited_path(route: str) -> bool: async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: + print("RATE LIMIT\n\n\n") route = request.url.path auth = request.headers.get("x-api-key", None) + if auth == settings.EXPLORER_API_KEY: + response = await call_next(request) + return response limit = self.rate_amount now = datetime.now() key = f"{request.client.host}:{now.year}{now.month}{now.day}{now.hour}{now.minute}" diff --git a/openaq_api/openaq_api/routers/auth.py b/openaq_api/openaq_api/routers/auth.py index 1cf5943..8354d13 100644 --- a/openaq_api/openaq_api/routers/auth.py +++ b/openaq_api/openaq_api/routers/auth.py @@ -6,10 +6,11 @@ from email.message import EmailMessage import boto3 -from fastapi import APIRouter, Depends, Form, HTTPException, Request, status +from fastapi import APIRouter, Body, Depends, Form, HTTPException, Request, status from fastapi.responses import RedirectResponse from fastapi.templating import Jinja2Templates from passlib.hash import pbkdf2_sha256 +from pydantic import BaseModel from ..db import DB from ..forms.register import RegisterForm, UserExistsException @@ -220,6 +221,49 @@ async def verify(request: Request, verification_code: str, db: DB = Depends()): ) +class RegenerateTokenBody(BaseModel): + users_id: int + token: str + + +@router.post("/regenerate-token") +async def regenerate_token( + token: int = Body(..., embed=True) + # request: Request, + # db: DB = Depends(), +): + """ """ + _token = token + print(_token) + try: + # db.get_user_token + # await db.regenerate_user_token(body.users_id, _token) + # token = await db.get_user_token(body.users_id) + # redis_client = getattr(request.app.state, "redis_client") + # print("REDIS", redis_client) + # if redis_client: + # await redis_client.srem("keys", _token) + # await redis_client.sadd("keys", token) + return {"success"} + except Exception as e: + return e + + +@router.post("/send-verification") +async def get_register( + request: Request, + users_id: int, + db: DB = Depends(), +): + user = db.get_user(users_id=users_id) + full_name = user[0] + email_address = user[1] + verification_code = user[2] + response = send_verification_email(verification_code, full_name, email_address) + logger.info(InfoLog(detail=json.dumps(response)).model_dump_json()) + return RedirectResponse("/check-email", status_code=status.HTTP_303_SEE_OTHER) + + @router.get("/register") async def get_register(request: Request): return templates.TemplateResponse("register/index.html", {"request": request}) diff --git a/openaq_api/openaq_api/settings.py b/openaq_api/openaq_api/settings.py index d96839e..1b09a90 100644 --- a/openaq_api/openaq_api/settings.py +++ b/openaq_api/openaq_api/settings.py @@ -37,6 +37,8 @@ class Settings(BaseSettings): EMAIL_SENDER: str | None = None + EXPLORER_API_KEY: str + @computed_field(return_type=str, alias="DATABASE_READ_URL") @property def DATABASE_READ_URL(self): diff --git a/openaq_api/openaq_api/v3/models/queries.py b/openaq_api/openaq_api/v3/models/queries.py index 55104c0..aade5b8 100644 --- a/openaq_api/openaq_api/v3/models/queries.py +++ b/openaq_api/openaq_api/v3/models/queries.py @@ -524,6 +524,7 @@ class PeriodNames(StrEnum): hod = "hod" dow = "dow" moy = "moy" + raw = "raw" class PeriodNameQuery(QueryBaseModel): @@ -698,13 +699,13 @@ def validate_bbox_in_range(cls, v): errors = [] bbox = [float(x) for x in v.split(",")] minx, miny, maxx, maxy = bbox - if not minx >= -180 and minx <= 180: + if not (minx >= -180 and minx <= 180): errors.append("X min must be between -180 and 180") - if not miny >= -90 and miny <= 90: + if not (miny >= -90 and miny <= 90): errors.append("Y min must be between -90 and 90") - if not maxx >= -180 and maxx <= 180: + if not (maxx >= -180 and maxx <= 180): errors.append("X max must be between -180 and 180") - if not maxy >= -90 and maxy <= 90: + if not (maxy >= -90 and maxy <= 90): errors.append("Y max must be between -90 and 90") if minx > maxx: errors.append("X max must be greater than or equal to X min") diff --git a/openaq_api/openaq_api/v3/models/responses.py b/openaq_api/openaq_api/v3/models/responses.py index ba70f97..e7bf9fb 100644 --- a/openaq_api/openaq_api/v3/models/responses.py +++ b/openaq_api/openaq_api/v3/models/responses.py @@ -22,9 +22,6 @@ class OpenAQResult(JsonBase): results: list[Any] = [] -# - - class DatetimeObject(JsonBase): utc: str local: str @@ -214,11 +211,12 @@ class Location(JsonBase): class Measurement(JsonBase): - period: Period + #datetime: DatetimeObject value: float parameter: ParameterBase + period: Period | None = None coordinates: Coordinates | None = None - summary: Summary | None = None + #summary: Summary | None = None coverage: Coverage | None = None diff --git a/openaq_api/openaq_api/v3/routers/auth.py b/openaq_api/openaq_api/v3/routers/auth.py new file mode 100644 index 0000000..c7ceee4 --- /dev/null +++ b/openaq_api/openaq_api/v3/routers/auth.py @@ -0,0 +1,323 @@ +import json +import logging +import os +import pathlib +from email.message import EmailMessage + +import boto3 +from fastapi import APIRouter, Body, Depends, HTTPException, Request, status +from fastapi.responses import RedirectResponse +from fastapi.templating import Jinja2Templates + +from openaq_api.db import DB +from openaq_api.models.logging import AuthLog, ErrorLog, InfoLog, SESEmailLog +from openaq_api.settings import settings +from openaq_api.v3.models.responses import JsonBase + +logger = logging.getLogger("auth") + + +templates = Jinja2Templates( + directory=os.path.join(str(pathlib.Path(__file__).parent.parent), "templates") +) + +router = APIRouter( + prefix="/auth", + include_in_schema=False, +) + + +def send_change_password_email(full_name: str, email: str): + ses_client = boto3.client("ses") + TEXT_EMAIL_CONTENT = """ + We are contacting you to notify you that your OpenAQ Explorer password has been changed. + + If you did not make this change, please contact info@openaq.org. + """ + HTML_EMAIL_CONTENT = """ + + + + + + + +
+

We are contacting you to notify you that your OpenAQ Explorer password has been changed.

+

If you did not make this change, please contact info@openaq.org.

+
+ + + """ + msg = EmailMessage() + msg.set_content(TEXT_EMAIL_CONTENT) + msg.add_alternative(HTML_EMAIL_CONTENT, subtype="html") + msg["Subject"] = "OpenAQ Explorer - Password changed" + msg["From"] = settings.EMAIL_SENDER + msg["To"] = email + response = ses_client.send_raw_email( + Source=settings.EMAIL_SENDER, + Destinations=[f"{full_name} <{email}>"], + RawMessage={"Data": msg.as_string()}, + ) + logger.info( + SESEmailLog( + detail=json.dumps( + { + "email": email, + "name": full_name, + "reponse": response, + } + ) + ).model_dump_json() + ) + return response + + +def send_verification_email(verification_code: str, full_name: str, email: str): + ses_client = boto3.client("ses") + TEXT_EMAIL_CONTENT = f""" + Thank you for signing up for an OpenAQ Explorer Account + Visit the following URL to verify your email: + https://explore.openaq.org/verify/{verification_code} + """ + HTML_EMAIL_CONTENT = f""" + + + + + + + +
+

Thank you for signing up for an OpenAQ Explorer Account

+

Click the following link to verify your email:

+ https://explore.openaq.org/verify/{verification_code} +
+ + + """ + msg = EmailMessage() + msg.set_content(TEXT_EMAIL_CONTENT) + msg.add_alternative(HTML_EMAIL_CONTENT, subtype="html") + msg["Subject"] = "OpenAQ Explorer - Verify your email" + msg["From"] = settings.EMAIL_SENDER + msg["To"] = email + response = ses_client.send_raw_email( + Source=settings.EMAIL_SENDER, + Destinations=[f"{full_name} <{email}>"], + RawMessage={"Data": msg.as_string()}, + ) + logger.info( + SESEmailLog( + detail=json.dumps( + { + "email": email, + "name": full_name, + "verificationCode": verification_code, + "reponse": response, + } + ) + ).model_dump_json() + ) + return response + + +def send_password_reset_email(verification_code: str, email: str): + ses_client = boto3.client("ses") + TEXT_EMAIL_CONTENT = f""" + You have requested a password reset for your OpenAQ Explorer account. Please visit the following link (expires in 30 minutes): + https://explore.openaq.org/new-password?code={verification_code} + """ + HTML_EMAIL_CONTENT = f""" + + + + + + + +
+

OpenAQ password reset requests

+

You have requested a password reset for your OpenAQ Explorer account. Please visit the following link (expires in 30 minutes):

+ https://explore.openaq.org/new-password?code={verification_code} +
+ + + """ + msg = EmailMessage() + msg.set_content(TEXT_EMAIL_CONTENT) + msg.add_alternative(HTML_EMAIL_CONTENT, subtype="html") + msg["Subject"] = "OpenAQ Explorer - Reset password request" + msg["From"] = settings.EMAIL_SENDER + msg["To"] = email + response = ses_client.send_raw_email( + Source=settings.EMAIL_SENDER, + Destinations=[f"<{email}>"], + RawMessage={"Data": msg.as_string()}, + ) + logger.info( + SESEmailLog( + detail=json.dumps( + { + "email": email, + "verificationCode": verification_code, + "reponse": response, + } + ) + ).model_dump_json() + ) + return response + + + +def send_password_changed_email(email: str): + ses_client = boto3.client("ses") + TEXT_EMAIL_CONTENT = """ + This email confirms you have successfully changed your password for you OpenAQ Explorer account. + + If you believe you have recieved this email in error please reach out to info@openaq.org. + """ + HTML_EMAIL_CONTENT = """ + + + + + + + +
+

OpenAQ password reset

+

This email confirms you have successfully changed your password for you OpenAQ Explorer account.

+

If you believe you have recieved this email in error please reach out to info@openaq.org

+
+ + + """ + msg = EmailMessage() + msg.set_content(TEXT_EMAIL_CONTENT) + msg.add_alternative(HTML_EMAIL_CONTENT, subtype="html") + msg["Subject"] = "OpenAQ Explorer - Reset password success" + msg["From"] = settings.EMAIL_SENDER + msg["To"] = email + response = ses_client.send_raw_email( + Source=settings.EMAIL_SENDER, + Destinations=[f"<{email}>"], + RawMessage={"Data": msg.as_string()}, + ) + logger.info( + SESEmailLog( + detail=json.dumps( + { + "email": email, + "reponse": response, + } + ) + ).model_dump_json() + ) + return response + +class RegenerateTokenBody(JsonBase): + users_id: int + token: str + + +@router.post("/regenerate-token") +async def get_register( + request: Request, + body: RegenerateTokenBody, + db: DB = Depends(), +): + """ """ + try: + user_token = await db.get_user_token(body.users_id) + if user_token != body.token: + return HTTPException(401) + await db.regenerate_user_token(body.users_id, body.token) + new_token = await db.get_user_token(body.users_id) + redis_client = getattr(request.app.state, "redis_client") + if redis_client: + await redis_client.srem("keys", body.token) + await redis_client.sadd("keys", new_token) + return {"message": "success"} + except Exception as e: + return e + + +class VerificationBody(JsonBase): + users_id: int + + +@router.post("/send-verification") +async def send_verification( + body: VerificationBody, + db: DB = Depends(), +): + user = await db.get_user(body.users_id) + full_name = user[0] + email_address = user[1] + verification_code = user[2] + response = send_verification_email(verification_code, full_name, email_address) + logger.info(InfoLog(detail=json.dumps(response)).model_dump_json()) + + +class PasswordResetEmailBody(JsonBase): + email_address: str + + +@router.post("/send-password-email") +async def request_password_reset_email( + body: PasswordResetEmailBody, + db: DB = Depends(), +): + email_address = body.email_address + verification_code = await db.generate_verification_code(email_address) + response = send_password_reset_email(verification_code, email_address) + logger.info(InfoLog(detail=json.dumps(response)).model_dump_json()) + + +@router.post("/send-password-changed-email") +async def password_changed_email( + body: PasswordResetEmailBody, +): + email_address = body.email_address + response = send_password_changed_email(email_address) + logger.info(InfoLog(detail=json.dumps(response)).model_dump_json()) + + +class VerifyBody(JsonBase): + users_id: int + + +@router.post("/verify") +async def verify_email( + request: Request, + body: VerificationBody, + db: DB = Depends(), +): + try: + token = await db.get_user_token(body.users_id) + redis_client = getattr(request.app.state, "redis_client") + if redis_client: + await redis_client.sadd("keys", token) + except Exception as e: + logger.error(ErrorLog(detail=f"something went wrong: {e}")) + return HTTPException(500) + return {"message": "success"} + + +class ChangePasswordEmailBody(JsonBase): + users_id: int + + +@router.post("/change-password-email") +async def change_password_email( + body: ChangePasswordEmailBody, + db: DB = Depends(), +): + """ """ + user = await db.get_user(body.users_id) + full_name = user[0] + email_address = user[1] + response = send_change_password_email(full_name, email_address) + logger.info(InfoLog(detail=json.dumps(response)).model_dump_json()) diff --git a/openaq_api/openaq_api/v3/routers/locations.py b/openaq_api/openaq_api/v3/routers/locations.py index 155d864..bc375e7 100644 --- a/openaq_api/openaq_api/v3/routers/locations.py +++ b/openaq_api/openaq_api/v3/routers/locations.py @@ -1,7 +1,7 @@ import logging from typing import Annotated from enum import StrEnum, auto -from fastapi import APIRouter, Depends, Path, Query +from fastapi import APIRouter, Depends, Path, Query, Request from openaq_api.db import DB from openaq_api.v3.models.queries import ( @@ -78,8 +78,7 @@ class LocationsQueries( MobileQuery, MonitorQuery, LocationsSorting, -): - ... +): ... @router.get( @@ -90,8 +89,10 @@ class LocationsQueries( ) async def location_get( locations: Annotated[LocationPathQuery, Depends(LocationPathQuery.depends())], + request: Request, db: DB = Depends(), ): + print("FOO", request.app.state.redis_client) response = await fetch_locations(locations, db) return response diff --git a/openaq_api/openaq_api/v3/routers/sensors.py b/openaq_api/openaq_api/v3/routers/sensors.py index 178adaf..0055af0 100644 --- a/openaq_api/openaq_api/v3/routers/sensors.py +++ b/openaq_api/openaq_api/v3/routers/sensors.py @@ -1,7 +1,9 @@ import logging -from typing import Annotated +from typing import Annotated, Any -from fastapi import APIRouter, Depends, Path +from fastapi import APIRouter, Depends, Path, HTTPException +from pydantic import field_validator +from datetime import date, datetime from openaq_api.db import DB from openaq_api.v3.models.queries import ( @@ -35,6 +37,7 @@ class SensorQuery(QueryBaseModel): def where(self): return "s.sensors_id = :sensors_id" + class LocationSensorQuery(QueryBaseModel): locations_id: int = Path( ..., description="Limit the results to a specific sensors id", ge=1 @@ -43,6 +46,7 @@ class LocationSensorQuery(QueryBaseModel): def where(self): return "n.sensor_nodes_id = :locations_id" + class SensorMeasurementsQueries( Paging, SensorQuery, @@ -50,7 +54,20 @@ class SensorMeasurementsQueries( DateToQuery, PeriodNameQuery, ): - ... + @field_validator('date_to', 'date_from') + @classmethod + def must_be_date_if_aggregating_to_day(cls, v: Any, values) -> str: + if values.data.get('period_name') in ['dow','day','moy','month']: + if isinstance(v, datetime): + # this is to deal with the error that is thrown when using ValueError with datetime objects + err = [{ + "type": "value_error", + "msg": "When aggregating data to daily values or higher you can only use whole dates in the `date_from` and `date_to` parameters. E.g. 2024-01-01, 2024-01-01 00:00:00", + "input": str(v) + }] + raise HTTPException(status_code=422, detail=err) + return v + @router.get( @@ -192,6 +209,63 @@ async def fetch_measurements(q, db): ORDER BY datetime {query.pagination()} """ + elif q.period_name in ["raw"]: + sql = f""" + WITH sensor AS ( + SELECT s.sensors_id + , sn.sensor_nodes_id + , s.data_averaging_period_seconds + , s.data_logging_period_seconds + , format('%ssec', s.data_averaging_period_seconds)::interval as averaging_interval + , format('%ssec', s.data_logging_period_seconds)::interval as logging_interval + , tz.tzid as timezone + , m.measurands_id + , m.measurand + , m.units + , timezone(tz.tzid, :date_from::timestamp) as datetime_from + , timezone(tz.tzid, :date_to::timestamp) as datetime_to + FROM sensors s + , sensor_systems sy + , sensor_nodes sn + , timezones tz + , measurands m + WHERE s.sensor_systems_id = sy.sensor_systems_id + AND sy.sensor_nodes_id = sn.sensor_nodes_id + AND sn.timezones_id = tz.gid + AND s.sensors_id = :sensors_id + AND s.measurands_id = m.measurands_id) + SELECT m.sensors_id + , value + , get_datetime_object(m.datetime, s.timezone) + , json_build_object( + 'id', s.measurands_id + , 'units', s.units + , 'name', s.measurand + ) as parameter + , json_build_object( + 'label', 'raw' + , 'interval', s.logging_interval + , 'datetime_from', get_datetime_object(m.datetime - s.logging_interval, s.timezone) + , 'datetime_to', get_datetime_object(m.datetime, s.timezone) + ) as period + , json_build_object( + 'expected_count', 1 + , 'observed_count', 1 + , 'expected_interval', s.logging_interval + , 'observed_interval', s.averaging_interval + , 'datetime_from', get_datetime_object(m.datetime - s.averaging_interval, s.timezone) + , 'datetime_to', get_datetime_object(m.datetime, s.timezone) + , 'percent_complete', 100 + , 'percent_coverage', (s.data_averaging_period_seconds/s.data_logging_period_seconds)*100 + ) as coverage + FROM measurements m + JOIN sensor s USING (sensors_id) + WHERE datetime > datetime_from + AND datetime <= datetime_to + AND s.sensors_id = :sensors_id + ORDER BY datetime + {query.pagination()} + """ elif q.period_name in ["day", "month"]: # Query for the aggregate data if q.period_name == "day": @@ -273,36 +347,78 @@ async def fetch_measurements(q, db): {query.pagination()} """ elif q.period_name in ["hod","dow","moy"]: - - fmt = "" if q.period_name == "hod": - fmt = "HH24" - dur = "01:00:00" - prd = "hour" + q.period_name = "hour" + period_format = "'HH24'" + period_first_offset = "'-1sec'" + period_last_offset = "'+1sec'" elif q.period_name == "dow": - fmt = "ID" - dur = "24:00:00" - prd = "day" - elif q.period_name == "mod": - fmt = "MM" - dur = "1 month" - prd = "month" + q.period_name = "day" + period_format = "'ID'" + period_first_offset = "'0sec'" + period_last_offset = "'0sec'" + elif q.period_name == "moy": + q.period_name = "month" + period_format = "'MM'" + period_first_offset = "'-1sec'" + period_last_offset = "'+1sec'" - q.period_name = prd sql = f""" -WITH trends AS ( -SELECT - sn.id + ----------------------------------- + -- start by getting some basic sensor information + -- and transforming the timestamps + ----------------------------------- + WITH sensor AS ( + SELECT s.sensors_id + , sn.sensor_nodes_id + , s.data_averaging_period_seconds + , s.data_logging_period_seconds + , tz.tzid as timezone + , m.measurands_id + , m.measurand + , m.units + , timezone(tz.tzid, :date_from::timestamp) as datetime_from + , timezone(tz.tzid, :date_to::timestamp) as datetime_to + FROM sensors s + , sensor_systems sy + , sensor_nodes sn + , timezones tz + , measurands m + WHERE s.sensor_systems_id = sy.sensor_systems_id + AND sy.sensor_nodes_id = sn.sensor_nodes_id + AND sn.timezones_id = tz.gid + AND s.sensors_id = :sensors_id + AND s.measurands_id = m.measurands_id + -------------------------------- + -- Then we calculate what we expect to find in the data + -------------------------------- + ), expected AS ( + SELECT to_char(timezone(s.timezone, dd - '1sec'::interval), {period_format}) as factor + , s.timezone + , COUNT(1) as n + , MIN(date_trunc(:period_name, dd + {period_first_offset}::interval)) as period_first + , MAX(date_trunc(:period_name, dd + {period_last_offset}::interval)) as period_last + FROM sensor s + , generate_series(s.datetime_from + '1hour'::interval, s.datetime_to, ('1hour')::interval) dd + GROUP BY 1,2 + ------------------------------------ + -- Then we query what we have in the db + -- we join the sensor CTE here so that we have access to the timezone + ------------------------------------ + ), observed AS ( + SELECT + s.sensors_id + , s.data_averaging_period_seconds + , s.data_logging_period_seconds + , s.timezone , s.measurands_id - , sn.timezone - , to_char(timezone(sn.timezone, datetime - '1sec'::interval), '{fmt}') as factor - , AVG(s.data_averaging_period_seconds) as avg_seconds - , AVG(s.data_logging_period_seconds) as log_seconds -, MAX(truncate_timestamp(datetime, :period_name, sn.timezone, '1{prd}'::interval)) as last_period -, MIN(timezone(sn.timezone, datetime - '1sec'::interval)) as first_datetime -, MAX(timezone(sn.timezone, datetime - '1sec'::interval)) as last_datetime - , COUNT(1) as value_count + , s.measurand + , s.units + , to_char(timezone(s.timezone, datetime - '1sec'::interval), {period_format}) as factor + , MIN(datetime) as coverage_first + , MAX(datetime) as coverage_last + , COUNT(1) as n , AVG(value_avg) as value_avg , STDDEV(value_avg) as value_sd , MIN(value_avg) as value_min @@ -314,46 +430,49 @@ async def fetch_measurements(q, db): , PERCENTILE_CONT(0.98) WITHIN GROUP(ORDER BY value_avg) as value_p98 , current_timestamp as calculated_on FROM hourly_data m - JOIN sensors s ON (m.sensors_id = s.sensors_id) - JOIN sensor_systems sy ON (s.sensor_systems_id = sy.sensor_systems_id) - JOIN locations_view_cached sn ON (sy.sensor_nodes_id = sn.id) - {query.where()} - GROUP BY 1, 2, 3, 4) - SELECT t.id - , json_build_object( - 'label', factor - , 'datetime_from', get_datetime_object(first_datetime, t.timezone) - , 'datetime_to', get_datetime_object(last_datetime, t.timezone) - , 'interval', '{dur}' - ) as period - , sig_digits(value_avg, 2) as value - , json_build_object( - 'id', t.measurands_id - , 'units', m.units - , 'name', m.measurand - ) as parameter - , json_build_object( - 'sd', t.value_sd - , 'min', t.value_min - , 'q02', t.value_p02 - , 'q25', t.value_p25 - , 'median', t.value_p50 - , 'q75', t.value_p75 - , 'q98', t.value_p98 - , 'max', t.value_max - ) as summary - , calculate_coverage( - t.value_count::int - , t.avg_seconds - , t.log_seconds - , expected_hours(first_datetime, last_datetime, '{prd}', factor) * 3600.0 -)||jsonb_build_object( - 'datetime_from', get_datetime_object(first_datetime, t.timezone) - , 'datetime_to', get_datetime_object(last_datetime, t.timezone) - ) as coverage - FROM trends t - JOIN measurands m ON (t.measurands_id = m.measurands_id) - {query.pagination()} + JOIN sensor s ON (m.sensors_id = s.sensors_id) + WHERE datetime > datetime_from + AND datetime <= datetime_to + AND s.sensors_id = :sensors_id + GROUP BY 1, 2, 3, 4, 5, 6, 7, 8) +----------------------------------------- +-- And finally we tie it all together +----------------------------------------- + SELECT o.sensors_id + , sig_digits(value_avg, 2) as value + , json_build_object( + 'id', o.measurands_id + , 'units', o.units + , 'name', o.measurand + ) as parameter + , json_build_object( + 'sd', o.value_sd + , 'min', o.value_min + , 'q02', o.value_p02 + , 'q25', o.value_p25 + , 'median', o.value_p50 + , 'q75', o.value_p75 + , 'q98', o.value_p98 + , 'max', o.value_max + ) as summary + , json_build_object( + 'label', e.factor + , 'datetime_from', get_datetime_object(e.period_first, o.timezone) + , 'datetime_to', get_datetime_object(e.period_last, o.timezone) + , 'interval', :period_name + ) as period + , calculate_coverage( + o.n::int + , o.data_averaging_period_seconds + , o.data_logging_period_seconds + , e.n * 3600.0)|| + jsonb_build_object( + 'datetime_from', get_datetime_object(o.coverage_first, o.timezone) + , 'datetime_to', get_datetime_object(o.coverage_last, o.timezone) + ) as coverage + FROM expected e + JOIN observed o ON (e.factor = o.factor) + {query.pagination()} """ return await db.fetchPage(sql, query.params())