diff --git a/src/dioptra/restapi/db/repository/errors.py b/src/dioptra/restapi/db/repository/errors.py index bd2d2d2c6..29bda2038 100644 --- a/src/dioptra/restapi/db/repository/errors.py +++ b/src/dioptra/restapi/db/repository/errors.py @@ -31,3 +31,17 @@ class UserEmailNotAvailableError(Exception): class QueueAlreadyExistsError(Exception): """The queue name already exists.""" + + +class QueueSortError(Exception): + """The requested sortBy column is not a sortable field.""" + + +class UnsupportedFilterField(Exception): + """A filter field is not supported for a particular repository method""" + + def __init__(self, field_name: str) -> None: + self.field_name = field_name + + message = f"{self.field_name!r} is not a valid field" + super().__init__(message) diff --git a/src/dioptra/restapi/db/repository/groups.py b/src/dioptra/restapi/db/repository/groups.py index 1d01f7c21..e90f6101b 100644 --- a/src/dioptra/restapi/db/repository/groups.py +++ b/src/dioptra/restapi/db/repository/groups.py @@ -17,7 +17,8 @@ """ The group repository: data operations related to groups """ -from typing import Final +from collections.abc import Sequence +from typing import Any, Final import sqlalchemy as sa @@ -33,6 +34,7 @@ assert_group_exists, assert_user_exists, check_user_collision, + construct_sql_query_filters, group_exists, user_exists, ) @@ -52,6 +54,10 @@ class GroupRepository: + SEARCHABLE_FIELDS: Final[dict[str, Any]] = { + "name": lambda x: Group.name.like(x, escape="/"), + } + def __init__(self, session: CompatibleSession[S]): self.session = session @@ -180,6 +186,40 @@ def get_by_name( return group + def get_by_filters_paged( + self, + filters: list[dict], + page_start: int, + page_length: int, + deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED, + ) -> tuple[Sequence[User], int]: + sql_filter = construct_sql_query_filters(filters, self.SEARCHABLE_FIELDS) + + count_stmt = sa.select(sa.func.count()).select_from(Group) + if sql_filter is not None: + count_stmt = count_stmt.where(sql_filter) + count_stmt = _apply_deletion_policy(count_stmt, deletion_policy) + current_count = self.session.scalar(count_stmt) + + # For mypy: a "SELECT count(*)..." query should never return NULL. + assert current_count is not None + + groups: Sequence[Group] + if current_count == 0: + groups = [] + else: + page_stmt = sa.select(Group) + if sql_filter is not None: + page_stmt = page_stmt.where(sql_filter) + page_stmt = _apply_deletion_policy(page_stmt, deletion_policy) + # *must* enforce a sort order for consistent paging + page_stmt = page_stmt.order_by(Group.group_id) + page_stmt = page_stmt.offset(page_start).limit(page_length) + + groups = self.session.scalars(page_stmt).all() + + return groups, current_count + def num_groups( self, deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED ) -> int: diff --git a/src/dioptra/restapi/db/repository/queues.py b/src/dioptra/restapi/db/repository/queues.py index 519ce2f4b..162db6650 100644 --- a/src/dioptra/restapi/db/repository/queues.py +++ b/src/dioptra/restapi/db/repository/queues.py @@ -19,11 +19,12 @@ """ from collections.abc import Iterable, Sequence +from typing import Any, Final import sqlalchemy as sa -from dioptra.restapi.db.models import Group, Queue, Resource -from dioptra.restapi.db.repository.errors import QueueAlreadyExistsError +from dioptra.restapi.db.models import Group, Queue, Resource, Tag +from dioptra.restapi.db.repository.errors import QueueAlreadyExistsError, QueueSortError from dioptra.restapi.db.repository.utils import ( CompatibleSession, DeletionPolicy, @@ -35,6 +36,7 @@ assert_snapshot_does_not_exist, assert_user_exists, assert_user_in_group, + construct_sql_query_filters, delete_resource, get_group_id, get_resource_id, @@ -43,6 +45,21 @@ class QueueRepository: + + SEARCHABLE_FIELDS: Final[dict[str, Any]] = { + "name": lambda x: Queue.name.like(x, escape="/"), + "description": lambda x: Queue.description.like(x, escape="/"), + "tag": lambda x: Queue.tags.any(Tag.name.like(x, escape="/")), + } + + # Maps a general sort criterion name to a Queue attribute name + SORTABLE_FIELDS: Final[dict[str, str]] = { + "name": "name", + "createdOn": "created_on", + "lastModifiedOn": "last_modified_on", + "description": "description", + } + def __init__(self, session: CompatibleSession[S]): self.session = session @@ -261,3 +278,94 @@ def get_by_name( queue = self.session.scalar(stmt) return queue + + def get_by_filters_paged( + self, + group: Group | int | None, + filters: list[dict], + page_start: int, + page_length: int, + sort_by: str | None, + descending: bool, + deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED, + ) -> tuple[Sequence[Queue], int]: + """ + Get a page of queues according to more complex criteria. + + Args: + group: Limit queues to those owned by this group; None to not limit + the search + filters: Search criteria, see parse_search_text() + page_start: Zero-based row index where the page should start + page_length: Maximum number of rows in the page + sort_by: Sort criterion; must be a key of SORTABLE_FIELDS. None + to sort in an implementation-dependent way. + descending: Whether to sort in descending order; only applicable + if sort_by is given + deletion_policy: Whether to look at deleted queues, non-deleted + queue, or all queues + + Returns: + A 2-tuple including the page of queues and total count of matching + queues which exist + """ + sql_filter = construct_sql_query_filters(filters, self.SEARCHABLE_FIELDS) + if sort_by: + sort_by = self.SORTABLE_FIELDS.get(sort_by) + if not sort_by: + raise QueueSortError + group_id = None if group is None else get_group_id(group) + + if group_id is not None: + assert_group_exists(self.session, group_id, DeletionPolicy.NOT_DELETED) + + count_stmt = ( + sa.select(sa.func.count()) + .select_from(Queue, Resource) + .where(Queue.resource_snapshot_id == Resource.latest_snapshot_id) + ) + + if group_id is not None: + count_stmt = count_stmt.where(Resource.group_id == group_id) + + if sql_filter is not None: + count_stmt = count_stmt.where(sql_filter) + + count_stmt = apply_resource_deletion_policy(count_stmt, deletion_policy) + current_count = self.session.scalar(count_stmt) + + # For mypy: a "SELECT count(*)..." query should never return NULL. + assert current_count is not None + + queues: Sequence[Queue] + if current_count == 0: + queues = [] + else: + page_stmt = ( + sa.select(Queue) + .join(Resource) + .where(Queue.resource_snapshot_id == Resource.latest_snapshot_id) + ) + + if group_id is not None: + page_stmt = page_stmt.where(Resource.group_id == group_id) + + if sql_filter is not None: + page_stmt = page_stmt.where(sql_filter) + + page_stmt = apply_resource_deletion_policy(page_stmt, deletion_policy) + + if sort_by: + sort_criteria = getattr(Queue, sort_by) + if descending: + sort_criteria = sort_criteria.desc() + else: + # *must* enforce a sort order for consistent paging + sort_criteria = Queue.resource_snapshot_id + page_stmt = page_stmt.order_by(sort_criteria) + + page_stmt = page_stmt.offset(page_start).limit(page_length) + + queues = self.session.scalars(page_stmt).all() + + return queues, current_count diff --git a/src/dioptra/restapi/db/repository/users.py b/src/dioptra/restapi/db/repository/users.py index dac5705f0..399010f1d 100644 --- a/src/dioptra/restapi/db/repository/users.py +++ b/src/dioptra/restapi/db/repository/users.py @@ -19,6 +19,7 @@ """ import uuid from collections.abc import Sequence +from typing import Any, Final import sqlalchemy as sa @@ -33,12 +34,18 @@ assert_user_does_not_exist, assert_user_exists, check_user_collision, + construct_sql_query_filters, user_exists, ) class UserRepository: + SEARCHABLE_FIELDS: Final[dict[str, Any]] = { + "username": lambda x: User.username.like(x, escape="/"), + "email": lambda x: User.email_address.like(x, escape="/"), + } + def __init__(self, session: CompatibleSession[S]): self.session = session @@ -208,6 +215,55 @@ def get_by_email( return user + def get_by_filters_paged( + self, + filters: list[dict], + page_start: int, + page_length: int, + deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED, + ) -> tuple[Sequence[User], int]: + """ + Get some users according to search criteria. + + Args: + filters: A structure representing search criteria. See + parse_search_text(). + page_start: A row index where the returned page should start + page_length: A row count representing the page length + deletion_policy: Whether to look at deleted users, non-deleted + users, or all users + + Returns: + A 2-tuple including a page of User objects, and a count of the + total number of users matching the criteria + """ + sql_filter = construct_sql_query_filters(filters, self.SEARCHABLE_FIELDS) + + count_stmt = sa.select(sa.func.count()).select_from(User) + if sql_filter is not None: + count_stmt = count_stmt.where(sql_filter) + count_stmt = _apply_deletion_policy(count_stmt, deletion_policy) + current_count = self.session.scalar(count_stmt) + + # For mypy: a "SELECT count(*)..." query should never return NULL. + assert current_count is not None + + users: Sequence[User] + if current_count == 0: + users = [] + else: + page_stmt = sa.select(User) + if sql_filter is not None: + page_stmt = page_stmt.where(sql_filter) + page_stmt = _apply_deletion_policy(page_stmt, deletion_policy) + # *must* enforce a sort order for consistent paging + page_stmt = page_stmt.order_by(User.user_id) + page_stmt = page_stmt.offset(page_start).limit(page_length) + + users = self.session.scalars(page_stmt).all() + + return users, current_count + def num_users( self, deletion_policy: DeletionPolicy = DeletionPolicy.NOT_DELETED ) -> int: diff --git a/src/dioptra/restapi/db/repository/utils.py b/src/dioptra/restapi/db/repository/utils.py index b7ad9e0bd..6ef2fd76f 100644 --- a/src/dioptra/restapi/db/repository/utils.py +++ b/src/dioptra/restapi/db/repository/utils.py @@ -16,8 +16,10 @@ # https://creativecommons.org/licenses/by/4.0/legalcode import enum import typing +from collections.abc import Callable, Iterable import sqlalchemy as sa +import sqlalchemy.sql.expression as sae from sqlalchemy.orm import Session, aliased, scoped_session from dioptra.restapi.db.models import ( @@ -36,6 +38,7 @@ user_lock_types, ) from dioptra.restapi.db.repository.errors import ( + UnsupportedFilterField, UserEmailNotAvailableError, UsernameNotAvailableError, ) @@ -51,6 +54,9 @@ S = typing.TypeVar("S", bound=Session) CompatibleSession = Session | scoped_session[S] +# Type alias for search field callbacks +SearchFieldCallback = Callable[[str], sae.ColumnElement[bool]] + class ExistenceResult(enum.Enum): """ @@ -767,3 +773,96 @@ def delete_resource( session.add(lock) # else: exists_result is DELETED; nothing to do. + + +def _construct_sql_search_value(search_term: str) -> str: + """ + Constructs a search value for a SQL query by replacing wildcards, + escaping and un-escaping. The escape character is assumed to be "/". + + Args: + search_value: A search term + + Returns: + A string to be used as the value in the WHERE clause of a SQL query. + """ + if search_term == "*": + search_term = "%" + elif search_term == "?": + search_term = "_" + else: + search_term = search_term.replace("/", r"//") + search_term = search_term.replace("%", r"/%") + search_term = search_term.replace("_", r"/_") + search_term = search_term.replace(r"\\", "\\") + search_term = search_term.replace(r"\*", "*") + search_term = search_term.replace(r"\?", "?") + search_term = search_term.replace(r"\"", '"') + search_term = search_term.replace(r"\'", "'") + search_term = search_term.replace(r"\n", "\n") + + return search_term + + +def construct_sql_query_filters( + parsed_search_terms: list[dict], searchable_fields: dict[str, SearchFieldCallback] +) -> sae.ColumnElement[bool] | None: + """ + Constructs a search filter to be used by sqlalchemy. + + Args: + parsed_search_terms: A data structure describing a search; see + parse_search_text() + searchable_fields: A dict which maps from a search field name to a + function of one argument which transforms a query string to an + SQLAlchemy expression usable in the WHERE clause of a SELECT + statement, i.e. to filter table rows. The query string will be + an SQL "LIKE" pattern. + + Returns: + A filter that can be used in a sqlalchemy query, or None if no search + terms were given. + + Raises: + SearchParseError: If a search string cannot be parsed. + """ + filter_fns: Iterable[SearchFieldCallback] + + query_filters = [] + for search_term in parsed_search_terms: + field = search_term["field"] + values = search_term["value"] + + sql_search_values = (_construct_sql_search_value(value) for value in values) + + if field is None: + # if no field, create a "fuzzier" combined search pattern + combined_search_value = "%" + "%".join(sql_search_values) + "%" + filter_fns = searchable_fields.values() + else: + combined_search_value = "".join(sql_search_values) + filter_fn = searchable_fields.get(field) + if filter_fn: + filter_fns = (filter_fn,) + else: + raise UnsupportedFilterField(field) + + search_exprs = [filter_fn(combined_search_value) for filter_fn in filter_fns] + + if len(search_exprs) == 1: + # avoid useless 1-arg OR + combined_search_expr = search_exprs[0] + else: + combined_search_expr = sa.or_(*search_exprs) + + query_filters.append(combined_search_expr) + + if not query_filters: + result = None + elif len(query_filters) == 1: + # avoid useless 1-arg AND + result = query_filters[0] + else: + result = sa.and_(*query_filters) + + return result diff --git a/src/dioptra/restapi/v1/groups/service.py b/src/dioptra/restapi/v1/groups/service.py index ac13c916c..c673e6f72 100644 --- a/src/dioptra/restapi/v1/groups/service.py +++ b/src/dioptra/restapi/v1/groups/service.py @@ -18,18 +18,16 @@ from __future__ import annotations import datetime -from typing import Any, Final, cast +from typing import Any, Final import structlog from injector import inject -from sqlalchemy import func, select from structlog.stdlib import BoundLogger -from dioptra.restapi.db import db, models +from dioptra.restapi.db import models from dioptra.restapi.db.repository.utils import DeletionPolicy from dioptra.restapi.db.unit_of_work import UnitOfWork -from dioptra.restapi.errors import BackendDatabaseError -from dioptra.restapi.v1.shared.search_parser import construct_sql_query_filters +from dioptra.restapi.v1.shared.search_parser import parse_search_text from .errors import GroupDoesNotExistError, GroupNameNotAvailableError @@ -45,9 +43,6 @@ "owner": False, "admin": True, } -SEARCHABLE_FIELDS: Final[dict[str, Any]] = { - "name": lambda x: models.Group.name.like(x, escape="/"), -} class GroupService(object): @@ -130,36 +125,12 @@ def get( log: BoundLogger = kwargs.get("log", LOGGER.new()) log.debug("Get list of groups") - search_filters = construct_sql_query_filters(search_string, SEARCHABLE_FIELDS) - - stmt = ( - select(func.count(models.Group.group_id)) - .filter_by(is_deleted=False) - .filter(search_filters) - ) - total_num_groups = db.session.scalars(stmt).first() - - if total_num_groups is None: - log.error( - "The database query returned a None when counting the number of " - "groups when it should return a number.", - sql=str(stmt), - ) - raise BackendDatabaseError - - if total_num_groups == 0: - return cast(list[models.Group], []), total_num_groups - - stmt = ( - select(models.Group) # type: ignore - .filter_by(is_deleted=False) - .filter(search_filters) - .offset(page_index) - .limit(page_length) + search_struct = parse_search_text(search_string) + groups, total_num_groups = self._uow.group_repo.get_by_filters_paged( + search_struct, page_index * page_length, page_length ) - groups = cast(list[models.Group], db.session.scalars(stmt).all()) - return groups, total_num_groups + return list(groups), total_num_groups class GroupIdService(object): diff --git a/src/dioptra/restapi/v1/queues/controller.py b/src/dioptra/restapi/v1/queues/controller.py index 2540370a5..122266d88 100644 --- a/src/dioptra/restapi/v1/queues/controller.py +++ b/src/dioptra/restapi/v1/queues/controller.py @@ -28,6 +28,7 @@ from structlog.stdlib import BoundLogger from dioptra.restapi.db import models +from dioptra.restapi.db.repository.queues import QueueRepository from dioptra.restapi.routes import V1_QUEUES_ROUTE from dioptra.restapi.v1 import utils from dioptra.restapi.v1.schemas import IdStatusResponseSchema @@ -51,7 +52,7 @@ QueuePageSchema, QueueSchema, ) -from .service import RESOURCE_TYPE, SEARCHABLE_FIELDS, QueueIdService, QueueService +from .service import RESOURCE_TYPE, QueueIdService, QueueService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -212,7 +213,7 @@ def put(self, id: int): resource_model=models.Queue, resource_name=RESOURCE_TYPE, route_prefix=V1_QUEUES_ROUTE, - searchable_fields=SEARCHABLE_FIELDS, + searchable_fields=QueueRepository.SEARCHABLE_FIELDS, page_schema=QueuePageSchema, build_fn=utils.build_queue, ) diff --git a/src/dioptra/restapi/v1/queues/errors.py b/src/dioptra/restapi/v1/queues/errors.py index 01d0e80cc..610c93f82 100644 --- a/src/dioptra/restapi/v1/queues/errors.py +++ b/src/dioptra/restapi/v1/queues/errors.py @@ -19,7 +19,7 @@ from flask_restx import Api -from dioptra.restapi.db.repository.errors import QueueAlreadyExistsError +from dioptra.restapi.db.repository.errors import QueueAlreadyExistsError, QueueSortError class QueueDoesNotExistError(Exception): @@ -30,10 +30,6 @@ class QueueLockedError(Exception): """The requested queue is locked.""" -class QueueSortError(Exception): - """The requested sortBy column is not a sortable field.""" - - def register_error_handlers(api: Api) -> None: @api.errorhandler(QueueDoesNotExistError) def handle_queue_does_not_exist_error(error): diff --git a/src/dioptra/restapi/v1/queues/service.py b/src/dioptra/restapi/v1/queues/service.py index 1ff325d54..5b4138a38 100644 --- a/src/dioptra/restapi/v1/queues/service.py +++ b/src/dioptra/restapi/v1/queues/service.py @@ -23,33 +23,21 @@ import structlog from flask_login import current_user from injector import inject -from sqlalchemy import Integer, func, select +from sqlalchemy import Integer, select from structlog.stdlib import BoundLogger from dioptra.restapi.db import db, models from dioptra.restapi.db.repository.utils import DeletionPolicy from dioptra.restapi.db.shared_errors import ResourceDeletedError, ResourceNotFoundError from dioptra.restapi.db.unit_of_work import UnitOfWork -from dioptra.restapi.errors import BackendDatabaseError from dioptra.restapi.v1 import groups, utils -from dioptra.restapi.v1.shared.search_parser import construct_sql_query_filters +from dioptra.restapi.v1.shared.search_parser import parse_search_text -from .errors import QueueDoesNotExistError, QueueSortError +from .errors import QueueDoesNotExistError LOGGER: BoundLogger = structlog.stdlib.get_logger() RESOURCE_TYPE: Final[str] = "queue" -SEARCHABLE_FIELDS: Final[dict[str, Any]] = { - "name": lambda x: models.Queue.name.like(x, escape="/"), - "description": lambda x: models.Queue.description.like(x, escape="/"), - "tag": lambda x: models.Queue.tags.any(models.Tag.name.like(x, escape="/")), -} -SORTABLE_FIELDS: Final[dict[str, Any]] = { - "name": models.Queue.name, - "createdOn": models.Queue.created_on, - "lastModifiedOn": models.Resource.last_modified_on, - "description": models.Queue.description, -} class QueueService(object): @@ -149,62 +137,17 @@ def get( log: BoundLogger = kwargs.get("log", LOGGER.new()) log.debug("Get full list of queues") - filters = list() + search_struct = parse_search_text(search_string) - if group_id is not None: - filters.append(models.Resource.group_id == group_id) - - if search_string: - filters.append( - construct_sql_query_filters(search_string, SEARCHABLE_FIELDS) - ) - - stmt = ( - select(func.count(models.Queue.resource_id)) - .join(models.Resource) - .where( - *filters, - models.Resource.is_deleted == False, # noqa: E712 - models.Resource.latest_snapshot_id == models.Queue.resource_snapshot_id, - ) + queues, total_num_queues = self._uow.queue_repo.get_by_filters_paged( + group_id, + search_struct, + page_index * page_length, + page_length, + sort_by_string, + descending, + DeletionPolicy.NOT_DELETED, ) - total_num_queues = db.session.scalars(stmt).first() - - if total_num_queues is None: - log.error( - "The database query returned a None when counting the number of " - "queues when it should return a number.", - sql=str(stmt), - ) - raise BackendDatabaseError - - if total_num_queues == 0: - return [], total_num_queues - - queues_stmt = ( - select(models.Queue) - .join(models.Resource) - .where( - *filters, - models.Resource.is_deleted == False, # noqa: E712 - models.Resource.latest_snapshot_id == models.Queue.resource_snapshot_id, - ) - .offset(page_index) - .limit(page_length) - ) - - if sort_by_string and sort_by_string in SORTABLE_FIELDS: - sort_column = SORTABLE_FIELDS[sort_by_string] - if descending: - sort_column = sort_column.desc() - else: - sort_column = sort_column.asc() - queues_stmt = queues_stmt.order_by(sort_column) - elif sort_by_string and sort_by_string not in SORTABLE_FIELDS: - log.debug(f"sort_by_string: '{sort_by_string}' is not in SORTABLE_FIELDS") - raise QueueSortError - - queues = list(db.session.scalars(queues_stmt).all()) drafts_stmt = select( models.DraftResource.payload["resource_id"].as_string().cast(Integer) diff --git a/src/dioptra/restapi/v1/shared/search_parser.py b/src/dioptra/restapi/v1/shared/search_parser.py index c737d8701..99536b9e9 100644 --- a/src/dioptra/restapi/v1/shared/search_parser.py +++ b/src/dioptra/restapi/v1/shared/search_parser.py @@ -108,15 +108,20 @@ def parse_search_text(search_text: str) -> list[dict]: field. The value is a list of strings that represent the search value. """ - parsed_search = DIOPTRA_QUERY_GRAMMAR.parse_string( - search_text, parse_all=True - ).as_list() - formatted_result = [] - for term in parsed_search: - if len(term) > 1 and isinstance(term[1], list): - formatted_result.append({"field": term[0], "value": term[1]}) - else: - formatted_result.append({"field": None, "value": term}) + formatted_result: list[dict] + if not search_text: + formatted_result = [] + + else: + parsed_search = DIOPTRA_QUERY_GRAMMAR.parse_string( + search_text, parse_all=True + ).as_list() + formatted_result = [] + for term in parsed_search: + if len(term) > 1 and isinstance(term[1], list): + formatted_result.append({"field": term[0], "value": term[1]}) + else: + formatted_result.append({"field": None, "value": term}) return formatted_result diff --git a/src/dioptra/restapi/v1/users/service.py b/src/dioptra/restapi/v1/users/service.py index 0454b5b28..57993d0dc 100644 --- a/src/dioptra/restapi/v1/users/service.py +++ b/src/dioptra/restapi/v1/users/service.py @@ -24,19 +24,17 @@ import structlog from flask_login import current_user from injector import inject -from sqlalchemy import func, select from structlog.stdlib import BoundLogger -from dioptra.restapi.db import db, models +from dioptra.restapi.db import models from dioptra.restapi.db.repository.utils import DeletionPolicy from dioptra.restapi.db.unit_of_work import UnitOfWork -from dioptra.restapi.errors import BackendDatabaseError from dioptra.restapi.v1.groups.service import GroupMemberService from dioptra.restapi.v1.plugin_parameter_types.service import ( BuiltinPluginParameterTypeService, ) from dioptra.restapi.v1.shared.password_service import PasswordService -from dioptra.restapi.v1.shared.search_parser import construct_sql_query_filters +from dioptra.restapi.v1.shared.search_parser import parse_search_text from .errors import ( NoCurrentUserError, @@ -58,10 +56,6 @@ "share_write": False, } DAYS_TO_EXPIRE_PASSWORD_DEFAULT: Final[int] = 365 -SEARCHABLE_FIELDS: Final[dict[str, Any]] = { - "username": lambda x: models.User.username.like(x, escape="/"), - "email": lambda x: models.User.email_address.like(x, escape="/"), -} class UserService(object): @@ -162,7 +156,7 @@ def get( Args: search_string: A search string used to filter results. - page_index: The index of the first user to be returned. + page_index: The index of the first page to be returned. page_length: The maximum number of users to be returned. Returns: @@ -176,36 +170,12 @@ def get( log: BoundLogger = kwargs.get("log", LOGGER.new()) log.debug("Get list of users") - search_filters = construct_sql_query_filters(search_string, SEARCHABLE_FIELDS) - - stmt = ( - select(func.count(models.User.user_id)) - .filter_by(is_deleted=False) - .filter(search_filters) - ) - total_num_users = db.session.scalars(stmt).first() - - if total_num_users is None: - log.error( - "The database query returned a None when counting the number of " - "users when it should return a number.", - sql=str(stmt), - ) - raise BackendDatabaseError - - if total_num_users == 0: - return cast(list[models.User], []), total_num_users - - stmt = ( - select(models.User) # type: ignore - .filter_by(is_deleted=False) - .filter(search_filters) - .offset(page_index) - .limit(page_length) + search_struct = parse_search_text(search_string) + users, total_num_users = self._uow.user_repo.get_by_filters_paged( + search_struct, page_index * page_length, page_length ) - users = cast(list[models.User], db.session.scalars(stmt).all()) - return users, total_num_users + return list(users), total_num_users def _create_or_get_default_group( self, diff --git a/tests/unit/restapi/v1/test_queue.py b/tests/unit/restapi/v1/test_queue.py index fff422f18..b7189032c 100644 --- a/tests/unit/restapi/v1/test_queue.py +++ b/tests/unit/restapi/v1/test_queue.py @@ -207,7 +207,11 @@ def assert_retrieving_queues_works( query_string=query_string, follow_redirects=True, ) - assert response.status_code == 200 and response.get_json()["data"] == expected + # A sort order was not given in the request, so we must not assume a + # particular order in the response. + expected = sorted(expected, key=lambda d: d["id"]) + resp_data = sorted(response.get_json()["data"], key=lambda d: d["id"]) + assert response.status_code == 200 and resp_data == expected def assert_sorting_queue_works(