Skip to content

Commit

Permalink
Cleanup; remove ANN101, ANN102, and PT004 deprecated ruff rules
Browse files Browse the repository at this point in the history
  • Loading branch information
peytondmurray committed Jan 16, 2025
1 parent 1371f2c commit 664d122
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 46 deletions.
2 changes: 1 addition & 1 deletion conda-store-server/conda_store_server/_internal/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ class APICursorPaginatedResponse(BaseModel):
status: APIStatus
message: Optional[str] = None
cursor: Optional[str] = None
count: int
count: int # the total number of results available to fetch


class APIAckResponse(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ async def api_list_environments(
role_bindings=auth.entity_bindings(entity),
)

paginated, next_cursor = paginate(
paginated, next_cursor, count = paginate(
query=query,
ordering_metadata=OrderingMetadata(
order_names=["namespace", "name"],
Expand All @@ -759,7 +759,7 @@ async def api_list_environments(
data=paginated,
status="ok",
cursor=next_cursor.dump(),
count=1000,
count=count,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class Ordering(Enum):

class Cursor(pydantic.BaseModel):
last_id: int | None = 0
count: int | None = 0

# List query parameters to order by, and the last value of the ordered attribute
# {
Expand Down Expand Up @@ -57,7 +56,7 @@ def load(cls, b64_cursor: str | None = None) -> Cursor:
Cursor representation of the b64-encoded string
"""
if b64_cursor is None:
return cls(last_id=None, count=0, last_value=None)
return cls(last_id=None, last_value=None)
return cls.model_validate_json(base64.b64decode(b64_cursor).decode("utf8"))

def get_last_values(self, order_names: list[str]) -> list[Any]:
Expand Down Expand Up @@ -87,11 +86,19 @@ def end(cls) -> Cursor:
Cursor
An empty cursor
"""
return cls(last_id=None, count=0, last_value=None)
return cls(last_id=None, last_value=None)

@classmethod
def begin(cls) -> Cursor:
return cls(last_id=0, count=None, last_value=None)
"""Cursor representing the beginning of a set of paginated results.
Returns
-------
Cursor
A cursor that points at the first result; the count is 0
because this cursor
"""
return cls(last_id=0, last_value=None)


def paginate(
Expand All @@ -101,7 +108,7 @@ def paginate(
order_by: list[str] | None = None,
order: Ordering = Ordering.ASCENDING,
limit: int = 10,
) -> tuple[SqlQuery, Cursor]:
) -> tuple[list[Base], Cursor, int]:
"""Paginate the query using the cursor and the requested sort_bys.
This function assumes that the first column of the query contains
Expand All @@ -124,9 +131,9 @@ def paginate(
Returns
-------
tuple[SqlQuery, Cursor]
Query containing the paginated results, and Cursor for retrieving
the next page
tuple[Base, Cursor, int]
Query containing the paginated results, Cursor for retrieving
the next page, and total number of results
"""
if order_by is None:
order_by = []
Expand All @@ -143,6 +150,9 @@ def paginate(
detail=f"Invalid query parameter: order = {order}; must be one of ['asc', 'desc']",
)

# Fetch the total number of objects in the database before filtering
count = query.count()

# Get the python type of the objects being queried
queried_type = query.column_descriptions[0]["type"]
columns = ordering_metadata.get_requested_columns(order_by)
Expand All @@ -159,23 +169,22 @@ def paginate(
)
)

order_by_args = [order_func(col) for col in columns] + [order_func(queried_type.id)]

query = query.order_by(*order_by_args)
# Order the query by the requested columns, and also by the object's primary key
query = query.order_by(
*([order_func(col) for col in columns] + [order_func(queried_type.id)])
)
data = query.limit(limit).all()
count = query.count()

if count > 0:
if len(data) > 0:
last_result = data[-1]
last_value = ordering_metadata.get_attr_values(last_result, order_by)

next_cursor = Cursor(
last_id=data[-1].id, last_value=last_value, count=query.count()
last_id=last_result.id,
last_value=ordering_metadata.get_attr_values(last_result, order_by),
)
else:
next_cursor = Cursor.end()

return (data, next_cursor)
return (data, next_cursor, count)


class CursorPaginatedArgs(pydantic.BaseModel):
Expand Down
3 changes: 0 additions & 3 deletions conda-store-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ ignore = [
"ANN001", # missing-type-function-argument
"ANN002", # missing-type-args
"ANN003", # missing-type-kwargs
"ANN101", # missing-type-self
"ANN102", # missing-type-cls
"ANN201", # missing-return-type-undocumented-public-function
"ANN202", # missing-return-type-private-function
"ANN204", # missing-return-type-special-method
Expand Down Expand Up @@ -175,7 +173,6 @@ ignore = [
"FIX002", # line-contains-todo
"N805", # invalid-first-argument-name-for-method
"N815", # mixed-case-variable-in-class-scope
"PT004", # pytest-missing-fixture-name-underscore
"PT006", # pytest-parametrize-names-wrong-type
"PT011", # pytest-raises-too-broad
"PT012", # pytest-raises-with-multiple-statements
Expand Down
84 changes: 61 additions & 23 deletions conda-store-server/tests/_internal/server/views/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,7 @@ def test_default_conda_store_dir():
[
"asc",
"desc",
None, # If none is specified, results will be sorted ascending
],
)
@pytest.mark.parametrize(
Expand All @@ -1089,7 +1090,7 @@ def test_default_conda_store_dir():
("namespace,name", lambda x: (x.namespace.name, x.name, x.id)),
],
)
def test_api_list_environments(
def test_api_list_environments_paginate_order_by(
conda_store_server,
testclient,
seed_conda_store_big,
Expand All @@ -1098,47 +1099,84 @@ def test_api_list_environments(
sort_by_param,
attr_func,
):
"""Test the REST API lists the paginated envs when sorting by name."""
response = testclient.get(
f"api/v1/environment/?sort_by={sort_by_param}&order={order}"
)
response.raise_for_status()
"""Test the REST API lists the paginated envs when given sort_by query parameters."""
limit = 10
nfetches = 0
envs = []

order_param = "" if order is None else f"&order={order}"
cursor = None
cursor_param = ""
while cursor is None or cursor != Cursor.end():
response = testclient.get(
f"api/v1/environment/?limit={limit}&sort_by={sort_by_param}{order_param}{cursor_param}"
)
response.raise_for_status()

model = schema.APIListEnvironment.model_validate(response.json())
assert model.status == schema.APIStatus.OK

envs.extend(model.data)

# Get the next cursor and the next query parameters
cursor = Cursor.load(model.cursor)
cursor_param = f"&cursor={model.cursor}"

model = schema.APIListEnvironment.model_validate(response.json())
assert model.status == schema.APIStatus.OK
nfetches += 1

# Pull out the attributes that we are sorting on from each environment
envs = [attr_func(env) for env in model.data]
env_attrs = [attr_func(env) for env in envs]

# Check that the number of results reported by the server corresponds to the number of
# results retrieved
assert model.count == len(envs)

# Check that number of results requested corresponds to the number of results retrieved.
# Since the last fetch isn't always of length `limit`, we subtract off the remainder
# before checking.
assert len(envs) - (len(envs) % limit) == limit * (nfetches - 1)

# The environments should already be sorted; check that this is the case
assert sorted(envs, reverse=(order == "desc")) == envs
assert sorted(env_attrs, reverse=(order == "desc")) == env_attrs


def test_api_list_environments_no_qparam(
def test_api_list_environments_paginate(
conda_store_server,
testclient,
seed_conda_store_big,
authenticate,
):
"""Test the REST API lists the envs by id when no query params are specified."""
response = testclient.get("api/v1/environment/?limit=10")
response.raise_for_status()

model = schema.APIListEnvironment.model_validate(response.json())
assert model.status == schema.APIStatus.OK

envs = model.data

while Cursor.load(model.cursor).last_id is not None:
response = testclient.get(f"api/v1/environment/?limit=10&cursor={model.cursor}")
limit = 10
nfetches = 0
envs = []

cursor = None
cursor_param = ""
while cursor is None or cursor != Cursor.end():
response = testclient.get(f"api/v1/environment/?limit={limit}{cursor_param}")
response.raise_for_status()

model = schema.APIListEnvironment.model_validate(response.json())
assert model.status == schema.APIStatus.OK

envs.extend(model.data)

env_ids = [env.id for env in model.data]
# Get the next cursor and the next query parameters
cursor = Cursor.load(model.cursor)
cursor_param = f"&cursor={model.cursor}"

nfetches += 1

env_ids = [env.id for env in envs]

# Check that the number of results reported by the server corresponds to the number of
# results retrieved
assert model.count == len(env_ids)

# Check that number of results requested corresponds to the number of results retrieved.
# Since the last fetch isn't always of length `limit`, we subtract off the remainder
# before checking.
assert len(env_ids) - (len(env_ids) % limit) == limit * (nfetches - 1)

# Check that the environments are sorted by ID
assert sorted(env_ids) == env_ids

0 comments on commit 664d122

Please sign in to comment.