From 664d122d44580d4854a2ff4326f95106975e4e3c Mon Sep 17 00:00:00 2001 From: pdmurray Date: Thu, 16 Jan 2025 13:52:09 -0800 Subject: [PATCH] Cleanup; remove ANN101, ANN102, and PT004 deprecated ruff rules --- .../conda_store_server/_internal/schema.py | 2 +- .../_internal/server/views/api.py | 4 +- .../_internal/server/views/pagination.py | 43 ++++++---- conda-store-server/pyproject.toml | 3 - .../tests/_internal/server/views/test_api.py | 84 ++++++++++++++----- 5 files changed, 90 insertions(+), 46 deletions(-) diff --git a/conda-store-server/conda_store_server/_internal/schema.py b/conda-store-server/conda_store_server/_internal/schema.py index 679ddd250..ee649f674 100644 --- a/conda-store-server/conda_store_server/_internal/schema.py +++ b/conda-store-server/conda_store_server/_internal/schema.py @@ -503,7 +503,7 @@ class APICursorPaginatedResponse(BaseModel): status: APIStatus message: Optional[str] = None cursor: Optional[str] = None - count: int + count: int # the total number of results available to fetch class APIAckResponse(BaseModel): diff --git a/conda-store-server/conda_store_server/_internal/server/views/api.py b/conda-store-server/conda_store_server/_internal/server/views/api.py index 603f612cc..ca5a4726c 100644 --- a/conda-store-server/conda_store_server/_internal/server/views/api.py +++ b/conda-store-server/conda_store_server/_internal/server/views/api.py @@ -742,7 +742,7 @@ async def api_list_environments( role_bindings=auth.entity_bindings(entity), ) - paginated, next_cursor = paginate( + paginated, next_cursor, count = paginate( query=query, ordering_metadata=OrderingMetadata( order_names=["namespace", "name"], @@ -759,7 +759,7 @@ async def api_list_environments( data=paginated, status="ok", cursor=next_cursor.dump(), - count=1000, + count=count, ) diff --git a/conda-store-server/conda_store_server/_internal/server/views/pagination.py b/conda-store-server/conda_store_server/_internal/server/views/pagination.py index f21b37278..7d363c12b 100644 --- a/conda-store-server/conda_store_server/_internal/server/views/pagination.py +++ b/conda-store-server/conda_store_server/_internal/server/views/pagination.py @@ -22,7 +22,6 @@ class Ordering(Enum): class Cursor(pydantic.BaseModel): last_id: int | None = 0 - count: int | None = 0 # List query parameters to order by, and the last value of the ordered attribute # { @@ -57,7 +56,7 @@ def load(cls, b64_cursor: str | None = None) -> Cursor: Cursor representation of the b64-encoded string """ if b64_cursor is None: - return cls(last_id=None, count=0, last_value=None) + return cls(last_id=None, last_value=None) return cls.model_validate_json(base64.b64decode(b64_cursor).decode("utf8")) def get_last_values(self, order_names: list[str]) -> list[Any]: @@ -87,11 +86,19 @@ def end(cls) -> Cursor: Cursor An empty cursor """ - return cls(last_id=None, count=0, last_value=None) + return cls(last_id=None, last_value=None) @classmethod def begin(cls) -> Cursor: - return cls(last_id=0, count=None, last_value=None) + """Cursor representing the beginning of a set of paginated results. + + Returns + ------- + Cursor + A cursor that points at the first result; the count is 0 + because this cursor + """ + return cls(last_id=0, last_value=None) def paginate( @@ -101,7 +108,7 @@ def paginate( order_by: list[str] | None = None, order: Ordering = Ordering.ASCENDING, limit: int = 10, -) -> tuple[SqlQuery, Cursor]: +) -> tuple[list[Base], Cursor, int]: """Paginate the query using the cursor and the requested sort_bys. This function assumes that the first column of the query contains @@ -124,9 +131,9 @@ def paginate( Returns ------- - tuple[SqlQuery, Cursor] - Query containing the paginated results, and Cursor for retrieving - the next page + tuple[Base, Cursor, int] + Query containing the paginated results, Cursor for retrieving + the next page, and total number of results """ if order_by is None: order_by = [] @@ -143,6 +150,9 @@ def paginate( detail=f"Invalid query parameter: order = {order}; must be one of ['asc', 'desc']", ) + # Fetch the total number of objects in the database before filtering + count = query.count() + # Get the python type of the objects being queried queried_type = query.column_descriptions[0]["type"] columns = ordering_metadata.get_requested_columns(order_by) @@ -159,23 +169,22 @@ def paginate( ) ) - order_by_args = [order_func(col) for col in columns] + [order_func(queried_type.id)] - - query = query.order_by(*order_by_args) + # Order the query by the requested columns, and also by the object's primary key + query = query.order_by( + *([order_func(col) for col in columns] + [order_func(queried_type.id)]) + ) data = query.limit(limit).all() - count = query.count() - if count > 0: + if len(data) > 0: last_result = data[-1] - last_value = ordering_metadata.get_attr_values(last_result, order_by) - next_cursor = Cursor( - last_id=data[-1].id, last_value=last_value, count=query.count() + last_id=last_result.id, + last_value=ordering_metadata.get_attr_values(last_result, order_by), ) else: next_cursor = Cursor.end() - return (data, next_cursor) + return (data, next_cursor, count) class CursorPaginatedArgs(pydantic.BaseModel): diff --git a/conda-store-server/pyproject.toml b/conda-store-server/pyproject.toml index fc437da64..7f5e1e0d5 100644 --- a/conda-store-server/pyproject.toml +++ b/conda-store-server/pyproject.toml @@ -142,8 +142,6 @@ ignore = [ "ANN001", # missing-type-function-argument "ANN002", # missing-type-args "ANN003", # missing-type-kwargs - "ANN101", # missing-type-self - "ANN102", # missing-type-cls "ANN201", # missing-return-type-undocumented-public-function "ANN202", # missing-return-type-private-function "ANN204", # missing-return-type-special-method @@ -175,7 +173,6 @@ ignore = [ "FIX002", # line-contains-todo "N805", # invalid-first-argument-name-for-method "N815", # mixed-case-variable-in-class-scope - "PT004", # pytest-missing-fixture-name-underscore "PT006", # pytest-parametrize-names-wrong-type "PT011", # pytest-raises-too-broad "PT012", # pytest-raises-with-multiple-statements diff --git a/conda-store-server/tests/_internal/server/views/test_api.py b/conda-store-server/tests/_internal/server/views/test_api.py index 188006b12..0fd04984e 100644 --- a/conda-store-server/tests/_internal/server/views/test_api.py +++ b/conda-store-server/tests/_internal/server/views/test_api.py @@ -1078,6 +1078,7 @@ def test_default_conda_store_dir(): [ "asc", "desc", + None, # If none is specified, results will be sorted ascending ], ) @pytest.mark.parametrize( @@ -1089,7 +1090,7 @@ def test_default_conda_store_dir(): ("namespace,name", lambda x: (x.namespace.name, x.name, x.id)), ], ) -def test_api_list_environments( +def test_api_list_environments_paginate_order_by( conda_store_server, testclient, seed_conda_store_big, @@ -1098,39 +1099,61 @@ def test_api_list_environments( sort_by_param, attr_func, ): - """Test the REST API lists the paginated envs when sorting by name.""" - response = testclient.get( - f"api/v1/environment/?sort_by={sort_by_param}&order={order}" - ) - response.raise_for_status() + """Test the REST API lists the paginated envs when given sort_by query parameters.""" + limit = 10 + nfetches = 0 + envs = [] + + order_param = "" if order is None else f"&order={order}" + cursor = None + cursor_param = "" + while cursor is None or cursor != Cursor.end(): + response = testclient.get( + f"api/v1/environment/?limit={limit}&sort_by={sort_by_param}{order_param}{cursor_param}" + ) + response.raise_for_status() + + model = schema.APIListEnvironment.model_validate(response.json()) + assert model.status == schema.APIStatus.OK + + envs.extend(model.data) + + # Get the next cursor and the next query parameters + cursor = Cursor.load(model.cursor) + cursor_param = f"&cursor={model.cursor}" - model = schema.APIListEnvironment.model_validate(response.json()) - assert model.status == schema.APIStatus.OK + nfetches += 1 - # Pull out the attributes that we are sorting on from each environment - envs = [attr_func(env) for env in model.data] + env_attrs = [attr_func(env) for env in envs] + + # Check that the number of results reported by the server corresponds to the number of + # results retrieved + assert model.count == len(envs) + + # Check that number of results requested corresponds to the number of results retrieved. + # Since the last fetch isn't always of length `limit`, we subtract off the remainder + # before checking. + assert len(envs) - (len(envs) % limit) == limit * (nfetches - 1) # The environments should already be sorted; check that this is the case - assert sorted(envs, reverse=(order == "desc")) == envs + assert sorted(env_attrs, reverse=(order == "desc")) == env_attrs -def test_api_list_environments_no_qparam( +def test_api_list_environments_paginate( conda_store_server, testclient, seed_conda_store_big, authenticate, ): """Test the REST API lists the envs by id when no query params are specified.""" - response = testclient.get("api/v1/environment/?limit=10") - response.raise_for_status() - - model = schema.APIListEnvironment.model_validate(response.json()) - assert model.status == schema.APIStatus.OK - - envs = model.data - - while Cursor.load(model.cursor).last_id is not None: - response = testclient.get(f"api/v1/environment/?limit=10&cursor={model.cursor}") + limit = 10 + nfetches = 0 + envs = [] + + cursor = None + cursor_param = "" + while cursor is None or cursor != Cursor.end(): + response = testclient.get(f"api/v1/environment/?limit={limit}{cursor_param}") response.raise_for_status() model = schema.APIListEnvironment.model_validate(response.json()) @@ -1138,7 +1161,22 @@ def test_api_list_environments_no_qparam( envs.extend(model.data) - env_ids = [env.id for env in model.data] + # Get the next cursor and the next query parameters + cursor = Cursor.load(model.cursor) + cursor_param = f"&cursor={model.cursor}" + + nfetches += 1 + + env_ids = [env.id for env in envs] + + # Check that the number of results reported by the server corresponds to the number of + # results retrieved + assert model.count == len(env_ids) + + # Check that number of results requested corresponds to the number of results retrieved. + # Since the last fetch isn't always of length `limit`, we subtract off the remainder + # before checking. + assert len(env_ids) - (len(env_ids) % limit) == limit * (nfetches - 1) # Check that the environments are sorted by ID assert sorted(env_ids) == env_ids