From 0de0293ce26f91d389586f0d10b83ab01e32da4c Mon Sep 17 00:00:00 2001 From: "James K. Glasbrenner" Date: Fri, 4 Oct 2024 11:18:04 -0400 Subject: [PATCH] feat: implement dioptra v1 rest api python client examples: begin creating demo of v1client feat: update test_client for users and auth feat: update client demo feat: correct location of demo client feat: update logging and add queues for dioptra v1 client feat: correct some comments feat: update demo of v1 client to include plugins feat: change id parameter names to be more descriptive and not conflict with python builtin id feat: ensure testing of more than one parameter per plugin task feat: put client in correct place and consolidate tags feature for most endpoints feat: fix tag usage examples in notebook; fix lint in client chore: lint fixes on client chore: fix linting issues feat: fix tags client class, move address definition, default Session chore: linting errors feat: update client feat: define a DioptraSession class and begin refactoring to use it refactor: work-in-progress to consolidate and streamline client refactor: add boilerplate comment refactor: update the dioptra client to have proper session management This is a work-in-progress. In addition to the session manager, the endpoint clients are being broken out into separate files to avoid having one very large file. Proper type annotations are being added, and the logic for building the endpoint urls has been consolidated so that there's less url building and passing in the code in general. chore: add stubs for the other endpoint client files test: implement DioptraFlaskClientSession for testing purposes This update creates the DioptraFlaskClientSession class that serves that adapts the FlaskClient and TestResponse classes to work within DioptraClient. This allows the DioptraClient to be directly exercised in unit tests. test: refactor users endpoint tests to use the DioptraClient refactor: remove __init__ from DioptraSession base class The __init__ should be customizable to the particular requests library. In addition, the url property is now an abstractmethod and must be defined in every inheriting class. Finally, the session property in DioptraRequestsSession has been demoted to a private method since there is no special reason the requests Session class should be accessible to the user as a public-facing attribute. refactor: add ClassVar type annotations fix: move abstractmethod below property decorator, set None defaults feat: create DioptraClientError base class for client exceptions feat: create a fields validation error exception for the client test: set V1_ROOT as the DioptraFlaskClientSession url property feat: create a SubEndpoint base class for the client feat: implement the drafts subendpoint component for the client feat: implement the tags subendpoint component for the client feat: implement the queues endpoint support in the client test: convert first unit test for queues to using new client fix: fix url construction for drafts with existing resources fix: add missing group_id argument in queues client test: migrate common drafts assert statements to new dioptra client test: migrate queue testing code to using new dioptra client test: update docstrings refactor: fix generic typevars and consolidate sessions code test: refactor test session client to consolidate code and use generic types test: change pytest.raises assertions back to status code checks refactor: remove unused debug logging functions refactor: hide client method kwargs from logger Secrets can end up here, so its not a good idea to just show everything in the logger messages. refactor: wrap_request_method is a better name than wrap_response test: wrap_request_method is a better name than wrap_response feat: add factory functions for constructing response and json dioptra clients refactor: move api address construction into client factory functions refactor: remove None options from modify_current_user in client refactor: move validate_ids_argument to bottom of file and prefix with underscore refactor: rename set to append in tags subclient test: update to use append instead of set name docs: add docstrings to drafts, queues, tags, users in client feat: create snapshots subendpoint and attach to queues client test: exercise queues snapshots subendpoint in unit tests chore: run code formatter refactor: move 2xx status code check into separate function Expand the check to the full 200 status code range. refactor: remove None option from address kwarg docs: add docstrings to all existing DioptraClient related classes refactor: reorder properties in DioptraResquestProtocol test: remove unused imports test: add docstrings to the dioptra test client and responses test: remove unused imports test: fix whitespace chore: run black and isort chore: run black and isort test: run black and isort test: run black and isort chore: satisfy flake8 feat: implement the tags endpoint client test: update tags unit tests to use the dioptra client refactor: append "Client" suffix to base endpoint and subendpoint classes refactor: append "Client" to all subendpoint classes test: propagate new subendpoint class names feat: declare a SubEndpointChildrenClient interface feat: begin implementing the SubEndpointChildren version of Tags refactor: rename build_child_url to build_resource_url chore: run black and isort docs: add docstrings refactor: moved to collection/sub collection naming scheme for client This refactor standardizes the naming scheme for the different endpoint clients on the terms collection and sub collection and updates the logic for the sub-collections to better handle arbitrary nesting of sub-collections. This follows the concepts outlined in this page: https://restful-api-design.readthedocs.io/en/latest/resources.html Co-authored-by: jtsextonMITRE <45762017+jtsextonMITRE@users.noreply.github.com> --- src/dioptra/client/__init__.py | 14 +- src/dioptra/client/_client.py | 1653 ++++++++++------- src/dioptra/client/artifacts.py | 16 + src/dioptra/client/auth.py | 63 + src/dioptra/client/base.py | 437 +++++ src/dioptra/client/client.py | 210 +++ src/dioptra/client/drafts.py | 337 ++++ src/dioptra/client/entrypoints.py | 16 + src/dioptra/client/experiments.py | 16 + src/dioptra/client/groups.py | 16 + src/dioptra/client/jobs.py | 16 + src/dioptra/client/models.py | 16 + src/dioptra/client/plugin_parameter_types.py | 16 + src/dioptra/client/plugins.py | 16 + src/dioptra/client/queues.py | 198 ++ src/dioptra/client/sessions.py | 587 ++++++ src/dioptra/client/snapshots.py | 85 + src/dioptra/client/tags.py | 293 +++ src/dioptra/client/users.py | 183 ++ tests/unit/restapi/conftest.py | 8 + tests/unit/restapi/lib/asserts_client.py | 196 ++ tests/unit/restapi/lib/client.py | 320 ++++ tests/unit/restapi/v1/conftest.py | 1 - tests/unit/restapi/v1/test_experiment.py | 1 - tests/unit/restapi/v1/test_group.py | 1 + .../restapi/v1/test_plugin_parameter_type.py | 30 +- tests/unit/restapi/v1/test_queue.py | 398 ++-- tests/unit/restapi/v1/test_tag.py | 203 +- tests/unit/restapi/v1/test_user.py | 324 ++-- 29 files changed, 4415 insertions(+), 1255 deletions(-) create mode 100644 src/dioptra/client/artifacts.py create mode 100644 src/dioptra/client/auth.py create mode 100644 src/dioptra/client/base.py create mode 100644 src/dioptra/client/client.py create mode 100644 src/dioptra/client/drafts.py create mode 100644 src/dioptra/client/entrypoints.py create mode 100644 src/dioptra/client/experiments.py create mode 100644 src/dioptra/client/groups.py create mode 100644 src/dioptra/client/jobs.py create mode 100644 src/dioptra/client/models.py create mode 100644 src/dioptra/client/plugin_parameter_types.py create mode 100644 src/dioptra/client/plugins.py create mode 100644 src/dioptra/client/queues.py create mode 100644 src/dioptra/client/sessions.py create mode 100644 src/dioptra/client/snapshots.py create mode 100644 src/dioptra/client/tags.py create mode 100644 src/dioptra/client/users.py create mode 100644 tests/unit/restapi/lib/asserts_client.py create mode 100644 tests/unit/restapi/lib/client.py diff --git a/src/dioptra/client/__init__.py b/src/dioptra/client/__init__.py index 9c9274cfe..0da6a6511 100644 --- a/src/dioptra/client/__init__.py +++ b/src/dioptra/client/__init__.py @@ -14,8 +14,14 @@ # # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode -from __future__ import annotations +from .client import ( + DioptraClient, + connect_json_dioptra_client, + connect_response_dioptra_client, +) -from ._client import DioptraClient - -__all__ = ["DioptraClient"] +__all__ = [ + "connect_response_dioptra_client", + "connect_json_dioptra_client", + "DioptraClient", +] diff --git a/src/dioptra/client/_client.py b/src/dioptra/client/_client.py index 176ea3ae9..f4350effe 100644 --- a/src/dioptra/client/_client.py +++ b/src/dioptra/client/_client.py @@ -14,31 +14,44 @@ # # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode -from __future__ import annotations - import os -from pathlib import Path +from abc import ABC, abstractmethod from posixpath import join as urljoin -from typing import Any, cast from urllib.parse import urlparse, urlunparse import requests +import structlog +from structlog.stdlib import BoundLogger +LOGGER: BoundLogger = structlog.stdlib.get_logger() -class DioptraClient(object): - """Connects to the Dioptra REST api, and provides access to endpoints. - Args: - address: Address of the Dioptra REST api or if no address is given the - DIOPTRA_RESTAPI_URI environment variable is used. - api_version: The version of the Dioptra REST API to use. Defaults to "v0". +class APIConnectionError(Exception): + """Class for connection errors""" + + +class StatusCodeError(Exception): + """Class for status code errors""" + + +class JSONDecodeError(Exception): + """Class for JSON decode errors""" + - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html for - more information on Dioptra's REST api. - """ +def create_data_dict(**kwargs): + return kwargs - def __init__(self, address: str | None = None, api_version: str = "v0") -> None: + +def debug_request(url, method, data=None): + LOGGER.debug("Request made.", url=url, method=method, data=data) + + +def debug_response(json): + LOGGER.debug("Response received.", json=json) + + +class DioptraSession(ABC): + def __init__(self, address=None, api_version="v1"): address = ( f"{address}/api/{api_version}" if address @@ -46,701 +59,1041 @@ def __init__(self, address: str | None = None, api_version: str = "v0") -> None: ) self._scheme, self._netloc, self._path, _, _, _ = urlparse(address) + @abstractmethod + def get_session(self): + raise NotImplementedError + + @abstractmethod + def make_request(self, method_name, endpoint, data, *features): + raise NotImplementedError + + @abstractmethod + def handle_error(self, url, method, data, response, error): + raise NotImplementedError + + def get(self, endpoint, *features): + debug_request(urljoin(endpoint, *features), "GET") + return self.make_request("get", endpoint, None, *features) + + def post(self, endpoint, data, *features): + debug_request(urljoin(endpoint, *features), "POST", data) + return self.make_request("post", endpoint, data, *features) + + def delete(self, endpoint, data, *features): + debug_request(urljoin(endpoint, *features), "DELETE", data) + return self.make_request("delete", endpoint, data, *features) + + def put(self, endpoint, data, *features): + debug_request(urljoin(endpoint, *features), "PUT", data) + return self.make_request("put", endpoint, data, *features) + + +class DioptraRequestsSession(DioptraSession): + def __init__(self, address=None, api_version="v1"): + super().__init__(address=address, api_version=api_version) + self._session = None + + def get_session(self): + if self._session is None: + self._session = requests.Session() + return self._session + + def make_request(self, method_name, endpoint, data, *features): + session = self.get_session() + url = urljoin(endpoint, *features) + method = getattr(session, method_name) + try: + if data: + response = method(url, json=data) + else: + response = method(url) + if response.status_code != 200: + raise StatusCodeError() + json = response.json() + except ( + requests.ConnectionError, + StatusCodeError, + requests.JSONDecodeError, + ) as e: + self.handle_error(session, url, method_name.upper(), data, response, e) + debug_response(json=json) + return json + + def handle_error(self, url, method, data, response, error): + if type(error) is requests.ConnectionError: + restapi = os.environ["DIOPTRA_RESTAPI_URI"] + message = ( + "Could not connect to the REST API. Is the server running at " + f"{restapi}?" + ) + LOGGER.error( + message, url=url, method=method, data=data, response=response.text + ) + raise APIConnectionError(message) + if type(error) is StatusCodeError: + message = f"Error code {response.status_code} returned." + LOGGER.error( + message, url=url, method=method, data=data, response=response.text + ) + raise StatusCodeError(message) + if type(error) is requests.JSONDecodeError: + message = "JSON response could not be decoded." + LOGGER.error( + message, url=url, method=method, data=data, response=response.text + ) + raise JSONDecodeError(message) + + +class DioptraClient(object): + def __init__(self, session): + self._session = session + self._users = UsersClient(session) + self._auth = AuthClient(session) + self._queues = QueuesClient(session) + self._groups = GroupsClient(session) + self._tags = TagsClient(session) + self._plugins = PluginsClient(session) + self._pluginParameterTypes = PluginParameterTypesClient(session) + self._experiments = ExperimentsClient(session) + self._jobs = JobsClient(session) + self._entrypoints = EntrypointsClient(session) + self._models = ModelsClient(session) + self._artifacts = ArtifactsClient(session) + @property - def experiment_endpoint(self) -> str: - """Experiment endpoint url""" - return urlunparse( - (self._scheme, self._netloc, urljoin(self._path, "experiment/"), "", "", "") - ) + def users(self): + return self._users @property - def job_endpoint(self) -> str: - """Job endpoint url""" - return urlunparse( - (self._scheme, self._netloc, urljoin(self._path, "job/"), "", "", "") - ) + def auth(self): + return self._auth @property - def task_plugin_endpoint(self) -> str: - """Task plugins endpoint url""" - return urlunparse( - (self._scheme, self._netloc, urljoin(self._path, "taskPlugin/"), "", "", "") - ) + def queues(self): + return self._queues @property - def task_plugin_builtins_endpoint(self) -> str: - """Builtin task plugins endpoint url""" - return urlunparse( - ( - self._scheme, - self._netloc, - urljoin(self._path, "taskPlugin/dioptra_builtins"), - "", - "", - "", - ) - ) + def groups(self): + return self._groups + + @property + def tags(self): + return self._tags + + @property + def plugins(self): + return self._plugins + + @property + def pluginParameterTypes(self): + return self._pluginParameterTypes + + @property + def experiments(self): + return self._experiments + + @property + def jobs(self): + return self._jobs + + @property + def entrypoints(self): + return self._entrypoints + + @property + def models(self): + return self._models + + @property + def artifacts(self): + return self._artifacts + + +class HasTagsProvider(object): + def __init__(self, url, session): + self._tags = TagsProvider(url, session) @property - def task_plugin_custom_endpoint(self) -> str: - """Custom task plugins endpoint url""" + def tags(self): + return self.get_endpoint(self._tags) + + def get_endpoint(self, ep): + ep.session = self._session + return ep + + +class HasDraftsEndpoint(object): + def __init__(self, url, session, address, fields, put_fields=None): + self.draft_fields = fields + self.put_fields = put_fields if put_fields is not None else fields + self._drafts = DraftsEndpoint(url, self, session, "draft", address) + + @property + def drafts(self): + return self.get_endpoint(self._drafts) + + def get_endpoint(self, ep): + ep.session = self._session + return ep + + +class HasSubEndpointProvider(object): + def __init__(self, url): + self._url = url + + def idurl(self, ep_id): + return urljoin(self._url, ep_id) + + +class Endpoint(object): + + _ep_name = "" + + def __init__(self, session): + self._session = session + + @property + def session(self): + return self._session + + @session.setter + def session(self, s): + self._session = s + + @property + def url(self): return urlunparse( ( - self._scheme, - self._netloc, - urljoin(self._path, "taskPlugin/dioptra_custom"), + self._session._scheme, + self._session._netloc, + urljoin(self._session._path, self._ep_name + "/"), "", "", "", ) ) - @property - def queue_endpoint(self) -> str: - """Queue endpoint url""" - return urlunparse( - (self._scheme, self._netloc, urljoin(self._path, "queue/"), "", "", "") - ) - def delete_custom_task_plugin(self, name: str) -> dict[str, Any]: - """Deletes a custom task plugin by its unique name. - - Args: - name: A unique string identifying a task plugin package within - dioptra_custom collection. - - Returns: - The Dioptra REST api's response. - - Example:: - - { - 'collection': 'dioptra_custom', - 'status': 'Success', - 'taskPluginName': ['evaluation'] - } - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - plugin_name_query: str = urljoin(self.task_plugin_custom_endpoint, name) - result = cast(dict[str, Any], requests.delete(plugin_name_query).json()) - return result - - def get_experiment_by_id(self, id: int) -> dict[str, Any]: - """Gets an experiment by its unique identifier. - - Args: - id: An integer identifying a registered experiment. - - Returns: - The Dioptra REST api's response. - - Example:: - - { - 'lastModified': '2023-06-22T13:42:35.379462', - 'experimentId': 10, - 'name': 'mnist_feature_squeezing', - 'createdOn': '2023-06-22T13:42:35.379462' - } - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - experiment_id_query: str = urljoin(self.experiment_endpoint, str(id)) - return cast(dict[str, Any], requests.get(experiment_id_query).json()) - - def get_experiment_by_name(self, name: str) -> dict[str, Any]: - """Gets an experiment by its unique name. - - Args: - name: The name of the experiment. - - Returns: - The Dioptra REST api's response containing the experiment id, name, and - metadata. - - Example:: - - { - 'experimentId': 1, - 'name': 'mnist', - 'createdOn': '2023-06-22T13:42:35.379462', - 'lastModified': '2023-06-22T13:42:35.379462' - } - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - experiment_name_query: str = urljoin(self.experiment_endpoint, "name", name) - return cast(dict[str, Any], requests.get(experiment_name_query).json()) - - def get_job_by_id(self, id: str) -> dict[str, Any]: - """Gets a job by its unique identifier. - - Args: - id: A string specifying a job's UUID. - - Returns: - The Dioptra REST api's response. - - Example:: - - { - 'mlflowRunId': None, - 'lastModified': '2023-06-26T15:26:43.100093', - 'experimentId': 10, - 'queueId': 2, - 'workflowUri': 's3://workflow/268a7620/workflows.tar.gz', - 'entryPoint': 'train', - 'dependsOn': None, - 'status': 'queued', - 'timeout': '24h', - 'jobId': '4eb2305e-57c3-4867-a59f-1a1ecd2033d4', - 'entryPointKwargs': '-P model_architecture=shallow_net -P epochs=3', - 'createdOn': '2023-06-26T15:26:43.100093' - } - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - job_id_query: str = urljoin(self.job_endpoint, id) - return cast(dict[str, Any], requests.get(job_id_query).json()) - - def get_queue_by_id(self, id: int) -> dict[str, Any]: - """Gets a queue by its unique identifier. - - Args: - id: An integer identifying a registered queue. - - Returns: - The Dioptra REST api's response. - - Example:: - - { - 'lastModified': '2023-04-24T20:53:09.801442', - 'name': 'tensorflow_cpu', - 'queueId': 1, - 'createdOn': '2023-04-24T20:53:09.801442' - } - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - queue_id_query: str = urljoin(self.queue_endpoint, str(id)) - return cast(dict[str, Any], requests.get(queue_id_query).json()) - - def get_queue_by_name(self, name: str) -> dict[str, Any]: - """Gets a queue by its unique name. - - Args: - name: The name of the queue. - - Returns: - The Dioptra REST api's response. - - Example:: - - { - 'lastModified': '2023-04-24T20:53:09.801442', - 'name': 'tensorflow_cpu', - 'queueId': 1, - 'createdOn': '2023-04-24T20:53:09.801442' - } - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - queue_name_query: str = urljoin(self.queue_endpoint, "name", name) - return cast(dict[str, Any], requests.get(queue_name_query).json()) - - def get_builtin_task_plugin(self, name: str) -> dict[str, Any]: - """Gets a custom builtin plugin by its unique name. - - Args: - name: A unique string identifying a task plugin package within - dioptra_builtins collection. - - Returns: - The Dioptra REST api's response. - - Example:: - - { - 'taskPluginName': 'attacks', - 'collection': 'dioptra_builtins', - 'modules': ['__init__.py', 'fgm.py'] - } - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - task_plugin_name_query: str = urljoin(self.task_plugin_builtins_endpoint, name) - return cast(dict[str, Any], requests.get(task_plugin_name_query).json()) - - def get_custom_task_plugin(self, name: str) -> dict[str, Any]: - """Gets a custom task plugin by its unique name. - - Args: - name: A unique string identifying a task plugin package within - dioptra_builtins collection. - - Returns: - The Dioptra REST api's response. - - Example:: - - { - 'taskPluginName': 'custom_poisoning_plugins', - 'collection': 'dioptra_custom', - 'modules': [ - '__init__.py', - 'attacks_poison.py', - 'data_tensorflow.py', - 'datasetup.py', - 'defenses_image_preprocessing.py', - 'defenses_training.py', - 'estimators_keras_classifiers.py', - 'registry_art.py', - 'tensorflow.py' - ] - } - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - task_plugin_name_query: str = urljoin(self.task_plugin_custom_endpoint, name) - return cast(dict[str, Any], requests.get(task_plugin_name_query).json()) - - def list_experiments(self) -> list[dict[str, Any]]: - """Gets a list of all registered experiments. - - Returns: - A list of responses detailing all experiments. - - Example:: - - [ - { - 'lastModified': '2023-04-24T20:20:27.315687', - 'experimentId': 1, - 'name': 'mnist', - 'createdOn': '2023-04-24T20:20:27.315687' - }, - ... - { - 'lastModified': '2023-06-22T13:42:35.379462', - 'experimentId': 10, - 'name': 'mnist_feature_squeezing', - 'createdOn': '2023-06-22T13:42:35.379462' - } - ] - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - return cast(list[dict[str, Any]], requests.get(self.experiment_endpoint).json()) - - def list_jobs(self) -> list[dict[str, Any]]: - """Gets a list of all submitted jobs. - - Returns: - A list of responses detailing all jobs. - - Example:: - - [ - { - 'mlflowRunId': None, - 'lastModified': '2023-04-24T20:54:30.722304', - 'experimentId': 2, - 'queueId': 2, - 'workflowUri': 's3://workflow/268a7620/workflows.tar.gz', - 'entryPoint': 'train', - 'dependsOn': None, - 'status': 'queued', - 'timeout': '1h', - 'jobId': 'a4c574dd-cbd1-43c9-9afe-17d69cd1c73d', - 'entryPointKwargs': '-P data_dir=/nfs/data/Mnist', - 'createdOn': '2023-04-24T20:54:30.722304' - }, - ... - ] - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - return cast(list[dict[str, Any]], requests.get(self.job_endpoint).json()) - - def list_queues(self) -> list[dict[str, Any]]: - """Gets a list of all registered queues. - - Returns: - A list of responses detailing all registered queues. - - Example:: - - [ - { - 'lastModified': '2023-04-24T20:53:09.801442', - 'name': 'tensorflow_cpu', - 'queueId': 1, - 'createdOn': '2023-04-24T20:53:09.801442' - }, - { - 'lastModified': '2023-04-24T20:53:09.824101', - 'name': 'tensorflow_gpu', - 'queueId': 2, - 'createdOn': '2023-04-24T20:53:09.824101' - }, - { - 'lastModified': '2023-04-24T20:53:09.867917', - 'name': 'pytorch_cpu', - 'queueId': 3, - 'createdOn': '2023-04-24T20:53:09.867917' - }, - { - 'lastModified': '2023-04-24T20:53:09.893451', - 'name': 'pytorch_gpu', - 'queueId': 4, - 'createdOn': '2023-04-24T20:53:09.893451' - } - ] - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - return cast(list[dict[str, Any]], requests.get(self.queue_endpoint).json()) - - def list_all_task_plugins(self) -> list[dict[str, Any]]: - """Gets a list of all registered builtin task plugins. - - Returns: - A list of responses detailing all plugins. - - Example:: - - [ - { - 'taskPluginName': 'artifacts', - 'collection': 'dioptra_builtins', - 'modules': ['__init__.py', 'mlflow.py', 'utils.py'] - }, - ... - { - 'taskPluginName': 'pixel_threshold', - 'collection': 'dioptra_custom', - 'modules': ['__init__.py', 'pixelthreshold.py'] - } - ] - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - - return cast( - list[dict[str, Any]], requests.get(self.task_plugin_endpoint).json() - ) +class SubEndpoint(Endpoint): + def __init__(self, parent, session, ep_name, address): + Endpoint.__init__(self, session, ep_name, address) + self._parent = parent # parent should extend HasSubEndpointProvider + + def suburl(self, ep_id): + return urljoin(self._parent.idurl(str(ep_id)), self.ep_name) - def list_builtin_task_plugins(self) -> list[dict[str, Any]]: - """Gets a list of all registered builtin task plugins. - - Returns: - A list of responses detailing all builtin plugins. - - Example:: - - [ - { - 'taskPluginName': 'artifacts', - 'collection': 'dioptra_builtins', - 'modules': ['__init__.py', 'mlflow.py', 'utils.py'] - }, - ... - { - 'taskPluginName': 'backend_configs', - 'collection': 'dioptra_builtins', - 'modules': ['__init__.py', 'tensorflow.py'] - } - ] - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html for - more information on Dioptra's REST api. - """ - return cast( - list[dict[str, Any]], - requests.get(self.task_plugin_builtins_endpoint).json(), - ) - def list_custom_task_plugins(self) -> list[dict[str, Any]]: - """Gets a list of all registered custom task plugins. - - Returns: - A list of responses detailing all custom plugins. - - Example:: - - [ - { - 'taskPluginName': 'model_inversion', - 'collection': 'dioptra_custom', - 'modules': ['__init__.py', 'modelinversion.py'] - }, - ... - { - 'taskPluginName': 'pixel_threshold', - 'collection': 'dioptra_custom', - 'modules': ['__init__.py', 'pixelthreshold.py'] - } - ] - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - return cast( - list[dict[str, Any]], requests.get(self.task_plugin_custom_endpoint).json() +class UsersClient(Endpoint): + + _ep_name = "users" + + def get_all(self): + """gets all users""" + return self.session.get(self.url) + + def create(self, username, email, password, confirm_password): + """creates a user""" + d = { + "username": username, + "email": email, + "password": password, + "confirmPassword": confirm_password, + } + return self.session.post(self.url, d) + + def get_by_id(self, user_id): + """get a user by id""" + return self.session.get(self.url, str(user_id)) + + def update_password_by_id( + self, user_id, old_password, new_password, confirm_new_password + ): + """change a user's password by id""" + d = { + "oldPassword": old_password, + "newPassword": new_password, + "confirmNewPassword": confirm_new_password, + } + return self.session.post(self.url, d, str(user_id), "password") + + def current(self): + """get the current user""" + return self.session.get(self.url, "current") + + def delete_current(self, password): + """delete the current user""" + d = {"password": password} + return self.session.delete(self.url, d, "current") + + def modify_current(self, username, email): + """modify the current user""" + d = {"username": username, "email": email} + return self.session.put(self.url, d, "current") + + def modify_current_password(self, old_password, new_password, confirm_new_password): + """modify the current user's password""" + d = { + "oldPassword": old_password, + "newPassword": new_password, + "confirmNewPassword": confirm_new_password, + } + return self.session.post(self.url, d, "current", "password") + + +class AuthClient(Endpoint): + + _ep_name = "auth" + + def login(self, username, password): + """login as the given user""" + d = {"username": username, "password": password} + return self.session.post(self.url, d, "login") + + def logout(self, everywhere): + """logout as the current user""" + d = {"everywhere": everywhere} + return self.session.post(self.url, d, "logout") + + +class GroupsClient(Endpoint): + + _ep_name = "groups" + + def get_all(self): + """get all groups""" + return self.session.get(self.url) + + def get_by_id(self, gid): + """get a group by id""" + return self.session.get(self.url, str(gid)) + + +class QueuesClient(Endpoint, HasDraftsEndpoint, HasSubEndpointProvider): + + _ep_name = "queues" + + def __init__(self, session, ep_name, address): + Endpoint.__init__(self, session, ep_name, address) + HasDraftsEndpoint.__init__( + self, self.url, self.session, address, ["name", "description"] ) + HasSubEndpointProvider.__init__(self, self.url) + + def get_all(self): + """gets all queues""" + return self.session.get(self.url) + + def create(self, group, name, description): + """create a queue""" + d = {"group": group, "name": name, "description": description} + return self.session.post(self.url, d) + + def modify_by_id(self, queue_id, name, description): + """modify a queue by id""" + d = {"name": name, "description": description} + return self.session.put(self.url, d, str(queue_id)) + + def delete_by_id(self, queue_id): + """delete a queue by id""" + d = None + return self.session.delete(self.url, d, str(queue_id)) + + def get_by_id(self, queue_id): + """get a queue by id""" + return self.session.get(self.url, str(queue_id)) - def lock_queue(self, name: str) -> dict[str, Any]: - """Locks the queue (name reference) if it is unlocked. - Args: - name: The name of the queue. +class TagsClient(Endpoint): - Returns: - The Dioptra REST api's response. + _ep_name = "tags" - Example:: + def get_all(self): + return self.session.get(self.url) - {'name': ['tensorflow_cpu'], 'status': 'Success'} + def create(self, group, name): + d = {"name": name, "group": group} + return self.session.post(self.url, d) - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - queue_name_query: str = urljoin(self.queue_endpoint, "name", name, "lock") - return cast(dict[str, Any], requests.put(queue_name_query).json()) + def delete_by_id(self, tag_id): + d = None + return self.session.delete(self.url, d, str(tag_id)) - def unlock_queue(self, name: str) -> dict[str, Any]: - """Removes the lock from the queue (name reference) if it exists. + def get_by_id(self, tag_id): + return self.session.get(self.url, str(tag_id)) - Args: - name: The name of the queue. + def modify_by_id(self, tag_id, name): + d = {"name": name} + return self.session.put(self.url, d, str(tag_id)) - Returns: - The Dioptra REST api's response. + def get_resources_by_id(self, tag_id): + return self.session.get(self.url, str(tag_id), "resources") - Example:: - {'name': ['tensorflow_cpu'], 'status': 'Success'} +class EntrypointsClient( + Endpoint, HasTagsProvider, HasDraftsEndpoint, HasSubEndpointProvider +): - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - queue_name_query: str = urljoin(self.queue_endpoint, "name", name, "lock") - return cast(dict[str, Any], requests.delete(queue_name_query).json()) + _ep_name = "entrypoints" - def register_experiment(self, name: str) -> dict[str, Any]: - """Creates a new experiment via an experiment registration form. + def __init__(self, session, ep_name, address): + Endpoint.__init__(self, session, ep_name, address) + HasTagsProvider.__init__(self, self.url, self.session) + HasDraftsEndpoint.__init__( + self, + self.url, + self.session, + address, + ["name", "description", "taskGraph", "parameters", "queues", "plugins"], + ) + HasSubEndpointProvider.__init__(self, self.url) + + def get_all(self): + return self.session.get(self.session, self.url) + + def create(self, group, name, description, taskGraph, parameters, queues, plugins): + d = { + "group": group, + "name": name, + "description": description, + "taskGraph": taskGraph, + "parameters": parameters, + "queues": queues, + "plugins": plugins, + } + return self.session.post(self.session, self.url, d) + + def modify_by_id( + self, entrypoint_id, name, description, taskGraph, parameters, queues + ): + d = { + "name": name, + "description": description, + "taskGraph": taskGraph, + "parameters": parameters, + "queues": queues, + } + return self.session.put(self.url, d, str(entrypoint_id)) - Args: - name: The name to register as a new experiment. + def get_by_id(self, entrypoint_id): + return self.session.get(self.url, str(entrypoint_id)) - Returns: - The Dioptra REST api's response. + def delete_by_id(self, entrypoint_id): + d = None + return self.session.delete(self.url, d, str(entrypoint_id)) - Example:: + def get_plugins_by_entrypoint_id(self, entrypoint_id): + return self.session.get(self.url, str(entrypoint_id), "plugins") - { - 'lastModified': '2023-06-26T15:45:09.232878', - 'experimentId': 11, - 'name': 'experiment1234', - 'createdOn': '2023-06-26T15:45:09.232878' - } + def add_plugins_by_entrypoint_id(self, entrypoint_id, plugins): + d = {"plugins": plugins} + return self.session.post(self.url, d, str(entrypoint_id), "plugins") - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - experiment_registration_form = {"name": name} + def get_plugins_by_entrypoint_id_plugin_id(self, entrypoint_id, plugin_id): + return self.session.get(self.url, str(entrypoint_id), "plugins", str(plugin_id)) - response = requests.post( - self.experiment_endpoint, - json=experiment_registration_form, + def delete_plugins_by_entrypoint_id_plugin_id(self, entrypoint_id, plugin_id): + d = None + return self.session.delete( + self.url, d, str(entrypoint_id), "plugins", str(plugin_id) ) - return cast(dict[str, Any], response.json()) + def modify_queues_by_entrypoint_id(self, entrypoint_id, ids): + d = {"ids": ids} + return self.session.put(self.url, d, str(entrypoint_id), "queues") - def register_queue(self, name: str = "tensorflow_cpu") -> dict[str, Any]: - """Creates a new queue via a queue registration form. + def add_queues_by_entrypoint_id(self, entrypoint_id, ids): + d = {"ids": ids} + return self.session.post(self.url, d, str(entrypoint_id), "queues") - Args: - name: The name to register as a new queue. Defaults to "tensorflow_cpu". + def get_queues_by_entrypoint_id(self, entrypoint_id): + return self.session.get(self.url, str(entrypoint_id), "queues") - Returns: - The Dioptra REST api's response. + def delete_queues_by_entrypoint_id(self, entrypoint_id): + d = None + return self.session.delete(self.url, d, str(entrypoint_id), "queues") - Example:: + def delete_queues_by_entrypoint_id_queue_id(self, entrypoint_id, queue_id): + d = None + return self.session.delete( + self.url, d, str(entrypoint_id), "queues", str(queue_id) + ) + + def get_snapshots_by_entrypoint_id(self, entrypoint_id): + return self.session.get(self.url, str(entrypoint_id), "snapshots") + + def get_snapshots_by_entrypoint_id_snapshot_id(self, entrypoint_id, snapshot_id): + return self.session.get( + self.url, str(entrypoint_id), "snapshots", str(snapshot_id) + ) - { - 'lastModified': '2023-06-26T15:48:47.662293', - 'name': 'queue', - 'queueId': 7, - 'createdOn': '2023-06-26T15:48:47.662293' - } - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - queue_registration_form = {"name": name} +class ExperimentsClient( + Endpoint, HasTagsProvider, HasDraftsEndpoint, HasSubEndpointProvider +): - response = requests.post( - self.queue_endpoint, - json=queue_registration_form, + _ep_name = "experiments" + + def __init__(self, session, ep_name, address): + Endpoint.__init__(self, session, ep_name, address) + HasTagsProvider.__init__(self, self.url, self.session) + HasDraftsEndpoint.__init__( + self, + self.url, + self.session, + address, + ["name", "description", "entrypoints"], + ) + HasSubEndpointProvider.__init__(self, self.url) + + def get_all(self): + return self.session.get(self.url) + + def create(self, group, name, description, entrypoints): + d = { + "group": group, + "name": name, + "description": description, + "entrypoints": entrypoints, + } + return self.session.post(self.url, d) + + def get_drafts(self): + return self.session.get(self.url, "drafts") + + def get_by_id(self, experiment_id): + return self.session.get(self.url, str(experiment_id)) + + def modify_by_id(self, experiment_id, name, description, entrypoints): + d = {"name": name, "description": description, "entrypoints": entrypoints} + return self.session.put(self.url, d, str(experiment_id)) + + def delete_by_id(self, experiment_id): + d = None + return self.session.delete(self.url, d, str(experiment_id)) + + def get_entrypoints_by_experiment_id(self, experiment_id): + return self.session.get(self.url, str(experiment_id), "entrypoints") + + def modify_entrypoints_by_experiment_id(self, experiment_id, ids): + d = {"ids": ids} + return self.session.put(self.url, d, str(experiment_id), "entrypoints") + + def add_entrypoints_by_experiment_id(self, experiment_id, ids): + d = {"ids": ids} + return self.session.post(self.url, d, str(experiment_id), "entrypoints") + + def delete_entrypoints_by_experiment_id(self, experiment_id): + d = None + return self.session.delete(self.url, d, str(experiment_id), "entrypoints") + + def delete_entrypoints_by_experiment_id_entrypoint_id( + self, experiment_id, entrypoint_id + ): + d = None + return self.session.delete( + self.url, + d, + str(experiment_id), + "entrypoints", + str(entrypoint_id), ) - return cast(dict[str, Any], response.json()) - - def submit_job( - self, - workflows_file: str | Path, - experiment_name: str, - entry_point: str, - entry_point_kwargs: str | None = None, - depends_on: str | None = None, - queue: str = "tensorflow_cpu", - timeout: str = "24h", - ) -> dict[str, Any]: - """Creates a new job via a job submission form with an attached file. - - Args: - workflows_file: A tarball archive or zip file containing, at a minimum, - a MLproject file and its associated entry point scripts. - experiment_name:The name of a registered experiment. - entry_point: Entrypoint name. - entry_point_kwargs: A string listing parameter values to pass to the - entry point for the job. The list of parameters is specified using the - following format: “-P param1=value1 -P param2=value2”. Defaults to None. - depends_on: A UUID for a previously submitted job to set as a dependency - for the current job. Defaults to None. - queue: Name of the queue the job is submitted to. Defaults to - "tensorflow_cpu". - timeout: The maximum alloted time for a job before it times out and is - stopped. Defaults to "24h". - - Returns: - The Dioptra REST api's response. - - Example:: - - { - 'createdOn': '2023-06-26T15:26:43.100093', - 'dependsOn': None, - 'entryPoint': 'train', - 'entryPointKwargs': '-P data_dir=/dioptra/data/Mnist', - 'experimentId': 10, - 'jobId': '4eb2305e-57c3-4867-a59f-1a1ecd2033d4', - 'lastModified': '2023-06-26T15:26:43.100093', - 'mlflowRunId': None, - 'queueId': 2, - 'status': 'queued', - 'timeout': '24h', - 'workflowUri': 's3://workflow/07d2c0a9/workflows.tar.gz' - } - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - job_form: dict[str, Any] = { - "experimentName": experiment_name, + def get_jobs_by_experiment_id(self, experiment_id): + return self.session.get(self.url, str(experiment_id), "jobs") + + def create_jobs_by_experiment_id( + self, experiment_id, description, queue, entrypoint, values, timeout + ): + d = { + "description": description, "queue": queue, + "entrypoint": entrypoint, + "values": values, "timeout": timeout, - "entryPoint": entry_point, } + return self.session.post(self.url, d, str(experiment_id), "jobs") - if entry_point_kwargs is not None: - job_form["entryPointKwargs"] = entry_point_kwargs + def get_jobs_by_experiment_id_job_id(self, experiment_id, job_id): + return self.session.get(self.url, str(experiment_id), "jobs", str(job_id)) - if depends_on is not None: - job_form["dependsOn"] = depends_on + def delete_jobs_by_experiment_id_job_id(self, experiment_id, job_id): + d = None + return self.session.delete(self.url, d, str(experiment_id), "jobs", str(job_id)) - workflows_file = Path(workflows_file) + def get_jobs_status_by_experiment_id_job_id(self, experiment_id, job_id): + return self.session.get( + self.url, str(experiment_id), "jobs", str(job_id), "status" + ) - with workflows_file.open("rb") as f: - job_files = {"workflow": (workflows_file.name, f)} - response = requests.post( - self.job_endpoint, - data=job_form, - files=job_files, - ) + def modify_jobs_status_by_experiment_id_job_id(self, experiment_id, job_id, status): + d = {"status": status} + return self.session.put( + self.url, d, str(experiment_id), "jobs", str(job_id), "status" + ) + + def get_snapshots_by_experiment_id(self, experiment_id): + return self.session.get(self.url, str(experiment_id), "snapshots") + + def get_snapshots_by_experiment_id_snapshot_id(self, experiment_id, snapshot_id): + return self.session.get( + self.url, str(experiment_id), "snapshots", str(snapshot_id) + ) + + +class JobsClient(Endpoint, HasTagsProvider): + + _ep_name = "jobs" + + def __init__(self, session, ep_name, address): + Endpoint.__init__(self, session, ep_name, address) + HasTagsProvider.__init__(self, self.url, self.session) + + def get_all(self): + return self.session.get(self.url) + + def delete_by_id(self, job_id): + d = None + return self.session.delete(self.url, d, str(job_id)) - return cast(dict[str, Any], response.json()) - - def upload_custom_plugin_package( - self, - custom_plugin_name: str, - custom_plugin_file: str | Path, - collection: str = "dioptra_custom", - ) -> dict[str, Any]: - """Registers a new task plugin uploaded via the task plugin upload form. - - Args: - custom_plugin_name: Plugin name for the upload form. - custom_plugin_file: Path to custom plugin. - collection: Collection to upload the plugin to. Defaults to - "dioptra_custom". - - Returns: - The Dioptra REST api's response. - - Example:: - - { - 'taskPluginName': 'evaluation', - 'collection': 'dioptra_custom', - 'modules': [ - 'tensorflow.py', - 'import_keras.py', - '__init__.py' - ] - } - - Notes: - See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html - for more information on Dioptra's REST api. - """ - plugin_upload_form = { - "taskPluginName": custom_plugin_name, - "collection": collection, + def get_by_id(self, job_id): + return self.session.get(self.url, str(job_id)) + + def get_snapshots_by_job_id(self, job_id): + return self.session.get(self.url, str(job_id), "snapshots") + + def get_snapshots_by_job_id_snapshot_id(self, job_id, snapshot_id): + return self.session.get(self.url, str(job_id), "snapshots", str(snapshot_id)) + + def get_status_by_job_id(self, job_id): + return self.session.get(self.url, str(job_id), "status") + + +class PluginsClient( + Endpoint, HasDraftsEndpoint, HasSubEndpointProvider, HasTagsProvider +): + + _ep_name = "plugins" + + def __init__(self, session, ep_name, address): + Endpoint.__init__(self, session, ep_name, address) + HasTagsProvider.__init__(self, self.url, self.session) + HasDraftsEndpoint.__init__( + self, self.url, self.session, address, ["name", "description"] + ) + HasSubEndpointProvider.__init__(self, self.url) + self._files = PluginFilesClient(self, session, "files", address) + + @property + def files(self): + return self._files + + def get_all(self): + return self.session.get(self.url) + + def create(self, group, name, description): + d = {"group": group, "name": name, "description": description} + return self.session.post(self.url, d) + + def get_by_id(self, plugin_id): + return self.session.get(self.url, str(plugin_id)) + + def modify_by_id(self, plugin_id, name, description): + d = {"name": name, "description": description} + return self.session.put(self.url, d, str(plugin_id)) + + def delete_by_id(self, plugin_id): + d = None + return self.session.delete(self.url, d, str(plugin_id)) + + def get_snapshots_by_plugin_id(self, plugin_id): + return self.session.get(self.url, str(plugin_id), "snapshots") + + def get_snapshot_by_plugin_id_snapshot_id(self, plugin_id, snapshot_id): + return self.session.get(self.url, str(plugin_id), "snapshots", str(snapshot_id)) + + +class PluginFilesClient(SubEndpoint): + def __init__(self, parent, session, ep_name, address): + SubEndpoint.__init__(self, parent, session, ep_name, address) + # HasTagsProvider.__init__(self, self.url, self.session) + # HasDraftsEndpoint.__init__(self, self.url, self.session, address, + # ["filename", "description"] + # ) + # HasSubEndpointProvider.__init__(self, self.url) + + def get_files_by_plugin_id(self, plugin_id): + return self.session.get(self.suburl(plugin_id)) + + def create_files_by_plugin_id( + self, plugin_id, filename, contents, description, *plugins + ): + d = { + "filename": filename, + "contents": contents, + "description": description, + "tasks": [plugin.as_dict() for plugin in plugins], + } + return self.session.post(self.suburl(plugin_id), d) + + def delete_files_by_plugin_id(self, plugin_id): + d = None + return self.session.delete(self.suburl(plugin_id), d) + + def get_files_drafts_by_plugin_id(self, plugin_id): + return self.session.get(self.suburl(plugin_id), "drafts") + + def create_files_drafts_by_plugin_id( + self, plugin_id, filename, contents, description, *plugins + ): + d = { + "filename": filename, + "contents": contents, + "description": description, + "tasks": [plugin.as_dict() for plugin in plugins], + } + return self.session.post(self.suburl(plugin_id), d, "drafts") + + def get_files_drafts_by_plugin_id_draft_id(self, plugin_id, drafts_id): + return self.session.get(self.suburl(plugin_id), "drafts", str(drafts_id)) + + def modify_files_drafts_by_plugin_id_draft_id( + self, plugin_id, drafts_id, filename, contents, description, *plugins + ): + d = { + "filename": filename, + "contents": contents, + "description": description, + "tasks": [plugin.as_dict() for plugin in plugins], + } + return self.session.put(self.suburl(plugin_id), d, "drafts", str(drafts_id)) + + def delete_files_drafts_by_plugin_id_draft_id(self, plugin_id, drafts_id): + d = None + return self.session.delete(self.suburl(plugin_id), d, "drafts", str(drafts_id)) + + def get_files_by_plugin_id_file_id(self, plugin_id, file_id): + return self.session.get(self.suburl(plugin_id), str(file_id)) + + def modify_files_by_plugin_id_file_id( + self, plugin_id, file_id, filename, contents, description, *plugins + ): + d = { + "filename": filename, + "contents": contents, + "description": description, + "tasks": [plugin.as_dict() for plugin in plugins], } + return self.session.put(self.suburl(plugin_id), d, str(file_id)) + + def delete_files_by_plugin_id_file_id(self, plugin_id, file_id): + d = None + return self.session.delete(self.suburl(plugin_id), d, str(file_id)) + + def get_files_draft_by_plugin_id_file_id(self, plugin_id, file_id): + return self.session.get(self.suburl(plugin_id), str(file_id), "draft") + + def modify_files_draft_by_plugin_id_file_id( + self, plugin_id, file_id, filename, contents, description, *plugins + ): + d = { + "filename": filename, + "contents": contents, + "description": description, + "tasks": [plugin.as_dict() for plugin in plugins], + } + return self.session.put(self.suburl(plugin_id), d, str(file_id), "draft") + + def delete_files_draft_by_plugin_id_file_id(self, plugin_id, file_id): + d = None + return self.session.delete(self.suburl(plugin_id), d, str(file_id), "draft") + + def create_files_draft_by_plugin_id_file_id( + self, plugin_id, file_id, filename, contents, description, *plugins + ): + d = { + "filename": filename, + "contents": contents, + "description": description, + "tasks": [plugin.as_dict() for plugin in plugins], + } + return self.session.post(self.suburl(plugin_id), d, str(file_id), "draft") + + def get_snapshots_by_plugin_id_file_id(self, plugin_id, file_id): + return self.session.get(self.suburl(plugin_id), str(file_id), "snapshots") + + def get_snapshots_by_plugin_id_file_id_snapshot_id( + self, plugin_id, file_id, snapshot_id + ): + return self.session.get( + self.suburl(plugin_id), + str(file_id), + "snapshots", + str(snapshot_id), + ) - custom_plugin_file = Path(custom_plugin_file) + def get_tags_by_plugin_id_file_id(self, plugin_id, file_id): + return self.session.get(self.suburl(plugin_id), str(file_id), "tags") - with custom_plugin_file.open("rb") as f: - custom_plugin_file_dict = {"taskPluginFile": (custom_plugin_file.name, f)} - response = requests.post( - self.task_plugin_endpoint, - data=plugin_upload_form, - files=custom_plugin_file_dict, - ) + def modify_tags_by_plugin_id_file_id(self, plugin_id, file_id, ids): + d = {"ids": ids} + return self.session.put(self.suburl(plugin_id), d, str(file_id), "tags") + + def delete_tags_by_plugin_id_file_id(self, plugin_id, file_id): + d = None + return self.session.delete(self.suburl(plugin_id), d, str(file_id), "tags") + + def add_tags_by_plugin_id_file_id(self, plugin_id, file_id, ids): + d = {"ids": ids} + return self.session.post(self.suburl(plugin_id), d, str(file_id), "tags") + + def delete_tags_by_plugin_id_file_id_tag_id(self, plugin_id, file_id, tag_id): + d = None + return self.session.delete( + self.suburl(plugin_id), d, str(file_id), "tags", str(tag_id) + ) + + +class PluginParameterTypesClient(Endpoint): + + _ep_name = "pluginParameterTypes" + + def get_all(self): + return self.session.get(self.url) + + def create(self, group, name, description, structure): + d = { + "group": group, + "name": name, + "description": description, + "structure": structure, + } + return self.session.post(self.url, d) + + def get_by_id(self, type_id): + return self.session.get(self.url, str(type_id)) + + def modify_by_id(self, type_id, name, description, structure): + d = {"name": name, "description": description, "structure": structure} + return self.session.put(self.url, d, str(type_id)) + + def delete_by_id(self, type_id): + d = None + return self.session.delete(self.url, d, str(type_id)) + + +class PluginTask(object): + def __init__(self, name, inputs, outputs, client): + self.name = name + self.inputs = inputs # expects [(name1, type1), (name2, type2) ...] + self.outputs = outputs # expects [(name1, type1), (name2, type2) ...] + self.client = client + + def convert_params_to_ids(self, mappings): + """this converts parameters to registered ids using a mapping + from register_unregistered_types""" + return [(i[0], mappings[i[1]]) for i in self.inputs], [ + (o[0], mappings[o[1]]) for o in self.outputs + ] + + def register_unregistered_types(self, group=1): + """checks all the types in inputs/outputs and register things that + aren't registered""" + registered_types = ( + self.client.pluginParameterTypes.get_all() + ) # get all registered types + types_used_in_plugin = set( + [m[1] for m in self.inputs] + [m[1] for m in self.outputs] + ) # get all types for this plugin + types_to_id = {} + for registered in registered_types[ + "data" + ]: # add registered types to our dictionary + types_to_id[str(registered["name"])] = str(registered["id"]) + for used in types_used_in_plugin: + used = str(used) + if used not in types_to_id: # not yet registered, so register it + response = self.client.pluginParameterTypes.create( + group, used, used + " plugin parameter", structure={} + ) + types_to_id[used] = str(response["id"]) + return types_to_id # mapping of types to ids + + def as_dict(self, mappings=None): + """convert it to a dict to be sent to the RESTAPI""" + if mappings is None: + mappings = self.register_unregistered_types() + ins, outs = self.convert_params_to_ids(mappings) + return { + "name": self.name, + "inputParams": [ + {"name": param[0], "parameterType": param[1]} for param in ins + ], + "outputParams": [ + {"name": param[0], "parameterType": param[1]} for param in outs + ], + } + + +class ArtifactsClient(Endpoint): + + _ep_name = "artifacts" + + def get_all(self): + return self.session.get(self.url) + + def create(self, group, description, job, uri): + d = {"group": group, "description": description, "job": job, "uri": uri} + return self.session.post(self.url, d) + + def get_by_id(self, artifact_id): + return self.session.get(self.url, str(artifact_id)) + + def modify_by_id(self, artifact_id, description): + d = {"description": description} + return self.session.put(self.url, d, str(artifact_id)) + + def get_snapshots(self, artifact_id): + return self.session.get(self.url, str(artifact_id), "snapshots") + + def get_snapshots_by_artifact_id_snapshot_id(self, artifact_id, snapshot_id): + return self.session.get( + self.url, str(artifact_id), "snapshots", str(snapshot_id) + ) + + +class ModelsClient( + Endpoint, HasTagsProvider, HasDraftsEndpoint, HasSubEndpointProvider +): + + _ep_name = "models" + + def __init__(self, session, ep_name, address): + Endpoint.__init__(self, session, ep_name, address) + HasSubEndpointProvider.__init__(self, self.url) + HasTagsProvider.__init__(self, self.url, self.session) + HasDraftsEndpoint.__init__( + self, self.url, self.session, address, ["name", "description"] + ) + + def get_all(self): + return self.session.get(self.url) + + def create(self, group, name, description): + d = {"group": group, "name": name, "description": description} + return self.session.post(self.url, d) + + def get_by_id(self, model_id): + return self.session.get(self.url, str(model_id)) + + def modify_by_id(self, model_id, name, description): + d = {"name": name, "description": description} + return self.session.put(self.url, d, str(model_id)) + + def delete_by_id(self, model_id): + d = None + return self.session.delete(self.url, d, str(model_id)) + + def get_snapshots_by_model_id(self, model_id): + return self.session.get(self.url, str(model_id), "snapshots") + + def get_snapshot_by_plugin_id_model_id(self, model_id, snapshot_id): + return self.session.get(self.url, str(model_id), "snapshots", str(snapshot_id)) + + def get_versions_by_model_id(self, model_id): + return self.session.get(self.url, str(model_id), "versions") + + def create_version_by_model_id(self, model_id, description, artifact): + d = {"description": description, "artifact": artifact} + return self.session.post(self.url, d, str(model_id), "versions") + + def modify_version_by_model_id_version_id(self, model_id, version_id, description): + d = {"description": description} + return self.session.put(self.url, d, str(model_id), "versions", str(version_id)) + + def get_version_by_model_id_version_id(self, model_id, version_id): + return self.session.get(self.url, str(model_id), "versions", str(version_id)) + + +class DraftsEndpoint(SubEndpoint): + def __init__(self, base_url, parent, session, ep_name, address): + SubEndpoint.__init__(self, parent, session, ep_name, address) + self.base_url = base_url + self.fields = parent.draft_fields # array of field names + self.put_fields = parent.put_fields # used when PUT method differs from create + + @property + def drafts_url(self): + return urljoin(self.base_url, "drafts") + + # /something/id/draft + + def create_draft_for_resource( + self, parent_id, *fields + ): # TODO: what to do about these parameters? they can be different + d = {} + for f in zip(self.fields, fields): + d[f[0]] = f[1] + return self.session.post(self.suburl(parent_id), d) + + def get_draft_for_resource(self, parent_id): + return self.session.get(self.suburl(parent_id)) + + def modify_draft_for_resource(self, parent_id, *fields): + d = {} + for f in zip(self.put_fields, fields): + d[f[0]] = f[1] + return self.session.put(self.suburl(parent_id), d) + + def delete_draft_for_resource(self, parent_id): + d = None + return self.session.delete(self.suburl(parent_id), d) + + # /something/drafts/ + + def get_all(self): + return self.session.get(self.drafts_url) + + def create(self, group_id, *fields): + d = {"group": group_id} + for f in zip(self.fields, fields): + d[f[0]] = f[1] + return self.session.post(self.drafts_url, d) + + def modify_by_draft_id(self, draft_id, *fields): + d = {} + for f in zip(self.put_fields, fields): + d[f[0]] = f[1] + return self.session.put(self.drafts_url, d, str(draft_id)) + + def delete_by_draft_id(self, draft_id): + d = None + return self.session.delete(self.drafts_url, d, str(draft_id)) + + def get_by_draft_id(self, draft_id): + return self.session.get(self.drafts_url, str(draft_id)) + + +class TagsProvider(object): + def __init__(self, base_url, session): + # SubEndpoint.__init__(self, session) + self.url = base_url + self.session = session + + def get(self, parent_id): + return self.session.get(self.url, str(parent_id), "tags") + + def modify(self, parent_id, ids): + d = {"ids": ids} + return self.session.put(self.url, d, str(parent_id), "tags") + + def delete_all(self, parent_id): + d = None + return self.session.delete(self.url, d, str(parent_id), "tags") + + def add(self, parent_id, ids): + d = {"ids": ids} + return self.session.post(self.url, d, str(parent_id), "tags") - return cast(dict[str, Any], response.json()) + def delete(self, parent_id, tag_id): + d = None + return self.session.delete(self.url, d, str(parent_id), "tags", str(tag_id)) diff --git a/src/dioptra/client/artifacts.py b/src/dioptra/client/artifacts.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/src/dioptra/client/artifacts.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/src/dioptra/client/auth.py b/src/dioptra/client/auth.py new file mode 100644 index 000000000..10538d93f --- /dev/null +++ b/src/dioptra/client/auth.py @@ -0,0 +1,63 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from typing import ClassVar, TypeVar + +import structlog +from structlog.stdlib import BoundLogger + +from .base import CollectionClient + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +T = TypeVar("T") + + +class AuthCollectionClient(CollectionClient[T]): + """The client for interacting with the Dioptra API's auth endpoint. + + Attributes: + name: The name of the endpoint. + """ + + name: ClassVar[str] = "auth" + + def login(self, username: str, password: str) -> T: + """Send a login request to the Dioptra API. + + Args: + username: The username of the user. + password: The password of the user. + + Returns: + The response from the Dioptra API. + """ + return self._session.post( + self.url, + "login", + json_={"username": username, "password": password}, + ) + + def logout(self, everywhere: bool = False) -> T: + """Send a logout request to the Dioptra API. + + Args: + everywhere: If True, log out from all sessions. + + Returns: + The response from the Dioptra API. + """ + return self._session.post(self.url, "logout", params={"everywhere": everywhere}) diff --git a/src/dioptra/client/base.py b/src/dioptra/client/base.py new file mode 100644 index 000000000..9fa52008c --- /dev/null +++ b/src/dioptra/client/base.py @@ -0,0 +1,437 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from abc import ABC, abstractmethod +from posixpath import join as urljoin +from typing import Any, ClassVar, Generic, Protocol, TypeVar + +import structlog +from structlog.stdlib import BoundLogger + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +T = TypeVar("T") + + +class DioptraClientError(Exception): + """Base class for client errors""" + + +class FieldsValidationError(DioptraClientError): + """Raised when one or more fields are invalid.""" + + +class APIConnectionError(DioptraClientError): + """Class for connection errors""" + + +class StatusCodeError(DioptraClientError): + """Class for status code errors""" + + +class JSONDecodeError(DioptraClientError): + """Class for JSON decode errors""" + + +class SubCollectionUrlError(DioptraClientError): + """Class for errors in the sub-collection URL""" + + +class DioptraRequestProtocol(Protocol): + """The interface for a request to the Dioptra API.""" + + @property + def method(self) -> str: + """The HTTP method used in the request.""" + ... # fmt: skip + + @property + def url(self) -> str: + """The URL the request was made to.""" + ... # fmt: skip + + +class DioptraResponseProtocol(Protocol): + """The interface for a response from the Dioptra API.""" + + @property + def request(self) -> DioptraRequestProtocol: + """The request that generated the response.""" + ... # fmt: skip + + @property + def status_code(self) -> int: + """The HTTP status code of the response.""" + ... # fmt: skip + + @property + def text(self) -> str: + """The response body as a string.""" + ... # fmt: skip + + def json(self) -> dict[str, Any]: + """Return the response body as a JSON-like Python dictionary. + + Returns: + The response body as a dictionary. + """ + ... # fmt: skip + + +class DioptraSession(ABC, Generic[T]): + """The interface for communicating with the Dioptra API.""" + + @property + @abstractmethod + def url(self) -> str: + """The base URL of the API endpoints.""" + raise NotImplementedError + + @abstractmethod + def connect(self) -> None: + """Connect to the API.""" + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the connection to the API.""" + raise NotImplementedError + + @abstractmethod + def make_request( + self, + method_name: str, + url: str, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a request to the API. + + All response objects must implement the DioptraResponseProtocol interface. + + Args: + method_name: The HTTP method to use. Must be one of "get", "patch", "post", + "put", or "delete". + url: The URL of the API endpoint. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + The response from the API. + """ + raise NotImplementedError + + @abstractmethod + def get(self, endpoint: str, *parts, params: dict[str, Any] | None = None) -> T: + """Make a GET request to the API. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + + Returns: + The response from the API. + """ + raise NotImplementedError + + @abstractmethod + def patch( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> T: + """Make a PATCH request to the API. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + The response from the API. + """ + raise NotImplementedError + + @abstractmethod + def post( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> T: + """Make a POST request to the API. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + The response from the API. + """ + raise NotImplementedError + + @abstractmethod + def delete( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> T: + """Make a DELETE request to the API. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + The response from the API. + """ + raise NotImplementedError + + @abstractmethod + def put( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> T: + """Make a PUT request to the API. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + The response from the API. + """ + raise NotImplementedError + + def _get( + self, endpoint: str, *parts, params: dict[str, Any] | None = None + ) -> DioptraResponseProtocol: + """Make a GET request to the API. + + The response from this internal method always implements the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + + Returns: + A response object that implements the DioptraResponseProtocol interface. + """ + return self.make_request("get", self.build_url(endpoint, *parts), params=params) + + def _patch( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a PATCH request to the API. + + The response from this internal method always implements the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A response object that implements the DioptraResponseProtocol interface. + """ + return self.make_request( + "patch", self.build_url(endpoint, *parts), params=params, json_=json_ + ) + + def _post( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a POST request to the API. + + The response from this internal method always implements the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A response object that implements the DioptraResponseProtocol interface. + """ + return self.make_request( + "post", self.build_url(endpoint, *parts), params=params, json_=json_ + ) + + def _delete( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a DELETE request to the API. + + The response from this internal method always implements the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A response object that implements the DioptraResponseProtocol interface. + """ + return self.make_request( + "delete", self.build_url(endpoint, *parts), params=params, json_=json_ + ) + + def _put( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a PUT request to the API. + + The response from this internal method always implements the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A response object that implements the DioptraResponseProtocol interface. + """ + return self.make_request( + "put", self.build_url(endpoint, *parts), params=params, json_=json_ + ) + + @staticmethod + def build_url(base: str, *parts) -> str: + """Build a URL from a base and one or more parts. + + Args: + base: The base URL. + *parts: The parts to join to the base URL. + + Returns: + The joined URL. + """ + return urljoin(base, *parts) + + +class CollectionClient(Generic[T]): + """The interface for an API collection client. + + Attributes: + name: The name of the collection. + """ + + name: ClassVar[str] + + def __init__(self, session: DioptraSession[T]) -> None: + """Initialize the Endpoint object. + + Args: + session: The Dioptra API session object. + """ + self._session = session + + @property + def url(self) -> str: + """The URL of the API endpoint.""" + return self._session.build_url(self._session.url, self.name) + + +class SubCollectionClient(Generic[T]): + name: ClassVar[str] + + def __init__( + self, + session: DioptraSession[T], + root_collection: CollectionClient[T], + parent_sub_collections: list["SubCollectionClient[T]"] | None = None, + ) -> None: + self._session = session + self._root_collection = root_collection + self._parent_sub_collections: list["SubCollectionClient[T]"] = ( + parent_sub_collections or [] + ) + + def build_sub_collection_url(self, *resource_ids: str | int) -> str: + self._validate_resource_ids_count(resource_ids) + parent_url_parts: list[str] = [ + self._root_collection.url, + str(resource_ids[0]), + ] + + for resource_id, parent_sub_collection in zip( + resource_ids[1:], self._parent_sub_collections + ): + parent_url_parts.extend([parent_sub_collection.name, str(resource_id)]) + + return self._session.build_url(*parent_url_parts, self.name) + + def _validate_resource_ids_count(self, resource_ids: tuple[str | int, ...]) -> None: + num_resource_ids = len(resource_ids) + expected_count = len(self._parent_sub_collections) + 1 + if num_resource_ids != expected_count: + raise SubCollectionUrlError( + f"Invalid number of resource ids (reason: expected {expected_count}): " + f"{num_resource_ids}" + ) diff --git a/src/dioptra/client/client.py b/src/dioptra/client/client.py new file mode 100644 index 000000000..2c6a7e57b --- /dev/null +++ b/src/dioptra/client/client.py @@ -0,0 +1,210 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import os +from posixpath import join as urljoin +from typing import Any, Final, Generic, TypeVar + +import structlog +from structlog.stdlib import BoundLogger + +from .auth import AuthCollectionClient +from .base import DioptraResponseProtocol, DioptraSession +from .queues import QueuesCollectionClient +from .tags import TagsCollectionClient +from .users import UsersCollectionClient + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +DIOPTRA_V1_ROOT: Final[str] = "api/v1" +ENV_DIOPTRA_API: Final[str] = "DIOPTRA_API" + +T = TypeVar("T") + + +class DioptraClient(Generic[T]): + """The Dioptra API client.""" + + def __init__(self, session: DioptraSession[T]) -> None: + """Initialize the DioptraClient instance. + + Args: + session: The Dioptra API session object. + """ + self._session = session + self._users = UsersCollectionClient[T](session) + self._auth = AuthCollectionClient[T](session) + self._queues = QueuesCollectionClient[T](session) + self._tags = TagsCollectionClient[T](session) + # self._groups = GroupsCollectionClient[T](session) + # self._plugins = PluginsCollectionClient[T](session) + # self._plugin_parameter_types = ( + # PluginParameterTypesCollectionClient[T](session) + # ) + # self._experiments = ExperimentsCollectionClient[T](session) + # self._jobs = JobsCollectionClient[T](session) + # self._entrypoints = EntrypointsCollectionClient[T](session) + # self._models = ModelsCollectionClient[T](session) + # self._artifacts = ArtifactsCollectionClient[T](session) + + @property + def users(self) -> UsersCollectionClient[T]: + """The Dioptra API's /users endpoint.""" + return self._users + + @property + def auth(self) -> AuthCollectionClient[T]: + """The Dioptra API's /auth endpoint.""" + return self._auth + + @property + def queues(self) -> QueuesCollectionClient[T]: + """The Dioptra API's /queues endpoint.""" + return self._queues + + @property + def tags(self) -> TagsCollectionClient[T]: + """The Dioptra API's /tags endpoint.""" + return self._tags + + # @property + # def groups(self) -> GroupsCollectionClient[T]: + # """The Dioptra API's /groups endpoint.""" + # return self._groups + + # @property + # def plugins(self) -> PluginsCollectionClient[T]: + # """The Dioptra API's /plugins endpoint.""" + # return self._plugins + + # @property + # def plugin_parameter_types(self) -> PluginParameterTypesCollectionClient[T]: + # """The Dioptra API's /pluginParameterTypes endpoint.""" + # return self._plugin_parameter_types + + # @property + # def experiments(self) -> ExperimentsCollectionClient[T]: + # """The Dioptra API's /experiments endpoint.""" + # return self._experiments + + # @property + # def jobs(self) -> JobsCollectionClient[T]: + # """The Dioptra API's /jobs endpoint.""" + # return self._jobs + + # @property + # def entrypoints(self) -> EntrypointsCollectionClient[T]: + # """The Dioptra API's /entrypoints endpoint.""" + # return self._entrypoints + + # @property + # def models(self) -> ModelsCollectionClient[T]: + # """The Dioptra API's /models endpoint.""" + # return self._models + + # @property + # def artifacts(self) -> ArtifactsCollectionClient[T]: + # """The Dioptra API's /artifacts endpoint.""" + # return self._artifacts + + def close(self) -> None: + """Close the client's connection to the API.""" + self._session.close() + + +def connect_response_dioptra_client( + address: str | None = None, +) -> DioptraClient[DioptraResponseProtocol]: + """Connect a client to the Dioptra API that returns response objects. + + This client always returns a response object regardless of the response status code. + It is the responsibility of the caller to check the status code and handle any + errors. + + Args: + address: The Dioptra web address. This is the same address used to access the + web GUI, e.g. "https://dioptra.example.org". Note that the + "/api/" suffix is omitted. If None, then the DIOPTRA_API + environment variable will be checked and used. + + Returns: + A Dioptra client. + + Raises: + ValueError: If address is None and the DIOPTRA_API environment variable is not + set. + """ + from .sessions import DioptraRequestsSession + + return DioptraClient[DioptraResponseProtocol]( + session=DioptraRequestsSession(_build_api_address(address)) + ) + + +def connect_json_dioptra_client( + address: str | None = None, +) -> DioptraClient[dict[str, Any]]: + """Connect a client to the Dioptra API that returns JSON-like Python dictionaries. + + In contrast to the client that returns response objects, this client will raise an + exception for any non-2xx response status code. + + Args: + address: The Dioptra web address. This is the same address used to access the + web GUI, e.g. "https://dioptra.example.org". Note that the + "/api/" suffix is omitted. If None, then the DIOPTRA_API + environment variable will be checked and used. + + Returns: + A Dioptra client. + + Raises: + ValueError: If address is None and the DIOPTRA_API environment variable is not + set. + """ + from .sessions import DioptraRequestsSessionJson + + return DioptraClient[dict[str, Any]]( + session=DioptraRequestsSessionJson(_build_api_address(address)) + ) + + +def _build_api_address(address: str | None) -> str: + """Build the Dioptra API address. + + Args: + address: The Dioptra web address. This is the same address used to access the + web GUI, e.g. "https://dioptra.example.org". Note that the + "/api/" suffix is omitted. If None, then the DIOPTRA_API + environment variable will be checked and used. + + Returns: + The Dioptra API address. + + Raises: + ValueError: If address is None and the DIOPTRA_API environment variable is not + set. + """ + if address is not None: + return urljoin(address, DIOPTRA_V1_ROOT) + + if (dioptra_api := os.getenv(ENV_DIOPTRA_API)) is None: + raise ValueError( + f"The {ENV_DIOPTRA_API} environment variable must be set if the " + "address is not provided." + ) + + return urljoin(dioptra_api, DIOPTRA_V1_ROOT) diff --git a/src/dioptra/client/drafts.py b/src/dioptra/client/drafts.py new file mode 100644 index 000000000..a11cb28bf --- /dev/null +++ b/src/dioptra/client/drafts.py @@ -0,0 +1,337 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from typing import Any, ClassVar, Generic, Protocol, TypeVar + +import structlog +from structlog.stdlib import BoundLogger + +from .base import ( + CollectionClient, + DioptraClientError, + DioptraSession, + SubCollectionClient, + SubCollectionUrlError, +) + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +T = TypeVar("T") + + +class DraftFieldsValidationError(DioptraClientError): + """Raised when one or more draft fields are invalid.""" + + +class ValidateDraftFieldsProtocol(Protocol): + def __call__(self, json_: dict[str, Any]) -> dict[str, Any]: + ... # fmt: skip + + +def make_draft_fields_validator( + draft_fields: set[str], resource_name: str +) -> ValidateDraftFieldsProtocol: + """Create a function to validate the allowed draft fields. + + Args: + draft_fields: The allowed draft fields. + resource_name: The name of the resource the draft fields are for. + + Returns: + The function to validate the allowed draft fields. + """ + + def validate_draft_fields(json_: dict[str, Any]) -> dict[str, Any]: + """Validate the provided draft fields. + + Args: + json_: The draft fields to validate. + + Returns: + The validated draft fields. + + Raises: + DraftFieldsValidationError: If one or more draft fields are invalid or + missing. + """ + provided_fields = set(json_.keys()) + + if draft_fields != provided_fields: + invalid_fields = provided_fields - draft_fields + missing_fields = draft_fields - provided_fields + msg: list[str] = [f"Invalid or missing fields for {resource_name} draft."] + + if invalid_fields: + msg.append(f"Invalid fields: {invalid_fields}") + + if missing_fields: + msg.append(f"Missing fields: {missing_fields}") + + LOGGER.error( + "Invalid or missing draft fields.", + resource_name=resource_name, + invalid_fields=invalid_fields, + missing_fields=missing_fields, + ) + raise DraftFieldsValidationError(" ".join(msg)) + + return json_ + + return validate_draft_fields + + +class NewResourceDraftsSubCollectionClient(Generic[T]): + """The sub-endpoint client for managing drafts under an endpoint. + + Attributes: + name: The name of the sub-endpoint. + draft_for_resource_name: The name of the draft sub-endpoint for a specific + resource. + """ + + name: ClassVar[str] = "drafts" + + def __init__( + self, + session: DioptraSession[T], + validate_fields_fn: ValidateDraftFieldsProtocol, + root_collection: CollectionClient[T], + parent_sub_collections: list[SubCollectionClient[T]] | None = None, + ) -> None: + """Initialize the DraftsSubEndpoint instance. + + Args: + session: The Dioptra API session object. + parent_endpoint: The parent endpoint client. + validate_fields_fn: The function to validate the allowed draft fields. + """ + self._session = session + self._validate_fields = validate_fields_fn + self._root_collection = root_collection + self._parent_sub_collections: list[SubCollectionClient[T]] = ( + parent_sub_collections or [] + ) + + def get( + self, + *resource_ids: str | int, + draft_type: str | None = None, + group_id: int | None = None, + index: int = 0, + page_length: int = 10, + ) -> T: + """Get the list of endpoint drafts. + + Args: + draft_type: The type of drafts to return: all, existing, or new. + group_id: The group ID the drafts belong to. If None, return drafts from all + groups that the user has access to. + index: The paging index. + page_length: The maximum number of drafts to return in the paged response. + + Returns: + The response from the Dioptra API. + """ + params: dict[str, Any] = { + "index": index, + "pageLength": page_length, + } + + if group_id is not None: + params["groupId"] = group_id + + if draft_type is not None: + params["draftType"] = draft_type + + return self._session.get( + self.build_sub_collection_url(*resource_ids), params=params + ) + + def get_by_id(self, *resource_ids: str | int, draft_id: int) -> T: + """Get an endpoint draft by its ID. + + Args: + draft_id: The ID of the draft. + + Returns: + The response from the Dioptra API. + """ + return self._session.get( + self.build_sub_collection_url(*resource_ids), str(draft_id) + ) + + def create(self, *resource_ids: str | int, group_id: int, **kwargs) -> T: + """Create a new endpoint draft. + + Args: + group_id: The ID for the group that will own the resource when the draft is + published. + **kwargs: The draft fields. + + Returns: + The response from the Dioptra API. + + Raises: + ValueError: If "group" is specified in kwargs. + """ + + if "group" in kwargs: + raise ValueError('Cannot specify "group" in kwargs') + + data: dict[str, Any] = {"group": group_id} | self._validate_fields(kwargs) + return self._session.post( + self.build_sub_collection_url(*resource_ids), json_=data + ) + + def modify(self, *resource_ids: str | int, draft_id: int, **kwargs) -> T: + """Modify the endpoint draft matching the provided ID. + + Args: + draft_id: The draft ID. + **kwargs: The draft fields to modify. + + Returns: + The response from the Dioptra API. + + Raises: + ValueError: If "draftId" is specified in kwargs. + """ + if "draftId" in kwargs: + raise ValueError('Cannot specify "draftId" in kwargs') + + return self._session.put( + self.build_sub_collection_url(*resource_ids), + str(draft_id), + json_=self._validate_fields(kwargs), + ) + + def delete(self, *resource_ids: str | int, draft_id: int) -> T: + """Delete the endpoint draft matching the provided ID. + + Args: + draft_id: The draft ID. + + Returns: + The response from the Dioptra API. + """ + return self._session.delete( + self.build_sub_collection_url(*resource_ids), str(draft_id) + ) + + def build_sub_collection_url(self, *resource_ids: str | int) -> str: + self._validate_resource_ids_count(*resource_ids) + parent_url_parts: list[str] = [self._root_collection.url] + + for resource_id, parent_sub_collection in zip( + resource_ids, self._parent_sub_collections + ): + parent_url_parts.extend([str(resource_id), parent_sub_collection.name]) + + return self._session.build_url(*parent_url_parts, self.name) + + def _validate_resource_ids_count(self, *resource_ids: str | int) -> None: + num_resource_ids = len(resource_ids) + expected_count = len(self._parent_sub_collections) + if num_resource_ids != expected_count: + raise SubCollectionUrlError( + f"Invalid number of resource ids (reason: expected {expected_count}): " + f"{num_resource_ids}" + ) + + +class ExistingResourceDraftsSubCollectionClient(SubCollectionClient[T]): + """The sub-endpoint client for managing drafts under an endpoint. + + Attributes: + name: The name of the sub-endpoint. + draft_for_resource_name: The name of the draft sub-endpoint for a specific + resource. + """ + + name: ClassVar[str] = "draft" + + def __init__( + self, + session: DioptraSession[T], + validate_fields_fn: ValidateDraftFieldsProtocol, + root_collection: CollectionClient[T], + parent_sub_collections: list[SubCollectionClient[T]] | None = None, + ) -> None: + """Initialize the DraftsSubEndpoint instance. + + Args: + session: The Dioptra API session object. + parent_endpoint: The parent endpoint client. + validate_fields_fn: The function to validate the allowed draft fields. + """ + super().__init__( + session, + root_collection=root_collection, + parent_sub_collections=parent_sub_collections, + ) + self._validate_fields = validate_fields_fn + + def get_by_id(self, *resource_ids: str | int) -> T: + """Get the draft for a specific endpoint resource. + + Args: + resource_id: The ID of the endpoint resource. + + Returns: + The response from the Dioptra API. + """ + return self._session.get(self.build_sub_collection_url(*resource_ids)) + + def create(self, *resource_ids: str | int, **kwargs) -> T: + """Create a draft for a specific endpoint resource. + + Args: + resource_id: The ID of the endpoint resource. + **kwargs: The draft fields. + + Returns: + The response from the Dioptra API. + """ + return self._session.post( + self.build_sub_collection_url(*resource_ids), + json_=self._validate_fields(kwargs), + ) + + def modify(self, *resource_ids: str | int, **kwargs) -> T: + """Modify the draft for a specific endpoint resource. + + Args: + resource_id: The ID of the endpoint resource. + **kwargs: The draft fields to modify. + + Returns: + The response from the Dioptra API. + """ + return self._session.put( + self.build_sub_collection_url(*resource_ids), + json_=self._validate_fields(kwargs), + ) + + def delete(self, *resource_ids: str | int) -> T: + """Delete the draft for a specific endpoint resource. + + Args: + resource_id: The ID of the endpoint resource. + + Returns: + The response from the Dioptra API. + """ + return self._session.delete(self.build_sub_collection_url(*resource_ids)) diff --git a/src/dioptra/client/entrypoints.py b/src/dioptra/client/entrypoints.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/src/dioptra/client/entrypoints.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/src/dioptra/client/experiments.py b/src/dioptra/client/experiments.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/src/dioptra/client/experiments.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/src/dioptra/client/groups.py b/src/dioptra/client/groups.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/src/dioptra/client/groups.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/src/dioptra/client/jobs.py b/src/dioptra/client/jobs.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/src/dioptra/client/jobs.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/src/dioptra/client/models.py b/src/dioptra/client/models.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/src/dioptra/client/models.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/src/dioptra/client/plugin_parameter_types.py b/src/dioptra/client/plugin_parameter_types.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/src/dioptra/client/plugin_parameter_types.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/src/dioptra/client/plugins.py b/src/dioptra/client/plugins.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/src/dioptra/client/plugins.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/src/dioptra/client/queues.py b/src/dioptra/client/queues.py new file mode 100644 index 000000000..30df37fee --- /dev/null +++ b/src/dioptra/client/queues.py @@ -0,0 +1,198 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from typing import Any, ClassVar, Final, TypeVar + +import structlog +from structlog.stdlib import BoundLogger + +from .base import CollectionClient, DioptraSession +from .drafts import ( + ExistingResourceDraftsSubCollectionClient, + NewResourceDraftsSubCollectionClient, + make_draft_fields_validator, +) +from .snapshots import SnapshotsSubCollectionClient +from .tags import TagsSubCollectionClient + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +DRAFT_FIELDS: Final[set[str]] = {"name", "description"} + +T = TypeVar("T") + + +class QueuesCollectionClient(CollectionClient[T]): + """The client for interacting with the Dioptra API's queues endpoint. + + Attributes: + name: The name of the endpoint. + """ + + name: ClassVar[str] = "queues" + + def __init__(self, session: DioptraSession[T]) -> None: + """Initialize the QueuesClient instance. + + Args: + session: The Dioptra API session object. + """ + super().__init__(session) + self._new_resource_drafts = NewResourceDraftsSubCollectionClient[T]( + session=session, + validate_fields_fn=make_draft_fields_validator( + draft_fields=DRAFT_FIELDS, + resource_name=self.name, + ), + root_collection=self, + ) + self._existing_resource_drafts = ExistingResourceDraftsSubCollectionClient[T]( + session=session, + validate_fields_fn=make_draft_fields_validator( + draft_fields=DRAFT_FIELDS, + resource_name=self.name, + ), + root_collection=self, + ) + self._snapshots = SnapshotsSubCollectionClient[T]( + session=session, root_collection=self + ) + self._tags = TagsSubCollectionClient[T](session=session, root_collection=self) + + @property + def new_resource_drafts(self) -> NewResourceDraftsSubCollectionClient[T]: + return self._new_resource_drafts + + @property + def existing_resource_drafts(self) -> ExistingResourceDraftsSubCollectionClient[T]: + return self._existing_resource_drafts + + @property + def snapshots(self) -> SnapshotsSubCollectionClient[T]: + """The sub-endpoint client for retrieving queue resource snapshots.""" + return self._snapshots + + @property + def tags(self) -> TagsSubCollectionClient[T]: + """The sub-endpoint client for creating and managing queues tags.""" + return self._tags + + def get( + self, + group_id: int | None = None, + index: int = 0, + page_length: int = 10, + sort_by: str | None = None, + descending: bool | None = None, + search: str | None = None, + ) -> T: + """Get a list of queues. + + Args: + group_id: The group ID the queues belong to. If None, return queues from all + groups that the user has access to. + index: The paging index. + page_length: The maximum number of queues to return in the paged response. + search: Search for queues using the Dioptra API's query language. + + Returns: + The response from the Dioptra API. + """ + params: dict[str, Any] = { + "index": index, + "pageLength": page_length, + } + + if sort_by is not None: + params["sortBy"] = sort_by + + if descending is not None: + params["descending"] = descending + + if search is not None: + params["search"] = search + + if group_id is not None: + params["groupId"] = group_id + + return self._session.get( + self.url, + params=params, + ) + + def get_by_id(self, queue_id: str | int) -> T: + """Get the queue matching the provided id. + + Args: + queue_id: The queue id, an integer. + + Returns: + The response from the Dioptra API. + """ + return self._session.get(self.url, str(queue_id)) + + def create(self, group_id: int, name: str, description: str | None = None) -> T: + """Creates a queue. + + Args: + group_id: The ID of the group that will own the queue. + name: The name of the new queue. + description: The description of the new queue. Optional, defaults to None. + + Returns: + The response from the Dioptra API. + """ + json_ = { + "group": group_id, + "name": name, + } + + if description is not None: + json_["description"] = description + + return self._session.post(self.url, json_=json_) + + def modify_by_id( + self, queue_id: str | int, name: str, description: str | None + ) -> T: + """Modify the queue matching the provided id. + + Args: + queue_id: The queue id, an integer. + name: The new name of the queue. + description: The new description of the queue. To remove the description, + pass None. + + Returns: + The response from the Dioptra API. + """ + json_ = {"name": name} + + if description is not None: + json_["description"] = description + + return self._session.put(self.url, str(queue_id), json_=json_) + + def delete_by_id(self, queue_id: str | int) -> T: + """Delete the queue matching the provided id. + + Args: + queue_id: The queue id, an integer. + + Returns: + The response from the Dioptra API. + """ + return self._session.delete(self.url, str(queue_id)) diff --git a/src/dioptra/client/sessions.py b/src/dioptra/client/sessions.py new file mode 100644 index 000000000..fcd413a58 --- /dev/null +++ b/src/dioptra/client/sessions.py @@ -0,0 +1,587 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from abc import ABC, abstractmethod +from typing import Any, Callable, Final, Generic, TypeVar, cast +from urllib.parse import urlparse, urlunparse + +import requests +import structlog +from structlog.stdlib import BoundLogger + +from .base import ( + APIConnectionError, + DioptraResponseProtocol, + DioptraSession, + JSONDecodeError, + StatusCodeError, +) + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +DIOPTRA_API_VERSION: Final[str] = "v1" + +T = TypeVar("T") + + +def wrap_request_method( + func: Any, +) -> Callable[..., DioptraResponseProtocol]: + """Wrap a requests method to log the request and response data. + + Args: + func: The requests method to wrap. + + Returns: + A wrapped version of the requests method that logs the request and response + data. + """ + + def wrapper(url: str, *args, **kwargs) -> DioptraResponseProtocol: + """Wrap the requests method to log the request and response data. + + The returned response object will follow the DioptraResponseProtocol interface. + + Args: + url: The URL of the API endpoint. + *args: Additional arguments to pass to the requests method. + **kwargs: Additional keyword arguments to pass to the requests method. + + Returns: + The response from the requests method. + + Raises: + APIConnectionError: If the connection to the REST API fails. + """ + LOGGER.debug( + "Request made.", + url=url, + method=str(func.__name__).upper(), + ) + + try: + response = cast(DioptraResponseProtocol, func(url, *args, **kwargs)) + + except requests.ConnectionError as err: + LOGGER.error( + "Connection to REST API failed", + url=url, + ) + raise APIConnectionError(f"Connection failed: {url}") from err + + LOGGER.debug("Response received.", status_code=response.status_code) + return response + + return wrapper + + +def convert_response_to_dict(response: DioptraResponseProtocol) -> dict[str, Any]: + """Convert a response object to a JSON-like Python dictionary. + + Args: + response: A response object that follows the DioptraResponseProtocol interface. + + Returns: + A Python dictionary containing the response data. + + Raises: + StatusCodeError: If the response status code is not in the 2xx range. + JSONDecodeError: If the response data cannot be parsed as JSON. + """ + if is_not_2xx(response.status_code): + LOGGER.error( + "HTTP error code returned", + status_code=response.status_code, + method=response.request.method, + text=response.text, + url=response.request.url, + ) + raise StatusCodeError(f"Error code returned: {response.status_code}") + + try: + response_dict = response.json() + + except requests.JSONDecodeError as err: + LOGGER.error( + "Failed to parse HTTP response data as JSON", + method=response.request.method, + text=response.text, + url=response.request.url, + ) + raise JSONDecodeError("Failed to parse HTTP response data as JSON") from err + + return response_dict + + +def is_not_2xx(status_code: int) -> bool: + """Check if the status code is not in the 2xx range. + + Args: + status_code: The HTTP status code to check. + + Returns: + True if the status code is not in the 2xx range, False otherwise. + """ + return status_code < 200 or status_code >= 300 + + +class BaseDioptraRequestsSession(DioptraSession[T], ABC, Generic[T]): + """ + The interface for communicating with the Dioptra API using the requests library. + """ + + def __init__(self, address: str) -> None: + """Initialize the Dioptra API session object. + + Args: + address: The base URL of the API endpoints. + """ + self._scheme, self._netloc, self._path, _, _, _ = urlparse(address) + self._session: requests.Session | None = None + + @property + def url(self) -> str: + """The base URL of the API endpoints.""" + return urlunparse((self._scheme, self._netloc, self._path, "", "", "")) + + def connect(self) -> None: + """Connect to the API using a requests Session.""" + if self._session is None: + self._session = requests.Session() + + def close(self) -> None: + """Close the connection to the API by closing the requests Session.""" + if self._session is None: + return None + + self._session.close() + self._session = None + + def make_request( + self, + method_name: str, + url: str, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a request to the API. + + Args: + method_name: The HTTP method to use. Must be one of "get", "patch", "post", + "put", or "delete". + url: The URL of the API endpoint. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + The response from the API. + + Raises: + ValueError: If an unsupported method is requested. + """ + session = self._get_requests_session() + methods_registry: dict[str, Callable[..., DioptraResponseProtocol]] = { + "get": wrap_request_method(session.get), + "patch": wrap_request_method(session.patch), + "post": wrap_request_method(session.post), + "put": wrap_request_method(session.put), + "delete": wrap_request_method(session.delete), + } + + if method_name not in methods_registry: + LOGGER.error( + "Unsupported method requested. Must be one of " + f"{sorted(methods_registry.keys())}.", + name=method_name, + ) + raise ValueError( + f"Unsupported method requested. Must be one of " + f"{sorted(methods_registry.keys())}." + ) + + method = methods_registry[method_name] + method_kwargs: dict[str, Any] = {} + + if json_: + method_kwargs["json"] = json_ + + if params: + method_kwargs["params"] = params + + return method(url, **method_kwargs) + + @abstractmethod + def get(self, endpoint: str, *parts, params: dict[str, Any] | None = None) -> T: + """Make a GET request to the API. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + + Returns: + The response from the API. + """ + raise NotImplementedError + + @abstractmethod + def patch( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> T: + """Make a PATCH request to the API. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + The response from the API. + """ + raise NotImplementedError + + @abstractmethod + def post( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> T: + """Make a POST request to the API. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + The response from the API. + """ + raise NotImplementedError + + @abstractmethod + def delete( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> T: + """Make a DELETE request to the API. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + The response from the API. + """ + raise NotImplementedError + + @abstractmethod + def put( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> T: + """Make a PUT request to the API. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + The response from the API. + """ + raise NotImplementedError + + def _get_requests_session(self) -> requests.Session: + """Get the requests Session object. + + This method will start a new session if one does not already exist. + + Returns: + The requests Session object. + + Raises: + APIConnectionError: If the session connection fails. + """ + self.connect() + + if self._session is None: + LOGGER.error( + "Failed to start session connection.", + address=self.url, + ) + raise APIConnectionError("Failed to start session connection.") + + return self._session + + +class DioptraRequestsSession(BaseDioptraRequestsSession[DioptraResponseProtocol]): + def get( + self, endpoint: str, *parts, params: dict[str, Any] | None = None + ) -> DioptraResponseProtocol: + """Make a GET request to the API. + + The response will be a requests Response object, which follows the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + + Returns: + A requests Response object. + """ + return self._get(endpoint, *parts, params=params) + + def patch( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a PATCH request to the API. + + The response will be a requests Response object, which follows the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A requests Response object. + """ + return self._patch(endpoint, *parts, params=params, json_=json_) + + def post( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a POST request to the API. + + The response will be a requests Response object, which follows the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A requests Response object. + """ + return self._post(endpoint, *parts, params=params, json_=json_) + + def delete( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a DELETE request to the API. + + The response will be a requests Response object, which follows the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A requests Response object. + """ + return self._delete(endpoint, *parts, params=params, json_=json_) + + def put( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a PUT request to the API. + + The response will be a requests Response object, which follows the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A requests Response object. + """ + return self._put(endpoint, *parts, params=params, json_=json_) + + +class DioptraRequestsSessionJson(BaseDioptraRequestsSession[dict[str, Any]]): + def get( + self, endpoint: str, *parts, params: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Make a GET request to the API. + + The response will be a JSON-like Python dictionary. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + + Returns: + A Python dictionary containing the response data. + """ + return convert_response_to_dict(self._get(endpoint, *parts, params=params)) + + def patch( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Make a PATCH request to the API. + + The response will be a JSON-like Python dictionary. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A Python dictionary containing the response data. + """ + return convert_response_to_dict( + self._patch(endpoint, *parts, params=params, json_=json_) + ) + + def post( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Make a POST request to the API. + + The response will be a JSON-like Python dictionary. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A Python dictionary containing the response data. + """ + return convert_response_to_dict( + self._post(endpoint, *parts, params=params, json_=json_) + ) + + def delete( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Make a DELETE request to the API. + + The response will be a JSON-like Python dictionary. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A Python dictionary containing the response data. + """ + return convert_response_to_dict( + self._delete(endpoint, *parts, params=params, json_=json_) + ) + + def put( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Make a PUT request to the API. + + The response will be a JSON-like Python dictionary. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A Python dictionary containing the response data. + """ + return convert_response_to_dict( + self._put(endpoint, *parts, params=params, json_=json_) + ) diff --git a/src/dioptra/client/snapshots.py b/src/dioptra/client/snapshots.py new file mode 100644 index 000000000..d829e9949 --- /dev/null +++ b/src/dioptra/client/snapshots.py @@ -0,0 +1,85 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from typing import Any, ClassVar, TypeVar + +import structlog +from structlog.stdlib import BoundLogger + +from .base import SubCollectionClient + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +T = TypeVar("T") + + +class SnapshotsSubCollectionClient(SubCollectionClient[T]): + """The sub-endpoint client for retrieving snapshots under an endpoint. + + Attributes: + name: The name of the sub-endpoint. + """ + + name: ClassVar[str] = "snapshots" + + def get( + self, + *resource_ids: str | int, + index: int = 0, + page_length: int = 10, + search: str | None = None, + ) -> T: + """Get the list of snapshots for a given resource. + + Args: + resource_id: The ID of the resource. + index: The paging index. + page_length: The maximum number of snapshots to return in the paged + response. + search: Search for snapshots using the Dioptra API's query language. + + Returns: + The response from the Dioptra API. + """ + params: dict[str, Any] = { + "index": index, + "pageLength": page_length, + } + + if search is not None: + params["search"] = search + + return self._session.get( + self.build_sub_collection_url(*resource_ids), params=params + ) + + def get_by_id( + self, + *resource_ids: str | int, + snapshot_id: int, + ) -> T: + """Get a snapshot by its ID for a specific resource. + + Args: + resource_id: The ID of the resource. + snapshot_id: The ID of the snapshot. + + Returns: + The response from the Dioptra API. + """ + return self._session.get( + self.build_sub_collection_url(*resource_ids), str(snapshot_id) + ) diff --git a/src/dioptra/client/tags.py b/src/dioptra/client/tags.py new file mode 100644 index 000000000..85b607e8d --- /dev/null +++ b/src/dioptra/client/tags.py @@ -0,0 +1,293 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from typing import Any, ClassVar, TypeVar + +import structlog +from structlog.stdlib import BoundLogger + +from .base import CollectionClient, FieldsValidationError, SubCollectionClient + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +T = TypeVar("T") + + +class TagsCollectionClient(CollectionClient[T]): + """The client for interacting with the Dioptra API's /tags endpoint. + + Attributes: + name: The name of the endpoint. + """ + + name: ClassVar[str] = "tags" + + def get( + self, + group_id: int | None = None, + index: int = 0, + page_length: int = 10, + sort_by: str | None = None, + descending: bool | None = None, + search: str | None = None, + ) -> T: + """Get a list of tags. + + Args: + group_id: The group ID the tags belong to. If None, return tags from all + groups that the user has access to. + index: The paging index. + page_length: The maximum number of tags to return in the paged response. + search: Search for tags using the Dioptra API's query language. + + Returns: + The response from the Dioptra API. + """ + params: dict[str, Any] = { + "index": index, + "pageLength": page_length, + } + + if sort_by is not None: + params["sortBy"] = sort_by + + if descending is not None: + params["descending"] = descending + + if search is not None: + params["search"] = search + + if group_id is not None: + params["groupId"] = group_id + + return self._session.get( + self.url, + params=params, + ) + + def get_by_id(self, tag_id: str | int) -> T: + """Get the tag matching the provided id. + + Args: + tag_id: The tag id, an integer. + + Returns: + The response from the Dioptra API. + """ + return self._session.get(self.url, str(tag_id)) + + def create(self, group_id: int, name: str) -> T: + """Creates a tag. + + Args: + group_id: The ID of the group that will own the tag. + name: The name of the new tag. + + Returns: + The response from the Dioptra API. + """ + json_ = { + "group": group_id, + "name": name, + } + + return self._session.post(self.url, json_=json_) + + def modify_by_id(self, tag_id: str | int, name: str) -> T: + """Modify the tag matching the provided id. + + Args: + tag_id: The tag id, an integer. + name: The new name of the tag. + + Returns: + The response from the Dioptra API. + """ + json_ = {"name": name} + + return self._session.put(self.url, str(tag_id), json_=json_) + + def delete_by_id(self, tag_id: str | int) -> T: + """Delete the tag matching the provided id. + + Args: + tag_id: The tag id, an integer. + + Returns: + The response from the Dioptra API. + """ + return self._session.delete(self.url, str(tag_id)) + + def get_resources_for_tag( + self, + tag_id: str | int, + resource_type: str | None = None, + index: int = 0, + page_length: int = 10, + ) -> T: + """Get a list of resources labeled with a tag. + + Args: + tag_id: The tag id, an integer. + resource_type: The type of resource to filter by. If None, return all + resources associated with the tag. Optional, defaults to None. + index: The paging index. + page_length: The maximum number of tags to return in the paged response. + + Returns: + The response from the Dioptra API. + """ + params: dict[str, Any] = { + "index": index, + "pageLength": page_length, + } + + if resource_type is not None: + params["resourceType"] = resource_type + + return self._session.get( + self.url, + str(tag_id), + "resources", + params=params, + ) + + +class TagsSubCollectionClient(SubCollectionClient[T]): + """The sub-endpoint client for managing tags under an endpoint. + + Attributes: + name: The name of the sub-endpoint. + """ + + name: ClassVar[str] = "tags" + + def get(self, *resource_ids: str | int) -> T: + """Get a list of tags. + + Args: + resource_id: The ID of an endpoint resource. + + Returns: + The response from the Dioptra API. + """ + return self._session.get(self.build_sub_collection_url(*resource_ids)) + + def modify( + self, + *resource_ids: str | int, + ids: list[int], + ) -> T: + """Change the list of tags associated with an endpoint resource. + + This method overwrites the existing list of tags associated with an endpoint + resource. To non-destructively append multiple tags, use the `append` method. To + delete an individual tag, use the `remove` method. + + Args: + resource_id: The ID of an endpoint resource. + ids: The list of tag IDs to set on the resource. + + Returns: + The response from the Dioptra API. + """ + return self._session.put( + self.build_sub_collection_url(*resource_ids), + json_={"ids": _validate_ids_argument(ids)}, + ) + + def append( + self, + *resource_ids: str | int, + ids: list[int], + ) -> T: + """Append one or more tags to an endpoint resource. + + Tag IDs that have already been appended to the endpoint resource will be + ignored. + + Args: + resource_id: The ID of an endpoint resource. + ids: The list of tag IDs to append to the endpoint resource. + + Returns: + The response from the Dioptra API. + """ + return self._session.post( + self.build_sub_collection_url(*resource_ids), + json_={"ids": _validate_ids_argument(ids)}, + ) + + def remove( + self, + *resource_ids: int, + tag_id: int, + ) -> T: + """Remove a tag from an endpoint resource. + + Args: + resource_id: The ID of an endpoint resource. + tag_id: The ID of the tag to remove from the endpoint resource. + + Returns: + The response from the Dioptra API. + """ + return self._session.delete( + self.build_sub_collection_url(*resource_ids), str(tag_id) + ) + + def remove_all( + self, + *resource_ids: int, + ) -> T: + """Remove all tags from an endpoint resource. + + This method will remove all tags from the endpoint resource and cannot be + reversed. To remove individual tags, use the `remove` method. + + Args: + resource_id: The ID of an endpoint resource. + + Returns: + The response from the Dioptra API. + """ + return self._session.delete(self.build_sub_collection_url(*resource_ids)) + + +def _validate_ids_argument(ids: list[int]) -> list[int]: + """Validate the ids argument for tag operations. + + Args: + ids: The list of tag IDs to validate. + + Returns: + The validated list of tag IDs. + + Raises: + FieldsValidationError: If the `ids` argument is not a list or is an empty list. + """ + + if not isinstance(ids, list): + LOGGER.error('"ids" argument is invalid', reason=f"Not a list: {type(ids)}") + raise FieldsValidationError( + '"ids" argument is invalid (reason: Not a list): {type(ids)}' + ) + + if len(ids) == 0: + LOGGER.error('"ids" argument is invalid', reason="Empty list") + raise FieldsValidationError('"ids" argument is invalid (reason: Empty list)') + + return ids diff --git a/src/dioptra/client/users.py b/src/dioptra/client/users.py new file mode 100644 index 000000000..ad5b7a1d4 --- /dev/null +++ b/src/dioptra/client/users.py @@ -0,0 +1,183 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from typing import Any, ClassVar, TypeVar + +import structlog +from structlog.stdlib import BoundLogger + +from .base import CollectionClient + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +T = TypeVar("T") + + +class UsersCollectionClient(CollectionClient[T]): + """The client for interacting with the Dioptra API's users endpoint. + + Attributes: + name: The name of the endpoint. + """ + + name: ClassVar[str] = "users" + + def get( + self, index: int = 0, page_length: int = 10, search: str | None = None + ) -> T: + """Get a list of Dioptra users. + + Args: + index: The paging index. + page_length: The maximum number of users to return in the paged response. + search: Search for users using the Dioptra API's query language. + + Returns: + The response from the Dioptra API. + """ + params: dict[str, Any] = { + "index": index, + "pageLength": page_length, + } + + if search is not None: + params["search"] = search + + return self._session.get( + self.url, + params=params, + ) + + def create(self, username: str, email: str, password: str) -> T: + """Creates a Dioptra user. + + Args: + username: The username of the new user. + email: The email address of the new user. + password: The password to set for the new user. + + Returns: + The response from the Dioptra API. + """ + + return self._session.post( + self.url, + json_={ + "username": username, + "email": email, + "password": password, + "confirmPassword": password, + }, + ) + + def get_by_id(self, user_id: str | int) -> T: + """Get the user matching the provided id. + + Args: + user_id: The user id, an integer. + + Returns: + The response from the Dioptra API. + """ + return self._session.get(self.url, str(user_id)) + + def change_password_by_id( + self, user_id: str | int, old_password: str, new_password: str + ) -> T: + """Change the password of the user matching the provided id. + + This primary use case for using this over `change_current_user_password` is if + your password has expired and you need to update it before you can log in. + + Args: + user_id: The user id, an integer. + old_password: The user's current password. The password change will fail if + this is incorrect. + new_password: The new password to set for the user. + + Returns: + The response from the Dioptra API. + """ + return self._session.post( + self.url, + str(user_id), + "password", + json_={ + "oldPassword": old_password, + "newPassword": new_password, + "confirmNewPassword": new_password, + }, + ) + + def get_current(self) -> T: + """Get details about the currently logged-in user. + + Returns: + The response from the Dioptra API. + """ + return self._session.get(self.url, "current") + + def delete_current_user(self, password: str) -> T: + """Delete the currently logged-in user. + + Args: + password: The password of the currently logged-in user. The deletion will + fail if this is incorrect. + + Returns: + The response from the Dioptra API. + """ + return self._session.delete(self.url, "current", json_={"password": password}) + + def modify_current_user(self, username: str, email: str) -> T: + """Modify details about the currently logged-in user. + + Args: + username: The new username for the currently logged-in user. If None, the + username will not be changed. + email: The new email address for the currently logged-in user. If None, the + email address will not be changed. + + Returns: + The response from the Dioptra API. + """ + return self._session.put( + self.url, + "current", + json_={"username": username, "email": email}, + ) + + def change_current_user_password(self, old_password: str, new_password: str) -> T: + """Change the currently logged-in user's password. + + Args: + old_password: The currently logged-in user's current password. The password + change will fail if this is incorrect. + new_password: The new password to set for the currently logged-in user. + + Returns: + The response from the Dioptra API. + """ + return self._session.post( + self.url, + "current", + "password", + json_={ + "oldPassword": old_password, + "newPassword": new_password, + "confirmNewPassword": new_password, + }, + ) diff --git a/tests/unit/restapi/conftest.py b/tests/unit/restapi/conftest.py index 265832fca..61d6e298f 100644 --- a/tests/unit/restapi/conftest.py +++ b/tests/unit/restapi/conftest.py @@ -35,10 +35,13 @@ from requests import ConnectionError from requests import Session as RequestsSession +from dioptra.client.base import DioptraResponseProtocol +from dioptra.client.client import DioptraClient from dioptra.restapi.db import db as restapi_db from dioptra.restapi.v1.shared.request_scope import request from .lib import db as libdb +from .lib.client import DioptraFlaskClientSession from .lib.server import FlaskTestServer @@ -140,6 +143,11 @@ def client(app: Flask) -> FlaskClient: return app.test_client() +@pytest.fixture +def dioptra_client(client: FlaskClient) -> DioptraClient[DioptraResponseProtocol]: + return DioptraClient[DioptraResponseProtocol](DioptraFlaskClientSession(client)) + + @pytest.fixture def flask_test_server(tmp_path: Path, http_client: RequestsSession): """Start a Flask test server. diff --git a/tests/unit/restapi/lib/asserts_client.py b/tests/unit/restapi/lib/asserts_client.py new file mode 100644 index 000000000..fee82716b --- /dev/null +++ b/tests/unit/restapi/lib/asserts_client.py @@ -0,0 +1,196 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Shared assertions for REST API unit tests.""" +from typing import Any + +from dioptra.client.drafts import ( + ExistingResourceDraftsSubCollectionClient, + NewResourceDraftsSubCollectionClient, +) +from dioptra.client.snapshots import SnapshotsSubCollectionClient + + +def assert_retrieving_draft_by_resource_id_works( + drafts_client: ExistingResourceDraftsSubCollectionClient, + *resource_ids: str | int, + expected: dict[str, Any], +) -> None: + """Assert that retrieving an existing resource draft by resource id works. + + Args: + drafts_client: The DraftsSubEndpointClient client. + resource_id: The id of the resource to retrieve the draft for. + expected: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + response = drafts_client.get_by_id(*resource_ids) + assert response.status_code == 200 and response.json() == expected + + +def assert_retrieving_draft_by_id_works( + drafts_client: NewResourceDraftsSubCollectionClient, + *resource_ids: str | int, + draft_id: int, + expected: dict[str, Any], +) -> None: + """Assert that retrieving a draft by id works. + + Args: + drafts_client: The DraftsSubEndpointClient client. + draft_id: The id of the draft to retrieve. + expected: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + response = drafts_client.get_by_id(*resource_ids, draft_id=draft_id) + assert response.status_code == 200 and response.json() == expected + + +def assert_retrieving_drafts_works( + drafts_client: NewResourceDraftsSubCollectionClient, + *resource_ids: str | int, + expected: list[dict[str, Any]], + group_id: int | None = None, + paging_info: dict[str, Any] | None = None, +) -> None: + """Assert that retrieving all drafts for a resource type works. + + Args: + drafts_client: The DraftsSubEndpointClient client. + expected: The expected response from the API. + group_id: The group ID used in query parameters. + paging_info: The paging information used in query parameters. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + + query_string: dict[str, Any] = {} + + if group_id is not None: + query_string["group_id"] = group_id + + if paging_info is not None: + query_string["index"] = paging_info["index"] + query_string["page_length"] = paging_info["page_length"] + + response = drafts_client.get(*resource_ids, **query_string) + assert response.status_code == 200 and response.json()["data"] == expected + + +def assert_creating_another_existing_draft_fails( + drafts_client: ExistingResourceDraftsSubCollectionClient, + *resource_ids: str | int, + payload: dict[str, Any], +) -> None: + """Assert that registering another draft for the same resource fails + + Args: + drafts_client: The DraftsSubEndpointClient client. + resource_id: The id of the resource to retrieve the draft for. + payload: A dictionary containing the draft fields. + + Raises: + AssertionError: If the response status code is not 400. + """ + response = drafts_client.create( + *resource_ids, **payload + ) + assert response.status_code == 400 + + +def assert_existing_draft_is_not_found( + drafts_client: ExistingResourceDraftsSubCollectionClient, + *resource_ids: str | int, +) -> None: + """Assert that a draft of an existing resource is not found. + + Args: + drafts_client: The DraftsSubEndpointClient client. + resource_id: The id of the resource to retrieve the draft for. + + Raises: + AssertionError: If the response status code is not 404. + """ + response = drafts_client.get_by_id(*resource_ids) + assert response.status_code == 404 + + +def assert_new_draft_is_not_found( + drafts_client: NewResourceDraftsSubCollectionClient, + *resource_ids: str | int, + draft_id: int, +) -> None: + """Assert that a draft of an existing resource is not found. + + Args: + drafts_client: The DraftsSubEndpointClient client. + resource_id: The id of the resource to retrieve the draft for. + + Raises: + AssertionError: If the response status code is not 404. + """ + response = drafts_client.get_by_id(*resource_ids, draft_id=draft_id) + assert response.status_code == 404 + + +def assert_retrieving_snapshots_works( + snapshots_client: SnapshotsSubCollectionClient, + *resource_ids: int, + expected: dict[str, Any], +) -> None: + """Assert that retrieving a snapshot by id works. + + Args: + snapshots_client: The SnapshotsSubCollectionClient client. + resource_id: The id of the resource to retrieve snapshots for. + expected: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + response = snapshots_client.get(*resource_ids) + assert response.status_code == 200 and response.json()["data"] == expected + + +def assert_retrieving_snapshot_by_id_works( + snapshots_client: SnapshotsSubCollectionClient, + *resource_ids: int, + snapshot_id: int, + expected: dict[str, Any], +) -> None: + """Assert that retrieving a resource snapshot by id works. + + Args: + snapshots_client: The SnapshotsSubCollectionClient client. + resource_id: The id of the resource to retrieve a snapshot of. + snapshot_id: The id to the snapshot to retrieve. + expected: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + response = snapshots_client.get_by_id(*resource_ids, snapshot_id=snapshot_id) + assert response.status_code == 200 and response.json() == expected diff --git a/tests/unit/restapi/lib/client.py b/tests/unit/restapi/lib/client.py new file mode 100644 index 000000000..79c64f7eb --- /dev/null +++ b/tests/unit/restapi/lib/client.py @@ -0,0 +1,320 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from typing import Any, Callable, Protocol, cast + +import structlog +from flask.testing import FlaskClient +from structlog.stdlib import BoundLogger +from werkzeug.test import TestResponse + +from dioptra.client.base import ( + DioptraRequestProtocol, + DioptraResponseProtocol, + DioptraSession, +) +from dioptra.restapi.routes import V1_ROOT + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + + +class DioptraTestResponse(object): + """ + A wrapper for Flask TestResponse objects that follows the DioptraResponseProtocol + interface. + """ + + def __init__(self, test_response: TestResponse) -> None: + """Initialize the DioptraTestResponse instance. + + Args: + test_response: The Flask TestResponse object. + """ + self._test_response = test_response + + @property + def request(self) -> DioptraRequestProtocol: + """The request that generated the response.""" + return cast(DioptraRequestProtocol, self._test_response.request) + + @property + def status_code(self) -> int: + """The HTTP status code of the response.""" + return self._test_response.status_code + + @property + def text(self) -> str: + """The response body as a string.""" + return self._test_response.text + + def json(self) -> dict[str, Any]: + """Return the response body as a JSON-like Python dictionary. + + Returns: + The response body as a dictionary. + """ + return cast(dict[str, Any], self._test_response.get_json(silent=False)) + + +class RequestMethodProtocol(Protocol): + """The interface for a FlaskClient request method.""" + + def __call__(self, *args: Any, **kw: Any) -> TestResponse: + """The method signature for a FlaskClient request method. + + Args: + *args: Positional arguments to pass to the request method. + **kw: Keyword arguments to pass to the request method. + + Returns: + A Flask TestResponse object. + """ + ... # fmt: skip + + +def wrap_request_method( + func: RequestMethodProtocol, +) -> Callable[..., DioptraResponseProtocol]: + """ + Wrap a FlaskClient request method to log the requests and responses and wrap the + response in a DioptraTestResponse object. + + Args: + func: The FlaskClient request method to wrap. + + Returns: + The wrapped request method. + """ + + def wrapper(url: str, *args, **kwargs) -> DioptraResponseProtocol: + """Wrap the FlaskClient request method. + + The returned response object will follow the DioptraResponseProtocol interface. + + Args: + url: The URL of the API endpoint. + *args: Additional arguments to pass to the requests method. + **kwargs: Additional keyword arguments to pass to the requests method. + + Returns: + The response from the requests method. + """ + LOGGER.debug( + "Request made.", + url=url, + method=str(func.__name__).upper(), # type: ignore + method_kwargs=kwargs, + ) + response = DioptraTestResponse(func(url, *args, **kwargs)) + LOGGER.debug("Response received.", status_code=response.status_code) + return response + + return wrapper + + +class DioptraFlaskClientSession(DioptraSession[DioptraResponseProtocol]): + """ + The interface for communicating with the Dioptra API using the FlaskClient. + """ + + def __init__(self, flask_client: FlaskClient) -> None: + """Initialize the DioptraFlaskClientSession instance. + + Args: + flask_client: The FlaskClient object to use for making requests. + """ + self._session: FlaskClient = flask_client + + @property + def url(self) -> str: + """The base URL of the API endpoints.""" + return self.build_url("/", V1_ROOT) + + def connect(self) -> None: + """Connect to the API. A no-op for the FlaskClient.""" + pass + + def close(self) -> None: + """Close the connection to the API. A no-op for the FlaskClient.""" + pass + + def make_request( + self, + method_name: str, + url: str, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a request to the API. + + Args: + method_name: The HTTP method to use. Must be one of "get", "patch", "post", + "put", or "delete". + url: The URL of the API endpoint. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + The response from the API. + + Raises: + ValueError: If an unsupported method is requested. + """ + methods_registry: dict[str, Callable[..., DioptraResponseProtocol]] = { + "get": wrap_request_method(self._session.get), + "patch": wrap_request_method(self._session.patch), + "post": wrap_request_method(self._session.post), + "put": wrap_request_method(self._session.put), + "delete": wrap_request_method(self._session.delete), + } + + if method_name not in methods_registry: + LOGGER.error( + "Unsupported method requested. Must be one of " + f"{sorted(methods_registry.keys())}.", + name=method_name, + ) + raise ValueError( + f"Unsupported method requested. Must be one of " + f"{sorted(methods_registry.keys())}." + ) + + method = methods_registry[method_name] + method_kwargs: dict[str, Any] = {"follow_redirects": True} + + if json_: + method_kwargs["json"] = json_ + + if params: + method_kwargs["query_string"] = params + + return method(url, **method_kwargs) + + def get( + self, endpoint: str, *parts, params: dict[str, Any] | None = None + ) -> DioptraResponseProtocol: + """Make a GET request to the API. + + The response will be a DioptraTestResponse object, which follows the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + + Returns: + A DioptraTestResponse object. + """ + return self._get(endpoint, *parts, params=params) + + def patch( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a PATCH request to the API. + + The response will be a DioptraTestResponse object, which follows the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A DioptraTestResponse object. + """ + return self._patch(endpoint, *parts, params=params, json_=json_) + + def post( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a POST request to the API. + + The response will be a DioptraTestResponse object, which follows the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A DioptraTestResponse object. + """ + return self._post(endpoint, *parts, params=params, json_=json_) + + def delete( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a DELETE request to the API. + + The response will be a DioptraTestResponse object, which follows the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A DioptraTestResponse object. + """ + return self._delete(endpoint, *parts, params=params, json_=json_) + + def put( + self, + endpoint: str, + *parts, + params: dict[str, Any] | None = None, + json_: dict[str, Any] | None = None, + ) -> DioptraResponseProtocol: + """Make a PUT request to the API. + + The response will be a DioptraTestResponse object, which follows the + DioptraResponseProtocol interface. + + Args: + endpoint: The base URL of the API endpoint. + *parts: Additional parts to append to the base URL. + params: The query parameters to include in the request. Optional, defaults + to None. + json_: The JSON data to include in the request. Optional, defaults to None. + + Returns: + A DioptraTestResponse object. + """ + return self._put(endpoint, *parts, params=params, json_=json_) diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index dade9d789..a0c4c821c 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -17,7 +17,6 @@ """Fixtures representing resources needed for test suites""" import textwrap from collections.abc import Iterator -from pathlib import Path from typing import Any, cast import pytest diff --git a/tests/unit/restapi/v1/test_experiment.py b/tests/unit/restapi/v1/test_experiment.py index 9a370a252..2911d4dd2 100644 --- a/tests/unit/restapi/v1/test_experiment.py +++ b/tests/unit/restapi/v1/test_experiment.py @@ -14,7 +14,6 @@ # # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode - """Test suite for experiment operations. This module contains a set of tests that validate the supported CRUD operations and diff --git a/tests/unit/restapi/v1/test_group.py b/tests/unit/restapi/v1/test_group.py index cc896dcfd..3113328a1 100644 --- a/tests/unit/restapi/v1/test_group.py +++ b/tests/unit/restapi/v1/test_group.py @@ -15,6 +15,7 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """Test suite for group operations. + This module contains a set of tests that validate the CRUD operations and additional functionalities for the group entity. The tests ensure that the groups can be registered, queried, and renamed as expected through the REST API. diff --git a/tests/unit/restapi/v1/test_plugin_parameter_type.py b/tests/unit/restapi/v1/test_plugin_parameter_type.py index 8553513ca..910a32d43 100644 --- a/tests/unit/restapi/v1/test_plugin_parameter_type.py +++ b/tests/unit/restapi/v1/test_plugin_parameter_type.py @@ -130,9 +130,12 @@ def assert_retrieving_plugin_parameter_types_works( def assert_sorting_plugin_parameter_type_works( client: FlaskClient, - sortBy: str, - descending: bool, expected: list[str], + sort_by: str | None, + descending: bool | None, + group_id: int | None = None, + search: str | None = None, + paging_info: dict[str, Any] | None = None, ) -> None: """Assert that plugin parameter types can be sorted by column ascending/descending. @@ -148,8 +151,21 @@ def assert_sorting_plugin_parameter_type_works( query_string: dict[str, Any] = {} - query_string["sortBy"] = sortBy - query_string["descending"] = descending + if descending is not None: + query_string["descending"] = descending + + if sort_by is not None: + query_string["sortBy"] = sort_by + + if group_id is not None: + query_string["groupId"] = group_id + + if search is not None: + query_string["search"] = search + + if paging_info is not None: + query_string["index"] = paging_info["index"] + query_string["pageLength"] = paging_info["page_length"] response = client.get( f"/{V1_ROOT}/{V1_PLUGIN_PARAMETER_TYPES_ROUTE}", @@ -467,7 +483,7 @@ def test_get_all_plugin_parameter_types( @pytest.mark.parametrize( - "sortBy, descending , expected", + "sort_by, descending , expected", [ ( None, @@ -501,7 +517,7 @@ def test_plugin_parameter_type_sort( db: SQLAlchemy, auth_account: dict[str, Any], registered_plugin_parameter_types: dict[str, Any], - sortBy: str, + sort_by: str | None, descending: bool, expected: list[str], ) -> None: @@ -522,7 +538,7 @@ def test_plugin_parameter_type_sort( for expected_name in expected ] assert_sorting_plugin_parameter_type_works( - client, sortBy, descending, expected=expected_ids + client, sort_by=sort_by, descending=descending, expected=expected_ids ) diff --git a/tests/unit/restapi/v1/test_queue.py b/tests/unit/restapi/v1/test_queue.py index 79d6ab688..99ea824b5 100644 --- a/tests/unit/restapi/v1/test_queue.py +++ b/tests/unit/restapi/v1/test_queue.py @@ -25,60 +25,11 @@ import pytest from flask.testing import FlaskClient from flask_sqlalchemy import SQLAlchemy -from werkzeug.test import TestResponse -from dioptra.restapi.routes import V1_ENTRYPOINTS_ROUTE, V1_QUEUES_ROUTE, V1_ROOT - -from ..lib import actions, asserts, helpers - -# -- Actions --------------------------------------------------------------------------- - - -def modify_queue( - client: FlaskClient, - queue_id: int, - new_name: str, - new_description: str, -) -> TestResponse: - """Rename a queue using the API. - - Args: - client: The Flask test client. - queue_id: The id of the queue to rename. - new_name: The new name to assign to the queue. - new_description: The new description to assign to the queue. - - Returns: - The response from the API. - """ - payload = {"name": new_name, "description": new_description} - - return client.put( - f"/{V1_ROOT}/{V1_QUEUES_ROUTE}/{queue_id}", - json=payload, - follow_redirects=True, - ) - - -def delete_queue_with_id( - client: FlaskClient, - queue_id: int, -) -> TestResponse: - """Delete a queue using the API. - - Args: - client: The Flask test client. - queue_id: The id of the queue to delete. - - Returns: - The response from the API. - """ - - return client.delete( - f"/{V1_ROOT}/{V1_QUEUES_ROUTE}/{queue_id}", - follow_redirects=True, - ) +from dioptra.client.client import DioptraClient +from dioptra.restapi.routes import V1_ENTRYPOINTS_ROUTE, V1_ROOT +from ..lib import asserts, asserts_client, helpers # -- Assertions ------------------------------------------------------------------------ @@ -148,14 +99,14 @@ def assert_queue_response_contents_matches_expectations( def assert_retrieving_queue_by_id_works( - client: FlaskClient, + dioptra_client: DioptraClient, queue_id: int, expected: dict[str, Any], ) -> None: """Assert that retrieving a queue by id works. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. queue_id: The id of the queue to retrieve. expected: The expected response from the API. @@ -163,14 +114,12 @@ def assert_retrieving_queue_by_id_works( AssertionError: If the response status code is not 200 or if the API response does not match the expected response. """ - response = client.get( - f"/{V1_ROOT}/{V1_QUEUES_ROUTE}/{queue_id}", follow_redirects=True - ) - assert response.status_code == 200 and response.get_json() == expected + response = dioptra_client.queues.get_by_id(queue_id) + assert response.status_code == 200 and response.json() == expected def assert_retrieving_queues_works( - client: FlaskClient, + dioptra_client: DioptraClient, expected: list[dict[str, Any]], group_id: int | None = None, search: str | None = None, @@ -179,7 +128,7 @@ def assert_retrieving_queues_works( """Assert that retrieving all queues works. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. expected: The expected response from the API. group_id: The group ID used in query parameters. search: The search string used in query parameters. @@ -193,28 +142,27 @@ def assert_retrieving_queues_works( query_string: dict[str, Any] = {} if group_id is not None: - query_string["groupId"] = group_id + query_string["group_id"] = group_id if search is not None: query_string["search"] = search if paging_info is not None: query_string["index"] = paging_info["index"] - query_string["pageLength"] = paging_info["page_length"] + query_string["page_length"] = paging_info["page_length"] - response = client.get( - f"/{V1_ROOT}/{V1_QUEUES_ROUTE}", - query_string=query_string, - follow_redirects=True, - ) - assert response.status_code == 200 and response.get_json()["data"] == expected + response = dioptra_client.queues.get(**query_string) + assert response.status_code == 200 and response.json()["data"] == expected def assert_sorting_queue_works( - client: FlaskClient, - sortBy: str, - descending: bool, + dioptra_client: DioptraClient, expected: list[str], + sort_by: str | None, + descending: bool | None, + group_id: int | None = None, + search: str | None = None, + paging_info: dict[str, Any] | None = None, ) -> None: """Assert that queues can be sorted by column ascending/descending. @@ -230,46 +178,53 @@ def assert_sorting_queue_works( query_string: dict[str, Any] = {} - query_string["sortBy"] = sortBy - query_string["descending"] = descending + if descending is not None: + query_string["descending"] = descending - response = client.get( - f"/{V1_ROOT}/{V1_QUEUES_ROUTE}", - query_string=query_string, - follow_redirects=True, - ) + if sort_by is not None: + query_string["sort_by"] = sort_by - response_data = response.get_json() - queue_ids = [queue["id"] for queue in response_data["data"]] + if group_id is not None: + query_string["group_id"] = group_id + + if search is not None: + query_string["search"] = search + if paging_info is not None: + query_string["index"] = paging_info["index"] + query_string["page_length"] = paging_info["page_length"] + + response = dioptra_client.queues.get(**query_string) + response_data = response.json() + queue_ids = [queue["id"] for queue in response_data["data"]] assert response.status_code == 200 and queue_ids == expected def assert_registering_existing_queue_name_fails( - client: FlaskClient, name: str, group_id: int + dioptra_client: DioptraClient, name: str, group_id: int ) -> None: """Assert that registering a queue with an existing name fails. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. name: The name to assign to the new queue. Raises: AssertionError: If the response status code is not 400. """ - response = actions.register_queue( - client, name=name, description="", group_id=group_id + response = dioptra_client.queues.create( + group_id=group_id, name=name, description="" ) assert response.status_code == 409 def assert_queue_name_matches_expected_name( - client: FlaskClient, queue_id: int, expected_name: str + dioptra_client: DioptraClient, queue_id: int, expected_name: str ) -> None: """Assert that the name of a queue matches the expected name. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. queue_id: The id of the queue to retrieve. expected_name: The expected name of the queue. @@ -277,30 +232,24 @@ def assert_queue_name_matches_expected_name( AssertionError: If the response status code is not 200 or if the name of the queue does not match the expected name. """ - response = client.get( - f"/{V1_ROOT}/{V1_QUEUES_ROUTE}/{queue_id}", - follow_redirects=True, - ) - assert response.status_code == 200 and response.get_json()["name"] == expected_name + response = dioptra_client.queues.get_by_id(queue_id) + assert response.status_code == 200 and response.json()["name"] == expected_name def assert_queue_is_not_found( - client: FlaskClient, + dioptra_client: DioptraClient, queue_id: int, ) -> None: """Assert that a queue is not found. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. queue_id: The id of the queue to retrieve. Raises: AssertionError: If the response status code is not 404. """ - response = client.get( - f"/{V1_ROOT}/{V1_QUEUES_ROUTE}/{queue_id}", - follow_redirects=True, - ) + response = dioptra_client.queues.get_by_id(queue_id) assert response.status_code == 404 @@ -331,7 +280,7 @@ def assert_queue_is_not_associated_with_entrypoint( def assert_cannot_rename_queue_with_existing_name( - client: FlaskClient, + dioptra_client: DioptraClient, queue_id: int, existing_name: str, existing_description: str, @@ -339,18 +288,17 @@ def assert_cannot_rename_queue_with_existing_name( """Assert that renaming a queue with an existing name fails. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. queue_id: The id of the queue to rename. name: The name of an existing queue. Raises: AssertionError: If the response status code is not 400. """ - response = modify_queue( - client=client, + response = dioptra_client.queues.modify_by_id( queue_id=queue_id, - new_name=existing_name, - new_description=existing_description, + name=existing_name, + description=existing_description, ) assert response.status_code == 409 @@ -359,7 +307,7 @@ def assert_cannot_rename_queue_with_existing_name( def test_create_queue( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], ) -> None: @@ -375,10 +323,10 @@ def test_create_queue( description = "The first queue." user_id = auth_account["id"] group_id = auth_account["groups"][0]["id"] - queue1_response = actions.register_queue( - client, name=name, description=description, group_id=group_id + queue1_response = dioptra_client.queues.create( + group_id=group_id, name=name, description=description ) - queue1_expected = queue1_response.get_json() + queue1_expected = queue1_response.json() assert_queue_response_contents_matches_expectations( response=queue1_expected, expected_contents={ @@ -389,12 +337,12 @@ def test_create_queue( }, ) assert_retrieving_queue_by_id_works( - client, queue_id=queue1_expected["id"], expected=queue1_expected + dioptra_client, queue_id=queue1_expected["id"], expected=queue1_expected ) def test_queue_get_all( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_queues: dict[str, Any], @@ -409,11 +357,11 @@ def test_queue_get_all( - The returned list of queues matches the full list of registered queues. """ queue_expected_list = list(registered_queues.values()) - assert_retrieving_queues_works(client, expected=queue_expected_list) + assert_retrieving_queues_works(dioptra_client, expected=queue_expected_list) @pytest.mark.parametrize( - "sortBy, descending , expected", + "sort_by,descending,expected", [ (None, None, ["queue1", "queue2", "queue3"]), ("name", True, ["queue2", "queue1", "queue3"]), @@ -423,11 +371,11 @@ def test_queue_get_all( ], ) def test_queue_sort( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_queues: dict[str, Any], - sortBy: str, + sort_by: str | None, descending: bool, expected: list[str], ) -> None: @@ -445,11 +393,13 @@ def test_queue_sort( expected_ids = [ registered_queues[expected_name]["id"] for expected_name in expected ] - assert_sorting_queue_works(client, sortBy, descending, expected=expected_ids) + assert_sorting_queue_works( + dioptra_client, sort_by=sort_by, descending=descending, expected=expected_ids + ) def test_queue_search_query( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_queues: dict[str, Any], @@ -464,17 +414,19 @@ def test_queue_search_query( """ queue_expected_list = list(registered_queues.values())[:2] assert_retrieving_queues_works( - client, expected=queue_expected_list, search="description:*queue*" + dioptra_client, expected=queue_expected_list, search="description:*queue*" ) assert_retrieving_queues_works( - client, expected=queue_expected_list, search="*queue*, name:tensorflow*" + dioptra_client, expected=queue_expected_list, search="*queue*, name:tensorflow*" ) queue_expected_list = list(registered_queues.values()) - assert_retrieving_queues_works(client, expected=queue_expected_list, search="*") + assert_retrieving_queues_works( + dioptra_client, expected=queue_expected_list, search="*" + ) def test_queue_group_query( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_queues: dict[str, Any], @@ -490,14 +442,14 @@ def test_queue_group_query( """ queue_expected_list = list(registered_queues.values()) assert_retrieving_queues_works( - client, + dioptra_client, expected=queue_expected_list, group_id=auth_account["groups"][0]["id"], ) def test_cannot_register_existing_queue_name( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_queues: dict[str, Any], @@ -513,14 +465,14 @@ def test_cannot_register_existing_queue_name( existing_queue = registered_queues["queue1"] assert_registering_existing_queue_name_fails( - client, + dioptra_client, name=existing_queue["name"], group_id=existing_queue["group"]["id"], ) def test_rename_queue( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_queues: dict[str, Any], @@ -544,34 +496,32 @@ def test_rename_queue( queue_to_rename = registered_queues["queue1"] existing_queue = registered_queues["queue2"] - modified_queue = modify_queue( - client, + modified_queue = dioptra_client.queues.modify_by_id( queue_id=queue_to_rename["id"], - new_name=updated_queue_name, - new_description=queue_to_rename["description"], - ).get_json() + name=updated_queue_name, + description=queue_to_rename["description"], + ).json() assert_queue_name_matches_expected_name( - client, queue_id=queue_to_rename["id"], expected_name=updated_queue_name + dioptra_client, queue_id=queue_to_rename["id"], expected_name=updated_queue_name ) queue_expected_list = [ modified_queue, registered_queues["queue2"], registered_queues["queue3"], ] - assert_retrieving_queues_works(client, expected=queue_expected_list) + assert_retrieving_queues_works(dioptra_client, expected=queue_expected_list) - modified_queue = modify_queue( - client, + modified_queue = dioptra_client.queues.modify_by_id( queue_id=queue_to_rename["id"], - new_name=updated_queue_name, - new_description=queue_to_rename["description"], - ).get_json() + name=updated_queue_name, + description=queue_to_rename["description"], + ).json() assert_queue_name_matches_expected_name( - client, queue_id=queue_to_rename["id"], expected_name=updated_queue_name + dioptra_client, queue_id=queue_to_rename["id"], expected_name=updated_queue_name ) assert_cannot_rename_queue_with_existing_name( - client, + dioptra_client, queue_id=queue_to_rename["id"], existing_name=existing_queue["name"], existing_description=queue_to_rename["description"], @@ -580,6 +530,7 @@ def test_rename_queue( def test_delete_queue_by_id( client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_queues: dict[str, Any], @@ -598,15 +549,15 @@ def test_delete_queue_by_id( entrypoint = registered_entrypoints["entrypoint1"] queue_to_delete = entrypoint["queues"][0] - delete_queue_with_id(client, queue_id=queue_to_delete["id"]) - assert_queue_is_not_found(client, queue_id=queue_to_delete["id"]) + dioptra_client.queues.delete_by_id(queue_to_delete["id"]) + assert_queue_is_not_found(dioptra_client, queue_id=queue_to_delete["id"]) assert_queue_is_not_associated_with_entrypoint( client, entrypoint_id=entrypoint["id"], queue_id=queue_to_delete["id"] ) def test_manage_existing_queue_draft( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_queues: dict[str, Any], @@ -641,18 +592,17 @@ def test_manage_existing_queue_draft( "num_other_drafts": 0, "payload": payload, } - response = actions.create_existing_resource_draft( - client, resource_route=V1_QUEUES_ROUTE, resource_id=queue["id"], payload=payload - ).get_json() + response = dioptra_client.queues.existing_resource_drafts.create( + queue["id"], **payload + ).json() asserts.assert_draft_response_contents_matches_expectations(response, expected) - asserts.assert_retrieving_draft_by_resource_id_works( - client, - resource_route=V1_QUEUES_ROUTE, - resource_id=queue["id"], + asserts_client.assert_retrieving_draft_by_resource_id_works( + dioptra_client.queues.existing_resource_drafts, + queue["id"], expected=response, ) - asserts.assert_creating_another_existing_draft_fails( - client, resource_route=V1_QUEUES_ROUTE, resource_id=queue["id"] + asserts_client.assert_creating_another_existing_draft_fails( + dioptra_client.queues.existing_resource_drafts, queue["id"], payload=payload ) # test modification @@ -665,25 +615,20 @@ def test_manage_existing_queue_draft( "num_other_drafts": 0, "payload": payload, } - response = actions.modify_existing_resource_draft( - client, - resource_route=V1_QUEUES_ROUTE, - resource_id=queue["id"], - payload=payload, - ).get_json() + response = dioptra_client.queues.existing_resource_drafts.modify( + queue["id"], **payload + ).json() asserts.assert_draft_response_contents_matches_expectations(response, expected) # test deletion - actions.delete_existing_resource_draft( - client, resource_route=V1_QUEUES_ROUTE, resource_id=queue["id"] - ) - asserts.assert_existing_draft_is_not_found( - client, resource_route=V1_QUEUES_ROUTE, resource_id=queue["id"] + dioptra_client.queues.existing_resource_drafts.delete(queue["id"]) + asserts_client.assert_existing_draft_is_not_found( + dioptra_client.queues.existing_resource_drafts, queue["id"] ) def test_manage_new_queue_drafts( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], ) -> None: @@ -711,18 +656,14 @@ def test_manage_new_queue_drafts( "group_id": group_id, "payload": drafts["draft1"], } - draft1_response = actions.create_new_resource_draft( - client, - resource_route=V1_QUEUES_ROUTE, - group_id=group_id, - payload=drafts["draft1"], - ).get_json() + draft1_response = dioptra_client.queues.new_resource_drafts.create( + group_id=group_id, **drafts["draft1"] + ).json() asserts.assert_draft_response_contents_matches_expectations( draft1_response, draft1_expected ) - asserts.assert_retrieving_draft_by_id_works( - client, - resource_route=V1_QUEUES_ROUTE, + asserts_client.assert_retrieving_draft_by_id_works( + dioptra_client.queues.new_resource_drafts, draft_id=draft1_response["id"], expected=draft1_response, ) @@ -731,24 +672,19 @@ def test_manage_new_queue_drafts( "group_id": group_id, "payload": drafts["draft2"], } - draft2_response = actions.create_new_resource_draft( - client, - resource_route=V1_QUEUES_ROUTE, - group_id=group_id, - payload=drafts["draft2"], - ).get_json() + draft2_response = dioptra_client.queues.new_resource_drafts.create( + group_id=group_id, **drafts["draft2"] + ).json() asserts.assert_draft_response_contents_matches_expectations( draft2_response, draft2_expected ) - asserts.assert_retrieving_draft_by_id_works( - client, - resource_route=V1_QUEUES_ROUTE, + asserts_client.assert_retrieving_draft_by_id_works( + dioptra_client.queues.new_resource_drafts, draft_id=draft2_response["id"], expected=draft2_response, ) - asserts.assert_retrieving_drafts_works( - client, - resource_route=V1_QUEUES_ROUTE, + asserts_client.assert_retrieving_drafts_works( + dioptra_client.queues.new_resource_drafts, expected=[draft1_response, draft2_response], ) @@ -759,27 +695,22 @@ def test_manage_new_queue_drafts( "group_id": group_id, "payload": draft1_mod, } - response = actions.modify_new_resource_draft( - client, - resource_route=V1_QUEUES_ROUTE, - draft_id=draft1_response["id"], - payload=draft1_mod, - ).get_json() + response = dioptra_client.queues.new_resource_drafts.modify( + draft_id=draft1_response["id"], **draft1_mod + ).json() asserts.assert_draft_response_contents_matches_expectations( response, draft1_mod_expected ) # test deletion - actions.delete_new_resource_draft( - client, resource_route=V1_QUEUES_ROUTE, draft_id=draft1_response["id"] - ) - asserts.assert_new_draft_is_not_found( - client, resource_route=V1_QUEUES_ROUTE, draft_id=draft1_response["id"] + dioptra_client.queues.new_resource_drafts.delete(draft_id=draft1_response["id"]) + asserts_client.assert_new_draft_is_not_found( + dioptra_client.queues.new_resource_drafts, draft_id=draft1_response["id"] ) def test_manage_queue_snapshots( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_queues: dict[str, Any], @@ -798,41 +729,37 @@ def test_manage_queue_snapshots( response """ queue_to_rename = registered_queues["queue1"] - modified_queue = modify_queue( - client, + modified_queue = dioptra_client.queues.modify_by_id( queue_id=queue_to_rename["id"], - new_name=queue_to_rename["name"] + "modified", - new_description=queue_to_rename["description"], - ).get_json() + name=queue_to_rename["name"] + "modified", + description=queue_to_rename["description"], + ).json() modified_queue.pop("hasDraft") queue_to_rename.pop("hasDraft") queue_to_rename["latestSnapshot"] = False queue_to_rename["lastModifiedOn"] = modified_queue["lastModifiedOn"] - asserts.assert_retrieving_snapshot_by_id_works( - client, - resource_route=V1_QUEUES_ROUTE, - resource_id=queue_to_rename["id"], + asserts_client.assert_retrieving_snapshot_by_id_works( + dioptra_client.queues.snapshots, + queue_to_rename["id"], snapshot_id=queue_to_rename["snapshot"], expected=queue_to_rename, ) - asserts.assert_retrieving_snapshot_by_id_works( - client, - resource_route=V1_QUEUES_ROUTE, - resource_id=modified_queue["id"], + asserts_client.assert_retrieving_snapshot_by_id_works( + dioptra_client.queues.snapshots, + modified_queue["id"], snapshot_id=modified_queue["snapshot"], expected=modified_queue, ) expected_snapshots = [queue_to_rename, modified_queue] - asserts.assert_retrieving_snapshots_works( - client, - resource_route=V1_QUEUES_ROUTE, - resource_id=queue_to_rename["id"], + asserts_client.assert_retrieving_snapshots_works( + dioptra_client.queues.snapshots, + queue_to_rename["id"], expected=expected_snapshots, ) def test_tag_queue( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_queues: dict[str, Any], @@ -848,52 +775,29 @@ def test_tag_queue( tags = [tag["id"] for tag in registered_tags.values()] # test append - response = actions.append_tags( - client, - resource_route=V1_QUEUES_ROUTE, - resource_id=queue["id"], - tag_ids=[tags[0], tags[1]], - ) + response = dioptra_client.queues.tags.append(queue["id"], ids=tags[:2]) asserts.assert_tags_response_contents_matches_expectations( - response.get_json(), [tags[0], tags[1]] - ) - response = actions.append_tags( - client, - resource_route=V1_QUEUES_ROUTE, - resource_id=queue["id"], - tag_ids=[tags[1], tags[2]], + response.json(), tags[:2] ) + response = dioptra_client.queues.tags.append(queue["id"], ids=tags[1:3]) asserts.assert_tags_response_contents_matches_expectations( - response.get_json(), [tags[0], tags[1], tags[2]] + response.json(), tags[:3] ) # test remove - actions.remove_tag( - client, resource_route=V1_QUEUES_ROUTE, resource_id=queue["id"], tag_id=tags[1] - ) - response = actions.get_tags( - client, resource_route=V1_QUEUES_ROUTE, resource_id=queue["id"] - ) + dioptra_client.queues.tags.remove(queue["id"], tag_id=tags[1]) + response = dioptra_client.queues.tags.get(queue["id"]) asserts.assert_tags_response_contents_matches_expectations( - response.get_json(), [tags[0], tags[2]] + response.json(), [tags[0], tags[2]] ) # test modify - response = actions.modify_tags( - client, - resource_route=V1_QUEUES_ROUTE, - resource_id=queue["id"], - tag_ids=[tags[1], tags[2]], - ) + response = dioptra_client.queues.tags.modify(queue["id"], ids=tags[1:3]) asserts.assert_tags_response_contents_matches_expectations( - response.get_json(), [tags[1], tags[2]] + response.json(), tags[1:3] ) # test delete - response = actions.remove_tags( - client, resource_route=V1_QUEUES_ROUTE, resource_id=queue["id"] - ) - response = actions.get_tags( - client, resource_route=V1_QUEUES_ROUTE, resource_id=queue["id"] - ) - asserts.assert_tags_response_contents_matches_expectations(response.get_json(), []) + dioptra_client.queues.tags.remove_all(queue["id"]) + response = dioptra_client.queues.tags.get(queue["id"]) + asserts.assert_tags_response_contents_matches_expectations(response.json(), []) diff --git a/tests/unit/restapi/v1/test_tag.py b/tests/unit/restapi/v1/test_tag.py index 6ca000900..c97a3c5ae 100644 --- a/tests/unit/restapi/v1/test_tag.py +++ b/tests/unit/restapi/v1/test_tag.py @@ -23,60 +23,11 @@ from typing import Any import pytest -from flask.testing import FlaskClient from flask_sqlalchemy import SQLAlchemy -from werkzeug.test import TestResponse -from dioptra.restapi.routes import V1_ROOT, V1_TAGS_ROUTE - -from ..lib import actions, helpers - -# -- Actions --------------------------------------------------------------------------- - - -def modify_tag( - client: FlaskClient, - tag_id: int, - new_name: str, -) -> TestResponse: - """Rename a tag using the API. - - Args: - client: The Flask test client. - tag_id: The id of the tag to rename. - new_name: The new name to assign to the tag. - - Returns: - The response from the API. - """ - payload: dict[str, Any] = {"name": new_name} - - return client.put( - f"/{V1_ROOT}/{V1_TAGS_ROUTE}/{tag_id}", - json=payload, - follow_redirects=True, - ) - - -def delete_tag( - client: FlaskClient, - tag_id: int, -) -> TestResponse: - """Delete a tag using the API. - - Args: - client: The Flask test client. - tag_id: The id of the tag to delete. - - Returns: - The response from the API. - """ - - return client.delete( - f"/{V1_ROOT}/{V1_TAGS_ROUTE}/{tag_id}", - follow_redirects=True, - ) +from dioptra.client.client import DioptraClient +from ..lib import helpers # -- Assertions ------------------------------------------------------------------------ @@ -119,14 +70,14 @@ def assert_tag_response_contents_matches_expectations( def assert_retrieving_tag_by_id_works( - client: FlaskClient, + dioptra_client: DioptraClient, tag_id: int, expected: dict[str, Any], ) -> None: """Assert that retrieving a tag by id works. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. tag_id: The id of the tag to retrieve. expected: The expected response from the API. @@ -134,12 +85,12 @@ def assert_retrieving_tag_by_id_works( AssertionError: If the response status code is not 200 or if the API response does not match the expected response. """ - response = client.get(f"/{V1_ROOT}/{V1_TAGS_ROUTE}/{tag_id}", follow_redirects=True) - assert response.status_code == 200 and response.get_json() == expected + response = dioptra_client.tags.get_by_id(tag_id) + assert response.status_code == 200 and response.json() == expected def assert_retrieving_tags_works( - client: FlaskClient, + dioptra_client: DioptraClient, expected: list[dict[str, Any]], group_id: int | None = None, search: str | None = None, @@ -148,7 +99,7 @@ def assert_retrieving_tags_works( """Assert that retrieving all tags works. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. expected: The expected response from the API. group_id: The group ID used in query parameters. search: The search string used in query parameters. @@ -159,31 +110,30 @@ def assert_retrieving_tags_works( does not match the expected response. """ - query_string: dict[str, Any] = {} + query_kwargs: dict[str, Any] = {} if group_id is not None: - query_string["groupId"] = group_id + query_kwargs["group_id"] = group_id if search is not None: - query_string["search"] = search + query_kwargs["search"] = search if paging_info is not None: - query_string["index"] = paging_info["index"] - query_string["pageLength"] = paging_info["page_length"] + query_kwargs["index"] = paging_info["index"] + query_kwargs["page_length"] = paging_info["page_length"] - response = client.get( - f"/{V1_ROOT}/{V1_TAGS_ROUTE}", - query_string=query_string, - follow_redirects=True, - ) - assert response.status_code == 200 and response.get_json()["data"] == expected + response = dioptra_client.tags.get(**query_kwargs) + assert response.status_code == 200 and response.json()["data"] == expected def assert_sorting_tag_works( - client: FlaskClient, - sortBy: str, - descending: bool, + dioptra_client: DioptraClient, expected: list[str], + sort_by: str | None, + descending: bool | None, + group_id: int | None = None, + search: str | None = None, + paging_info: dict[str, Any] | None = None, ) -> None: """Assert that tags can be sorted by column ascending/descending. @@ -197,46 +147,54 @@ def assert_sorting_tag_works( does not match the expected response. """ - query_string: dict[str, Any] = {} + query_kwargs: dict[str, Any] = {} - query_string["sortBy"] = sortBy - query_string["descending"] = descending + if sort_by is not None: + query_kwargs["sort_by"] = sort_by - response = client.get( - f"/{V1_ROOT}/{V1_TAGS_ROUTE}", - query_string=query_string, - follow_redirects=True, - ) + if descending is not None: + query_kwargs["descending"] = descending + + if group_id is not None: + query_kwargs["group_id"] = group_id + + if search is not None: + query_kwargs["search"] = search + + if paging_info is not None: + query_kwargs["index"] = paging_info["index"] + query_kwargs["page_length"] = paging_info["page_length"] - response_data = response.get_json() + response = dioptra_client.tags.get(**query_kwargs) + response_data = response.json() tag_ids = [tag["id"] for tag in response_data["data"]] assert response.status_code == 200 and tag_ids == expected def assert_registering_existing_tag_name_fails( - client: FlaskClient, name: str, group_id: int + dioptra_client: DioptraClient, name: str, group_id: int ) -> None: """Assert that registering a tag with an existing name fails. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. name: The name to assign to the new tag. Raises: AssertionError: If the response status code is not 400. """ - response = actions.register_tag(client, name=name, group_id=group_id) + response = dioptra_client.tags.create(group_id=group_id, name=name) assert response.status_code == 409 def assert_tag_name_matches_expected_name( - client: FlaskClient, tag_id: int, expected_name: str + dioptra_client: DioptraClient, tag_id: int, expected_name: str ) -> None: """Assert that the name of a tag matches the expected name. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. tag_id: The id of the tag to retrieve. expected_name: The expected name of the tag. @@ -244,53 +202,43 @@ def assert_tag_name_matches_expected_name( AssertionError: If the response status code is not 200 or if the name of the tag does not match the expected name. """ - response = client.get( - f"/{V1_ROOT}/{V1_TAGS_ROUTE}/{tag_id}", - follow_redirects=True, - ) - assert response.status_code == 200 and response.get_json()["name"] == expected_name + response = dioptra_client.tags.get_by_id(tag_id) + assert response.status_code == 200 and response.json()["name"] == expected_name def assert_tag_is_not_found( - client: FlaskClient, + dioptra_client: DioptraClient, tag_id: int, ) -> None: """Assert that a tag is not found. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. tag_id: The id of the tag to retrieve. Raises: AssertionError: If the response status code is not 404. """ - response = client.get( - f"/{V1_ROOT}/{V1_TAGS_ROUTE}/{tag_id}", - follow_redirects=True, - ) + response = dioptra_client.tags.get_by_id(tag_id) assert response.status_code == 404 def assert_cannot_rename_tag_with_existing_name( - client: FlaskClient, + dioptra_client: DioptraClient, tag_id: int, existing_name: str, ) -> None: """Assert that renaming a tag with an existing name fails. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. tag_id: The id of the tag to rename. name: The name of an existing tag. Raises: AssertionError: If the response status code is not 400. """ - response = modify_tag( - client=client, - tag_id=tag_id, - new_name=existing_name, - ) + response = dioptra_client.tags.modify_by_id(tag_id=tag_id, name=existing_name) assert response.status_code == 409 @@ -298,7 +246,7 @@ def assert_cannot_rename_tag_with_existing_name( def test_create_tag( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], ) -> None: @@ -315,12 +263,11 @@ def test_create_tag( name = "tag" user_id = auth_account["default_group_id"] group_id = auth_account["default_group_id"] - tag_response = actions.register_tag( - client, - name=name, + tag_response = dioptra_client.tags.create( group_id=group_id, + name=name, ) - tag_expected = tag_response.get_json() + tag_expected = tag_response.json() assert_tag_response_contents_matches_expectations( response=tag_expected, @@ -331,12 +278,12 @@ def test_create_tag( }, ) assert_retrieving_tag_by_id_works( - client, tag_id=tag_expected["id"], expected=tag_expected + dioptra_client, tag_id=tag_expected["id"], expected=tag_expected ) @pytest.mark.parametrize( - "sortBy, descending , expected", + "sort_by, descending , expected", [ (None, None, ["tag1", "tag2", "tag3"]), ("name", True, ["tag2", "tag1", "tag3"]), @@ -346,11 +293,11 @@ def test_create_tag( ], ) def test_tag_sort( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_tags: dict[str, Any], - sortBy: str, + sort_by: str | None, descending: bool, expected: list[str], ) -> None: @@ -366,11 +313,13 @@ def test_tag_sort( """ expected_ids = [registered_tags[expected_name]["id"] for expected_name in expected] - assert_sorting_tag_works(client, sortBy, descending, expected=expected_ids) + assert_sorting_tag_works( + dioptra_client, sort_by=sort_by, descending=descending, expected=expected_ids + ) def test_tag_search_query( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_tags: dict[str, Any], @@ -390,14 +339,14 @@ def test_tag_search_query( tag_expected_list = [tag1_expected, tag2_expected] assert_retrieving_tags_works( - client, + dioptra_client, expected=tag_expected_list, search="name:*tag*", ) def test_cannot_register_existing_tag_name( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_tags: dict[str, Any], @@ -412,12 +361,12 @@ def test_cannot_register_existing_tag_name( existing_tag = registered_tags["tag1"] assert_registering_existing_tag_name_fails( - client, name=existing_tag["name"], group_id=existing_tag["group"]["id"] + dioptra_client, name=existing_tag["name"], group_id=existing_tag["group"]["id"] ) def test_rename_tag( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_tags: dict[str, Any], @@ -437,23 +386,19 @@ def test_rename_tag( tag_to_rename = registered_tags["tag1"] existing_tag = registered_tags["tag2"] - modify_tag( - client, - tag_id=tag_to_rename["id"], - new_name=updated_tag_name, - ) + dioptra_client.tags.modify_by_id(tag_id=tag_to_rename["id"], name=updated_tag_name) assert_tag_name_matches_expected_name( - client, tag_id=tag_to_rename["id"], expected_name=updated_tag_name + dioptra_client, tag_id=tag_to_rename["id"], expected_name=updated_tag_name ) assert_cannot_rename_tag_with_existing_name( - client, + dioptra_client, tag_id=tag_to_rename["id"], existing_name=existing_tag["name"], ) def test_delete_tag_by_id( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_tags: dict[str, Any], @@ -472,7 +417,7 @@ def test_delete_tag_by_id( tag_expected = registered_tags["tag1"] assert_retrieving_tag_by_id_works( - client, tag_id=tag_expected["id"], expected=tag_expected + dioptra_client, tag_id=tag_expected["id"], expected=tag_expected ) - delete_tag(client, tag_id=tag_expected["id"]) - assert_tag_is_not_found(client, tag_id=tag_expected["id"]) + dioptra_client.tags.delete_by_id(tag_id=tag_expected["id"]) + assert_tag_is_not_found(dioptra_client, tag_id=tag_expected["id"]) diff --git a/tests/unit/restapi/v1/test_user.py b/tests/unit/restapi/v1/test_user.py index 48c866ebd..5085ff931 100644 --- a/tests/unit/restapi/v1/test_user.py +++ b/tests/unit/restapi/v1/test_user.py @@ -15,96 +15,18 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """Test suite for user operations. + This module contains a set of tests that validate the CRUD operations and additional -functionalities for the user entity. The tests ensure that the users can be -registered, modified, and deleted as expected through the REST API. +functionalities for the user entity. The tests ensure that the users can be registered, +modified, and deleted as expected through the REST API. """ from typing import Any -from flask.testing import FlaskClient from flask_sqlalchemy import SQLAlchemy -from werkzeug.test import TestResponse - -from dioptra.restapi.routes import V1_ROOT, V1_USERS_ROUTE - -from ..lib import actions, helpers - -# -- Actions --------------------------------------------------------------------------- - - -def modify_current_user( - client: FlaskClient, - new_username: str, - new_email: str, -) -> TestResponse: - """Change the current user's email using the API. - - Args: - client The Flask test client. - new_email: The new email to assign to the user. - - Returns: - The response from the API. - """ - payload = {"username": new_username, "email": new_email} - - return client.put( - f"/{V1_ROOT}/{V1_USERS_ROUTE}/current", json=payload, follow_redirects=True - ) - - -def delete_current_user( - client: FlaskClient, - password: str, -) -> TestResponse: - """Delete the current user using the API. - - Args: - client: The Flask test client. - - Returns: - The response from the API. - """ - payload = {"password": password} - return client.delete( - f"/{V1_ROOT}/{V1_USERS_ROUTE}/current", json=payload, follow_redirects=True - ) - - -def change_current_user_password( - client: FlaskClient, - old_password: str, - new_password: str, -): - """Change the current user password using the API.""" - payload = { - "oldPassword": old_password, - "newPassword": new_password, - "confirmNewPassword": new_password, - } - return client.post( - f"/{V1_ROOT}/{V1_USERS_ROUTE}/current/password", - json=payload, - follow_redirects=True, - ) - - -def change_user_password( - client: FlaskClient, user_id: int, old_password: str, new_password: str -): - """Change a user password using its ID using the API.""" - payload = { - "oldPassword": old_password, - "newPassword": new_password, - "confirmNewPassword": new_password, - } - return client.post( - f"/{V1_ROOT}/{V1_USERS_ROUTE}/{user_id}/password", - json=payload, - follow_redirects=True, - ) +from dioptra.client.client import DioptraClient +from ..lib import helpers # -- Assertions ---------------------------------------------------------------- @@ -170,12 +92,12 @@ def assert_user_response_contents_matches_expectations( def assert_retrieving_user_by_id_works( - client: FlaskClient, user_id: int, expected: dict[str, Any] + dioptra_client: DioptraClient, user_id: int, expected: dict[str, Any] ) -> None: """Assert that retrieving a user by id works. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. user_id: The id of the user to retrieve. expected: The expected response from the API. @@ -183,59 +105,35 @@ def assert_retrieving_user_by_id_works( AssertionError: If the response status code is not 200 or if the API response does not match the expected response. """ - response = client.get( - f"/{V1_ROOT}/{V1_USERS_ROUTE}/{user_id}", follow_redirects=True - ) - assert response.status_code == 200 and response.get_json() == expected + response = dioptra_client.users.get_by_id(user_id) + assert response.status_code == 200 and response.json() == expected def assert_retrieving_current_user_works( - client: FlaskClient, + dioptra_client: DioptraClient, expected: dict[str, Any], ) -> None: """Assert that retrieving the current user works. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. expected: The expected response from the API. Raises: AssertionError: If the response status code is not 200 or if the API response does not match the expected response. """ - response = client.get(f"/{V1_ROOT}/{V1_USERS_ROUTE}/current", follow_redirects=True) + response = dioptra_client.users.get_current() to_ignore = ["lastLoginOn", "lastModifiedOn"] response_info_filtered = { - k: v for k, v in response.get_json().items() if k not in to_ignore + k: v for k, v in response.json().items() if k not in to_ignore } expected_filtered = {k: v for k, v in expected.items() if k not in to_ignore} assert response.status_code == 200 and response_info_filtered == expected_filtered -def assert_retrieving_all_users_works( - client: FlaskClient, - expected: list[dict[str, Any]], -) -> None: - """Assert that retrieving all queues works. - - Args: - client: The Flask test client. - expected: The expected response from the API. - - Raises: - AssertionError: If the response status code is not 200 or if the API response - does not match the expected response. - """ - response = client.get( - f"/{V1_ROOT}/{V1_USERS_ROUTE}", - query_string={}, - follow_redirects=True, - ) - assert response.status_code == 200 and response.get_json()["data"] == expected - - def assert_retrieving_users_works( - client: FlaskClient, + dioptra_client: DioptraClient, expected: list[dict[str, Any]], search: str | None = None, paging_info: dict[str, int] | None = None, @@ -243,7 +141,7 @@ def assert_retrieving_users_works( """Assert that retrieving all users works. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. expected: The expected response from the API. search: The search string used in query parameters. paging_info: The paging information used in query parameters. @@ -252,7 +150,7 @@ def assert_retrieving_users_works( AssertionError: If the response status code is not 200 or if the API response does not match the expected response. """ - query_string = {} + query_string: dict[str, Any] = {} if search is not None: query_string["search"] = search @@ -261,63 +159,59 @@ def assert_retrieving_users_works( query_string["index"] = paging_info["index"] query_string["pageLength"] = paging_info["page_length"] - response = client.get( - f"/{V1_ROOT}/{V1_USERS_ROUTE}", - query_string=query_string, - follow_redirects=True, - ) - assert response.status_code == 200 and response.get_json()["data"] == expected + response = dioptra_client.users.get(**query_string) + assert response.status_code == 200 and response.json()["data"] == expected def assert_registering_existing_username_fails( - client: FlaskClient, + dioptra_client: DioptraClient, existing_username: str, non_existing_email: str, ) -> None: """Assert that registering a user with an existing username fails. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. username: The username to assign to the new user. Raises: AssertionError: If the response status code is not 400. """ password = "supersecurepassword" - response = actions.register_user( - client, existing_username, non_existing_email, password + response = dioptra_client.users.create( + username=existing_username, email=non_existing_email, password=password ) assert response.status_code == 409 def assert_registering_existing_email_fails( - client: FlaskClient, + dioptra_client: DioptraClient, non_existing_username: str, existing_email: str, ) -> None: """Assert that registering a user with an existing username fails. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. username: The username to assign to the new user. Raises: AssertionError: If the response status code is not 400. """ password = "supersecurepassword" - response = actions.register_user( - client, non_existing_username, existing_email, password + response = dioptra_client.users.create( + username=non_existing_username, email=existing_email, password=password ) assert response.status_code == 409 def assert_user_username_matches_expected_name( - client: FlaskClient, user_id: int, expected_name: str + dioptra_client: DioptraClient, user_id: int, expected_name: str ) -> None: """Assert that the name of a user matches the expected name. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. user_id: The id of the user to retrieve. expected_name: The expected name of the user. @@ -325,153 +219,144 @@ def assert_user_username_matches_expected_name( AssertionError: If the response status code is not 200 or if the name of the user does not match the expected name. """ - response = client.get( - f"/{V1_ROOT}/{V1_USERS_ROUTE}/{user_id}", - follow_redirects=True, - ) - assert response.status_code == 200 and response.get_json()["name"] == expected_name + response = dioptra_client.users.get_by_id(user_id) + assert response.status_code == 200 and response.json()["name"] == expected_name def assert_current_user_username_matches_expected_name( - client: FlaskClient, expected_name: str + dioptra_client: DioptraClient, expected_name: str ) -> None: """Assert that the name of the current user matches the expected name. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. expected_name: The expected name of the user. Raises: AssertionError: If the response status code is not 200 or if the name of the user does not match the expected name. """ - response = client.get( - f"/{V1_ROOT}/{V1_USERS_ROUTE}/current", - follow_redirects=True, - ) - assert response.status_code == 200 and response.get_json()["name"] == expected_name + response = dioptra_client.users.get_current() + assert response.status_code == 200 and response.json()["name"] == expected_name def assert_user_is_not_found( - client: FlaskClient, + dioptra_client: DioptraClient, user_id: int, ) -> None: """Assert that a user is not found. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. user_id: The id of the user to retrieve. Raises: AssertionError: If the response status code is not 404. """ - response = client.get( - f"/{V1_ROOT}/{V1_USERS_ROUTE}/{user_id}", - follow_redirects=True, - ) + response = dioptra_client.users.get_by_id(user_id) assert response.status_code == 404 def assert_cannot_rename_user_with_existing_username( - client: FlaskClient, + dioptra_client: DioptraClient, existing_username: str, ) -> None: """Assert that renaming a user with an existing username fails. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. existing_username: The username of the existing user. Raises: AssertionError: If the response status code is not 400. """ - response = modify_current_user( - client=client, - new_username=existing_username, - new_email="new_email", + response = dioptra_client.users.modify_current_user( + username=existing_username, + email="new_email", ) assert response.status_code == 400 def assert_cannot_rename_user_with_existing_email( - client: FlaskClient, + dioptra_client: DioptraClient, existing_email: str, ) -> None: """Assert that changing a user email with an existing email fails. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. existing_email: The email of the existing user. Raises: AssertionError: If the response status code is not 400. """ - response = modify_current_user( - client=client, - new_username="new_username", - new_email=existing_email, + response = dioptra_client.users.modify_current_user( + username="new_username", + email=existing_email, ) assert response.status_code == 400 def assert_login_works( - client: FlaskClient, + dioptra_client: DioptraClient, username: str, password: str, ): """Assert that logging in using a username and password works. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. username: The username of the user to be logged in. password: The password of the user to be logged in. Raises: AssertionError: If the response status code is not 200. """ - assert actions.login(client, username, password).status_code == 200 + assert dioptra_client.auth.login(username, password).status_code == 200 def assert_user_does_not_exist( - client: FlaskClient, + dioptra_client: DioptraClient, username: str, password: str, ): """Assert that the user does not exist. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. username: The username of the user to be logged in. password: The password of the user to be logged in. Raises: AssertionError: If the response status code is not 404. """ - assert actions.login(client, username, password).status_code == 404 + response = dioptra_client.auth.login(username, password) + assert response.status_code == 404 def assert_login_is_unauthorized( - client: FlaskClient, + dioptra_client: DioptraClient, username: str, password: str, ): """Assert that logging in using a username and password is unauthorized. Args: - client: The Flask test client. + dioptra_client: The Dioptra client. username: The username of the user to be logged in. password: The password of the user to be logged in. Raises: AssertionError: If the response status code is not 401. """ - assert actions.login(client, username, password).status_code == 401 + response = dioptra_client.auth.login(username, password) + assert response.status_code == 401 def assert_new_password_cannot_be_existing( - client: FlaskClient, + dioptra_client: DioptraClient, password: str, - user_id: str = None, + user_id: str | None = None, ): """Assert that changing a user (current or otherwise) password to the existing password fails. @@ -486,21 +371,21 @@ def assert_new_password_cannot_be_existing( AssertionError: If the response status code is not 400. """ # Means we are the current user. - if not user_id: - assert ( - change_current_user_password(client, password, password).status_code == 403 - ) + if user_id is None: + response = dioptra_client.users.change_current_user_password(password, password) + assert response.status_code == 403 else: - assert ( - change_user_password(client, user_id, password, password).status_code == 403 + response = dioptra_client.users.change_password_by_id( + user_id, password, password ) + assert response.status_code == 403 # -- Tests ------------------------------------------------------------- def test_create_user( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, ) -> None: """Test that we can create a user and its response is expected. @@ -520,7 +405,7 @@ def test_create_user( password = "supersecurepassword" # Posting a user returns CurrentUserSchema. - user_response = actions.register_user(client, username, email, password).get_json() + user_response = dioptra_client.users.create(username, email, password).json() assert_user_response_contents_matches_expectations( response=user_response, expected_contents={ @@ -530,18 +415,20 @@ def test_create_user( current_user=True, ) - actions.login(client, username, password).get_json() - assert_retrieving_current_user_works(client, expected=user_response) + dioptra_client.auth.login(username, password) + assert_retrieving_current_user_works(dioptra_client, expected=user_response) # Getting a user by id returns UserSchema. user_expected = { k: v for k, v in user_response.items() if k in ["username", "email", "id"] } - assert_retrieving_user_by_id_works(client, user_expected["id"], user_expected) + assert_retrieving_user_by_id_works( + dioptra_client, user_expected["id"], user_expected + ) def test_user_get_all( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_users: dict[str, Any], @@ -559,11 +446,11 @@ def test_user_get_all( {"username": user["username"], "email": user["email"], "id": user["id"]} for user in list(registered_users.values()) ] - assert_retrieving_users_works(client, expected=user_expected_list) + assert_retrieving_users_works(dioptra_client, expected=user_expected_list) def test_user_search_query( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_users: dict[str, Any], @@ -581,23 +468,25 @@ def test_user_search_query( for user in list(registered_users.values())[:2] ] assert_retrieving_users_works( - client, expected=user_expected_list, search="username:*user*" + dioptra_client, expected=user_expected_list, search="username:*user*" + ) + assert_retrieving_users_works( + dioptra_client, expected=user_expected_list, search="username:'*user*'" ) assert_retrieving_users_works( - client, expected=user_expected_list, search="username:'*user*'" + dioptra_client, expected=user_expected_list, search='username:"user?"' ) assert_retrieving_users_works( - client, expected=user_expected_list, search='username:"user?"' + dioptra_client, expected=user_expected_list, search="username:user?,email:user*" ) assert_retrieving_users_works( - client, expected=user_expected_list, search="username:user?,email:user*" + dioptra_client, expected=[], search=r"username:\*user*" ) - assert_retrieving_users_works(client, expected=[], search=r"username:\*user*") - assert_retrieving_users_works(client, expected=[], search="email:user?") + assert_retrieving_users_works(dioptra_client, expected=[], search="email:user?") def test_cannot_register_existing_username( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, registered_users: dict[str, Any], ) -> None: @@ -612,14 +501,14 @@ def test_cannot_register_existing_username( """ existing_user = registered_users["user1"] assert_registering_existing_username_fails( - client, + dioptra_client, existing_username=existing_user["username"], non_existing_email="unique" + existing_user["email"], ) def test_cannot_register_existing_email( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, registered_users: dict[str, Any], ) -> None: @@ -634,14 +523,14 @@ def test_cannot_register_existing_email( """ existing_user = registered_users["user1"] assert_registering_existing_email_fails( - client, + dioptra_client, non_existing_username="unique" + existing_user["username"], existing_email=existing_user["email"], ) def test_rename_current_user( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], ) -> None: @@ -654,12 +543,15 @@ def test_rename_current_user( that reflects the updated username. """ new_username = "new_name" - user = modify_current_user(client, new_username, auth_account["email"]).get_json() - assert_retrieving_current_user_works(client, expected=user) + user = dioptra_client.users.modify_current_user( + username=new_username, + email=auth_account["email"], + ).json() + assert_retrieving_current_user_works(dioptra_client, expected=user) def test_user_authorization_failure( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, registered_users: dict[str, Any], ) -> None: @@ -672,11 +564,11 @@ def test_user_authorization_failure( """ username = registered_users["user2"]["username"] password = registered_users["user2"]["password"] + "incorrect" - assert_login_is_unauthorized(client, username=username, password=password) + assert_login_is_unauthorized(dioptra_client, username=username, password=password) def test_delete_current_user( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], ) -> None: @@ -689,12 +581,12 @@ def test_delete_current_user( """ username = auth_account["username"] password = auth_account["password"] - delete_current_user(client, password) - assert_user_does_not_exist(client, username=username, password=password) + dioptra_client.users.delete_current_user(password) + assert_user_does_not_exist(dioptra_client, username=username, password=password) def test_change_current_user_password( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], ): @@ -708,12 +600,12 @@ def test_change_current_user_password( username = auth_account["username"] old_password = auth_account["password"] new_password = "new_password" - change_current_user_password(client, old_password, new_password) - assert_login_works(client, username=username, password=new_password) + dioptra_client.users.change_current_user_password(old_password, new_password) + assert_login_works(dioptra_client, username=username, password=new_password) def test_change_user_password( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], registered_users: dict[str, Any], @@ -729,12 +621,12 @@ def test_change_user_password( username = registered_users["user2"]["username"] old_password = registered_users["user2"]["password"] new_password = "new_password" - change_user_password(client, user_id, old_password, new_password) - assert_login_works(client, username=username, password=new_password) + dioptra_client.users.change_password_by_id(user_id, old_password, new_password) + assert_login_works(dioptra_client, username=username, password=new_password) def test_new_password_cannot_be_existing( - client: FlaskClient, + dioptra_client: DioptraClient, db: SQLAlchemy, auth_account: dict[str, Any], ): @@ -751,6 +643,6 @@ def test_new_password_cannot_be_existing( user_id = auth_account["id"] password = auth_account["password"] # test via /users/current - assert_new_password_cannot_be_existing(client, password) + assert_new_password_cannot_be_existing(dioptra_client, password) # test via /users/{user_id} - assert_new_password_cannot_be_existing(client, password, user_id) + assert_new_password_cannot_be_existing(dioptra_client, password, user_id)