Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ASGI mounts): Prevent accidental scope overrides by mounted ASGI apps #3945

Merged
merged 2 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/application_hooks/after_exception_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def my_handler() -> None:

async def after_exception_handler(exc: Exception, scope: "Scope") -> None:
"""Hook function that will be invoked after each exception."""
state = scope["app"].state
state = Litestar.from_scope(scope).state
if not hasattr(state, "error_count"):
state.error_count = 1
else:
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/application_hooks/before_send_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def before_send_hook_handler(message: Message, scope: Scope) -> None:
"""
if message["type"] == "http.response.start":
headers = MutableScopeHeaders.from_message(message=message)
headers["My Header"] = scope["app"].state.message
headers["My Header"] = Litestar.from_scope(scope).state.message


def on_startup(app: Litestar) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def middleware_factory(*, app: "ASGIApp") -> "ASGIApp":
"""A middleware can access application state via `scope`."""

async def my_middleware(scope: "Scope", receive: "Receive", send: "Send") -> None:
state = scope["app"].state
state = Litestar.from_scope(scope).state
logger.info("state value in middleware: %s", state.value)
await app(scope, receive, send)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/routing/mount_custom_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from litestar.types import Receive, Scope, Send


@asgi("/some/sub-path", is_mount=True)
@asgi("/some/sub-path", is_mount=True, copy_scope=True)
async def my_asgi_app(scope: "Scope", receive: "Receive", send: "Send") -> None:
"""
Args:
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/routing/mounting_starlette_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def index(request: "Request") -> JSONResponse:
return JSONResponse({"forwarded_path": request.url.path})


starlette_app = asgi(path="/some/sub-path", is_mount=True)(
starlette_app = asgi(path="/some/sub-path", is_mount=True, copy_scope=True)(
Starlette(
routes=[
Route("/", index),
Expand Down
8 changes: 4 additions & 4 deletions docs/usage/applications.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ is accessible.
:ref:`reserved keyword arguments <usage/routing/handlers:"reserved" keyword arguments>`.

It is important to understand in this context that the application instance is injected into the ASGI ``scope`` mapping
for each connection (i.e. request or websocket connection) as ``scope["app"]``. This makes the application
accessible wherever the scope mapping is available, e.g. in middleware, on :class:`~.connection.request.Request` and
:class:`~.connection.websocket.WebSocket` instances (accessible as ``request.app`` / ``socket.app``), and many
other places.
for each connection (i.e. request or websocket connection) as ``scope["litestar_app"]``, and can be retrieved using
:meth:`~.Litestar.from_scope`. This makes the application accessible wherever the scope mapping is available,
e.g. in middleware, on :class:`~.connection.request.Request` and :class:`~.connection.websocket.WebSocket` instances
(accessible as ``request.app`` / ``socket.app``), and many other places.

Therefore, :paramref:`~.app.Litestar.state` offers an easy way to share contextual data between disparate parts
of the application, as seen below:
Expand Down
7 changes: 6 additions & 1 deletion litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,15 @@ async def __call__(
await self.asgi_router.lifespan(receive=receive, send=send) # type: ignore[arg-type]
return

scope["app"] = self
scope["app"] = scope["litestar_app"] = self
scope.setdefault("state", {})
await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope)) # type: ignore[arg-type]

@classmethod
def from_scope(cls, scope: Scope) -> Litestar:
"""Retrieve the Litestar application from the current ASGI scope"""
return scope["litestar_app"]

async def _call_lifespan_hook(self, hook: LifespanHook) -> None:
ret = hook(self) if inspect.signature(hook).parameters else hook() # type: ignore[call-arg]

Expand Down
6 changes: 3 additions & 3 deletions litestar/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def app(self) -> Litestar:
Returns:
The :class:`Litestar <litestar.app.Litestar>` application instance
"""
return self.scope["app"]
return self.scope["litestar_app"]

@property
def route_handler(self) -> HandlerT:
Expand Down Expand Up @@ -321,7 +321,7 @@ def url_for(self, name: str, **path_parameters: Any) -> str:
Returns:
A string representing the absolute url of the route handler.
"""
litestar_instance = self.scope["app"]
litestar_instance = self.scope["litestar_app"]
url_path = litestar_instance.route_reverse(name, **path_parameters)

return make_absolute_url(url_path, self.base_url)
Expand All @@ -339,7 +339,7 @@ def url_for_static_asset(self, name: str, file_path: str) -> str:
Returns:
A string representing absolute url to the asset.
"""
litestar_instance = self.scope["app"]
litestar_instance = self.scope["litestar_app"]
url_path = litestar_instance.url_for_static_asset(name, file_path)

return make_absolute_url(url_path, self.base_url)
22 changes: 21 additions & 1 deletion litestar/handlers/asgi_handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Mapping, Sequence

from litestar.exceptions import ImproperlyConfiguredException
Expand All @@ -11,6 +12,7 @@


if TYPE_CHECKING:
from litestar import Litestar
from litestar.types import (
ExceptionHandlersMap,
Guard,
Expand All @@ -24,7 +26,7 @@ class ASGIRouteHandler(BaseRouteHandler):
Use this decorator to decorate ASGI applications.
"""

__slots__ = ("is_mount", "is_static")
__slots__ = ("copy_scope", "is_mount", "is_static")

def __init__(
self,
Expand All @@ -37,6 +39,7 @@ def __init__(
is_mount: bool = False,
is_static: bool = False,
signature_namespace: Mapping[str, Any] | None = None,
copy_scope: bool | None = None,
**kwargs: Any,
) -> None:
"""Initialize ``ASGIRouteHandler``.
Expand All @@ -58,10 +61,14 @@ def __init__(
are used to deliver static files.
signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling.
type_encoders: A mapping of types to callables that transform them into types supported for serialization.
copy_scope: Copy the ASGI 'scope' before calling the mounted application. Should be set to 'True' unless
side effects via scope mutations by the mounted ASGI application are intentional
**kwargs: Any additional kwarg - will be set in the opt dictionary.
"""
self.is_mount = is_mount or is_static
self.is_static = is_static
self.copy_scope = copy_scope

super().__init__(
path,
exception_handlers=exception_handlers,
Expand All @@ -72,6 +79,19 @@ def __init__(
**kwargs,
)

def on_registration(self, app: Litestar) -> None:
super().on_registration(app)

if self.copy_scope is None:
warnings.warn(
f"{self}: 'copy_scope' not set for ASGI handler. Leaving 'copy_scope' unset will warn about mounted "
"ASGI applications modifying the scope. Set 'copy_scope=True' to ensure calling into mounted ASGI apps "
"does not cause any side effects via scope mutations, or set 'copy_scope=False' if those mutations are "
"desired. 'copy'scope' will default to 'True' in Litestar 3",
category=DeprecationWarning,
stacklevel=1,
)

def _validate_handler_function(self) -> None:
"""Validate the route handler function once it's set by inspecting its return annotations."""
super()._validate_handler_function()
Expand Down
2 changes: 1 addition & 1 deletion litestar/middleware/_internal/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
origin = headers.get("origin")

if scope["type"] == ScopeType.HTTP and scope["method"] == HttpMethod.OPTIONS and origin:
request = scope["app"].request_class(scope=scope, receive=receive, send=send)
request = scope["litestar_app"].request_class(scope=scope, receive=receive, send=send)
asgi_response = self._create_preflight_response(origin=origin, request_headers=headers).to_asgi_response(
app=None, request=request
)
Expand Down
4 changes: 2 additions & 2 deletions litestar/middleware/_internal/exceptions/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(

@staticmethod
def _get_debug_scope(scope: Scope) -> bool:
return scope["app"].debug
return scope["litestar_app"].debug

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI-callable.
Expand All @@ -161,7 +161,7 @@ async def capture_response_started(event: Message) -> None:
if scope_state.response_started:
raise LitestarException("Exception caught after response started") from e

litestar_app = scope["app"]
litestar_app = scope["litestar_app"]

if litestar_app.logging_config and (logger := litestar_app.logger):
self.handle_exception_logging(logger=logger, logging_config=litestar_app.logging_config, scope=scope)
Expand Down
2 changes: 1 addition & 1 deletion litestar/middleware/csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return

request: Request[Any, Any, Any] = scope["app"].request_class(scope=scope, receive=receive)
request: Request[Any, Any, Any] = scope["litestar_app"].request_class(scope=scope, receive=receive)
content_type, _ = request.content_type
csrf_cookie = request.cookies.get(self.config.cookie_name)
existing_csrf_token = request.headers.get(self.config.header_name)
Expand Down
4 changes: 2 additions & 2 deletions litestar/middleware/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
None
"""
if not hasattr(self, "logger"):
self.logger = scope["app"].get_logger(self.config.logger_name)
self.logger = scope["litestar_app"].get_logger(self.config.logger_name)
self.is_struct_logger = structlog_installed and repr(self.logger).startswith("<BoundLoggerLazyProxy")

if self.config.response_log_fields:
Expand All @@ -121,7 +121,7 @@ async def log_request(self, scope: Scope, receive: Receive) -> None:
Returns:
None
"""
extracted_data = await self.extract_request_data(request=scope["app"].request_class(scope, receive))
extracted_data = await self.extract_request_data(request=scope["litestar_app"].request_class(scope, receive))
self.log_message(values=extracted_data)

def log_response(self, scope: Scope) -> None:
Expand Down
2 changes: 1 addition & 1 deletion litestar/middleware/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Returns:
None
"""
app = scope["app"]
app = scope["litestar_app"]
request: Request[Any, Any, Any] = app.request_class(scope)
store = self.config.get_store_from_app(app)
if await self.should_check_request(request=request):
Expand Down
2 changes: 1 addition & 1 deletion litestar/middleware/response_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def wrapped_send(message: Message) -> None:

if messages and message["type"] == HTTP_RESPONSE_BODY and not message.get("more_body"):
key = (route_handler.cache_key_builder or self.config.key_builder)(Request(scope))
store = self.config.get_store_from_app(scope["app"])
store = self.config.get_store_from_app(scope["litestar_app"])
await store.set(key, encode_msgpack(messages), expires_in=expires_in)
await send(message)

Expand Down
21 changes: 20 additions & 1 deletion litestar/routes/asgi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any

from litestar.connection import ASGIConnection
from litestar.enums import ScopeType
from litestar.exceptions import LitestarWarning
from litestar.routes.base import BaseRoute

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,4 +53,21 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
connection = ASGIConnection["ASGIRouteHandler", Any, Any, Any](scope=scope, receive=receive)
await self.route_handler.authorize_connection(connection=connection)

await self.route_handler.fn(scope=scope, receive=receive, send=send)
handler_scope = scope.copy()
copy_scope = self.route_handler.copy_scope

await self.route_handler.fn(
scope=handler_scope if copy_scope is True else scope,
receive=receive,
send=send,
)

if copy_scope is None and handler_scope != scope:
warnings.warn(
f"{self.route_handler}: Mounted ASGI app {self.route_handler.fn} modified 'scope' with 'copy_scope' "
"set to 'None'. Set 'copy_scope=True' to avoid mutating the original scope or set 'copy_scope=False' "
"if mutating the scope from within the mounted ASGI app is intentional. Note: 'copy_scope' will "
"default to 'True' by default in Litestar 3",
category=LitestarWarning,
stacklevel=1,
)
4 changes: 3 additions & 1 deletion litestar/routes/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ async def _call_handler_function(
route_handler=route_handler, parameter_model=parameter_model, request=request
)

response: ASGIApp = await route_handler.to_response(app=scope["app"], data=response_data, request=request)
response: ASGIApp = await route_handler.to_response(
app=scope["litestar_app"], data=response_data, request=request
)

if cleanup_group:
await cleanup_group.cleanup()
Expand Down
1 change: 1 addition & 0 deletions litestar/testing/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def fake_asgi_connection(app: ASGIApp, cookies: dict[str, str]) -> ASGIConnectio
"http_version": "1.1",
"extensions": {"http.response.template": {}},
"app": app, # type: ignore[typeddict-item]
provinzkraut marked this conversation as resolved.
Show resolved Hide resolved
"litestar_app": app,
"state": {},
"path_params": {},
"route_handler": None,
Expand Down
3 changes: 2 additions & 1 deletion litestar/testing/request_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
"""Initialize ``RequestFactory``

Args:
app: An instance of :class:`Litestar <litestar.app.Litestar>` to set as ``request.scope["app"]``.
app: An instance of :class:`Litestar <litestar.app.Litestar>` to set as ``request.scope["litestar_app"]``.
server: The server's domain.
port: The server's port.
root_path: Root path for the server.
Expand Down Expand Up @@ -175,6 +175,7 @@ def _create_scope(
path=path,
headers=[],
app=self.app,
litestar_app=self.app,
session=session,
user=user,
auth=auth,
Expand Down
3 changes: 2 additions & 1 deletion litestar/types/asgi_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ class HeaderScope(TypedDict):
class BaseScope(HeaderScope):
"""Base ASGI-scope."""

app: Litestar
app: Litestar # deprecated
litestar_app: Litestar
asgi: ASGIVersion
auth: Any
client: tuple[str, int] | None
Expand Down
2 changes: 1 addition & 1 deletion litestar/utils/scope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_serializer_from_scope(scope: Scope) -> Serializer:
A serializer function
"""
route_handler = scope["route_handler"]
app = scope["app"]
app = scope["litestar_app"]

if hasattr(route_handler, "resolve_type_encoders"):
type_encoders = route_handler.resolve_type_encoders()
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def inner(
) -> Scope:
scope = {
"app": app,
"litestar_app": app,
"asgi": asgi or {"spec_version": "2.0", "version": "3.0"},
"auth": auth,
"type": type,
Expand Down
18 changes: 17 additions & 1 deletion tests/unit/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def handler() -> Dict[str, str]:
async def before_send_hook_handler(message: Message, scope: Scope) -> None:
if message["type"] == "http.response.start":
headers = MutableScopeHeaders(message)
headers.add("My Header", scope["app"].state.message)
headers.add("My Header", Litestar.from_scope(scope).state.message)

def on_startup(app: Litestar) -> None:
app.state.message = "value injected during send"
Expand Down Expand Up @@ -466,3 +466,19 @@ def my_route_handler() -> None: ...
with create_test_client(my_route_handler, path="/abc") as client:
response = client.get("/abc")
assert response.status_code == HTTP_200_OK


def test_from_scope() -> None:
mock = MagicMock()

@get()
def handler(scope: Scope) -> None:
mock(Litestar.from_scope(scope))
return

app = Litestar(route_handlers=[handler])

with TestClient(app) as client:
client.get("/")

mock.assert_called_once_with(app)
Loading
Loading