Skip to content

Commit

Permalink
fix: replace OpenSensus with OpenTelemetry and fix exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
collinlokken committed Jan 9, 2025
1 parent 0caf81e commit 710e9f1
Show file tree
Hide file tree
Showing 14 changed files with 782 additions and 468 deletions.
976 changes: 647 additions & 329 deletions api/poetry.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "1.4.0" # x-release-please-version
description = "API for Template Fastapi React"
authors = []
license = ""
package-mode = false

[tool.poetry.dependencies]
cachetools = "^5.3.0"
Expand All @@ -14,10 +15,12 @@ uvicorn = {extras = ["standard"], version = "^0.21.1"}
pymongo = "4.1.1"
certifi = "^2023.7.22"
httpx = "^0.23.3"
opencensus-ext-azure = "^1.1.9"
pydantic = "^2.1"
pydantic-settings = "^2.0.1"
pydantic-extra-types = "^2.0.0"
azure-monitor-opentelemetry = "^1.6.2"
opentelemetry-instrumentation-fastapi = "^0.48b0"
cryptography = "^44.0.0"

[tool.poetry.dev-dependencies]
pre-commit = ">=3"
Expand Down
12 changes: 7 additions & 5 deletions api/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from starlette.middleware import Middleware

from authentication.authentication import auth_with_jwt
from common.exception_handlers import add_exception_handlers
from common.middleware import LocalLoggerMiddleware, OpenCensusRequestLoggingMiddleware
from common.middleware import LocalLoggerMiddleware
from common.responses import responses
from config import config
from features.health_check import health_check_feature
Expand Down Expand Up @@ -35,8 +34,6 @@ def create_app() -> FastAPI:
authenticated_routes.include_router(whoami_feature.router)

middleware = [Middleware(LocalLoggerMiddleware)]
if config.APPINSIGHTS_CONSTRING:
middleware.append(Middleware(OpenCensusRequestLoggingMiddleware))

app = FastAPI(
title="Template FastAPI React",
Expand All @@ -54,7 +51,12 @@ def create_app() -> FastAPI:
},
)

add_exception_handlers(app)
if config.APPINSIGHTS_CONSTRING:
from azure.monitor.opentelemetry import configure_azure_monitor
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor

configure_azure_monitor(connection_string=config.APPINSIGHTS_CONSTRING, logger_name="API")
FastAPIInstrumentor.instrument_app(app, excluded_urls="healthcheck")

app.include_router(authenticated_routes, dependencies=[Security(auth_with_jwt)])
app.include_router(public_routes)
Expand Down
10 changes: 5 additions & 5 deletions api/src/authentication/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi.security import OAuth2AuthorizationCodeBearer

from authentication.models import User
from common.exceptions import credentials_exception
from common.exceptions import UnauthorizedException
from common.logger import logger
from config import config

Expand All @@ -25,12 +25,12 @@ def get_JWK_client() -> jwt.PyJWKClient:
return jwt.PyJWKClient(oid_conf["jwks_uri"])
except Exception as error:
logger.error(f"Failed to fetch OpenId Connect configuration for '{config.OAUTH_WELL_KNOWN}': {error}")
raise credentials_exception
raise UnauthorizedException


def auth_with_jwt(jwt_token: str = Security(oauth2_scheme)) -> User:
if not jwt_token:
raise credentials_exception
raise UnauthorizedException
key = get_JWK_client().get_signing_key_from_jwt(jwt_token).key
try:
payload = jwt.decode(jwt_token, key, algorithms=["RS256"], audience=config.OAUTH_AUDIENCE)
Expand All @@ -41,8 +41,8 @@ def auth_with_jwt(jwt_token: str = Security(oauth2_scheme)) -> User:
user = User(user_id=payload["sub"], **payload)
except jwt.exceptions.InvalidTokenError as error:
logger.warning(f"Failed to decode JWT: {error}")
raise credentials_exception
raise UnauthorizedException

if user is None:
raise credentials_exception
raise UnauthorizedException
return user
53 changes: 36 additions & 17 deletions api/src/common/exception_handlers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import traceback
import uuid
from collections.abc import Callable, Coroutine
from typing import Any

from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.routing import APIRoute
from httpx import HTTPStatusError
from starlette import status
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.responses import JSONResponse, Response

from common.exceptions import (
ApplicationException,
Expand All @@ -16,29 +18,16 @@
ExceptionSeverity,
MissingPrivilegeException,
NotFoundException,
UnauthorizedException,
ValidationException,
)
from common.logger import logger


def add_exception_handlers(app: FastAPI) -> None:
# Handle custom exceptions
app.add_exception_handler(BadRequestException, generic_exception_handler)
app.add_exception_handler(ValidationException, generic_exception_handler)
app.add_exception_handler(NotFoundException, generic_exception_handler)
app.add_exception_handler(MissingPrivilegeException, generic_exception_handler)

# Override built-in default handler
app.add_exception_handler(RequestValidationError, validation_exception_handler) # type: ignore
app.add_exception_handler(HTTPStatusError, http_exception_handler)

# Fallback exception handler for all unexpected exceptions
app.add_exception_handler(Exception, fall_back_exception_handler)


def fall_back_exception_handler(request: Request, exc: Exception) -> JSONResponse:
error_id = uuid.uuid4()
traceback_string = " ".join(traceback.format_tb(tb=exc.__traceback__))
print(traceback_string)
logger.error(
f"Unexpected unhandled exception ({error_id}): {exc}",
extra={"custom_dimensions": {"Error ID": error_id, "Traceback": traceback_string}},
Expand Down Expand Up @@ -98,3 +87,33 @@ def http_exception_handler(request: Request, exc: HTTPStatusError) -> JSONRespon
debug=exc.response,
)
)


class ExceptionHandlingRoute(APIRoute):
"""APIRoute class for handling exceptions."""

def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
"""Intercept response and return correct exception response."""
original_route_handler = super().get_route_handler()

async def custom_route_handler(request: Request) -> Response:
try:
return await original_route_handler(request)
except BadRequestException as e:
return generic_exception_handler(request, e)
except ValidationException as e:
return generic_exception_handler(request, e)
except NotFoundException as e:
return generic_exception_handler(request, e)
except MissingPrivilegeException as e:
return generic_exception_handler(request, e)
except RequestValidationError as e:
return validation_exception_handler(request, e)
except HTTPStatusError as e:
return http_exception_handler(request, e)
except UnauthorizedException as e:
return generic_exception_handler(request, e)
except Exception as e:
return fall_back_exception_handler(request, e)

return custom_route_handler
15 changes: 9 additions & 6 deletions api/src/common/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from enum import Enum
from typing import Any

from fastapi import HTTPException
from pydantic import BaseModel
from starlette import status as request_status

Expand Down Expand Up @@ -98,8 +97,12 @@ def __init__(
self.type = self.__class__.__name__


credentials_exception = HTTPException(
status_code=request_status.HTTP_401_UNAUTHORIZED,
detail="Token validation failed",
headers={"WWW-Authenticate": "Bearer"},
)
class UnauthorizedException(ApplicationException):
def __init__(
self,
message: str = "Token validation failed",
debug: str = "Token was not valid for requested operation.",
extra: dict[str, Any] | None = None,
):
super().__init__(message, debug, extra, request_status.HTTP_401_UNAUTHORIZED)
self.type = self.__class__.__name__
46 changes: 0 additions & 46 deletions api/src/common/middleware.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import time

from azure.core.tracing import SpanKind
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace.attributes_helper import COMMON_ATTRIBUTES
from opencensus.trace.samplers import ProbabilitySampler
from opencensus.trace.tracer import Tracer
from starlette.datastructures import MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from common.logger import logger
from config import config


# These middlewares are written as "pure ASGI middleware", see: https://www.starlette.io/middleware/#pure-asgi-middleware
Expand Down Expand Up @@ -44,43 +38,3 @@ async def inner_send(message: Message) -> None:

await self.app(scope, receive, inner_send)
logger.info(f"{method} {path} - {process_time}ms - {response['status']}")


class OpenCensusRequestLoggingMiddleware:
exporter = AzureExporter(connection_string=config.APPINSIGHTS_CONSTRING) if config.APPINSIGHTS_CONSTRING else None
sampler = ProbabilitySampler(1.0)

def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)

tracer = Tracer(exporter=self.exporter, sampler=self.sampler)

path = scope["path"]
response: Message = {}

async def inner_send(message: Message) -> None:
nonlocal response
if message["type"] == "http.response.start":
response = message

await send(message)

if path == "/health-check": # Don't send health-check requests to Azure
return await self.app(scope, receive, send)

with tracer.span("main") as span:
span.span_kind = SpanKind.SERVER

await self.app(scope, receive, inner_send)

tracer.add_attribute_to_current_span(
attribute_key=COMMON_ATTRIBUTES["HTTP_STATUS_CODE"], attribute_value=response["status"]
)
host = next((header[1].decode() for header in scope["headers"] if header[0] == b"host"), "")
tracer.add_attribute_to_current_span(
attribute_key=COMMON_ATTRIBUTES["HTTP_URL"], attribute_value=f"{host}{path}"
)
4 changes: 4 additions & 0 deletions api/src/common/telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from opentelemetry import trace

# Creates a tracer from the global tracer provider
tracer = trace.get_tracer("tracer.global")
4 changes: 3 additions & 1 deletion api/src/features/health_check/health_check_feature.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from fastapi import APIRouter, status
from fastapi.responses import PlainTextResponse

router = APIRouter(tags=["health_check"], prefix="/health-check")
from common.exception_handlers import ExceptionHandlingRoute

router = APIRouter(tags=["health_check"], prefix="/health-check", route_class=ExceptionHandlingRoute)


@router.get(
Expand Down
3 changes: 2 additions & 1 deletion api/src/features/todo/todo_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from authentication.authentication import auth_with_jwt
from authentication.models import User
from common.exception_handlers import ExceptionHandlingRoute
from features.todo.repository.todo_repository import get_todo_repository
from features.todo.repository.todo_repository_interface import TodoRepositoryInterface
from features.todo.use_cases.add_todo import (
Expand All @@ -24,7 +25,7 @@
update_todo_use_case,
)

router = APIRouter(tags=["todos"], prefix="/todos")
router = APIRouter(tags=["todos"], prefix="/todos", route_class=ExceptionHandlingRoute)


@router.post("", operation_id="create")
Expand Down
8 changes: 8 additions & 0 deletions api/src/features/todo/use_cases/get_todo_all.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pydantic import BaseModel

from common.logger import logger
from common.telemetry import tracer
from features.todo.entities.todo_item import TodoItem
from features.todo.repository.todo_repository_interface import TodoRepositoryInterface

Expand All @@ -14,10 +16,16 @@ def from_entity(todo_item: TodoItem) -> "GetTodoAllResponse":
return GetTodoAllResponse(id=todo_item.id, title=todo_item.title, is_completed=todo_item.is_completed)


# Telemetry example: Initialize a span that will be used to log telemetry data
@tracer.start_as_current_span("get_todo_all_use_case") # type: ignore
def get_todo_all_use_case(
user_id: str,
todo_repository: TodoRepositoryInterface,
) -> list[GetTodoAllResponse]:
# Telemetry example
logger.info(
f"Get todos for user: {user_id}"
) # This log message will be logged within the span context defined above
return [
GetTodoAllResponse.from_entity(todo_item)
for todo_item in todo_repository.get_all()
Expand Down
3 changes: 2 additions & 1 deletion api/src/features/whoami/whoami_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from authentication.authentication import auth_with_jwt
from authentication.models import User
from common.exception_handlers import ExceptionHandlingRoute

router = APIRouter(tags=["whoami"], prefix="/whoami")
router = APIRouter(tags=["whoami"], prefix="/whoami", route_class=ExceptionHandlingRoute)


@router.get("", operation_id="whoami")
Expand Down
6 changes: 3 additions & 3 deletions api/src/tests/integration/mock_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from authentication.authentication import oauth2_scheme
from authentication.models import User
from common.exceptions import credentials_exception
from common.exceptions import UnauthorizedException
from config import config, default_user


Expand Down Expand Up @@ -89,7 +89,7 @@ def mock_auth_with_jwt(jwt_token: str = Security(oauth2_scheme)) -> User:
print(payload)
user = User(user_id=payload["sub"], **payload)
except jwt.exceptions.InvalidTokenError as error:
raise credentials_exception from error
raise UnauthorizedException from error
if user is None:
raise credentials_exception
raise UnauthorizedException
return user
Loading

0 comments on commit 710e9f1

Please sign in to comment.