From 9122e616f9783cdb688c7930052e23ec5f3fd102 Mon Sep 17 00:00:00 2001 From: pdmurray Date: Fri, 17 Jan 2025 12:35:00 -0800 Subject: [PATCH] Address review comments; add bad query parameter handling --- .../_internal/server/dependencies.py | 55 +++++++++++++- .../_internal/server/views/api.py | 76 +++++-------------- .../_internal/server/views/pagination.py | 72 +++++++++++------- conda-store-server/pyproject.toml | 1 + .../tests/_internal/server/views/test_api.py | 17 +++++ 5 files changed, 137 insertions(+), 84 deletions(-) diff --git a/conda-store-server/conda_store_server/_internal/server/dependencies.py b/conda-store-server/conda_store_server/_internal/server/dependencies.py index 7226760bd..b72c6462e 100644 --- a/conda-store-server/conda_store_server/_internal/server/dependencies.py +++ b/conda-store-server/conda_store_server/_internal/server/dependencies.py @@ -2,7 +2,15 @@ # Use of this source code is governed by a BSD-style # license that can be found in the LICENSE file. -from fastapi import Depends, Request +from typing import Optional, TypedDict + +from fastapi import Depends, Query, Request + +from conda_store_server._internal.server.views.pagination import ( + Cursor, + CursorPaginatedArgs, + Ordering, +) async def get_conda_store(request: Request): @@ -27,3 +35,48 @@ async def get_templates(request: Request): async def get_url_prefix(request: Request, server=Depends(get_server)): return server.url_prefix + + +def get_cursor(cursor: str | None = None) -> Cursor: + return Cursor.load(cursor) + + +def get_cursor_paginated_args( + order: Ordering = Ordering.ASCENDING, + limit: int | None = None, + sort_by: list[str] = Query([]), + server=Depends(get_server), +) -> CursorPaginatedArgs: + return CursorPaginatedArgs( + limit=server.max_page_size if limit is None else limit, + order=order, + sort_by=sort_by, + ) + + +class PaginatedArgs(TypedDict): + """Dictionary type holding information about paginated requests.""" + + limit: int + offset: int + sort_by: list[str] + order: str + + +def get_paginated_args( + page: int = 1, + order: Optional[str] = None, + size: Optional[int] = None, + sort_by: list[str] = Query([]), + server=Depends(get_server), +) -> PaginatedArgs: + if size is None: + size = server.max_page_size + size = min(size, server.max_page_size) + offset = (page - 1) * size + return { + "limit": size, + "offset": offset, + "sort_by": sort_by, + "order": order, + } diff --git a/conda-store-server/conda_store_server/_internal/server/views/api.py b/conda-store-server/conda_store_server/_internal/server/views/api.py index c32e96adc..08c9ed1e7 100644 --- a/conda-store-server/conda_store_server/_internal/server/views/api.py +++ b/conda-store-server/conda_store_server/_internal/server/views/api.py @@ -3,7 +3,7 @@ # license that can be found in the LICENSE file. import datetime -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, Dict, List, Optional import pydantic import yaml @@ -18,7 +18,6 @@ from conda_store_server._internal.server.views.pagination import ( Cursor, CursorPaginatedArgs, - Ordering, OrderingMetadata, paginate, ) @@ -28,58 +27,12 @@ from conda_store_server.server.auth import Authentication from conda_store_server.server.schema import AuthenticationToken, Permissions - -def get_cursor(cursor: Optional[str] = None) -> Cursor: - return Cursor.load(cursor) - - -def get_cursor_paginated_args( - order: Optional[Ordering] = Ordering.ASCENDING, - limit: Optional[int] = None, - sort_by: List[str] = Query([]), - server=Depends(dependencies.get_server), -) -> CursorPaginatedArgs: - return CursorPaginatedArgs( - limit=server.max_page_size if limit is None else limit, - order=order, - sort_by=sort_by, - ) - - -class PaginatedArgs(TypedDict): - """Dictionary type holding information about paginated requests.""" - - limit: int - offset: int - sort_by: List[str] - order: str - - router_api = APIRouter( tags=["api"], prefix="/api/v1", ) -def get_paginated_args( - page: int = 1, - order: Optional[str] = None, - size: Optional[int] = None, - sort_by: List[str] = Query([]), - server=Depends(dependencies.get_server), -) -> PaginatedArgs: - if size is None: - size = server.max_page_size - size = min(size, server.max_page_size) - offset = (page - 1) * size - return { - "limit": size, - "offset": offset, - "sort_by": sort_by, - "order": order, - } - - def filter_distinct_on( query, distinct_on: List[str] = [], @@ -290,7 +243,9 @@ async def api_post_token( async def api_list_namespaces( auth=Depends(dependencies.get_auth), entity=Depends(dependencies.get_entity), - paginated_args: PaginatedArgs = Depends(get_paginated_args), + paginated_args: dependencies.PaginatedArgs = Depends( + dependencies.get_paginated_args + ), conda_store=Depends(dependencies.get_conda_store), ): with conda_store.get_db() as db: @@ -656,14 +611,17 @@ async def api_delete_namespace( @router_api.get( "/environment/", response_model=schema.APIListEnvironment, + response_model_exclude={"data": {"__all__": {"current_build"}}}, ) async def api_list_environments( request: Request, auth: Authentication = Depends(dependencies.get_auth), conda_store: CondaStore = Depends(dependencies.get_conda_store), entity: AuthenticationToken = Depends(dependencies.get_entity), - paginated_args: CursorPaginatedArgs = Depends(get_cursor_paginated_args), - cursor: Cursor = Depends(get_cursor), + paginated_args: CursorPaginatedArgs = Depends( + dependencies.get_cursor_paginated_args + ), + cursor: Cursor = Depends(dependencies.get_cursor), artifact: Optional[schema.BuildArtifactType] = None, jwt: Optional[str] = None, name: Optional[str] = None, @@ -711,6 +669,10 @@ async def api_list_environments( envrionment's build's scheduled_on time to ensure all results are returned when iterating over pages in systems where the number of environments is changing while results are being requested; see https://github.com/conda-incubator/conda-store/issues/859 for context + + Note that the Environment objects returned here have their `current_build` fields omitted + to keep the repsonse size down; these fields otherwise drastically increase the response + size. """ with conda_store.get_db() as db: if jwt: @@ -744,12 +706,12 @@ async def api_list_environments( paginated, next_cursor, count = paginate( query=query, ordering_metadata=OrderingMetadata( - order_names=["namespace", "name"], + valid_orderings=["namespace", "name"], column_names=["namespace.name", "name"], column_objects=[orm.Namespace.name, orm.Environment.name], ), cursor=cursor, - order_by=paginated_args.sort_by, + sort_by=paginated_args.sort_by, order=paginated_args.order, limit=paginated_args.limit, ) @@ -977,7 +939,7 @@ async def api_list_builds( conda_store=Depends(dependencies.get_conda_store), auth=Depends(dependencies.get_auth), entity=Depends(dependencies.get_entity), - paginated_args=Depends(get_paginated_args), + paginated_args=Depends(dependencies.get_paginated_args), ): with conda_store.get_db() as db: orm_builds = auth.filter_builds( @@ -1172,7 +1134,7 @@ async def api_get_build_packages( build: Optional[str] = None, auth=Depends(dependencies.get_auth), conda_store=Depends(dependencies.get_conda_store), - paginated_args=Depends(get_paginated_args), + paginated_args=Depends(dependencies.get_paginated_args), ): with conda_store.get_db() as db: build_orm = api.get_build(db, build_id) @@ -1229,7 +1191,7 @@ async def api_get_build_logs( ) async def api_list_channels( conda_store=Depends(dependencies.get_conda_store), - paginated_args=Depends(get_paginated_args), + paginated_args=Depends(dependencies.get_paginated_args), ): with conda_store.get_db() as db: orm_channels = api.list_conda_channels(db) @@ -1250,7 +1212,7 @@ async def api_list_packages( search: Optional[str] = None, exact: Optional[str] = None, build: Optional[str] = None, - paginated_args=Depends(get_paginated_args), + paginated_args=Depends(dependencies.get_paginated_args), conda_store=Depends(dependencies.get_conda_store), distinct_on: List[str] = Query([]), ): diff --git a/conda-store-server/conda_store_server/_internal/server/views/pagination.py b/conda-store-server/conda_store_server/_internal/server/views/pagination.py index 7d363c12b..5bb60c154 100644 --- a/conda-store-server/conda_store_server/_internal/server/views/pagination.py +++ b/conda-store-server/conda_store_server/_internal/server/views/pagination.py @@ -3,7 +3,7 @@ import base64 import operator from enum import Enum -from typing import Any, Optional +from typing import Any import pydantic from fastapi import HTTPException @@ -105,7 +105,7 @@ def paginate( query: SqlQuery, ordering_metadata: OrderingMetadata, cursor: Cursor | None = None, - order_by: list[str] | None = None, + sort_by: list[str] | None = None, order: Ordering = Ordering.ASCENDING, limit: int = 10, ) -> tuple[list[Base], Cursor, int]: @@ -127,7 +127,7 @@ def paginate( Cursor object containing information about the last item on the previous page. If None, the first page is returned. order_by : list[str] | None - List of sort_by query parameters + List of query parameters to order the results by Returns ------- @@ -135,8 +135,8 @@ def paginate( Query containing the paginated results, Cursor for retrieving the next page, and total number of results """ - if order_by is None: - order_by = [] + if sort_by is None: + sort_by = [] if order == Ordering.ASCENDING: comparison = operator.gt @@ -147,7 +147,20 @@ def paginate( else: raise HTTPException( status_code=400, - detail=f"Invalid query parameter: order = {order}; must be one of ['asc', 'desc']", + detail=( + f"Cannot order results: {order}" + f"Valid order values are [{Ordering.ASCENDING.value}, {Ordering.DESCENDING.value}]", + ), + ) + + invalid_params = ordering_metadata.get_invalid_orderings(sort_by) + if invalid_params: + raise HTTPException( + status_code=400, + detail=( + f"Cannot sort results by {invalid_params}. " + f"Valid sort_by values are {ordering_metadata.valid_orderings}" + ), ) # Fetch the total number of objects in the database before filtering @@ -155,13 +168,13 @@ def paginate( # Get the python type of the objects being queried queried_type = query.column_descriptions[0]["type"] - columns = ordering_metadata.get_requested_columns(order_by) + columns = ordering_metadata.get_requested_columns(sort_by) # If there's a cursor already, use the last attributes to filter # the results by (*attributes, id) >/< (*last_values, last_id) # Order by desc or asc if cursor is not None and cursor != Cursor.end(): - last_values = cursor.get_last_values(order_by) + last_values = cursor.get_last_values(sort_by) query = query.filter( comparison( tuple_(*columns, queried_type.id), @@ -179,7 +192,7 @@ def paginate( last_result = data[-1] next_cursor = Cursor( last_id=last_result.id, - last_value=ordering_metadata.get_attr_values(last_result, order_by), + last_value=ordering_metadata.get_attr_values(last_result, sort_by), ) else: next_cursor = Cursor.end() @@ -188,9 +201,9 @@ def paginate( class CursorPaginatedArgs(pydantic.BaseModel): - limit: Optional[int] = 10 - order: Optional[Ordering] = Ordering.ASCENDING - sort_by: Optional[list[str]] = [] + limit: int | None + order: Ordering + sort_by: list[str] @pydantic.field_validator("sort_by") def validate_sort_by(cls, v: list[str]) -> list[str]: @@ -218,24 +231,31 @@ def validate_sort_by(cls, v: list[str]) -> list[str]: class OrderingMetadata: def __init__( self, - order_names: list[str] | None = None, + valid_orderings: list[str] | None = None, column_names: list[str] | None = None, column_objects: list[InstrumentedAttribute] | None = None, ): - self.order_names = order_names + self.valid_orderings = valid_orderings self.column_names = column_names self.column_objects = column_objects - def validate(self, model: Base): - if len(self.order_names) != len(self.column_names): - raise ValueError( - "Each name of a valid ordering available to the order_by query parameter" - "must have an associated column name to select in the table." - ) + def get_invalid_orderings(self, query_params: list[str] | None) -> list[str]: + """Return a list of invalid ordering query parameters. + + Parameters + ---------- + query_params : list[str] | None + A list of ordering query parameters + + Returns + ------- + list[str] + A list of the query parameters which cannot be used to order the results + """ + if query_params is None: + return [] - for col in self.column_names: - if not hasattr(model, col): - raise ValueError(f"No column named {col} found on model {model}.") + return [param for param in query_params if param not in self.valid_orderings] def get_requested_columns( self, @@ -258,13 +278,13 @@ def get_requested_columns( columns = [] if order_by: for order_name in order_by: - idx = self.order_names.index(order_name) + idx = self.valid_orderings.index(order_name) columns.append(self.column_objects[idx]) return columns def __str__(self) -> str: - return f"OrderingMetadata" + return f"OrderingMetadata" def __repr__(self) -> str: return str(self) @@ -293,7 +313,7 @@ def get_attr_values( """ values = {} for order_name in order_by: - idx = self.order_names.index(order_name) + idx = self.valid_orderings.index(order_name) attr = self.column_names[idx] values[order_name] = get_nested_attribute(obj, attr) diff --git a/conda-store-server/pyproject.toml b/conda-store-server/pyproject.toml index 2e0f10b26..bf746f0e7 100644 --- a/conda-store-server/pyproject.toml +++ b/conda-store-server/pyproject.toml @@ -94,6 +94,7 @@ dependencies = [ "build", "docker-compose", "docker-py<7", + "httpx", "pre-commit", "pytest", "pytest-celery", diff --git a/conda-store-server/tests/_internal/server/views/test_api.py b/conda-store-server/tests/_internal/server/views/test_api.py index 35fcf12a7..6234288ce 100644 --- a/conda-store-server/tests/_internal/server/views/test_api.py +++ b/conda-store-server/tests/_internal/server/views/test_api.py @@ -8,6 +8,7 @@ import sys import time +import httpx import pytest import traitlets import yaml @@ -1181,3 +1182,19 @@ def test_api_list_environments_paginate( # Check that the environments are sorted by ID assert sorted(env_ids) == env_ids + + +def test_api_list_environments_invalid_query_params( + conda_store_server, + testclient, + seed_conda_store_big, + authenticate, +): + """Test that invalid query parameters return 400 status codes.""" + response = testclient.get("api/v1/environment/?order=foo") + with pytest.raises(httpx.HTTPStatusError): + response.raise_for_status() + + response = testclient.get("api/v1/environment/?sort_by=foo") + with pytest.raises(httpx.HTTPStatusError): + response.raise_for_status()