diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index a74f70c54379..a86f236f49d7 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -8,7 +8,7 @@ jobs: - name: Run checks run: | - pipx install $(grep "^black" ./cvat-cli/requirements/development.txt) + pipx install $(grep "^black" ./dev/requirements.txt) echo "Black version: $(black --version)" diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml index 19332d917030..620dc6c85d79 100644 --- a/.github/workflows/isort.yml +++ b/.github/workflows/isort.yml @@ -5,35 +5,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - id: files - uses: tj-actions/changed-files@v41.0.0 - with: - files: | - cvat-sdk/**/*.py - cvat-cli/**/*.py - tests/python/**/*.py - cvat/apps/quality_control/**/*.py - cvat/apps/analytics_report/**/*.py - dir_names: true - name: Run checks run: | - # If different modules use different isort configs, - # we need to run isort for each python component group separately. - # Otherwise, they all will use the same config. + pipx install $(grep "^isort" ./dev/requirements.txt) - UPDATED_DIRS="${{steps.files.outputs.all_changed_files}}" + echo "isort version: $(isort --version-number)" - if [[ ! -z $UPDATED_DIRS ]]; then - pipx install $(egrep "isort.*" ./cvat-cli/requirements/development.txt) - - echo "isort version: $(isort --version-number)" - echo "The dirs will be checked: $UPDATED_DIRS" - EXIT_CODE=0 - for DIR in $UPDATED_DIRS; do - isort --check $DIR || EXIT_CODE=$(($? | $EXIT_CODE)) || true - done - exit $EXIT_CODE - else - echo "No files with the \"py\" extension found" - fi + isort --check --diff --resolve-all-configs . diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index d808a823771f..05237f441988 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -19,11 +19,11 @@ jobs: CHANGED_FILES="${{steps.files.outputs.all_changed_files}}" if [[ ! -z $CHANGED_FILES ]]; then - pipx install $(egrep "^pylint==" ./cvat/requirements/development.txt) + pipx install $(grep "^pylint==" ./dev/requirements.txt) pipx inject pylint \ - $(egrep "^pylint-.+==" ./cvat/requirements/development.txt) \ - $(egrep "^django==" ./cvat/requirements/base.txt) + $(grep "^pylint-.\+==" ./dev/requirements.txt) \ + $(grep "^django==" ./cvat/requirements/base.txt) echo "Pylint version: "$(pylint --version | head -1) echo "The files will be checked: "$(echo $CHANGED_FILES) diff --git a/.gitignore b/.gitignore index c375c7df4e7e..5ea2759c829e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,6 @@ /share/ /static/ /db.sqlite3 -/.*env* /keys /logs /profiles @@ -21,6 +20,11 @@ __pycache__ .coverage .husky/ .python-version +tmp*cvat/ +temp*/ + +# Ignore generated test files +docker-compose.tests.yml # Ignore npm logs file npm-debug.log* @@ -49,8 +53,8 @@ yarn-error.log* # Ignore all the installed packages node_modules -venv/ -.venv/ +/*env*/ +/.*env* # Ignore all js dists cvat-data/dist diff --git a/CHANGELOG.md b/CHANGELOG.md index 31a4aae3db7f..a18a0284f814 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,38 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 + +## \[2.25.0\] - 2025-01-09 + +### Added + +- \[CLI\] Added commands for working with native functions + () + +- Ultralytics YOLO formats now support tracks + () + +### Changed + +- YOLOv8 formats renamed to Ultralytics YOLO formats + () + +- The `match_empty_frames` quality setting is changed to `empty_is_annotated`. + The updated option includes any empty frames in the final metrics instead of only + matching empty frames. This makes metrics such as Precision much more representative and useful. + () + +### Fixed + +- Changing rotation after export/import in Ultralytics YOLO Oriented Boxes format + () + +- Export to yolo formats if both Train and default dataset are present + () + +- Issue with deleting frames + () + ## \[2.24.0\] - 2024-12-20 diff --git a/README.md b/README.md index 2a252b4eaed8..6ca7185523d5 100644 --- a/README.md +++ b/README.md @@ -175,11 +175,11 @@ For more information about the supported formats, see: | [Kitti Raw Format](https://www.cvlibs.net/datasets/kitti/raw_data.php) | ✔️ | ✔️ | | [LFW](http://vis-www.cs.umass.edu/lfw/) | ✔️ | ✔️ | | [Supervisely Point Cloud Format](https://docs.supervise.ly/data-organization/00_ann_format_navi) | ✔️ | ✔️ | -| [YOLOv8 Detection](https://docs.ultralytics.com/datasets/detect/) | ✔️ | ✔️ | -| [YOLOv8 Oriented Bounding Boxes](https://docs.ultralytics.com/datasets/obb/) | ✔️ | ✔️ | -| [YOLOv8 Segmentation](https://docs.ultralytics.com/datasets/segment/) | ✔️ | ✔️ | -| [YOLOv8 Pose](https://docs.ultralytics.com/datasets/pose/) | ✔️ | ✔️ | -| [YOLOv8 Classification](https://docs.ultralytics.com/datasets/classify/) | ✔️ | ✔️ | +| [Ultralytics YOLO Detection](https://docs.ultralytics.com/datasets/detect/) | ✔️ | ✔️ | +| [Ultralytics YOLO Oriented Bounding Boxes](https://docs.ultralytics.com/datasets/obb/) | ✔️ | ✔️ | +| [Ultralytics YOLO Segmentation](https://docs.ultralytics.com/datasets/segment/) | ✔️ | ✔️ | +| [Ultralytics YOLO Pose](https://docs.ultralytics.com/datasets/pose/) | ✔️ | ✔️ | +| [Ultralytics YOLO Classification](https://docs.ultralytics.com/datasets/classify/) | ✔️ | ✔️ | diff --git a/cvat-cli/README.md b/cvat-cli/README.md index bbd98c0980c9..fcee05dae1c4 100644 --- a/cvat-cli/README.md +++ b/cvat-cli/README.md @@ -22,6 +22,11 @@ The following subcommands are supported: - `backup` - back up a task - `auto-annotate` - automatically annotate a task using a local function +- Functions (Enterprise/Cloud only): + - `create-native` - create a function that can be powered by an agent + - `delete` - delete a function + - `run-agent` - process requests for a native function + ## Installation `pip install cvat-cli` diff --git a/cvat-cli/requirements/base.txt b/cvat-cli/requirements/base.txt index 664017edebe3..a40cefb84e01 100644 --- a/cvat-cli/requirements/base.txt +++ b/cvat-cli/requirements/base.txt @@ -1,3 +1,5 @@ -cvat-sdk~=2.24.0 +cvat-sdk==2.25.0 + +attrs>=24.2.0 Pillow>=10.3.0 setuptools>=70.0.0 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/cvat-cli/requirements/development.txt b/cvat-cli/requirements/development.txt deleted file mode 100644 index 42a144087213..000000000000 --- a/cvat-cli/requirements/development.txt +++ /dev/null @@ -1,5 +0,0 @@ --r base.txt - -black>=24.1 -isort>=5.10.1 -pylint>=2.7.0 \ No newline at end of file diff --git a/cvat-cli/src/cvat_cli/__main__.py b/cvat-cli/src/cvat_cli/__main__.py index c93569182c08..7c649747cb31 100755 --- a/cvat-cli/src/cvat_cli/__main__.py +++ b/cvat-cli/src/cvat_cli/__main__.py @@ -11,7 +11,12 @@ from cvat_sdk import exceptions from ._internal.commands_all import COMMANDS -from ._internal.common import build_client, configure_common_arguments, configure_logger +from ._internal.common import ( + CriticalError, + build_client, + configure_common_arguments, + configure_logger, +) from ._internal.utils import popattr logger = logging.getLogger(__name__) @@ -29,7 +34,7 @@ def main(args: list[str] = None): try: with build_client(parsed_args, logger=logger) as client: popattr(parsed_args, "_executor")(client, **vars(parsed_args)) - except (exceptions.ApiException, urllib3.exceptions.HTTPError) as e: + except (exceptions.ApiException, urllib3.exceptions.HTTPError, CriticalError) as e: logger.critical(e) return 1 diff --git a/cvat-cli/src/cvat_cli/_internal/agent.py b/cvat-cli/src/cvat_cli/_internal/agent.py new file mode 100644 index 000000000000..820a758e54d2 --- /dev/null +++ b/cvat-cli/src/cvat_cli/_internal/agent.py @@ -0,0 +1,351 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import concurrent.futures +import json +import multiprocessing +import random +import secrets +import shutil +import tempfile +import time +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Optional + +import cvat_sdk.auto_annotation as cvataa +import cvat_sdk.datasets as cvatds +import urllib3.exceptions +from cvat_sdk import Client, models +from cvat_sdk.auto_annotation.driver import ( + _AnnotationMapper, + _DetectionFunctionContextImpl, + _LabelNameMapping, + _SpecNameMapping, +) +from cvat_sdk.exceptions import ApiException + +from .common import CriticalError, FunctionLoader + +FUNCTION_PROVIDER_NATIVE = "native" +FUNCTION_KIND_DETECTOR = "detector" + +_POLLING_INTERVAL_MEAN = timedelta(seconds=60) +_POLLING_INTERVAL_MAX_OFFSET = timedelta(seconds=10) + +_UPDATE_INTERVAL = timedelta(seconds=30) + + +class _RecoverableExecutor: + # A wrapper around ProcessPoolExecutor that recreates the underlying + # executor when a worker crashes. + def __init__(self, initializer, initargs): + self._mp_context = multiprocessing.get_context("spawn") + self._initializer = initializer + self._initargs = initargs + self._executor = self._new_executor() + + def _new_executor(self): + return concurrent.futures.ProcessPoolExecutor( + max_workers=1, + mp_context=self._mp_context, + initializer=self._initializer, + initargs=self._initargs, + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._executor.shutdown() + + def submit(self, func, /, *args, **kwargs): + return self._executor.submit(func, *args, **kwargs) + + def result(self, future: concurrent.futures.Future): + try: + return future.result() + except concurrent.futures.BrokenExecutor: + self._executor.shutdown() + self._executor = self._new_executor() + raise + + +_current_function: cvataa.DetectionFunction + + +def _worker_init(function_loader: FunctionLoader): + global _current_function + _current_function = function_loader.load() + + +def _worker_job_get_function_spec(): + return _current_function.spec + + +def _worker_job_detect(context, image): + return _current_function.detect(context, image) + + +class _Agent: + def __init__(self, client: Client, executor: _RecoverableExecutor, function_id: int): + self._rng = random.Random() # nosec + + self._client = client + self._executor = executor + self._function_id = function_id + self._function_spec = self._executor.result( + self._executor.submit(_worker_job_get_function_spec) + ) + + _, response = self._client.api_client.call_api( + "/api/functions/{function_id}", + "GET", + path_params={"function_id": self._function_id}, + ) + + remote_function = json.loads(response.data) + + self._validate_function_compatibility(remote_function) + + self._agent_id = secrets.token_hex(16) + self._client.logger.info("Agent starting with ID %r", self._agent_id) + + self._cached_task_id = None + + def _validate_function_compatibility(self, remote_function: dict) -> None: + function_id = remote_function["id"] + + if remote_function["provider"] != FUNCTION_PROVIDER_NATIVE: + raise CriticalError( + f"Function #{function_id} has provider {remote_function['provider']!r}. " + f"Agents can only be run for functions with provider {FUNCTION_PROVIDER_NATIVE!r}." + ) + + if isinstance(self._function_spec, cvataa.DetectionFunctionSpec): + self._validate_detection_function_compatibility(remote_function) + self._calculate_result_for_ar = self._calculate_result_for_detection_ar + else: + raise CriticalError( + f"Unsupported function spec type: {type(self._function_spec).__name__}" + ) + + def _validate_detection_function_compatibility(self, remote_function: dict) -> None: + incompatible_msg = ( + f"Function #{remote_function['id']} is incompatible with function object: " + ) + + if remote_function["kind"] != FUNCTION_KIND_DETECTOR: + raise CriticalError( + incompatible_msg + + f"kind is {remote_function['kind']!r} (expected {FUNCTION_KIND_DETECTOR!r})." + ) + + labels_by_name = {label.name: label for label in self._function_spec.labels} + + for remote_label in remote_function["labels_v2"]: + label = labels_by_name.get(remote_label["name"]) + + if not label: + raise CriticalError( + incompatible_msg + f"label {remote_label['name']!r} is not supported." + ) + + if ( + remote_label["type"] not in {"any", "unknown"} + and remote_label["type"] != label.type + ): + raise CriticalError( + incompatible_msg + + f"label {remote_label['name']!r} has type {remote_label['type']!r}, " + f"but the function object expects type {label.type!r}." + ) + + if remote_label["attributes"]: + raise CriticalError( + incompatible_msg + + f"label {remote_label['name']!r} has attributes, which is not supported." + ) + + def _wait_between_polls(self): + # offset the interval randomly to avoid synchronization between workers + max_offset_sec = _POLLING_INTERVAL_MAX_OFFSET.total_seconds() + offset_sec = self._rng.uniform(-max_offset_sec, max_offset_sec) + time.sleep(_POLLING_INTERVAL_MEAN.total_seconds() + offset_sec) + + def run(self, *, burst: bool) -> None: + if burst: + while ar_assignment := self._poll_for_ar(): + self._process_ar(ar_assignment) + self._client.logger.info("No annotation requests left in queue; exiting.") + else: + while True: + if ar_assignment := self._poll_for_ar(): + self._process_ar(ar_assignment) + else: + self._wait_between_polls() + + def _process_ar(self, ar_assignment: dict) -> None: + self._client.logger.info("Got annotation request assignment: %r", ar_assignment) + + ar_id = ar_assignment["ar_id"] + + try: + result = self._calculate_result_for_ar(ar_id, ar_assignment["ar_params"]) + + self._client.logger.info("Submitting result for AR %r...", ar_id) + self._client.api_client.call_api( + "/api/functions/queues/{queue_id}/requests/{request_id}/complete", + "POST", + path_params={"queue_id": f"function:{self._function_id}", "request_id": ar_id}, + body={"agent_id": self._agent_id, "annotations": result}, + ) + self._client.logger.info("AR %r completed", ar_id) + except Exception as ex: + self._client.logger.error("Failed to process AR %r", ar_id, exc_info=True) + + # Arbitrary exceptions may contain details of the client's system or code, which + # shouldn't be exposed to the server (and to users of the function). + # Therefore, we only produce a limited amount of detail, and only in known failure cases. + error_message = "Unknown error" + + if isinstance(ex, ApiException): + if ex.status: + error_message = f"Received HTTP status {ex.status}" + else: + error_message = "Failed an API call" + elif isinstance(ex, urllib3.exceptions.RequestError): + if isinstance(ex, urllib3.exceptions.MaxRetryError): + ex_type = type(ex.reason) + else: + ex_type = type(ex) + + error_message = f"Failed to make an HTTP request to {ex.url} ({ex_type.__name__})" + elif isinstance(ex, urllib3.exceptions.HTTPError): + error_message = "Failed to make an HTTP request" + elif isinstance(ex, cvataa.BadFunctionError): + error_message = "Underlying function returned incorrect result: " + str(ex) + elif isinstance(ex, concurrent.futures.BrokenExecutor): + error_message = "Worker process crashed" + + try: + self._client.api_client.call_api( + "/api/functions/queues/{queue_id}/requests/{request_id}/fail", + "POST", + path_params={ + "queue_id": f"function:{self._function_id}", + "request_id": ar_id, + }, + body={"agent_id": self._agent_id, "exc_info": error_message}, + ) + except Exception: + self._client.logger.error("Couldn't fail AR %r", ar_id, exc_info=True) + else: + self._client.logger.info("AR %r failed", ar_id) + + def _poll_for_ar(self) -> Optional[dict]: + while True: + self._client.logger.info("Trying to acquire an annotation request...") + try: + _, response = self._client.api_client.call_api( + "/api/functions/queues/{queue_id}/requests/acquire", + "POST", + path_params={"queue_id": f"function:{self._function_id}"}, + body={"agent_id": self._agent_id, "request_category": "batch"}, + ) + break + except (urllib3.exceptions.HTTPError, ApiException) as ex: + if isinstance(ex, ApiException) and ex.status and 400 <= ex.status < 500: + # We did something wrong; no point in retrying. + raise + + self._client.logger.error("Acquire request failed; will retry", exc_info=True) + self._wait_between_polls() + + response_data = json.loads(response.data) + return response_data["ar_assignment"] + + def _calculate_result_for_detection_ar( + self, ar_id: str, ar_params + ) -> models.PatchedLabeledDataRequest: + if ar_params["type"] != "annotate_task": + raise RuntimeError(f"Unsupported AR type: {ar_params['type']!r}") + + if ar_params["task"] != self._cached_task_id: + # To avoid uncontrolled disk usage, + # we'll only keep one task in the cache at a time. + self._client.logger.info("Switched to a new task; clearing the cache...") + if self._client.config.cache_dir.exists(): + shutil.rmtree(self._client.config.cache_dir) + + ds = cvatds.TaskDataset(self._client, ar_params["task"], load_annotations=False) + + self._cached_task_id = ar_params["task"] + + # Fetching the dataset might take a while, so do a progress update to let the server + # know we're still alive. + self._update_ar(ar_id, 0) + last_update_timestamp = datetime.now(tz=timezone.utc) + + mapping = ar_params["mapping"] + conv_mask_to_poly = ar_params["conv_mask_to_poly"] + + spec_nm = _SpecNameMapping( + labels={k: _LabelNameMapping(v["name"]) for k, v in mapping.items()} + ) + + mapper = _AnnotationMapper( + self._client.logger, + self._function_spec.labels, + ds.labels, + allow_unmatched_labels=False, + spec_nm=spec_nm, + conv_mask_to_poly=conv_mask_to_poly, + ) + + all_annotations = models.PatchedLabeledDataRequest(shapes=[]) + + for sample_index, sample in enumerate(ds.samples): + context = _DetectionFunctionContextImpl( + frame_name=sample.frame_name, + conf_threshold=ar_params["threshold"], + conv_mask_to_poly=conv_mask_to_poly, + ) + shapes = self._executor.result( + self._executor.submit(_worker_job_detect, context, sample.media.load_image()) + ) + + mapper.validate_and_remap(shapes, sample.frame_index) + all_annotations.shapes.extend(shapes) + + current_timestamp = datetime.now(tz=timezone.utc) + + if current_timestamp >= last_update_timestamp + _UPDATE_INTERVAL: + self._update_ar(ar_id, (sample_index + 1) / len(ds.samples)) + last_update_timestamp = current_timestamp + + return all_annotations + + def _update_ar(self, ar_id: str, progress: float) -> None: + self._client.logger.info("Updating AR %r progress to %.2f%%", ar_id, progress * 100) + self._client.api_client.call_api( + "/api/functions/queues/{queue_id}/requests/{request_id}/update", + "POST", + path_params={"queue_id": f"function:{self._function_id}", "request_id": ar_id}, + body={"agent_id": self._agent_id, "progress": progress}, + ) + + +def run_agent( + client: Client, function_loader: FunctionLoader, function_id: int, *, burst: bool +) -> None: + with ( + _RecoverableExecutor(initializer=_worker_init, initargs=[function_loader]) as executor, + tempfile.TemporaryDirectory() as cache_dir, + ): + client.config.cache_dir = Path(cache_dir, "cache") + client.logger.info("Will store cache at %s", client.config.cache_dir) + + agent = _Agent(client, executor, function_id) + agent.run(burst=burst) diff --git a/cvat-cli/src/cvat_cli/_internal/commands_all.py b/cvat-cli/src/cvat_cli/_internal/commands_all.py index 758d6b1d05e8..5f293f0ce06f 100644 --- a/cvat-cli/src/cvat_cli/_internal/commands_all.py +++ b/cvat-cli/src/cvat_cli/_internal/commands_all.py @@ -3,11 +3,13 @@ # SPDX-License-Identifier: MIT from .command_base import CommandGroup, DeprecatedAlias +from .commands_functions import COMMANDS as COMMANDS_FUNCTIONS from .commands_projects import COMMANDS as COMMANDS_PROJECTS from .commands_tasks import COMMANDS as COMMANDS_TASKS COMMANDS = CommandGroup(description="Perform operations on CVAT resources.") +COMMANDS.add_command("function", COMMANDS_FUNCTIONS) COMMANDS.add_command("project", COMMANDS_PROJECTS) COMMANDS.add_command("task", COMMANDS_TASKS) diff --git a/cvat-cli/src/cvat_cli/_internal/commands_functions.py b/cvat-cli/src/cvat_cli/_internal/commands_functions.py new file mode 100644 index 000000000000..76ccc56b05e9 --- /dev/null +++ b/cvat-cli/src/cvat_cli/_internal/commands_functions.py @@ -0,0 +1,138 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import argparse +import json +import textwrap +from collections.abc import Sequence + +import cvat_sdk.auto_annotation as cvataa +from cvat_sdk import Client + +from .agent import FUNCTION_KIND_DETECTOR, FUNCTION_PROVIDER_NATIVE, run_agent +from .command_base import CommandGroup +from .common import FunctionLoader, configure_function_implementation_arguments + +COMMANDS = CommandGroup(description="Perform operations on CVAT lambda functions.") + + +@COMMANDS.command_class("create-native") +class FunctionCreateNative: + description = textwrap.dedent( + """\ + Create a CVAT function that can be powered by an agent running the given local function. + """ + ) + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "name", + help="a human-readable name for the function", + ) + + configure_function_implementation_arguments(parser) + + def execute( + self, + client: Client, + *, + name: str, + function_loader: FunctionLoader, + ) -> None: + function = function_loader.load() + + remote_function = { + "provider": FUNCTION_PROVIDER_NATIVE, + "name": name, + } + + if isinstance(function.spec, cvataa.DetectionFunctionSpec): + remote_function["kind"] = FUNCTION_KIND_DETECTOR + remote_function["labels_v2"] = [] + + for label_spec in function.spec.labels: + if getattr(label_spec, "sublabels", None): + raise cvataa.BadFunctionError( + f"Function label {label_spec.name!r} has sublabels. This is currently not supported." + ) + + remote_function["labels_v2"].append( + { + "name": label_spec.name, + } + ) + else: + raise cvataa.BadFunctionError( + f"Unsupported function spec type: {type(function.spec).__name__}" + ) + + _, response = client.api_client.call_api( + "/api/functions", + "POST", + body=remote_function, + ) + + remote_function = json.loads(response.data) + + client.logger.info( + "Created function #%d: %s", remote_function["id"], remote_function["name"] + ) + print(remote_function["id"]) + + +@COMMANDS.command_class("delete") +class FunctionDelete: + description = "Delete a list of functions, ignoring those which don't exist." + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument("function_ids", type=int, help="IDs of functions to delete", nargs="+") + + def execute(self, client: Client, *, function_ids: Sequence[int]) -> None: + for function_id in function_ids: + _, response = client.api_client.call_api( + "/api/functions/{function_id}", + "DELETE", + path_params={"function_id": function_id}, + _check_status=False, + ) + + if 200 <= response.status <= 299: + client.logger.info(f"Function #{function_id} deleted") + elif response.status == 404: + client.logger.warning(f"Function #{function_id} not found") + else: + client.logger.error( + f"Failed to delete function #{function_id}: " + f"{response.msg} (status {response.status})" + ) + + +@COMMANDS.command_class("run-agent") +class FunctionRunAgent: + description = "Process requests for a given native function, indefinitely." + + def configure_parser(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "function_id", + type=int, + help="ID of the function to process requests for", + ) + + configure_function_implementation_arguments(parser) + + parser.add_argument( + "--burst", + action="store_true", + help="process all pending requests and then exit", + ) + + def execute( + self, + client: Client, + *, + function_id: int, + function_loader: FunctionLoader, + burst: bool, + ) -> None: + run_agent(client, function_loader, function_id, burst=burst) diff --git a/cvat-cli/src/cvat_cli/_internal/commands_tasks.py b/cvat-cli/src/cvat_cli/_internal/commands_tasks.py index 8c6782887d97..cbe2139cf457 100644 --- a/cvat-cli/src/cvat_cli/_internal/commands_tasks.py +++ b/cvat-cli/src/cvat_cli/_internal/commands_tasks.py @@ -5,12 +5,9 @@ from __future__ import annotations import argparse -import importlib -import importlib.util import textwrap from collections.abc import Sequence -from pathlib import Path -from typing import Any, Optional +from typing import Optional import cvat_sdk.auto_annotation as cvataa from attr.converters import to_bool @@ -19,13 +16,8 @@ from cvat_sdk.core.proxies.tasks import ResourceType from .command_base import CommandGroup, GenericCommand, GenericDeleteCommand, GenericListCommand -from .parsers import ( - BuildDictAction, - parse_function_parameter, - parse_label_arg, - parse_resource_type, - parse_threshold, -) +from .common import FunctionLoader, configure_function_implementation_arguments +from .parsers import parse_label_arg, parse_resource_type, parse_threshold COMMANDS = CommandGroup(description="Perform operations on CVAT tasks.") @@ -416,30 +408,7 @@ class TaskAutoAnnotate: def configure_parser(self, parser: argparse.ArgumentParser) -> None: parser.add_argument("task_id", type=int, help="task ID") - function_group = parser.add_mutually_exclusive_group(required=True) - - function_group.add_argument( - "--function-module", - metavar="MODULE", - help="qualified name of a module to use as the function", - ) - - function_group.add_argument( - "--function-file", - metavar="PATH", - type=Path, - help="path to a Python source file to use as the function", - ) - - parser.add_argument( - "--function-parameter", - "-p", - metavar="NAME=TYPE:VALUE", - type=parse_function_parameter, - action=BuildDictAction, - dest="function_parameters", - help="parameter for the function", - ) + configure_function_implementation_arguments(parser) parser.add_argument( "--clear-existing", @@ -471,29 +440,13 @@ def execute( client: Client, *, task_id: int, - function_module: Optional[str] = None, - function_file: Optional[Path] = None, - function_parameters: dict[str, Any], + function_loader: FunctionLoader, clear_existing: bool = False, allow_unmatched_labels: bool = False, conf_threshold: Optional[float], conv_mask_to_poly: bool, ) -> None: - if function_module is not None: - function = importlib.import_module(function_module) - elif function_file is not None: - module_spec = importlib.util.spec_from_file_location("__cvat_function__", function_file) - function = importlib.util.module_from_spec(module_spec) - module_spec.loader.exec_module(function) - else: - assert False, "function identification arguments missing" - - if hasattr(function, "create"): - # this is actually a function factory - function = function.create(**function_parameters) - else: - if function_parameters: - raise TypeError("function takes no parameters") + function = function_loader.load() cvataa.annotate_task( client, diff --git a/cvat-cli/src/cvat_cli/_internal/common.py b/cvat-cli/src/cvat_cli/_internal/common.py index 6f37e3d74eaa..e07d85c9b65e 100644 --- a/cvat-cli/src/cvat_cli/_internal/common.py +++ b/cvat-cli/src/cvat_cli/_internal/common.py @@ -5,17 +5,28 @@ import argparse import getpass +import importlib +import importlib.util import logging import os import sys from http.client import HTTPConnection +from pathlib import Path +from typing import Any, Optional +import attrs +import cvat_sdk.auto_annotation as cvataa from cvat_sdk.core.client import Client, Config from ..version import VERSION +from .parsers import BuildDictAction, parse_function_parameter from .utils import popattr +class CriticalError(Exception): + pass + + def get_auth(s): """Parse USER[:PASS] strings and prompt for password if none was supplied.""" @@ -102,3 +113,77 @@ def build_client(parsed_args: argparse.Namespace, logger: logging.Logger) -> Cli client.organization_slug = popattr(parsed_args, "organization") return client + + +def configure_function_implementation_arguments(parser: argparse.ArgumentParser) -> None: + function_group = parser.add_mutually_exclusive_group(required=True) + + function_group.add_argument( + "--function-module", + metavar="MODULE", + help="qualified name of a module to use as the function", + ) + + function_group.add_argument( + "--function-file", + metavar="PATH", + type=Path, + help="path to a Python source file to use as the function", + ) + + parser.add_argument( + "--function-parameter", + "-p", + metavar="NAME=TYPE:VALUE", + type=parse_function_parameter, + action=BuildDictAction, + dest="function_parameters", + help="parameter for the function", + ) + + original_executor = parser.get_default("_executor") + + def execute_with_function_loader( + client, + *, + function_module: Optional[str], + function_file: Optional[Path], + function_parameters: dict[str, Any], + **kwargs, + ): + original_executor( + client, + function_loader=FunctionLoader(function_module, function_file, function_parameters), + **kwargs, + ) + + parser.set_defaults(_executor=execute_with_function_loader) + + +@attrs.frozen +class FunctionLoader: + function_module: Optional[str] + function_file: Optional[Path] + function_parameters: dict[str, Any] + + def __attrs_post_init__(self): + assert self.function_module is not None or self.function_file is not None + + def load(self) -> cvataa.DetectionFunction: + if self.function_module is not None: + function = importlib.import_module(self.function_module) + else: + module_spec = importlib.util.spec_from_file_location( + "__cvat_function__", self.function_file + ) + function = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(function) + + if hasattr(function, "create"): + # this is actually a function factory + function = function.create(**self.function_parameters) + else: + if self.function_parameters: + raise TypeError("function takes no parameters") + + return function diff --git a/cvat-cli/src/cvat_cli/version.py b/cvat-cli/src/cvat_cli/version.py index 203e6c4bc9b2..3b01a28f6f60 100644 --- a/cvat-cli/src/cvat_cli/version.py +++ b/cvat-cli/src/cvat_cli/version.py @@ -1 +1 @@ -VERSION = "2.24.0" +VERSION = "2.25.0" diff --git a/cvat-core/src/annotations-actions/base-action.ts b/cvat-core/src/annotations-actions/base-action.ts index 2ec2148b24c7..8a0abba4b32d 100644 --- a/cvat-core/src/annotations-actions/base-action.ts +++ b/cvat-core/src/annotations-actions/base-action.ts @@ -53,7 +53,7 @@ export function validateClientIDs(collection: Partial) { collection.tracks ?? [], collection.tags ?? [], ).forEach((object) => { - // clientID is required to correct collection filtering and commiting in annotations actions logic + // clientID is required to correct collection filtering and committing in annotations actions logic if (typeof object.clientID !== 'number') { throw new Error('ClientID is undefined when running annotations action, but required'); } diff --git a/cvat-core/src/annotations-actions/base-shapes-action.ts b/cvat-core/src/annotations-actions/base-shapes-action.ts index 9eb65f052ee4..e5223f085d2d 100644 --- a/cvat-core/src/annotations-actions/base-shapes-action.ts +++ b/cvat-core/src/annotations-actions/base-shapes-action.ts @@ -129,7 +129,7 @@ export async function run( } } - await showMessageWithPause('Commiting handled objects', 100, 1500); + await showMessageWithPause('Committing handled objects', 100, 1500); if (cancelled()) { return; } diff --git a/cvat-core/src/frames.ts b/cvat-core/src/frames.ts index 3305edfc5aab..b772aca4ca4a 100644 --- a/cvat-core/src/frames.ts +++ b/cvat-core/src/frames.ts @@ -16,7 +16,6 @@ import config from './config'; // frame storage by job id const frameDataCache: Record; getChunk: (chunkIndex: number, quality: ChunkQuality) => Promise; + getMeta: () => Promise; }> = {}; // frame meta data storage by job id const frameMetaCache: Record> = {}; +enum DeletedFrameState { + DELETED = 'deleted', + RESTORED = 'restored', +} + +interface FramesMetaDataUpdatedData { + deletedFrames: Record; +} + export class FramesMetaData { public chunkSize: number; public deletedFrames: Record; @@ -82,10 +91,13 @@ export class FramesMetaData { if (Object.prototype.hasOwnProperty.call(data, property) && property in initialData) { if (property === 'deleted_frames') { const update = (frame: string, remove: boolean): void => { - if (this.#updateTrigger.get(`deletedFrames:${frame}:${!remove}`)) { - this.#updateTrigger.resetField(`deletedFrames:${frame}:${!remove}`); + const [state, oppositeState] = remove ? + [DeletedFrameState.DELETED, DeletedFrameState.RESTORED] : + [DeletedFrameState.RESTORED, DeletedFrameState.DELETED]; + if (this.#updateTrigger.get(`deletedFrames:${frame}:${oppositeState}`)) { + this.#updateTrigger.resetField(`deletedFrames:${frame}:${oppositeState}`); } else { - this.#updateTrigger.update(`deletedFrames:${frame}:${remove}`); + this.#updateTrigger.update(`deletedFrames:${frame}:${state}`); } }; @@ -178,8 +190,17 @@ export class FramesMetaData { return (dataFrameNumber - this.startFrame) / this.frameStep; } - getUpdated(): Record { - return this.#updateTrigger.getUpdated(this); + getUpdated(): FramesMetaDataUpdatedData { + const updatedFields = this.#updateTrigger.getUpdated(this); + const deletedFrames: FramesMetaDataUpdatedData['deletedFrames'] = {}; + for (const key in updatedFields) { + if (Object.hasOwn(updatedFields, key) && key.startsWith('deletedFrames')) { + const [, frame, state] = key.split(':'); + deletedFrames[frame] = state; + } + } + + return { deletedFrames }; } resetUpdated(): void { @@ -340,17 +361,18 @@ class PrefetchAnalyzer { } Object.defineProperty(FrameData.prototype.data, 'implementation', { - value(this: FrameData, onServerRequest) { + async value(this: FrameData, onServerRequest) { + const { + provider, prefetchAnalyzer, chunkSize, jobStartFrame, + decodeForward, forwardStep, decodedBlocksCacheSize, + } = frameDataCache[this.jobID]; + const meta = await frameDataCache[this.jobID].getMeta(); + return new Promise<{ renderWidth: number; renderHeight: number; imageData: ImageBitmap | Blob; } | Blob>((resolve, reject) => { - const { - meta, provider, prefetchAnalyzer, chunkSize, jobStartFrame, - decodeForward, forwardStep, decodedBlocksCacheSize, - } = frameDataCache[this.jobID]; - const requestId = +_.uniqueId(); const requestedDataFrameNumber = meta.getDataFrameNumber(this.number - jobStartFrame); const chunkIndex = meta.getFrameChunkIndex(requestedDataFrameNumber); @@ -536,6 +558,34 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { writable: false, }); +function mergeMetaData( + nextData: SerializedFramesMetaData, + previousData?: Promise, +): Promise { + const framesMetaData = new FramesMetaData({ + ...nextData, + deleted_frames: Object.fromEntries(nextData.deleted_frames.map((_frame) => [_frame, true])), + }); + + if (previousData instanceof Promise) { + return previousData.then((prevMeta) => { + const updatedFields = prevMeta.getUpdated(); + const updatedDeletedFrames = updatedFields.deletedFrames; + for (const [frame, state] of Object.entries(updatedDeletedFrames)) { + if (state === DeletedFrameState.DELETED) { + framesMetaData.deletedFrames[frame] = true; + } else if (state === DeletedFrameState.RESTORED) { + delete framesMetaData.deletedFrames[frame]; + } + } + + return framesMetaData; + }); + } + + return Promise.resolve(framesMetaData); +} + export function getFramesMeta(type: 'job' | 'task', id: number, forceReload = false): Promise { if (type === 'task') { // we do not cache task meta currently. So, each new call will results to the server request @@ -551,11 +601,11 @@ export function getFramesMeta(type: 'job' | 'task', id: number, forceReload = fa const previousCache = frameMetaCache[id]; frameMetaCache[id] = new Promise((resolve, reject) => { serverProxy.frames.getMeta('job', id).then((serialized) => { - const framesMetaData = new FramesMetaData({ - ...serialized, - deleted_frames: Object.fromEntries(serialized.deleted_frames.map((_frame) => [_frame, true])), + // When we get new framesMetaData from server there can be some unsaved data + // here we merge new meta data with cached one + mergeMetaData(serialized, previousCache).then((mergedData) => { + resolve(mergedData); }); - resolve(framesMetaData); }).catch((error: unknown) => { delete frameMetaCache[id]; if (previousCache instanceof Promise) { @@ -588,8 +638,9 @@ function saveJobMeta(meta: FramesMetaData, jobID: number): Promise { + const { mode, jobStartFrame } = frameDataCache[jobID]; + const meta = await frameDataCache[jobID].getMeta(); let frameMeta = null; if (mode === 'interpolation' && meta.frames.length === 1) { // video tasks have 1 frame info, but image tasks will have many infos @@ -616,12 +667,12 @@ async function refreshJobCacheIfOutdated(jobID: number): Promise { if (isOutdated) { // get metadata again if outdated + const prevMeta = await cached.getMeta(); const meta = await getFramesMeta('job', jobID, true); - if (new Date(meta.chunksUpdatedDate) > new Date(cached.meta.chunksUpdatedDate)) { + if (new Date(meta.chunksUpdatedDate) > new Date(prevMeta.chunksUpdatedDate)) { // chunks were re-defined. Existing data not relevant anymore // currently we only re-write meta, remove all cached frames from provider and clear cached context images // other parameters (e.g. chunkSize) are not supposed to be changed - cached.meta = meta; cached.provider.cleanup(Number.MAX_SAFE_INTEGER); for (const frame of Object.keys(cached.contextCache)) { for (const image of Object.values(cached.contextCache[+frame].data)) { @@ -636,7 +687,12 @@ async function refreshJobCacheIfOutdated(jobID: number): Promise { } } -export function getContextImage(jobID: number, frame: number): Promise> { +export async function getContextImage(jobID: number, frame: number): Promise> { + const frameData = frameDataCache[jobID]; + const meta = await frameData.getMeta(); + const requestId = frame; + const { jobStartFrame } = frameData; + const { related_files: relatedFiles } = meta.frames[frame - jobStartFrame]; return new Promise>((resolve, reject) => { if (!(jobID in frameDataCache)) { reject(new Error( @@ -644,11 +700,6 @@ export function getContextImage(jobID: number, frame: number): Promise { + const cached = frameMetaCache[jobID]; + if (!(cached instanceof Promise)) { + throw new Error('Frame meta data is not initialized'); + } + return cached; + }, }; } @@ -803,25 +860,27 @@ export async function getFrame( // Thus, it is better to only call `refreshJobCacheIfOutdated` from getFrame() await refreshJobCacheIfOutdated(jobID); - const frameMeta = getFrameMeta(jobID, frame); + const frameMeta = await getFrameMeta(jobID, frame); frameDataCache[jobID].provider.setRenderSize(frameMeta.width, frameMeta.height); frameDataCache[jobID].decodeForward = isPlaying; frameDataCache[jobID].forwardStep = step; + const meta = await frameDataCache[jobID].getMeta(); + return new FrameData({ width: frameMeta.width, height: frameMeta.height, name: frameMeta.name, related_files: frameMeta.related_files, frameNumber: frame, - deleted: frame in frameDataCache[jobID].meta.deletedFrames, + deleted: frame in meta.deletedFrames, jobID, }); } export async function getDeletedFrames(instanceType: 'job' | 'task', id: number): Promise> { if (instanceType === 'job') { - const { meta } = frameDataCache[id]; + const meta = await frameDataCache[id].getMeta(); return meta.deletedFrames; } @@ -900,12 +959,13 @@ export function getCachedChunks(jobID: number): number[] { return frameDataCache[jobID].provider.cachedChunks(true); } -export function getJobFrameNumbers(jobID: number): number[] { +export async function getJobFrameNumbers(jobID: number): Promise { if (!(jobID in frameDataCache)) { return []; } - const { meta, jobStartFrame } = frameDataCache[jobID]; + const { jobStartFrame } = frameDataCache[jobID]; + const meta = await frameDataCache[jobID].getMeta(); return meta.getSegmentFrameNumbers(jobStartFrame); } diff --git a/cvat-core/src/lambda-manager.ts b/cvat-core/src/lambda-manager.ts index 66733d7ed236..cfed3d474329 100644 --- a/cvat-core/src/lambda-manager.ts +++ b/cvat-core/src/lambda-manager.ts @@ -8,12 +8,6 @@ import { ArgumentError } from './exceptions'; import MLModel from './ml-model'; import { RQStatus, ShapeType } from './enums'; -export interface ModelProvider { - name: string; - icon: string; - attributes: Record; -} - export interface InteractorResults { mask: number[][]; points?: [number, number][]; diff --git a/cvat-core/src/quality-settings.ts b/cvat-core/src/quality-settings.ts index 7c591e371cc4..bc553105c181 100644 --- a/cvat-core/src/quality-settings.ts +++ b/cvat-core/src/quality-settings.ts @@ -38,7 +38,7 @@ export default class QualitySettings { #objectVisibilityThreshold: number; #panopticComparison: boolean; #compareAttributes: boolean; - #matchEmptyFrames: boolean; + #emptyIsAnnotated: boolean; #descriptions: Record; constructor(initialData: SerializedQualitySettingsData) { @@ -60,7 +60,7 @@ export default class QualitySettings { this.#objectVisibilityThreshold = initialData.object_visibility_threshold; this.#panopticComparison = initialData.panoptic_comparison; this.#compareAttributes = initialData.compare_attributes; - this.#matchEmptyFrames = initialData.match_empty_frames; + this.#emptyIsAnnotated = initialData.empty_is_annotated; this.#descriptions = initialData.descriptions; } @@ -200,12 +200,12 @@ export default class QualitySettings { this.#maxValidationsPerJob = newVal; } - get matchEmptyFrames(): boolean { - return this.#matchEmptyFrames; + get emptyIsAnnotated(): boolean { + return this.#emptyIsAnnotated; } - set matchEmptyFrames(newVal: boolean) { - this.#matchEmptyFrames = newVal; + set emptyIsAnnotated(newVal: boolean) { + this.#emptyIsAnnotated = newVal; } get descriptions(): Record { @@ -236,7 +236,7 @@ export default class QualitySettings { target_metric: this.#targetMetric, target_metric_threshold: this.#targetMetricThreshold, max_validations_per_job: this.#maxValidationsPerJob, - match_empty_frames: this.#matchEmptyFrames, + empty_is_annotated: this.#emptyIsAnnotated, }; return result; diff --git a/cvat-core/src/server-response-types.ts b/cvat-core/src/server-response-types.ts index ea97c0730aaa..ef635d12004e 100644 --- a/cvat-core/src/server-response-types.ts +++ b/cvat-core/src/server-response-types.ts @@ -258,7 +258,7 @@ export interface SerializedQualitySettingsData { object_visibility_threshold?: number; panoptic_comparison?: boolean; compare_attributes?: boolean; - match_empty_frames?: boolean; + empty_is_annotated?: boolean; descriptions?: Record; } diff --git a/cvat-core/src/session-implementation.ts b/cvat-core/src/session-implementation.ts index 7ea9e326fb8b..a5c008605749 100644 --- a/cvat-core/src/session-implementation.ts +++ b/cvat-core/src/session-implementation.ts @@ -265,7 +265,7 @@ export function implementJob(Job: typeof JobClass): typeof JobClass { value: function includedFramesImplementation( this: JobClass, ): ReturnType { - return Promise.resolve(getJobFrameNumbers(this.id)); + return getJobFrameNumbers(this.id); }, }); diff --git a/cvat-sdk/cvat_sdk/auto_annotation/driver.py b/cvat-sdk/cvat_sdk/auto_annotation/driver.py index 5ffdb36f5bee..42e17f93b6b2 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/driver.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/driver.py @@ -23,24 +23,62 @@ class BadFunctionError(Exception): """ +@attrs.frozen +class _SublabelNameMapping: + name: str + + +@attrs.frozen +class _LabelNameMapping(_SublabelNameMapping): + sublabels: Optional[Mapping[str, _SublabelNameMapping]] = attrs.field( + kw_only=True, default=None + ) + + def map_sublabel(self, name: str): + if self.sublabels is None: + return _SublabelNameMapping(name) + + return self.sublabels.get(name) + + +@attrs.frozen +class _SpecNameMapping: + labels: Optional[Mapping[str, _LabelNameMapping]] = attrs.field(kw_only=True, default=None) + + def map_label(self, name: str): + if self.labels is None: + return _LabelNameMapping(name) + + return self.labels.get(name) + + class _AnnotationMapper: @attrs.frozen - class _MappedLabel: + class _LabelIdMapping: id: int - sublabel_mapping: Mapping[int, Optional[int]] + sublabels: Mapping[int, Optional[int]] expected_num_elements: int = 0 - _label_mapping: Mapping[int, Optional[_MappedLabel]] + _label_id_mappings: Mapping[int, Optional[_LabelIdMapping]] - def _build_mapped_label( - self, fun_label: models.ILabel, ds_labels_by_name: Mapping[str, models.ILabel] - ) -> Optional[_MappedLabel]: + def _build_label_id_mapping( + self, + fun_label: models.ILabel, + ds_labels_by_name: Mapping[str, models.ILabel], + *, + allow_unmatched_labels: bool, + spec_nm: _SpecNameMapping, + ) -> Optional[_LabelIdMapping]: if getattr(fun_label, "attributes", None): raise BadFunctionError(f"label attributes are currently not supported") - ds_label = ds_labels_by_name.get(fun_label.name) + label_nm = spec_nm.map_label(fun_label.name) + if label_nm is None: + return None + + ds_label = ds_labels_by_name.get(label_nm.name) if ds_label is None: - if not self._allow_unmatched_labels: + if not allow_unmatched_labels: raise BadFunctionError(f"label {fun_label.name!r} is not in dataset") self._logger.info( @@ -71,9 +109,14 @@ def _build_mapped_label( f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has same ID as another sublabel ({fun_sl.id})" ) - ds_sl = ds_sublabels_by_name.get(fun_sl.name) + sublabel_nm = label_nm.map_sublabel(fun_sl.name) + if sublabel_nm is None: + sl_map[fun_sl.id] = None + continue + + ds_sl = ds_sublabels_by_name.get(sublabel_nm.name) if not ds_sl: - if not self._allow_unmatched_labels: + if not allow_unmatched_labels: raise BadFunctionError( f"sublabel {fun_sl.name!r} of label {fun_label.name!r} is not in dataset" ) @@ -88,8 +131,8 @@ def _build_mapped_label( sl_map[fun_sl.id] = ds_sl.id - return self._MappedLabel( - ds_label.id, sublabel_mapping=sl_map, expected_num_elements=len(ds_label.sublabels) + return self._LabelIdMapping( + ds_label.id, sublabels=sl_map, expected_num_elements=len(ds_label.sublabels) ) def __init__( @@ -100,26 +143,29 @@ def __init__( *, allow_unmatched_labels: bool, conv_mask_to_poly: bool, + spec_nm: _SpecNameMapping = _SpecNameMapping(), ) -> None: self._logger = logger - self._allow_unmatched_labels = allow_unmatched_labels self._conv_mask_to_poly = conv_mask_to_poly ds_labels_by_name = {ds_label.name: ds_label for ds_label in ds_labels} - self._label_mapping = {} + self._label_id_mappings = {} for fun_label in fun_labels: if not hasattr(fun_label, "id"): raise BadFunctionError(f"label {fun_label.name!r} has no ID") - if fun_label.id in self._label_mapping: + if fun_label.id in self._label_id_mappings: raise BadFunctionError( f"label {fun_label.name} has same ID as another label ({fun_label.id})" ) - self._label_mapping[fun_label.id] = self._build_mapped_label( - fun_label, ds_labels_by_name + self._label_id_mappings[fun_label.id] = self._build_label_id_mapping( + fun_label, + ds_labels_by_name, + allow_unmatched_labels=allow_unmatched_labels, + spec_nm=spec_nm, ) def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: int) -> None: @@ -141,16 +187,16 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: shape.frame = ds_frame try: - mapped_label = self._label_mapping[shape.label_id] + label_id_mapping = self._label_id_mappings[shape.label_id] except KeyError: raise BadFunctionError( f"function output shape with unknown label ID ({shape.label_id})" ) - if not mapped_label: + if not label_id_mapping: continue - shape.label_id = mapped_label.id + shape.label_id = label_id_mapping.id if getattr(shape, "attributes", None): raise BadFunctionError( @@ -184,7 +230,7 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: ) try: - mapped_sl_id = mapped_label.sublabel_mapping[element.label_id] + mapped_sl_id = label_id_mapping.sublabels[element.label_id] except KeyError: raise BadFunctionError( f"function output shape with unknown sublabel ID ({element.label_id})" @@ -204,14 +250,14 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: new_elements.append(element) - if len(new_elements) != mapped_label.expected_num_elements: + if len(new_elements) != label_id_mapping.expected_num_elements: # new_elements could only be shorter than expected, # because the reverse would imply that there are more distinct sublabel IDs # than are actually defined in the dataset. - assert len(new_elements) < mapped_label.expected_num_elements + assert len(new_elements) < label_id_mapping.expected_num_elements raise BadFunctionError( - f"function output skeleton with fewer elements than expected ({len(new_elements)} vs {mapped_label.expected_num_elements})" + f"function output skeleton with fewer elements than expected ({len(new_elements)} vs {label_id_mapping.expected_num_elements})" ) shape.elements[:] = new_elements diff --git a/cvat-sdk/gen/generate.sh b/cvat-sdk/gen/generate.sh index f4d78e868601..855a7d71c4f0 100755 --- a/cvat-sdk/gen/generate.sh +++ b/cvat-sdk/gen/generate.sh @@ -8,7 +8,7 @@ set -e GENERATOR_VERSION="v6.0.1" -VERSION="2.24.0" +VERSION="2.25.0" LIB_NAME="cvat_sdk" LAYER1_LIB_NAME="${LIB_NAME}/api_client" DST_DIR="$(cd "$(dirname -- "$0")/.." && pwd)" diff --git a/cvat-sdk/gen/requirements.txt b/cvat-sdk/gen/requirements.txt index 18f397e59dc6..54c28f0b0007 100644 --- a/cvat-sdk/gen/requirements.txt +++ b/cvat-sdk/gen/requirements.txt @@ -1,5 +1,4 @@ # can't have a dependency on base.txt, because it depends on the generated file inflection >= 0.5.1 -isort>=5.10.1 ruamel.yaml>=0.17.21 diff --git a/cvat-sdk/pyproject.toml b/cvat-sdk/pyproject.toml index ce8cba3ffba6..8d3fb7787504 100644 --- a/cvat-sdk/pyproject.toml +++ b/cvat-sdk/pyproject.toml @@ -7,3 +7,4 @@ profile = "black" forced_separate = ["tests"] line_length = 100 skip_gitignore = true # align tool behavior with Black +known_first_party = ["cvat_sdk"] diff --git a/cvat-ui/src/components/quality-control/quality-control-page.tsx b/cvat-ui/src/components/quality-control/quality-control-page.tsx index cbaa26a8dd09..afa166f6f5fa 100644 --- a/cvat-ui/src/components/quality-control/quality-control-page.tsx +++ b/cvat-ui/src/components/quality-control/quality-control-page.tsx @@ -223,7 +223,7 @@ function QualityControlPage(): JSX.Element { settings.lowOverlapThreshold = values.lowOverlapThreshold / 100; settings.iouThreshold = values.iouThreshold / 100; settings.compareAttributes = values.compareAttributes; - settings.matchEmptyFrames = values.matchEmptyFrames; + settings.emptyIsAnnotated = values.emptyIsAnnotated; settings.oksSigma = values.oksSigma / 100; settings.pointSizeBase = values.pointSizeBase; diff --git a/cvat-ui/src/components/quality-control/task-quality/quality-settings-form.tsx b/cvat-ui/src/components/quality-control/task-quality/quality-settings-form.tsx index 87a727f9772b..b5218475b418 100644 --- a/cvat-ui/src/components/quality-control/task-quality/quality-settings-form.tsx +++ b/cvat-ui/src/components/quality-control/task-quality/quality-settings-form.tsx @@ -34,7 +34,7 @@ export default function QualitySettingsForm(props: Readonly): JSX.Element lowOverlapThreshold: settings.lowOverlapThreshold * 100, iouThreshold: settings.iouThreshold * 100, compareAttributes: settings.compareAttributes, - matchEmptyFrames: settings.matchEmptyFrames, + emptyIsAnnotated: settings.emptyIsAnnotated, oksSigma: settings.oksSigma * 100, pointSizeBase: settings.pointSizeBase, @@ -81,7 +81,7 @@ export default function QualitySettingsForm(props: Readonly): JSX.Element {makeTooltipFragment('Target metric', targetMetricDescription)} {makeTooltipFragment('Target metric threshold', settings.descriptions.targetMetricThreshold)} {makeTooltipFragment('Compare attributes', settings.descriptions.compareAttributes)} - {makeTooltipFragment('Match empty frames', settings.descriptions.matchEmptyFrames)} + {makeTooltipFragment('Empty frames are annotated', settings.descriptions.emptyIsAnnotated)} , ); @@ -198,12 +198,12 @@ export default function QualitySettingsForm(props: Readonly): JSX.Element - Match empty frames + Empty frames are annotated diff --git a/cvat-ui/src/cvat-core-wrapper.ts b/cvat-ui/src/cvat-core-wrapper.ts index ba7b47fcfa54..fc255dd53324 100644 --- a/cvat-ui/src/cvat-core-wrapper.ts +++ b/cvat-ui/src/cvat-core-wrapper.ts @@ -10,7 +10,6 @@ import ObjectState from 'cvat-core/src/object-state'; import Webhook from 'cvat-core/src/webhook'; import MLModel from 'cvat-core/src/ml-model'; import CloudStorage from 'cvat-core/src/cloud-storage'; -import { ModelProvider } from 'cvat-core/src/lambda-manager'; import { Label, Attribute, } from 'cvat-core/src/labels'; @@ -121,7 +120,6 @@ export type { SerializedAttribute, SerializedLabel, StorageData, - ModelProvider, APIWrapperEnterOptions, QualitySummary, CVATCore, diff --git a/cvat-ui/src/utils/is-able-to-change-frame.ts b/cvat-ui/src/utils/is-able-to-change-frame.ts index 3cbc127a8a86..d86b6357cd88 100644 --- a/cvat-ui/src/utils/is-able-to-change-frame.ts +++ b/cvat-ui/src/utils/is-able-to-change-frame.ts @@ -21,7 +21,7 @@ export default function isAbleToChangeFrame(frame?: number): boolean { if (typeof frame === 'number') { if (meta.includedFrames) { // frame argument comes in job coordinates - // hovewer includedFrames contains absolute data values + // however includedFrames contains absolute data values frameInTheJob = meta.includedFrames.includes(meta.getDataFrameNumber(frame - job.startFrame)); } diff --git a/cvat/__init__.py b/cvat/__init__.py index 7586ced41d72..f9c5ae2659e5 100644 --- a/cvat/__init__.py +++ b/cvat/__init__.py @@ -4,6 +4,6 @@ from cvat.utils.version import get_version -VERSION = (2, 24, 0, "final", 0) +VERSION = (2, 25, 0, "final", 0) __version__ = get_version(VERSION) diff --git a/cvat/apps/dataset_manager/annotation.py b/cvat/apps/dataset_manager/annotation.py index 4ea10ba9619d..943e53d003e3 100644 --- a/cvat/apps/dataset_manager/annotation.py +++ b/cvat/apps/dataset_manager/annotation.py @@ -3,19 +3,19 @@ # # SPDX-License-Identifier: MIT -from copy import copy, deepcopy - import math from collections.abc import Container, Sequence +from copy import copy, deepcopy +from itertools import chain from typing import Optional + import numpy as np -from itertools import chain from scipy.optimize import linear_sum_assignment from shapely import geometry -from cvat.apps.engine.models import ShapeType, DimensionType -from cvat.apps.engine.serializers import LabeledDataSerializer from cvat.apps.dataset_manager.util import faster_deepcopy +from cvat.apps.engine.models import DimensionType, ShapeType +from cvat.apps.engine.serializers import LabeledDataSerializer class AnnotationIR: diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index 8b759f7b6316..7fddcb198f35 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -16,29 +16,39 @@ from types import SimpleNamespace from typing import Any, Callable, Literal, NamedTuple, Optional, Union -from attrs.converters import to_bool import datumaro as dm import defusedxml.ElementTree as ET import rq from attr import attrib, attrs +from attrs.converters import to_bool from datumaro.components.format_detection import RejectionReason +from django.conf import settings from django.db.models import Prefetch, QuerySet from django.utils import timezone -from django.conf import settings from cvat.apps.dataset_manager.formats.utils import get_label_color from cvat.apps.dataset_manager.util import add_prefetch_fields from cvat.apps.engine import models -from cvat.apps.engine.frame_provider import TaskFrameProvider, FrameQuality, FrameOutputType -from cvat.apps.engine.models import (AttributeSpec, AttributeType, DimensionType, Job, - JobType, Label, LabelType, Project, SegmentType, ShapeType, - Task) -from cvat.apps.engine.rq_job_handler import RQJobMetaField +from cvat.apps.engine.frame_provider import FrameOutputType, FrameQuality, TaskFrameProvider from cvat.apps.engine.lazy_list import LazyList +from cvat.apps.engine.models import ( + AttributeSpec, + AttributeType, + DimensionType, + Job, + JobType, + Label, + LabelType, + Project, + SegmentType, + ShapeType, + Task, +) +from cvat.apps.engine.rq_job_handler import RQJobMetaField -from .annotation import AnnotationIR, AnnotationManager, TrackManager -from .formats.transformations import MaskConverter, EllipsesToMasks from ..engine.log import ServerLogManager +from .annotation import AnnotationIR, AnnotationManager, TrackManager +from .formats.transformations import EllipsesToMasks, MaskConverter slogger = ServerLogManager(__name__) @@ -2175,7 +2185,11 @@ def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[ProjectDa 'coco', 'coco_instances', 'coco_person_keypoints', - 'voc' + 'voc', + 'yolo_ultralytics_detection', + 'yolo_ultralytics_segmentation', + 'yolo_ultralytics_oriented_boxes', + 'yolo_ultralytics_pose', ] label_cat = dm_dataset.categories()[dm.AnnotationType.label] diff --git a/cvat/apps/dataset_manager/formats/__init__.py b/cvat/apps/dataset_manager/formats/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/cvat/apps/dataset_manager/formats/camvid.py b/cvat/apps/dataset_manager/formats/camvid.py index 75cea9e98bd4..e995c5f1075d 100644 --- a/cvat/apps/dataset_manager/formats/camvid.py +++ b/cvat/apps/dataset_manager/formats/camvid.py @@ -6,12 +6,11 @@ from datumaro.components.dataset import Dataset from pyunpack import Archive -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, - import_dm_annotations) +from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, import_dm_annotations from cvat.apps.dataset_manager.util import make_zip_archive -from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons from .registry import dm_env, exporter, importer +from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons from .utils import make_colormap diff --git a/cvat/apps/dataset_manager/formats/cityscapes.py b/cvat/apps/dataset_manager/formats/cityscapes.py index ea39578ea3f3..dce977b94d1a 100644 --- a/cvat/apps/dataset_manager/formats/cityscapes.py +++ b/cvat/apps/dataset_manager/formats/cityscapes.py @@ -6,15 +6,18 @@ import os.path as osp from datumaro.components.dataset import Dataset -from datumaro.plugins.cityscapes_format import write_label_map +from datumaro.plugins.data_formats.cityscapes import write_label_map from pyunpack import Archive -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, - import_dm_annotations) +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + import_dm_annotations, +) from cvat.apps.dataset_manager.util import make_zip_archive -from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons from .registry import dm_env, exporter, importer +from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons from .utils import make_colormap diff --git a/cvat/apps/dataset_manager/formats/coco.py b/cvat/apps/dataset_manager/formats/coco.py index 1d1a8ce4d0d5..cab74bcb42e1 100644 --- a/cvat/apps/dataset_manager/formats/coco.py +++ b/cvat/apps/dataset_manager/formats/coco.py @@ -5,17 +5,21 @@ import zipfile -from datumaro.components.dataset import Dataset from datumaro.components.annotation import AnnotationType -from datumaro.plugins.coco_format.importer import CocoImporter +from datumaro.components.dataset import Dataset +from datumaro.plugins.data_formats.coco.importer import CocoImporter from cvat.apps.dataset_manager.bindings import ( - GetCVATDataExtractor, NoMediaInAnnotationFileError, import_dm_annotations, detect_dataset + GetCVATDataExtractor, + NoMediaInAnnotationFileError, + detect_dataset, + import_dm_annotations, ) from cvat.apps.dataset_manager.util import make_zip_archive from .registry import dm_env, exporter, importer + @exporter(name='COCO', ext='ZIP', version='1.0') def _export(dst_file, temp_dir, instance_data, save_images=False): with GetCVATDataExtractor(instance_data, include_images=save_images) as extractor: diff --git a/cvat/apps/dataset_manager/formats/cvat.py b/cvat/apps/dataset_manager/formats/cvat.py index fa46b58813bf..f5c7dc18fcda 100644 --- a/cvat/apps/dataset_manager/formats/cvat.py +++ b/cvat/apps/dataset_manager/formats/cvat.py @@ -11,29 +11,34 @@ from io import BufferedWriter from typing import Callable, Union -from datumaro.components.annotation import (AnnotationType, Bbox, Label, - LabelCategories, Points, Polygon, - PolyLine, Skeleton) +from datumaro.components.annotation import ( + AnnotationType, + Bbox, + Label, + LabelCategories, + Points, + Polygon, + PolyLine, + Skeleton, +) from datumaro.components.dataset import Dataset, DatasetItem -from datumaro.components.extractor import (DEFAULT_SUBSET_NAME, Extractor, - Importer) -from datumaro.plugins.cvat_format.extractor import CvatImporter as _CvatImporter - +from datumaro.components.extractor import DEFAULT_SUBSET_NAME, Extractor, Importer +from datumaro.plugins.data_formats.cvat.base import CvatImporter as _CvatImporter from datumaro.util.image import Image from defusedxml import ElementTree from cvat.apps.dataset_manager.bindings import ( + JobData, NoMediaInAnnotationFileError, ProjectData, TaskData, - JobData, detect_dataset, get_defaulted_subset, import_dm_annotations, - match_dm_item + match_dm_item, ) from cvat.apps.dataset_manager.util import make_zip_archive -from cvat.apps.engine.frame_provider import FrameQuality, FrameOutputType, make_frame_provider +from cvat.apps.engine.frame_provider import FrameOutputType, FrameQuality, make_frame_provider from .registry import dm_env, exporter, importer diff --git a/cvat/apps/dataset_manager/formats/datumaro.py b/cvat/apps/dataset_manager/formats/datumaro.py index 4fc1d246dd47..81f86cb32065 100644 --- a/cvat/apps/dataset_manager/formats/datumaro.py +++ b/cvat/apps/dataset_manager/formats/datumaro.py @@ -4,10 +4,14 @@ # SPDX-License-Identifier: MIT import zipfile + from datumaro.components.dataset import Dataset from cvat.apps.dataset_manager.bindings import ( - GetCVATDataExtractor, import_dm_annotations, NoMediaInAnnotationFileError, detect_dataset + GetCVATDataExtractor, + NoMediaInAnnotationFileError, + detect_dataset, + import_dm_annotations, ) from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.engine.models import DimensionType diff --git a/cvat/apps/dataset_manager/formats/icdar.py b/cvat/apps/dataset_manager/formats/icdar.py index 5d031eef82b0..c72f9708fe11 100644 --- a/cvat/apps/dataset_manager/formats/icdar.py +++ b/cvat/apps/dataset_manager/formats/icdar.py @@ -5,17 +5,15 @@ import zipfile -from datumaro.components.annotation import (AnnotationType, Caption, Label, - LabelCategories) +from datumaro.components.annotation import AnnotationType, Caption, Label, LabelCategories from datumaro.components.dataset import Dataset from datumaro.components.extractor import ItemTransform -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, - import_dm_annotations) +from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, import_dm_annotations from cvat.apps.dataset_manager.util import make_zip_archive -from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons from .registry import dm_env, exporter, importer +from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons class AddLabelToAnns(ItemTransform): diff --git a/cvat/apps/dataset_manager/formats/imagenet.py b/cvat/apps/dataset_manager/formats/imagenet.py index fd5e9a99a176..273f47616bc1 100644 --- a/cvat/apps/dataset_manager/formats/imagenet.py +++ b/cvat/apps/dataset_manager/formats/imagenet.py @@ -9,8 +9,7 @@ from datumaro.components.dataset import Dataset -from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, \ - import_dm_annotations +from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, import_dm_annotations from cvat.apps.dataset_manager.util import make_zip_archive from .registry import dm_env, exporter, importer diff --git a/cvat/apps/dataset_manager/formats/kitti.py b/cvat/apps/dataset_manager/formats/kitti.py index 01e1cd3fc4bc..631f903f7289 100644 --- a/cvat/apps/dataset_manager/formats/kitti.py +++ b/cvat/apps/dataset_manager/formats/kitti.py @@ -6,15 +6,18 @@ import os.path as osp from datumaro.components.dataset import Dataset -from datumaro.plugins.kitti_format.format import KittiPath, write_label_map - +from datumaro.plugins.data_formats.kitti.format import KittiPath, write_label_map from pyunpack import Archive -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, import_dm_annotations) +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + import_dm_annotations, +) from cvat.apps.dataset_manager.util import make_zip_archive -from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons from .registry import dm_env, exporter, importer +from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons from .utils import make_colormap diff --git a/cvat/apps/dataset_manager/formats/labelme.py b/cvat/apps/dataset_manager/formats/labelme.py index be9679f268e8..179fb320f322 100644 --- a/cvat/apps/dataset_manager/formats/labelme.py +++ b/cvat/apps/dataset_manager/formats/labelme.py @@ -6,8 +6,11 @@ from datumaro.components.dataset import Dataset from pyunpack import Archive -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, - import_dm_annotations) +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + import_dm_annotations, +) from cvat.apps.dataset_manager.formats.transformations import MaskToPolygonTransformation from cvat.apps.dataset_manager.util import make_zip_archive diff --git a/cvat/apps/dataset_manager/formats/lfw.py b/cvat/apps/dataset_manager/formats/lfw.py index 0af356332bb5..407240c5e0a3 100644 --- a/cvat/apps/dataset_manager/formats/lfw.py +++ b/cvat/apps/dataset_manager/formats/lfw.py @@ -6,8 +6,11 @@ from datumaro.components.dataset import Dataset from pyunpack import Archive -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, - import_dm_annotations) +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + import_dm_annotations, +) from cvat.apps.dataset_manager.util import make_zip_archive from .registry import dm_env, exporter, importer diff --git a/cvat/apps/dataset_manager/formats/market1501.py b/cvat/apps/dataset_manager/formats/market1501.py index 6be8b2fcf75f..e9d46a095bc8 100644 --- a/cvat/apps/dataset_manager/formats/market1501.py +++ b/cvat/apps/dataset_manager/formats/market1501.py @@ -5,17 +5,20 @@ import zipfile -from datumaro.components.annotation import (AnnotationType, Label, - LabelCategories) +from datumaro.components.annotation import AnnotationType, Label, LabelCategories from datumaro.components.dataset import Dataset from datumaro.components.extractor import ItemTransform -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, - import_dm_annotations) +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + import_dm_annotations, +) from cvat.apps.dataset_manager.util import make_zip_archive from .registry import dm_env, exporter, importer + class AttrToLabelAttr(ItemTransform): def __init__(self, extractor, label): super().__init__(extractor) diff --git a/cvat/apps/dataset_manager/formats/mask.py b/cvat/apps/dataset_manager/formats/mask.py index f003f68383e7..eab4238f4242 100644 --- a/cvat/apps/dataset_manager/formats/mask.py +++ b/cvat/apps/dataset_manager/formats/mask.py @@ -6,14 +6,18 @@ from datumaro.components.dataset import Dataset from pyunpack import Archive -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, - import_dm_annotations) +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + import_dm_annotations, +) from cvat.apps.dataset_manager.util import make_zip_archive -from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons from .registry import dm_env, exporter, importer +from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons from .utils import make_colormap + @exporter(name='Segmentation mask', ext='ZIP', version='1.1') def _export(dst_file, temp_dir, instance_data, save_images=False): with GetCVATDataExtractor(instance_data, include_images=save_images) as extractor: diff --git a/cvat/apps/dataset_manager/formats/mots.py b/cvat/apps/dataset_manager/formats/mots.py index 9ed156e6cd4e..736ccb1ce0f8 100644 --- a/cvat/apps/dataset_manager/formats/mots.py +++ b/cvat/apps/dataset_manager/formats/mots.py @@ -8,12 +8,16 @@ from datumaro.components.extractor import ItemTransform from pyunpack import Archive -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, - find_dataset_root, match_dm_item) +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + find_dataset_root, + match_dm_item, +) from cvat.apps.dataset_manager.util import make_zip_archive -from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons from .registry import dm_env, exporter, importer +from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons class KeepTracks(ItemTransform): diff --git a/cvat/apps/dataset_manager/formats/openimages.py b/cvat/apps/dataset_manager/formats/openimages.py index 51fcee29a2fb..2ae544238ee2 100644 --- a/cvat/apps/dataset_manager/formats/openimages.py +++ b/cvat/apps/dataset_manager/formats/openimages.py @@ -7,16 +7,21 @@ import os.path as osp from datumaro.components.dataset import Dataset, DatasetItem -from datumaro.plugins.open_images_format import OpenImagesPath +from datumaro.plugins.data_formats.open_images import OpenImagesPath from datumaro.util.image import DEFAULT_IMAGE_META_FILE_NAME from pyunpack import Archive -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, - find_dataset_root, import_dm_annotations, match_dm_item) +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + find_dataset_root, + import_dm_annotations, + match_dm_item, +) from cvat.apps.dataset_manager.util import make_zip_archive -from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons from .registry import dm_env, exporter, importer +from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons def find_item_ids(path): diff --git a/cvat/apps/dataset_manager/formats/pascal_voc.py b/cvat/apps/dataset_manager/formats/pascal_voc.py index a0d84b745d73..3b55928e1f90 100644 --- a/cvat/apps/dataset_manager/formats/pascal_voc.py +++ b/cvat/apps/dataset_manager/formats/pascal_voc.py @@ -11,7 +11,11 @@ from datumaro.components.dataset import Dataset from pyunpack import Archive -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, import_dm_annotations) +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + import_dm_annotations, +) from cvat.apps.dataset_manager.formats.transformations import MaskToPolygonTransformation from cvat.apps.dataset_manager.util import make_zip_archive diff --git a/cvat/apps/dataset_manager/formats/pointcloud.py b/cvat/apps/dataset_manager/formats/pointcloud.py index 6ddfbb495427..8743c6eb8f3c 100644 --- a/cvat/apps/dataset_manager/formats/pointcloud.py +++ b/cvat/apps/dataset_manager/formats/pointcloud.py @@ -7,8 +7,11 @@ from datumaro.components.dataset import Dataset -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, - import_dm_annotations) +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + import_dm_annotations, +) from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.engine.models import DimensionType diff --git a/cvat/apps/dataset_manager/formats/transformations.py b/cvat/apps/dataset_manager/formats/transformations.py index 99d754252378..496786126709 100644 --- a/cvat/apps/dataset_manager/formats/transformations.py +++ b/cvat/apps/dataset_manager/formats/transformations.py @@ -4,12 +4,12 @@ # SPDX-License-Identifier: MIT import math -import cv2 -import numpy as np from itertools import chain -from pycocotools import mask as mask_utils +import cv2 import datumaro as dm +import numpy as np +from pycocotools import mask as mask_utils class RotatedBoxesToPolygons(dm.ItemTransform): @@ -37,6 +37,7 @@ def transform_item(self, item): return item.wrap(annotations=annotations) + class MaskConverter: @staticmethod def cvat_rle_to_dm_rle(shape, img_h: int, img_w: int) -> dm.RleMask: @@ -100,6 +101,7 @@ def rle(cls, arr: np.ndarray) -> list[int]: return cvat_rle + class EllipsesToMasks: @staticmethod def convert_ellipse(ellipse, img_h, img_w): @@ -115,6 +117,7 @@ def convert_ellipse(ellipse, img_h, img_w): return dm.RleMask(rle=rle, label=ellipse.label, z_order=ellipse.z_order, attributes=ellipse.attributes, group=ellipse.group) + class MaskToPolygonTransformation: """ Manages common logic for mask to polygons conversion in dataset import. @@ -130,3 +133,13 @@ def convert_dataset(cls, dataset, **kwargs): if kwargs.get('conv_mask_to_poly', True): dataset.transform('masks_to_polygons') return dataset + + +class SetKeyframeForEveryTrackShape(dm.ItemTransform): + def transform_item(self, item): + annotations = [] + for ann in item.annotations: + if "track_id" in ann.attributes: + ann = ann.wrap(attributes=dict(ann.attributes, keyframe=True)) + annotations.append(ann) + return item.wrap(annotations=annotations) diff --git a/cvat/apps/dataset_manager/formats/utils.py b/cvat/apps/dataset_manager/formats/utils.py index 7811fbbfc902..f565c0aed687 100644 --- a/cvat/apps/dataset_manager/formats/utils.py +++ b/cvat/apps/dataset_manager/formats/utils.py @@ -2,13 +2,14 @@ # # SPDX-License-Identifier: MIT -import os.path as osp -from hashlib import blake2s import itertools import operator +import os.path as osp +from hashlib import blake2s from datumaro.util.os_util import make_file_name + def get_color_from_index(index): def get_bit(number, index): return (number >> index) & 1 diff --git a/cvat/apps/dataset_manager/formats/velodynepoint.py b/cvat/apps/dataset_manager/formats/velodynepoint.py index 9912d0b1d67b..d6051bf6fce8 100644 --- a/cvat/apps/dataset_manager/formats/velodynepoint.py +++ b/cvat/apps/dataset_manager/formats/velodynepoint.py @@ -8,14 +8,16 @@ from datumaro.components.dataset import Dataset from datumaro.components.extractor import ItemTransform -from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, detect_dataset, \ - import_dm_annotations -from .registry import dm_env - +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + import_dm_annotations, +) from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.engine.models import DimensionType -from .registry import exporter, importer +from .registry import dm_env, exporter, importer + class RemoveTrackingInformation(ItemTransform): def transform_item(self, item): diff --git a/cvat/apps/dataset_manager/formats/vggface2.py b/cvat/apps/dataset_manager/formats/vggface2.py index 642171f0f8d9..aa172f947db3 100644 --- a/cvat/apps/dataset_manager/formats/vggface2.py +++ b/cvat/apps/dataset_manager/formats/vggface2.py @@ -7,8 +7,12 @@ from datumaro.components.dataset import Dataset -from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, TaskData, detect_dataset, \ - import_dm_annotations +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + TaskData, + detect_dataset, + import_dm_annotations, +) from cvat.apps.dataset_manager.util import make_zip_archive from .registry import dm_env, exporter, importer diff --git a/cvat/apps/dataset_manager/formats/widerface.py b/cvat/apps/dataset_manager/formats/widerface.py index 12a9bf0d21e5..99480bf1f8f5 100644 --- a/cvat/apps/dataset_manager/formats/widerface.py +++ b/cvat/apps/dataset_manager/formats/widerface.py @@ -7,8 +7,11 @@ from datumaro.components.dataset import Dataset -from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, detect_dataset, \ - import_dm_annotations +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, + detect_dataset, + import_dm_annotations, +) from cvat.apps.dataset_manager.util import make_zip_archive from .registry import dm_env, exporter, importer diff --git a/cvat/apps/dataset_manager/formats/yolo.py b/cvat/apps/dataset_manager/formats/yolo.py index 887232b8e666..2bcfdfca1325 100644 --- a/cvat/apps/dataset_manager/formats/yolo.py +++ b/cvat/apps/dataset_manager/formats/yolo.py @@ -4,28 +4,40 @@ # SPDX-License-Identifier: MIT import os.path as osp from glob import glob +from typing import Callable, Optional +from datumaro.components.annotation import AnnotationType +from datumaro.components.extractor import DatasetItem +from datumaro.components.project import Dataset from pyunpack import Archive from cvat.apps.dataset_manager.bindings import ( + CommonData, GetCVATDataExtractor, + ProjectData, detect_dataset, + find_dataset_root, import_dm_annotations, match_dm_item, - find_dataset_root, ) from cvat.apps.dataset_manager.util import make_zip_archive -from datumaro.components.annotation import AnnotationType -from datumaro.components.extractor import DatasetItem -from datumaro.components.project import Dataset from .registry import dm_env, exporter, importer +from .transformations import SetKeyframeForEveryTrackShape -def _export_common(dst_file, temp_dir, instance_data, format_name, *, save_images=False): +def _export_common( + dst_file: str, + temp_dir: str, + instance_data: ProjectData | CommonData, + format_name: str, + *, + save_images: bool = False, + **kwargs +): with GetCVATDataExtractor(instance_data, include_images=save_images) as extractor: dataset = Dataset.from_extractors(extractor, env=dm_env) - dataset.export(temp_dir, format_name, save_images=save_images) + dataset.export(temp_dir, format_name, save_images=save_images, **kwargs) make_zip_archive(temp_dir, dst_file) @@ -37,12 +49,12 @@ def _export_yolo(*args, **kwargs): def _import_common( src_file, - temp_dir, - instance_data, - format_name, + temp_dir: str, + instance_data: ProjectData | CommonData, + format_name: str, *, - load_data_callback=None, - import_kwargs=None, + load_data_callback: Optional[Callable] = None, + import_kwargs: dict | None = None, **kwargs ): Archive(src_file.name).extractall(temp_dir) @@ -67,6 +79,7 @@ def _import_common( detect_dataset(temp_dir, format_name=format_name, importer=dm_env.importers.get(format_name)) dataset = Dataset.import_from(temp_dir, format_name, env=dm_env, image_info=image_info, **(import_kwargs or {})) + dataset = dataset.transform(SetKeyframeForEveryTrackShape) if load_data_callback is not None: load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) @@ -77,53 +90,58 @@ def _import_yolo(*args, **kwargs): _import_common(*args, format_name="yolo", **kwargs) -@exporter(name='YOLOv8 Detection', ext='ZIP', version='1.0') -def _export_yolov8_detection(*args, **kwargs): - _export_common(*args, format_name='yolov8_detection', **kwargs) +@exporter(name='Ultralytics YOLO Detection', ext='ZIP', version='1.0') +def _export_yolo_ultralytics_detection(*args, **kwargs): + _export_common(*args, format_name='yolo_ultralytics_detection', **kwargs) + + +@exporter(name='Ultralytics YOLO Detection Track', ext='ZIP', version='1.0') +def _export_yolo_ultralytics_detection_track(*args, **kwargs): + _export_common(*args, format_name='yolo_ultralytics_detection', write_track_id=True, **kwargs) -@exporter(name='YOLOv8 Oriented Bounding Boxes', ext='ZIP', version='1.0') -def _export_yolov8_oriented_boxes(*args, **kwargs): - _export_common(*args, format_name='yolov8_oriented_boxes', **kwargs) +@exporter(name='Ultralytics YOLO Oriented Bounding Boxes', ext='ZIP', version='1.0') +def _export_yolo_ultralytics_oriented_boxes(*args, **kwargs): + _export_common(*args, format_name='yolo_ultralytics_oriented_boxes', **kwargs) -@exporter(name='YOLOv8 Segmentation', ext='ZIP', version='1.0') -def _export_yolov8_segmentation(dst_file, temp_dir, instance_data, *, save_images=False): +@exporter(name='Ultralytics YOLO Segmentation', ext='ZIP', version='1.0') +def _export_yolo_ultralytics_segmentation(dst_file, temp_dir, instance_data, *, save_images=False): with GetCVATDataExtractor(instance_data, include_images=save_images) as extractor: dataset = Dataset.from_extractors(extractor, env=dm_env) dataset = dataset.transform('masks_to_polygons') - dataset.export(temp_dir, 'yolov8_segmentation', save_images=save_images) + dataset.export(temp_dir, 'yolo_ultralytics_segmentation', save_images=save_images) make_zip_archive(temp_dir, dst_file) -@exporter(name='YOLOv8 Pose', ext='ZIP', version='1.0') -def _export_yolov8_pose(*args, **kwargs): - _export_common(*args, format_name='yolov8_pose', **kwargs) +@exporter(name='Ultralytics YOLO Pose', ext='ZIP', version='1.0') +def _export_yolo_ultralytics_pose(*args, **kwargs): + _export_common(*args, format_name='yolo_ultralytics_pose', **kwargs) -@exporter(name='YOLOv8 Classification', ext='ZIP', version='1.0') -def _export_yolov8_classification(*args, **kwargs): - _export_common(*args, format_name='yolov8_classification', **kwargs) +@exporter(name='Ultralytics YOLO Classification', ext='ZIP', version='1.0') +def _export_yolo_ultralytics_classification(*args, **kwargs): + _export_common(*args, format_name='yolo_ultralytics_classification', **kwargs) -@importer(name='YOLOv8 Detection', ext="ZIP", version="1.0") -def _import_yolov8_detection(*args, **kwargs): - _import_common(*args, format_name="yolov8_detection", **kwargs) +@importer(name='Ultralytics YOLO Detection', ext="ZIP", version="1.0") +def _import_yolo_ultralytics_detection(*args, **kwargs): + _import_common(*args, format_name="yolo_ultralytics_detection", **kwargs) -@importer(name='YOLOv8 Segmentation', ext="ZIP", version="1.0") -def _import_yolov8_segmentation(*args, **kwargs): - _import_common(*args, format_name="yolov8_segmentation", **kwargs) +@importer(name='Ultralytics YOLO Segmentation', ext="ZIP", version="1.0") +def _import_yolo_ultralytics_segmentation(*args, **kwargs): + _import_common(*args, format_name="yolo_ultralytics_segmentation", **kwargs) -@importer(name='YOLOv8 Oriented Bounding Boxes', ext="ZIP", version="1.0") -def _import_yolov8_oriented_boxes(*args, **kwargs): - _import_common(*args, format_name="yolov8_oriented_boxes", **kwargs) +@importer(name='Ultralytics YOLO Oriented Bounding Boxes', ext="ZIP", version="1.0") +def _import_yolo_ultralytics_oriented_boxes(*args, **kwargs): + _import_common(*args, format_name="yolo_ultralytics_oriented_boxes", **kwargs) -@importer(name='YOLOv8 Pose', ext="ZIP", version="1.0") -def _import_yolov8_pose(src_file, temp_dir, instance_data, **kwargs): +@importer(name='Ultralytics YOLO Pose', ext="ZIP", version="1.0") +def _import_yolo_ultralytics_pose(src_file, temp_dir, instance_data, **kwargs): with GetCVATDataExtractor(instance_data) as extractor: point_categories = extractor.categories().get(AnnotationType.points) label_categories = extractor.categories().get(AnnotationType.label) @@ -135,12 +153,12 @@ def _import_yolov8_pose(src_file, temp_dir, instance_data, **kwargs): src_file, temp_dir, instance_data, - format_name="yolov8_pose", + format_name="yolo_ultralytics_pose", import_kwargs=dict(skeleton_sub_labels=true_skeleton_point_labels), **kwargs ) -@importer(name='YOLOv8 Classification', ext="ZIP", version="1.0") -def _import_yolov8_classification(*args, **kwargs): - _import_common(*args, format_name="yolov8_classification", **kwargs) +@importer(name='Ultralytics YOLO Classification', ext="ZIP", version="1.0") +def _import_yolo_ultralytics_classification(*args, **kwargs): + _import_common(*args, format_name="yolo_ultralytics_classification", **kwargs) diff --git a/cvat/apps/dataset_manager/project.py b/cvat/apps/dataset_manager/project.py index 93ac651cf477..ad51370b04e1 100644 --- a/cvat/apps/dataset_manager/project.py +++ b/cvat/apps/dataset_manager/project.py @@ -6,22 +6,22 @@ import os from collections.abc import Mapping from tempfile import TemporaryDirectory -import rq from typing import Any, Callable -from datumaro.components.errors import DatasetError, DatasetImportError, DatasetNotFoundError -from django.db import transaction +import rq +from datumaro.components.errors import DatasetError, DatasetImportError, DatasetNotFoundError from django.conf import settings +from django.db import transaction +from cvat.apps.dataset_manager.task import TaskAnnotation from cvat.apps.engine import models from cvat.apps.engine.log import DatasetLogManager +from cvat.apps.engine.rq_job_handler import RQJobMetaField from cvat.apps.engine.serializers import DataSerializer, TaskWriteSerializer from cvat.apps.engine.task import _create_thread as create_task -from cvat.apps.engine.rq_job_handler import RQJobMetaField -from cvat.apps.dataset_manager.task import TaskAnnotation from .annotation import AnnotationIR -from .bindings import CvatDatasetNotFoundError, ProjectData, load_dataset_data, CvatImportError +from .bindings import CvatDatasetNotFoundError, CvatImportError, ProjectData, load_dataset_data from .formats.registry import make_exporter, make_importer dlogger = DatasetLogManager() diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py index 83886d7e9cf1..74f035d40787 100644 --- a/cvat/apps/dataset_manager/task.py +++ b/cvat/apps/dataset_manager/task.py @@ -10,27 +10,34 @@ from enum import Enum from tempfile import TemporaryDirectory from typing import Optional, Union -from datumaro.components.errors import DatasetError, DatasetImportError, DatasetNotFoundError +from datumaro.components.errors import DatasetError, DatasetImportError, DatasetNotFoundError +from django.conf import settings from django.db import transaction from django.db.models.query import Prefetch, QuerySet -from django.conf import settings from rest_framework.exceptions import ValidationError +from cvat.apps.dataset_manager.annotation import AnnotationIR, AnnotationManager +from cvat.apps.dataset_manager.bindings import ( + CvatDatasetNotFoundError, + CvatImportError, + JobData, + TaskData, +) +from cvat.apps.dataset_manager.formats.registry import make_exporter, make_importer +from cvat.apps.dataset_manager.util import ( + add_prefetch_fields, + bulk_create, + faster_deepcopy, + get_cached, +) from cvat.apps.engine import models, serializers -from cvat.apps.engine.plugins import plugin_decorator from cvat.apps.engine.log import DatasetLogManager +from cvat.apps.engine.plugins import plugin_decorator from cvat.apps.engine.utils import take_by from cvat.apps.events.handlers import handle_annotations_change from cvat.apps.profiler import silk_profile -from cvat.apps.dataset_manager.annotation import AnnotationIR, AnnotationManager -from cvat.apps.dataset_manager.bindings import TaskData, JobData, CvatImportError, CvatDatasetNotFoundError -from cvat.apps.dataset_manager.formats.registry import make_exporter, make_importer -from cvat.apps.dataset_manager.util import ( - add_prefetch_fields, bulk_create, get_cached, faster_deepcopy -) - dlogger = DatasetLogManager() class dotdict(OrderedDict): diff --git a/cvat/apps/dataset_manager/tests/assets/annotations.json b/cvat/apps/dataset_manager/tests/assets/annotations.json index 9f7c27b94bcb..2a1d7f70696c 100644 --- a/cvat/apps/dataset_manager/tests/assets/annotations.json +++ b/cvat/apps/dataset_manager/tests/assets/annotations.json @@ -976,7 +976,7 @@ ], "tracks": [] }, - "YOLOv8 Classification 1.0": { + "Ultralytics YOLO Classification 1.0": { "version": 0, "tags": [ { @@ -990,7 +990,7 @@ "shapes": [], "tracks": [] }, - "YOLOv8 Detection 1.0": { + "Ultralytics YOLO Detection 1.0": { "version": 0, "tags": [], "shapes": [ @@ -1008,7 +1008,55 @@ ], "tracks": [] }, - "YOLOv8 Oriented Bounding Boxes 1.0": { + "Ultralytics YOLO Detection Track 1.0": { + "version": 0, + "tags": [], + "shapes": [ + { + "type": "rectangle", + "occluded": false, + "z_order": 0, + "points": [0.3, 0.1, 0.2, 0.8], + "frame": 0, + "label_id": null, + "group": 0, + "source": "manual", + "attributes": [] + } + ], + "tracks": [ + { + "frame": 0, + "label_id": null, + "group": 0, + "source": "manual", + "shapes": [ + { + "type": "rectangle", + "occluded": false, + "z_order": 0, + "points": [0.2, 0.1, 0.2, 0.8], + "frame": 0, + "outside": false, + "attributes": [], + "keyframe": true + }, + { + "type": "rectangle", + "occluded": false, + "z_order": 0, + "points": [0.4, 0.1, 0.2, 0.8], + "frame": 1, + "outside": true, + "attributes": [], + "keyframe": true + } + ], + "attributes": [] + } + ] + }, + "Ultralytics YOLO Oriented Bounding Boxes 1.0": { "version": 0, "tags": [], "shapes": [ @@ -1027,7 +1075,7 @@ ], "tracks": [] }, - "YOLOv8 Segmentation 1.0": { + "Ultralytics YOLO Segmentation 1.0": { "version": 0, "tags": [], "shapes": [ @@ -1045,7 +1093,7 @@ ], "tracks": [] }, - "YOLOv8 Pose 1.0": { + "Ultralytics YOLO Pose 1.0": { "version": 0, "tags": [], "shapes": [ diff --git a/cvat/apps/dataset_manager/tests/assets/tasks.json b/cvat/apps/dataset_manager/tests/assets/tasks.json index 2c29ce712929..ad68c6f5aa5f 100644 --- a/cvat/apps/dataset_manager/tests/assets/tasks.json +++ b/cvat/apps/dataset_manager/tests/assets/tasks.json @@ -634,8 +634,8 @@ } ] }, - "YOLOv8 Pose 1.0": { - "name": "YOLOv8 pose task", + "Ultralytics YOLO Pose 1.0": { + "name": "Ultralytics YOLO pose task", "overlap": 0, "segment_size": 100, "labels": [ diff --git a/cvat/apps/dataset_manager/tests/test_formats.py b/cvat/apps/dataset_manager/tests/test_formats.py index 91a3081ca089..097884092de0 100644 --- a/cvat/apps/dataset_manager/tests/test_formats.py +++ b/cvat/apps/dataset_manager/tests/test_formats.py @@ -4,28 +4,33 @@ # # SPDX-License-Identifier: MIT -import numpy as np import os.path as osp import tempfile import zipfile from io import BytesIO import datumaro -from datumaro.components.dataset import Dataset, DatasetItem +import numpy as np from datumaro.components.annotation import Mask +from datumaro.components.dataset import Dataset, DatasetItem from django.contrib.auth.models import Group, User - from rest_framework import status import cvat.apps.dataset_manager as dm from cvat.apps.dataset_manager.annotation import AnnotationIR -from cvat.apps.dataset_manager.bindings import (CvatTaskOrJobDataExtractor, - TaskData, find_dataset_root) +from cvat.apps.dataset_manager.bindings import ( + CvatTaskOrJobDataExtractor, + TaskData, + find_dataset_root, +) from cvat.apps.dataset_manager.task import TaskAnnotation from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.engine.models import Task from cvat.apps.engine.tests.utils import ( - get_paginated_collection, ForceLogin, generate_image_file, ApiTestBase + ApiTestBase, + ForceLogin, + generate_image_file, + get_paginated_collection, ) @@ -292,11 +297,12 @@ def test_export_formats_query(self): 'LFW 1.0', 'Cityscapes 1.0', 'Open Images V6 1.0', - 'YOLOv8 Classification 1.0', - 'YOLOv8 Oriented Bounding Boxes 1.0', - 'YOLOv8 Detection 1.0', - 'YOLOv8 Pose 1.0', - 'YOLOv8 Segmentation 1.0', + 'Ultralytics YOLO Classification 1.0', + 'Ultralytics YOLO Oriented Bounding Boxes 1.0', + 'Ultralytics YOLO Detection 1.0', + 'Ultralytics YOLO Detection Track 1.0', + 'Ultralytics YOLO Pose 1.0', + 'Ultralytics YOLO Segmentation 1.0', }) def test_import_formats_query(self): @@ -329,11 +335,11 @@ def test_import_formats_query(self): 'Open Images V6 1.0', 'Datumaro 1.0', 'Datumaro 3D 1.0', - 'YOLOv8 Classification 1.0', - 'YOLOv8 Oriented Bounding Boxes 1.0', - 'YOLOv8 Detection 1.0', - 'YOLOv8 Pose 1.0', - 'YOLOv8 Segmentation 1.0', + 'Ultralytics YOLO Classification 1.0', + 'Ultralytics YOLO Oriented Bounding Boxes 1.0', + 'Ultralytics YOLO Detection 1.0', + 'Ultralytics YOLO Pose 1.0', + 'Ultralytics YOLO Segmentation 1.0', }) def test_exports(self): @@ -383,11 +389,11 @@ def test_empty_images_are_exported(self): # ('KITTI 1.0', 'kitti') format does not support empty annotations ('LFW 1.0', 'lfw'), # ('Cityscapes 1.0', 'cityscapes'), does not support, empty annotations - ('YOLOv8 Classification 1.0', 'yolov8_classification'), - ('YOLOv8 Oriented Bounding Boxes 1.0', 'yolov8_oriented_boxes'), - ('YOLOv8 Detection 1.0', 'yolov8_detection'), - ('YOLOv8 Pose 1.0', 'yolov8_pose'), - ('YOLOv8 Segmentation 1.0', 'yolov8_segmentation'), + ('Ultralytics YOLO Classification 1.0', 'yolo_ultralytics_classification'), + ('Ultralytics YOLO Oriented Bounding Boxes 1.0', 'yolo_ultralytics_oriented_boxes'), + ('Ultralytics YOLO Detection 1.0', 'yolo_ultralytics_detection'), + ('Ultralytics YOLO Pose 1.0', 'yolo_ultralytics_pose'), + ('Ultralytics YOLO Segmentation 1.0', 'yolo_ultralytics_segmentation'), ]: with self.subTest(format=format_name): if not dm.formats.registry.EXPORT_FORMATS[format_name].ENABLED: diff --git a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py index 50883826b5a5..fe1addd2cbc5 100644 --- a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py +++ b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py @@ -6,11 +6,9 @@ import copy import itertools import json -import os.path as osp -import os import multiprocessing -import av -import numpy as np +import os +import os.path as osp import random import shutil import xml.etree.ElementTree as ET @@ -22,8 +20,11 @@ from tempfile import TemporaryDirectory from time import sleep from typing import Any, Callable, ClassVar, Optional, overload -from unittest.mock import MagicMock, patch, DEFAULT as MOCK_DEFAULT +from unittest.mock import DEFAULT as MOCK_DEFAULT +from unittest.mock import MagicMock, patch +import av +import numpy as np from attr import define, field from datumaro.components.dataset import Dataset from datumaro.components.operations import ExactComparator @@ -38,7 +39,7 @@ from cvat.apps.dataset_manager.util import get_export_cache_lock from cvat.apps.dataset_manager.views import clear_export_cache, export, parse_export_file_path from cvat.apps.engine.models import Task -from cvat.apps.engine.tests.utils import get_paginated_collection, ApiTestBase, ForceLogin +from cvat.apps.engine.tests.utils import ApiTestBase, ForceLogin, get_paginated_collection projects_path = osp.join(osp.dirname(__file__), 'assets', 'projects.json') with open(projects_path) as file: @@ -55,12 +56,13 @@ DEFAULT_ATTRIBUTES_FORMATS = [ "VGGFace2 1.0", "WiderFace 1.0", - "YOLOv8 Classification 1.0", + "Ultralytics YOLO Classification 1.0", "YOLO 1.1", - "YOLOv8 Detection 1.0", - "YOLOv8 Segmentation 1.0", - "YOLOv8 Oriented Bounding Boxes 1.0", - "YOLOv8 Pose 1.0", + "Ultralytics YOLO Detection 1.0", + "Ultralytics YOLO Detection Track 1.0", + "Ultralytics YOLO Segmentation 1.0", + "Ultralytics YOLO Oriented Bounding Boxes 1.0", + "Ultralytics YOLO Pose 1.0", "PASCAL VOC 1.1", "Segmentation mask 1.1", "ImageNet 1.0", @@ -411,7 +413,7 @@ def test_api_v2_dump_and_upload_annotations_with_objects_type_is_shape(self): "Cityscapes 1.0", "COCO Keypoints 1.0", "ICDAR Localization 1.0", "ICDAR Recognition 1.0", "ICDAR Segmentation 1.0", "Market-1501 1.0", "MOT 1.1", - "YOLOv8 Pose 1.0", + "Ultralytics YOLO Pose 1.0", ]: task = self._create_task(tasks[dump_format_name], images) else: @@ -469,7 +471,7 @@ def test_api_v2_dump_and_upload_annotations_with_objects_type_is_shape(self): "Cityscapes 1.0", "COCO Keypoints 1.0", "ICDAR Localization 1.0", "ICDAR Recognition 1.0", "ICDAR Segmentation 1.0", "Market-1501 1.0", "MOT 1.1", - "YOLOv8 Pose 1.0", + "Ultralytics YOLO Pose 1.0", ]: task = self._create_task(tasks[upload_format_name], images) else: @@ -513,7 +515,7 @@ def test_api_v2_dump_annotations_with_objects_type_is_track(self): "Cityscapes 1.0", "COCO Keypoints 1.0", "ICDAR Localization 1.0", "ICDAR Recognition 1.0", "ICDAR Segmentation 1.0", "Market-1501 1.0", "MOT 1.1", - "YOLOv8 Pose 1.0", + "Ultralytics YOLO Pose 1.0", ]: task = self._create_task(tasks[dump_format_name], video) else: @@ -569,7 +571,7 @@ def test_api_v2_dump_annotations_with_objects_type_is_track(self): "Cityscapes 1.0", "COCO Keypoints 1.0", "ICDAR Localization 1.0", "ICDAR Recognition 1.0", "ICDAR Segmentation 1.0", "Market-1501 1.0", "MOT 1.1", - "YOLOv8 Pose 1.0", + "Ultralytics YOLO Pose 1.0", ]: task = self._create_task(tasks[upload_format_name], video) else: @@ -846,7 +848,7 @@ def test_api_v2_export_dataset(self): "Cityscapes 1.0", "COCO Keypoints 1.0", "ICDAR Localization 1.0", "ICDAR Recognition 1.0", "ICDAR Segmentation 1.0", "Market-1501 1.0", "MOT 1.1", - "YOLOv8 Pose 1.0", + "Ultralytics YOLO Pose 1.0", ]: task = self._create_task(tasks[dump_format_name], images) else: @@ -947,7 +949,7 @@ def test_api_v2_rewriting_annotations(self): if dump_format_name in [ "Market-1501 1.0", "ICDAR Localization 1.0", "ICDAR Recognition 1.0", - "ICDAR Segmentation 1.0", "COCO Keypoints 1.0", "YOLOv8 Pose 1.0", + "ICDAR Segmentation 1.0", "COCO Keypoints 1.0", "Ultralytics YOLO Pose 1.0", ]: task = self._create_task(tasks[dump_format_name], images) else: @@ -979,6 +981,8 @@ def test_api_v2_rewriting_annotations(self): if dump_format_name == "CVAT for images 1.1" or dump_format_name == "CVAT for video 1.1": dump_format_name = "CVAT 1.1" + elif dump_format_name == "Ultralytics YOLO Detection Track 1.0": + dump_format_name = "Ultralytics YOLO Detection 1.0" url = self._generate_url_upload_tasks_annotations(task_id, dump_format_name) with open(file_zip_name, 'rb') as binary_file: @@ -1058,7 +1062,7 @@ def test_api_v2_tasks_annotations_dump_and_upload_with_datumaro(self): "Market-1501 1.0", "Cityscapes 1.0", "ICDAR Localization 1.0", "ICDAR Recognition 1.0", "ICDAR Segmentation 1.0", "COCO Keypoints 1.0", - "YOLOv8 Pose 1.0", + "Ultralytics YOLO Pose 1.0", ]: task = self._create_task(tasks[dump_format_name], images) else: @@ -1092,6 +1096,8 @@ def test_api_v2_tasks_annotations_dump_and_upload_with_datumaro(self): # upload annotations if dump_format_name in ["CVAT for images 1.1", "CVAT for video 1.1"]: upload_format_name = "CVAT 1.1" + elif dump_format_name in ['Ultralytics YOLO Detection Track 1.0']: + upload_format_name = 'Ultralytics YOLO Detection 1.0' else: upload_format_name = dump_format_name url = self._generate_url_upload_tasks_annotations(task_id, upload_format_name) @@ -1451,8 +1457,8 @@ def _export(*_, task_id: int): import sys from os import replace as original_replace from os.path import exists as original_exists - from cvat.apps.dataset_manager.task import export_task as original_export_task + from cvat.apps.dataset_manager.task import export_task as original_export_task from cvat.apps.dataset_manager.views import log_exception as original_log_exception def patched_log_exception(logger=None, exc_info=True): diff --git a/cvat/apps/dataset_manager/views.py b/cvat/apps/dataset_manager/views.py index 52bc9cd15f7a..4dcd8304e43d 100644 --- a/cvat/apps/dataset_manager/views.py +++ b/cvat/apps/dataset_manager/views.py @@ -8,10 +8,10 @@ import os.path as osp import tempfile from datetime import timedelta +from os.path import exists as osp_exists import django_rq import rq -from os.path import exists as osp_exists from django.conf import settings from django.utils import timezone from rq_scheduler import Scheduler @@ -20,18 +20,20 @@ import cvat.apps.dataset_manager.task as task from cvat.apps.engine.log import ServerLogManager from cvat.apps.engine.models import Job, Project, Task -from cvat.apps.engine.utils import get_rq_lock_by_user from cvat.apps.engine.rq_job_handler import RQMeta +from cvat.apps.engine.utils import get_rq_lock_by_user from .formats.registry import EXPORT_FORMATS, IMPORT_FORMATS +from .util import EXPORT_CACHE_DIR_NAME # pylint: disable=unused-import from .util import ( LockNotAvailableError, - current_function_name, get_export_cache_lock, - get_export_cache_dir, make_export_filename, - parse_export_file_path, extend_export_file_lifetime + current_function_name, + extend_export_file_lifetime, + get_export_cache_dir, + get_export_cache_lock, + make_export_filename, + parse_export_file_path, ) -from .util import EXPORT_CACHE_DIR_NAME # pylint: disable=unused-import - slogger = ServerLogManager(__name__) diff --git a/cvat/apps/dataset_repo/migrations/0001_initial.py b/cvat/apps/dataset_repo/migrations/0001_initial.py index 2ecf9c17c9b9..fa02f8c54b5d 100644 --- a/cvat/apps/dataset_repo/migrations/0001_initial.py +++ b/cvat/apps/dataset_repo/migrations/0001_initial.py @@ -1,7 +1,7 @@ # Generated by Django 2.1.3 on 2018-12-05 13:24 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): @@ -9,23 +9,31 @@ class Migration(migrations.Migration): initial = True dependencies = [ - ('engine', '0014_job_max_shape_id'), + ("engine", "0014_job_max_shape_id"), ] - replaces = [('git', '0001_initial')] + replaces = [("git", "0001_initial")] operations = [ migrations.CreateModel( - name='GitData', + name="GitData", fields=[ - ('task', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False, to='engine.Task')), - ('url', models.URLField(max_length=2000)), - ('path', models.CharField(max_length=256)), - ('sync_date', models.DateTimeField(auto_now_add=True)), - ('status', models.CharField(default='!sync', max_length=20)), + ( + "task", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + primary_key=True, + serialize=False, + to="engine.Task", + ), + ), + ("url", models.URLField(max_length=2000)), + ("path", models.CharField(max_length=256)), + ("sync_date", models.DateTimeField(auto_now_add=True)), + ("status", models.CharField(default="!sync", max_length=20)), ], options={ - 'db_table': 'git_gitdata', + "db_table": "git_gitdata", }, ), ] diff --git a/cvat/apps/dataset_repo/migrations/0002_auto_20190123_1305.py b/cvat/apps/dataset_repo/migrations/0002_auto_20190123_1305.py index 13fb92b8658e..ce0be5cbbc39 100644 --- a/cvat/apps/dataset_repo/migrations/0002_auto_20190123_1305.py +++ b/cvat/apps/dataset_repo/migrations/0002_auto_20190123_1305.py @@ -6,15 +6,15 @@ class Migration(migrations.Migration): dependencies = [ - ('dataset_repo', '0001_initial'), + ("dataset_repo", "0001_initial"), ] - replaces = [('git', '0002_auto_20190123_1305')] + replaces = [("git", "0002_auto_20190123_1305")] operations = [ migrations.AlterField( - model_name='gitdata', - name='status', - field=models.CharField(default='!sync', max_length=20), + model_name="gitdata", + name="status", + field=models.CharField(default="!sync", max_length=20), ), ] diff --git a/cvat/apps/dataset_repo/migrations/0003_gitdata_lfs.py b/cvat/apps/dataset_repo/migrations/0003_gitdata_lfs.py index b42ebd30db29..1e845e48a108 100644 --- a/cvat/apps/dataset_repo/migrations/0003_gitdata_lfs.py +++ b/cvat/apps/dataset_repo/migrations/0003_gitdata_lfs.py @@ -6,15 +6,15 @@ class Migration(migrations.Migration): dependencies = [ - ('dataset_repo', '0002_auto_20190123_1305'), + ("dataset_repo", "0002_auto_20190123_1305"), ] - replaces = [('git', '0003_gitdata_lfs')] + replaces = [("git", "0003_gitdata_lfs")] operations = [ migrations.AddField( - model_name='gitdata', - name='lfs', + model_name="gitdata", + name="lfs", field=models.BooleanField(default=True), ), ] diff --git a/cvat/apps/dataset_repo/migrations/0004_rename.py b/cvat/apps/dataset_repo/migrations/0004_rename.py index 9629165722d1..94b820dcaa56 100644 --- a/cvat/apps/dataset_repo/migrations/0004_rename.py +++ b/cvat/apps/dataset_repo/migrations/0004_rename.py @@ -1,16 +1,18 @@ from django.db import migrations + def update_contenttypes_table(apps, schema_editor): - content_type_model = apps.get_model('contenttypes', 'ContentType') - content_type_model.objects.filter(app_label='git').update(app_label='dataset_repo') + content_type_model = apps.get_model("contenttypes", "ContentType") + content_type_model.objects.filter(app_label="git").update(app_label="dataset_repo") + class Migration(migrations.Migration): dependencies = [ - ('dataset_repo', '0003_gitdata_lfs'), + ("dataset_repo", "0003_gitdata_lfs"), ] operations = [ - migrations.AlterModelTable('gitdata', 'dataset_repo_gitdata'), + migrations.AlterModelTable("gitdata", "dataset_repo_gitdata"), migrations.RunPython(update_contenttypes_table), ] diff --git a/cvat/apps/dataset_repo/migrations/0005_auto_20201019_1100.py b/cvat/apps/dataset_repo/migrations/0005_auto_20201019_1100.py index f26c280b7f84..8c07d05d29f3 100644 --- a/cvat/apps/dataset_repo/migrations/0005_auto_20201019_1100.py +++ b/cvat/apps/dataset_repo/migrations/0005_auto_20201019_1100.py @@ -6,12 +6,12 @@ class Migration(migrations.Migration): dependencies = [ - ('dataset_repo', '0004_rename'), + ("dataset_repo", "0004_rename"), ] operations = [ migrations.AlterModelTable( - name='gitdata', + name="gitdata", table=None, ), ] diff --git a/cvat/apps/dataset_repo/migrations/0006_gitdata_format.py b/cvat/apps/dataset_repo/migrations/0006_gitdata_format.py index 641d246743eb..1b42f2d3caea 100644 --- a/cvat/apps/dataset_repo/migrations/0006_gitdata_format.py +++ b/cvat/apps/dataset_repo/migrations/0006_gitdata_format.py @@ -4,21 +4,27 @@ def update_default_format_field(apps, schema_editor): - GitData = apps.get_model('dataset_repo', 'GitData') + GitData = apps.get_model("dataset_repo", "GitData") for git_data in GitData.objects.all(): if not git_data.format: - git_data.format = 'CVAT for images 1.1' if git_data.task.mode == 'annotation' else 'CVAT for video 1.1' + git_data.format = ( + "CVAT for images 1.1" + if git_data.task.mode == "annotation" + else "CVAT for video 1.1" + ) git_data.save() + + class Migration(migrations.Migration): dependencies = [ - ('dataset_repo', '0005_auto_20201019_1100'), + ("dataset_repo", "0005_auto_20201019_1100"), ] operations = [ migrations.AddField( - model_name='gitdata', - name='format', + model_name="gitdata", + name="format", field=models.CharField(blank=True, max_length=256), ), migrations.RunPython(update_default_format_field), diff --git a/cvat/apps/engine/__init__.py b/cvat/apps/engine/__init__.py index 325276288f0f..f6b1f2bb9381 100644 --- a/cvat/apps/engine/__init__.py +++ b/cvat/apps/engine/__init__.py @@ -4,4 +4,4 @@ # SPDX-License-Identifier: MIT -from .schema import * # force import of declared symbols +from .schema import * # force import of declared symbols diff --git a/cvat/apps/engine/admin.py b/cvat/apps/engine/admin.py index 05e4b40a0f9b..712e67fa5582 100644 --- a/cvat/apps/engine/admin.py +++ b/cvat/apps/engine/admin.py @@ -4,8 +4,21 @@ # SPDX-License-Identifier: MIT from django.contrib import admin -from .models import Task, Segment, Job, Label, AttributeSpec, Project, \ - CloudStorage, Storage, Data, AnnotationGuide, Asset + +from .models import ( + AnnotationGuide, + Asset, + AttributeSpec, + CloudStorage, + Data, + Job, + Label, + Project, + Segment, + Storage, + Task, +) + class JobInline(admin.TabularInline): model = Job diff --git a/cvat/apps/engine/apps.py b/cvat/apps/engine/apps.py index bcad84510f5d..1cea639842c8 100644 --- a/cvat/apps/engine/apps.py +++ b/cvat/apps/engine/apps.py @@ -20,6 +20,7 @@ def ready(self): # Required to define signals in application import cvat.apps.engine.signals + # Required in order to silent "unused-import" in pyflake assert cvat.apps.engine.signals diff --git a/cvat/apps/engine/background.py b/cvat/apps/engine/background.py index d9f9237e6d27..a3a2d34326b9 100644 --- a/cvat/apps/engine/background.py +++ b/cvat/apps/engine/background.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Optional, Union import django_rq from attrs.converters import to_bool @@ -170,7 +170,7 @@ class ExportArgs: format: str filename: str save_images: bool - location_config: Dict[str, Any] + location_config: dict[str, Any] @property def location(self) -> Location: @@ -515,7 +515,7 @@ class BackupExportManager(_ResourceExportManager): @dataclass class ExportArgs: filename: str - location_config: Dict[str, Any] + location_config: dict[str, Any] @property def location(self) -> Location: diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py index 499700a3b4ef..f3790427f5ba 100644 --- a/cvat/apps/engine/backup.py +++ b/cvat/apps/engine/backup.py @@ -10,47 +10,68 @@ import shutil import tempfile import uuid +from collections.abc import Collection, Iterable from enum import Enum from logging import Logger from tempfile import NamedTemporaryFile -from typing import Any, Collection, Dict, Iterable, Optional, Union +from typing import Any, Optional, Union from zipfile import ZipFile import django_rq from django.conf import settings from django.db import transaction from django.utils import timezone - from rest_framework import serializers, status +from rest_framework.exceptions import ValidationError from rest_framework.parsers import JSONParser from rest_framework.renderers import JSONRenderer from rest_framework.response import Response -from rest_framework.exceptions import ValidationError import cvat.apps.dataset_manager as dm +from cvat.apps.dataset_manager.bindings import CvatImportError +from cvat.apps.dataset_manager.views import get_export_cache_dir, log_exception from cvat.apps.engine import models +from cvat.apps.engine.cloud_provider import import_resource_from_cloud_storage +from cvat.apps.engine.location import StorageType, get_location_configuration from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.serializers import (AttributeSerializer, DataSerializer, JobWriteSerializer, - LabelSerializer, AnnotationGuideWriteSerializer, AssetWriteSerializer, - LabeledDataSerializer, SegmentSerializer, SimpleJobSerializer, TaskReadSerializer, - ProjectReadSerializer, ProjectFileSerializer, TaskFileSerializer, RqIdSerializer, - ValidationParamsSerializer) -from cvat.apps.engine.utils import ( - av_scan_paths, process_failed_job, - get_rq_job_meta, import_resource_with_clean_up_after, - define_dependent_job, get_rq_lock_by_user, +from cvat.apps.engine.models import ( + DataChoice, + Location, + Project, + RequestAction, + RequestSubresource, + RequestTarget, + StorageChoice, + StorageMethodChoice, ) +from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField -from cvat.apps.engine.models import ( - StorageChoice, StorageMethodChoice, DataChoice, Project, Location, - RequestAction, RequestTarget, RequestSubresource, +from cvat.apps.engine.serializers import ( + AnnotationGuideWriteSerializer, + AssetWriteSerializer, + AttributeSerializer, + DataSerializer, + JobWriteSerializer, + LabeledDataSerializer, + LabelSerializer, + ProjectFileSerializer, + ProjectReadSerializer, + RqIdSerializer, + SegmentSerializer, + SimpleJobSerializer, + TaskFileSerializer, + TaskReadSerializer, + ValidationParamsSerializer, ) from cvat.apps.engine.task import JobFileMapping, _create_thread -from cvat.apps.engine.cloud_provider import import_resource_from_cloud_storage -from cvat.apps.engine.location import StorageType, get_location_configuration -from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export -from cvat.apps.dataset_manager.views import get_export_cache_dir, log_exception -from cvat.apps.dataset_manager.bindings import CvatImportError +from cvat.apps.engine.utils import ( + av_scan_paths, + define_dependent_job, + get_rq_job_meta, + get_rq_lock_by_user, + import_resource_with_clean_up_after, + process_failed_job, +) slogger = ServerLogManager(__name__) @@ -650,7 +671,7 @@ def _calculate_segment_size(jobs): return segment_size, overlap @staticmethod - def _parse_segment_frames(*, jobs: Dict[str, Any]) -> JobFileMapping: + def _parse_segment_frames(*, jobs: dict[str, Any]) -> JobFileMapping: segments = [] for i, segment in enumerate(jobs): diff --git a/cvat/apps/engine/cache.py b/cvat/apps/engine/cache.py index f89ca741501e..ffe8fe0cb920 100644 --- a/cvat/apps/engine/cache.py +++ b/cvat/apps/engine/cache.py @@ -13,22 +13,11 @@ import time import zipfile import zlib +from collections.abc import Collection, Generator, Iterator, Sequence from contextlib import ExitStack, closing from datetime import datetime, timezone from itertools import groupby, pairwise -from typing import ( - Any, - Callable, - Collection, - Generator, - Iterator, - Optional, - Sequence, - Tuple, - Type, - Union, - overload, -) +from typing import Any, Callable, Optional, Union, overload import attrs import av @@ -76,8 +65,8 @@ slogger = ServerLogManager(__name__) -DataWithMime = Tuple[io.BytesIO, str] -_CacheItem = Tuple[io.BytesIO, str, int, Union[datetime, None]] +DataWithMime = tuple[io.BytesIO, str] +_CacheItem = tuple[io.BytesIO, str, int, Union[datetime, None]] def enqueue_create_chunk_job( @@ -229,17 +218,19 @@ def _create_and_set_cache_item( item_data = create_callback() item_data_bytes = item_data[0].getvalue() item = (item_data[0], item_data[1], cls._get_checksum(item_data_bytes), timestamp) - if item_data_bytes: - cache = cls._cache() - with get_rq_lock_for_job( - cls._get_queue(), - key, - ): - cached_item = cache.get(key) - if cached_item is not None and timestamp <= cached_item[3]: - item = cached_item - else: - cache.set(key, item, timeout=cache_item_ttl or cache.default_timeout) + + # allow empty data to be set in cache to prevent + # future rq jobs from being enqueued to prepare the item + cache = cls._cache() + with get_rq_lock_for_job( + cls._get_queue(), + key, + ): + cached_item = cache.get(key) + if cached_item is not None and timestamp <= cached_item[3]: + item = cached_item + else: + cache.set(key, item, timeout=cache_item_ttl or cache.default_timeout) return item @@ -364,11 +355,18 @@ def _make_frame_context_images_chunk_key(self, db_data: models.Data, frame_numbe def _to_data_with_mime(self, cache_item: _CacheItem) -> DataWithMime: ... @overload - def _to_data_with_mime(self, cache_item: Optional[_CacheItem]) -> Optional[DataWithMime]: ... + def _to_data_with_mime( + self, cache_item: Optional[_CacheItem], *, allow_none: bool = False + ) -> Optional[DataWithMime]: ... - def _to_data_with_mime(self, cache_item: Optional[_CacheItem]) -> Optional[DataWithMime]: + def _to_data_with_mime( + self, cache_item: Optional[_CacheItem], *, allow_none: bool = False + ) -> Optional[DataWithMime]: if not cache_item: - return None + if allow_none: + return None + + raise ValueError("A cache item is not allowed to be None") return cache_item[:2] @@ -396,7 +394,8 @@ def get_task_chunk( return self._to_data_with_mime( self._get_cache_item( key=self._make_chunk_key(db_task, chunk_number, quality=quality), - ) + ), + allow_none=True, ) def get_or_set_task_chunk( @@ -424,7 +423,8 @@ def get_segment_task_chunk( return self._to_data_with_mime( self._get_cache_item( key=self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality), - ) + ), + allow_none=True, ) def get_or_set_segment_task_chunk( @@ -521,7 +521,9 @@ def remove_context_images_chunks(self, params: Sequence[dict[str, Any]]) -> None self._bulk_delete_cache_items(keys_to_remove) def get_cloud_preview(self, db_storage: models.CloudStorage) -> Optional[DataWithMime]: - return self._to_data_with_mime(self._get_cache_item(self._make_preview_key(db_storage))) + return self._to_data_with_mime( + self._get_cache_item(self._make_preview_key(db_storage)), allow_none=True + ) def get_or_set_cloud_preview(self, db_storage: models.CloudStorage) -> DataWithMime: return self._to_data_with_mime( @@ -636,7 +638,7 @@ def _read_raw_images( @staticmethod def _read_raw_frames( db_task: Union[models.Task, int], frame_ids: Sequence[int] - ) -> Generator[Tuple[Union[av.VideoFrame, PIL.Image.Image], str, str], None, None]: + ) -> Generator[tuple[Union[av.VideoFrame, PIL.Image.Image], str, str], None, None]: if isinstance(db_task, int): db_task = models.Task.objects.get(pk=db_task) @@ -962,7 +964,7 @@ def prepare_preview_image(image: PIL.Image.Image) -> DataWithMime: def prepare_chunk( - task_chunk_frames: Iterator[Tuple[Any, str, int]], + task_chunk_frames: Iterator[tuple[Any, str, int]], *, quality: FrameQuality, db_task: models.Task, @@ -972,7 +974,7 @@ def prepare_chunk( db_data = db_task.data - writer_classes: dict[FrameQuality, Type[IChunkWriter]] = { + writer_classes: dict[FrameQuality, type[IChunkWriter]] = { FrameQuality.COMPRESSED: ( Mpeg4CompressedChunkWriter if db_data.compressed_chunk_type == models.DataChoice.VIDEO @@ -1005,7 +1007,7 @@ def prepare_chunk( return buffer, get_chunk_mime_type_for_writer(writer_class) -def get_chunk_mime_type_for_writer(writer: Union[IChunkWriter, Type[IChunkWriter]]) -> str: +def get_chunk_mime_type_for_writer(writer: Union[IChunkWriter, type[IChunkWriter]]) -> str: if isinstance(writer, IChunkWriter): writer_class = type(writer) else: diff --git a/cvat/apps/engine/cloud_provider.py b/cvat/apps/engine/cloud_provider.py index f3fe3e6a28e1..06b2496ce16b 100644 --- a/cvat/apps/engine/cloud_provider.py +++ b/cvat/apps/engine/cloud_provider.py @@ -5,13 +5,14 @@ import functools import json -import os import math -from abc import ABC, abstractmethod, abstractproperty +import os +from abc import ABC, abstractmethod +from collections.abc import Iterator +from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait from enum import Enum from io import BytesIO -from typing import Dict, List, Optional, Any, Callable, TypeVar, Iterator -from concurrent.futures import ThreadPoolExecutor, wait, FIRST_EXCEPTION +from typing import Any, Callable, Optional, TypeVar import boto3 from azure.core.exceptions import HttpResponseError, ResourceExistsError @@ -26,14 +27,14 @@ from google.cloud.exceptions import Forbidden as GoogleCloudForbidden from google.cloud.exceptions import NotFound as GoogleCloudNotFound from PIL import Image, ImageFile -from rest_framework.exceptions import (NotFound, PermissionDenied, - ValidationError) +from rest_framework.exceptions import NotFound, PermissionDenied, ValidationError from cvat.apps.engine.log import ServerLogManager from cvat.apps.engine.models import CloudProviderChoice, CredentialsTypeChoice from cvat.apps.engine.utils import get_cpu_number, take_by from cvat.utils.http import PROXIES_FOR_UNTRUSTED_URLS + class NamedBytesIO(BytesIO): @property def filename(self) -> Optional[str]: @@ -135,7 +136,8 @@ class _CloudStorage(ABC): def __init__(self, prefix: Optional[str] = None): self.prefix = prefix - @abstractproperty + @property + @abstractmethod def name(self): pass @@ -232,7 +234,7 @@ def optimally_image_download(self, key: str, chunk_size: int = 65536) -> NamedBy def bulk_download_to_memory( self, - files: List[str], + files: list[str], *, threads_number: Optional[int] = None, _use_optimal_downloading: bool = True, @@ -246,7 +248,7 @@ def bulk_download_to_memory( def bulk_download_to_dir( self, - files: List[str], + files: list[str], upload_dir: str, *, threads_number: Optional[int] = None, @@ -274,7 +276,7 @@ def _list_raw_content_on_one_page( prefix: str = "", next_token: Optional[str] = None, page_size: int = settings.BUCKET_CONTENT_MAX_PAGE_SIZE, - ) -> Dict: + ) -> dict: pass def list_files_on_one_page( @@ -284,7 +286,7 @@ def list_files_on_one_page( page_size: int = settings.BUCKET_CONTENT_MAX_PAGE_SIZE, _use_flat_listing: bool = False, _use_sort: bool = False, - ) -> Dict: + ) -> dict: if self.prefix and prefix and not (self.prefix.startswith(prefix) or prefix.startswith(self.prefix)): return { @@ -337,7 +339,7 @@ def list_files( self, prefix: str = "", _use_flat_listing: bool = False, - ) -> List[str]: + ) -> list[str]: all_files = [] next_token = None while True: @@ -349,7 +351,8 @@ def list_files( return all_files - @abstractproperty + @property + @abstractmethod def supported_actions(self): pass @@ -365,7 +368,7 @@ def get_cloud_storage_instance( cloud_provider: CloudProviderChoice, resource: str, credentials: str, - specific_attributes: Optional[Dict[str, Any]] = None, + specific_attributes: Optional[dict[str, Any]] = None, ): instance = None if cloud_provider == CloudProviderChoice.AWS_S3: @@ -529,7 +532,7 @@ def _list_raw_content_on_one_page( prefix: str = "", next_token: Optional[str] = None, page_size: int = settings.BUCKET_CONTENT_MAX_PAGE_SIZE, - ) -> Dict: + ) -> dict: # The structure of response looks like this: # { # 'CommonPrefixes': [{'Prefix': 'sub/'}], @@ -736,7 +739,7 @@ def _list_raw_content_on_one_page( prefix: str = "", next_token: Optional[str] = None, page_size: int = settings.BUCKET_CONTENT_MAX_PAGE_SIZE, - ) -> Dict: + ) -> dict: page = self._client.walk_blobs( maxresults=page_size, results_per_page=page_size, delimiter='/', **({'name_starts_with': prefix} if prefix else {}) @@ -852,7 +855,7 @@ def _list_raw_content_on_one_page( prefix: str = "", next_token: Optional[str] = None, page_size: int = settings.BUCKET_CONTENT_MAX_PAGE_SIZE, - ) -> Dict: + ) -> dict: iterator = self._client.list_blobs( bucket_or_name=self.name, max_results=page_size, page_size=page_size, fields='items(name),nextPageToken,prefixes', # https://cloud.google.com/storage/docs/json_api/v1/parameters#fields diff --git a/cvat/apps/engine/field_validation.py b/cvat/apps/engine/field_validation.py index bbfa58b5f3ea..e411284b3cde 100644 --- a/cvat/apps/engine/field_validation.py +++ b/cvat/apps/engine/field_validation.py @@ -2,7 +2,8 @@ # # SPDX-License-Identifier: MIT -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from rest_framework import serializers diff --git a/cvat/apps/engine/filters.py b/cvat/apps/engine/filters.py index 663b6554e168..6a80e94ad6cc 100644 --- a/cvat/apps/engine/filters.py +++ b/cvat/apps/engine/filters.py @@ -3,29 +3,30 @@ # # SPDX-License-Identifier: MIT -from typing import Any, Dict, Tuple, List, Iterator, Optional, Iterable -from functools import reduce -import operator import json +import operator +from collections.abc import Iterable, Iterator +from functools import reduce +from textwrap import dedent +from typing import Any, Optional +from django.db.models import Q +from django.db.models.query import QuerySet +from django.utils.encoding import force_str +from django.utils.translation import gettext_lazy as _ from django_filters import FilterSet from django_filters import filters as djf from django_filters.filterset import BaseFilterSet from django_filters.rest_framework import DjangoFilterBackend -from django.db.models import Q -from django.db.models.query import QuerySet -from django.utils.translation import gettext_lazy as _ -from django.utils.encoding import force_str -from rest_framework.request import Request from rest_framework import filters from rest_framework.compat import coreapi, coreschema from rest_framework.exceptions import ValidationError -from textwrap import dedent +from rest_framework.request import Request DEFAULT_FILTER_FIELDS_ATTR = 'filter_fields' DEFAULT_LOOKUP_MAP_ATTR = 'lookup_fields' -def get_lookup_fields(view, fields: Optional[Iterator[str]] = None) -> Dict[str, str]: +def get_lookup_fields(view, fields: Optional[Iterator[str]] = None) -> dict[str, str]: if fields is None: fields = getattr(view, DEFAULT_FILTER_FIELDS_ATTR, None) or [] @@ -134,7 +135,7 @@ def get_schema_operation_parameters(self, view): }] if ordering_fields else [] class JsonLogicFilter(filters.BaseFilterBackend): - Rules = Dict[str, Any] + Rules = dict[str, Any] filter_param = 'filter' filter_title = _('Filter') filter_description = _(dedent(""" @@ -191,7 +192,7 @@ def _parse_query(self, json_rules: str) -> Rules: return rules def apply_filter(self, - queryset: QuerySet, parsed_rules: Rules, *, lookup_fields: Dict[str, Any] + queryset: QuerySet, parsed_rules: Rules, *, lookup_fields: dict[str, Any] ) -> QuerySet: try: q_object = self._build_Q(parsed_rules, lookup_fields) @@ -362,7 +363,7 @@ class DotDict(dict): __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ - def __init__(self, dct: Dict): + def __init__(self, dct: dict): for key, value in dct.items(): if isinstance(value, dict): value = self.__class__(value) @@ -454,7 +455,7 @@ class NonModelOrderingFilter(OrderingFilter, _NestedAttributeHandler): ?sort=-field1,-field2 """ - def get_ordering(self, request, queryset, view) -> Tuple[List[str], bool]: + def get_ordering(self, request, queryset, view) -> tuple[list[str], bool]: ordering = super().get_ordering(request, queryset, view) result, reverse = [], False for field in ordering: diff --git a/cvat/apps/engine/frame_provider.py b/cvat/apps/engine/frame_provider.py index a004256320aa..1a5fd1f40ebd 100644 --- a/cvat/apps/engine/frame_provider.py +++ b/cvat/apps/engine/frame_provider.py @@ -11,22 +11,11 @@ from abc import ABCMeta, abstractmethod from bisect import bisect from collections import OrderedDict +from collections.abc import Iterator, Sequence from dataclasses import dataclass from enum import Enum, auto from io import BytesIO -from typing import ( - Any, - Callable, - Generic, - Iterator, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, - overload, -) +from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload import av import cv2 @@ -53,7 +42,7 @@ class _ChunkLoader(metaclass=ABCMeta): def __init__( self, - reader_class: Type[IMediaReader], + reader_class: type[IMediaReader], *, reader_params: Optional[dict] = None, ) -> None: @@ -62,7 +51,7 @@ def __init__( self.reader_class = reader_class self.reader_params = reader_params - def load(self, chunk_id: int) -> RandomAccessIterator[Tuple[Any, str, int]]: + def load(self, chunk_id: int) -> RandomAccessIterator[tuple[Any, str, int]]: if self.chunk_id != chunk_id: self.unload() @@ -88,7 +77,7 @@ def read_chunk(self, chunk_id: int) -> DataWithMime: ... class _FileChunkLoader(_ChunkLoader): def __init__( self, - reader_class: Type[IMediaReader], + reader_class: type[IMediaReader], get_chunk_path_callback: Callable[[int], str], *, reader_params: Optional[dict] = None, @@ -108,7 +97,7 @@ def read_chunk(self, chunk_id: int) -> DataWithMime: class _BufferChunkLoader(_ChunkLoader): def __init__( self, - reader_class: Type[IMediaReader], + reader_class: type[IMediaReader], get_chunk_callback: Callable[[int], DataWithMime], *, reader_params: Optional[dict] = None, @@ -154,7 +143,7 @@ def _av_frame_to_png_bytes(cls, av_frame: av.VideoFrame) -> BytesIO: return BytesIO(result.tobytes()) def _convert_frame( - self, frame: Any, reader_class: Type[IMediaReader], out_type: FrameOutputType + self, frame: Any, reader_class: type[IMediaReader], out_type: FrameOutputType ) -> AnyFrame: if out_type == FrameOutputType.BUFFER: return ( @@ -451,7 +440,7 @@ def __init__(self, db_segment: models.Segment) -> None: db_data = db_segment.task.data - reader_class: dict[models.DataChoice, Tuple[Type[IMediaReader], Optional[dict]]] = { + reader_class: dict[models.DataChoice, tuple[type[IMediaReader], Optional[dict]]] = { models.DataChoice.IMAGESET: (ZipReader, None), models.DataChoice.VIDEO: ( VideoReader, @@ -523,7 +512,7 @@ def get_frame_index(self, frame_number: int) -> Optional[int]: return frame_index - def validate_frame_number(self, frame_number: int) -> Tuple[int, int, int]: + def validate_frame_number(self, frame_number: int) -> tuple[int, int, int]: frame_index = self.get_frame_index(frame_number) if frame_index is None: raise ValidationError(f"Incorrect requested frame number: {frame_number}") @@ -576,7 +565,7 @@ def _get_raw_frame( frame_number: int, *, quality: FrameQuality = FrameQuality.ORIGINAL, - ) -> Tuple[Any, str, Type[IMediaReader]]: + ) -> tuple[Any, str, type[IMediaReader]]: _, chunk_number, frame_offset = self.validate_frame_number(frame_number) loader = self._loaders[quality] chunk_reader = loader.load(chunk_number) diff --git a/cvat/apps/engine/handlers.py b/cvat/apps/engine/handlers.py index d686bbf0ba5c..0a831a44827b 100644 --- a/cvat/apps/engine/handlers.py +++ b/cvat/apps/engine/handlers.py @@ -4,7 +4,9 @@ from pathlib import Path from time import time + from django.conf import settings + from cvat.apps.engine.log import ServerLogManager slogger = ServerLogManager(__name__) diff --git a/cvat/apps/engine/lazy_list.py b/cvat/apps/engine/lazy_list.py index 61d2c8956209..e8a36a09641f 100644 --- a/cvat/apps/engine/lazy_list.py +++ b/cvat/apps/engine/lazy_list.py @@ -2,9 +2,10 @@ # # SPDX-License-Identifier: MIT +from collections.abc import Iterator from functools import wraps from itertools import islice -from typing import Any, Callable, Iterator, TypeVar, overload +from typing import Any, Callable, TypeVar, overload import attrs from attr import field diff --git a/cvat/apps/engine/location.py b/cvat/apps/engine/location.py index ac6ab77dc073..deea541f09d3 100644 --- a/cvat/apps/engine/location.py +++ b/cvat/apps/engine/location.py @@ -3,9 +3,10 @@ # SPDX-License-Identifier: MIT from enum import Enum -from typing import Any, Dict, Union, Optional +from typing import Any, Optional, Union + +from cvat.apps.engine.models import Job, Location, Project, Task -from cvat.apps.engine.models import Location, Project, Task, Job class StorageType(str, Enum): TARGET = 'target_storage' @@ -15,11 +16,11 @@ def __str__(self): return self.value def get_location_configuration( - query_params: Dict[str, Any], + query_params: dict[str, Any], field_name: str, *, db_instance: Optional[Union[Project, Task, Job]] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: location = query_params.get('location') # handle resource import diff --git a/cvat/apps/engine/log.py b/cvat/apps/engine/log.py index 6f1740e74fd4..3cc2cecff37b 100644 --- a/cvat/apps/engine/log.py +++ b/cvat/apps/engine/log.py @@ -4,12 +4,15 @@ # SPDX-License-Identifier: MIT import logging -import sys import os.path as osp +import sys from contextlib import contextmanager -from cvat.apps.engine.utils import directory_tree + from django.conf import settings +from cvat.apps.engine.utils import directory_tree + + class _LoggerAdapter(logging.LoggerAdapter): def process(self, msg: str, kwargs): if msg_prefix := self.extra.get("msg_prefix"): diff --git a/cvat/apps/engine/management/commands/runperiodicjob.py b/cvat/apps/engine/management/commands/runperiodicjob.py new file mode 100644 index 000000000000..765f16541cfd --- /dev/null +++ b/cvat/apps/engine/management/commands/runperiodicjob.py @@ -0,0 +1,23 @@ +from argparse import ArgumentParser + +from django.conf import settings +from django.core.management.base import BaseCommand, CommandError +from django.utils.module_loading import import_string + + +class Command(BaseCommand): + help = "Run a configured periodic job immediately" + + def add_arguments(self, parser: ArgumentParser) -> None: + parser.add_argument("job_id", help="ID of the job to run") + + def handle(self, *args, **options): + job_id = options["job_id"] + + for job_definition in settings.PERIODIC_RQ_JOBS: + if job_definition["id"] == job_id: + job_func = import_string(job_definition["func"]) + job_func() + return + + raise CommandError(f"Job with ID {job_id} not found") diff --git a/cvat/apps/engine/management/commands/syncperiodicjobs.py b/cvat/apps/engine/management/commands/syncperiodicjobs.py index 097f468b337f..d78d3f247179 100644 --- a/cvat/apps/engine/management/commands/syncperiodicjobs.py +++ b/cvat/apps/engine/management/commands/syncperiodicjobs.py @@ -5,10 +5,10 @@ from argparse import ArgumentParser from collections import defaultdict -from django.core.management.base import BaseCommand +import django_rq from django.conf import settings +from django.core.management.base import BaseCommand -import django_rq class Command(BaseCommand): help = "Synchronize periodic jobs in Redis with the project configuration" diff --git a/cvat/apps/engine/media_extractors.py b/cvat/apps/engine/media_extractors.py index 3e7b8e17a31b..09c2ce2876de 100644 --- a/cvat/apps/engine/media_extractors.py +++ b/cvat/apps/engine/media_extractors.py @@ -5,23 +5,22 @@ from __future__ import annotations +import io +import itertools import os +import shutil +import struct import sysconfig import tempfile -import shutil import zipfile -import io -import itertools -import struct from abc import ABC, abstractmethod from bisect import bisect -from contextlib import ExitStack, closing, contextmanager +from collections.abc import Generator, Iterable, Iterator, Sequence +from contextlib import AbstractContextManager, ExitStack, closing, contextmanager from dataclasses import dataclass from enum import IntEnum -from typing import ( - Any, Callable, ContextManager, Generator, Iterable, Iterator, Optional, Protocol, - Sequence, Tuple, TypeVar, Union -) +from random import shuffle +from typing import Any, Callable, Optional, Protocol, TypeVar, Union import av import av.codec @@ -29,19 +28,19 @@ import av.video.stream import numpy as np from natsort import os_sorted -from pyunpack import Archive from PIL import Image, ImageFile, ImageOps -from random import shuffle -from cvat.apps.engine.utils import rotate_image -from cvat.apps.engine.models import DimensionType, SortingMethod +from pyunpack import Archive from rest_framework.exceptions import ValidationError +from cvat.apps.engine.models import DimensionType, SortingMethod +from cvat.apps.engine.utils import rotate_image + # fixes: "OSError:broken data stream" when executing line 72 while loading images downloaded from the web # see: https://stackoverflow.com/questions/42462431/oserror-broken-data-stream-when-reading-image-file ImageFile.LOAD_TRUNCATED_IMAGES = True from cvat.apps.engine.mime_types import mimetypes -from utils.dataset_manifest import VideoManifestManager, ImageManifestManager +from utils.dataset_manifest import ImageManifestManager, VideoManifestManager ORIENTATION_EXIF_TAG = 274 @@ -612,7 +611,7 @@ def iterate_frames( *, frame_filter: Union[bool, Iterable[int]] = True, video_stream: Optional[av.video.stream.VideoStream] = None, - ) -> Iterator[Tuple[av.VideoFrame, str, int]]: + ) -> Iterator[tuple[av.VideoFrame, str, int]]: """ If provided, frame_filter must be an ordered sequence in the ascending order. 'True' means using the frames configured in the reader object. @@ -673,14 +672,14 @@ def iterate_frames( if next_frame_filter_frame is None: return - def __iter__(self) -> Iterator[Tuple[av.VideoFrame, str, int]]: + def __iter__(self) -> Iterator[tuple[av.VideoFrame, str, int]]: return self.iterate_frames() def get_progress(self, pos): duration = self._get_duration() return pos / duration if duration else None - def _read_av_container(self) -> ContextManager[av.container.InputContainer]: + def _read_av_container(self) -> AbstractContextManager[av.container.InputContainer]: return _AvVideoReading().read_av_container(self._source_path[0]) def _decode_stream( @@ -771,7 +770,7 @@ def __init__(self, manifest_path: str, source_path: str, *, allow_threading: boo self.allow_threading = allow_threading - def _read_av_container(self) -> ContextManager[av.container.InputContainer]: + def _read_av_container(self) -> AbstractContextManager[av.container.InputContainer]: return _AvVideoReading().read_av_container(self.source_path) def _decode_stream( @@ -1032,11 +1031,11 @@ def _add_video_stream(self, container: av.container.OutputContainer, w, h, rate, return video_stream - FrameDescriptor = Tuple[av.VideoFrame, Any, Any] + FrameDescriptor = tuple[av.VideoFrame, Any, Any] def _peek_first_frame( self, frame_iter: Iterator[FrameDescriptor] - ) -> Tuple[Optional[FrameDescriptor], Iterator[FrameDescriptor]]: + ) -> tuple[Optional[FrameDescriptor], Iterator[FrameDescriptor]]: "Gets the first frame and returns the same full iterator" if not hasattr(frame_iter, '__next__'): @@ -1047,7 +1046,7 @@ def _peek_first_frame( def save_as_chunk( self, images: Iterator[FrameDescriptor], chunk_path: str - ) -> Sequence[Tuple[int, int]]: + ) -> Sequence[tuple[int, int]]: first_frame, images = self._peek_first_frame(images) if not first_frame: raise Exception('no images to save') diff --git a/cvat/apps/engine/middleware.py b/cvat/apps/engine/middleware.py index f2b990a14b50..2e8f116f4ecd 100644 --- a/cvat/apps/engine/middleware.py +++ b/cvat/apps/engine/middleware.py @@ -4,6 +4,7 @@ from uuid import uuid4 + class RequestTrackingMiddleware: def __init__(self, get_response): self.get_response = get_response diff --git a/cvat/apps/engine/migrations/0001_release_v0_1_0.py b/cvat/apps/engine/migrations/0001_release_v0_1_0.py index 64d030cc81c6..59edc03104f4 100644 --- a/cvat/apps/engine/migrations/0001_release_v0_1_0.py +++ b/cvat/apps/engine/migrations/0001_release_v0_1_0.py @@ -5,9 +5,9 @@ # Generated by Django 2.0.3 on 2018-05-23 11:51 +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0002_labeledpoints_labeledpointsattributeval_labeledpolygon_labeledpolygonattributeval_labeledpolyline_la.py b/cvat/apps/engine/migrations/0002_labeledpoints_labeledpointsattributeval_labeledpolygon_labeledpolygonattributeval_labeledpolyline_la.py index 0e7820999c38..fa3e6fe79b94 100644 --- a/cvat/apps/engine/migrations/0002_labeledpoints_labeledpointsattributeval_labeledpolygon_labeledpolygonattributeval_labeledpolyline_la.py +++ b/cvat/apps/engine/migrations/0002_labeledpoints_labeledpointsattributeval_labeledpolygon_labeledpolygonattributeval_labeledpolyline_la.py @@ -5,8 +5,8 @@ # Generated by Django 2.0.3 on 2018-05-30 09:53 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0008_auto_20180917_1424.py b/cvat/apps/engine/migrations/0008_auto_20180917_1424.py index cf6b45500d90..a32051d585e4 100644 --- a/cvat/apps/engine/migrations/0008_auto_20180917_1424.py +++ b/cvat/apps/engine/migrations/0008_auto_20180917_1424.py @@ -1,8 +1,8 @@ # Generated by Django 2.0.3 on 2018-09-17 11:24 +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0011_add_task_source_and_safecharfield.py b/cvat/apps/engine/migrations/0011_add_task_source_and_safecharfield.py index bb96c1b588dd..4b168322d486 100644 --- a/cvat/apps/engine/migrations/0011_add_task_source_and_safecharfield.py +++ b/cvat/apps/engine/migrations/0011_add_task_source_and_safecharfield.py @@ -1,8 +1,9 @@ # Generated by Django 2.0.9 on 2018-10-24 10:50 -import cvat.apps.engine.models from django.db import migrations +import cvat.apps.engine.models + class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0013_auth_no_default_permissions.py b/cvat/apps/engine/migrations/0013_auth_no_default_permissions.py index bc735269eed6..2dabe07fe9a0 100644 --- a/cvat/apps/engine/migrations/0013_auth_no_default_permissions.py +++ b/cvat/apps/engine/migrations/0013_auth_no_default_permissions.py @@ -1,8 +1,8 @@ # Generated by Django 2.0.9 on 2018-11-07 12:25 +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0015_db_redesign_20190217.py b/cvat/apps/engine/migrations/0015_db_redesign_20190217.py index db9589d8b807..accac35b8187 100644 --- a/cvat/apps/engine/migrations/0015_db_redesign_20190217.py +++ b/cvat/apps/engine/migrations/0015_db_redesign_20190217.py @@ -1,11 +1,13 @@ # Generated by Django 2.1.5 on 2019-02-17 19:32 -from django.conf import settings -from django.db import migrations, models import django.db.migrations.operations.special import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + import cvat.apps.engine.models + def set_segment_size(apps, schema_editor): Task = apps.get_model('engine', 'Task') for task in Task.objects.all(): diff --git a/cvat/apps/engine/migrations/0016_attribute_spec_20190217.py b/cvat/apps/engine/migrations/0016_attribute_spec_20190217.py index 27d273af2790..ac060ad69326 100644 --- a/cvat/apps/engine/migrations/0016_attribute_spec_20190217.py +++ b/cvat/apps/engine/migrations/0016_attribute_spec_20190217.py @@ -1,12 +1,15 @@ +import csv import os import re -import csv from io import StringIO -from PIL import Image -from django.db import migrations + from django.conf import settings +from django.db import migrations +from PIL import Image + from cvat.apps.engine.media_extractors import get_mime + def parse_attribute(value): match = re.match(r'^([~@])(\w+)=(\w+):(.+)?$', value) if match: diff --git a/cvat/apps/engine/migrations/0017_db_redesign_20190221.py b/cvat/apps/engine/migrations/0017_db_redesign_20190221.py index 22b7e5d28881..d30d5fa0a73a 100644 --- a/cvat/apps/engine/migrations/0017_db_redesign_20190221.py +++ b/cvat/apps/engine/migrations/0017_db_redesign_20190221.py @@ -1,11 +1,13 @@ # Generated by Django 2.1.5 on 2019-02-21 12:25 -import cvat.apps.engine.models -from django.db import migrations, models import django.db.models.deletion from django.conf import settings +from django.db import migrations, models + +import cvat.apps.engine.models from cvat.apps.dataset_manager.task import merge_table_rows as _merge_table_rows + # some modified functions to transfer annotation def _bulk_create(db_model, db_alias, objects, flt_param): if objects: diff --git a/cvat/apps/engine/migrations/0018_jobcommit.py b/cvat/apps/engine/migrations/0018_jobcommit.py index c526cb896435..b25187c50f60 100644 --- a/cvat/apps/engine/migrations/0018_jobcommit.py +++ b/cvat/apps/engine/migrations/0018_jobcommit.py @@ -1,8 +1,8 @@ # Generated by Django 2.1.7 on 2019-04-17 09:25 +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0020_remove_task_flipped.py b/cvat/apps/engine/migrations/0020_remove_task_flipped.py index 7ca57e880417..7744def2b302 100644 --- a/cvat/apps/engine/migrations/0020_remove_task_flipped.py +++ b/cvat/apps/engine/migrations/0020_remove_task_flipped.py @@ -1,14 +1,15 @@ # Generated by Django 2.1.7 on 2019-06-18 11:08 -from django.db import migrations +import os +from ast import literal_eval + from django.conf import settings +from django.db import migrations +from PIL import Image -from cvat.apps.engine.models import Job, ShapeType from cvat.apps.engine.media_extractors import get_mime +from cvat.apps.engine.models import Job, ShapeType -from PIL import Image -from ast import literal_eval -import os def make_image_meta_cache(db_task): with open(db_task.get_image_meta_cache_path(), 'w') as meta_file: diff --git a/cvat/apps/engine/migrations/0022_auto_20191004_0817.py b/cvat/apps/engine/migrations/0022_auto_20191004_0817.py index b48a24f583db..6fd0ca45d8c3 100644 --- a/cvat/apps/engine/migrations/0022_auto_20191004_0817.py +++ b/cvat/apps/engine/migrations/0022_auto_20191004_0817.py @@ -1,9 +1,10 @@ # Generated by Django 2.2.3 on 2019-10-04 08:17 -import cvat.apps.engine.models +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion + +import cvat.apps.engine.models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0023_auto_20200113_1323.py b/cvat/apps/engine/migrations/0023_auto_20200113_1323.py index 4089eb1a1a66..33c586398323 100644 --- a/cvat/apps/engine/migrations/0023_auto_20200113_1323.py +++ b/cvat/apps/engine/migrations/0023_auto_20200113_1323.py @@ -1,8 +1,9 @@ # Generated by Django 2.2.8 on 2020-01-13 13:23 -import cvat.apps.engine.models from django.db import migrations +import cvat.apps.engine.models + class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0024_auto_20191023_1025.py b/cvat/apps/engine/migrations/0024_auto_20191023_1025.py index c8aefe7b7774..945879ef7552 100644 --- a/cvat/apps/engine/migrations/0024_auto_20191023_1025.py +++ b/cvat/apps/engine/migrations/0024_auto_20191023_1025.py @@ -1,24 +1,32 @@ # Generated by Django 2.2.4 on 2019-10-23 10:25 +import glob +import itertools +import multiprocessing import os import re import shutil -import glob import sys -import traceback -import itertools -import multiprocessing import time +import traceback -from django.db import migrations, models import django.db.models.deletion from django.conf import settings +from django.db import migrations, models -from cvat.apps.engine.media_extractors import (VideoReader, ArchiveReader, ZipReader, - PdfReader , ImageListReader, Mpeg4ChunkWriter, - ZipChunkWriter, ZipCompressedChunkWriter, get_mime) -from cvat.apps.engine.models import DataChoice from cvat.apps.engine.log import get_migration_logger +from cvat.apps.engine.media_extractors import ( + ArchiveReader, + ImageListReader, + Mpeg4ChunkWriter, + PdfReader, + VideoReader, + ZipChunkWriter, + ZipCompressedChunkWriter, + ZipReader, + get_mime, +) +from cvat.apps.engine.models import DataChoice MIGRATION_THREAD_COUNT = 2 @@ -79,7 +87,7 @@ def migrate_task_data(db_task_id, db_data_id, original_video, original_images, s compressed_chunk_path = os.path.join(compressed_cache_dir, '{}.zip'.format(chunk_idx)) compressed_chunk_writer.save_as_chunk(chunk_images, compressed_chunk_path) - preview = reader.get_preview() + preview = reader.get_preview(0) preview.save(os.path.join(db_data_dir, 'preview.jpeg')) else: original_chunk_writer = ZipChunkWriter(100) @@ -146,7 +154,7 @@ def migrate_task_data(db_task_id, db_data_id, original_video, original_images, s original_chunk_path = os.path.join(original_cache_dir, '{}.zip'.format(chunk_idx)) original_chunk_writer.save_as_chunk(chunk_images, original_chunk_path) - preview = reader.get_preview() + preview = reader.get_preview(0) preview.save(os.path.join(db_data_dir, 'preview.jpeg')) shutil.rmtree(old_db_task_dir) return_dict[db_task_id] = (True, '') diff --git a/cvat/apps/engine/migrations/0028_labelcolor.py b/cvat/apps/engine/migrations/0028_labelcolor.py index af30fbabd8d2..eda6215ecdd6 100644 --- a/cvat/apps/engine/migrations/0028_labelcolor.py +++ b/cvat/apps/engine/migrations/0028_labelcolor.py @@ -1,7 +1,9 @@ # Generated by Django 2.2.13 on 2020-08-11 11:26 from django.db import migrations, models + from cvat.apps.dataset_manager.formats.utils import get_label_color + def alter_label_colors(apps, schema_editor): Label = apps.get_model('engine', 'Label') Task = apps.get_model('engine', 'Task') diff --git a/cvat/apps/engine/migrations/0029_data_storage_method.py b/cvat/apps/engine/migrations/0029_data_storage_method.py index 1c1aa814e4cd..e5ee36f33f06 100644 --- a/cvat/apps/engine/migrations/0029_data_storage_method.py +++ b/cvat/apps/engine/migrations/0029_data_storage_method.py @@ -1,12 +1,15 @@ # Generated by Django 2.2.13 on 2020-08-13 05:49 -from cvat.apps.engine.media_extractors import _is_archive, _is_zip -import cvat.apps.engine.models +import os + from django.conf import settings from django.db import migrations, models -import os from pyunpack import Archive +import cvat.apps.engine.models +from cvat.apps.engine.media_extractors import _is_archive, _is_zip + + def unzip(apps, schema_editor): Data = apps.get_model("engine", "Data") data_q_set = Data.objects.all() diff --git a/cvat/apps/engine/migrations/0033_projects_adjastment.py b/cvat/apps/engine/migrations/0033_projects_adjastment.py index e57bd0e6c568..8af73e6d1da5 100644 --- a/cvat/apps/engine/migrations/0033_projects_adjastment.py +++ b/cvat/apps/engine/migrations/0033_projects_adjastment.py @@ -1,7 +1,7 @@ # Generated by Django 3.1.1 on 2020-09-24 12:44 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0034_auto_20201125_1426.py b/cvat/apps/engine/migrations/0034_auto_20201125_1426.py index 457861a3942c..d02582342893 100644 --- a/cvat/apps/engine/migrations/0034_auto_20201125_1426.py +++ b/cvat/apps/engine/migrations/0034_auto_20201125_1426.py @@ -1,17 +1,19 @@ # Generated by Django 3.1.1 on 2020-11-25 14:26 -import cvat.apps.engine.models +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion + +import cvat.apps.engine.models + def create_profile(apps, schema_editor): - User = apps.get_model('auth', 'User') - Profile = apps.get_model('engine', 'Profile') - for user in User.objects.all(): - profile = Profile() - profile.user = user - profile.save() + User = apps.get_model('auth', 'User') + Profile = apps.get_model('engine', 'Profile') + for user in User.objects.all(): + profile = Profile() + profile.user = user + profile.save() class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0035_data_storage.py b/cvat/apps/engine/migrations/0035_data_storage.py index 5a8a9903784f..075d7ce38015 100644 --- a/cvat/apps/engine/migrations/0035_data_storage.py +++ b/cvat/apps/engine/migrations/0035_data_storage.py @@ -1,8 +1,9 @@ # Generated by Django 3.1.1 on 2020-12-02 06:47 -import cvat.apps.engine.models from django.db import migrations, models +import cvat.apps.engine.models + class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0036_auto_20201216_0943.py b/cvat/apps/engine/migrations/0036_auto_20201216_0943.py index 6f2fde01250f..52cb5faca2a5 100644 --- a/cvat/apps/engine/migrations/0036_auto_20201216_0943.py +++ b/cvat/apps/engine/migrations/0036_auto_20201216_0943.py @@ -1,8 +1,9 @@ # Generated by Django 3.1.1 on 2020-12-16 09:43 -import cvat.apps.engine.models -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models + +import cvat.apps.engine.models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0038_manifest.py b/cvat/apps/engine/migrations/0038_manifest.py index ec96045ae69c..33208ad4cf19 100644 --- a/cvat/apps/engine/migrations/0038_manifest.py +++ b/cvat/apps/engine/migrations/0038_manifest.py @@ -9,9 +9,8 @@ from django.db import migrations from cvat.apps.engine.log import get_logger -from cvat.apps.engine.models import (DimensionType, StorageChoice, - StorageMethodChoice) from cvat.apps.engine.media_extractors import get_mime +from cvat.apps.engine.models import DimensionType, StorageChoice, StorageMethodChoice from utils.dataset_manifest import ImageManifestManager, VideoManifestManager MIGRATION_NAME = os.path.splitext(os.path.basename(__file__))[0] @@ -110,7 +109,7 @@ def migrate2manifest(apps, shema_editor): if db_data.storage == StorageChoice.SHARE: def _get_frame_step(str_): - match = search("step\s*=\s*([1-9]\d*)", str_) + match = search(r"step\s*=\s*([1-9]\d*)", str_) return int(match.group(1)) if match else 1 logger.info('Data is located on the share, metadata update has been started') manifest.step = _get_frame_step(db_data.frame_filter) diff --git a/cvat/apps/engine/migrations/0039_auto_training.py b/cvat/apps/engine/migrations/0039_auto_training.py index a9f22ea7a03a..4594942d801e 100644 --- a/cvat/apps/engine/migrations/0039_auto_training.py +++ b/cvat/apps/engine/migrations/0039_auto_training.py @@ -1,7 +1,7 @@ # Generated by Django 3.1.7 on 2021-04-02 13:17 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0040_cloud_storage.py b/cvat/apps/engine/migrations/0040_cloud_storage.py index c73609fd9fef..f7ecac010d19 100644 --- a/cvat/apps/engine/migrations/0040_cloud_storage.py +++ b/cvat/apps/engine/migrations/0040_cloud_storage.py @@ -1,9 +1,10 @@ # Generated by Django 3.1.8 on 2021-05-07 06:42 -import cvat.apps.engine.models +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion + +import cvat.apps.engine.models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0042_auto_20210830_1056.py b/cvat/apps/engine/migrations/0042_auto_20210830_1056.py index 7b5a496af97c..69866f2c788a 100644 --- a/cvat/apps/engine/migrations/0042_auto_20210830_1056.py +++ b/cvat/apps/engine/migrations/0042_auto_20210830_1056.py @@ -1,7 +1,7 @@ # Generated by Django 3.1.13 on 2021-08-30 10:56 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0046_data_sorting_method.py b/cvat/apps/engine/migrations/0046_data_sorting_method.py index f3880482fc33..cb58bce9ed69 100644 --- a/cvat/apps/engine/migrations/0046_data_sorting_method.py +++ b/cvat/apps/engine/migrations/0046_data_sorting_method.py @@ -1,8 +1,9 @@ # Generated by Django 3.1.13 on 2021-12-03 08:06 -import cvat.apps.engine.models from django.db import migrations, models +import cvat.apps.engine.models + class Migration(migrations.Migration): replaces = [('engine', '0045_data_sorting_method')] diff --git a/cvat/apps/engine/migrations/0047_auto_20211110_1938.py b/cvat/apps/engine/migrations/0047_auto_20211110_1938.py index 69434115f269..0113b1816c67 100644 --- a/cvat/apps/engine/migrations/0047_auto_20211110_1938.py +++ b/cvat/apps/engine/migrations/0047_auto_20211110_1938.py @@ -1,8 +1,9 @@ # Generated by Django 3.2.8 on 2021-11-10 19:38 -import cvat.apps.engine.models -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models + +import cvat.apps.engine.models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0048_auto_20211112_1918.py b/cvat/apps/engine/migrations/0048_auto_20211112_1918.py index e1c54ab1206b..6c2106624397 100644 --- a/cvat/apps/engine/migrations/0048_auto_20211112_1918.py +++ b/cvat/apps/engine/migrations/0048_auto_20211112_1918.py @@ -1,8 +1,8 @@ # Generated by Django 3.2.8 on 2021-11-12 19:18 +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0053_data_deleted_frames.py b/cvat/apps/engine/migrations/0053_data_deleted_frames.py index 8bbf49792f49..e1421a0a2c1f 100644 --- a/cvat/apps/engine/migrations/0053_data_deleted_frames.py +++ b/cvat/apps/engine/migrations/0053_data_deleted_frames.py @@ -1,8 +1,9 @@ # Generated by Django 3.2.12 on 2022-05-20 09:21 -import cvat.apps.engine.models from django.db import migrations +import cvat.apps.engine.models + class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0054_auto_20220610_1829.py b/cvat/apps/engine/migrations/0054_auto_20220610_1829.py index 1c7ae1a802ec..25ed5b9c0617 100644 --- a/cvat/apps/engine/migrations/0054_auto_20220610_1829.py +++ b/cvat/apps/engine/migrations/0054_auto_20220610_1829.py @@ -1,8 +1,9 @@ # Generated by Django 3.2.12 on 2022-06-10 18:29 -import cvat.apps.engine.models -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models + +import cvat.apps.engine.models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0055_jobs_directories.py b/cvat/apps/engine/migrations/0055_jobs_directories.py index ec97f2c8d3d5..89d7cd300b24 100644 --- a/cvat/apps/engine/migrations/0055_jobs_directories.py +++ b/cvat/apps/engine/migrations/0055_jobs_directories.py @@ -3,8 +3,9 @@ import os import shutil -from django.db import migrations from django.conf import settings +from django.db import migrations + from cvat.apps.engine.log import get_logger MIGRATION_NAME = os.path.splitext(os.path.basename(__file__))[0] diff --git a/cvat/apps/engine/migrations/0056_jobs_previews.py b/cvat/apps/engine/migrations/0056_jobs_previews.py index b8722018f92b..f3e6235fc780 100644 --- a/cvat/apps/engine/migrations/0056_jobs_previews.py +++ b/cvat/apps/engine/migrations/0056_jobs_previews.py @@ -2,8 +2,10 @@ import os import shutil -from django.db import migrations + from django.conf import settings +from django.db import migrations + from cvat.apps.engine.log import get_logger MIGRATION_NAME = os.path.splitext(os.path.basename(__file__))[0] diff --git a/cvat/apps/engine/migrations/0057_auto_20220726_0926.py b/cvat/apps/engine/migrations/0057_auto_20220726_0926.py index 459dbad6e783..22cd9f15e70b 100644 --- a/cvat/apps/engine/migrations/0057_auto_20220726_0926.py +++ b/cvat/apps/engine/migrations/0057_auto_20220726_0926.py @@ -1,8 +1,9 @@ # Generated by Django 3.2.14 on 2022-07-26 09:26 -import cvat.apps.engine.models -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models + +import cvat.apps.engine.models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0058_auto_20220809_1236.py b/cvat/apps/engine/migrations/0058_auto_20220809_1236.py index 8a7eb002d0af..aafb9a3bfab0 100644 --- a/cvat/apps/engine/migrations/0058_auto_20220809_1236.py +++ b/cvat/apps/engine/migrations/0058_auto_20220809_1236.py @@ -1,7 +1,7 @@ # Generated by Django 3.2.15 on 2022-08-09 12:36 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0060_alter_label_parent.py b/cvat/apps/engine/migrations/0060_alter_label_parent.py index 5eb698343413..a5e8a8df31f3 100644 --- a/cvat/apps/engine/migrations/0060_alter_label_parent.py +++ b/cvat/apps/engine/migrations/0060_alter_label_parent.py @@ -1,7 +1,7 @@ # Generated by Django 3.2.15 on 2022-09-09 09:00 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0062_delete_previews.py b/cvat/apps/engine/migrations/0062_delete_previews.py index da986be097fb..ccf5e8f9f176 100644 --- a/cvat/apps/engine/migrations/0062_delete_previews.py +++ b/cvat/apps/engine/migrations/0062_delete_previews.py @@ -2,10 +2,12 @@ import sys import traceback -from django.db import migrations from django.conf import settings +from django.db import migrations + from cvat.apps.engine.log import get_migration_logger + def delete_previews(apps, schema_editor): migration_name = os.path.splitext(os.path.basename(__file__))[0] with get_migration_logger(migration_name) as log: diff --git a/cvat/apps/engine/migrations/0064_delete_or_rename_wrong_labels.py b/cvat/apps/engine/migrations/0064_delete_or_rename_wrong_labels.py index 63c167381529..97cad2c4f565 100644 --- a/cvat/apps/engine/migrations/0064_delete_or_rename_wrong_labels.py +++ b/cvat/apps/engine/migrations/0064_delete_or_rename_wrong_labels.py @@ -1,8 +1,10 @@ import os from django.db import migrations + from cvat.apps.engine.log import get_migration_logger + def delete_or_rename_wrong_labels(apps, schema_editor): migration_name = os.path.splitext(os.path.basename(__file__))[0] with get_migration_logger(migration_name) as log: diff --git a/cvat/apps/engine/migrations/0070_add_job_type_created_date.py b/cvat/apps/engine/migrations/0070_add_job_type_created_date.py index 034a6b275ae9..62d0293245cf 100644 --- a/cvat/apps/engine/migrations/0070_add_job_type_created_date.py +++ b/cvat/apps/engine/migrations/0070_add_job_type_created_date.py @@ -1,6 +1,7 @@ -import cvat.apps.engine.models -from django.db import migrations, models import django.utils.timezone +from django.db import migrations, models + +import cvat.apps.engine.models def add_created_date_to_existing_jobs(apps, schema_editor): diff --git a/cvat/apps/engine/migrations/0071_annotationguide_asset.py b/cvat/apps/engine/migrations/0071_annotationguide_asset.py index 1060c4576aba..a6b50c50861b 100644 --- a/cvat/apps/engine/migrations/0071_annotationguide_asset.py +++ b/cvat/apps/engine/migrations/0071_annotationguide_asset.py @@ -1,9 +1,10 @@ # Generated by Django 3.2.18 on 2023-06-13 13:14 +import uuid + +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion -import uuid class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0072_alter_issue_updated_date.py b/cvat/apps/engine/migrations/0072_alter_issue_updated_date.py index 4c549be10aa5..344036d12f65 100644 --- a/cvat/apps/engine/migrations/0072_alter_issue_updated_date.py +++ b/cvat/apps/engine/migrations/0072_alter_issue_updated_date.py @@ -2,6 +2,7 @@ from django.db import migrations, models + def forwards_func(apps, schema_editor): Issue = apps.get_model("engine", "Issue") diff --git a/cvat/apps/engine/migrations/0076_remove_storages_that_refer_to_deleted_cloud_storages.py b/cvat/apps/engine/migrations/0076_remove_storages_that_refer_to_deleted_cloud_storages.py index 50c1461319a7..41c902bb2500 100644 --- a/cvat/apps/engine/migrations/0076_remove_storages_that_refer_to_deleted_cloud_storages.py +++ b/cvat/apps/engine/migrations/0076_remove_storages_that_refer_to_deleted_cloud_storages.py @@ -1,6 +1,7 @@ # Generated by Django 4.2.6 on 2023-11-17 10:10 from django.db import migrations, models + from cvat.apps.engine.models import Location diff --git a/cvat/apps/engine/migrations/0077_auto_20231121_1952.py b/cvat/apps/engine/migrations/0077_auto_20231121_1952.py index 8b5c3648e068..831e83c8712a 100644 --- a/cvat/apps/engine/migrations/0077_auto_20231121_1952.py +++ b/cvat/apps/engine/migrations/0077_auto_20231121_1952.py @@ -1,7 +1,7 @@ # Generated by Django 4.2.6 on 2023-11-21 19:52 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0079_alter_labeledimageattributeval_image_and_more.py b/cvat/apps/engine/migrations/0079_alter_labeledimageattributeval_image_and_more.py index ccafa6086b5e..58921bc97c92 100644 --- a/cvat/apps/engine/migrations/0079_alter_labeledimageattributeval_image_and_more.py +++ b/cvat/apps/engine/migrations/0079_alter_labeledimageattributeval_image_and_more.py @@ -1,7 +1,7 @@ # Generated by Django 4.2.13 on 2024-07-09 11:08 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0080_alter_trackedshape_track.py b/cvat/apps/engine/migrations/0080_alter_trackedshape_track.py index d5997d15ff91..8266dbf4ba38 100644 --- a/cvat/apps/engine/migrations/0080_alter_trackedshape_track.py +++ b/cvat/apps/engine/migrations/0080_alter_trackedshape_track.py @@ -1,7 +1,7 @@ # Generated by Django 4.2.13 on 2024-07-12 19:01 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0082_alter_labeledimage_job_and_more.py b/cvat/apps/engine/migrations/0082_alter_labeledimage_job_and_more.py index 50b91829b213..ecbc9d76f60d 100644 --- a/cvat/apps/engine/migrations/0082_alter_labeledimage_job_and_more.py +++ b/cvat/apps/engine/migrations/0082_alter_labeledimage_job_and_more.py @@ -1,7 +1,7 @@ # Generated by Django 4.2.14 on 2024-07-22 07:27 -from django.db import migrations, models import django.db.models.deletion +from django.db import migrations, models class Migration(migrations.Migration): diff --git a/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py b/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py index 8ef887d4c54b..4138d9295c87 100644 --- a/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py +++ b/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py @@ -1,8 +1,9 @@ # Generated by Django 4.2.13 on 2024-08-12 09:49 import os +from collections.abc import Iterable from itertools import islice -from typing import Iterable, TypeVar +from typing import TypeVar from django.db import migrations diff --git a/cvat/apps/engine/migrations/0084_honeypot_support.py b/cvat/apps/engine/migrations/0084_honeypot_support.py index 721d400ec386..fb44839c50bd 100644 --- a/cvat/apps/engine/migrations/0084_honeypot_support.py +++ b/cvat/apps/engine/migrations/0084_honeypot_support.py @@ -1,7 +1,7 @@ # Generated by Django 4.2.15 on 2024-09-23 13:11 -from typing import Collection from collections import defaultdict +from collections.abc import Collection import django.db.models.deletion from django.db import migrations, models diff --git a/cvat/apps/engine/migrations/0085_segment_chunks_updated_date.py b/cvat/apps/engine/migrations/0085_segment_chunks_updated_date.py index 52342d7db774..6fed44b22a6a 100644 --- a/cvat/apps/engine/migrations/0085_segment_chunks_updated_date.py +++ b/cvat/apps/engine/migrations/0085_segment_chunks_updated_date.py @@ -1,6 +1,7 @@ # Generated by Django 4.2.15 on 2024-09-25 13:52 from datetime import datetime + from django.db import migrations, models diff --git a/cvat/apps/engine/mime_types.py b/cvat/apps/engine/mime_types.py index 8e70c5cc4193..fad18ba6b6f8 100644 --- a/cvat/apps/engine/mime_types.py +++ b/cvat/apps/engine/mime_types.py @@ -2,9 +2,8 @@ # # SPDX-License-Identifier: MIT -import os import mimetypes - +import os _SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) MEDIA_MIMETYPES_FILES = [ diff --git a/cvat/apps/engine/mixins.py b/cvat/apps/engine/mixins.py index 3e48bf85327e..9e69ffdd5ccb 100644 --- a/cvat/apps/engine/mixins.py +++ b/cvat/apps/engine/mixins.py @@ -8,12 +8,13 @@ import os import os.path import uuid +from collections.abc import Mapping from dataclasses import asdict, dataclass from pathlib import Path from tempfile import NamedTemporaryFile -from unittest import mock from textwrap import dedent -from typing import Optional, Callable, Dict, Any, Mapping +from typing import Any, Callable, Optional +from unittest import mock from urllib.parse import urljoin import django_rq @@ -21,20 +22,18 @@ from django.conf import settings from django.http import HttpRequest from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import (OpenApiParameter, OpenApiResponse, - extend_schema) +from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema from rest_framework import mixins, status -from rest_framework.decorators import action from rest_framework.authentication import SessionAuthentication +from rest_framework.decorators import action from rest_framework.response import Response from rest_framework.views import APIView -from cvat.apps.engine.background import (BackupExportManager, - DatasetExportManager) +from cvat.apps.engine.background import BackupExportManager, DatasetExportManager from cvat.apps.engine.handlers import clear_import_cache from cvat.apps.engine.location import StorageType, get_location_configuration from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.models import Location, RequestAction, RequestTarget, RequestSubresource +from cvat.apps.engine.models import Location, RequestAction, RequestSubresource, RequestTarget from cvat.apps.engine.rq_job_handler import RQId from cvat.apps.engine.serializers import DataSerializer, RqIdSerializer from cvat.apps.engine.utils import is_dataset_export @@ -424,7 +423,7 @@ def export_dataset_v1( request, save_images: bool, *, - get_data: Optional[Callable[[int], Dict[str, Any]]] = None, + get_data: Optional[Callable[[int], dict[str, Any]]] = None, ) -> Response: if request.query_params.get("format"): callback = self.get_export_callback(save_images) diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index 527741497531..c25c75404eaf 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -10,9 +10,10 @@ import re import shutil import uuid +from collections.abc import Collection, Sequence from enum import Enum from functools import cached_property -from typing import Any, ClassVar, Collection, Dict, Optional, Sequence +from typing import Any, ClassVar, Optional from django.conf import settings from django.contrib.auth.models import User @@ -824,7 +825,7 @@ def update_or_create(self, *args, **kwargs: Any): return super().update_or_create(*args, **kwargs) - def _validate_constraints(self, obj: Dict[str, Any]): + def _validate_constraints(self, obj: dict[str, Any]): if 'type' not in obj: return diff --git a/cvat/apps/engine/pagination.py b/cvat/apps/engine/pagination.py index 2bb417f5c0d1..6a1dd499b893 100644 --- a/cvat/apps/engine/pagination.py +++ b/cvat/apps/engine/pagination.py @@ -3,8 +3,10 @@ # SPDX-License-Identifier: MIT import sys + from rest_framework.pagination import PageNumberPagination + class CustomPagination(PageNumberPagination): page_size_query_param = "page_size" diff --git a/cvat/apps/engine/parsers.py b/cvat/apps/engine/parsers.py index d0cecc4b02d0..03b4ebd45da8 100644 --- a/cvat/apps/engine/parsers.py +++ b/cvat/apps/engine/parsers.py @@ -4,6 +4,7 @@ from rest_framework.parsers import BaseParser + class TusUploadParser(BaseParser): # The media type is sent by TUS protocol (tus.io) for uploading files media_type = 'application/offset+octet-stream' diff --git a/cvat/apps/engine/permissions.py b/cvat/apps/engine/permissions.py index d01036fc9004..a180410142cd 100644 --- a/cvat/apps/engine/permissions.py +++ b/cvat/apps/engine/permissions.py @@ -4,24 +4,28 @@ # SPDX-License-Identifier: MIT from collections import namedtuple -from typing import Any, Dict, List, Optional, Sequence, Union, cast +from collections.abc import Sequence +from typing import Any, Optional, Union, cast -from django.shortcuts import get_object_or_404 from django.conf import settings - -from rest_framework.exceptions import ValidationError, PermissionDenied +from django.shortcuts import get_object_or_404 +from rest_framework.exceptions import PermissionDenied, ValidationError from rq.job import Job as RQJob from cvat.apps.engine.rq_job_handler import is_rq_job_owner +from cvat.apps.engine.utils import is_dataset_export from cvat.apps.iam.permissions import ( - OpenPolicyAgentPermission, StrEnum, get_iam_context, get_membership + OpenPolicyAgentPermission, + StrEnum, + get_iam_context, + get_membership, ) from cvat.apps.organizations.models import Organization from .models import AnnotationGuide, CloudStorage, Issue, Job, Label, Project, Task -from cvat.apps.engine.utils import is_dataset_export -def _get_key(d: Dict[str, Any], key_path: Union[str, Sequence[str]]) -> Optional[Any]: + +def _get_key(d: dict[str, Any], key_path: Union[str, Sequence[str]]) -> Optional[Any]: """ Like dict.get(), but supports nested fields. If the field is missing, returns None. """ @@ -466,7 +470,7 @@ def __init__(self, **kwargs): self.url = settings.IAM_OPA_DATA_URL + '/tasks/allow' @staticmethod - def get_scopes(request, view, obj) -> List[Scopes]: + def get_scopes(request, view, obj) -> list[Scopes]: Scopes = __class__.Scopes scope = { ('list', 'GET'): Scopes.LIST, @@ -1191,7 +1195,7 @@ class Scopes(StrEnum): CANCEL = 'cancel' @classmethod - def create(cls, request, view, obj: Optional[RQJob], iam_context: Dict): + def create(cls, request, view, obj: Optional[RQJob], iam_context: dict): permissions = [] if view.basename == 'request': for scope in cls.get_scopes(request, view, obj): @@ -1207,7 +1211,7 @@ def __init__(self, **kwargs): self.url = settings.IAM_OPA_DATA_URL + '/requests/allow' @staticmethod - def get_scopes(request, view, obj) -> List[Scopes]: + def get_scopes(request, view, obj) -> list[Scopes]: Scopes = __class__.Scopes return [{ ('list', 'GET'): Scopes.LIST, diff --git a/cvat/apps/engine/renderers.py b/cvat/apps/engine/renderers.py index f56eb4d39808..542a322048ed 100644 --- a/cvat/apps/engine/renderers.py +++ b/cvat/apps/engine/renderers.py @@ -4,5 +4,6 @@ from rest_framework.renderers import JSONRenderer + class CVATAPIRenderer(JSONRenderer): media_type = 'application/vnd.cvat+json' diff --git a/cvat/apps/engine/rq_job_handler.py b/cvat/apps/engine/rq_job_handler.py index c5b31336ecdc..b4f146197afc 100644 --- a/cvat/apps/engine/rq_job_handler.py +++ b/cvat/apps/engine/rq_job_handler.py @@ -4,13 +4,14 @@ from __future__ import annotations -import attrs - -from typing import Optional, Union, Any +from typing import Any, Optional, Union from uuid import UUID + +import attrs from rq.job import Job as RQJob -from .models import RequestAction, RequestTarget, RequestSubresource +from .models import RequestAction, RequestSubresource, RequestTarget + class RQMeta: @staticmethod diff --git a/cvat/apps/engine/schema.py b/cvat/apps/engine/schema.py index 5931381b403d..f3914a03dddd 100644 --- a/cvat/apps/engine/schema.py +++ b/cvat/apps/engine/schema.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: MIT import textwrap -from typing import Type from drf_spectacular.extensions import OpenApiSerializerExtension from drf_spectacular.plumbing import build_basic_type, force_instance @@ -15,7 +14,7 @@ def _copy_serializer( instance: serializers.Serializer, *, - _new_type: Type[serializers.Serializer] = None, + _new_type: type[serializers.Serializer] = None, **kwargs ) -> serializers.Serializer: _new_type = _new_type or type(instance) diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index cf16d885163c..6c760b42ba65 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -5,46 +5,55 @@ from __future__ import annotations -from contextlib import closing +import os +import re +import shutil +import string +import textwrap import warnings +from collections import OrderedDict +from collections.abc import Iterable, Sequence +from contextlib import closing from copy import copy from datetime import timedelta from decimal import Decimal from inspect import isclass -import os -import re -import shutil -import string from tempfile import NamedTemporaryFile -import textwrap -from typing import Any, Dict, Iterable, Optional, OrderedDict, Sequence, Union +from typing import Any, Optional, Union import django_rq +import rq.defaults as rq_defaults from django.conf import settings -from django.contrib.auth.models import User, Group +from django.contrib.auth.models import Group, User from django.db import transaction -from django.db.models import prefetch_related_objects, Prefetch +from django.db.models import Prefetch, prefetch_related_objects from django.utils import timezone +from drf_spectacular.utils import OpenApiExample, extend_schema_field, extend_schema_serializer from numpy import random -from rest_framework import serializers, exceptions -import rq.defaults as rq_defaults -from rq.job import Job as RQJob, JobStatus as RQJobStatus +from rest_framework import exceptions, serializers +from rq.job import Job as RQJob +from rq.job import JobStatus as RQJobStatus from cvat.apps.dataset_manager.formats.utils import get_label_color from cvat.apps.engine import field_validation, models -from cvat.apps.engine.frame_provider import TaskFrameProvider, FrameQuality -from cvat.apps.engine.cloud_provider import get_cloud_storage_instance, Credentials, Status +from cvat.apps.engine.cloud_provider import Credentials, Status, get_cloud_storage_instance +from cvat.apps.engine.frame_provider import FrameQuality, TaskFrameProvider from cvat.apps.engine.log import ServerLogManager from cvat.apps.engine.permissions import TaskPermission +from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField from cvat.apps.engine.task_validation import HoneypotFrameSelector -from cvat.apps.engine.rq_job_handler import RQJobMetaField, RQId from cvat.apps.engine.utils import ( - format_list, grouped, parse_exception_message, CvatChunkTimestampMismatchError, - parse_specific_attributes, build_field_filter_params, get_list_view_name, reverse, take_by + CvatChunkTimestampMismatchError, + build_field_filter_params, + format_list, + get_list_view_name, + grouped, + parse_exception_message, + parse_specific_attributes, + reverse, + take_by, ) -from drf_spectacular.utils import OpenApiExample, extend_schema_field, extend_schema_serializer - slogger = ServerLogManager(__name__) class WriteOnceMixin: @@ -367,9 +376,9 @@ def check_attribute_names_unique(attrs): @transaction.atomic def update_label( cls, - validated_data: Dict[str, Any], + validated_data: dict[str, Any], svg: str, - sublabels: Iterable[Dict[str, Any]], + sublabels: Iterable[dict[str, Any]], *, parent_instance: Union[models.Project, models.Task], parent_label: Optional[models.Label] = None @@ -483,7 +492,7 @@ def update_label( @classmethod @transaction.atomic def create_labels(cls, - labels: Iterable[Dict[str, Any]], + labels: Iterable[dict[str, Any]], *, parent_instance: Union[models.Project, models.Task], parent_label: Optional[models.Label] = None @@ -534,7 +543,7 @@ def create_labels(cls, @classmethod @transaction.atomic def update_labels(cls, - labels: Iterable[Dict[str, Any]], + labels: Iterable[dict[str, Any]], *, parent_instance: Union[models.Project, models.Task], parent_label: Optional[models.Label] = None @@ -994,7 +1003,10 @@ def validate(self, attrs): @transaction.atomic def update(self, instance: models.Job, validated_data: dict[str, Any]) -> models.Job: from cvat.apps.engine.cache import ( - MediaCache, Callback, enqueue_create_chunk_job, wait_for_rq_job + Callback, + MediaCache, + enqueue_create_chunk_job, + wait_for_rq_job, ) from cvat.apps.engine.frame_provider import JobFrameProvider @@ -1099,7 +1111,7 @@ def _to_abs_frame(rel_frame: int) -> int: ) if bulk_context: - active_validation_frame_counts = bulk_context.active_validation_frame_counts + frame_selector = bulk_context.honeypot_frame_selector else: active_validation_frame_counts = { validation_frame: 0 for validation_frame in task_active_validation_frames @@ -1109,7 +1121,8 @@ def _to_abs_frame(rel_frame: int) -> int: if real_frame in task_active_validation_frames: active_validation_frame_counts[real_frame] += 1 - frame_selector = HoneypotFrameSelector(active_validation_frame_counts) + frame_selector = HoneypotFrameSelector(active_validation_frame_counts) + requested_frames = frame_selector.select_next_frames(segment_honeypots_count) requested_frames = list(map(_to_abs_frame, requested_frames)) else: @@ -1356,7 +1369,7 @@ def __init__( honeypot_frames: list[int], all_validation_frames: list[int], active_validation_frames: list[int], - validation_frame_counts: dict[int, int] | None = None + honeypot_frame_selector: HoneypotFrameSelector | None = None ): self.updated_honeypots: dict[int, models.Image] = {} self.updated_segments: list[int] = [] @@ -1368,7 +1381,7 @@ def __init__( self.honeypot_frames = honeypot_frames self.all_validation_frames = all_validation_frames self.active_validation_frames = active_validation_frames - self.active_validation_frame_counts = validation_frame_counts + self.honeypot_frame_selector = honeypot_frame_selector class TaskValidationLayoutWriteSerializer(serializers.Serializer): disabled_frames = serializers.ListField( @@ -1483,7 +1496,9 @@ def update(self, instance: models.Task, validated_data: dict[str, Any]) -> model ) elif frame_selection_method == models.JobFrameSelectionMethod.RANDOM_UNIFORM: # Reset distribution for active validation frames - bulk_context.active_validation_frame_counts = { f: 0 for f in active_validation_frames } + active_validation_frame_counts = { f: 0 for f in active_validation_frames } + frame_selector = HoneypotFrameSelector(active_validation_frame_counts) + bulk_context.honeypot_frame_selector = frame_selector # Could be done using Django ORM, but using order_by() and filter() # would result in an extra DB request @@ -3270,7 +3285,7 @@ class Meta: def _update_related_storages( instance: Union[models.Project, models.Task], - validated_data: Dict[str, Any], + validated_data: dict[str, Any], ) -> None: for storage_type in ('source_storage', 'target_storage'): new_conf = validated_data.pop(storage_type, None) @@ -3325,7 +3340,7 @@ def _update_related_storages( storage_instance.cloud_storage_id = new_cloud_storage_id storage_instance.save() -def _configure_related_storages(validated_data: Dict[str, Any]) -> Dict[str, Optional[models.Storage]]: +def _configure_related_storages(validated_data: dict[str, Any]) -> dict[str, Optional[models.Storage]]: storages = { 'source_storage': None, 'target_storage': None, @@ -3418,7 +3433,7 @@ class RequestDataOperationSerializer(serializers.Serializer): format = serializers.CharField(required=False, allow_null=True) function_id = serializers.CharField(required=False, allow_null=True) - def to_representation(self, rq_job: RQJob) -> Dict[str, Any]: + def to_representation(self, rq_job: RQJob) -> dict[str, Any]: parsed_rq_id: RQId = rq_job.parsed_rq_id return { @@ -3459,7 +3474,7 @@ class RequestSerializer(serializers.Serializer): result_id = serializers.IntegerField(required=False, allow_null=True) @extend_schema_field(UserIdentifiersSerializer()) - def get_owner(self, rq_job: RQJob) -> Dict[str, Any]: + def get_owner(self, rq_job: RQJob) -> dict[str, Any]: return UserIdentifiersSerializer(rq_job.meta[RQJobMetaField.USER]).data @extend_schema_field( @@ -3499,7 +3514,7 @@ def get_message(self, rq_job: RQJob) -> str: return message - def to_representation(self, rq_job: RQJob) -> Dict[str, Any]: + def to_representation(self, rq_job: RQJob) -> dict[str, Any]: representation = super().to_representation(rq_job) # FUTURE-TODO: support such statuses on UI diff --git a/cvat/apps/engine/signals.py b/cvat/apps/engine/signals.py index 3a964d90c2cc..456c6f228081 100644 --- a/cvat/apps/engine/signals.py +++ b/cvat/apps/engine/signals.py @@ -11,8 +11,7 @@ from django.db.models.signals import m2m_changed, post_delete, post_save from django.dispatch import receiver -from .models import CloudStorage, Data, Job, Profile, Project, StatusChoice, Task, Asset - +from .models import Asset, CloudStorage, Data, Job, Profile, Project, StatusChoice, Task # TODO: need to log any problems reported by shutil.rmtree when the new # analytics feature is available. Now the log system can write information diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index 3fac8f03fe65..7aa92acba2fd 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -4,23 +4,24 @@ # SPDX-License-Identifier: MIT import concurrent.futures -import itertools import fnmatch +import itertools import os import re -import rq import shutil -from copy import deepcopy +from collections.abc import Iterator, Sequence from contextlib import closing +from copy import deepcopy from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Union +from typing import Any, NamedTuple, Optional, Union from urllib import parse as urlparse from urllib import request as urlrequest -import av import attrs +import av import django_rq +import rq from django.conf import settings from django.db import transaction from django.forms.models import model_to_dict @@ -28,25 +29,39 @@ from rest_framework.serializers import ValidationError from cvat.apps.engine import models -from cvat.apps.engine.log import ServerLogManager from cvat.apps.engine.frame_provider import TaskFrameProvider +from cvat.apps.engine.log import ServerLogManager from cvat.apps.engine.media_extractors import ( - MEDIA_TYPES, CachingMediaIterator, IMediaReader, ImageListReader, - Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter, RandomAccessIterator, - ValidateDimension, ZipChunkWriter, ZipCompressedChunkWriter, get_mime, sort, + MEDIA_TYPES, + CachingMediaIterator, + ImageListReader, + IMediaReader, + Mpeg4ChunkWriter, + Mpeg4CompressedChunkWriter, + RandomAccessIterator, + ValidateDimension, + ZipChunkWriter, + ZipCompressedChunkWriter, + get_mime, load_image, + sort, ) from cvat.apps.engine.models import RequestAction, RequestTarget -from cvat.apps.engine.utils import ( - av_scan_paths, format_list, get_rq_job_meta, - define_dependent_job, get_rq_lock_by_user, take_by -) from cvat.apps.engine.rq_job_handler import RQId from cvat.apps.engine.task_validation import HoneypotFrameSelector -from cvat.utils.http import make_requests_session, PROXIES_FOR_UNTRUSTED_URLS +from cvat.apps.engine.utils import ( + av_scan_paths, + define_dependent_job, + format_list, + get_rq_job_meta, + get_rq_lock_by_user, + take_by, +) +from cvat.utils.http import PROXIES_FOR_UNTRUSTED_URLS, make_requests_session from utils.dataset_manifest import ImageManifestManager, VideoManifestManager, is_manifest from utils.dataset_manifest.core import VideoManifestValidator, is_dataset_manifest from utils.dataset_manifest.utils import detect_related_images + from .cloud_provider import db_storage_to_storage_instance slogger = ServerLogManager(__name__) @@ -77,7 +92,7 @@ def create( ############################# Internal implementation for server API -JobFileMapping = List[List[str]] +JobFileMapping = list[list[str]] class SegmentParams(NamedTuple): start_frame: int @@ -91,10 +106,10 @@ class SegmentsParams(NamedTuple): overlap: int def _copy_data_from_share_point( - server_files: List[str], + server_files: list[str], upload_dir: str, server_dir: Optional[str] = None, - server_files_exclude: Optional[List[str]] = None, + server_files_exclude: Optional[list[str]] = None, ): job = rq.get_current_job() job.meta['status'] = 'Data are being copied from source..' @@ -304,7 +319,7 @@ def _validate_data(counter, manifest_files=None): return counter, task_modes[0] def _validate_job_file_mapping( - db_task: models.Task, data: Dict[str, Any] + db_task: models.Task, data: dict[str, Any] ) -> Optional[JobFileMapping]: job_file_mapping = data.get('job_file_mapping', None) @@ -343,7 +358,7 @@ def _validate_job_file_mapping( return job_file_mapping def _validate_validation_params( - db_task: models.Task, data: Dict[str, Any], *, is_backup_restore: bool = False + db_task: models.Task, data: dict[str, Any], *, is_backup_restore: bool = False ) -> Optional[dict[str, Any]]: params = data.get('validation_params', {}) if not params: @@ -382,7 +397,7 @@ def _validate_validation_params( return params def _validate_manifest( - manifests: List[str], + manifests: list[str], root_dir: Optional[str], *, is_in_cloud: bool, @@ -455,7 +470,7 @@ def _download_data(urls, upload_dir): def _download_data_from_cloud_storage( db_storage: models.CloudStorage, - files: List[str], + files: list[str], upload_dir: str, ): cloud_storage_instance = db_storage_to_storage_instance(db_storage) @@ -479,7 +494,7 @@ def _read_dataset_manifest(path: str, *, create_index: bool = False) -> ImageMan def _restore_file_order_from_manifest( extractor: ImageListReader, manifest: ImageManifestManager, upload_dir: str -) -> List[str]: +) -> list[str]: """ Restores file ordering for the "predefined" file sorting method of the task creation. Checks for extra files in the input. @@ -511,7 +526,7 @@ def _restore_file_order_from_manifest( return [input_files[fn] for fn in manifest_files] def _create_task_manifest_based_on_cloud_storage_manifest( - sorted_media: List[str], + sorted_media: list[str], cloud_storage_manifest_prefix: str, cloud_storage_manifest: ImageManifestManager, manifest: ImageManifestManager, @@ -536,7 +551,7 @@ def _add_prefix(properties): def _create_task_manifest_from_cloud_data( db_storage: models.CloudStorage, - sorted_media: List[str], + sorted_media: list[str], manifest: ImageManifestManager, dimension: models.DimensionType = models.DimensionType.DIM_2D, *, @@ -557,7 +572,7 @@ def _create_task_manifest_from_cloud_data( @transaction.atomic def _create_thread( db_task: Union[int, models.Task], - data: Dict[str, Any], + data: dict[str, Any], *, is_backup_restore: bool = False, is_dataset_import: bool = False, @@ -1598,7 +1613,7 @@ def save_chunks( frame_map = {} # frame number -> extractor frame number if isinstance(media_extractor, MEDIA_TYPES['video']['extractor']): - def _get_frame_size(frame_tuple: Tuple[av.VideoFrame, Any, Any]) -> int: + def _get_frame_size(frame_tuple: tuple[av.VideoFrame, Any, Any]) -> int: # There is no need to be absolutely precise here, # just need to provide the reasonable upper boundary. # Return bytes needed for 1 frame diff --git a/cvat/apps/engine/task_validation.py b/cvat/apps/engine/task_validation.py index 3f15b7d79716..4734c153e8b4 100644 --- a/cvat/apps/engine/task_validation.py +++ b/cvat/apps/engine/task_validation.py @@ -2,25 +2,109 @@ # # SPDX-License-Identifier: MIT -from typing import Generic, Mapping, Sequence, TypeVar +from __future__ import annotations +from typing import Callable, Generic, Iterable, Mapping, Sequence, TypeVar + +import attrs import numpy as np -_T = TypeVar("_T") +_K = TypeVar("_K") + + +@attrs.define +class _BaggedCounter(Generic[_K]): + # Stores items with count = k in a single "bag". Bags are stored in the ascending order + bags: dict[ + int, + dict[_K, None], + # dict is used instead of a set to preserve item order. It's also more performant + ] + + @staticmethod + def from_dict(item_counts: Mapping[_K, int]) -> _BaggedCounter: + return _BaggedCounter.from_counts(item_counts, item_count=item_counts.__getitem__) + + @staticmethod + def from_counts(items: Sequence[_K], item_count: Callable[[_K], int]) -> _BaggedCounter: + bags = {} + for item in items: + count = item_count(item) + bags.setdefault(count, dict())[item] = None + + return _BaggedCounter(bags=bags) + + def __attrs_post_init__(self): + self._sort_bags() + + def _sort_bags(self): + self.bags = dict(sorted(self.bags.items(), key=lambda e: e[0])) + + def shuffle(self, *, rng: np.random.Generator | None): + if not rng: + rng = np.random.default_rng() + + for count, bag in self.bags.items(): + items = list(bag.items()) + rng.shuffle(items) + self.bags[count] = dict(items) + + def use_item(self, item: _K, *, count: int | None = None, bag: dict | None = None): + if count is not None: + if bag is None: + bag = self.bags[count] + elif count is None and bag is None: + count, bag = next((c, b) for c, b in self.bags.items() if item in b) + else: + raise AssertionError("'bag' can only be used together with 'count'") + bag.pop(item) -class HoneypotFrameSelector(Generic[_T]): + if not bag: + self.bags.pop(count) + + next_bag = self.bags.get(count + 1) + if next_bag is None: + next_bag = {} + self.bags[count + 1] = next_bag + self._sort_bags() # the new bag can be added in the wrong position if there were gaps + + next_bag[item] = None + + def __iter__(self) -> Iterable[tuple[int, _K, dict]]: + for count, bag in self.bags.items(): # bags must be ordered + for item in bag: + yield (count, item, bag) + + def select_next_least_used(self, count: int) -> Sequence[_K]: + pick = [None] * count + pick_original_use_counts = [(None, None)] * count + for i, (use_count, item, bag) in zip(range(count), self): + pick[i] = item + pick_original_use_counts[i] = (use_count, bag) + + for item, (use_count, bag) in zip(pick, pick_original_use_counts): + self.use_item(item, count=use_count, bag=bag) + + return pick + + +class HoneypotFrameSelector(Generic[_K]): def __init__( - self, validation_frame_counts: Mapping[_T, int], *, rng: np.random.Generator | None = None + self, + validation_frame_counts: Mapping[_K, int], + *, + rng: np.random.Generator | None = None, ): - self.validation_frame_counts = validation_frame_counts - if not rng: rng = np.random.default_rng() self.rng = rng - def select_next_frames(self, count: int) -> Sequence[_T]: + self._counter = _BaggedCounter.from_dict(validation_frame_counts) + self._counter.shuffle(rng=rng) + + def select_next_frames(self, count: int) -> Sequence[_K]: # This approach guarantees that: # - every GT frame is used # - GT frames are used uniformly (at most min count + 1) @@ -28,20 +112,8 @@ def select_next_frames(self, count: int) -> Sequence[_T]: # - honeypot sets are different in jobs # - honeypot sets are random # if possible (if the job and GT counts allow this). - pick = [] - - for random_number in self.rng.random(count): - least_count = min(c for f, c in self.validation_frame_counts.items() if f not in pick) - least_used_frames = tuple( - f - for f, c in self.validation_frame_counts.items() - if f not in pick - if c == least_count - ) - - selected_item = int(random_number * len(least_used_frames)) - selected_frame = least_used_frames[selected_item] - pick.append(selected_frame) - self.validation_frame_counts[selected_frame] += 1 - - return pick + # Picks must be reproducible for a given rng state. + """ + Selects 'count' least used items randomly, without repetition + """ + return self._counter.select_next_least_used(count) diff --git a/cvat/apps/engine/tests/test_lazy_list.py b/cvat/apps/engine/tests/test_lazy_list.py index 6ba4b07dd38f..2a021f89b94a 100644 --- a/cvat/apps/engine/tests/test_lazy_list.py +++ b/cvat/apps/engine/tests/test_lazy_list.py @@ -1,9 +1,9 @@ -import unittest import copy import pickle +import unittest from typing import TypeVar -from cvat.apps.engine.lazy_list import LazyList +from cvat.apps.engine.lazy_list import LazyList T = TypeVar('T') diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index e6ed6b6c0303..d59c310e5a3c 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -3,10 +3,10 @@ # # SPDX-License-Identifier: MIT -from contextlib import ExitStack -from datetime import timedelta +import copy import io -from itertools import product +import json +import logging import os import random import shutil @@ -15,40 +15,56 @@ import xml.etree.ElementTree as ET import zipfile from collections import defaultdict +from contextlib import ExitStack +from datetime import timedelta from enum import Enum from glob import glob from io import BytesIO, IOBase -from unittest import mock +from itertools import product from time import sleep -import logging -import copy -import json +from unittest import mock import av import django_rq import numpy as np -from pdf2image import convert_from_bytes -from pyunpack import Archive from django.conf import settings from django.contrib.auth.models import Group, User from django.http import HttpResponse +from pdf2image import convert_from_bytes from PIL import Image from pycocotools import coco as coco_loader +from pyunpack import Archive from rest_framework import status from rest_framework.test import APIClient from cvat.apps.dataset_manager.tests.utils import TestDir from cvat.apps.dataset_manager.util import current_function_name -from cvat.apps.engine.models import (AttributeSpec, AttributeType, Data, Job, - Project, Segment, StageChoice, StatusChoice, Task, Label, StorageMethodChoice, - StorageChoice, DimensionType, SortingMethod) from cvat.apps.engine.media_extractors import ValidateDimension, sort -from cvat.apps.engine.tests.utils import get_paginated_collection +from cvat.apps.engine.models import ( + AttributeSpec, + AttributeType, + Data, + DimensionType, + Job, + Label, + Project, + Segment, + SortingMethod, + StageChoice, + StatusChoice, + StorageChoice, + StorageMethodChoice, + Task, +) +from cvat.apps.engine.tests.utils import ( + ApiTestBase, + ForceLogin, + generate_image_file, + generate_video_file, + get_paginated_collection, +) from utils.dataset_manifest import ImageManifestManager, VideoManifestManager -from cvat.apps.engine.tests.utils import (ApiTestBase, ForceLogin, - generate_image_file, generate_video_file) - #suppress av warnings logging.getLogger('libav').setLevel(logging.ERROR) @@ -6127,13 +6143,13 @@ def _get_initial_annotation(annotation_format): elif annotation_format == "YOLO 1.1": annotations["shapes"] = rectangle_shapes_wo_attrs - elif annotation_format == "YOLOv8 Detection 1.0": + elif annotation_format == "Ultralytics YOLO Detection 1.0": annotations["shapes"] = rectangle_shapes_wo_attrs - elif annotation_format == "YOLOv8 Oriented Bounding Boxes 1.0": + elif annotation_format == "Ultralytics YOLO Oriented Bounding Boxes 1.0": annotations["shapes"] = rectangle_shapes_wo_attrs - elif annotation_format == "YOLOv8 Segmentation 1.0": + elif annotation_format == "Ultralytics YOLO Segmentation 1.0": annotations["shapes"] = polygon_shapes_wo_attrs elif annotation_format == "COCO 1.0": @@ -6382,6 +6398,9 @@ def _get_initial_annotation(annotation_format): formats['CVAT for video 1.1'] = 'CVAT 1.1' if 'CVAT for images 1.1' in export_formats: formats['CVAT for images 1.1'] = 'CVAT 1.1' + if 'Ultralytics YOLO Detection 1.0' in import_formats: + if 'Ultralytics YOLO Detection Track 1.0' in export_formats: + formats['Ultralytics YOLO Detection Track 1.0'] = 'Ultralytics YOLO Detection 1.0' if set(import_formats) ^ set(export_formats): # NOTE: this may not be an error, so we should not fail print("The following import formats have no pair:", @@ -6493,7 +6512,10 @@ def etree_to_dict(t): self.assertEqual(meta["task"]["name"], task["name"]) elif format_name == "PASCAL VOC 1.1": self.assertTrue(zipfile.is_zipfile(content)) - elif format_name in ["YOLO 1.1", "YOLOv8 Detection 1.0", "YOLOv8 Segmentation 1.0", "YOLOv8 Oriented Bounding Boxes 1.0", "YOLOv8 Pose 1.0"]: + elif format_name in [ + "YOLO 1.1", "Ultralytics YOLO Detection 1.0", "Ultralytics YOLO Segmentation 1.0", + "Ultralytics YOLO Oriented Bounding Boxes 1.0", "Ultralytics YOLO Pose 1.0", + ]: self.assertTrue(zipfile.is_zipfile(content)) elif format_name in ['Kitti Raw Format 1.0','Sly Point Cloud Format 1.0']: self.assertTrue(zipfile.is_zipfile(content)) diff --git a/cvat/apps/engine/tests/test_rest_api_3D.py b/cvat/apps/engine/tests/test_rest_api_3D.py index 67791c3c113c..087448c90dd2 100644 --- a/cvat/apps/engine/tests/test_rest_api_3D.py +++ b/cvat/apps/engine/tests/test_rest_api_3D.py @@ -4,7 +4,9 @@ # SPDX-License-Identifier: MIT +import copy import io +import itertools import os import os.path as osp import tempfile @@ -13,18 +15,15 @@ from collections import defaultdict from glob import glob from io import BytesIO -import copy from shutil import copyfile -import itertools from django.contrib.auth.models import Group, User from rest_framework import status +from cvat.apps.dataset_manager.task import TaskAnnotation from cvat.apps.dataset_manager.tests.utils import TestDir from cvat.apps.engine.media_extractors import ValidateDimension -from cvat.apps.dataset_manager.task import TaskAnnotation - -from cvat.apps.engine.tests.utils import get_paginated_collection, ApiTestBase, ForceLogin +from cvat.apps.engine.tests.utils import ApiTestBase, ForceLogin, get_paginated_collection CREATE_ACTION = "create" UPDATE_ACTION = "update" diff --git a/cvat/apps/engine/tests/utils.py b/cvat/apps/engine/tests/utils.py index 3d2a533d1e97..09fd850b2c19 100644 --- a/cvat/apps/engine/tests/utils.py +++ b/cvat/apps/engine/tests/utils.py @@ -2,21 +2,22 @@ # # SPDX-License-Identifier: MIT -from contextlib import contextmanager -from io import BytesIO -from typing import Any, Callable, Dict, Iterator, Sequence, TypeVar import itertools import logging import os +from collections.abc import Iterator, Sequence +from contextlib import contextmanager +from io import BytesIO +from typing import Any, Callable, TypeVar +import av +import django_rq +import numpy as np from django.conf import settings from django.core.cache import caches from django.http.response import HttpResponse from PIL import Image from rest_framework.test import APITestCase -import av -import django_rq -import numpy as np T = TypeVar('T') @@ -178,6 +179,6 @@ def get_paginated_collection( def filter_dict( - d: Dict[str, Any], *, keep: Sequence[str] = None, drop: Sequence[str] = None -) -> Dict[str, Any]: + d: dict[str, Any], *, keep: Sequence[str] = None, drop: Sequence[str] = None +) -> dict[str, Any]: return {k: v for k, v in d.items() if (not keep or k in keep) and (not drop or k not in drop)} diff --git a/cvat/apps/engine/urls.py b/cvat/apps/engine/urls.py index 1755197ebcdf..1380ae5f7961 100644 --- a/cvat/apps/engine/urls.py +++ b/cvat/apps/engine/urls.py @@ -3,14 +3,13 @@ # # SPDX-License-Identifier: MIT -from django.urls import path, include -from . import views -from rest_framework import routers - -from django.views.generic import RedirectView from django.conf import settings - +from django.urls import include, path +from django.views.generic import RedirectView from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView +from rest_framework import routers + +from . import views router = routers.DefaultRouter(trailing_slash=False) router.register('projects', views.ProjectViewSet) diff --git a/cvat/apps/engine/utils.py b/cvat/apps/engine/utils.py index 13d1d354dd3d..b3e3d48f69d6 100644 --- a/cvat/apps/engine/utils.py +++ b/cvat/apps/engine/utils.py @@ -4,43 +4,39 @@ # SPDX-License-Identifier: MIT import ast -from itertools import islice -import cv2 as cv -from collections import namedtuple import hashlib import importlib +import logging +import os +import platform +import re +import subprocess import sys import traceback -from contextlib import suppress, nullcontext -from typing import ( - Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Mapping, Sequence, TypeVar, Union -) -import subprocess -import os import urllib.parse -import re -import logging -import platform +from collections import namedtuple +from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence +from contextlib import nullcontext, suppress +from itertools import islice +from multiprocessing import cpu_count +from pathlib import Path +from typing import Any, Callable, Optional, TypeVar, Union +import cv2 as cv from attr.converters import to_bool +from av import VideoFrame from datumaro.util.os_util import walk -from rq.job import Job, Dependency -from django_rq.queues import DjangoRQ -from pathlib import Path - +from django.conf import settings +from django.core.exceptions import ValidationError from django.http.request import HttpRequest from django.utils import timezone from django.utils.http import urlencode -from rest_framework.reverse import reverse as _reverse - -from av import VideoFrame -from PIL import Image -from multiprocessing import cpu_count - -from django.core.exceptions import ValidationError +from django_rq.queues import DjangoRQ from django_sendfile import sendfile as _sendfile -from django.conf import settings +from PIL import Image from redis.lock import Lock +from rest_framework.reverse import reverse as _reverse +from rq.job import Dependency, Job Import = namedtuple("Import", ["module", "name", "alias"]) @@ -231,8 +227,8 @@ def get_rq_job_meta( result_url: Optional[str] = None, ): # to prevent circular import - from cvat.apps.webhooks.signals import project_id, organization_id - from cvat.apps.events.handlers import task_id, job_id, organization_slug + from cvat.apps.events.handlers import job_id, organization_slug, task_id + from cvat.apps.webhooks.signals import organization_id, project_id oid = organization_id(db_obj) oslug = organization_slug(db_obj) @@ -264,7 +260,7 @@ def get_rq_job_meta( return meta def reverse(viewname, *, args=None, kwargs=None, - query_params: Optional[Dict[str, str]] = None, + query_params: Optional[dict[str, str]] = None, request: Optional[HttpRequest] = None, ) -> str: """ @@ -283,7 +279,7 @@ def reverse(viewname, *, args=None, kwargs=None, def get_server_url(request: HttpRequest) -> str: return request.build_absolute_uri('/') -def build_field_filter_params(field: str, value: Any) -> Dict[str, str]: +def build_field_filter_params(field: str, value: Any) -> dict[str, str]: """ Builds a collection filter query params for a single field and value. """ diff --git a/cvat/apps/engine/view_utils.py b/cvat/apps/engine/view_utils.py index 2acb8bac780f..dbac90720b43 100644 --- a/cvat/apps/engine/view_utils.py +++ b/cvat/apps/engine/view_utils.py @@ -4,16 +4,16 @@ # NOTE: importing in the utils.py header leads to circular importing -from typing import Optional, Type +from typing import Optional from django.db.models.query import QuerySet from django.http.request import HttpRequest from django.http.response import HttpResponse +from drf_spectacular.utils import extend_schema from rest_framework.decorators import action from rest_framework.response import Response from rest_framework.serializers import Serializer from rest_framework.viewsets import GenericViewSet -from drf_spectacular.utils import extend_schema from cvat.apps.engine.mixins import UploadMixin from cvat.apps.engine.parsers import TusUploadParser @@ -23,9 +23,9 @@ def make_paginated_response( queryset: QuerySet, *, viewset: GenericViewSet, - response_type: Optional[Type[HttpResponse]] = None, - serializer_type: Optional[Type[Serializer]] = None, - request: Optional[Type[HttpRequest]] = None, + response_type: Optional[type[HttpResponse]] = None, + serializer_type: Optional[type[Serializer]] = None, + request: Optional[type[HttpRequest]] = None, **serializer_params ): # Adapted from the mixins.ListModelMixin.list() @@ -54,7 +54,7 @@ def make_paginated_response( return response_type(serializer.data) -def list_action(serializer_class: Type[Serializer], **kwargs): +def list_action(serializer_class: type[Serializer], **kwargs): params = dict( detail=True, methods=["GET"], diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 9692cfc2f750..6b70836d53a1 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -13,104 +13,173 @@ import traceback import zlib from abc import ABCMeta, abstractmethod -from contextlib import suppress -from PIL import Image -from types import SimpleNamespace -from typing import Optional, Any, Dict, List, Union, cast, Callable, Mapping, Iterable from collections import namedtuple +from collections.abc import Iterable, Mapping +from contextlib import suppress from copy import copy from datetime import datetime -from redis.exceptions import ConnectionError as RedisConnectionError +from pathlib import Path from tempfile import NamedTemporaryFile +from types import SimpleNamespace +from typing import Any, Callable, Optional, Union, cast import django_rq from attr.converters import to_bool from django.conf import settings from django.contrib.auth.models import User -from django.db import IntegrityError, transaction +from django.db import IntegrityError from django.db import models as django_models +from django.db import transaction from django.db.models.query import Prefetch -from django.http import HttpResponse, HttpRequest, HttpResponseNotFound, HttpResponseBadRequest +from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest, HttpResponseNotFound from django.utils import timezone from django.utils.decorators import method_decorator from django.views.decorators.cache import never_cache from django_rq.queues import DjangoRQ - from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import ( - OpenApiExample, OpenApiParameter, OpenApiResponse, PolymorphicProxySerializer, - extend_schema_view, extend_schema + OpenApiExample, + OpenApiParameter, + OpenApiResponse, + PolymorphicProxySerializer, + extend_schema, + extend_schema_view, ) - -from pathlib import Path +from PIL import Image +from redis.exceptions import ConnectionError as RedisConnectionError from rest_framework import mixins, serializers, status, viewsets from rest_framework.decorators import action -from rest_framework.exceptions import APIException, NotFound, ValidationError, PermissionDenied +from rest_framework.exceptions import APIException, NotFound, PermissionDenied, ValidationError from rest_framework.parsers import MultiPartParser from rest_framework.permissions import SAFE_METHODS from rest_framework.response import Response from rest_framework.settings import api_settings - -from rq.job import Job as RQJob, JobStatus as RQJobStatus +from rq.job import Job as RQJob +from rq.job import JobStatus as RQJobStatus import cvat.apps.dataset_manager as dm import cvat.apps.dataset_manager.views # pylint: disable=unused-import -from cvat.apps.engine.cloud_provider import db_storage_to_storage_instance, import_resource_from_cloud_storage -from cvat.apps.events.handlers import handle_dataset_import from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer +from cvat.apps.engine import backup +from cvat.apps.engine.cache import CvatChunkTimestampMismatchError, LockError, MediaCache +from cvat.apps.engine.cloud_provider import ( + db_storage_to_storage_instance, + import_resource_from_cloud_storage, +) +from cvat.apps.engine.filters import ( + NonModelJsonLogicFilter, + NonModelOrderingFilter, + NonModelSimpleFilter, +) from cvat.apps.engine.frame_provider import ( - DataWithMeta, IFrameProvider, TaskFrameProvider, JobFrameProvider, FrameQuality + DataWithMeta, + FrameQuality, + IFrameProvider, + JobFrameProvider, + TaskFrameProvider, ) -from cvat.apps.engine.filters import NonModelSimpleFilter, NonModelOrderingFilter, NonModelJsonLogicFilter +from cvat.apps.engine.location import StorageType, get_location_configuration from cvat.apps.engine.media_extractors import get_mime -from cvat.apps.engine.permissions import AnnotationGuidePermission, get_iam_context +from cvat.apps.engine.mixins import ( + BackupMixin, + CsrfWorkaroundMixin, + DatasetMixin, + PartialUpdateModelMixin, + UploadMixin, +) +from cvat.apps.engine.models import AnnotationGuide, Asset, ClientFile, CloudProviderChoice +from cvat.apps.engine.models import CloudStorage as CloudStorageModel from cvat.apps.engine.models import ( - ClientFile, Job, JobType, Label, Task, Project, Issue, Data, - Comment, StorageMethodChoice, StorageChoice, - CloudProviderChoice, Location, CloudStorage as CloudStorageModel, - Asset, AnnotationGuide, RequestStatus, RequestAction, RequestTarget, RequestSubresource + Comment, + Data, + Issue, + Job, + JobType, + Label, + Location, + Project, + RequestAction, + RequestStatus, + RequestSubresource, + RequestTarget, + StorageChoice, + StorageMethodChoice, + Task, +) +from cvat.apps.engine.permissions import ( + AnnotationGuidePermission, + CloudStoragePermission, + CommentPermission, + IssuePermission, + JobPermission, + LabelPermission, + ProjectPermission, + TaskPermission, + UserPermission, + get_cloud_storage_for_import_or_export, + get_iam_context, ) +from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField, is_rq_job_owner from cvat.apps.engine.serializers import ( - AboutSerializer, AnnotationFileSerializer, BasicUserSerializer, - DataMetaReadSerializer, DataMetaWriteSerializer, DataSerializer, FileInfoSerializer, - JobDataMetaWriteSerializer, JobReadSerializer, JobWriteSerializer, - JobValidationLayoutReadSerializer, JobValidationLayoutWriteSerializer, - LabelSerializer, LabeledDataSerializer, - ProjectReadSerializer, ProjectWriteSerializer, - RqStatusSerializer, TaskReadSerializer, TaskValidationLayoutReadSerializer, TaskValidationLayoutWriteSerializer, TaskWriteSerializer, - UserSerializer, PluginsSerializer, IssueReadSerializer, - AnnotationGuideReadSerializer, AnnotationGuideWriteSerializer, - AssetReadSerializer, AssetWriteSerializer, - IssueWriteSerializer, CommentReadSerializer, CommentWriteSerializer, CloudStorageWriteSerializer, - CloudStorageReadSerializer, DatasetFileSerializer, - ProjectFileSerializer, TaskFileSerializer, RqIdSerializer, CloudStorageContentSerializer, + AboutSerializer, + AnnotationFileSerializer, + AnnotationGuideReadSerializer, + AnnotationGuideWriteSerializer, + AssetReadSerializer, + AssetWriteSerializer, + BasicUserSerializer, + CloudStorageContentSerializer, + CloudStorageReadSerializer, + CloudStorageWriteSerializer, + CommentReadSerializer, + CommentWriteSerializer, + DataMetaReadSerializer, + DataMetaWriteSerializer, + DataSerializer, + DatasetFileSerializer, + FileInfoSerializer, + IssueReadSerializer, + IssueWriteSerializer, + JobDataMetaWriteSerializer, + JobReadSerializer, + JobValidationLayoutReadSerializer, + JobValidationLayoutWriteSerializer, + JobWriteSerializer, + LabeledDataSerializer, + LabelSerializer, + PluginsSerializer, + ProjectFileSerializer, + ProjectReadSerializer, + ProjectWriteSerializer, RequestSerializer, + RqIdSerializer, + RqStatusSerializer, + TaskFileSerializer, + TaskReadSerializer, + TaskValidationLayoutReadSerializer, + TaskValidationLayoutWriteSerializer, + TaskWriteSerializer, + UserSerializer, ) -from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export - -from utils.dataset_manifest import ImageManifestManager from cvat.apps.engine.utils import ( - av_scan_paths, process_failed_job, - parse_exception_message, get_rq_job_meta, - import_resource_with_clean_up_after, sendfile, define_dependent_job, get_rq_lock_by_user, -) -from cvat.apps.engine.rq_job_handler import RQId, is_rq_job_owner, RQJobMetaField -from cvat.apps.engine import backup -from cvat.apps.engine.mixins import ( - PartialUpdateModelMixin, UploadMixin, DatasetMixin, BackupMixin, CsrfWorkaroundMixin + av_scan_paths, + define_dependent_job, + get_rq_job_meta, + get_rq_lock_by_user, + import_resource_with_clean_up_after, + parse_exception_message, + process_failed_job, + sendfile, ) -from cvat.apps.engine.location import get_location_configuration, StorageType +from cvat.apps.engine.view_utils import tus_chunk_action +from cvat.apps.events.handlers import handle_dataset_import +from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS +from cvat.apps.iam.permissions import IsAuthenticatedOrReadPublicResource, PolicyEnforcer +from utils.dataset_manifest import ImageManifestManager from . import models, task from .log import ServerLogManager -from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS -from cvat.apps.iam.permissions import PolicyEnforcer, IsAuthenticatedOrReadPublicResource -from cvat.apps.engine.cache import MediaCache, CvatChunkTimestampMismatchError, LockError -from cvat.apps.engine.permissions import (CloudStoragePermission, - CommentPermission, IssuePermission, JobPermission, LabelPermission, ProjectPermission, - TaskPermission, UserPermission) -from cvat.apps.engine.view_utils import tus_chunk_action slogger = ServerLogManager(__name__) @@ -1076,7 +1145,7 @@ def _maybe_append_upload_info_entry(self, filename: str): filename = self._prepare_upload_info_entry(filename) task_data.client_files.get_or_create(file=filename) - def _append_upload_info_entries(self, client_files: List[Dict[str, Any]]): + def _append_upload_info_entries(self, client_files: list[dict[str, Any]]): # batch version of _maybe_append_upload_info_entry() without optional insertion task_data = cast(Data, self._object.data) task_data.client_files.bulk_create([ @@ -1084,7 +1153,7 @@ def _append_upload_info_entries(self, client_files: List[Dict[str, Any]]): for cf in client_files ]) - def _sort_uploaded_files(self, uploaded_files: List[str], ordering: List[str]) -> List[str]: + def _sort_uploaded_files(self, uploaded_files: list[str], ordering: list[str]) -> list[str]: """ Applies file ordering for the "predefined" file sorting method of the task creation. @@ -3568,7 +3637,7 @@ def get_queryset(self): def queues(self) -> Iterable[DjangoRQ]: return (django_rq.get_queue(queue_name) for queue_name in self.SUPPORTED_QUEUES) - def _get_rq_jobs_from_queue(self, queue: DjangoRQ, user_id: int) -> List[RQJob]: + def _get_rq_jobs_from_queue(self, queue: DjangoRQ, user_id: int) -> list[RQJob]: job_ids = set(queue.get_job_ids() + queue.started_job_registry.get_job_ids() + queue.finished_job_registry.get_job_ids() + @@ -3588,7 +3657,7 @@ def _get_rq_jobs_from_queue(self, queue: DjangoRQ, user_id: int) -> List[RQJob]: return jobs - def _get_rq_jobs(self, user_id: int) -> List[RQJob]: + def _get_rq_jobs(self, user_id: int) -> list[RQJob]: """ Get all RQ jobs for a specific user and return them as a list of RQJob objects. diff --git a/cvat/apps/events/__init__.py b/cvat/apps/events/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/cvat/apps/events/apps.py b/cvat/apps/events/apps.py index f700758ad204..c4a7b0a3d9b4 100644 --- a/cvat/apps/events/apps.py +++ b/cvat/apps/events/apps.py @@ -6,10 +6,11 @@ class EventsConfig(AppConfig): - name = 'cvat.apps.events' + name = "cvat.apps.events" def ready(self): - from . import signals # pylint: disable=unused-import - from cvat.apps.iam.permissions import load_app_permissions + load_app_permissions(self) + + from . import signals # pylint: disable=unused-import diff --git a/cvat/apps/events/cache.py b/cvat/apps/events/cache.py index 30d1e67b8fc1..d17a8e703bc1 100644 --- a/cvat/apps/events/cache.py +++ b/cvat/apps/events/cache.py @@ -4,36 +4,42 @@ _caches = {} -class DeleteCache(): + +class DeleteCache: def __init__(self, cache_id): - from cvat.apps.engine.models import Task, Job, Issue, Comment - self._cache = _caches.setdefault(cache_id, { - Task: {}, - Job: {}, - Issue: {}, - Comment: {}, - }) + from cvat.apps.engine.models import Comment, Issue, Job, Task + + self._cache = _caches.setdefault( + cache_id, + { + Task: {}, + Job: {}, + Issue: {}, + Comment: {}, + }, + ) def set(self, instance_class, instance_id, value): self._cache[instance_class][instance_id] = value def pop(self, instance_class, instance_id, default=None): - if instance_class in self._cache and \ - instance_id in self._cache[instance_class]: + if instance_class in self._cache and instance_id in self._cache[instance_class]: return self._cache[instance_class].pop(instance_id, default) def has_key(self, instance_class, instance_id): - if instance_class in self._cache and \ - instance_id in self._cache[instance_class]: + if instance_class in self._cache and instance_id in self._cache[instance_class]: return True return False def clear(self): self._cache.clear() + def get_cache(): from .handlers import request_id + return DeleteCache(request_id()) + def clear_cache(): get_cache().clear() diff --git a/cvat/apps/events/const.py b/cvat/apps/events/const.py new file mode 100644 index 000000000000..9291d9397be3 --- /dev/null +++ b/cvat/apps/events/const.py @@ -0,0 +1,10 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import datetime + +MAX_EVENT_DURATION = datetime.timedelta(seconds=100) +WORKING_TIME_RESOLUTION = datetime.timedelta(milliseconds=1) +WORKING_TIME_SCOPE = "send:working_time" +COMPRESSED_EVENT_SCOPES = frozenset(("change:frame",)) diff --git a/cvat/apps/events/event.py b/cvat/apps/events/event.py index a4afff968549..5368367b70d7 100644 --- a/cvat/apps/events/event.py +++ b/cvat/apps/events/event.py @@ -2,14 +2,15 @@ # # SPDX-License-Identifier: MIT -from rest_framework.renderers import JSONRenderer from datetime import datetime, timezone from typing import Optional from django.db import transaction +from rest_framework.renderers import JSONRenderer from cvat.apps.engine.log import vlogger + def event_scope(action, resource): return f"{action}:{resource}" @@ -41,6 +42,7 @@ def select(cls, resources): for action in cls.RESOURCES.get(resource, []) ] + def record_server_event( *, scope: str, @@ -63,11 +65,11 @@ def record_server_event( "scope": scope, "timestamp": str(datetime.now(timezone.utc).timestamp()), "source": "server", - "payload": JSONRenderer().render(payload_with_request_id).decode('UTF-8'), + "payload": JSONRenderer().render(payload_with_request_id).decode("UTF-8"), **kwargs, } - rendered_data = JSONRenderer().render(data).decode('UTF-8') + rendered_data = JSONRenderer().render(data).decode("UTF-8") if on_commit: transaction.on_commit(lambda: vlogger.info(rendered_data), robust=True) @@ -80,6 +82,7 @@ class EventScopeChoice: def choices(cls): return sorted((val, val.upper()) for val in AllEvents.events) + class AllEvents: events = list( event_scope(action, resource) diff --git a/cvat/apps/events/export.py b/cvat/apps/events/export.py index 9225f1141162..770f84dda054 100644 --- a/cvat/apps/events/export.py +++ b/cvat/apps/events/export.py @@ -2,50 +2,49 @@ # # SPDX-License-Identifier: MIT -from logging import Logger -import os import csv -from datetime import datetime, timedelta, timezone -from dateutil import parser +import os import uuid +from datetime import datetime, timedelta, timezone +from logging import Logger +import clickhouse_connect import django_rq +from dateutil import parser from django.conf import settings -import clickhouse_connect - - from rest_framework import serializers, status from rest_framework.response import Response from cvat.apps.dataset_manager.views import log_exception from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.utils import sendfile from cvat.apps.engine.rq_job_handler import RQJobMetaField +from cvat.apps.engine.utils import sendfile slogger = ServerLogManager(__name__) DEFAULT_CACHE_TTL = timedelta(hours=1) + def _create_csv(query_params, output_filename, cache_ttl): try: - clickhouse_settings = settings.CLICKHOUSE['events'] + clickhouse_settings = settings.CLICKHOUSE["events"] time_filter = { - 'from': query_params.pop('from'), - 'to': query_params.pop('to'), + "from": query_params.pop("from"), + "to": query_params.pop("to"), } query = "SELECT * FROM events" conditions = [] parameters = {} - if time_filter['from']: + if time_filter["from"]: conditions.append(f"timestamp >= {{from:DateTime64}}") - parameters['from'] = time_filter['from'] + parameters["from"] = time_filter["from"] - if time_filter['to']: + if time_filter["to"]: conditions.append(f"timestamp <= {{to:DateTime64}}") - parameters['to'] = time_filter['to'] + parameters["to"] = time_filter["to"] for param, value in query_params.items(): if value: @@ -58,22 +57,23 @@ def _create_csv(query_params, output_filename, cache_ttl): query += " ORDER BY timestamp ASC" with clickhouse_connect.get_client( - host=clickhouse_settings['HOST'], - database=clickhouse_settings['NAME'], - port=clickhouse_settings['PORT'], - username=clickhouse_settings['USER'], - password=clickhouse_settings['PASSWORD'], + host=clickhouse_settings["HOST"], + database=clickhouse_settings["NAME"], + port=clickhouse_settings["PORT"], + username=clickhouse_settings["USER"], + password=clickhouse_settings["PASSWORD"], ) as client: result = client.query(query, parameters=parameters) - with open(output_filename, 'w', encoding='UTF8') as f: + with open(output_filename, "w", encoding="UTF8") as f: writer = csv.writer(f) writer.writerow(result.column_names) writer.writerows(result.result_rows) archive_ctime = os.path.getctime(output_filename) scheduler = django_rq.get_scheduler(settings.CVAT_QUEUES.EXPORT_DATA.value) - cleaning_job = scheduler.enqueue_in(time_delta=cache_ttl, + cleaning_job = scheduler.enqueue_in( + time_delta=cache_ttl, func=_clear_export_cache, file_path=output_filename, file_ctime=archive_ctime, @@ -89,36 +89,37 @@ def _create_csv(query_params, output_filename, cache_ttl): log_exception(slogger.glob) raise + def export(request, filter_query, queue_name): - action = request.query_params.get('action', None) - filename = request.query_params.get('filename', None) + action = request.query_params.get("action", None) + filename = request.query_params.get("filename", None) query_params = { - 'org_id': filter_query.get('org_id', None), - 'project_id': filter_query.get('project_id', None), - 'task_id': filter_query.get('task_id', None), - 'job_id': filter_query.get('job_id', None), - 'user_id': filter_query.get('user_id', None), - 'from': filter_query.get('from', None), - 'to': filter_query.get('to', None), + "org_id": filter_query.get("org_id", None), + "project_id": filter_query.get("project_id", None), + "task_id": filter_query.get("task_id", None), + "job_id": filter_query.get("job_id", None), + "user_id": filter_query.get("user_id", None), + "from": filter_query.get("from", None), + "to": filter_query.get("to", None), } try: - if query_params['from']: - query_params['from'] = parser.parse(query_params['from']).timestamp() + if query_params["from"]: + query_params["from"] = parser.parse(query_params["from"]).timestamp() except parser.ParserError: raise serializers.ValidationError( f"Cannot parse 'from' datetime parameter: {query_params['from']}" ) try: - if query_params['to']: - query_params['to'] = parser.parse(query_params['to']).timestamp() + if query_params["to"]: + query_params["to"] = parser.parse(query_params["to"]).timestamp() except parser.ParserError: raise serializers.ValidationError( f"Cannot parse 'to' datetime parameter: {query_params['to']}" ) - if query_params['from'] and query_params['to'] and query_params['from'] > query_params['to']: + if query_params["from"] and query_params["to"] and query_params["from"] > query_params["to"]: raise serializers.ValidationError("'from' must be before than 'to'") # Set the default time interval to last 30 days @@ -126,14 +127,13 @@ def export(request, filter_query, queue_name): query_params["to"] = datetime.now(timezone.utc) query_params["from"] = query_params["to"] - timedelta(days=30) - if action not in (None, 'download'): - raise serializers.ValidationError( - "Unexpected action specified for the request") + if action not in (None, "download"): + raise serializers.ValidationError("Unexpected action specified for the request") - query_id = request.query_params.get('query_id', None) or uuid.uuid4() + query_id = request.query_params.get("query_id", None) or uuid.uuid4() rq_id = f"export:csv-logs-{query_id}-by-{request.user}" response_data = { - 'query_id': query_id, + "query_id": query_id, } queue = django_rq.get_queue(queue_name) @@ -147,16 +147,14 @@ def export(request, filter_query, queue_name): timestamp = datetime.strftime(datetime.now(), "%Y_%m_%d_%H_%M_%S") filename = filename or f"logs_{timestamp}.csv" - return sendfile(request, file_path, attachment=True, - attachment_filename=filename) + return sendfile(request, file_path, attachment=True, attachment_filename=filename) else: if os.path.exists(file_path): return Response(status=status.HTTP_201_CREATED) elif rq_job.is_failed: exc_info = rq_job.meta.get(RQJobMetaField.FORMATTED_EXCEPTION, str(rq_job.exc_info)) rq_job.delete() - return Response(exc_info, - status=status.HTTP_500_INTERNAL_SERVER_ERROR) + return Response(exc_info, status=status.HTTP_500_INTERNAL_SERVER_ERROR) else: return Response(data=response_data, status=status.HTTP_202_ACCEPTED) @@ -167,18 +165,19 @@ def export(request, filter_query, queue_name): args=(query_params, output_filename, DEFAULT_CACHE_TTL), job_id=rq_id, meta={}, - result_ttl=ttl, failure_ttl=ttl) + result_ttl=ttl, + failure_ttl=ttl, + ) return Response(data=response_data, status=status.HTTP_202_ACCEPTED) + def _clear_export_cache(file_path: str, file_ctime: float, logger: Logger) -> None: try: if os.path.exists(file_path) and os.path.getctime(file_path) == file_ctime: os.remove(file_path) - logger.info( - "Export cache file '{}' successfully removed" \ - .format(file_path)) + logger.info("Export cache file '{}' successfully removed".format(file_path)) except Exception: log_exception(logger) raise diff --git a/cvat/apps/events/handlers.py b/cvat/apps/events/handlers.py index 8f29f91d9a1a..69dd4b11cdd8 100644 --- a/cvat/apps/events/handlers.py +++ b/cvat/apps/events/handlers.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: MIT -import datetime import traceback from typing import Any, Optional, Union @@ -12,25 +11,41 @@ from rest_framework.exceptions import NotAuthenticated from rest_framework.views import exception_handler -from cvat.apps.engine.models import (CloudStorage, Comment, Issue, Job, Label, - Project, ShapeType, Task, User) -from cvat.apps.engine.serializers import (BasicUserSerializer, - CloudStorageReadSerializer, - CommentReadSerializer, - IssueReadSerializer, - JobReadSerializer, LabelSerializer, - ProjectReadSerializer, - TaskReadSerializer) -from cvat.apps.organizations.models import Invitation, Membership, Organization -from cvat.apps.organizations.serializers import (InvitationReadSerializer, - MembershipReadSerializer, - OrganizationReadSerializer) +from cvat.apps.engine.models import ( + CloudStorage, + Comment, + Issue, + Job, + Label, + Project, + ShapeType, + Task, + User, +) from cvat.apps.engine.rq_job_handler import RQJobMetaField +from cvat.apps.engine.serializers import ( + BasicUserSerializer, + CloudStorageReadSerializer, + CommentReadSerializer, + IssueReadSerializer, + JobReadSerializer, + LabelSerializer, + ProjectReadSerializer, + TaskReadSerializer, +) +from cvat.apps.organizations.models import Invitation, Membership, Organization +from cvat.apps.organizations.serializers import ( + InvitationReadSerializer, + MembershipReadSerializer, + OrganizationReadSerializer, +) from cvat.apps.webhooks.models import Webhook from cvat.apps.webhooks.serializers import WebhookReadSerializer from .cache import get_cache +from .const import WORKING_TIME_RESOLUTION, WORKING_TIME_SCOPE from .event import event_scope, record_server_event +from .utils import compute_working_time_per_ids def project_id(instance): @@ -160,9 +175,7 @@ def organization_slug(instance): def get_instance_diff(old_data, data): - ignore_related_fields = ( - "labels", - ) + ignore_related_fields = ("labels",) diff = {} for prop, value in data.items(): if prop in ignore_related_fields: @@ -178,7 +191,7 @@ def get_instance_diff(old_data, data): def _cleanup_fields(obj: dict[str, Any]) -> dict[str, Any]: - fields=( + fields = ( "slug", "id", "name", @@ -198,9 +211,7 @@ def _cleanup_fields(obj: dict[str, Any]) -> dict[str, Any]: "attributes", "key", ) - subfields=( - "url", - ) + subfields = ("url",) data = {} for k, v in obj.items(): @@ -214,11 +225,13 @@ def _cleanup_fields(obj: dict[str, Any]) -> dict[str, Any]: def _get_object_name(instance): - if isinstance(instance, Organization) or \ - isinstance(instance, Project) or \ - isinstance(instance, Task) or \ - isinstance(instance, Job) or \ - isinstance(instance, Label): + if ( + isinstance(instance, Organization) + or isinstance(instance, Project) + or isinstance(instance, Task) + or isinstance(instance, Job) + or isinstance(instance, Label) + ): return getattr(instance, "name", None) if isinstance(instance, User): @@ -250,9 +263,7 @@ def _get_object_name(instance): def get_serializer(instance): - context = { - "request": get_current_request() - } + context = {"request": get_current_request()} serializer = None for model, serializer_class in SERIALIZERS: @@ -261,6 +272,7 @@ def get_serializer(instance): return serializer + def get_serializer_without_url(instance): serializer = get_serializer(instance) if serializer: @@ -289,7 +301,7 @@ def handle_create(scope, instance, **kwargs): scope=scope, request_id=request_id(), on_commit=True, - obj_id=getattr(instance, 'id', None), + obj_id=getattr(instance, "id", None), obj_name=_get_object_name(instance), org_id=oid, org_slug=oslug, @@ -324,7 +336,7 @@ def handle_update(scope, instance, old_instance, **kwargs): request_id=request_id(), on_commit=True, obj_name=prop, - obj_id=getattr(instance, f'{prop}_id', None), + obj_id=getattr(instance, f"{prop}_id", None), obj_val=str(change["new_value"]), org_id=oid, org_slug=oslug, @@ -479,6 +491,7 @@ def filter_track(track): payload={"tracks": tracks}, ) + def handle_dataset_io( instance: Union[Project, Task, Job], action: str, @@ -487,7 +500,7 @@ def handle_dataset_io( cloud_storage_id: Optional[int], **payload_fields, ) -> None: - payload={"format": format_name, **payload_fields} + payload = {"format": format_name, **payload_fields} if cloud_storage_id: payload["cloud_storage"] = {"id": cloud_storage_id} @@ -506,6 +519,7 @@ def handle_dataset_io( payload=payload, ) + def handle_dataset_export( instance: Union[Project, Task, Job], *, @@ -513,8 +527,14 @@ def handle_dataset_export( cloud_storage_id: Optional[int], save_images: bool, ) -> None: - handle_dataset_io(instance, "export", - format_name=format_name, cloud_storage_id=cloud_storage_id, save_images=save_images) + handle_dataset_io( + instance, + "export", + format_name=format_name, + cloud_storage_id=cloud_storage_id, + save_images=save_images, + ) + def handle_dataset_import( instance: Union[Project, Task, Job], @@ -522,7 +542,10 @@ def handle_dataset_import( format_name: str, cloud_storage_id: Optional[int], ) -> None: - handle_dataset_io(instance, "import", format_name=format_name, cloud_storage_id=cloud_storage_id) + handle_dataset_io( + instance, "import", format_name=format_name, cloud_storage_id=cloud_storage_id + ) + def handle_function_call( function_id: str, @@ -544,6 +567,7 @@ def handle_function_call( }, ) + def handle_rq_exception(rq_job, exc_type, exc_value, tb): oid = rq_job.meta.get(RQJobMetaField.ORG_ID, None) oslug = rq_job.meta.get(RQJobMetaField.ORG_SLUG, None) @@ -557,7 +581,7 @@ def handle_rq_exception(rq_job, exc_type, exc_value, tb): payload = { "message": tb_strings[-1].rstrip("\n"), - "stack": ''.join(tb_strings), + "stack": "".join(tb_strings), } record_server_event( @@ -577,10 +601,11 @@ def handle_rq_exception(rq_job, exc_type, exc_value, tb): return False + def handle_viewset_exception(exc, context): response = exception_handler(exc, context) - IGNORED_EXCEPTION_CLASSES = (NotAuthenticated, ) + IGNORED_EXCEPTION_CLASSES = (NotAuthenticated,) if isinstance(exc, IGNORED_EXCEPTION_CLASSES): return response # the standard DRF exception handler only handle APIException, Http404 and PermissionDenied @@ -603,7 +628,7 @@ def handle_viewset_exception(exc, context): "method": request.method, }, "message": tb_strings[-1].rstrip("\n"), - "stack": ''.join(tb_strings), + "stack": "".join(tb_strings), "status_code": status_code, } @@ -619,53 +644,11 @@ def handle_viewset_exception(exc, context): return response + def handle_client_events_push(request, data: dict): - TIME_THRESHOLD = datetime.timedelta(seconds=100) - WORKING_TIME_SCOPE = 'send:working_time' - WORKING_TIME_RESOLUTION = datetime.timedelta(milliseconds=1) - COLLAPSED_EVENT_SCOPES = frozenset(("change:frame",)) org = request.iam_context["organization"] - def read_ids(event: dict) -> tuple[int | None, int | None, int | None]: - return event.get("job_id"), event.get("task_id"), event.get("project_id") - - def get_end_timestamp(event: dict) -> datetime.datetime: - if event["scope"] in COLLAPSED_EVENT_SCOPES: - return event["timestamp"] + datetime.timedelta(milliseconds=event["duration"]) - return event["timestamp"] - - if previous_event := data["previous_event"]: - previous_end_timestamp = get_end_timestamp(previous_event) - previous_ids = read_ids(previous_event) - elif data["events"]: - previous_end_timestamp = data["events"][0]["timestamp"] - previous_ids = read_ids(data["events"][0]) - - working_time_per_ids = {} - for event in data["events"]: - working_time = datetime.timedelta() - timestamp = event["timestamp"] - - if timestamp > previous_end_timestamp: - t_diff = timestamp - previous_end_timestamp - if t_diff < TIME_THRESHOLD: - working_time += t_diff - - previous_end_timestamp = timestamp - - end_timestamp = get_end_timestamp(event) - if end_timestamp > previous_end_timestamp: - working_time += end_timestamp - previous_end_timestamp - previous_end_timestamp = end_timestamp - - if previous_ids not in working_time_per_ids: - working_time_per_ids[previous_ids] = { - "value": datetime.timedelta(), - "timestamp": timestamp, - } - - working_time_per_ids[previous_ids]["value"] += working_time - previous_ids = read_ids(event) + working_time_per_ids = compute_working_time_per_ids(data) if data["events"]: common = { diff --git a/cvat/apps/events/permissions.py b/cvat/apps/events/permissions.py index a1b049cbbd4b..18d30f63ff65 100644 --- a/cvat/apps/events/permissions.py +++ b/cvat/apps/events/permissions.py @@ -4,21 +4,21 @@ # SPDX-License-Identifier: MIT from django.conf import settings - from rest_framework.exceptions import PermissionDenied from cvat.apps.iam.permissions import OpenPolicyAgentPermission, StrEnum from cvat.utils.http import make_requests_session + class EventsPermission(OpenPolicyAgentPermission): class Scopes(StrEnum): - SEND_EVENTS = 'send:events' - DUMP_EVENTS = 'dump:events' + SEND_EVENTS = "send:events" + DUMP_EVENTS = "dump:events" @classmethod def create(cls, request, view, obj, iam_context): permissions = [] - if view.basename == 'events': + if view.basename == "events": for scope in cls.get_scopes(request, view, obj): self = cls.create_base_perm(request, view, scope, iam_context, obj) permissions.append(self) @@ -27,19 +27,21 @@ def create(cls, request, view, obj, iam_context): def __init__(self, **kwargs): super().__init__(**kwargs) - self.url = settings.IAM_OPA_DATA_URL + '/events/allow' + self.url = settings.IAM_OPA_DATA_URL + "/events/allow" def filter(self, query_params): - url = self.url.replace('/allow', '/filter') + url = self.url.replace("/allow", "/filter") with make_requests_session() as session: - r = session.post(url, json=self.payload).json()['result'] + r = session.post(url, json=self.payload).json()["result"] filter_params = query_params.copy() for query in r: for attr, value in query.items(): if filter_params.get(attr, value) != value: - raise PermissionDenied(f"You don't have permission to view events with {attr}={filter_params.get(attr)}") + raise PermissionDenied( + f"You don't have permission to view events with {attr}={filter_params.get(attr)}" + ) else: filter_params[attr] = value return filter_params @@ -47,10 +49,12 @@ def filter(self, query_params): @staticmethod def get_scopes(request, view, obj): Scopes = __class__.Scopes - return [{ - ('create', 'POST'): Scopes.SEND_EVENTS, - ('list', 'GET'): Scopes.DUMP_EVENTS, - }[(view.action, request.method)]] + return [ + { + ("create", "POST"): Scopes.SEND_EVENTS, + ("list", "GET"): Scopes.DUMP_EVENTS, + }[(view.action, request.method)] + ] def get_resource(self): return None diff --git a/cvat/apps/events/rules/tests/generators/events_test.gen.rego.py b/cvat/apps/events/rules/tests/generators/events_test.gen.rego.py index dee2d4a68963..a345c8369f9e 100644 --- a/cvat/apps/events/rules/tests/generators/events_test.gen.rego.py +++ b/cvat/apps/events/rules/tests/generators/events_test.gen.rego.py @@ -83,13 +83,15 @@ def get_data(scope, context, ownership, privilege, membership, resource, same_or "scope": scope, "auth": { "user": {"id": random.randrange(0, 100), "privilege": privilege}, - "organization": { - "id": random.randrange(100, 200), - "owner": {"id": random.randrange(200, 300)}, - "user": {"role": membership}, - } - if context == "organization" - else None, + "organization": ( + { + "id": random.randrange(100, 200), + "owner": {"id": random.randrange(200, 300)}, + "user": {"role": membership}, + } + if context == "organization" + else None + ), }, "resource": resource, } diff --git a/cvat/apps/events/serializers.py b/cvat/apps/events/serializers.py index 9b70f17429c9..f634fef20b87 100644 --- a/cvat/apps/events/serializers.py +++ b/cvat/apps/events/serializers.py @@ -30,19 +30,40 @@ class EventSerializer(serializers.Serializer): class ClientEventsSerializer(serializers.Serializer): ALLOWED_SCOPES = { - 'client': frozenset(( - 'load:cvat', 'load:job', 'save:job','load:workspace', - 'upload:annotations', # TODO: remove in next releases - 'lock:object', # TODO: remove in next releases - 'change:attribute', # TODO: remove in next releases - 'change:label', # TODO: remove in next releases - 'send:exception', 'join:objects', 'change:frame', - 'draw:object', 'paste:object', 'copy:object', 'propagate:object', - 'drag:object', 'resize:object', 'delete:object', - 'merge:objects', 'split:objects', 'group:objects', 'slice:object', - 'zoom:image', 'fit:image', 'rotate:image', 'action:undo', 'action:redo', - 'debug:info', 'run:annotations_action', 'click:element', - )), + "client": frozenset( + ( + "load:cvat", + "load:job", + "save:job", + "load:workspace", + "upload:annotations", # TODO: remove in next releases + "lock:object", # TODO: remove in next releases + "change:attribute", # TODO: remove in next releases + "change:label", # TODO: remove in next releases + "send:exception", + "join:objects", + "change:frame", + "draw:object", + "paste:object", + "copy:object", + "propagate:object", + "drag:object", + "resize:object", + "delete:object", + "merge:objects", + "split:objects", + "group:objects", + "slice:object", + "zoom:image", + "fit:image", + "rotate:image", + "action:undo", + "action:redo", + "debug:info", + "run:annotations_action", + "click:element", + ) + ), } events = EventSerializer(many=True, default=[]) @@ -72,18 +93,24 @@ def to_internal_value(self, data): scope = event["scope"] source = event.get("source", "client") if scope not in ClientEventsSerializer.ALLOWED_SCOPES.get(source, []): - raise serializers.ValidationError({"scope": f"Event scope **{scope}** is not allowed from {source}"}) + raise serializers.ValidationError( + {"scope": f"Event scope **{scope}** is not allowed from {source}"} + ) try: payload = json.loads(event.get("payload", "{}")) except json.JSONDecodeError: - raise serializers.ValidationError({ "payload": "JSON payload is not valid in passed event" }) + raise serializers.ValidationError( + {"payload": "JSON payload is not valid in passed event"} + ) - event.update({ - "timestamp": event["timestamp"] + time_correction, - "source": source, - "payload": json.dumps(payload), - **(user_and_org_data if source == 'client' else {}) - }) + event.update( + { + "timestamp": event["timestamp"] + time_correction, + "source": source, + "payload": json.dumps(payload), + **(user_and_org_data if source == "client" else {}), + } + ) return data diff --git a/cvat/apps/events/tests/test_events.py b/cvat/apps/events/tests/test_events.py index 81b054171dce..3b9c4a6c832c 100644 --- a/cvat/apps/events/tests/test_events.py +++ b/cvat/apps/events/tests/test_events.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: MIT -import json import unittest from datetime import datetime, timedelta, timezone from typing import Optional @@ -10,15 +9,18 @@ from django.contrib.auth import get_user_model from django.test import RequestFactory +from cvat.apps.events.const import MAX_EVENT_DURATION, WORKING_TIME_RESOLUTION from cvat.apps.events.serializers import ClientEventsSerializer +from cvat.apps.events.utils import compute_working_time_per_ids, is_contained from cvat.apps.organizations.models import Organization + class WorkingTimeTestCase(unittest.TestCase): _START_TIMESTAMP = datetime(2024, 1, 1, 12) - _SHORT_GAP = ClientEventsSerializer._TIME_THRESHOLD - timedelta(milliseconds=1) - _SHORT_GAP_INT = _SHORT_GAP / ClientEventsSerializer._WORKING_TIME_RESOLUTION - _LONG_GAP = ClientEventsSerializer._TIME_THRESHOLD - _LONG_GAP_INT = _LONG_GAP / ClientEventsSerializer._WORKING_TIME_RESOLUTION + _SHORT_GAP = MAX_EVENT_DURATION - timedelta(milliseconds=1) + _SHORT_GAP_INT = _SHORT_GAP / WORKING_TIME_RESOLUTION + _LONG_GAP = MAX_EVENT_DURATION + _LONG_GAP_INT = _LONG_GAP / WORKING_TIME_RESOLUTION @staticmethod def _instant_event(timestamp: datetime) -> dict: @@ -33,16 +35,25 @@ def _compressed_event(timestamp: datetime, duration: timedelta) -> dict: return { "scope": "change:frame", "timestamp": timestamp.isoformat(), - "duration": duration // ClientEventsSerializer._WORKING_TIME_RESOLUTION, + "duration": duration // WORKING_TIME_RESOLUTION, } @staticmethod - def _working_time(event: dict) -> int: - payload = json.loads(event["payload"]) - return payload["working_time"] + def _get_actual_working_times(data: dict) -> list[int]: + data_copy = data.copy() + working_times = [] + for event in data["events"]: + data_copy["events"] = [event] + event_working_time = compute_working_time_per_ids(data_copy) + for working_time in event_working_time.values(): + working_times.append((working_time["value"] // WORKING_TIME_RESOLUTION)) + if data_copy["previous_event"] and is_contained(event, data_copy["previous_event"]): + continue + data_copy["previous_event"] = event + return working_times @staticmethod - def _deserialize(events: list[dict], previous_event: Optional[dict] = None) -> list[dict]: + def _deserialize(events: list[dict], previous_event: Optional[dict] = None) -> dict: request = RequestFactory().post("/api/events") request.user = get_user_model()(id=100, username="testuser", email="testuser@example.org") request.iam_context = { @@ -53,125 +64,156 @@ def _deserialize(events: list[dict], previous_event: Optional[dict] = None) -> l data={ "events": events, "previous_event": previous_event, - "timestamp": datetime.now(timezone.utc) + "timestamp": datetime.now(timezone.utc), }, context={"request": request}, ) s.is_valid(raise_exception=True) - return s.validated_data["events"] + return s.validated_data def test_instant(self): - events = self._deserialize([ - self._instant_event(self._START_TIMESTAMP), - ]) - self.assertEqual(self._working_time(events[0]), 0) + data = self._deserialize( + [ + self._instant_event(self._START_TIMESTAMP), + ] + ) + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], 0) def test_compressed(self): - events = self._deserialize([ - self._compressed_event(self._START_TIMESTAMP, self._LONG_GAP), - ]) - self.assertEqual(self._working_time(events[0]), self._LONG_GAP_INT) + data = self._deserialize( + [ + self._compressed_event(self._START_TIMESTAMP, self._LONG_GAP), + ] + ) + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], self._LONG_GAP_INT) def test_instants_with_short_gap(self): - events = self._deserialize([ - self._instant_event(self._START_TIMESTAMP), - self._instant_event(self._START_TIMESTAMP + self._SHORT_GAP), - ]) - self.assertEqual(self._working_time(events[0]), 0) - self.assertEqual(self._working_time(events[1]), self._SHORT_GAP_INT) + data = self._deserialize( + [ + self._instant_event(self._START_TIMESTAMP), + self._instant_event(self._START_TIMESTAMP + self._SHORT_GAP), + ] + ) + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], 0) + self.assertEqual(event_times[1], self._SHORT_GAP_INT) def test_instants_with_long_gap(self): - events = self._deserialize([ - self._instant_event(self._START_TIMESTAMP), - self._instant_event(self._START_TIMESTAMP + self._LONG_GAP), - ]) - self.assertEqual(self._working_time(events[0]), 0) - self.assertEqual(self._working_time(events[1]), 0) + data = self._deserialize( + [ + self._instant_event(self._START_TIMESTAMP), + self._instant_event(self._START_TIMESTAMP + self._LONG_GAP), + ] + ) + + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], 0) + self.assertEqual(event_times[1], 0) def test_compressed_with_short_gap(self): - events = self._deserialize([ - self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=1)), - self._compressed_event( - self._START_TIMESTAMP + timedelta(seconds=1) + self._SHORT_GAP, - timedelta(seconds=5) - ), - ]) - self.assertEqual(self._working_time(events[0]), 1000) - self.assertEqual(self._working_time(events[1]), self._SHORT_GAP_INT + 5000) + data = self._deserialize( + [ + self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=1)), + self._compressed_event( + self._START_TIMESTAMP + timedelta(seconds=1) + self._SHORT_GAP, + timedelta(seconds=5), + ), + ] + ) + + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], 1000) + self.assertEqual(event_times[1], self._SHORT_GAP_INT + 5000) def test_compressed_with_long_gap(self): - events = self._deserialize([ - self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=1)), - self._compressed_event( - self._START_TIMESTAMP + timedelta(seconds=1) + self._LONG_GAP, - timedelta(seconds=5) - ), - ]) - self.assertEqual(self._working_time(events[0]), 1000) - self.assertEqual(self._working_time(events[1]), 5000) + data = self._deserialize( + [ + self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=1)), + self._compressed_event( + self._START_TIMESTAMP + timedelta(seconds=1) + self._LONG_GAP, + timedelta(seconds=5), + ), + ] + ) + + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], 1000) + self.assertEqual(event_times[1], 5000) def test_compressed_contained(self): - events = self._deserialize([ - self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)), - self._compressed_event( - self._START_TIMESTAMP + timedelta(seconds=3), - timedelta(seconds=1) - ), - ]) - self.assertEqual(self._working_time(events[0]), 5000) - self.assertEqual(self._working_time(events[1]), 0) + data = self._deserialize( + [ + self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)), + self._compressed_event( + self._START_TIMESTAMP + timedelta(seconds=3), timedelta(seconds=1) + ), + ] + ) + + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], 5000) + self.assertEqual(event_times[1], 0) def test_compressed_overlapping(self): - events = self._deserialize([ - self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)), - self._compressed_event( - self._START_TIMESTAMP + timedelta(seconds=3), - timedelta(seconds=6) - ), - ]) - self.assertEqual(self._working_time(events[0]), 5000) - self.assertEqual(self._working_time(events[1]), 4000) + data = self._deserialize( + [ + self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)), + self._compressed_event( + self._START_TIMESTAMP + timedelta(seconds=3), timedelta(seconds=6) + ), + ] + ) + + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], 5000) + self.assertEqual(event_times[1], 4000) def test_instant_inside_compressed(self): - events = self._deserialize([ - self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)), - self._instant_event(self._START_TIMESTAMP + timedelta(seconds=3)), - self._instant_event(self._START_TIMESTAMP + timedelta(seconds=6)), - ]) - self.assertEqual(self._working_time(events[0]), 5000) - self.assertEqual(self._working_time(events[1]), 0) - self.assertEqual(self._working_time(events[2]), 1000) + data = self._deserialize( + [ + self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)), + self._instant_event(self._START_TIMESTAMP + timedelta(seconds=3)), + self._instant_event(self._START_TIMESTAMP + timedelta(seconds=6)), + ] + ) + + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], 5000) + self.assertEqual(event_times[1], 0) + self.assertEqual(event_times[2], 1000) def test_previous_instant_short_gap(self): - events = self._deserialize( + data = self._deserialize( [self._instant_event(self._START_TIMESTAMP + self._SHORT_GAP)], previous_event=self._instant_event(self._START_TIMESTAMP), ) - - self.assertEqual(self._working_time(events[0]), self._SHORT_GAP_INT) + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], self._SHORT_GAP_INT) def test_previous_instant_long_gap(self): - events = self._deserialize( + data = self._deserialize( [self._instant_event(self._START_TIMESTAMP + self._LONG_GAP)], previous_event=self._instant_event(self._START_TIMESTAMP), ) - - self.assertEqual(self._working_time(events[0]), 0) + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], 0) def test_previous_compressed_short_gap(self): - events = self._deserialize( + data = self._deserialize( [self._instant_event(self._START_TIMESTAMP + timedelta(seconds=1) + self._SHORT_GAP)], previous_event=self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=1)), ) - - self.assertEqual(self._working_time(events[0]), self._SHORT_GAP_INT) + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], self._SHORT_GAP_INT) def test_previous_compressed_long_gap(self): - events = self._deserialize( + data = self._deserialize( [self._instant_event(self._START_TIMESTAMP + timedelta(seconds=1) + self._LONG_GAP)], previous_event=self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=1)), ) - - self.assertEqual(self._working_time(events[0]), 0) + event_times = self._get_actual_working_times(data) + self.assertEqual(event_times[0], 0) diff --git a/cvat/apps/events/urls.py b/cvat/apps/events/urls.py index 832c86ac396b..cdb0d2032e68 100644 --- a/cvat/apps/events/urls.py +++ b/cvat/apps/events/urls.py @@ -1,4 +1,3 @@ - # Copyright (C) 2023 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -8,6 +7,6 @@ from . import views router = routers.DefaultRouter(trailing_slash=False) -router.register('events', views.EventsViewSet, basename='events') +router.register("events", views.EventsViewSet, basename="events") urlpatterns = router.urls diff --git a/cvat/apps/events/utils.py b/cvat/apps/events/utils.py index f5ef6dde4295..31c7f83c1791 100644 --- a/cvat/apps/events/utils.py +++ b/cvat/apps/events/utils.py @@ -2,13 +2,17 @@ # # SPDX-License-Identifier: MIT +import datetime + from .cache import clear_cache +from .const import COMPRESSED_EVENT_SCOPES, MAX_EVENT_DURATION + def _prepare_objects_to_delete(object_to_delete): - from cvat.apps.engine.models import Project, Task, Segment, Job, Issue, Comment + from cvat.apps.engine.models import Comment, Issue, Job, Project, Segment, Task relation_chain = (Project, Task, Segment, Job, Issue, Comment) - related_field_names = ('task_set', 'segment_set', 'job_set', 'issues', 'comments') + related_field_names = ("task_set", "segment_set", "job_set", "issues", "comments") field_names = tuple(m._meta.model_name for m in relation_chain) # Find object Model @@ -21,25 +25,21 @@ def _prepare_objects_to_delete(object_to_delete): # Fill filter param filter_params = { - f'{object_to_delete.__class__._meta.model_name}_id': object_to_delete.id, + f"{object_to_delete.__class__._meta.model_name}_id": object_to_delete.id, } # Fill prefetch prefetch = [] if index < len(relation_chain) - 1: - forward_prefetch = '__'.join(related_field_names[index:]) + forward_prefetch = "__".join(related_field_names[index:]) prefetch.append(forward_prefetch) if index > 0: - backward_prefetch = '__'.join(reversed(field_names[:index])) + backward_prefetch = "__".join(reversed(field_names[:index])) prefetch.append(backward_prefetch) # make queryset - objects = relation_chain[index].objects.filter( - **filter_params - ).prefetch_related( - *prefetch - ) + objects = relation_chain[index].objects.filter(**filter_params).prefetch_related(*prefetch) # list of objects which will be deleted with current object objects_to_delete = list(objects) @@ -51,9 +51,11 @@ def _prepare_objects_to_delete(object_to_delete): return objects_to_delete + def cache_deleted(method): def wrap(self, *args, **kwargs): from .signals import resource_delete + objects = _prepare_objects_to_delete(self) try: for obj in objects: @@ -62,4 +64,55 @@ def wrap(self, *args, **kwargs): method(self, *args, **kwargs) finally: clear_cache() + return wrap + + +def get_end_timestamp(event: dict) -> datetime.datetime: + if event["scope"] in COMPRESSED_EVENT_SCOPES: + return event["timestamp"] + datetime.timedelta(milliseconds=event["duration"]) + return event["timestamp"] + + +def is_contained(event1: dict, event2: dict) -> bool: + return event1["timestamp"] < get_end_timestamp(event2) + + +def compute_working_time_per_ids(data: dict) -> dict: + def read_ids(event: dict) -> tuple[int | None, int | None, int | None]: + return event.get("job_id"), event.get("task_id"), event.get("project_id") + + if previous_event := data["previous_event"]: + previous_end_timestamp = get_end_timestamp(previous_event) + previous_ids = read_ids(previous_event) + elif data["events"]: + previous_end_timestamp = data["events"][0]["timestamp"] + previous_ids = read_ids(data["events"][0]) + + working_time_per_ids = {} + for event in data["events"]: + working_time = datetime.timedelta() + timestamp = event["timestamp"] + + if timestamp > previous_end_timestamp: + t_diff = timestamp - previous_end_timestamp + if t_diff < MAX_EVENT_DURATION: + working_time += t_diff + + previous_end_timestamp = timestamp + + end_timestamp = get_end_timestamp(event) + if end_timestamp > previous_end_timestamp: + working_time += end_timestamp - previous_end_timestamp + previous_end_timestamp = end_timestamp + + if previous_ids not in working_time_per_ids: + working_time_per_ids[previous_ids] = { + "value": datetime.timedelta(), + "timestamp": timestamp, + } + + working_time_per_ids[previous_ids]["value"] += working_time + previous_ids = read_ids(event) + + return working_time_per_ids diff --git a/cvat/apps/events/views.py b/cvat/apps/events/views.py index 31914a829c3b..e910dabdc3be 100644 --- a/cvat/apps/events/views.py +++ b/cvat/apps/events/views.py @@ -4,8 +4,7 @@ from django.conf import settings from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import (OpenApiParameter, OpenApiResponse, - extend_schema) +from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema from rest_framework import status, viewsets from rest_framework.renderers import JSONRenderer from rest_framework.response import Response @@ -22,59 +21,114 @@ class EventsViewSet(viewsets.ViewSet): serializer_class = None - @extend_schema(summary='Log client events', - methods=['POST'], - description='Sends logs to the Clickhouse if it is connected', + @extend_schema( + summary="Log client events", + methods=["POST"], + description="Sends logs to the Clickhouse if it is connected", parameters=ORGANIZATION_OPEN_API_PARAMETERS, request=ClientEventsSerializer(), responses={ - '201': ClientEventsSerializer(), - }) + "201": ClientEventsSerializer(), + }, + ) def create(self, request): serializer = ClientEventsSerializer(data=request.data, context={"request": request}) serializer.is_valid(raise_exception=True) handle_client_events_push(request, serializer.validated_data) for event in serializer.validated_data["events"]: - message = JSONRenderer().render({ - **event, - 'timestamp': str(event["timestamp"].timestamp()) - }).decode('UTF-8') + message = ( + JSONRenderer() + .render({**event, "timestamp": str(event["timestamp"].timestamp())}) + .decode("UTF-8") + ) vlogger.info(message) return Response(serializer.validated_data, status=status.HTTP_201_CREATED) - @extend_schema(summary='Get an event log', - methods=['GET'], - description='The log is returned in the CSV format.', + @extend_schema( + summary="Get an event log", + methods=["GET"], + description="The log is returned in the CSV format.", parameters=[ - OpenApiParameter('org_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False, - description="Filter events by organization ID"), - OpenApiParameter('project_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False, - description="Filter events by project ID"), - OpenApiParameter('task_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False, - description="Filter events by task ID"), - OpenApiParameter('job_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False, - description="Filter events by job ID"), - OpenApiParameter('user_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False, - description="Filter events by user ID"), - OpenApiParameter('from', location=OpenApiParameter.QUERY, type=OpenApiTypes.DATETIME, required=False, - description="Filter events after the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set."), - OpenApiParameter('to', location=OpenApiParameter.QUERY, type=OpenApiTypes.DATETIME, required=False, - description="Filter events before the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set."), - OpenApiParameter('filename', description='Desired output file name', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), - OpenApiParameter('action', location=OpenApiParameter.QUERY, - description='Used to start downloading process after annotation file had been created', - type=OpenApiTypes.STR, required=False, enum=['download']), - OpenApiParameter('query_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - description="ID of query request that need to check or download"), + OpenApiParameter( + "org_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + required=False, + description="Filter events by organization ID", + ), + OpenApiParameter( + "project_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + required=False, + description="Filter events by project ID", + ), + OpenApiParameter( + "task_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + required=False, + description="Filter events by task ID", + ), + OpenApiParameter( + "job_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + required=False, + description="Filter events by job ID", + ), + OpenApiParameter( + "user_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.INT, + required=False, + description="Filter events by user ID", + ), + OpenApiParameter( + "from", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.DATETIME, + required=False, + description="Filter events after the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set.", + ), + OpenApiParameter( + "to", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.DATETIME, + required=False, + description="Filter events before the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set.", + ), + OpenApiParameter( + "filename", + description="Desired output file name", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + required=False, + ), + OpenApiParameter( + "action", + location=OpenApiParameter.QUERY, + description="Used to start downloading process after annotation file had been created", + type=OpenApiTypes.STR, + required=False, + enum=["download"], + ), + OpenApiParameter( + "query_id", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + required=False, + description="ID of query request that need to check or download", + ), ], responses={ - '200': OpenApiResponse(description='Download of file started'), - '201': OpenApiResponse(description='CSV log file is ready for downloading'), - '202': OpenApiResponse(description='Creating a CSV log file has been started'), - }) + "200": OpenApiResponse(description="Download of file started"), + "201": OpenApiResponse(description="CSV log file is ready for downloading"), + "202": OpenApiResponse(description="Creating a CSV log file has been started"), + }, + ) def list(self, request): perm = EventsPermission.create_scope_list(request) filter_query = perm.filter(request.query_params) diff --git a/cvat/apps/health/apps.py b/cvat/apps/health/apps.py index a457048b87c9..ae38010ff7b2 100644 --- a/cvat/apps/health/apps.py +++ b/cvat/apps/health/apps.py @@ -3,12 +3,13 @@ # SPDX-License-Identifier: MIT from django.apps import AppConfig - from health_check.plugins import plugin_dir + class HealthConfig(AppConfig): - name = 'cvat.apps.health' + name = "cvat.apps.health" def ready(self): from .backends import OPAHealthCheck + plugin_dir.register(OPAHealthCheck) diff --git a/cvat/apps/health/backends.py b/cvat/apps/health/backends.py index 2f361117173a..0ba37cb23195 100644 --- a/cvat/apps/health/backends.py +++ b/cvat/apps/health/backends.py @@ -3,19 +3,18 @@ # SPDX-License-Identifier: MIT import requests - +from django.conf import settings from health_check.backends import BaseHealthCheckBackend from health_check.exceptions import HealthCheckException -from django.conf import settings - from cvat.utils.http import make_requests_session + class OPAHealthCheck(BaseHealthCheckBackend): critical_service = True def check_status(self): - opa_health_url = f'{settings.IAM_OPA_HOST}/health?bundles' + opa_health_url = f"{settings.IAM_OPA_HOST}/health?bundles" try: with make_requests_session() as session: response = session.get(opa_health_url) diff --git a/cvat/apps/health/management/commands/workerprobe.py b/cvat/apps/health/management/commands/workerprobe.py index fc8b6cf7077a..af9d663a1a29 100644 --- a/cvat/apps/health/management/commands/workerprobe.py +++ b/cvat/apps/health/management/commands/workerprobe.py @@ -1,10 +1,11 @@ import os import platform from datetime import datetime, timedelta -from django.core.management.base import BaseCommand, CommandError + +import django_rq from django.conf import settings +from django.core.management.base import BaseCommand, CommandError from rq.worker import Worker -import django_rq class Command(BaseCommand): @@ -20,13 +21,21 @@ def handle(self, *args, **options): raise CommandError(f"Queue {queue_name} is not defined") connection = django_rq.get_connection(queue_name) - workers = [w for w in Worker.all(connection) if queue_name in w.queue_names() and w.hostname == hostname] + workers = [ + w + for w in Worker.all(connection) + if queue_name in w.queue_names() and w.hostname == hostname + ] expected_workers = int(os.getenv("NUMPROCS", 1)) if len(workers) != expected_workers: - raise CommandError("Number of registered workers does not match the expected number, " \ - f"actual: {len(workers)}, expected: {expected_workers}") + raise CommandError( + "Number of registered workers does not match the expected number, " + f"actual: {len(workers)}, expected: {expected_workers}" + ) for worker in workers: if datetime.now() - worker.last_heartbeat > timedelta(seconds=worker.worker_ttl): - raise CommandError(f"It seems that worker {worker.name}, pid: {worker.pid} is dead") + raise CommandError( + f"It seems that worker {worker.name}, pid: {worker.pid} is dead" + ) diff --git a/cvat/apps/iam/adapters.py b/cvat/apps/iam/adapters.py index 703bec48743f..50ff2812c3a5 100644 --- a/cvat/apps/iam/adapters.py +++ b/cvat/apps/iam/adapters.py @@ -2,10 +2,10 @@ # # SPDX-License-Identifier: MIT -from django.http import HttpResponseRedirect +from allauth.account.adapter import DefaultAccountAdapter from django.conf import settings +from django.http import HttpResponseRedirect -from allauth.account.adapter import DefaultAccountAdapter class DefaultAccountAdapterEx(DefaultAccountAdapter): def respond_email_verification_sent(self, request, user): diff --git a/cvat/apps/iam/admin.py b/cvat/apps/iam/admin.py index 648e15dc2da4..bf6efafe9a34 100644 --- a/cvat/apps/iam/admin.py +++ b/cvat/apps/iam/admin.py @@ -4,8 +4,8 @@ # SPDX-License-Identifier: MIT from django.contrib import admin -from django.contrib.auth.models import Group, User from django.contrib.auth.admin import GroupAdmin, UserAdmin +from django.contrib.auth.models import Group, User from django.utils.translation import gettext_lazy as _ from cvat.apps.engine.models import Profile @@ -14,20 +14,27 @@ class ProfileInline(admin.StackedInline): model = Profile - fieldsets = ( - (None, {'fields': ('has_analytics_access', )}), - ) + fieldsets = ((None, {"fields": ("has_analytics_access",)}),) class CustomUserAdmin(UserAdmin): inlines = (ProfileInline,) list_display = ("username", "email", "first_name", "last_name", "is_active", "is_staff") fieldsets = ( - (None, {'fields': ('username', 'password')}), - (_('Personal info'), {'fields': ('first_name', 'last_name', 'email')}), - (_('Permissions'), {'fields': ('is_active', 'is_staff', 'is_superuser', - 'groups',)}), - (_('Important dates'), {'fields': ('last_login', 'date_joined')}), + (None, {"fields": ("username", "password")}), + (_("Personal info"), {"fields": ("first_name", "last_name", "email")}), + ( + _("Permissions"), + { + "fields": ( + "is_active", + "is_staff", + "is_superuser", + "groups", + ) + }, + ), + (_("Important dates"), {"fields": ("last_login", "date_joined")}), ) add_fieldsets = ( ( @@ -40,21 +47,17 @@ class CustomUserAdmin(UserAdmin): ) actions = ["user_activate", "user_deactivate"] - @admin.action( - permissions=["change"], description=_("Mark selected users as active") - ) + @admin.action(permissions=["change"], description=_("Mark selected users as active")) def user_activate(self, request, queryset): queryset.update(is_active=True) - @admin.action( - permissions=["change"], description=_("Mark selected users as not active") - ) + @admin.action(permissions=["change"], description=_("Mark selected users as not active")) def user_deactivate(self, request, queryset): queryset.update(is_active=False) class CustomGroupAdmin(GroupAdmin): - fieldsets = ((None, {'fields': ('name',)}),) + fieldsets = ((None, {"fields": ("name",)}),) admin.site.unregister(User) diff --git a/cvat/apps/iam/apps.py b/cvat/apps/iam/apps.py index 97bdc3ca05fd..00f051a75c08 100644 --- a/cvat/apps/iam/apps.py +++ b/cvat/apps/iam/apps.py @@ -5,9 +5,11 @@ from django.apps import AppConfig + class IAMConfig(AppConfig): - name = 'cvat.apps.iam' + name = "cvat.apps.iam" def ready(self): from .signals import register_signals + register_signals(self) diff --git a/cvat/apps/iam/authentication.py b/cvat/apps/iam/authentication.py index 412806380389..74ec6f5424b7 100644 --- a/cvat/apps/iam/authentication.py +++ b/cvat/apps/iam/authentication.py @@ -2,21 +2,23 @@ # # SPDX-License-Identifier: MIT +import hashlib + +from django.contrib.auth import get_user_model from django.core import signing +from furl import furl from rest_framework import exceptions from rest_framework.authentication import BaseAuthentication -from django.contrib.auth import get_user_model -from furl import furl -import hashlib + # Got implementation ideas in https://github.com/marcgibbons/drf_signed_auth class Signer: - QUERY_PARAM = 'sign' + QUERY_PARAM = "sign" MAX_AGE = 30 @classmethod def get_salt(cls, url): - normalized_url = furl(url).remove(cls.QUERY_PARAM).url.encode('utf-8') + normalized_url = furl(url).remove(cls.QUERY_PARAM).url.encode("utf-8") salt = hashlib.sha256(normalized_url).hexdigest() return salt @@ -24,10 +26,7 @@ def sign(self, user, url): """ Create a signature for a user object. """ - data = { - 'user_id': user.pk, - 'username': user.get_username() - } + data = {"user_id": user.pk, "username": user.get_username()} return signing.dumps(data, salt=self.get_salt(url)) @@ -36,24 +35,24 @@ def unsign(self, signature, url): Return a user object for a valid signature. """ User = get_user_model() - data = signing.loads(signature, salt=self.get_salt(url), - max_age=self.MAX_AGE) + data = signing.loads(signature, salt=self.get_salt(url), max_age=self.MAX_AGE) if not isinstance(data, dict): raise signing.BadSignature() try: - return User.objects.get(**{ - 'pk': data.get('user_id'), - User.USERNAME_FIELD: data.get('username') - }) + return User.objects.get( + **{"pk": data.get("user_id"), User.USERNAME_FIELD: data.get("username")} + ) except User.DoesNotExist: raise signing.BadSignature() + class SignatureAuthentication(BaseAuthentication): """ Authentication backend for signed URLs. """ + def authenticate(self, request): """ Returns authenticated user if URL signature is valid. @@ -66,10 +65,10 @@ def authenticate(self, request): try: user = signer.unsign(sign, request.build_absolute_uri()) except signing.SignatureExpired: - raise exceptions.AuthenticationFailed('This URL has expired.') + raise exceptions.AuthenticationFailed("This URL has expired.") except signing.BadSignature: - raise exceptions.AuthenticationFailed('Invalid signature.') + raise exceptions.AuthenticationFailed("Invalid signature.") if not user.is_active: - raise exceptions.AuthenticationFailed('User inactive or deleted.') + raise exceptions.AuthenticationFailed("User inactive or deleted.") return (user, None) diff --git a/cvat/apps/iam/filters.py b/cvat/apps/iam/filters.py index 6fd62d8d05e5..c99da171bac7 100644 --- a/cvat/apps/iam/filters.py +++ b/cvat/apps/iam/filters.py @@ -2,29 +2,29 @@ # # SPDX-License-Identifier: MIT -from rest_framework.filters import BaseFilterBackend -from django.db.models import Q from collections.abc import Iterable +from django.db.models import Q from drf_spectacular.utils import OpenApiParameter +from rest_framework.filters import BaseFilterBackend ORGANIZATION_OPEN_API_PARAMETERS = [ OpenApiParameter( - name='org', + name="org", type=str, required=False, location=OpenApiParameter.QUERY, description="Organization unique slug", ), OpenApiParameter( - name='org_id', + name="org_id", type=int, required=False, location=OpenApiParameter.QUERY, description="Organization identifier", ), OpenApiParameter( - name='X-Organization', + name="X-Organization", type=str, required=False, location=OpenApiParameter.HEADER, @@ -32,13 +32,14 @@ ), ] + class OrganizationFilterBackend(BaseFilterBackend): def _parameter_is_provided(self, request): for parameter in ORGANIZATION_OPEN_API_PARAMETERS: - if parameter.location == 'header' and parameter.name in request.headers: + if parameter.location == "header" and parameter.name in request.headers: return True - elif parameter.location == 'query' and parameter.name in request.query_params: + elif parameter.location == "query" and parameter.name in request.query_params: return True return False @@ -62,34 +63,35 @@ def _construct_filter_query(self, organization_fields, org_id): return Q() - def filter_queryset(self, request, queryset, view): # Filter works only for "list" requests and allows to return # only non-organization objects if org isn't specified if ( - view.detail or not view.iam_organization_field or + view.detail + or not view.iam_organization_field + or # FIXME: It should be handled in another way. For example, if we try to get information for a specific job # and org isn't specified, we need to return the full list of labels, issues, comments. # Allow crowdsourcing users to get labels/issues/comments related to specific job. # Crowdsourcing user always has worker group and isn't a member of an organization. ( - view.__class__.__name__ in ('LabelViewSet', 'IssueViewSet', 'CommentViewSet') and - request.query_params.get('job_id') and - request.iam_context.get('organization') is None and - request.iam_context.get('privilege') == 'worker' + view.__class__.__name__ in ("LabelViewSet", "IssueViewSet", "CommentViewSet") + and request.query_params.get("job_id") + and request.iam_context.get("organization") is None + and request.iam_context.get("privilege") == "worker" ) ): return queryset visibility = None - org = request.iam_context['organization'] + org = request.iam_context["organization"] if org: - visibility = {'organization': org.id} + visibility = {"organization": org.id} elif not org and self._parameter_is_provided(request): - visibility = {'organization': None} + visibility = {"organization": None} if visibility: org_id = visibility.pop("organization") @@ -108,15 +110,17 @@ def get_schema_operation_parameters(self, view): parameter_type = None if parameter.type == int: - parameter_type = 'integer' + parameter_type = "integer" elif parameter.type == str: - parameter_type = 'string' - - parameters.append({ - 'name': parameter.name, - 'in': parameter.location, - 'description': parameter.description, - 'schema': {'type': parameter_type} - }) + parameter_type = "string" + + parameters.append( + { + "name": parameter.name, + "in": parameter.location, + "description": parameter.description, + "schema": {"type": parameter_type}, + } + ) return parameters diff --git a/cvat/apps/iam/forms.py b/cvat/apps/iam/forms.py index c1668b924387..af619a563f38 100644 --- a/cvat/apps/iam/forms.py +++ b/cvat/apps/iam/forms.py @@ -2,22 +2,27 @@ # # SPDX-License-Identifier: MIT -from django.contrib.sites.shortcuts import get_current_site -from django.contrib.auth import get_user_model - +from allauth.account.adapter import get_adapter from allauth.account.forms import default_token_generator from allauth.account.utils import user_pk_to_url_str -from allauth.account.adapter import get_adapter from dj_rest_auth.forms import AllAuthPasswordResetForm +from django.contrib.auth import get_user_model +from django.contrib.sites.shortcuts import get_current_site UserModel = get_user_model() -class ResetPasswordFormEx(AllAuthPasswordResetForm): - def save(self, request=None, domain_override=None, - email_template_prefix='authentication/password_reset_key', - use_https=False, token_generator=default_token_generator, - extra_email_context=None, **kwargs): +class ResetPasswordFormEx(AllAuthPasswordResetForm): + def save( + self, + request=None, + domain_override=None, + email_template_prefix="authentication/password_reset_key", + use_https=False, + token_generator=default_token_generator, + extra_email_context=None, + **kwargs, + ): """ Generate a one-use only link for resetting password and send it to the user. @@ -33,16 +38,16 @@ def save(self, request=None, domain_override=None, for user in self.users: user_email = getattr(user, email_field_name) context = { - 'email': user_email, - 'domain': domain, - 'site_name': site_name, - 'uid': user_pk_to_url_str(user), - 'user': user, - 'token': token_generator.make_token(user), - 'protocol': 'https' if use_https else 'http', + "email": user_email, + "domain": domain, + "site_name": site_name, + "uid": user_pk_to_url_str(user), + "user": user, + "token": token_generator.make_token(user), + "protocol": "https" if use_https else "http", **(extra_email_context or {}), } get_adapter(request).send_mail(email_template_prefix, email, context) - return self.cleaned_data['email'] + return self.cleaned_data["email"] diff --git a/cvat/apps/iam/middleware.py b/cvat/apps/iam/middleware.py index f2f1a4bae2e0..c09c5eeb96b6 100644 --- a/cvat/apps/iam/middleware.py +++ b/cvat/apps/iam/middleware.py @@ -5,10 +5,10 @@ from datetime import timedelta from typing import Callable -from django.utils.functional import SimpleLazyObject -from rest_framework.exceptions import ValidationError, NotFound from django.conf import settings from django.http import HttpRequest, HttpResponse +from django.utils.functional import SimpleLazyObject +from rest_framework.exceptions import NotFound, ValidationError def get_organization(request): @@ -22,31 +22,32 @@ def get_organization(request): organization = None try: - org_slug = request.GET.get('org') - org_id = request.GET.get('org_id') - org_header = request.headers.get('X-Organization') + org_slug = request.GET.get("org") + org_id = request.GET.get("org_id") + org_header = request.headers.get("X-Organization") if org_id is not None and (org_slug is not None or org_header is not None): - raise ValidationError('You cannot specify "org_id" query parameter with ' - '"org" query parameter or "X-Organization" HTTP header at the same time.') + raise ValidationError( + 'You cannot specify "org_id" query parameter with ' + '"org" query parameter or "X-Organization" HTTP header at the same time.' + ) if org_slug is not None and org_header is not None and org_slug != org_header: - raise ValidationError('You cannot specify "org" query parameter and ' - '"X-Organization" HTTP header with different values.') + raise ValidationError( + 'You cannot specify "org" query parameter and ' + '"X-Organization" HTTP header with different values.' + ) org_slug = org_slug if org_slug is not None else org_header if org_slug: - organization = Organization.objects.select_related('owner').get(slug=org_slug) + organization = Organization.objects.select_related("owner").get(slug=org_slug) elif org_id: - organization = Organization.objects.select_related('owner').get(id=int(org_id)) + organization = Organization.objects.select_related("owner").get(id=int(org_id)) except Organization.DoesNotExist: - raise NotFound(f'{org_slug or org_id} organization does not exist.') + raise NotFound(f"{org_slug or org_id} organization does not exist.") - context = { - "organization": organization, - "privilege": getattr(privilege, 'name', None) - } + context = {"organization": organization, "privilege": getattr(privilege, "name", None)} return context @@ -62,6 +63,7 @@ def __call__(self, request): return self.get_response(request) + class SessionRefreshMiddleware: """ Implements behavior similar to SESSION_SAVE_EVERY_REQUEST=True, but instead of diff --git a/cvat/apps/iam/migrations/0001_remove_business_group.py b/cvat/apps/iam/migrations/0001_remove_business_group.py index 2bf1a56b4065..aa64d4a56d6d 100644 --- a/cvat/apps/iam/migrations/0001_remove_business_group.py +++ b/cvat/apps/iam/migrations/0001_remove_business_group.py @@ -2,13 +2,12 @@ from django.conf import settings from django.db import migrations - BUSINESS_GROUP_NAME = "business" USER_GROUP_NAME = "user" def delete_business_group(apps, schema_editor): - Group = apps.get_model('auth', 'Group') + Group = apps.get_model("auth", "Group") User = apps.get_model(settings.AUTH_USER_MODEL) if user_group := Group.objects.filter(name=USER_GROUP_NAME).first(): diff --git a/cvat/apps/iam/models.py b/cvat/apps/iam/models.py index b1220197cf2a..f7c3408e3d12 100644 --- a/cvat/apps/iam/models.py +++ b/cvat/apps/iam/models.py @@ -1,4 +1,3 @@ # Copyright (C) 2021-2022 Intel Corporation # # SPDX-License-Identifier: MIT - diff --git a/cvat/apps/iam/permissions.py b/cvat/apps/iam/permissions.py index d4925426724a..f13d6be377ce 100644 --- a/cvat/apps/iam/permissions.py +++ b/cvat/apps/iam/permissions.py @@ -44,21 +44,21 @@ def get_organization(request, obj): if obj: try: - organization_id = getattr(obj, 'organization_id') + organization_id = getattr(obj, "organization_id") except AttributeError as exc: # Skip initialization of organization for those objects that don't related with organization - view = request.parser_context.get('view') + view = request.parser_context.get("view") if view and view.basename in settings.OBJECTS_NOT_RELATED_WITH_ORG: - return request.iam_context['organization'] + return request.iam_context["organization"] raise exc try: - return Organization.objects.select_related('owner').get(id=organization_id) + return Organization.objects.select_related("owner").get(id=organization_id) except Organization.DoesNotExist: return None - return request.iam_context['organization'] + return request.iam_context["organization"] def get_membership(request, organization): @@ -66,21 +66,20 @@ def get_membership(request, organization): return None return Membership.objects.filter( - organization=organization, - user=request.user, - is_active=True + organization=organization, user=request.user, is_active=True ).first() -def build_iam_context(request, organization: Optional[Organization], membership: Optional[Membership]): +def build_iam_context( + request, organization: Optional[Organization], membership: Optional[Membership] +): return { - 'user_id': request.user.id, - 'group_name': request.iam_context['privilege'], - 'org_id': getattr(organization, 'id', None), - 'org_slug': getattr(organization, 'slug', None), - 'org_owner_id': getattr(organization.owner, 'id', None) - if organization else None, - 'org_role': getattr(membership, 'role', None), + "user_id": request.user.id, + "group_name": request.iam_context["privilege"], + "org_id": getattr(organization, "id", None), + "org_slug": getattr(organization, "slug", None), + "org_owner_id": getattr(organization.owner, "id", None) if organization else None, + "org_role": getattr(membership, "role", None), } @@ -103,23 +102,19 @@ class OpenPolicyAgentPermission(metaclass=ABCMeta): @classmethod @abstractmethod - def create(cls, request, view, obj, iam_context) -> Sequence[OpenPolicyAgentPermission]: - ... + def create(cls, request, view, obj, iam_context) -> Sequence[OpenPolicyAgentPermission]: ... @classmethod def create_base_perm(cls, request, view, scope, iam_context, obj=None, **kwargs): if not iam_context and request: iam_context = get_iam_context(request, obj) - return cls( - scope=scope, - obj=obj, - **iam_context, **kwargs) + return cls(scope=scope, obj=obj, **iam_context, **kwargs) @classmethod def create_scope_list(cls, request, iam_context=None): if not iam_context and request: iam_context = get_iam_context(request, None) - return cls(**iam_context, scope='list') + return cls(**iam_context, scope="list") def __init__(self, **kwargs): self.obj = None @@ -127,27 +122,31 @@ def __init__(self, **kwargs): setattr(self, name, val) self.payload = { - 'input': { - 'scope': self.scope, - 'auth': { - 'user': { - 'id': self.user_id, - 'privilege': self.group_name, + "input": { + "scope": self.scope, + "auth": { + "user": { + "id": self.user_id, + "privilege": self.group_name, }, - 'organization': { - 'id': self.org_id, - 'owner': { - 'id': self.org_owner_id, - }, - 'user': { - 'role': self.org_role, - }, - } if self.org_id is not None else None - } + "organization": ( + { + "id": self.org_id, + "owner": { + "id": self.org_owner_id, + }, + "user": { + "role": self.org_role, + }, + } + if self.org_id is not None + else None + ), + }, } } - self.payload['input']['resource'] = self.get_resource() + self.payload["input"]["resource"] = self.get_resource() @abstractmethod def get_resource(self): @@ -156,13 +155,13 @@ def get_resource(self): def check_access(self) -> PermissionResult: with make_requests_session() as session: response = session.post(self.url, json=self.payload) - output = response.json()['result'] + output = response.json()["result"] allow = False reasons = [] if isinstance(output, dict): - allow = output['allow'] - reasons = output.get('reasons', []) + allow = output["allow"] + reasons = output.get("reasons", []) elif isinstance(output, bool): allow = output else: @@ -171,21 +170,21 @@ def check_access(self) -> PermissionResult: return PermissionResult(allow=allow, reasons=reasons) def filter(self, queryset): - url = self.url.replace('/allow', '/filter') + url = self.url.replace("/allow", "/filter") with make_requests_session() as session: - r = session.post(url, json=self.payload).json()['result'] + r = session.post(url, json=self.payload).json()["result"] q_objects = [] ops_dict = { - '|': operator.or_, - '&': operator.and_, - '~': operator.not_, + "|": operator.or_, + "&": operator.and_, + "~": operator.not_, } for item in r: if isinstance(item, str): val1 = q_objects.pop() - if item == '~': + if item == "~": q_objects.append(ops_dict[item](val1)) else: val2 = q_objects.pop() @@ -211,7 +210,7 @@ def get_per_field_update_scopes(cls, request, scopes_per_field): request body fields are associated with different scopes. """ - assert request.method == 'PATCH' + assert request.method == "PATCH" # Even if no fields are modified, a PATCH request typically returns the # new state of the object, so we need to make sure the user has permissions @@ -226,7 +225,7 @@ def get_per_field_update_scopes(cls, request, scopes_per_field): return scopes -T = TypeVar('T', bound=Model) +T = TypeVar("T", bound=Model) def is_public_obj(obj: T) -> bool: @@ -257,22 +256,23 @@ def has_permission(self, request, view): if not view.detail: return self.check_permission(request, view, None) else: - return True # has_object_permission will be called later + return True # has_object_permission will be called later def has_object_permission(self, request, view, obj): return self.check_permission(request, view, obj) @staticmethod def is_metadata_request(request, view): - return request.method == 'OPTIONS' \ - or (request.method == 'POST' and view.action == 'metadata' and len(request.data) == 0) + return request.method == "OPTIONS" or ( + request.method == "POST" and view.action == "metadata" and len(request.data) == 0 + ) class IsAuthenticatedOrReadPublicResource(BasePermission): def has_object_permission(self, request, view, obj) -> bool: return bool( - (request.user and request.user.is_authenticated) or - (request.method == 'GET' and is_public_obj(obj)) + (request.user and request.user.is_authenticated) + or (request.method == "GET" and is_public_obj(obj)) ) diff --git a/cvat/apps/iam/rules/tests/generate_tests.py b/cvat/apps/iam/rules/tests/generate_tests.py index 729de6732eb2..92b4a0e699a9 100755 --- a/cvat/apps/iam/rules/tests/generate_tests.py +++ b/cvat/apps/iam/rules/tests/generate_tests.py @@ -10,11 +10,12 @@ from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from functools import partial -from typing import Optional from pathlib import Path +from typing import Optional REPO_ROOT = Path(__file__).resolve().parents[5] + def create_arg_parser() -> ArgumentParser: parser = ArgumentParser(add_help=True) parser.add_argument( @@ -36,7 +37,7 @@ def parse_args(args: Optional[Sequence[str]] = None) -> Namespace: def call_generator(generator_path: Path, gen_params: Namespace) -> None: rules_dir = generator_path.parents[2] subprocess.check_call( - [sys.executable, generator_path.relative_to(rules_dir), 'tests/configs'], cwd=rules_dir + [sys.executable, generator_path.relative_to(rules_dir), "tests/configs"], cwd=rules_dir ) @@ -53,7 +54,7 @@ def main(args: Optional[Sequence[str]] = None) -> int: partial(call_generator, gen_params=args), generator_paths, ): - pass # consume all results in order to propagate exceptions + pass # consume all results in order to propagate exceptions if __name__ == "__main__": diff --git a/cvat/apps/iam/schema.py b/cvat/apps/iam/schema.py index 46f9e31052c1..7f54c5597a95 100644 --- a/cvat/apps/iam/schema.py +++ b/cvat/apps/iam/schema.py @@ -9,7 +9,6 @@ from drf_spectacular.authentication import SessionScheme, TokenScheme from drf_spectacular.extensions import OpenApiAuthenticationExtension from drf_spectacular.openapi import AutoSchema - from rest_framework import serializers @@ -18,35 +17,37 @@ class SignatureAuthenticationScheme(OpenApiAuthenticationExtension): Adds the signature auth method to schema """ - target_class = 'cvat.apps.iam.authentication.SignatureAuthentication' - name = 'signatureAuth' # name used in the schema + target_class = "cvat.apps.iam.authentication.SignatureAuthentication" + name = "signatureAuth" # name used in the schema def get_security_definition(self, auto_schema): return { - 'type': 'apiKey', - 'in': 'query', - 'name': 'sign', - 'description': 'Can be used to share URLs to private links', + "type": "apiKey", + "in": "query", + "name": "sign", + "description": "Can be used to share URLs to private links", } + class TokenAuthenticationScheme(TokenScheme): """ Adds the token auth method to schema. The description includes extra info comparing to what is generated by default. """ - name = 'tokenAuth' + name = "tokenAuth" priority = 0 match_subclasses = True def get_security_requirement(self, auto_schema): # These schemes must be used together - return {'sessionAuth': [], 'csrfAuth': [], self.name: []} + return {"sessionAuth": [], "csrfAuth": [], self.name: []} def get_security_definition(self, auto_schema): schema = super().get_security_definition(auto_schema) - schema['x-token-prefix'] = self.target.keyword - schema['description'] = textwrap.dedent(f""" + schema["x-token-prefix"] = self.target.keyword + schema["description"] = textwrap.dedent( + f""" To authenticate using a token (or API key), you need to have 3 components in a request: - the 'sessionid' cookie - the 'csrftoken' cookie or 'X-CSRFTOKEN' header @@ -54,16 +55,18 @@ def get_security_definition(self, auto_schema): You can obtain an API key (the token) from the server response on the basic auth request. - """) + """ + ) return schema + class CookieAuthenticationScheme(SessionScheme): """ This class adds csrftoken cookie into security sections. It must be used together with the 'sessionid' cookie. """ - name = ['sessionAuth', 'csrfAuth'] + name = ["sessionAuth", "csrfAuth"] priority = 0 def get_security_requirement(self, auto_schema): @@ -73,13 +76,14 @@ def get_security_requirement(self, auto_schema): def get_security_definition(self, auto_schema): sessionid_schema = super().get_security_definition(auto_schema) csrftoken_schema = { - 'type': 'apiKey', - 'in': 'cookie', - 'name': 'csrftoken', - 'description': 'Can be sent as a cookie or as the X-CSRFTOKEN header' + "type": "apiKey", + "in": "cookie", + "name": "csrftoken", + "description": "Can be sent as a cookie or as the X-CSRFTOKEN header", } return [sessionid_schema, csrftoken_schema] + class CustomAutoSchema(AutoSchema): def get_operation_id(self): # Change style of operation ids to [viewset _ action _ object] @@ -87,20 +91,20 @@ def get_operation_id(self): tokenized_path = self._tokenize_path() # replace dashes as they can be problematic later in code generation - tokenized_path = [t.replace('-', '_') for t in tokenized_path] + tokenized_path = [t.replace("-", "_") for t in tokenized_path] - if self.method == 'GET' and self._is_list_view(): - action = 'list' + if self.method == "GET" and self._is_list_view(): + action = "list" else: action = self.method_mapping[self.method.lower()] if not tokenized_path: - tokenized_path.append('root') + tokenized_path.append("root") - if re.search(r'', self.path_regex): - tokenized_path.append('formatted') + if re.search(r"", self.path_regex): + tokenized_path.append("formatted") - return '_'.join([tokenized_path[0]] + [action] + tokenized_path[1:]) + return "_".join([tokenized_path[0]] + [action] + tokenized_path[1:]) def _get_request_for_media_type(self, serializer, *args, **kwargs): # Enables support for required=False serializers in request body specification diff --git a/cvat/apps/iam/serializers.py b/cvat/apps/iam/serializers.py index 967b696a4f21..7de9919e3ab3 100644 --- a/cvat/apps/iam/serializers.py +++ b/cvat/apps/iam/serializers.py @@ -3,23 +3,21 @@ # # SPDX-License-Identifier: MIT -from dj_rest_auth.registration.serializers import RegisterSerializer -from dj_rest_auth.serializers import PasswordResetSerializer, LoginSerializer -from django.core.exceptions import ValidationError as DjangoValidationError -from rest_framework.exceptions import ValidationError -from rest_framework import serializers +from typing import Optional, Union + from allauth.account import app_settings as allauth_settings -from allauth.account.utils import filter_users_by_email from allauth.account.adapter import get_adapter -from allauth.account.utils import setup_user_email from allauth.account.models import EmailAddress - +from allauth.account.utils import filter_users_by_email, setup_user_email +from dj_rest_auth.registration.serializers import RegisterSerializer +from dj_rest_auth.serializers import LoginSerializer, PasswordResetSerializer from django.conf import settings from django.contrib.auth import get_user_model from django.contrib.auth.models import User - +from django.core.exceptions import ValidationError as DjangoValidationError from drf_spectacular.utils import extend_schema_field -from typing import Optional, Union +from rest_framework import serializers +from rest_framework.exceptions import ValidationError from cvat.apps.iam.forms import ResetPasswordFormEx from cvat.apps.iam.utils import get_dummy_user @@ -33,22 +31,30 @@ class RegisterSerializerEx(RegisterSerializer): @extend_schema_field(serializers.BooleanField) def get_email_verification_required(self, obj: Union[dict, User]) -> bool: - return allauth_settings.EMAIL_VERIFICATION == allauth_settings.EmailVerificationMethod.MANDATORY + return ( + allauth_settings.EMAIL_VERIFICATION + == allauth_settings.EmailVerificationMethod.MANDATORY + ) @extend_schema_field(serializers.CharField(allow_null=True)) def get_key(self, obj: Union[dict, User]) -> Optional[str]: key = None - if isinstance(obj, User) and allauth_settings.EMAIL_VERIFICATION != \ - allauth_settings.EmailVerificationMethod.MANDATORY: + if ( + isinstance(obj, User) + and allauth_settings.EMAIL_VERIFICATION + != allauth_settings.EmailVerificationMethod.MANDATORY + ): key = obj.auth_token.key return key def get_cleaned_data(self): data = super().get_cleaned_data() - data.update({ - 'first_name': self.validated_data.get('first_name', ''), - 'last_name': self.validated_data.get('last_name', ''), - }) + data.update( + { + "first_name": self.validated_data.get("first_name", ""), + "last_name": self.validated_data.get("last_name", ""), + } + ) return data @@ -57,7 +63,7 @@ def email_address_exists(email) -> bool: if EmailAddress.objects.filter(email__iexact=email).exists(): return True - if (email_field := allauth_settings.USER_MODEL_EMAIL_FIELD): + if email_field := allauth_settings.USER_MODEL_EMAIL_FIELD: users = get_user_model().objects return users.filter(**{email_field + "__iexact": email}).exists() return False @@ -68,7 +74,7 @@ def email_address_exists(email) -> bool: user = get_dummy_user(email) if not user: raise serializers.ValidationError( - ('A user is already registered with this e-mail address.'), + ("A user is already registered with this e-mail address."), ) return email @@ -84,11 +90,9 @@ def save(self, request): user = adapter.save_user(request, user, self, commit=False) if "password1" in self.cleaned_data: try: - adapter.clean_password(self.cleaned_data['password1'], user=user) + adapter.clean_password(self.cleaned_data["password1"], user=user) except DjangoValidationError as exc: - raise serializers.ValidationError( - detail=serializers.as_serializer_error(exc) - ) + raise serializers.ValidationError(detail=serializers.as_serializer_error(exc)) user.save() self.custom_signup(request, user) @@ -104,35 +108,42 @@ def password_reset_form_class(self): def get_email_options(self): domain = None - if hasattr(settings, 'UI_HOST') and settings.UI_HOST: + if hasattr(settings, "UI_HOST") and settings.UI_HOST: domain = settings.UI_HOST - if hasattr(settings, 'UI_PORT') and settings.UI_PORT: - domain += ':{}'.format(settings.UI_PORT) - return { - 'domain_override': domain - } + if hasattr(settings, "UI_PORT") and settings.UI_PORT: + domain += ":{}".format(settings.UI_PORT) + return {"domain_override": domain} + class LoginSerializerEx(LoginSerializer): def get_auth_user_using_allauth(self, username, email, password): def is_email_authentication(): - return settings.ACCOUNT_AUTHENTICATION_METHOD == allauth_settings.AuthenticationMethod.EMAIL + return ( + settings.ACCOUNT_AUTHENTICATION_METHOD + == allauth_settings.AuthenticationMethod.EMAIL + ) def is_username_authentication(): - return settings.ACCOUNT_AUTHENTICATION_METHOD == allauth_settings.AuthenticationMethod.USERNAME + return ( + settings.ACCOUNT_AUTHENTICATION_METHOD + == allauth_settings.AuthenticationMethod.USERNAME + ) # check that the server settings match the request if is_username_authentication() and not username and email: raise ValidationError( - 'Attempt to authenticate with email/password. ' - 'But username/password are used for authentication on the server. ' - 'Please check your server configuration ACCOUNT_AUTHENTICATION_METHOD.') + "Attempt to authenticate with email/password. " + "But username/password are used for authentication on the server. " + "Please check your server configuration ACCOUNT_AUTHENTICATION_METHOD." + ) if is_email_authentication() and not email and username: raise ValidationError( - 'Attempt to authenticate with username/password. ' - 'But email/password are used for authentication on the server. ' - 'Please check your server configuration ACCOUNT_AUTHENTICATION_METHOD.') + "Attempt to authenticate with username/password. " + "But email/password are used for authentication on the server. " + "Please check your server configuration ACCOUNT_AUTHENTICATION_METHOD." + ) # Authentication through email if settings.ACCOUNT_AUTHENTICATION_METHOD == allauth_settings.AuthenticationMethod.EMAIL: @@ -146,6 +157,6 @@ def is_username_authentication(): if email: users = filter_users_by_email(email) if not users or len(users) > 1: - raise ValidationError('Unable to login with provided credentials') + raise ValidationError("Unable to login with provided credentials") return self._validate_username_email(username, email, password) diff --git a/cvat/apps/iam/signals.py b/cvat/apps/iam/signals.py index 73f919a1a4a4..b8bbf643dab8 100644 --- a/cvat/apps/iam/signals.py +++ b/cvat/apps/iam/signals.py @@ -3,8 +3,8 @@ # SPDX-License-Identifier: MIT from django.conf import settings -from django.contrib.auth.models import User, Group -from django.db.models.signals import post_save, post_migrate +from django.contrib.auth.models import Group, User +from django.db.models.signals import post_migrate, post_save def register_groups(sender, **kwargs): @@ -12,7 +12,9 @@ def register_groups(sender, **kwargs): for role in settings.IAM_ROLES: Group.objects.get_or_create(name=role) -if settings.IAM_TYPE == 'BASIC': + +if settings.IAM_TYPE == "BASIC": + def create_user(sender, instance, created, **kwargs): from allauth.account import app_settings as allauth_settings from allauth.account.models import EmailAddress @@ -23,14 +25,16 @@ def create_user(sender, instance, created, **kwargs): # create and verify EmailAddress for superuser accounts if allauth_settings.EMAIL_REQUIRED: - EmailAddress.objects.get_or_create(user=instance, - email=instance.email, primary=True, verified=True) - else: # don't need to add default groups for superuser - if created and not getattr(instance, 'skip_group_assigning', None): + EmailAddress.objects.get_or_create( + user=instance, email=instance.email, primary=True, verified=True + ) + else: # don't need to add default groups for superuser + if created and not getattr(instance, "skip_group_assigning", None): db_group = Group.objects.get(name=settings.IAM_DEFAULT_ROLE) instance.groups.add(db_group) -elif settings.IAM_TYPE == 'LDAP': +elif settings.IAM_TYPE == "LDAP": + def create_user(sender, user=None, ldap_user=None, **kwargs): user_groups = [] for role in settings.IAM_ROLES: @@ -56,11 +60,12 @@ def create_user(sender, user=None, ldap_user=None, **kwargs): def register_signals(app_config): post_migrate.connect(register_groups, app_config) - if settings.IAM_TYPE == 'BASIC': + if settings.IAM_TYPE == "BASIC": # Add default groups and add admin rights to super users. post_save.connect(create_user, sender=User) - elif settings.IAM_TYPE == 'LDAP': + elif settings.IAM_TYPE == "LDAP": import django_auth_ldap.backend + # Map groups from LDAP to roles, convert a user to super user if he/she # has an admin group. django_auth_ldap.backend.populate_user.connect(create_user) diff --git a/cvat/apps/iam/tests/test_rest_api.py b/cvat/apps/iam/tests/test_rest_api.py index d3de9fd6f1df..db0745e999d3 100644 --- a/cvat/apps/iam/tests/test_rest_api.py +++ b/cvat/apps/iam/tests/test_rest_api.py @@ -3,25 +3,30 @@ # # SPDX-License-Identifier: MIT -from django.urls import reverse +from allauth.account.views import EmailVerificationSentView +from django.test import override_settings +from django.urls import path, re_path, reverse from rest_framework import status -from rest_framework.test import APITestCase, APIClient from rest_framework.authtoken.models import Token -from django.test import override_settings -from django.urls import path, re_path -from allauth.account.views import EmailVerificationSentView +from rest_framework.test import APIClient, APITestCase from cvat.apps.iam.urls import urlpatterns as iam_url_patterns from cvat.apps.iam.views import ConfirmEmailViewEx - urlpatterns = iam_url_patterns + [ - re_path(r'^account-confirm-email/(?P[-:\w]+)/$', ConfirmEmailViewEx.as_view(), - name='account_confirm_email'), - path('register/account-email-verification-sent', EmailVerificationSentView.as_view(), - name='account_email_verification_sent'), + re_path( + r"^account-confirm-email/(?P[-:\w]+)/$", + ConfirmEmailViewEx.as_view(), + name="account_confirm_email", + ), + path( + "register/account-email-verification-sent", + EmailVerificationSentView.as_view(), + name="account_email_verification_sent", + ), ] + class ForceLogin: def __init__(self, user, client): self.user = user @@ -29,7 +34,7 @@ def __init__(self, user, client): def __enter__(self): if self.user: - self.client.force_login(self.user, backend='django.contrib.auth.backends.ModelBackend') + self.client.force_login(self.user, backend="django.contrib.auth.backends.ModelBackend") return self @@ -37,57 +42,91 @@ def __exit__(self, exception_type, exception_value, traceback): if self.user: self.client.logout() + class UserRegisterAPITestCase(APITestCase): - user_data = {'first_name': 'test_first', 'last_name': 'test_last', 'username': 'test_username', - 'email': 'test_email@test.com', 'password1': '$Test357Test%', 'password2': '$Test357Test%', - 'confirmations': []} + user_data = { + "first_name": "test_first", + "last_name": "test_last", + "username": "test_username", + "email": "test_email@test.com", + "password1": "$Test357Test%", + "password2": "$Test357Test%", + "confirmations": [], + } def setUp(self): self.client = APIClient() def _run_api_v2_user_register(self, data): - url = reverse('rest_register') - response = self.client.post(url, data, format='json') + url = reverse("rest_register") + response = self.client.post(url, data, format="json") return response def _check_response(self, response, data): self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data, data) - @override_settings(ACCOUNT_EMAIL_VERIFICATION='none') + @override_settings(ACCOUNT_EMAIL_VERIFICATION="none") def test_api_v2_user_register_with_email_verification_none(self): """ Ensure we can register a user and get auth token key when email verification is none """ response = self._run_api_v2_user_register(self.user_data) - user_token = Token.objects.get(user__username=response.data['username']) - self._check_response(response, {'first_name': 'test_first', 'last_name': 'test_last', - 'username': 'test_username', 'email': 'test_email@test.com', - 'email_verification_required': False, 'key': user_token.key}) + user_token = Token.objects.get(user__username=response.data["username"]) + self._check_response( + response, + { + "first_name": "test_first", + "last_name": "test_last", + "username": "test_username", + "email": "test_email@test.com", + "email_verification_required": False, + "key": user_token.key, + }, + ) # Since URLConf is executed before running the tests, so we have to manually configure the url patterns for # the tests and pass it using ROOT_URLCONF in the override settings decorator - @override_settings(ACCOUNT_EMAIL_VERIFICATION='optional', ROOT_URLCONF=__name__) + @override_settings(ACCOUNT_EMAIL_VERIFICATION="optional", ROOT_URLCONF=__name__) def test_api_v2_user_register_with_email_verification_optional(self): """ Ensure we can register a user and get auth token key when email verification is optional """ response = self._run_api_v2_user_register(self.user_data) - user_token = Token.objects.get(user__username=response.data['username']) - self._check_response(response, {'first_name': 'test_first', 'last_name': 'test_last', - 'username': 'test_username', 'email': 'test_email@test.com', - 'email_verification_required': False, 'key': user_token.key}) - - @override_settings(ACCOUNT_EMAIL_REQUIRED=True, ACCOUNT_EMAIL_VERIFICATION='mandatory', - EMAIL_BACKEND='django.core.mail.backends.console.EmailBackend', ROOT_URLCONF=__name__) + user_token = Token.objects.get(user__username=response.data["username"]) + self._check_response( + response, + { + "first_name": "test_first", + "last_name": "test_last", + "username": "test_username", + "email": "test_email@test.com", + "email_verification_required": False, + "key": user_token.key, + }, + ) + + @override_settings( + ACCOUNT_EMAIL_REQUIRED=True, + ACCOUNT_EMAIL_VERIFICATION="mandatory", + EMAIL_BACKEND="django.core.mail.backends.console.EmailBackend", + ROOT_URLCONF=__name__, + ) def test_register_account_with_email_verification_mandatory(self): """ Ensure we can register a user and it does not return auth token key when email verification is mandatory """ response = self._run_api_v2_user_register(self.user_data) - self._check_response(response, {'first_name': 'test_first', 'last_name': 'test_last', - 'username': 'test_username', 'email': 'test_email@test.com', - 'email_verification_required': True, 'key': None}) - + self._check_response( + response, + { + "first_name": "test_first", + "last_name": "test_last", + "username": "test_username", + "email": "test_email@test.com", + "email_verification_required": True, + "key": None, + }, + ) diff --git a/cvat/apps/iam/urls.py b/cvat/apps/iam/urls.py index 8b8135fc2d9a..8f66f48f22b1 100644 --- a/cvat/apps/iam/urls.py +++ b/cvat/apps/iam/urls.py @@ -3,46 +3,55 @@ # # SPDX-License-Identifier: MIT -from django.urls import path, re_path +from allauth.account import app_settings as allauth_settings +from dj_rest_auth.views import ( + LogoutView, + PasswordChangeView, + PasswordResetConfirmView, + PasswordResetView, +) from django.conf import settings +from django.urls import path, re_path from django.urls.conf import include -from dj_rest_auth.views import ( - LogoutView, PasswordChangeView, - PasswordResetView, PasswordResetConfirmView) -from allauth.account import app_settings as allauth_settings from cvat.apps.iam.views import ( - SigningView, RegisterViewEx, RulesView, - ConfirmEmailViewEx, LoginViewEx + ConfirmEmailViewEx, + LoginViewEx, + RegisterViewEx, + RulesView, + SigningView, ) -BASIC_LOGIN_PATH_NAME = 'rest_login' -BASIC_REGISTER_PATH_NAME = 'rest_register' +BASIC_LOGIN_PATH_NAME = "rest_login" +BASIC_REGISTER_PATH_NAME = "rest_register" urlpatterns = [ - path('login', LoginViewEx.as_view(), name=BASIC_LOGIN_PATH_NAME), - path('logout', LogoutView.as_view(), name='rest_logout'), - path('signing', SigningView.as_view(), name='signing'), - path('rules', RulesView.as_view(), name='rules'), + path("login", LoginViewEx.as_view(), name=BASIC_LOGIN_PATH_NAME), + path("logout", LogoutView.as_view(), name="rest_logout"), + path("signing", SigningView.as_view(), name="signing"), + path("rules", RulesView.as_view(), name="rules"), ] -if settings.IAM_TYPE == 'BASIC': +if settings.IAM_TYPE == "BASIC": urlpatterns += [ - path('register', RegisterViewEx.as_view(), name=BASIC_REGISTER_PATH_NAME), + path("register", RegisterViewEx.as_view(), name=BASIC_REGISTER_PATH_NAME), # password - path('password/reset', PasswordResetView.as_view(), - name='rest_password_reset'), - path('password/reset/confirm', PasswordResetConfirmView.as_view(), - name='rest_password_reset_confirm'), - path('password/change', PasswordChangeView.as_view(), - name='rest_password_change'), + path("password/reset", PasswordResetView.as_view(), name="rest_password_reset"), + path( + "password/reset/confirm", + PasswordResetConfirmView.as_view(), + name="rest_password_reset_confirm", + ), + path("password/change", PasswordChangeView.as_view(), name="rest_password_change"), ] - if allauth_settings.EMAIL_VERIFICATION != \ - allauth_settings.EmailVerificationMethod.NONE: + if allauth_settings.EMAIL_VERIFICATION != allauth_settings.EmailVerificationMethod.NONE: # emails urlpatterns += [ - re_path(r'^account-confirm-email/(?P[-:\w]+)/$', ConfirmEmailViewEx.as_view(), - name='account_confirm_email'), + re_path( + r"^account-confirm-email/(?P[-:\w]+)/$", + ConfirmEmailViewEx.as_view(), + name="account_confirm_email", + ), ] -urlpatterns = [path('auth/', include(urlpatterns))] +urlpatterns = [path("auth/", include(urlpatterns))] diff --git a/cvat/apps/iam/utils.py b/cvat/apps/iam/utils.py index 8095902769f3..9b911e48ea7c 100644 --- a/cvat/apps/iam/utils.py +++ b/cvat/apps/iam/utils.py @@ -1,37 +1,40 @@ -from pathlib import Path import functools import hashlib import importlib import io import tarfile +from pathlib import Path from django.conf import settings from django.contrib.sessions.backends.base import SessionBase _OPA_RULES_PATHS = { - Path(__file__).parent / 'rules', + Path(__file__).parent / "rules", } + @functools.lru_cache(maxsize=None) def get_opa_bundle() -> tuple[bytes, str]: bundle_file = io.BytesIO() - with tarfile.open(fileobj=bundle_file, mode='w:gz') as tar: + with tarfile.open(fileobj=bundle_file, mode="w:gz") as tar: for p in _OPA_RULES_PATHS: - for f in p.glob('*[!.gen].rego'): + for f in p.glob("*[!.gen].rego"): tar.add(name=f, arcname=f.relative_to(p.parent)) bundle = bundle_file.getvalue() etag = hashlib.blake2b(bundle).hexdigest() return bundle, etag + def add_opa_rules_path(path: Path) -> None: _OPA_RULES_PATHS.add(path) get_opa_bundle.cache_clear() + def get_dummy_user(email): - from allauth.account.models import EmailAddress from allauth.account import app_settings + from allauth.account.models import EmailAddress from allauth.account.utils import filter_users_by_email users = filter_users_by_email(email) @@ -40,13 +43,13 @@ def get_dummy_user(email): user = users[0] if user.has_usable_password(): return None - if app_settings.EMAIL_VERIFICATION == \ - app_settings.EmailVerificationMethod.MANDATORY: + if app_settings.EMAIL_VERIFICATION == app_settings.EmailVerificationMethod.MANDATORY: email = EmailAddress.objects.get_for_user(user, email) if email.verified: return None return user + def clean_up_sessions() -> None: SessionStore: type[SessionBase] = importlib.import_module(settings.SESSION_ENGINE).SessionStore SessionStore.clear_expired() diff --git a/cvat/apps/iam/views.py b/cvat/apps/iam/views.py index 928d170c3bc4..d9bf960e426c 100644 --- a/cvat/apps/iam/views.py +++ b/cvat/apps/iam/views.py @@ -5,49 +5,55 @@ import functools -from django.http import Http404, HttpResponseBadRequest, HttpResponseRedirect -from rest_framework import views, serializers -from rest_framework.exceptions import ValidationError -from rest_framework.permissions import AllowAny -from django.conf import settings -from django.http import HttpResponse -from django.views.decorators.http import etag as django_etag -from rest_framework.response import Response +from allauth.account import app_settings as allauth_settings +from allauth.account.utils import complete_signup, has_verified_email, send_email_confirmation +from allauth.account.views import ConfirmEmailView from dj_rest_auth.app_settings import api_settings as dj_rest_auth_settings from dj_rest_auth.registration.views import RegisterView from dj_rest_auth.utils import jwt_encode from dj_rest_auth.views import LoginView -from allauth.account import app_settings as allauth_settings -from allauth.account.views import ConfirmEmailView -from allauth.account.utils import complete_signup, has_verified_email, send_email_confirmation - -from furl import furl - -from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer, extend_schema_view +from django.conf import settings +from django.http import Http404, HttpResponse, HttpResponseBadRequest, HttpResponseRedirect +from django.views.decorators.http import etag as django_etag from drf_spectacular.contrib.rest_auth import get_token_serializer_class +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import ( + OpenApiResponse, + extend_schema, + extend_schema_view, + inline_serializer, +) +from furl import furl +from rest_framework import serializers, views +from rest_framework.exceptions import ValidationError +from rest_framework.permissions import AllowAny +from rest_framework.response import Response from .authentication import Signer from .utils import get_opa_bundle -@extend_schema(tags=['auth']) -@extend_schema_view(post=extend_schema( - summary='This method signs URL for access to the server', - description='Signed URL contains a token which authenticates a user on the server.' - 'Signed URL is valid during 30 seconds since signing.', - request=inline_serializer( - name='Signing', - fields={ - 'url': serializers.CharField(), - } - ), - responses={'200': OpenApiResponse(response=OpenApiTypes.STR, description='text URL')})) + +@extend_schema(tags=["auth"]) +@extend_schema_view( + post=extend_schema( + summary="This method signs URL for access to the server", + description="Signed URL contains a token which authenticates a user on the server." + "Signed URL is valid during 30 seconds since signing.", + request=inline_serializer( + name="Signing", + fields={ + "url": serializers.CharField(), + }, + ), + responses={"200": OpenApiResponse(response=OpenApiTypes.STR, description="text URL")}, + ) +) class SigningView(views.APIView): def post(self, request): - url = request.data.get('url') + url = request.data.get("url") if not url: - raise ValidationError('Please provide `url` parameter') + raise ValidationError("Please provide `url` parameter") signer = Signer() url = self.request.build_absolute_uri(url) @@ -56,6 +62,7 @@ def post(self, request): url = furl(url).add({Signer.QUERY_PARAM: sign}).url return Response(url) + class LoginViewEx(LoginView): """ Check the credentials and return the REST Token @@ -68,6 +75,7 @@ class LoginViewEx(LoginView): Accept the following POST parameters: username, email, password Return the REST Framework Token Object's key. """ + @extend_schema(responses=get_token_serializer_class()) def post(self, request, *args, **kwargs): self.request = request @@ -76,9 +84,9 @@ def post(self, request, *args, **kwargs): self.serializer.is_valid(raise_exception=True) except ValidationError: user = self.serializer.get_auth_user( - self.serializer.data.get('username'), - self.serializer.data.get('email'), - self.serializer.data.get('password') + self.serializer.data.get("username"), + self.serializer.data.get("email"), + self.serializer.data.get("password"), ) if not user: raise @@ -90,13 +98,14 @@ def post(self, request, *args, **kwargs): # we cannot use redirect to ACCOUNT_EMAIL_VERIFICATION_SENT_REDIRECT_URL here # because redirect will make a POST request and we'll get a 404 code # (although in the browser request method will be displayed like GET) - return HttpResponseBadRequest('Unverified email') - except Exception: # nosec + return HttpResponseBadRequest("Unverified email") + except Exception: # nosec pass self.login() return self.get_response() + class RegisterViewEx(RegisterView): def get_response_data(self, user): serializer = self.get_serializer(user) @@ -117,20 +126,24 @@ def get_response_data(self, user): # Link to the issue: https://github.com/iMerica/dj-rest-auth/issues/604 def perform_create(self, serializer): user = serializer.save(self.request) - if allauth_settings.EMAIL_VERIFICATION != \ - allauth_settings.EmailVerificationMethod.MANDATORY: + if ( + allauth_settings.EMAIL_VERIFICATION + != allauth_settings.EmailVerificationMethod.MANDATORY + ): if dj_rest_auth_settings.USE_JWT: self.access_token, self.refresh_token = jwt_encode(user) elif self.token_model: dj_rest_auth_settings.TOKEN_CREATOR(self.token_model, user, serializer) complete_signup( - self.request._request, user, + self.request._request, + user, allauth_settings.EMAIL_VERIFICATION, None, ) return user + def _etag(etag_func): """ Decorator to support conditional retrieval (or change) @@ -138,6 +151,7 @@ def _etag(etag_func): It calls Django's original decorator but pass correct request object to it. Django's original decorator doesn't work with DRF request object. """ + def decorator(func): @functools.wraps(func) def wrapper(obj_self, request, *args, **kwargs): @@ -150,9 +164,12 @@ def patched_viewset_method(*_args, **_kwargs): return func(obj_self, drf_request, *args, **kwargs) return patched_viewset_method(wsgi_request, *args, **kwargs) + return wrapper + return decorator + class RulesView(views.APIView): serializer_class = None permission_classes = [AllowAny] @@ -161,10 +178,11 @@ class RulesView(views.APIView): @_etag(lambda request: get_opa_bundle()[1]) def get(self, request): - return HttpResponse(get_opa_bundle()[0], content_type='application/x-tar') + return HttpResponse(get_opa_bundle()[0], content_type="application/x-tar") + class ConfirmEmailViewEx(ConfirmEmailView): - template_name = 'account/email/email_confirmation_signup_message.html' + template_name = "account/email/email_confirmation_signup_message.html" def get(self, *args, **kwargs): try: diff --git a/cvat/apps/lambda_manager/apps.py b/cvat/apps/lambda_manager/apps.py index 1bbc515522ad..974e32dc74a4 100644 --- a/cvat/apps/lambda_manager/apps.py +++ b/cvat/apps/lambda_manager/apps.py @@ -7,8 +7,9 @@ class LambdaManagerConfig(AppConfig): - name = 'cvat.apps.lambda_manager' + name = "cvat.apps.lambda_manager" def ready(self) -> None: from cvat.apps.iam.permissions import load_app_permissions + load_app_permissions(self) diff --git a/cvat/apps/lambda_manager/models.py b/cvat/apps/lambda_manager/models.py index 47d732c41dd1..f6e684a1cc0f 100644 --- a/cvat/apps/lambda_manager/models.py +++ b/cvat/apps/lambda_manager/models.py @@ -5,6 +5,7 @@ import django.db.models as models + class FunctionKind(models.TextChoices): DETECTOR = "detector" INTERACTOR = "interactor" diff --git a/cvat/apps/lambda_manager/permissions.py b/cvat/apps/lambda_manager/permissions.py index 94800f0edd5d..a2192cdd4914 100644 --- a/cvat/apps/lambda_manager/permissions.py +++ b/cvat/apps/lambda_manager/permissions.py @@ -8,27 +8,28 @@ from cvat.apps.engine.permissions import JobPermission, TaskPermission from cvat.apps.iam.permissions import OpenPolicyAgentPermission, StrEnum + class LambdaPermission(OpenPolicyAgentPermission): class Scopes(StrEnum): - LIST = 'list' - VIEW = 'view' - CALL_ONLINE = 'call:online' - CALL_OFFLINE = 'call:offline' - LIST_OFFLINE = 'list:offline' + LIST = "list" + VIEW = "view" + CALL_ONLINE = "call:online" + CALL_OFFLINE = "call:offline" + LIST_OFFLINE = "list:offline" @classmethod def create(cls, request, view, obj, iam_context): permissions = [] - if view.basename == 'lambda_function' or view.basename == 'lambda_request': + if view.basename == "lambda_function" or view.basename == "lambda_request": scopes = cls.get_scopes(request, view, obj) for scope in scopes: self = cls.create_base_perm(request, view, scope, iam_context, obj) permissions.append(self) - if job_id := request.data.get('job'): + if job_id := request.data.get("job"): perm = JobPermission.create_scope_view_data(iam_context, job_id) permissions.append(perm) - elif task_id := request.data.get('task'): + elif task_id := request.data.get("task"): perm = TaskPermission.create_scope_view_data(iam_context, task_id) permissions.append(perm) @@ -36,20 +37,22 @@ def create(cls, request, view, obj, iam_context): def __init__(self, **kwargs): super().__init__(**kwargs) - self.url = settings.IAM_OPA_DATA_URL + '/lambda/allow' + self.url = settings.IAM_OPA_DATA_URL + "/lambda/allow" @staticmethod def get_scopes(request, view, obj): Scopes = __class__.Scopes - return [{ - ('lambda_function', 'list'): Scopes.LIST, - ('lambda_function', 'retrieve'): Scopes.VIEW, - ('lambda_function', 'call'): Scopes.CALL_ONLINE, - ('lambda_request', 'create'): Scopes.CALL_OFFLINE, - ('lambda_request', 'list'): Scopes.LIST_OFFLINE, - ('lambda_request', 'retrieve'): Scopes.CALL_OFFLINE, - ('lambda_request', 'destroy'): Scopes.CALL_OFFLINE, - }[(view.basename, view.action)]] + return [ + { + ("lambda_function", "list"): Scopes.LIST, + ("lambda_function", "retrieve"): Scopes.VIEW, + ("lambda_function", "call"): Scopes.CALL_ONLINE, + ("lambda_request", "create"): Scopes.CALL_OFFLINE, + ("lambda_request", "list"): Scopes.LIST_OFFLINE, + ("lambda_request", "retrieve"): Scopes.CALL_OFFLINE, + ("lambda_request", "destroy"): Scopes.CALL_OFFLINE, + }[(view.basename, view.action)] + ] def get_resource(self): return None diff --git a/cvat/apps/lambda_manager/rules/tests/generators/lambda_test.gen.rego.py b/cvat/apps/lambda_manager/rules/tests/generators/lambda_test.gen.rego.py index 94f694988a38..f506fda56a07 100644 --- a/cvat/apps/lambda_manager/rules/tests/generators/lambda_test.gen.rego.py +++ b/cvat/apps/lambda_manager/rules/tests/generators/lambda_test.gen.rego.py @@ -77,13 +77,15 @@ def get_data(scope, context, ownership, privilege, membership, resource): "scope": scope, "auth": { "user": {"id": random.randrange(0, 100), "privilege": privilege}, - "organization": { - "id": random.randrange(100, 200), - "owner": {"id": random.randrange(200, 300)}, - "user": {"role": membership}, - } - if context == "organization" - else None, + "organization": ( + { + "id": random.randrange(100, 200), + "owner": {"id": random.randrange(200, 300)}, + "user": {"role": membership}, + } + if context == "organization" + else None + ), }, "resource": resource, } diff --git a/cvat/apps/lambda_manager/serializers.py b/cvat/apps/lambda_manager/serializers.py index ab8809bd7cc8..8daf3a53642b 100644 --- a/cvat/apps/lambda_manager/serializers.py +++ b/cvat/apps/lambda_manager/serializers.py @@ -5,20 +5,25 @@ from drf_spectacular.utils import extend_schema_serializer from rest_framework import serializers + class SublabelMappingEntrySerializer(serializers.Serializer): name = serializers.CharField() attributes = serializers.DictField(child=serializers.CharField(), required=False) + class LabelMappingEntrySerializer(serializers.Serializer): name = serializers.CharField() attributes = serializers.DictField(child=serializers.CharField(), required=False) - sublabels = serializers.DictField(child=SublabelMappingEntrySerializer(), required=False, - help_text="Label mapping for from the model to the task sublabels within a parent label" + sublabels = serializers.DictField( + child=SublabelMappingEntrySerializer(), + required=False, + help_text="Label mapping for from the model to the task sublabels within a parent label", ) + @extend_schema_serializer( # The "Request" suffix is added by drf-spectacular automatically - component_name='FunctionCall' + component_name="FunctionCall" ) class FunctionCallRequestSerializer(serializers.Serializer): function = serializers.CharField(help_text="The name of the function to execute") @@ -26,13 +31,25 @@ class FunctionCallRequestSerializer(serializers.Serializer): job = serializers.IntegerField(required=False, help_text="The id of the job to be annotated") max_distance = serializers.IntegerField(required=False) threshold = serializers.FloatField(required=False) - cleanup = serializers.BooleanField(help_text="Whether existing annotations should be removed", default=False) - convMaskToPoly = serializers.BooleanField(required=False, source="conv_mask_to_poly", write_only=True, help_text="Deprecated; use conv_mask_to_poly instead") - conv_mask_to_poly = serializers.BooleanField(required=False, help_text="Convert mask shapes to polygons") - mapping = serializers.DictField(child=LabelMappingEntrySerializer(), required=False, - help_text="Label mapping from the model to the task labels" + cleanup = serializers.BooleanField( + help_text="Whether existing annotations should be removed", default=False + ) + convMaskToPoly = serializers.BooleanField( + required=False, + source="conv_mask_to_poly", + write_only=True, + help_text="Deprecated; use conv_mask_to_poly instead", + ) + conv_mask_to_poly = serializers.BooleanField( + required=False, help_text="Convert mask shapes to polygons" + ) + mapping = serializers.DictField( + child=LabelMappingEntrySerializer(), + required=False, + help_text="Label mapping from the model to the task labels", ) + class FunctionCallParamsSerializer(serializers.Serializer): id = serializers.CharField(allow_null=True, help_text="The name of the function") @@ -41,6 +58,7 @@ class FunctionCallParamsSerializer(serializers.Serializer): threshold = serializers.FloatField(allow_null=True) + class FunctionCallSerializer(serializers.Serializer): id = serializers.CharField(help_text="Request id") diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py index f9292b278b45..38e812b25ff0 100644 --- a/cvat/apps/lambda_manager/tests/test_lambda.py +++ b/cvat/apps/lambda_manager/tests/test_lambda.py @@ -3,12 +3,12 @@ # # SPDX-License-Identifier: MIT +import json +import os from collections import Counter, OrderedDict from itertools import groupby from typing import Optional from unittest import mock, skip -import json -import os import requests from django.contrib.auth.models import Group, User @@ -16,16 +16,22 @@ from rest_framework import status from cvat.apps.engine.tests.utils import ( - ApiTestBase, filter_dict, ForceLogin, generate_image_file, get_paginated_collection + ApiTestBase, + ForceLogin, + filter_dict, + generate_image_file, + get_paginated_collection, ) -LAMBDA_ROOT_PATH = '/api/lambda' -LAMBDA_FUNCTIONS_PATH = f'{LAMBDA_ROOT_PATH}/functions' -LAMBDA_REQUESTS_PATH = f'{LAMBDA_ROOT_PATH}/requests' +LAMBDA_ROOT_PATH = "/api/lambda" +LAMBDA_FUNCTIONS_PATH = f"{LAMBDA_ROOT_PATH}/functions" +LAMBDA_REQUESTS_PATH = f"{LAMBDA_ROOT_PATH}/requests" id_function_detector = "test-openvino-omz-public-yolo-v3-tf" id_function_reid_with_response_data = "test-openvino-omz-intel-person-reidentification-retail-0300" -id_function_reid_with_no_response_data = "test-openvino-omz-intel-person-reidentification-retail-1234" +id_function_reid_with_no_response_data = ( + "test-openvino-omz-intel-person-reidentification-retail-1234" +) id_function_interactor = "test-openvino-dextr" id_function_tracker = "test-pth-foolwood-siammask" id_function_non_type = "test-model-has-non-type" @@ -36,29 +42,47 @@ id_function_state_error = "test-model-has-state-error" expected_keys_in_response_all_functions = ["id", "kind", "labels_v2", "description", "name"] -expected_keys_in_response_function_interactor = expected_keys_in_response_all_functions + ["min_pos_points", "startswith_box"] -expected_keys_in_response_requests = ["id", "function", "status", "progress", "enqueued", "started", "ended", "exc_info"] - -path = os.path.join(os.path.dirname(__file__), 'assets', 'tasks.json') +expected_keys_in_response_function_interactor = expected_keys_in_response_all_functions + [ + "min_pos_points", + "startswith_box", +] +expected_keys_in_response_requests = [ + "id", + "function", + "status", + "progress", + "enqueued", + "started", + "ended", + "exc_info", +] + +path = os.path.join(os.path.dirname(__file__), "assets", "tasks.json") with open(path) as f: tasks = json.load(f) # removed unnecessary data -path = os.path.join(os.path.dirname(__file__), 'assets', 'functions.json') +path = os.path.join(os.path.dirname(__file__), "assets", "functions.json") with open(path) as f: functions = json.load(f) + class _LambdaTestCaseBase(ApiTestBase): def setUp(self): super().setUp() self.client = self.client_class(raise_request_exception=False) - http_patcher = mock.patch('cvat.apps.lambda_manager.views.LambdaGateway._http', side_effect = self._get_data_from_lambda_manager_http) + http_patcher = mock.patch( + "cvat.apps.lambda_manager.views.LambdaGateway._http", + side_effect=self._get_data_from_lambda_manager_http, + ) self.addCleanup(http_patcher.stop) http_patcher.start() - invoke_patcher = mock.patch('cvat.apps.lambda_manager.views.LambdaGateway.invoke', side_effect = self._invoke_function) + invoke_patcher = mock.patch( + "cvat.apps.lambda_manager.views.LambdaGateway.invoke", side_effect=self._invoke_function + ) self.addCleanup(invoke_patcher.stop) invoke_patcher.start() @@ -72,13 +96,13 @@ def _get_data_from_lambda_manager_http(self, **kwargs): if func_id in [id_function_state_building, id_function_state_error]: r = requests.RequestException() r.response = HttpResponseServerError() - raise r # raise 500 Internal_Server error + raise r # raise 500 Internal_Server error return functions["positive"][func_id] else: r = requests.HTTPError() r.response = HttpResponseNotFound() - raise r # raise 404 Not Found error + raise r # raise 404 Not Found error def _invoke_function(self, func, payload): data = [] @@ -135,27 +159,32 @@ def _create_db_users(cls): (group_admin, _) = Group.objects.get_or_create(name="admin") (group_user, _) = Group.objects.get_or_create(name="user") - user_admin = User.objects.create_superuser(username="admin", email="", - password="admin") + user_admin = User.objects.create_superuser(username="admin", email="", password="admin") user_admin.groups.add(group_admin) - user_dummy = User.objects.create_user(username="user", password="user", - email="user@example.com") + user_dummy = User.objects.create_user( + username="user", password="user", email="user@example.com" + ) user_dummy.groups.add(group_user) cls.admin = user_admin cls.user = user_dummy - def _create_task(self, task_spec, data, *, owner=None, org_id=None): with ForceLogin(owner or self.admin, self.client): - response = self.client.post('/api/tasks', data=task_spec, format="json", - QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) + response = self.client.post( + "/api/tasks", + data=task_spec, + format="json", + QUERY_STRING=f"org_id={org_id}" if org_id is not None else None, + ) assert response.status_code == status.HTTP_201_CREATED, response.status_code tid = response.data["id"] - response = self.client.post("/api/tasks/%s/data" % tid, + response = self.client.post( + "/api/tasks/%s/data" % tid, data=data, - QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) + QUERY_STRING=f"org_id={org_id}" if org_id is not None else None, + ) assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code rq_id = response.json()["rq_id"] @@ -163,65 +192,72 @@ def _create_task(self, task_spec, data, *, owner=None, org_id=None): assert response.status_code == status.HTTP_200_OK, response.status_code assert response.json()["status"] == "finished", response.json().get("status") - response = self.client.get("/api/tasks/%s" % tid, - QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) + response = self.client.get( + "/api/tasks/%s" % tid, + QUERY_STRING=f"org_id={org_id}" if org_id is not None else None, + ) task = response.data return task - - def _generate_task_images(self, count): # pylint: disable=no-self-use + def _generate_task_images(self, count): # pylint: disable=no-self-use images = { - "client_files[%d]" % i: generate_image_file("image_%d.jpg" % i) - for i in range(count) + "client_files[%d]" % i: generate_image_file("image_%d.jpg" % i) for i in range(count) } images["image_quality"] = 75 return images - @classmethod def setUpTestData(cls): cls._create_db_users() - def _get_request(self, path, user, *, org_id=None): with ForceLogin(user, self.client): - response = self.client.get(path, - QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') + response = self.client.get( + path, QUERY_STRING=f"org_id={org_id}" if org_id is not None else "" + ) return response - def _delete_request(self, path, user, *, org_id=None): with ForceLogin(user, self.client): - response = self.client.delete(path, - QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') + response = self.client.delete( + path, QUERY_STRING=f"org_id={org_id}" if org_id is not None else "" + ) return response - def _post_request(self, path, user, data, *, org_id=None): data = json.dumps(data) with ForceLogin(user, self.client): - response = self.client.post(path, data=data, content_type='application/json', - QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') + response = self.client.post( + path, + data=data, + content_type="application/json", + QUERY_STRING=f"org_id={org_id}" if org_id is not None else "", + ) return response - def _patch_request(self, path, user, data, *, org_id=None): data = json.dumps(data) with ForceLogin(user, self.client): - response = self.client.patch(path, data=data, content_type='application/json', - QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') + response = self.client.patch( + path, + data=data, + content_type="application/json", + QUERY_STRING=f"org_id={org_id}" if org_id is not None else "", + ) return response - def _put_request(self, path, user, data, *, org_id=None): data = json.dumps(data) with ForceLogin(user, self.client): - response = self.client.put(path, data=data, content_type='application/json', - QUERY_STRING=f'org_id={org_id}' if org_id is not None else '') + response = self.client.put( + path, + data=data, + content_type="application/json", + QUERY_STRING=f"org_id={org_id}" if org_id is not None else "", + ) return response - def _check_expected_keys_in_response_function(self, data): kind = data["kind"] if kind == "interactor": @@ -232,7 +268,7 @@ def _check_expected_keys_in_response_function(self, data): self.assertIn(key, data) def _delete_lambda_request(self, request_id: str, user: Optional[User] = None) -> None: - response = self._delete_request(f'{LAMBDA_REQUESTS_PATH}/{request_id}', user or self.admin) + response = self._delete_request(f"{LAMBDA_REQUESTS_PATH}/{request_id}", user or self.admin) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) @@ -261,8 +297,7 @@ def test_api_v2_lambda_functions_list(self): response = self._get_request(LAMBDA_FUNCTIONS_PATH, None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - - @mock.patch('cvat.apps.lambda_manager.views.LambdaGateway._http', return_value = {}) + @mock.patch("cvat.apps.lambda_manager.views.LambdaGateway._http", return_value={}) def test_api_v2_lambda_functions_list_empty(self, mock_http): response = self._get_request(LAMBDA_FUNCTIONS_PATH, self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -275,10 +310,12 @@ def test_api_v2_lambda_functions_list_empty(self, mock_http): response = self._get_request(LAMBDA_FUNCTIONS_PATH, None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - @mock.patch( - 'cvat.apps.lambda_manager.views.LambdaGateway._http', - return_value={**functions["negative"], id_function_detector: functions["positive"][id_function_detector]} + "cvat.apps.lambda_manager.views.LambdaGateway._http", + return_value={ + **functions["negative"], + id_function_detector: functions["positive"][id_function_detector], + }, ) def test_api_v2_lambda_functions_list_negative(self, mock_http): response = self._get_request(LAMBDA_FUNCTIONS_PATH, self.admin) @@ -289,11 +326,15 @@ def test_api_v2_lambda_functions_list_negative(self, mock_http): self.assertEqual(visible_ids, {id_function_detector}) def test_api_v2_lambda_functions_read(self): - ids_functions = [id_function_detector, id_function_interactor,\ - id_function_tracker, id_function_reid_with_response_data] + ids_functions = [ + id_function_detector, + id_function_interactor, + id_function_tracker, + id_function_reid_with_response_data, + ] for id_func in ids_functions: - path = f'{LAMBDA_FUNCTIONS_PATH}/{id_func}' + path = f"{LAMBDA_FUNCTIONS_PATH}/{id_func}" response = self._get_request(path, self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -306,32 +347,31 @@ def test_api_v2_lambda_functions_read(self): response = self._get_request(path, None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_api_v2_lambda_functions_read_wrong_id(self): id_wrong_function = "test-functions-wrong-id" - response = self._get_request(f'{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}', self.admin) + response = self._get_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}", self.admin) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - response = self._get_request(f'{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}', self.user) + response = self._get_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}", self.user) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - response = self._get_request(f'{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}', None) + response = self._get_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}", None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_api_v2_lambda_functions_read_negative(self): for id_func in [ - id_function_non_type, id_function_wrong_type, id_function_unknown_type, + id_function_non_type, + id_function_wrong_type, + id_function_unknown_type, id_function_non_unique_labels, ]: with mock.patch( - 'cvat.apps.lambda_manager.views.LambdaGateway._http', - return_value=functions["negative"][id_func] + "cvat.apps.lambda_manager.views.LambdaGateway._http", + return_value=functions["negative"][id_func], ): - response = self._get_request(f'{LAMBDA_FUNCTIONS_PATH}/{id_func}', self.admin) + response = self._get_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_func}", self.admin) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - @skip("Fail: add mock") def test_api_v2_lambda_requests_list(self): response = self._get_request(LAMBDA_REQUESTS_PATH, self.admin) @@ -347,7 +387,6 @@ def test_api_v2_lambda_requests_list(self): response = self._get_request(LAMBDA_REQUESTS_PATH, None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_api_v2_lambda_requests_list_empty(self): response = self._get_request(LAMBDA_REQUESTS_PATH, self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -360,7 +399,6 @@ def test_api_v2_lambda_requests_list_empty(self): response = self._get_request(LAMBDA_REQUESTS_PATH, None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_api_v2_lambda_requests_read(self): # create request data_main_task = { @@ -369,76 +407,78 @@ def test_api_v2_lambda_requests_read(self): "cleanup": True, "threshold": 55, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data_main_task) self.assertEqual(response.status_code, status.HTTP_200_OK) id_request = response.data["id"] - response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.admin) + response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) for key in expected_keys_in_response_requests: self.assertIn(key, response.data) - response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.user) + response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.user) self.assertEqual(response.status_code, status.HTTP_200_OK) for key in expected_keys_in_response_requests: self.assertIn(key, response.data) - response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', None) + response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_api_v2_lambda_requests_read_wrong_id(self): id_request = "cf343b95-afeb-475e-ab53-8d7e64991d30-wrong-id" - response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.admin) + response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.admin) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.user) + response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.user) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', None) + response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_api_v2_lambda_requests_delete_finished_request(self): data = { "function": id_function_detector, "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) id_request = response.data["id"] - response = self._delete_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', None) + response = self._delete_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", None) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = self._delete_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.admin) + response = self._delete_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.admin) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.admin) + response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.admin) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) id_request = response.data["id"] - response = self._delete_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.user) + response = self._delete_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.user) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.user) + response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.user) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - @skip("Fail: add mock") def test_api_v2_lambda_requests_delete_not_finished_request(self): pass - def test_api_v2_lambda_requests_create(self): - ids_functions = [id_function_detector, id_function_interactor, id_function_tracker, \ - id_function_reid_with_response_data, id_function_detector, id_function_reid_with_no_response_data] + ids_functions = [ + id_function_detector, + id_function_interactor, + id_function_tracker, + id_function_reid_with_response_data, + id_function_detector, + id_function_reid_with_no_response_data, + ] for id_func in ids_functions: data_main_task = { @@ -447,7 +487,7 @@ def test_api_v2_lambda_requests_create(self): "cleanup": True, "threshold": 55, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } data_assigneed_to_user_task = { @@ -456,7 +496,7 @@ def test_api_v2_lambda_requests_create(self): "cleanup": False, "max_distance": 70, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } @@ -467,7 +507,9 @@ def test_api_v2_lambda_requests_create(self): self._delete_lambda_request(response.data["id"]) - response = self._post_request(LAMBDA_REQUESTS_PATH, self.user, data_assigneed_to_user_task) + response = self._post_request( + LAMBDA_REQUESTS_PATH, self.user, data_assigneed_to_user_task + ) self.assertEqual(response.status_code, status.HTTP_200_OK) for key in expected_keys_in_response_requests: self.assertIn(key, response.data) @@ -480,10 +522,11 @@ def test_api_v2_lambda_requests_create(self): response = self._post_request(LAMBDA_REQUESTS_PATH, None, data_main_task) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_api_v2_lambda_requests_create_negative(self): for id_func in [ - id_function_non_type, id_function_wrong_type, id_function_unknown_type, + id_function_non_type, + id_function_wrong_type, + id_function_unknown_type, id_function_non_unique_labels, ]: data = { @@ -491,49 +534,45 @@ def test_api_v2_lambda_requests_create_negative(self): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } with mock.patch( - 'cvat.apps.lambda_manager.views.LambdaGateway._http', + "cvat.apps.lambda_manager.views.LambdaGateway._http", return_value=functions["negative"][id_func], ): response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - def test_api_v2_lambda_requests_create_empty_data(self): data = {} response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - def test_api_v2_lambda_requests_create_without_function(self): data = { "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - def test_api_v2_lambda_requests_create_wrong_id_function(self): data = { "function": "test-requests-wrong-id", "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - @skip("Fail: add mock") def test_api_v2_lambda_requests_create_two_requests(self): data = { @@ -541,10 +580,10 @@ def test_api_v2_lambda_requests_create_two_requests(self): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } - request_id = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data).data['id'] + request_id = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data).data["id"] response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) self.assertEqual(response.status_code, status.HTTP_409_CONFLICT) @@ -569,7 +608,7 @@ def test_api_v2_lambda_requests_create_without_cleanup(self): "function": id_function_detector, "task": self.main_task["id"], "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -597,26 +636,24 @@ def test_api_v2_lambda_requests_create_without_task(self): "function": id_function_detector, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - def test_api_v2_lambda_requests_create_wrong_id_task(self): data = { "function": id_function_detector, "task": 12345, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - def test_api_v2_lambda_requests_create_is_not_ready(self): ids_functions = [id_function_state_building, id_function_state_error] @@ -626,14 +663,13 @@ def test_api_v2_lambda_requests_create_is_not_ready(self): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - def test_api_v2_lambda_functions_create_detector(self): data_main_task = { "task": self.main_task["id"], @@ -641,7 +677,7 @@ def test_api_v2_lambda_functions_create_detector(self): "cleanup": True, "threshold": 0.55, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } data_assigneed_to_user_task = { @@ -649,122 +685,199 @@ def test_api_v2_lambda_functions_create_detector(self): "frame": 0, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data_main_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data_main_task + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.user, data_assigneed_to_user_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", + self.user, + data_assigneed_to_user_task, + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", None, data_main_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", None, data_main_task + ) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - @skip("Fail: expected result != actual result") # TODO move test to test_api_v2_lambda_functions_create + @skip( + "Fail: expected result != actual result" + ) # TODO move test to test_api_v2_lambda_functions_create def test_api_v2_lambda_functions_create_user_assigned_to_no_user(self): data = { "task": self.main_task["id"], "frame": 0, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.user, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.user, data + ) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - def test_api_v2_lambda_functions_create_interactor(self): data_main_task = { "task": self.main_task["id"], "frame": 0, "pos_points": [ - [3.45, 6.78], - [12.1, 12.1], + [3.45, 6.78], + [12.1, 12.1], [34.1, 41.0], - [43.01, 43.99], + [43.01, 43.99], ], "neg_points": [ - [3.25, 6.58], - [11.1, 11.0], - [35.5, 44.44], - [45.01, 45.99], - ], + [3.25, 6.58], + [11.1, 11.0], + [35.5, 44.44], + [45.01, 45.99], + ], } data_assigneed_to_user_task = { "task": self.assigneed_to_user_task["id"], "frame": 0, "threshold": 0.1, "pos_points": [ - [3.45, 6.78], - [12.1, 12.1], + [3.45, 6.78], + [12.1, 12.1], [34.1, 41.0], - [43.01, 43.99], + [43.01, 43.99], ], "neg_points": [ - [3.25, 6.58], - [11.1, 11.0], - [35.5, 44.44], - [45.01, 45.99], - ], + [3.25, 6.58], + [11.1, 11.0], + [35.5, 44.44], + [45.01, 45.99], + ], } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}", self.admin, data_main_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}", self.admin, data_main_task + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}", self.user, data_assigneed_to_user_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}", + self.user, + data_assigneed_to_user_task, + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}", None, data_main_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}", None, data_main_task + ) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_api_v2_lambda_functions_create_tracker(self): data_main_task = { "task": self.main_task["id"], "frame": 0, "shape": [ - 12.12, - 34.45, - 54.0, - 76.12, - ], + 12.12, + 34.45, + 54.0, + 76.12, + ], } data_assigneed_to_user_task = { "task": self.assigneed_to_user_task["id"], "frame": 0, "shape": [ - 12.12, - 34.45, - 54.0, - 76.12, - ], + 12.12, + 34.45, + 54.0, + 76.12, + ], } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", self.admin, data_main_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", self.admin, data_main_task + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", self.user, data_assigneed_to_user_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", self.user, data_assigneed_to_user_task + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", None, data_main_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", None, data_main_task + ) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_api_v2_lambda_functions_create_reid(self): data_main_task = { "task": self.main_task["id"], "frame0": 0, "frame1": 1, "boxes0": [ - OrderedDict([('attributes', []), ('frame', 0), ('group', None), ('id', 11258), ('label_id', 8), ('occluded', False), ('path_id', 0), ('points', [137.0, 129.0, 457.0, 676.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]), - OrderedDict([('attributes', []), ('frame', 0), ('group', None), ('id', 11259), ('label_id', 8), ('occluded', False), ('path_id', 1), ('points', [1511.0, 224.0, 1537.0, 437.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]), + OrderedDict( + [ + ("attributes", []), + ("frame", 0), + ("group", None), + ("id", 11258), + ("label_id", 8), + ("occluded", False), + ("path_id", 0), + ("points", [137.0, 129.0, 457.0, 676.0]), + ("source", "auto"), + ("type", "rectangle"), + ("z_order", 0), + ] + ), + OrderedDict( + [ + ("attributes", []), + ("frame", 0), + ("group", None), + ("id", 11259), + ("label_id", 8), + ("occluded", False), + ("path_id", 1), + ("points", [1511.0, 224.0, 1537.0, 437.0]), + ("source", "auto"), + ("type", "rectangle"), + ("z_order", 0), + ] + ), ], "boxes1": [ - OrderedDict([('attributes', []), ('frame', 1), ('group', None), ('id', 11260), ('label_id', 8), ('occluded', False), ('points', [1076.0, 199.0, 1218.0, 593.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]), - OrderedDict([('attributes', []), ('frame', 1), ('group', None), ('id', 11261), ('label_id', 8), ('occluded', False), ('points', [924.0, 177.0, 1090.0, 615.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]), + OrderedDict( + [ + ("attributes", []), + ("frame", 1), + ("group", None), + ("id", 11260), + ("label_id", 8), + ("occluded", False), + ("points", [1076.0, 199.0, 1218.0, 593.0]), + ("source", "auto"), + ("type", "rectangle"), + ("z_order", 0), + ] + ), + OrderedDict( + [ + ("attributes", []), + ("frame", 1), + ("group", None), + ("id", 11261), + ("label_id", 8), + ("occluded", False), + ("points", [924.0, 177.0, 1090.0, 615.0]), + ("source", "auto"), + ("type", "rectangle"), + ("z_order", 0), + ] + ), ], "threshold": 0.5, "max_distance": 55, @@ -774,63 +887,154 @@ def test_api_v2_lambda_functions_create_reid(self): "frame0": 0, "frame1": 1, "boxes0": [ - OrderedDict([('attributes', []), ('frame', 0), ('group', None), ('id', 11258), ('label_id', 8), ('occluded', False), ('path_id', 0), ('points', [137.0, 129.0, 457.0, 676.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]), - OrderedDict([('attributes', []), ('frame', 0), ('group', None), ('id', 11259), ('label_id', 8), ('occluded', False), ('path_id', 1), ('points', [1511.0, 224.0, 1537.0, 437.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]), + OrderedDict( + [ + ("attributes", []), + ("frame", 0), + ("group", None), + ("id", 11258), + ("label_id", 8), + ("occluded", False), + ("path_id", 0), + ("points", [137.0, 129.0, 457.0, 676.0]), + ("source", "auto"), + ("type", "rectangle"), + ("z_order", 0), + ] + ), + OrderedDict( + [ + ("attributes", []), + ("frame", 0), + ("group", None), + ("id", 11259), + ("label_id", 8), + ("occluded", False), + ("path_id", 1), + ("points", [1511.0, 224.0, 1537.0, 437.0]), + ("source", "auto"), + ("type", "rectangle"), + ("z_order", 0), + ] + ), ], "boxes1": [ - OrderedDict([('attributes', []), ('frame', 1), ('group', None), ('id', 11260), ('label_id', 8), ('occluded', False), ('points', [1076.0, 199.0, 1218.0, 593.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]), - OrderedDict([('attributes', []), ('frame', 1), ('group', 0), ('id', 11398), ('label_id', 8), ('occluded', False), ('points', [184.3935546875, 211.5048828125, 331.64968722073354, 97.27792672028772, 445.87667560321825, 126.17873100983161, 454.13404825737416, 691.8087578194827, 180.26452189455085]), ('source', 'manual'), ('type', 'polygon'), ('z_order', 0)]), + OrderedDict( + [ + ("attributes", []), + ("frame", 1), + ("group", None), + ("id", 11260), + ("label_id", 8), + ("occluded", False), + ("points", [1076.0, 199.0, 1218.0, 593.0]), + ("source", "auto"), + ("type", "rectangle"), + ("z_order", 0), + ] + ), + OrderedDict( + [ + ("attributes", []), + ("frame", 1), + ("group", 0), + ("id", 11398), + ("label_id", 8), + ("occluded", False), + ( + "points", + [ + 184.3935546875, + 211.5048828125, + 331.64968722073354, + 97.27792672028772, + 445.87667560321825, + 126.17873100983161, + 454.13404825737416, + 691.8087578194827, + 180.26452189455085, + ], + ), + ("source", "manual"), + ("type", "polygon"), + ("z_order", 0), + ] + ), ], } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}", self.admin, data_main_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}", + self.admin, + data_main_task, + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}", self.user, data_assigneed_to_user_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}", + self.user, + data_assigneed_to_user_task, + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}", None, data_main_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}", None, data_main_task + ) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}", self.admin, data_main_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}", + self.admin, + data_main_task, + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}", self.user, data_assigneed_to_user_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}", + self.user, + data_assigneed_to_user_task, + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}", None, data_main_task) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}", + None, + data_main_task, + ) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - def test_api_v2_lambda_functions_create_negative(self): data = { "task": self.main_task["id"], "frame": 0, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } for id_func in [ - id_function_non_type, id_function_wrong_type, id_function_unknown_type, + id_function_non_type, + id_function_wrong_type, + id_function_unknown_type, id_function_non_unique_labels, ]: with mock.patch( - 'cvat.apps.lambda_manager.views.LambdaGateway._http', - return_value=functions["negative"][id_func] + "cvat.apps.lambda_manager.views.LambdaGateway._http", + return_value=functions["negative"][id_func], ): - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_func}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_func}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - def test_api_v2_lambda_functions_convert_mask_to_rle(self): data_main_task = { "function": id_function_detector, "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data_main_task) @@ -839,7 +1043,7 @@ def test_api_v2_lambda_functions_convert_mask_to_rle(self): request_status = "started" while request_status != "finished" and request_status != "failed": - response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.admin) + response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) request_status = response.json().get("status") self.assertEqual(request_status, "finished") @@ -854,13 +1058,13 @@ def test_api_v2_lambda_functions_convert_mask_to_rle(self): # [1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0] -> [0, 2, 2, 2, 2, 2, 2] self.assertEqual(masks[0].get("points"), [0, 2, 2, 2, 2, 2, 2, 0, 0, 2, 3]) - def test_api_v2_lambda_functions_create_empty_data(self): data = {} - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - def test_api_v2_lambda_functions_create_detector_empty_mapping(self): data = { "task": self.main_task["id"], @@ -868,82 +1072,89 @@ def test_api_v2_lambda_functions_create_detector_empty_mapping(self): "cleanup": True, "mapping": {}, } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - def test_api_v2_lambda_functions_create_detector_without_cleanup(self): data = { "task": self.main_task["id"], "frame": 0, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - def test_api_v2_lambda_functions_create_detector_without_mapping(self): data = { "task": self.main_task["id"], "frame": 0, "cleanup": True, } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - def test_api_v2_lambda_functions_create_detector_without_task(self): data = { "frame": 0, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - def test_api_v2_lambda_functions_create_detector_without_id_frame(self): data = { "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - def test_api_v2_lambda_functions_create_wrong_id_function(self): data = { "task": self.main_task["id"], "frame": 0, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/test-functions-wrong-id", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/test-functions-wrong-id", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - def test_api_v2_lambda_functions_create_wrong_id_task(self): data = { "task": 12345, "frame": 0, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - @skip("Fail: expected result != actual result, issue #2770") def test_api_v2_lambda_functions_create_detector_wrong_id_frame(self): data = { @@ -951,13 +1162,14 @@ def test_api_v2_lambda_functions_create_detector_wrong_id_frame(self): "frame": 12345, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - @skip("Fail: add mock and expected result != actual result") def test_api_v2_lambda_functions_create_two_functions(self): data = { @@ -965,27 +1177,32 @@ def test_api_v2_lambda_functions_create_two_functions(self): "frame": 0, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_409_CONFLICT) - def test_api_v2_lambda_functions_create_function_is_not_ready(self): data = { "task": self.main_task["id"], "frame": 0, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_building}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_building}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_error}", self.admin, data) + response = self._post_request( + f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_error}", self.admin, data + ) self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) @@ -1038,29 +1255,27 @@ def setUp(self): self.task = self._create_task( task_spec={ - 'name': 'test_task', - 'labels': [{'name': 'car'}], - 'segment_size': segment_size + "name": "test_task", + "labels": [{"name": "car"}], + "segment_size": segment_size, }, data=data, - owner=self.user + owner=self.user, ) self.task_rel_frame_range = range(len(range(start_frame, stop_frame, frame_step))) self.start_frame = start_frame self.frame_step = frame_step self.segment_size = segment_size - self.labels = get_paginated_collection(lambda page: - self._get_request( - f"/api/labels?task_id={self.task['id']}&page={page}&sort=id", - self.admin + self.labels = get_paginated_collection( + lambda page: self._get_request( + f"/api/labels?task_id={self.task['id']}&page={page}&sort=id", self.admin ) ) - self.jobs = get_paginated_collection(lambda page: - self._get_request( - f"/api/jobs?task_id={self.task['id']}&page={page}", - self.admin + self.jobs = get_paginated_collection( + lambda page: self._get_request( + f"/api/jobs?task_id={self.task['id']}&page={page}", self.admin ) ) @@ -1068,7 +1283,7 @@ def setUp(self): self.reid_function_id = id_function_reid_with_response_data self.common_request_data = { - "task": self.task['id'], + "task": self.task["id"], "cleanup": True, } @@ -1085,14 +1300,14 @@ def _run_offline_function(self, function_id, data, user): def _wait_request(self, request_id: str) -> str: request_status = "started" while request_status != "finished" and request_status != "failed": - response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{request_id}', self.admin) + response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{request_id}", self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) request_status = response.json().get("status") return request_status def _run_online_function(self, function_id, data, user): - response = self._post_request(f'{LAMBDA_FUNCTIONS_PATH}/{function_id}', user, data) + response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{function_id}", user, data) return response def test_can_run_offline_detector_function_on_whole_task(self): @@ -1108,40 +1323,39 @@ def test_can_run_offline_detector_function_on_whole_task(self): requested_frame_range = self.task_rel_frame_range self.assertEqual( - { - frame: 1 for frame in requested_frame_range - }, + {frame: 1 for frame in requested_frame_range}, { frame: len(list(group)) for frame, group in groupby(annotations["shapes"], key=lambda a: a["frame"]) - } + }, ) def test_can_run_offline_reid_function_on_whole_task(self): # Add starting shapes to be tracked on following frames requested_frame_range = self.task_rel_frame_range shape_template = { - 'attributes': [], - 'group': None, - 'label_id': self.labels[0]["id"], - 'occluded': False, - 'points': [0, 5, 5, 0], - 'source': 'manual', - 'type': 'rectangle', - 'z_order': 0, + "attributes": [], + "group": None, + "label_id": self.labels[0]["id"], + "occluded": False, + "points": [0, 5, 5, 0], + "source": "manual", + "type": "rectangle", + "z_order": 0, } - response = self._put_request(f'/api/tasks/{self.task["id"]}/annotations', self.admin, data={ - 'tags': [], - 'shapes': [ - { 'frame': frame, **shape_template } - for frame in requested_frame_range - ], - 'tracks': [] - }) + response = self._put_request( + f'/api/tasks/{self.task["id"]}/annotations', + self.admin, + data={ + "tags": [], + "shapes": [{"frame": frame, **shape_template} for frame in requested_frame_range], + "tracks": [], + }, + ) self.assertEqual(response.status_code, status.HTTP_200_OK) data = self.common_request_data.copy() - data["cleanup"] = False # cleanup is not compatible with reid + data["cleanup"] = False # cleanup is not compatible with reid self._run_offline_function(self.reid_function_id, data, self.user) response = self._get_request(f'/api/tasks/{self.task["id"]}/annotations', self.admin) @@ -1154,25 +1368,24 @@ def test_can_run_offline_reid_function_on_whole_task(self): [ # The single track will be split by job segments { - 'frame': job["start_frame"], - 'shapes': [ - { 'frame': frame, 'outside': frame > job["stop_frame"] } + "frame": job["start_frame"], + "shapes": [ + {"frame": frame, "outside": frame > job["stop_frame"]} for frame in requested_frame_range if frame in range(job["start_frame"], job["stop_frame"] + self.segment_size) - ] + ], } for job in sorted(self.jobs, key=lambda j: j["start_frame"]) ], [ { - 'frame': track['frame'], - 'shapes': [ - filter_dict(shape, keep=['frame', 'outside']) - for shape in track["shapes"] - ] + "frame": track["frame"], + "shapes": [ + filter_dict(shape, keep=["frame", "outside"]) for shape in track["shapes"] + ], } - for track in annotations['tracks'] - ] + for track in annotations["tracks"] + ], ) def test_can_run_offline_detector_function_on_whole_job(self): @@ -1190,13 +1403,11 @@ def test_can_run_offline_detector_function_on_whole_job(self): requested_frame_range = range(job["start_frame"], job["stop_frame"] + 1) self.assertEqual( - { - frame: 1 for frame in requested_frame_range - }, + {frame: 1 for frame in requested_frame_range}, { frame: len(list(group)) for frame, group in groupby(annotations["shapes"], key=lambda a: a["frame"]) - } + }, ) def test_can_run_offline_reid_function_on_whole_job(self): @@ -1205,27 +1416,28 @@ def test_can_run_offline_reid_function_on_whole_job(self): # Add starting shapes to be tracked on following frames shape_template = { - 'attributes': [], - 'group': None, - 'label_id': self.labels[0]["id"], - 'occluded': False, - 'points': [0, 5, 5, 0], - 'source': 'manual', - 'type': 'rectangle', - 'z_order': 0, + "attributes": [], + "group": None, + "label_id": self.labels[0]["id"], + "occluded": False, + "points": [0, 5, 5, 0], + "source": "manual", + "type": "rectangle", + "z_order": 0, } - response = self._put_request(f'/api/jobs/{job["id"]}/annotations', self.admin, data={ - 'tags': [], - 'shapes': [ - { 'frame': frame, **shape_template } - for frame in requested_frame_range - ], - 'tracks': [] - }) + response = self._put_request( + f'/api/jobs/{job["id"]}/annotations', + self.admin, + data={ + "tags": [], + "shapes": [{"frame": frame, **shape_template} for frame in requested_frame_range], + "tracks": [], + }, + ) self.assertEqual(response.status_code, status.HTTP_200_OK) data = self.common_request_data.copy() - data["cleanup"] = False # cleanup is not compatible with reid + data["cleanup"] = False # cleanup is not compatible with reid data["job"] = job["id"] self._run_offline_function(self.reid_function_id, data, self.user) @@ -1238,34 +1450,37 @@ def test_can_run_offline_reid_function_on_whole_job(self): self.assertEqual( [ { - 'frame': job["start_frame"], - 'shapes': [ - { 'frame': frame, 'outside': frame > job["stop_frame"] } + "frame": job["start_frame"], + "shapes": [ + {"frame": frame, "outside": frame > job["stop_frame"]} for frame in requested_frame_range if frame in range(job["start_frame"], job["stop_frame"] + self.segment_size) - ] + ], } ], [ { - 'frame': track['frame'], - 'shapes': [ - filter_dict(shape, keep=['frame', 'outside']) - for shape in track["shapes"] - ] + "frame": track["frame"], + "shapes": [ + filter_dict(shape, keep=["frame", "outside"]) for shape in track["shapes"] + ], } - for track in annotations['tracks'] - ] + for track in annotations["tracks"] + ], ) def test_can_run_offline_detector_function_on_whole_gt_job(self): requested_frame_range = self.task_rel_frame_range[::3] - response = self._post_request("/api/jobs", self.admin, data={ - "type": "ground_truth", - "task_id": self.task["id"], - "frame_selection_method": "manual", - "frames": list(requested_frame_range), - }) + response = self._post_request( + "/api/jobs", + self.admin, + data={ + "type": "ground_truth", + "task_id": self.task["id"], + "frame_selection_method": "manual", + "frames": list(requested_frame_range), + }, + ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) job = response.json() @@ -1281,49 +1496,54 @@ def test_can_run_offline_detector_function_on_whole_gt_job(self): self.assertEqual(len(annotations["tracks"]), 0) self.assertEqual( - { frame: 1 for frame in requested_frame_range }, - Counter(a["frame"] for a in annotations["shapes"]) + {frame: 1 for frame in requested_frame_range}, + Counter(a["frame"] for a in annotations["shapes"]), ) response = self._get_request(f'/api/tasks/{self.task["id"]}/annotations', self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) annotations = response.json() - self.assertEqual(annotations, {'version': 0, 'tags': [], 'shapes': [], 'tracks': []}) + self.assertEqual(annotations, {"version": 0, "tags": [], "shapes": [], "tracks": []}) def test_can_run_offline_reid_function_on_whole_gt_job(self): requested_frame_range = self.task_rel_frame_range[::3] - response = self._post_request("/api/jobs", self.admin, data={ - "type": "ground_truth", - "task_id": self.task["id"], - "frame_selection_method": "manual", - "frames": list(requested_frame_range), - }) + response = self._post_request( + "/api/jobs", + self.admin, + data={ + "type": "ground_truth", + "task_id": self.task["id"], + "frame_selection_method": "manual", + "frames": list(requested_frame_range), + }, + ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) job = response.json() # Add starting shapes to be tracked on following frames shape_template = { - 'attributes': [], - 'group': None, - 'label_id': self.labels[0]["id"], - 'occluded': False, - 'points': [0, 5, 5, 0], - 'source': 'manual', - 'type': 'rectangle', - 'z_order': 0, + "attributes": [], + "group": None, + "label_id": self.labels[0]["id"], + "occluded": False, + "points": [0, 5, 5, 0], + "source": "manual", + "type": "rectangle", + "z_order": 0, } - response = self._put_request(f'/api/jobs/{job["id"]}/annotations', self.admin, data={ - 'tags': [], - 'shapes': [ - { 'frame': frame, **shape_template } - for frame in requested_frame_range - ], - 'tracks': [] - }) + response = self._put_request( + f'/api/jobs/{job["id"]}/annotations', + self.admin, + data={ + "tags": [], + "shapes": [{"frame": frame, **shape_template} for frame in requested_frame_range], + "tracks": [], + }, + ) self.assertEqual(response.status_code, status.HTTP_200_OK) data = self.common_request_data.copy() - data["cleanup"] = False # cleanup is not compatible with reid + data["cleanup"] = False # cleanup is not compatible with reid data["job"] = job["id"] self._run_offline_function(self.reid_function_id, data, self.user) @@ -1336,38 +1556,41 @@ def test_can_run_offline_reid_function_on_whole_gt_job(self): self.assertEqual( [ { - 'frame': job["start_frame"], - 'shapes': [ - { 'frame': frame, 'outside': frame > job["stop_frame"] } + "frame": job["start_frame"], + "shapes": [ + {"frame": frame, "outside": frame > job["stop_frame"]} for frame in requested_frame_range if frame in range(job["start_frame"], job["stop_frame"] + self.segment_size) - ] + ], } ], [ { - 'frame': track['frame'], - 'shapes': [ - filter_dict(shape, keep=['frame', 'outside']) - for shape in track["shapes"] - ] + "frame": track["frame"], + "shapes": [ + filter_dict(shape, keep=["frame", "outside"]) for shape in track["shapes"] + ], } - for track in annotations['tracks'] - ] + for track in annotations["tracks"] + ], ) response = self._get_request(f'/api/tasks/{self.task["id"]}/annotations', self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) annotations = response.json() - self.assertEqual(annotations, {'version': 0, 'tags': [], 'shapes': [], 'tracks': []}) + self.assertEqual(annotations, {"version": 0, "tags": [], "shapes": [], "tracks": []}) def test_offline_function_run_on_task_does_not_affect_gt_job(self): - response = self._post_request("/api/jobs", self.admin, data={ - "type": "ground_truth", - "task_id": self.task["id"], - "frame_selection_method": "manual", - "frames": list(self.task_rel_frame_range[::3]), - }) + response = self._post_request( + "/api/jobs", + self.admin, + data={ + "type": "ground_truth", + "task_id": self.task["id"], + "frame_selection_method": "manual", + "frames": list(self.task_rel_frame_range[::3]), + }, + ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) job = response.json() @@ -1383,14 +1606,14 @@ def test_offline_function_run_on_task_does_not_affect_gt_job(self): requested_frame_range = self.task_rel_frame_range self.assertEqual( - { frame: 1 for frame in requested_frame_range }, - Counter(a["frame"] for a in annotations["shapes"]) + {frame: 1 for frame in requested_frame_range}, + Counter(a["frame"] for a in annotations["shapes"]), ) response = self._get_request(f'/api/jobs/{job["id"]}/annotations', self.admin) self.assertEqual(response.status_code, status.HTTP_200_OK) annotations = response.json() - self.assertEqual(annotations, {'version': 0, 'tags': [], 'shapes': [], 'tracks': []}) + self.assertEqual(annotations, {"version": 0, "tags": [], "shapes": [], "tracks": []}) def test_can_run_online_function_on_valid_task_frame(self): data = self.common_request_data.copy() @@ -1441,70 +1664,87 @@ class Issue4996_Cases(_LambdaTestCaseBase): # This requires to pass the job id in the call request. def _create_org(self, *, owner: int, members: dict[int, str] = None) -> dict: - org = self._post_request('/api/organizations', user=owner, data={ - "slug": "testorg", - "name": "test Org", - }) + org = self._post_request( + "/api/organizations", + user=owner, + data={ + "slug": "testorg", + "name": "test Org", + }, + ) assert org.status_code == status.HTTP_201_CREATED org = org.json() for uid, role in members.items(): - user = self._get_request('/api/users/self', user=uid) + user = self._get_request("/api/users/self", user=uid) assert user.status_code == status.HTTP_200_OK user = user.json() - invitation = self._post_request('/api/invitations', user=owner, data={ - 'email': user['email'], - 'role': role, - }, org_id=org['id']) + invitation = self._post_request( + "/api/invitations", + user=owner, + data={ + "email": user["email"], + "role": role, + }, + org_id=org["id"], + ) assert invitation.status_code == status.HTTP_201_CREATED return org - def _set_task_assignee(self, task: int, assignee: Optional[int], *, - org_id: Optional[int] = None): - response = self._patch_request(f'/api/tasks/{task}', user=self.admin, data={ - 'assignee_id': assignee, - }, org_id=org_id) + def _set_task_assignee( + self, task: int, assignee: Optional[int], *, org_id: Optional[int] = None + ): + response = self._patch_request( + f"/api/tasks/{task}", + user=self.admin, + data={ + "assignee_id": assignee, + }, + org_id=org_id, + ) assert response.status_code == status.HTTP_200_OK - def _set_job_assignee(self, job: int, assignee: Optional[int], *, - org_id: Optional[int] = None): - response = self._patch_request(f'/api/jobs/{job}', user=self.admin, data={ - 'assignee': assignee, - }, org_id=org_id) + def _set_job_assignee(self, job: int, assignee: Optional[int], *, org_id: Optional[int] = None): + response = self._patch_request( + f"/api/jobs/{job}", + user=self.admin, + data={ + "assignee": assignee, + }, + org_id=org_id, + ) assert response.status_code == status.HTTP_200_OK def setUp(self): super().setUp() - self.org = self._create_org(owner=self.admin, members={self.user: 'worker'}) + self.org = self._create_org(owner=self.admin, members={self.user: "worker"}) - task = self._create_task(task_spec={ - 'name': 'test_task', - 'labels': [{'name': 'car'}], - 'segment_size': 2 - }, + task = self._create_task( + task_spec={"name": "test_task", "labels": [{"name": "car"}], "segment_size": 2}, data=self._generate_task_images(6), owner=self.admin, - org_id=self.org['id'], + org_id=self.org["id"], ) self.task = task - jobs = get_paginated_collection(lambda page: - self._get_request( + jobs = get_paginated_collection( + lambda page: self._get_request( f"/api/jobs?task_id={self.task['id']}&page={page}", - self.admin, org_id=self.org['id'] + self.admin, + org_id=self.org["id"], ) ) self.job = jobs[1] self.common_request_data = { - "task": self.task['id'], + "task": self.task["id"], "frame": 0, "cleanup": True, "mapping": { - "car": { "name": "car" }, + "car": {"name": "car"}, }, } @@ -1512,75 +1752,70 @@ def setUp(self): def _get_valid_job_request_data(self): data = self.common_request_data.copy() - data.update({ - "job": self.job['id'], - "frame": 2 - }) + data.update({"job": self.job["id"], "frame": 2}) return data def _get_invalid_job_request_data(self): data = self.common_request_data.copy() - data.update({ - "job": self.job['id'], - "frame": 0 - }) + data.update({"job": self.job["id"], "frame": 0}) return data - def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_task_request(self): + def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_task_request( + self, + ): data = self.common_request_data.copy() with self.subTest(job=None, assignee=None): - response = self._post_request(self.function_url, self.user, data, - org_id=self.org['id']) + response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"]) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_job_request(self): data = self._get_valid_job_request_data() - with self.subTest(job='defined', assignee=None): - response = self._post_request(self.function_url, self.user, data, - org_id=self.org['id']) + with self.subTest(job="defined", assignee=None): + response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"]) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - def test_can_call_function_for_job_worker_in_org__allow_task_assigned_worker_with_task_request(self): - self._set_task_assignee(self.task['id'], self.user.id, org_id=self.org['id']) + def test_can_call_function_for_job_worker_in_org__allow_task_assigned_worker_with_task_request( + self, + ): + self._set_task_assignee(self.task["id"], self.user.id, org_id=self.org["id"]) data = self.common_request_data.copy() - with self.subTest(job=None, assignee='task'): - response = self._post_request(self.function_url, self.user, data, - org_id=self.org['id']) + with self.subTest(job=None, assignee="task"): + response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"]) self.assertEqual(response.status_code, status.HTTP_200_OK) - def test_can_call_function_for_job_worker_in_org__deny_job_assigned_worker_with_task_request(self): - self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + def test_can_call_function_for_job_worker_in_org__deny_job_assigned_worker_with_task_request( + self, + ): + self._set_job_assignee(self.job["id"], self.user.id, org_id=self.org["id"]) data = self.common_request_data.copy() - with self.subTest(job=None, assignee='job'): - response = self._post_request(self.function_url, self.user, data, - org_id=self.org['id']) + with self.subTest(job=None, assignee="job"): + response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"]) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - def test_can_call_function_for_job_worker_in_org__allow_job_assigned_worker_with_job_request(self): - self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + def test_can_call_function_for_job_worker_in_org__allow_job_assigned_worker_with_job_request( + self, + ): + self._set_job_assignee(self.job["id"], self.user.id, org_id=self.org["id"]) data = self._get_valid_job_request_data() - with self.subTest(job='defined', assignee='job'): - response = self._post_request(self.function_url, self.user, data, - org_id=self.org['id']) + with self.subTest(job="defined", assignee="job"): + response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"]) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_can_check_job_boundaries_in_function_call__fail_for_frame_outside_job(self): - self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + self._set_job_assignee(self.job["id"], self.user.id, org_id=self.org["id"]) data = self._get_invalid_job_request_data() - with self.subTest(job='defined', frame='outside'): - response = self._post_request(self.function_url, self.user, data, - org_id=self.org['id']) + with self.subTest(job="defined", frame="outside"): + response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"]) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_can_check_job_boundaries_in_function_call__ok_for_frame_inside_job(self): - self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id']) + self._set_job_assignee(self.job["id"], self.user.id, org_id=self.org["id"]) data = self._get_valid_job_request_data() - with self.subTest(job='defined', frame='inside'): - response = self._post_request(self.function_url, self.user, data, - org_id=self.org['id']) + with self.subTest(job="defined", frame="inside"): + response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"]) self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/cvat/apps/lambda_manager/urls.py b/cvat/apps/lambda_manager/urls.py index 6dae0edaca76..261592a9f469 100644 --- a/cvat/apps/lambda_manager/urls.py +++ b/cvat/apps/lambda_manager/urls.py @@ -12,9 +12,9 @@ # I want to "call" my functions. To do that need to map my call method to # POST (like get HTTP method is mapped to list(...)). One way is to implement # own CustomRouter. But it is simpler just patch the router instance here. -router.routes[2].mapping.update({'post': 'call'}) -router.register('functions', views.FunctionViewSet, basename='lambda_function') -router.register('requests', views.RequestViewSet, basename='lambda_request') +router.routes[2].mapping.update({"post": "call"}) +router.register("functions", views.FunctionViewSet, basename="lambda_function") +router.register("requests", views.RequestViewSet, basename="lambda_request") # GET /api/lambda/functions - get list of functions # GET /api/lambda/functions/ - get information about the function @@ -24,6 +24,4 @@ # GET /api/lambda/requests - get list of requests # GET /api/lambda/requests/ - get status of the request # DEL /api/lambda/requests/ - cancel a request (don't delete) -urlpatterns = [ - path('api/lambda/', include(router.urls)) -] +urlpatterns = [path("api/lambda/", include(router.urls))] diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 559ef29813b5..465414e243a5 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -19,53 +19,75 @@ import numpy as np import requests import rq -from cvat.apps.events.handlers import handle_function_call -from cvat.apps.lambda_manager.signals import interactive_function_call_signal from django.conf import settings from django.core.exceptions import ObjectDoesNotExist, ValidationError from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import (OpenApiParameter, OpenApiResponse, - extend_schema, extend_schema_view, - inline_serializer) +from drf_spectacular.utils import ( + OpenApiParameter, + OpenApiResponse, + extend_schema, + extend_schema_view, + inline_serializer, +) from rest_framework import serializers, status, viewsets -from rest_framework.response import Response from rest_framework.request import Request +from rest_framework.response import Response import cvat.apps.dataset_manager as dm from cvat.apps.engine.frame_provider import TaskFrameProvider +from cvat.apps.engine.log import ServerLogManager from cvat.apps.engine.models import ( - Job, ShapeType, SourceType, Task, Label, RequestAction, RequestTarget + Job, + Label, + RequestAction, + RequestTarget, + ShapeType, + SourceType, + Task, ) from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField from cvat.apps.engine.serializers import LabeledDataSerializer +from cvat.apps.engine.utils import define_dependent_job, get_rq_job_meta, get_rq_lock_by_user +from cvat.apps.events.handlers import handle_function_call +from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS from cvat.apps.lambda_manager.models import FunctionKind from cvat.apps.lambda_manager.permissions import LambdaPermission from cvat.apps.lambda_manager.serializers import ( - FunctionCallRequestSerializer, FunctionCallSerializer + FunctionCallRequestSerializer, + FunctionCallSerializer, ) -from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.utils import define_dependent_job, get_rq_job_meta, get_rq_lock_by_user +from cvat.apps.lambda_manager.signals import interactive_function_call_signal from cvat.utils.http import make_requests_session -from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS slogger = ServerLogManager(__name__) + class LambdaGateway: - NUCLIO_ROOT_URL = '/api/functions' - - def _http(self, method="get", scheme=None, host=None, port=None, - function_namespace=None, url=None, headers=None, data=None): - NUCLIO_GATEWAY = '{}://{}:{}'.format( - scheme or settings.NUCLIO['SCHEME'], - host or settings.NUCLIO['HOST'], - port or settings.NUCLIO['PORT']) - NUCLIO_FUNCTION_NAMESPACE = function_namespace or settings.NUCLIO['FUNCTION_NAMESPACE'] - NUCLIO_TIMEOUT = settings.NUCLIO['DEFAULT_TIMEOUT'] + NUCLIO_ROOT_URL = "/api/functions" + + def _http( + self, + method="get", + scheme=None, + host=None, + port=None, + function_namespace=None, + url=None, + headers=None, + data=None, + ): + NUCLIO_GATEWAY = "{}://{}:{}".format( + scheme or settings.NUCLIO["SCHEME"], + host or settings.NUCLIO["HOST"], + port or settings.NUCLIO["PORT"], + ) + NUCLIO_FUNCTION_NAMESPACE = function_namespace or settings.NUCLIO["FUNCTION_NAMESPACE"] + NUCLIO_TIMEOUT = settings.NUCLIO["DEFAULT_TIMEOUT"] extra_headers = { - 'x-nuclio-project-name': 'cvat', - 'x-nuclio-function-namespace': NUCLIO_FUNCTION_NAMESPACE, - 'x-nuclio-invoke-via': 'domain-name', - 'X-Nuclio-Invoke-Timeout': f"{NUCLIO_TIMEOUT}s", + "x-nuclio-project-name": "cvat", + "x-nuclio-function-namespace": NUCLIO_FUNCTION_NAMESPACE, + "x-nuclio-invoke-via": "domain-name", + "X-Nuclio-Invoke-Timeout": f"{NUCLIO_TIMEOUT}s", } if headers: extra_headers.update(headers) @@ -76,8 +98,9 @@ def _http(self, method="get", scheme=None, host=None, port=None, url = NUCLIO_GATEWAY with make_requests_session() as session: - reply = session.request(method, url, headers=extra_headers, - timeout=NUCLIO_TIMEOUT, json=data) + reply = session.request( + method, url, headers=extra_headers, timeout=NUCLIO_TIMEOUT, json=data + ) reply.raise_for_status() response = reply.json() @@ -92,32 +115,33 @@ def list(self): slogger.glob.error("Failed to parse lambda function metadata", exc_info=True) def get(self, func_id): - data = self._http(url=self.NUCLIO_ROOT_URL + '/' + func_id) + data = self._http(url=self.NUCLIO_ROOT_URL + "/" + func_id) response = LambdaFunction(self, data) return response def invoke(self, func, payload): invoke_method = { - 'dashboard': self._invoke_via_dashboard, - 'direct': self._invoke_directly, + "dashboard": self._invoke_via_dashboard, + "direct": self._invoke_directly, } - return invoke_method[settings.NUCLIO['INVOKE_METHOD']](func, payload) + return invoke_method[settings.NUCLIO["INVOKE_METHOD"]](func, payload) def _invoke_via_dashboard(self, func, payload): - return self._http(method="post", url='/api/function_invocations', - data=payload, headers={ - 'x-nuclio-function-name': func.id, - 'x-nuclio-path': '/' - }) + return self._http( + method="post", + url="/api/function_invocations", + data=payload, + headers={"x-nuclio-function-name": func.id, "x-nuclio-path": "/"}, + ) def _invoke_directly(self, func, payload): # host.docker.internal for Linux will work only with Docker 20.10+ - NUCLIO_TIMEOUT = settings.NUCLIO['DEFAULT_TIMEOUT'] - if os.path.exists('/.dockerenv'): # inside a docker container - url = f'http://host.docker.internal:{func.port}' + NUCLIO_TIMEOUT = settings.NUCLIO["DEFAULT_TIMEOUT"] + if os.path.exists("/.dockerenv"): # inside a docker container + url = f"http://host.docker.internal:{func.port}" else: - url = f'http://localhost:{func.port}' + url = f"http://localhost:{func.port}" with make_requests_session() as session: reply = session.post(url, timeout=NUCLIO_TIMEOUT, json=payload) @@ -126,105 +150,119 @@ def _invoke_directly(self, func, payload): return response + class InvalidFunctionMetadataError(Exception): pass + class LambdaFunction: FRAME_PARAMETERS = ( - ('frame', 'frame'), - ('frame0', 'start frame'), - ('frame1', 'end frame'), + ("frame", "frame"), + ("frame0", "start frame"), + ("frame1", "end frame"), ) def __init__(self, gateway, data): # ID of the function (e.g. omz.public.yolo-v3) - self.id = data['metadata']['name'] + self.id = data["metadata"]["name"] # type of the function (e.g. detector, interactor) - meta_anno = data['metadata']['annotations'] - kind = meta_anno.get('type') + meta_anno = data["metadata"]["annotations"] + kind = meta_anno.get("type") try: self.kind = FunctionKind(kind) except ValueError as e: raise InvalidFunctionMetadataError( - f"{self.id} lambda function has unknown type: {kind!r}") from e + f"{self.id} lambda function has unknown type: {kind!r}" + ) from e # dictionary of labels for the function (e.g. car, person) - spec = json.loads(meta_anno.get('spec') or '[]') + spec = json.loads(meta_anno.get("spec") or "[]") def parse_labels(spec): def parse_attributes(attrs_spec): - parsed_attributes = [{ - 'name': attr['name'], - 'input_type': attr['input_type'], - 'values': attr['values'], - } for attr in attrs_spec] - - if len(parsed_attributes) != len({attr['name'] for attr in attrs_spec}): + parsed_attributes = [ + { + "name": attr["name"], + "input_type": attr["input_type"], + "values": attr["values"], + } + for attr in attrs_spec + ] + + if len(parsed_attributes) != len({attr["name"] for attr in attrs_spec}): raise InvalidFunctionMetadataError( - f"{self.id} lambda function has non-unique attributes") + f"{self.id} lambda function has non-unique attributes" + ) return parsed_attributes parsed_labels = [] for label in spec: parsed_label = { - 'name': label['name'], - 'type': label.get('type', 'unknown'), - 'attributes': parse_attributes(label.get('attributes', [])) + "name": label["name"], + "type": label.get("type", "unknown"), + "attributes": parse_attributes(label.get("attributes", [])), } - if parsed_label['type'] == 'skeleton': - parsed_label.update({ - 'sublabels': parse_labels(label['sublabels']), - 'svg': label['svg'] - }) + if parsed_label["type"] == "skeleton": + parsed_label.update( + {"sublabels": parse_labels(label["sublabels"]), "svg": label["svg"]} + ) parsed_labels.append(parsed_label) - if len(parsed_labels) != len({label['name'] for label in spec}): + if len(parsed_labels) != len({label["name"] for label in spec}): raise InvalidFunctionMetadataError( - f"{self.id} lambda function has non-unique labels") + f"{self.id} lambda function has non-unique labels" + ) return parsed_labels self.labels = parse_labels(spec) # mapping of labels and corresponding supported attributes - self.func_attributes = {item['name']: item.get('attributes', []) for item in spec} + self.func_attributes = {item["name"]: item.get("attributes", []) for item in spec} for label, attributes in self.func_attributes.items(): - if len([attr['name'] for attr in attributes]) != len(set([attr['name'] for attr in attributes])): + if len([attr["name"] for attr in attributes]) != len( + set([attr["name"] for attr in attributes]) + ): raise InvalidFunctionMetadataError( - "`{}` lambda function has non-unique attributes for label {}".format(self.id, label)) + "`{}` lambda function has non-unique attributes for label {}".format( + self.id, label + ) + ) # description of the function - self.description = data['spec']['description'] + self.description = data["spec"]["description"] # http port to access the serverless function self.port = data["status"].get("httpPort") # display name for the function - self.name = meta_anno.get('name', self.id) - self.min_pos_points = int(meta_anno.get('min_pos_points', 1)) - self.min_neg_points = int(meta_anno.get('min_neg_points', -1)) - self.startswith_box = bool(meta_anno.get('startswith_box', False)) - self.startswith_box_optional = bool(meta_anno.get('startswith_box_optional', False)) - self.animated_gif = meta_anno.get('animated_gif', '') - self.version = int(meta_anno.get('version', '1')) - self.help_message = meta_anno.get('help_message', '') + self.name = meta_anno.get("name", self.id) + self.min_pos_points = int(meta_anno.get("min_pos_points", 1)) + self.min_neg_points = int(meta_anno.get("min_neg_points", -1)) + self.startswith_box = bool(meta_anno.get("startswith_box", False)) + self.startswith_box_optional = bool(meta_anno.get("startswith_box_optional", False)) + self.animated_gif = meta_anno.get("animated_gif", "") + self.version = int(meta_anno.get("version", "1")) + self.help_message = meta_anno.get("help_message", "") self.gateway = gateway def to_dict(self): response = { - 'id': self.id, - 'kind': str(self.kind), - 'labels_v2': self.labels, - 'description': self.description, - 'name': self.name, - 'version': self.version + "id": self.id, + "kind": str(self.kind), + "labels_v2": self.labels, + "description": self.description, + "name": self.name, + "version": self.version, } if self.kind is FunctionKind.INTERACTOR: - response.update({ - 'min_pos_points': self.min_pos_points, - 'min_neg_points': self.min_neg_points, - 'startswith_box': self.startswith_box, - 'startswith_box_optional': self.startswith_box_optional, - 'help_message': self.help_message, - 'animated_gif': self.animated_gif - }) + response.update( + { + "min_pos_points": self.min_pos_points, + "min_neg_points": self.min_neg_points, + "startswith_box": self.startswith_box, + "startswith_box_optional": self.startswith_box_optional, + "help_message": self.help_message, + "animated_gif": self.animated_gif, + } + ) return response @@ -235,62 +273,75 @@ def invoke( *, db_job: Optional[Job] = None, is_interactive: Optional[bool] = False, - request: Optional[Request] = None + request: Optional[Request] = None, ): if db_job is not None and db_job.get_task_id() != db_task.id: - raise ValidationError("Job task id does not match task id", - code=status.HTTP_400_BAD_REQUEST + raise ValidationError( + "Job task id does not match task id", code=status.HTTP_400_BAD_REQUEST ) payload = {} - data = {k: v for k,v in data.items() if v is not None} + data = {k: v for k, v in data.items() if v is not None} def mandatory_arg(name: str) -> Any: try: return data[name] except KeyError: raise ValidationError( - "`{}` lambda function was called without mandatory argument: {}" - .format(self.id, name), - code=status.HTTP_400_BAD_REQUEST) + "`{}` lambda function was called without mandatory argument: {}".format( + self.id, name + ), + code=status.HTTP_400_BAD_REQUEST, + ) threshold = data.get("threshold") if threshold: - payload.update({ "threshold": threshold }) + payload.update({"threshold": threshold}) mapping = data.get("mapping", {}) model_labels = self.labels task_labels = db_task.get_labels(prefetch=True) def labels_compatible(model_label: dict, task_label: Label) -> bool: - model_type = model_label['type'] + model_type = model_label["type"] db_type = task_label.type compatible_types = [[ShapeType.MASK, ShapeType.POLYGON]] - return model_type == db_type or \ - (db_type == 'any' and model_type != 'skeleton') or \ - (model_type == 'unknown' and db_type != 'skeleton') or \ - any([model_type in compatible and db_type in compatible for compatible in compatible_types]) + return ( + model_type == db_type + or (db_type == "any" and model_type != "skeleton") + or (model_type == "unknown" and db_type != "skeleton") + or any( + [ + model_type in compatible and db_type in compatible + for compatible in compatible_types + ] + ) + ) def make_default_mapping(model_labels, task_labels): mapping_by_default = {} for model_label in model_labels: for task_label in task_labels: - if task_label.name == model_label['name'] and labels_compatible(model_label, task_label): + if task_label.name == model_label["name"] and labels_compatible( + model_label, task_label + ): attributes_default_mapping = {} - for model_attr in model_label.get('attributes', {}): + for model_attr in model_label.get("attributes", {}): for db_attr in task_label.attributespec_set.all(): - if db_attr.name == model_attr['name']: - attributes_default_mapping[model_attr['name']] = db_attr.name + if db_attr.name == model_attr["name"]: + attributes_default_mapping[model_attr["name"]] = db_attr.name - mapping_by_default[model_label['name']] = { - 'name': task_label.name, - 'attributes': attributes_default_mapping, + mapping_by_default[model_label["name"]] = { + "name": task_label.name, + "attributes": attributes_default_mapping, } - if model_label['type'] == 'skeleton' and task_label.type == 'skeleton': - mapping_by_default[model_label['name']]['sublabels'] = make_default_mapping( - model_label['sublabels'], - task_label.sublabels.all(), + if model_label["type"] == "skeleton" and task_label.type == "skeleton": + mapping_by_default[model_label["name"]]["sublabels"] = ( + make_default_mapping( + model_label["sublabels"], + task_label.sublabels.all(), + ) ) return mapping_by_default @@ -298,39 +349,43 @@ def make_default_mapping(model_labels, task_labels): def update_mapping(_mapping, _model_labels, _db_labels): copy = deepcopy(_mapping) for model_label_name, mapping_item in copy.items(): - md_label = next(filter(lambda x: x['name'] == model_label_name, _model_labels)) - db_label = next(filter(lambda x: x.name == mapping_item['name'], _db_labels)) - mapping_item.setdefault('attributes', {}) - mapping_item['md_label'] = md_label - mapping_item['db_label'] = db_label - if md_label['type'] == 'skeleton' and db_label.type == 'skeleton': - mapping_item['sublabels'] = update_mapping( - mapping_item['sublabels'], - md_label['sublabels'], - db_label.sublabels.all() + md_label = next(filter(lambda x: x["name"] == model_label_name, _model_labels)) + db_label = next(filter(lambda x: x.name == mapping_item["name"], _db_labels)) + mapping_item.setdefault("attributes", {}) + mapping_item["md_label"] = md_label + mapping_item["db_label"] = db_label + if md_label["type"] == "skeleton" and db_label.type == "skeleton": + mapping_item["sublabels"] = update_mapping( + mapping_item["sublabels"], md_label["sublabels"], db_label.sublabels.all() ) return copy def validate_labels_mapping(_mapping, _model_labels, _db_labels): def validate_attributes_mapping(attributes_mapping, model_attributes, db_attributes): db_attr_names = [attr.name for attr in db_attributes] - model_attr_names = [attr['name'] for attr in model_attributes] + model_attr_names = [attr["name"] for attr in model_attributes] for model_attr in attributes_mapping: task_attr = attributes_mapping[model_attr] if model_attr not in model_attr_names: - raise ValidationError(f'Invalid mapping. Unknown model attribute "{model_attr}"') + raise ValidationError( + f'Invalid mapping. Unknown model attribute "{model_attr}"' + ) if task_attr not in db_attr_names: - raise ValidationError(f'Invalid mapping. Unknown db attribute "{task_attr}"') + raise ValidationError( + f'Invalid mapping. Unknown db attribute "{task_attr}"' + ) for model_label_name, mapping_item in _mapping.items(): - db_label_name = mapping_item['name'] + db_label_name = mapping_item["name"] md_label = None db_label = None try: - md_label = next(x for x in _model_labels if x['name'] == model_label_name) + md_label = next(x for x in _model_labels if x["name"] == model_label_name) except StopIteration: - raise ValidationError(f'Invalid mapping. Unknown model label "{model_label_name}"') + raise ValidationError( + f'Invalid mapping. Unknown model label "{model_label_name}"' + ) try: db_label = next(x for x in _db_labels if x.name == db_label_name) @@ -339,26 +394,24 @@ def validate_attributes_mapping(attributes_mapping, model_attributes, db_attribu if not labels_compatible(md_label, db_label): raise ValidationError( - f'Invalid mapping. Model label "{model_label_name}" and' + \ - f' database label "{db_label_name}" are not compatible' + f'Invalid mapping. Model label "{model_label_name}" and' + + f' database label "{db_label_name}" are not compatible' ) validate_attributes_mapping( - mapping_item.get('attributes', {}), - md_label['attributes'], - db_label.attributespec_set.all() + mapping_item.get("attributes", {}), + md_label["attributes"], + db_label.attributespec_set.all(), ) - if md_label['type'] == 'skeleton' and db_label.type == 'skeleton': - if 'sublabels' not in mapping_item: + if md_label["type"] == "skeleton" and db_label.type == "skeleton": + if "sublabels" not in mapping_item: raise ValidationError( f'Mapping for elements was not specified in skeleton "{model_label_name}" ' ) validate_labels_mapping( - mapping_item['sublabels'], - md_label['sublabels'], - db_label.sublabels.all() + mapping_item["sublabels"], md_label["sublabels"], db_label.sublabels.all() ) if not mapping: @@ -380,44 +433,46 @@ def validate_attributes_mapping(attributes_mapping, model_attributes, db_attribu abs_frame_id = data_start_frame + data[key] * step if not db_job.segment.contains_frame(abs_frame_id): - raise ValidationError(f"The {desc} is outside the job range", - code=status.HTTP_400_BAD_REQUEST) - + raise ValidationError( + f"The {desc} is outside the job range", code=status.HTTP_400_BAD_REQUEST + ) if self.kind == FunctionKind.DETECTOR: - payload.update({ - "image": self._get_image(db_task, mandatory_arg("frame")) - }) + payload.update({"image": self._get_image(db_task, mandatory_arg("frame"))}) elif self.kind == FunctionKind.INTERACTOR: - payload.update({ - "image": self._get_image(db_task, mandatory_arg("frame")), - "pos_points": mandatory_arg("pos_points"), - "neg_points": mandatory_arg("neg_points"), - "obj_bbox": data.get("obj_bbox", None) - }) + payload.update( + { + "image": self._get_image(db_task, mandatory_arg("frame")), + "pos_points": mandatory_arg("pos_points"), + "neg_points": mandatory_arg("neg_points"), + "obj_bbox": data.get("obj_bbox", None), + } + ) elif self.kind == FunctionKind.REID: - payload.update({ - "image0": self._get_image(db_task, mandatory_arg("frame0")), - "image1": self._get_image(db_task, mandatory_arg("frame1")), - "boxes0": mandatory_arg("boxes0"), - "boxes1": mandatory_arg("boxes1") - }) + payload.update( + { + "image0": self._get_image(db_task, mandatory_arg("frame0")), + "image1": self._get_image(db_task, mandatory_arg("frame1")), + "boxes0": mandatory_arg("boxes0"), + "boxes1": mandatory_arg("boxes1"), + } + ) max_distance = data.get("max_distance") if max_distance: - payload.update({ - "max_distance": max_distance - }) + payload.update({"max_distance": max_distance}) elif self.kind == FunctionKind.TRACKER: - payload.update({ - "image": self._get_image(db_task, mandatory_arg("frame")), - "shapes": data.get("shapes", []), - "states": data.get("states", []) - }) + payload.update( + { + "image": self._get_image(db_task, mandatory_arg("frame")), + "shapes": data.get("shapes", []), + "states": data.get("states", []), + } + ) else: raise ValidationError( - '`{}` lambda function has incorrect type: {}' - .format(self.id, self.kind), - code=status.HTTP_500_INTERNAL_SERVER_ERROR) + "`{}` lambda function has incorrect type: {}".format(self.id, self.kind), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) if is_interactive and request: interactive_function_call_signal.send(sender=self, request=request) @@ -445,41 +500,38 @@ def check_attr_value(value, db_attr): def transform_attributes(input_attributes, attr_mapping, db_attributes): attributes = [] for attr in input_attributes: - if attr['name'] not in attr_mapping: + if attr["name"] not in attr_mapping: continue - db_attr_name = attr_mapping[attr['name']] - db_attr = next(filter(lambda x: x['name'] == db_attr_name, db_attributes), None) - if db_attr is not None and check_attr_value(attr['value'], db_attr): - attributes.append({ - 'name': db_attr['name'], - 'value': attr['value'] - }) + db_attr_name = attr_mapping[attr["name"]] + db_attr = next(filter(lambda x: x["name"] == db_attr_name, db_attributes), None) + if db_attr is not None and check_attr_value(attr["value"], db_attr): + attributes.append({"name": db_attr["name"], "value": attr["value"]}) return attributes if self.kind == FunctionKind.DETECTOR: for item in response: - item_label = item['label'] + item_label = item["label"] if item_label not in mapping: continue - db_label = mapping[item_label]['db_label'] - item['label'] = db_label.name - item['attributes'] = transform_attributes( - item.get('attributes', {}), - mapping[item_label]['attributes'], - db_label.attributespec_set.values() + db_label = mapping[item_label]["db_label"] + item["label"] = db_label.name + item["attributes"] = transform_attributes( + item.get("attributes", {}), + mapping[item_label]["attributes"], + db_label.attributespec_set.values(), ) - if 'elements' in item: - sublabels = mapping[item_label]['sublabels'] - item['elements'] = [x for x in item['elements'] if x['label'] in sublabels] - for element in item['elements']: - element_label = element['label'] - db_label = sublabels[element_label]['db_label'] - element['label'] = db_label.name - element['attributes'] = transform_attributes( - element.get('attributes', {}), - sublabels[element_label]['attributes'], - db_label.attributespec_set.values() + if "elements" in item: + sublabels = mapping[item_label]["sublabels"] + item["elements"] = [x for x in item["elements"] if x["label"] in sublabels] + for element in item["elements"]: + element_label = element["label"] + db_label = sublabels[element_label]["db_label"] + element["label"] = db_label.name + element["attributes"] = transform_attributes( + element.get("attributes", {}), + sublabels[element_label]["attributes"], + db_label.attributespec_set.values(), ) response_filtered.append(item) response = response_filtered @@ -490,7 +542,8 @@ def _get_image(self, db_task, frame): frame_provider = TaskFrameProvider(db_task) image = frame_provider.get_frame(frame) - return base64.b64encode(image.data.getvalue()).decode('utf-8') + return base64.b64encode(image.data.getvalue()).decode("utf-8") + class LambdaQueue: RESULT_TTL = timedelta(minutes=30) @@ -502,19 +555,29 @@ def _get_queue(self): def get_jobs(self): queue = self._get_queue() # Only failed jobs are not included in the list below. - job_ids = set(queue.get_job_ids() + - queue.started_job_registry.get_job_ids() + - queue.finished_job_registry.get_job_ids() + - queue.scheduled_job_registry.get_job_ids() + - queue.deferred_job_registry.get_job_ids()) + job_ids = set( + queue.get_job_ids() + + queue.started_job_registry.get_job_ids() + + queue.finished_job_registry.get_job_ids() + + queue.scheduled_job_registry.get_job_ids() + + queue.deferred_job_registry.get_job_ids() + ) jobs = queue.job_class.fetch_many(job_ids, queue.connection) return [LambdaJob(job) for job in jobs if job and job.meta.get("lambda")] - def enqueue(self, - lambda_func, threshold, task, mapping, cleanup, conv_mask_to_poly, max_distance, request, + def enqueue( + self, + lambda_func, + threshold, + task, + mapping, + cleanup, + conv_mask_to_poly, + max_distance, + request, *, - job: Optional[int] = None + job: Optional[int] = None, ) -> LambdaJob: queue = self._get_queue() rq_id = RQId(RequestAction.AUTOANNOTATE, RequestTarget.TASK, task).render() @@ -524,8 +587,10 @@ def enqueue(self, # protection. rq_job = queue.fetch_job(rq_id) - have_conflict = rq_job and \ - rq_job.get_status(refresh=False) not in {rq.job.JobStatus.FAILED, rq.job.JobStatus.FINISHED} + have_conflict = rq_job and rq_job.get_status(refresh=False) not in { + rq.job.JobStatus.FAILED, + rq.job.JobStatus.FINISHED, + } # There could be some jobs left over from before the current naming convention was adopted. # TODO: remove this check after a few releases. @@ -536,7 +601,8 @@ def enqueue(self, if have_conflict or have_legacy_conflict: raise ValidationError( "Only one running request is allowed for the same task #{}".format(task), - code=status.HTTP_409_CONFLICT) + code=status.HTTP_409_CONFLICT, + ) if rq_job: rq_job.delete() @@ -548,14 +614,13 @@ def enqueue(self, user_id = request.user.id with get_rq_lock_by_user(queue, user_id): - rq_job = queue.create_job(LambdaJob(None), + rq_job = queue.create_job( + LambdaJob(None), job_id=rq_id, meta={ **get_rq_job_meta( request, - db_obj=( - Job.objects.get(pk=job) if job else Task.objects.get(pk=task) - ), + db_obj=(Job.objects.get(pk=job) if job else Task.objects.get(pk=task)), ), RQJobMetaField.FUNCTION_ID: lambda_func.id, "lambda": True, @@ -568,7 +633,7 @@ def enqueue(self, "cleanup": cleanup, "conv_mask_to_poly": conv_mask_to_poly, "mapping": mapping, - "max_distance": max_distance + "max_distance": max_distance, }, depends_on=define_dependent_job(queue, user_id), result_ttl=self.RESULT_TTL.total_seconds(), @@ -583,36 +648,42 @@ def fetch_job(self, pk): queue = self._get_queue() rq_job = queue.fetch_job(pk) if rq_job is None or not rq_job.meta.get("lambda"): - raise ValidationError("{} lambda job is not found".format(pk), - code=status.HTTP_404_NOT_FOUND) + raise ValidationError( + "{} lambda job is not found".format(pk), code=status.HTTP_404_NOT_FOUND + ) return LambdaJob(rq_job) + class LambdaJob: def __init__(self, job): self.job = job def to_dict(self): lambda_func = self.job.kwargs.get("function") - dict_ = { + dict_ = { "id": self.job.id, "function": { "id": lambda_func.id if lambda_func else None, "threshold": self.job.kwargs.get("threshold"), "task": self.job.kwargs.get("task"), - **({ - "job": self.job.kwargs["job"], - } if self.job.kwargs.get("job") else {}) + **( + { + "job": self.job.kwargs["job"], + } + if self.job.kwargs.get("job") + else {} + ), }, "status": self.job.get_status(), - "progress": self.job.meta.get('progress', 0), + "progress": self.job.meta.get("progress", 0), "enqueued": self.job.enqueued_at, "started": self.job.started_at, "ended": self.job.ended_at, - "exc_info": self.job.exc_info + "exc_info": self.job.exc_info, } - if dict_['status'] == rq.job.JobStatus.DEFERRED: - dict_['status'] = rq.job.JobStatus.QUEUED.value + if dict_["status"] == rq.job.JobStatus.DEFERRED: + dict_["status"] = rq.job.JobStatus.QUEUED.value return dict_ @@ -659,7 +730,7 @@ def _call_detector( mapping: Optional[dict[str, str]], conv_mask_to_poly: bool, *, - db_job: Optional[Job] = None + db_job: Optional[Job] = None, ): class Results: def __init__(self, task_id, job_id: Optional[int] = None): @@ -700,15 +771,16 @@ def parse_anno(anno, labels): # Invalid label provided return None - attrs = [{ - 'spec_id': label['attributes'][attr['name']], - 'value': attr['value'] - } for attr in anno.get('attributes', []) if attr['name'] in label['attributes']] + attrs = [ + {"spec_id": label["attributes"][attr["name"]], "value": attr["value"]} + for attr in anno.get("attributes", []) + if attr["name"] in label["attributes"] + ] if anno["type"].lower() == "tag": return { "frame": frame, - "label_id": label['id'], + "label_id": label["id"], "source": "auto", "attributes": attrs, "group": None, @@ -716,14 +788,16 @@ def parse_anno(anno, labels): else: shape = { "frame": frame, - "label_id": label['id'], + "label_id": label["id"], "source": "auto", "attributes": attrs, "group": anno["group_id"] if "group_id" in anno else None, "type": anno["type"], "occluded": False, "outside": anno.get("outside", False), - "points": anno.get("mask", []) if anno["type"] == "mask" else anno.get("points", []), + "points": ( + anno.get("mask", []) if anno["type"] == "mask" else anno.get("points", []) + ), "z_order": 0, } @@ -741,7 +815,7 @@ def parse_anno(anno, labels): shape["points"] = rle if shape["type"] == "skeleton": - parsed_elements = [parse_anno(x, label['sublabels']) for x in anno["elements"]] + parsed_elements = [parse_anno(x, label["sublabels"]) for x in anno["elements"]] # find a center to set position of missing points center = [0, 0] @@ -753,25 +827,26 @@ def parse_anno(anno, labels): def _map(sublabel_body): try: - return next(filter( - lambda x: x['label_id'] == sublabel_body['id'], - parsed_elements) + return next( + filter( + lambda x: x["label_id"] == sublabel_body["id"], parsed_elements + ) ) except StopIteration: return { "frame": frame, - "label_id": sublabel_body['id'], + "label_id": sublabel_body["id"], "source": "auto", "attributes": [], "group": None, - "type": sublabel_body['type'], + "type": sublabel_body["type"], "occluded": False, "points": center, "outside": True, "z_order": 0, } - shape["elements"] = list(map(_map, label['sublabels'].values())) + shape["elements"] = list(map(_map, label["sublabels"].values())) if all(element["outside"] for element in shape["elements"]): return None @@ -785,10 +860,11 @@ def _map(sublabel_body): if frame in db_task.data.deleted_frames: continue - annotations = function.invoke(db_task, db_job=db_job, data={ - "frame": frame, "mapping": mapping, - "threshold": threshold - }) + annotations = function.invoke( + db_task, + db_job=db_job, + data={"frame": frame, "mapping": mapping, "threshold": threshold}, + ) progress = (frame + 1) / db_task.data.size if not cls._update_progress(progress): @@ -828,8 +904,7 @@ def _get_frame_set(cls, db_task: Task, db_job: Optional[Job]): data_start_frame = task_data.start_frame step = task_data.get_frame_step() frame_set = sorted( - (abs_id - data_start_frame) // step - for abs_id in db_job.segment.frame_set + (abs_id - data_start_frame) // step for abs_id in db_job.segment.frame_set ) else: frame_set = range(db_task.data.size) @@ -844,7 +919,7 @@ def _call_reid( threshold: float, max_distance: int, *, - db_job: Optional[Job] = None + db_job: Optional[Job] = None, ): if db_job: data = dm.task.get_job_data(db_job.id) @@ -872,10 +947,18 @@ def _call_reid( boxes1 = boxes_by_frame[frame1] if boxes0 and boxes1: - matching = function.invoke(db_task, db_job=db_job, data={ - "frame0": frame0, "frame1": frame1, - "boxes0": boxes0, "boxes1": boxes1, "threshold": threshold, - "max_distance": max_distance}) + matching = function.invoke( + db_task, + db_job=db_job, + data={ + "frame0": frame0, + "frame1": frame1, + "boxes0": boxes0, + "boxes1": boxes1, + "threshold": threshold, + "max_distance": max_distance, + }, + ) for idx0, idx1 in enumerate(matching): if idx1 >= 0: @@ -886,7 +969,6 @@ def _call_reid( if not LambdaJob._update_progress((i + 1) / len(frame_set)): break - for box in boxes_by_frame[frame_set[-1]]: if "path_id" not in box: path_id = len(paths) @@ -896,14 +978,16 @@ def _call_reid( tracks = [] for path_id in paths: box0 = paths[path_id][0] - tracks.append({ - "label_id": box0["label_id"], - "group": None, - "attributes": [], - "frame": box0["frame"], - "shapes": paths[path_id], - "source": str(SourceType.AUTO) - }) + tracks.append( + { + "label_id": box0["label_id"], + "group": None, + "attributes": [], + "frame": box0["frame"], + "shapes": paths[path_id], + "source": str(SourceType.AUTO), + } + ) for box in tracks[-1]["shapes"]: box.pop("id", None) @@ -936,8 +1020,8 @@ def _call_reid( def __call__(cls, function, task: int, cleanup: bool, **kwargs): # TODO: need logging db_job = None - if job := kwargs.get('job'): - db_job = Job.objects.select_related('segment', 'segment__task').get(pk=job) + if job := kwargs.get("job"): + db_job = Job.objects.select_related("segment", "segment__task").get(pk=job) db_task = db_job.segment.task else: db_task = Task.objects.get(pk=task) @@ -953,22 +1037,34 @@ def __call__(cls, function, task: int, cleanup: bool, **kwargs): def convert_labels(db_labels): labels = {} for label in db_labels: - labels[label.name] = {'id':label.id, 'attributes': {}, 'type': label.type} - if label.type == 'skeleton': - labels[label.name]['sublabels'] = convert_labels(label.sublabels.all()) + labels[label.name] = {"id": label.id, "attributes": {}, "type": label.type} + if label.type == "skeleton": + labels[label.name]["sublabels"] = convert_labels(label.sublabels.all()) for attr in label.attributespec_set.values(): - labels[label.name]['attributes'][attr['name']] = attr['id'] + labels[label.name]["attributes"][attr["name"]] = attr["id"] return labels labels = convert_labels(db_task.get_labels(prefetch=True)) if function.kind == FunctionKind.DETECTOR: - cls._call_detector(function, db_task, labels, - kwargs.get("threshold"), kwargs.get("mapping"), kwargs.get("conv_mask_to_poly"), - db_job=db_job) + cls._call_detector( + function, + db_task, + labels, + kwargs.get("threshold"), + kwargs.get("mapping"), + kwargs.get("conv_mask_to_poly"), + db_job=db_job, + ) elif function.kind == FunctionKind.REID: - cls._call_reid(function, db_task, - kwargs.get("threshold"), kwargs.get("max_distance"), db_job=db_job) + cls._call_reid( + function, + db_task, + kwargs.get("threshold"), + kwargs.get("max_distance"), + db_job=db_job, + ) + def return_response(success_code=status.HTTP_200_OK): def wrap_response(func): @@ -1000,23 +1096,28 @@ def func_wrapper(*args, **kwargs): return Response(data=data, status=status_code) return func_wrapper + return wrap_response -@extend_schema(tags=['lambda']) + +@extend_schema(tags=["lambda"]) @extend_schema_view( retrieve=extend_schema( - operation_id='lambda_retrieve_functions', - summary='Method returns the information about the function', + operation_id="lambda_retrieve_functions", + summary="Method returns the information about the function", responses={ - '200': OpenApiResponse(response=OpenApiTypes.OBJECT, description='Information about the function'), - }), + "200": OpenApiResponse( + response=OpenApiTypes.OBJECT, description="Information about the function" + ), + }, + ), list=extend_schema( - operation_id='lambda_list_functions', - summary='Method returns a list of functions') + operation_id="lambda_list_functions", summary="Method returns a list of functions" + ), ) class FunctionViewSet(viewsets.ViewSet): - lookup_value_regex = '[a-zA-Z0-9_.-]+' - lookup_field = 'func_id' + lookup_value_regex = "[a-zA-Z0-9_.-]+" + lookup_field = "func_id" iam_organization_field = None serializer_class = None @@ -1031,7 +1132,9 @@ def retrieve(self, request, func_id): gateway = LambdaGateway() return gateway.get(func_id).to_dict() - @extend_schema(description=textwrap.dedent("""\ + @extend_schema( + description=textwrap.dedent( + """\ Allows to execute a function for immediate computation. Intended for short-lived executions, useful for interactive calls. @@ -1039,44 +1142,51 @@ def retrieve(self, request, func_id): When executed for interactive annotation, the job id must be specified in the 'job' input field. The task id is not required in this case, but if it is specified, it must match the job task id. - """), - request=inline_serializer("OnlineFunctionCall", fields={ - "job": serializers.IntegerField(required=False), - "task": serializers.IntegerField(required=False), - }), - responses=OpenApiResponse(description="Returns function invocation results") + """ + ), + request=inline_serializer( + "OnlineFunctionCall", + fields={ + "job": serializers.IntegerField(required=False), + "task": serializers.IntegerField(required=False), + }, + ), + responses=OpenApiResponse(description="Returns function invocation results"), ) @return_response() def call(self, request, func_id): self.check_object_permissions(request, func_id) try: - job_id = request.data.get('job') + job_id = request.data.get("job") job = None if job_id is not None: job = Job.objects.get(id=job_id) task_id = job.get_task_id() else: - task_id = request.data['task'] + task_id = request.data["task"] db_task = Task.objects.get(pk=task_id) except (KeyError, ObjectDoesNotExist) as err: raise ValidationError( - '`{}` lambda function was run '.format(func_id) + - 'with wrong arguments ({})'.format(str(err)), - code=status.HTTP_400_BAD_REQUEST) + "`{}` lambda function was run ".format(func_id) + + "with wrong arguments ({})".format(str(err)), + code=status.HTTP_400_BAD_REQUEST, + ) gateway = LambdaGateway() lambda_func = gateway.get(func_id) response = lambda_func.invoke( db_task, - request.data, # TODO: better to add validation via serializer for these data + request.data, # TODO: better to add validation via serializer for these data db_job=job, is_interactive=True, - request=request + request=request, ) - handle_function_call(func_id, db_task, + handle_function_call( + func_id, + db_task, category="interactive", parameters={ param_name: param_value @@ -1088,41 +1198,44 @@ def call(self, request, func_id): return response -@extend_schema(tags=['lambda']) + +@extend_schema(tags=["lambda"]) @extend_schema_view( retrieve=extend_schema( - operation_id='lambda_retrieve_requests', - summary='Method returns the status of the request', + operation_id="lambda_retrieve_requests", + summary="Method returns the status of the request", parameters=[ - OpenApiParameter('id', location=OpenApiParameter.PATH, type=OpenApiTypes.STR, - description='Request id'), + OpenApiParameter( + "id", + location=OpenApiParameter.PATH, + type=OpenApiTypes.STR, + description="Request id", + ), ], - responses={ - '200': FunctionCallSerializer - } + responses={"200": FunctionCallSerializer}, ), list=extend_schema( - operation_id='lambda_list_requests', - summary='Method returns a list of requests', - responses={ - '200': FunctionCallSerializer(many=True) - } + operation_id="lambda_list_requests", + summary="Method returns a list of requests", + responses={"200": FunctionCallSerializer(many=True)}, ), create=extend_schema( parameters=ORGANIZATION_OPEN_API_PARAMETERS, - summary='Method calls the function', + summary="Method calls the function", request=FunctionCallRequestSerializer, - responses={ - '200': FunctionCallSerializer - } + responses={"200": FunctionCallSerializer}, ), destroy=extend_schema( - operation_id='lambda_delete_requests', - summary='Method cancels the request', + operation_id="lambda_delete_requests", + summary="Method cancels the request", parameters=[ - OpenApiParameter('id', location=OpenApiParameter.PATH, type=OpenApiTypes.STR, - description='Request id'), - ] + OpenApiParameter( + "id", + location=OpenApiParameter.PATH, + type=OpenApiTypes.STR, + description="Request id", + ), + ], ), ) class RequestViewSet(viewsets.ViewSet): @@ -1158,25 +1271,35 @@ def create(self, request): request_data = request_serializer.validated_data try: - function = request_data['function'] - threshold = request_data.get('threshold') - task = request_data['task'] - job = request_data.get('job', None) - cleanup = request_data.get('cleanup', False) - conv_mask_to_poly = request_data.get('conv_mask_to_poly', False) - mapping = request_data.get('mapping') - max_distance = request_data.get('max_distance') + function = request_data["function"] + threshold = request_data.get("threshold") + task = request_data["task"] + job = request_data.get("job", None) + cleanup = request_data.get("cleanup", False) + conv_mask_to_poly = request_data.get("conv_mask_to_poly", False) + mapping = request_data.get("mapping") + max_distance = request_data.get("max_distance") except KeyError as err: raise ValidationError( - '`{}` lambda function was run '.format(request_data.get('function', 'undefined')) + - 'with wrong arguments ({})'.format(str(err)), - code=status.HTTP_400_BAD_REQUEST) + "`{}` lambda function was run ".format(request_data.get("function", "undefined")) + + "with wrong arguments ({})".format(str(err)), + code=status.HTTP_400_BAD_REQUEST, + ) gateway = LambdaGateway() queue = LambdaQueue() lambda_func = gateway.get(function) - rq_job = queue.enqueue(lambda_func, threshold, task, - mapping, cleanup, conv_mask_to_poly, max_distance, request, job=job) + rq_job = queue.enqueue( + lambda_func, + threshold, + task, + mapping, + cleanup, + conv_mask_to_poly, + max_distance, + request, + job=job, + ) handle_function_call(function, job or task, category="batch") diff --git a/cvat/apps/log_viewer/apps.py b/cvat/apps/log_viewer/apps.py index 437c960e3929..a1806efc6462 100644 --- a/cvat/apps/log_viewer/apps.py +++ b/cvat/apps/log_viewer/apps.py @@ -6,8 +6,9 @@ class LogViewerConfig(AppConfig): - name = 'cvat.apps.log_viewer' + name = "cvat.apps.log_viewer" def ready(self) -> None: from cvat.apps.iam.permissions import load_app_permissions + load_app_permissions(self) diff --git a/cvat/apps/log_viewer/permissions.py b/cvat/apps/log_viewer/permissions.py index d25aa7fe275a..4ad996fb7e67 100644 --- a/cvat/apps/log_viewer/permissions.py +++ b/cvat/apps/log_viewer/permissions.py @@ -12,12 +12,12 @@ class LogViewerPermission(OpenPolicyAgentPermission): has_analytics_access: bool class Scopes(StrEnum): - VIEW = 'view' + VIEW = "view" @classmethod def create(cls, request, view, obj, iam_context): permissions = [] - if view.basename == 'analytics': + if view.basename == "analytics": for scope in cls.get_scopes(request, view, obj): self = cls.create_base_perm(request, view, scope, iam_context, obj) permissions.append(self) @@ -33,20 +33,22 @@ def create_base_perm(cls, request, view, scope, iam_context, obj=None, **kwargs) obj=obj, has_analytics_access=request.user.profile.has_analytics_access, **iam_context, - **kwargs + **kwargs, ) def __init__(self, has_analytics_access=False, **kwargs): super().__init__(**kwargs) - self.payload['input']['auth']['user']['has_analytics_access'] = has_analytics_access - self.url = settings.IAM_OPA_DATA_URL + '/analytics/allow' + self.payload["input"]["auth"]["user"]["has_analytics_access"] = has_analytics_access + self.url = settings.IAM_OPA_DATA_URL + "/analytics/allow" @staticmethod def get_scopes(request, view, obj): Scopes = __class__.Scopes - return [{ - 'list': Scopes.VIEW, - }[view.action]] + return [ + { + "list": Scopes.VIEW, + }[view.action] + ] def get_resource(self): return None diff --git a/cvat/apps/log_viewer/rules/tests/generators/analytics_test.gen.rego.py b/cvat/apps/log_viewer/rules/tests/generators/analytics_test.gen.rego.py index 95d566e4b93a..12d28193cd9f 100644 --- a/cvat/apps/log_viewer/rules/tests/generators/analytics_test.gen.rego.py +++ b/cvat/apps/log_viewer/rules/tests/generators/analytics_test.gen.rego.py @@ -62,9 +62,15 @@ def eval_rule(scope, context, ownership, privilege, membership, data, has_analyt ) ) rules = list(filter(lambda r: GROUPS.index(privilege) <= GROUPS.index(r["privilege"]), rules)) - rules = list(filter(lambda r: r["hasanalyticsaccess"] in ("na", str(has_analytics_access).lower()), rules)) + rules = list( + filter( + lambda r: r["hasanalyticsaccess"] in ("na", str(has_analytics_access).lower()), rules + ) + ) resource = data["resource"] - rules = list(filter(lambda r: not r["limit"] or eval(r["limit"], {"resource": resource}), rules)) + rules = list( + filter(lambda r: not r["limit"] or eval(r["limit"], {"resource": resource}), rules) + ) return bool(rules) @@ -78,13 +84,15 @@ def get_data(scope, context, ownership, privilege, membership, resource, has_ana "privilege": privilege, "has_analytics_access": has_analytics_access, }, - "organization": { - "id": random.randrange(100, 200), - "owner": {"id": random.randrange(200, 300)}, - "user": {"role": membership}, - } - if context == "organization" - else None, + "organization": ( + { + "id": random.randrange(100, 200), + "owner": {"id": random.randrange(200, 300)}, + "user": {"role": membership}, + } + if context == "organization" + else None + ), }, "resource": resource, } @@ -143,9 +151,15 @@ def gen_test_rego(name): if not is_valid(scope, context, ownership, privilege, membership, resource): continue - data = get_data(scope, context, ownership, privilege, membership, resource, has_analytics_access) - test_name = get_name(scope, context, ownership, privilege, membership, resource, has_analytics_access) - result = eval_rule(scope, context, ownership, privilege, membership, data, has_analytics_access) + data = get_data( + scope, context, ownership, privilege, membership, resource, has_analytics_access + ) + test_name = get_name( + scope, context, ownership, privilege, membership, resource, has_analytics_access + ) + result = eval_rule( + scope, context, ownership, privilege, membership, data, has_analytics_access + ) f.write( "{test_name} if {{\n {allow} with input as {data}\n}}\n\n".format( test_name=test_name, diff --git a/cvat/apps/log_viewer/urls.py b/cvat/apps/log_viewer/urls.py index 0de56682a37e..96e88e38c9bc 100644 --- a/cvat/apps/log_viewer/urls.py +++ b/cvat/apps/log_viewer/urls.py @@ -1,4 +1,3 @@ - # Copyright (C) 2018-2022 Intel Corporation # # SPDX-License-Identifier: MIT @@ -8,6 +7,6 @@ from . import views router = routers.DefaultRouter(trailing_slash=False) -router.register('analytics', views.LogViewerAccessViewSet, basename='analytics') +router.register("analytics", views.LogViewerAccessViewSet, basename="analytics") urlpatterns = router.urls diff --git a/cvat/apps/log_viewer/views.py b/cvat/apps/log_viewer/views.py index 9e52f546c634..362f2bb97ec3 100644 --- a/cvat/apps/log_viewer/views.py +++ b/cvat/apps/log_viewer/views.py @@ -4,11 +4,11 @@ from django.conf import settings from django.http import HttpResponsePermanentRedirect +from drf_spectacular.utils import extend_schema from rest_framework import status, viewsets from rest_framework.decorators import action from rest_framework.response import Response -from drf_spectacular.utils import extend_schema @extend_schema(exclude=True) class LogViewerAccessViewSet(viewsets.ViewSet): @@ -19,7 +19,7 @@ def list(self, request): # All log view requests are proxied by Traefik in production mode which is not available in debug mode, # In order not to duplicate settings, let's just redirect to the default page in debug mode - @action(detail=False, url_path='dashboards') + @action(detail=False, url_path="dashboards") def redirect(self, request): if settings.DEBUG: - return HttpResponsePermanentRedirect('http://localhost:3001/dashboards') + return HttpResponsePermanentRedirect("http://localhost:3001/dashboards") diff --git a/cvat/apps/organizations/__init__.py b/cvat/apps/organizations/__init__.py index b1220197cf2a..f7c3408e3d12 100644 --- a/cvat/apps/organizations/__init__.py +++ b/cvat/apps/organizations/__init__.py @@ -1,4 +1,3 @@ # Copyright (C) 2021-2022 Intel Corporation # # SPDX-License-Identifier: MIT - diff --git a/cvat/apps/organizations/admin.py b/cvat/apps/organizations/admin.py index 756100244743..33e711189a1f 100644 --- a/cvat/apps/organizations/admin.py +++ b/cvat/apps/organizations/admin.py @@ -2,27 +2,29 @@ # # SPDX-License-Identifier: MIT -from .models import Organization, Membership from django.contrib import admin +from .models import Membership, Organization + + class MembershipInline(admin.TabularInline): model = Membership extra = 0 radio_fields = { - 'role': admin.VERTICAL, + "role": admin.VERTICAL, } - autocomplete_fields = ('user', ) + autocomplete_fields = ("user",) + class OrganizationAdmin(admin.ModelAdmin): - search_fields = ('slug', 'name', 'owner__username') - list_display = ('id', 'slug', 'name') + search_fields = ("slug", "name", "owner__username") + list_display = ("id", "slug", "name") + + autocomplete_fields = ("owner",) - autocomplete_fields = ('owner', ) + inlines = [MembershipInline] - inlines = [ - MembershipInline - ] admin.site.register(Organization, OrganizationAdmin) diff --git a/cvat/apps/organizations/apps.py b/cvat/apps/organizations/apps.py index f73094af1723..ad654a0b8061 100644 --- a/cvat/apps/organizations/apps.py +++ b/cvat/apps/organizations/apps.py @@ -5,9 +5,11 @@ from django.apps import AppConfig + class OrganizationsConfig(AppConfig): - name = 'cvat.apps.organizations' + name = "cvat.apps.organizations" def ready(self) -> None: from cvat.apps.iam.permissions import load_app_permissions + load_app_permissions(self) diff --git a/cvat/apps/organizations/migrations/0001_initial.py b/cvat/apps/organizations/migrations/0001_initial.py index 1d2689d343d1..5d4887a15fb2 100644 --- a/cvat/apps/organizations/migrations/0001_initial.py +++ b/cvat/apps/organizations/migrations/0001_initial.py @@ -1,8 +1,8 @@ # Generated by Django 3.2.8 on 2021-10-26 14:52 +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion class Migration(migrations.Migration): @@ -15,46 +15,103 @@ class Migration(migrations.Migration): operations = [ migrations.CreateModel( - name='Organization', + name="Organization", fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('slug', models.SlugField(max_length=16, unique=True)), - ('name', models.CharField(blank=True, max_length=64)), - ('description', models.TextField(blank=True)), - ('created_date', models.DateTimeField(auto_now_add=True)), - ('updated_date', models.DateTimeField(auto_now=True)), - ('contact', models.JSONField(blank=True, default=dict)), - ('owner', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to=settings.AUTH_USER_MODEL)), + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("slug", models.SlugField(max_length=16, unique=True)), + ("name", models.CharField(blank=True, max_length=64)), + ("description", models.TextField(blank=True)), + ("created_date", models.DateTimeField(auto_now_add=True)), + ("updated_date", models.DateTimeField(auto_now=True)), + ("contact", models.JSONField(blank=True, default=dict)), + ( + "owner", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), ], options={ - 'default_permissions': (), + "default_permissions": (), }, ), migrations.CreateModel( - name='Membership', + name="Membership", fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('is_active', models.BooleanField(default=False)), - ('joined_date', models.DateTimeField(null=True)), - ('role', models.CharField(choices=[('worker', 'Worker'), ('supervisor', 'Supervisor'), ('maintainer', 'Maintainer'), ('owner', 'Owner')], max_length=16)), - ('organization', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='members', to='organizations.organization')), - ('user', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='memberships', to=settings.AUTH_USER_MODEL)), + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("is_active", models.BooleanField(default=False)), + ("joined_date", models.DateTimeField(null=True)), + ( + "role", + models.CharField( + choices=[ + ("worker", "Worker"), + ("supervisor", "Supervisor"), + ("maintainer", "Maintainer"), + ("owner", "Owner"), + ], + max_length=16, + ), + ), + ( + "organization", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="members", + to="organizations.organization", + ), + ), + ( + "user", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="memberships", + to=settings.AUTH_USER_MODEL, + ), + ), ], options={ - 'default_permissions': (), - 'unique_together': {('user', 'organization')}, + "default_permissions": (), + "unique_together": {("user", "organization")}, }, ), migrations.CreateModel( - name='Invitation', + name="Invitation", fields=[ - ('key', models.CharField(max_length=64, primary_key=True, serialize=False)), - ('created_date', models.DateTimeField(auto_now_add=True)), - ('membership', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to='organizations.membership')), - ('owner', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL)), + ("key", models.CharField(max_length=64, primary_key=True, serialize=False)), + ("created_date", models.DateTimeField(auto_now_add=True)), + ( + "membership", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, to="organizations.membership" + ), + ), + ( + "owner", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), ], options={ - 'default_permissions': (), + "default_permissions": (), }, ), ] diff --git a/cvat/apps/organizations/models.py b/cvat/apps/organizations/models.py index 3da77bafbebf..d582459866f8 100644 --- a/cvat/apps/organizations/models.py +++ b/cvat/apps/organizations/models.py @@ -4,55 +4,62 @@ # SPDX-License-Identifier: MIT from datetime import timedelta -from django.conf import settings -from allauth.account.adapter import get_adapter -from django.contrib.sites.shortcuts import get_current_site -from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import extend_schema_field -from django.db import models +from allauth.account.adapter import get_adapter +from django.conf import settings from django.contrib.auth import get_user_model +from django.contrib.sites.shortcuts import get_current_site from django.core.exceptions import ImproperlyConfigured +from django.db import models from django.utils import timezone +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import extend_schema_field from cvat.apps.engine.models import TimestampedModel + class Organization(TimestampedModel): slug = models.SlugField(max_length=16, blank=False, unique=True) name = models.CharField(max_length=64, blank=True) description = models.TextField(blank=True) contact = models.JSONField(blank=True, default=dict) - owner = models.ForeignKey(get_user_model(), null=True, - blank=True, on_delete=models.SET_NULL, related_name='+') + owner = models.ForeignKey( + get_user_model(), null=True, blank=True, on_delete=models.SET_NULL, related_name="+" + ) def __str__(self): return self.slug + class Meta: default_permissions = () + class Membership(models.Model): - WORKER = 'worker' - SUPERVISOR = 'supervisor' - MAINTAINER = 'maintainer' - OWNER = 'owner' - - user = models.ForeignKey(get_user_model(), on_delete=models.CASCADE, - null=True, related_name='memberships') - organization = models.ForeignKey(Organization, on_delete=models.CASCADE, - related_name='members') + WORKER = "worker" + SUPERVISOR = "supervisor" + MAINTAINER = "maintainer" + OWNER = "owner" + + user = models.ForeignKey( + get_user_model(), on_delete=models.CASCADE, null=True, related_name="memberships" + ) + organization = models.ForeignKey(Organization, on_delete=models.CASCADE, related_name="members") is_active = models.BooleanField(default=False) joined_date = models.DateTimeField(null=True) - role = models.CharField(max_length=16, choices=[ - (WORKER, 'Worker'), - (SUPERVISOR, 'Supervisor'), - (MAINTAINER, 'Maintainer'), - (OWNER, 'Owner'), - ]) + role = models.CharField( + max_length=16, + choices=[ + (WORKER, "Worker"), + (SUPERVISOR, "Supervisor"), + (MAINTAINER, "Maintainer"), + (OWNER, "Owner"), + ], + ) class Meta: default_permissions = () - unique_together = ('user', 'organization') + unique_together = ("user", "organization") # Inspried by https://github.com/bee-keeper/django-invitations @@ -94,16 +101,16 @@ def send(self, request): site_name = current_site.name domain = current_site.domain context = { - 'email': target_email, - 'invitation_key': self.key, - 'domain': domain, - 'site_name': site_name, - 'invitation_owner': self.owner.get_username(), - 'organization_name': self.membership.organization.slug, - 'protocol': 'https' if request.is_secure() else 'http', + "email": target_email, + "invitation_key": self.key, + "domain": domain, + "site_name": site_name, + "invitation_owner": self.owner.get_username(), + "organization_name": self.membership.organization.slug, + "protocol": "https" if request.is_secure() else "http", } - get_adapter(request).send_mail('invitation/invitation', target_email, context) + get_adapter(request).send_mail("invitation/invitation", target_email, context) self.sent_date = timezone.now() self.save() diff --git a/cvat/apps/organizations/permissions.py b/cvat/apps/organizations/permissions.py index e45b05d978c3..1e18cf5e20c5 100644 --- a/cvat/apps/organizations/permissions.py +++ b/cvat/apps/organizations/permissions.py @@ -9,18 +9,19 @@ from .models import Membership + class OrganizationPermission(OpenPolicyAgentPermission): class Scopes(StrEnum): - LIST = 'list' - CREATE = 'create' - DELETE = 'delete' - UPDATE = 'update' - VIEW = 'view' + LIST = "list" + CREATE = "create" + DELETE = "delete" + UPDATE = "update" + VIEW = "view" @classmethod def create(cls, request, view, obj, iam_context): permissions = [] - if view.basename == 'organization': + if view.basename == "organization": for scope in cls.get_scopes(request, view, obj): self = cls.create_base_perm(request, view, scope, iam_context, obj) permissions.append(self) @@ -29,127 +30,116 @@ def create(cls, request, view, obj, iam_context): def __init__(self, **kwargs): super().__init__(**kwargs) - self.url = settings.IAM_OPA_DATA_URL + '/organizations/allow' + self.url = settings.IAM_OPA_DATA_URL + "/organizations/allow" @staticmethod def get_scopes(request, view, obj): Scopes = __class__.Scopes - return [{ - 'list': Scopes.LIST, - 'create': Scopes.CREATE, - 'destroy': Scopes.DELETE, - 'partial_update': Scopes.UPDATE, - 'retrieve': Scopes.VIEW, - }[view.action]] + return [ + { + "list": Scopes.LIST, + "create": Scopes.CREATE, + "destroy": Scopes.DELETE, + "partial_update": Scopes.UPDATE, + "retrieve": Scopes.VIEW, + }[view.action] + ] def get_resource(self): if self.obj: - membership = Membership.objects.filter( - organization=self.obj, user=self.user_id).first() + membership = Membership.objects.filter(organization=self.obj, user=self.user_id).first() return { - 'id': self.obj.id, - 'owner': { - 'id': getattr(self.obj.owner, 'id', None) - }, - 'user': { - 'role': membership.role if membership else None - } + "id": self.obj.id, + "owner": {"id": getattr(self.obj.owner, "id", None)}, + "user": {"role": membership.role if membership else None}, } elif self.scope.startswith(__class__.Scopes.CREATE.value): - return { - 'id': None, - 'owner': { - 'id': self.user_id - }, - 'user': { - 'role': 'owner' - } - } + return {"id": None, "owner": {"id": self.user_id}, "user": {"role": "owner"}} else: return None + class InvitationPermission(OpenPolicyAgentPermission): class Scopes(StrEnum): - LIST = 'list' - CREATE = 'create' - DELETE = 'delete' - ACCEPT = 'accept' - DECLINE = 'decline' - RESEND = 'resend' - VIEW = 'view' + LIST = "list" + CREATE = "create" + DELETE = "delete" + ACCEPT = "accept" + DECLINE = "decline" + RESEND = "resend" + VIEW = "view" @classmethod def create(cls, request, view, obj, iam_context): permissions = [] - if view.basename == 'invitation': + if view.basename == "invitation": for scope in cls.get_scopes(request, view, obj): - self = cls.create_base_perm(request, view, scope, iam_context, obj, - role=request.data.get('role')) + self = cls.create_base_perm( + request, view, scope, iam_context, obj, role=request.data.get("role") + ) permissions.append(self) return permissions def __init__(self, **kwargs): super().__init__(**kwargs) - self.role = kwargs.get('role') - self.url = settings.IAM_OPA_DATA_URL + '/invitations/allow' + self.role = kwargs.get("role") + self.url = settings.IAM_OPA_DATA_URL + "/invitations/allow" @staticmethod def get_scopes(request, view, obj): Scopes = __class__.Scopes - return [{ - 'list': Scopes.LIST, - 'create': Scopes.CREATE, - 'destroy': Scopes.DELETE, - 'partial_update': Scopes.ACCEPT if 'accepted' in - request.query_params else Scopes.RESEND, - 'retrieve': Scopes.VIEW, - 'accept': Scopes.ACCEPT, - 'decline': Scopes.DECLINE, - 'resend': Scopes.RESEND, - }[view.action]] + return [ + { + "list": Scopes.LIST, + "create": Scopes.CREATE, + "destroy": Scopes.DELETE, + "partial_update": ( + Scopes.ACCEPT if "accepted" in request.query_params else Scopes.RESEND + ), + "retrieve": Scopes.VIEW, + "accept": Scopes.ACCEPT, + "decline": Scopes.DECLINE, + "resend": Scopes.RESEND, + }[view.action] + ] def get_resource(self): data = None if self.obj: data = { - 'owner': { 'id': getattr(self.obj.owner, 'id', None) }, - 'invitee': { 'id': getattr(self.obj.membership.user, 'id', None) }, - 'role': self.obj.membership.role, - 'organization': { - 'id': self.obj.membership.organization.id - } + "owner": {"id": getattr(self.obj.owner, "id", None)}, + "invitee": {"id": getattr(self.obj.membership.user, "id", None)}, + "role": self.obj.membership.role, + "organization": {"id": self.obj.membership.organization.id}, } elif self.scope.startswith(__class__.Scopes.CREATE.value): data = { - 'owner': { 'id': self.user_id }, - 'invitee': { - 'id': None # unknown yet - }, - 'role': self.role, - 'organization': { - 'id': self.org_id - } if self.org_id is not None else None + "owner": {"id": self.user_id}, + "invitee": {"id": None}, # unknown yet + "role": self.role, + "organization": {"id": self.org_id} if self.org_id is not None else None, } return data + class MembershipPermission(OpenPolicyAgentPermission): class Scopes(StrEnum): - LIST = 'list' - UPDATE = 'change' - UPDATE_ROLE = 'change:role' - VIEW = 'view' - DELETE = 'delete' + LIST = "list" + UPDATE = "change" + UPDATE_ROLE = "change:role" + VIEW = "view" + DELETE = "delete" @classmethod def create(cls, request, view, obj, iam_context): permissions = [] - if view.basename == 'membership': + if view.basename == "membership": for scope in cls.get_scopes(request, view, obj): params = {} - if scope == 'change:role': - params['role'] = request.data.get('role') + if scope == "change:role": + params["role"] = request.data.get("role") self = cls.create_base_perm(request, view, scope, iam_context, obj, **params) permissions.append(self) @@ -158,7 +148,7 @@ def create(cls, request, view, obj, iam_context): def __init__(self, **kwargs): super().__init__(**kwargs) - self.url = settings.IAM_OPA_DATA_URL + '/memberships/allow' + self.url = settings.IAM_OPA_DATA_URL + "/memberships/allow" @staticmethod def get_scopes(request, view, obj): @@ -166,16 +156,21 @@ def get_scopes(request, view, obj): scopes = [] scope = { - 'list': Scopes.LIST, - 'partial_update': Scopes.UPDATE, - 'retrieve': Scopes.VIEW, - 'destroy': Scopes.DELETE, + "list": Scopes.LIST, + "partial_update": Scopes.UPDATE, + "retrieve": Scopes.VIEW, + "destroy": Scopes.DELETE, }[view.action] if scope == Scopes.UPDATE: - scopes.extend(__class__.get_per_field_update_scopes(request, { - 'role': Scopes.UPDATE_ROLE, - })) + scopes.extend( + __class__.get_per_field_update_scopes( + request, + { + "role": Scopes.UPDATE_ROLE, + }, + ) + ) else: scopes.append(scope) @@ -184,10 +179,10 @@ def get_scopes(request, view, obj): def get_resource(self): if self.obj: return { - 'role': self.obj.role, - 'is_active': self.obj.is_active, - 'user': { 'id': self.obj.user.id }, - 'organization': { 'id': self.obj.organization.id } + "role": self.obj.role, + "is_active": self.obj.is_active, + "user": {"id": self.obj.user.id}, + "organization": {"id": self.obj.organization.id}, } else: return None diff --git a/cvat/apps/organizations/rules/tests/generators/invitations_test.gen.rego.py b/cvat/apps/organizations/rules/tests/generators/invitations_test.gen.rego.py index bf7edec50713..39ff446d8eac 100644 --- a/cvat/apps/organizations/rules/tests/generators/invitations_test.gen.rego.py +++ b/cvat/apps/organizations/rules/tests/generators/invitations_test.gen.rego.py @@ -109,13 +109,15 @@ def get_data(scope, context, ownership, privilege, membership, resource, same_or "scope": scope, "auth": { "user": {"id": random.randrange(0, 100), "privilege": privilege}, - "organization": { - "id": random.randrange(100, 200), - "owner": {"id": random.randrange(200, 300)}, - "user": {"role": membership}, - } - if context == "organization" - else None, + "organization": ( + { + "id": random.randrange(100, 200), + "owner": {"id": random.randrange(200, 300)}, + "user": {"role": membership}, + } + if context == "organization" + else None + ), }, "resource": resource, } diff --git a/cvat/apps/organizations/rules/tests/generators/memberships_test.gen.rego.py b/cvat/apps/organizations/rules/tests/generators/memberships_test.gen.rego.py index c74a4a7c992b..09258163b2db 100644 --- a/cvat/apps/organizations/rules/tests/generators/memberships_test.gen.rego.py +++ b/cvat/apps/organizations/rules/tests/generators/memberships_test.gen.rego.py @@ -98,14 +98,14 @@ def eval_rule(scope, context, ownership, privilege, membership, data): return False if scope != "create" and not data["resource"]["is_active"]: - is_staff = membership == "owner" or membership == 'maintainer' + is_staff = membership == "owner" or membership == "maintainer" if is_staff: - if scope != 'view': + if scope != "view": if ORG_ROLES.index(membership) >= ORG_ROLES.index(resource["role"]): return False if GROUPS.index(privilege) > GROUPS.index("user"): return False - if resource["user"]['id'] == data["auth"]["user"]['id']: + if resource["user"]["id"] == data["auth"]["user"]["id"]: return False return True return False @@ -118,13 +118,15 @@ def get_data(scope, context, ownership, privilege, membership, resource, same_or "scope": scope, "auth": { "user": {"id": random.randrange(0, 100), "privilege": privilege}, - "organization": { - "id": random.randrange(100, 200), - "owner": {"id": random.randrange(200, 300)}, - "user": {"role": membership}, - } - if context == "organization" - else None, + "organization": ( + { + "id": random.randrange(100, 200), + "owner": {"id": random.randrange(200, 300)}, + "user": {"role": membership}, + } + if context == "organization" + else None + ), }, "resource": resource, } diff --git a/cvat/apps/organizations/rules/tests/generators/organizations_test.gen.rego.py b/cvat/apps/organizations/rules/tests/generators/organizations_test.gen.rego.py index d2a8a6fb653b..35f4fad15678 100644 --- a/cvat/apps/organizations/rules/tests/generators/organizations_test.gen.rego.py +++ b/cvat/apps/organizations/rules/tests/generators/organizations_test.gen.rego.py @@ -78,13 +78,15 @@ def get_data(scope, context, ownership, privilege, membership, resource): "scope": scope, "auth": { "user": {"id": random.randrange(0, 100), "privilege": privilege}, - "organization": { - "id": random.randrange(100, 200), - "owner": {"id": random.randrange(200, 300)}, - "user": {"role": membership}, - } - if context == "organization" - else None, + "organization": ( + { + "id": random.randrange(100, 200), + "owner": {"id": random.randrange(200, 300)}, + "user": {"role": membership}, + } + if context == "organization" + else None + ), }, "resource": {**resource, "owner": {"id": random.randrange(300, 400)}} if resource else None, } diff --git a/cvat/apps/organizations/serializers.py b/cvat/apps/organizations/serializers.py index 9cfb467aa3b9..6fe3a7a851a4 100644 --- a/cvat/apps/organizations/serializers.py +++ b/cvat/apps/organizations/serializers.py @@ -3,33 +3,46 @@ # # SPDX-License-Identifier: MIT -from attr.converters import to_bool -from django.contrib.auth import get_user_model from allauth.account.models import EmailAddress -from django.core.exceptions import ObjectDoesNotExist +from attr.converters import to_bool from django.conf import settings +from django.contrib.auth import get_user_model from django.contrib.auth.models import User +from django.core.exceptions import ObjectDoesNotExist from django.db import transaction - from rest_framework import serializers + from cvat.apps.engine.serializers import BasicUserSerializer from cvat.apps.iam.utils import get_dummy_user + from .models import Invitation, Membership, Organization + class OrganizationReadSerializer(serializers.ModelSerializer): owner = BasicUserSerializer(allow_null=True) + class Meta: model = Organization - fields = ['id', 'slug', 'name', 'description', 'created_date', - 'updated_date', 'contact', 'owner'] + fields = [ + "id", + "slug", + "name", + "description", + "created_date", + "updated_date", + "contact", + "owner", + ] read_only_fields = fields + class BasicOrganizationSerializer(serializers.ModelSerializer): class Meta: model = Organization - fields = ['id', 'slug'] + fields = ["id", "slug"] read_only_fields = fields + class OrganizationWriteSerializer(serializers.ModelSerializer): def to_representation(self, instance): serializer = OrganizationReadSerializer(instance, context=self.context) @@ -37,12 +50,12 @@ def to_representation(self, instance): class Meta: model = Organization - fields = ['slug', 'name', 'description', 'contact', 'owner'] + fields = ["slug", "name", "description", "contact", "owner"] # TODO: at the moment isn't possible to change the owner. It should # be a separate feature. Need to change it together with corresponding # Membership. Also such operation should be well protected. - read_only_fields = ['owner'] + read_only_fields = ["owner"] def create(self, validated_data): organization = super().create(validated_data) @@ -51,36 +64,47 @@ def create(self, validated_data): organization=organization, is_active=True, joined_date=organization.created_date, - role=Membership.OWNER) + role=Membership.OWNER, + ) return organization + class InvitationReadSerializer(serializers.ModelSerializer): - role = serializers.ChoiceField(Membership.role.field.choices, - source='membership.role') - user = BasicUserSerializer(source='membership.user') + role = serializers.ChoiceField(Membership.role.field.choices, source="membership.role") + user = BasicUserSerializer(source="membership.user") organization = serializers.PrimaryKeyRelatedField( - queryset=Organization.objects.all(), - source='membership.organization') - organization_info = BasicOrganizationSerializer(source='membership.organization') + queryset=Organization.objects.all(), source="membership.organization" + ) + organization_info = BasicOrganizationSerializer(source="membership.organization") owner = BasicUserSerializer(allow_null=True) class Meta: model = Invitation - fields = ['key', 'created_date', 'owner', 'role', 'user', 'organization', 'expired', 'organization_info'] + fields = [ + "key", + "created_date", + "owner", + "role", + "user", + "organization", + "expired", + "organization_info", + ] read_only_fields = fields extra_kwargs = { - 'expired': { - 'allow_null': True, + "expired": { + "allow_null": True, } } + class InvitationWriteSerializer(serializers.ModelSerializer): - role = serializers.ChoiceField(Membership.role.field.choices, - source='membership.role') - email = serializers.EmailField(source='membership.user.email') + role = serializers.ChoiceField(Membership.role.field.choices, source="membership.role") + email = serializers.EmailField(source="membership.user.email") organization = serializers.PrimaryKeyRelatedField( - source='membership.organization', read_only=True) + source="membership.organization", read_only=True + ) def to_representation(self, instance): serializer = InvitationReadSerializer(instance, context=self.context) @@ -88,34 +112,35 @@ def to_representation(self, instance): class Meta: model = Invitation - fields = ['key', 'created_date', 'owner', 'role', 'organization', 'email'] - read_only_fields = ['key', 'created_date', 'owner', 'organization'] + fields = ["key", "created_date", "owner", "role", "organization", "email"] + read_only_fields = ["key", "created_date", "owner", "organization"] @transaction.atomic def create(self, validated_data): - membership_data = validated_data.pop('membership') - organization = validated_data.pop('organization') + membership_data = validated_data.pop("membership") + organization = validated_data.pop("organization") try: - user = get_user_model().objects.get( - email__iexact=membership_data['user']['email']) - del membership_data['user'] + user = get_user_model().objects.get(email__iexact=membership_data["user"]["email"]) + del membership_data["user"] except ObjectDoesNotExist: - user_email = membership_data['user']['email'] + user_email = membership_data["user"]["email"] user = User.objects.create_user(username=user_email, email=user_email) user.set_unusable_password() # User.objects.create_user(...) normalizes passed email and user.email can be different from original user_email - email = EmailAddress.objects.create(user=user, email=user.email, primary=True, verified=False) + email = EmailAddress.objects.create( + user=user, email=user.email, primary=True, verified=False + ) user.save() email.save() - del membership_data['user'] + del membership_data["user"] membership, created = Membership.objects.get_or_create( - defaults=membership_data, - user=user, organization=organization) + defaults=membership_data, user=user, organization=organization + ) if not created: - raise serializers.ValidationError('The user is a member of ' - 'the organization already.') - invitation = Invitation.objects.create(**validated_data, - membership=membership) + raise serializers.ValidationError( + "The user is a member of " "the organization already." + ) + invitation = Invitation.objects.create(**validated_data, membership=membership) return invitation @@ -132,20 +157,21 @@ def save(self, request, **kwargs): return invitation + class MembershipReadSerializer(serializers.ModelSerializer): user = BasicUserSerializer() class Meta: model = Membership - fields = ['id', 'user', 'organization', 'is_active', 'joined_date', 'role', - 'invitation'] + fields = ["id", "user", "organization", "is_active", "joined_date", "role", "invitation"] read_only_fields = fields extra_kwargs = { - 'invitation': { - 'allow_null': True, # owner of an organization does not have an invitation + "invitation": { + "allow_null": True, # owner of an organization does not have an invitation } } + class MembershipWriteSerializer(serializers.ModelSerializer): def to_representation(self, instance): serializer = MembershipReadSerializer(instance, context=self.context) @@ -153,8 +179,9 @@ def to_representation(self, instance): class Meta: model = Membership - fields = ['id', 'user', 'organization', 'is_active', 'joined_date', 'role'] - read_only_fields = ['user', 'organization', 'is_active', 'joined_date'] + fields = ["id", "user", "organization", "is_active", "joined_date", "role"] + read_only_fields = ["user", "organization", "is_active", "joined_date"] + class AcceptInvitationReadSerializer(serializers.Serializer): organization_slug = serializers.CharField() diff --git a/cvat/apps/organizations/throttle.py b/cvat/apps/organizations/throttle.py index 438538b61d4a..342b9463170b 100644 --- a/cvat/apps/organizations/throttle.py +++ b/cvat/apps/organizations/throttle.py @@ -4,5 +4,6 @@ from rest_framework.throttling import UserRateThrottle + class ResendOrganizationInvitationThrottle(UserRateThrottle): - rate = '5/hour' + rate = "5/hour" diff --git a/cvat/apps/organizations/urls.py b/cvat/apps/organizations/urls.py index 068f72b0968d..4ec7fdc628bc 100644 --- a/cvat/apps/organizations/urls.py +++ b/cvat/apps/organizations/urls.py @@ -3,11 +3,12 @@ # SPDX-License-Identifier: MIT from rest_framework.routers import DefaultRouter + from .views import InvitationViewSet, MembershipViewSet, OrganizationViewSet router = DefaultRouter(trailing_slash=False) -router.register('organizations', OrganizationViewSet) -router.register('invitations', InvitationViewSet) -router.register('memberships', MembershipViewSet) +router.register("organizations", OrganizationViewSet) +router.register("invitations", InvitationViewSet) +router.register("memberships", MembershipViewSet) urlpatterns = router.urls diff --git a/cvat/apps/organizations/views.py b/cvat/apps/organizations/views.py index 11b92b29cad8..dbb1eeec9a9c 100644 --- a/cvat/apps/organizations/views.py +++ b/cvat/apps/organizations/views.py @@ -3,76 +3,87 @@ # # SPDX-License-Identifier: MIT -from django.utils.crypto import get_random_string -from django.db import transaction from django.core.exceptions import ImproperlyConfigured - -from rest_framework import mixins, viewsets, status -from rest_framework.permissions import SAFE_METHODS +from django.db import transaction +from django.utils.crypto import get_random_string +from drf_spectacular.utils import OpenApiResponse, extend_schema, extend_schema_view +from rest_framework import mixins, status, viewsets from rest_framework.decorators import action +from rest_framework.permissions import SAFE_METHODS from rest_framework.response import Response -from drf_spectacular.utils import OpenApiResponse, extend_schema, extend_schema_view - +from cvat.apps.engine.mixins import PartialUpdateModelMixin from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS from cvat.apps.organizations.permissions import ( - InvitationPermission, MembershipPermission, OrganizationPermission) + InvitationPermission, + MembershipPermission, + OrganizationPermission, +) from cvat.apps.organizations.throttle import ResendOrganizationInvitationThrottle -from cvat.apps.engine.mixins import PartialUpdateModelMixin from .models import Invitation, Membership, Organization - from .serializers import ( - InvitationReadSerializer, InvitationWriteSerializer, - MembershipReadSerializer, MembershipWriteSerializer, - OrganizationReadSerializer, OrganizationWriteSerializer, - AcceptInvitationReadSerializer) + AcceptInvitationReadSerializer, + InvitationReadSerializer, + InvitationWriteSerializer, + MembershipReadSerializer, + MembershipWriteSerializer, + OrganizationReadSerializer, + OrganizationWriteSerializer, +) + -@extend_schema(tags=['organizations']) +@extend_schema(tags=["organizations"]) @extend_schema_view( retrieve=extend_schema( - summary='Get organization details', + summary="Get organization details", responses={ - '200': OrganizationReadSerializer, - }), + "200": OrganizationReadSerializer, + }, + ), list=extend_schema( - summary='List organizations', + summary="List organizations", responses={ - '200': OrganizationReadSerializer(many=True), - }), + "200": OrganizationReadSerializer(many=True), + }, + ), partial_update=extend_schema( - summary='Update an organization', + summary="Update an organization", request=OrganizationWriteSerializer(partial=True), responses={ - '200': OrganizationReadSerializer, # check OrganizationWriteSerializer.to_representation - }), + "200": OrganizationReadSerializer, # check OrganizationWriteSerializer.to_representation + }, + ), create=extend_schema( - summary='Create an organization', + summary="Create an organization", request=OrganizationWriteSerializer, responses={ - '201': OrganizationReadSerializer, # check OrganizationWriteSerializer.to_representation - }), + "201": OrganizationReadSerializer, # check OrganizationWriteSerializer.to_representation + }, + ), destroy=extend_schema( - summary='Delete an organization', + summary="Delete an organization", responses={ - '204': OpenApiResponse(description='The organization has been deleted'), - }) + "204": OpenApiResponse(description="The organization has been deleted"), + }, + ), ) -class OrganizationViewSet(viewsets.GenericViewSet, - mixins.RetrieveModelMixin, - mixins.ListModelMixin, - mixins.CreateModelMixin, - mixins.DestroyModelMixin, - PartialUpdateModelMixin, - ): - queryset = Organization.objects.select_related('owner').all() - search_fields = ('name', 'owner', 'slug') - filter_fields = list(search_fields) + ['id'] +class OrganizationViewSet( + viewsets.GenericViewSet, + mixins.RetrieveModelMixin, + mixins.ListModelMixin, + mixins.CreateModelMixin, + mixins.DestroyModelMixin, + PartialUpdateModelMixin, +): + queryset = Organization.objects.select_related("owner").all() + search_fields = ("name", "owner", "slug") + filter_fields = list(search_fields) + ["id"] simple_filters = list(search_fields) - lookup_fields = {'owner': 'owner__username'} + lookup_fields = {"owner": "owner__username"} ordering_fields = list(filter_fields) - ordering = '-id' - http_method_names = ['get', 'post', 'patch', 'delete', 'head', 'options'] + ordering = "-id" + http_method_names = ["get", "post", "patch", "delete", "head", "options"] iam_organization_field = None def get_queryset(self): @@ -88,50 +99,60 @@ def get_serializer_class(self): return OrganizationWriteSerializer def perform_create(self, serializer): - extra_kwargs = { 'owner': self.request.user } - if not serializer.validated_data.get('name'): - extra_kwargs.update({ 'name': serializer.validated_data['slug'] }) + extra_kwargs = {"owner": self.request.user} + if not serializer.validated_data.get("name"): + extra_kwargs.update({"name": serializer.validated_data["slug"]}) serializer.save(**extra_kwargs) class Meta: model = Membership - fields = ("user", ) + fields = ("user",) -@extend_schema(tags=['memberships']) + +@extend_schema(tags=["memberships"]) @extend_schema_view( retrieve=extend_schema( - summary='Get membership details', + summary="Get membership details", responses={ - '200': MembershipReadSerializer, - }), + "200": MembershipReadSerializer, + }, + ), list=extend_schema( - summary='List memberships', + summary="List memberships", responses={ - '200': MembershipReadSerializer(many=True), - }), + "200": MembershipReadSerializer(many=True), + }, + ), partial_update=extend_schema( - summary='Update a membership', + summary="Update a membership", request=MembershipWriteSerializer(partial=True), responses={ - '200': MembershipReadSerializer, # check MembershipWriteSerializer.to_representation - }), + "200": MembershipReadSerializer, # check MembershipWriteSerializer.to_representation + }, + ), destroy=extend_schema( - summary='Delete a membership', + summary="Delete a membership", responses={ - '204': OpenApiResponse(description='The membership has been deleted'), - }) + "204": OpenApiResponse(description="The membership has been deleted"), + }, + ), ) -class MembershipViewSet(mixins.RetrieveModelMixin, mixins.DestroyModelMixin, - mixins.ListModelMixin, PartialUpdateModelMixin, viewsets.GenericViewSet): - queryset = Membership.objects.select_related('invitation', 'user').all() - ordering = '-id' - http_method_names = ['get', 'patch', 'delete', 'head', 'options'] - search_fields = ('user', 'role') - filter_fields = list(search_fields) + ['id'] +class MembershipViewSet( + mixins.RetrieveModelMixin, + mixins.DestroyModelMixin, + mixins.ListModelMixin, + PartialUpdateModelMixin, + viewsets.GenericViewSet, +): + queryset = Membership.objects.select_related("invitation", "user").all() + ordering = "-id" + http_method_names = ["get", "patch", "delete", "head", "options"] + search_fields = ("user", "role") + filter_fields = list(search_fields) + ["id"] simple_filters = list(search_fields) ordering_fields = list(filter_fields) - lookup_fields = {'user': 'user__username'} - iam_organization_field = 'organization' + lookup_fields = {"user": "user__username"} + iam_organization_field = "organization" def get_serializer_class(self): if self.request.method in SAFE_METHODS: @@ -142,86 +163,98 @@ def get_serializer_class(self): def get_queryset(self): queryset = super().get_queryset() - if self.action == 'list': + if self.action == "list": permission = MembershipPermission.create_scope_list(self.request) queryset = permission.filter(queryset) return queryset -@extend_schema(tags=['invitations']) + +@extend_schema(tags=["invitations"]) @extend_schema_view( retrieve=extend_schema( - summary='Get invitation details', + summary="Get invitation details", responses={ - '200': InvitationReadSerializer, - }), + "200": InvitationReadSerializer, + }, + ), list=extend_schema( - summary='List invitations', + summary="List invitations", responses={ - '200': InvitationReadSerializer(many=True), - }), + "200": InvitationReadSerializer(many=True), + }, + ), partial_update=extend_schema( - summary='Update an invitation', + summary="Update an invitation", request=InvitationWriteSerializer(partial=True), responses={ - '200': InvitationReadSerializer, # check InvitationWriteSerializer.to_representation - }), + "200": InvitationReadSerializer, # check InvitationWriteSerializer.to_representation + }, + ), create=extend_schema( - summary='Create an invitation', + summary="Create an invitation", request=InvitationWriteSerializer, parameters=ORGANIZATION_OPEN_API_PARAMETERS, responses={ - '201': InvitationReadSerializer, # check InvitationWriteSerializer.to_representation - }), + "201": InvitationReadSerializer, # check InvitationWriteSerializer.to_representation + }, + ), destroy=extend_schema( - summary='Delete an invitation', + summary="Delete an invitation", responses={ - '204': OpenApiResponse(description='The invitation has been deleted'), - }), + "204": OpenApiResponse(description="The invitation has been deleted"), + }, + ), accept=extend_schema( - operation_id='invitations_accept', + operation_id="invitations_accept", request=None, - summary='Accept an invitation', + summary="Accept an invitation", responses={ - '200': OpenApiResponse(response=AcceptInvitationReadSerializer, description='The invitation is accepted'), - '400': OpenApiResponse(description='The invitation is expired or already accepted'), - }), + "200": OpenApiResponse( + response=AcceptInvitationReadSerializer, description="The invitation is accepted" + ), + "400": OpenApiResponse(description="The invitation is expired or already accepted"), + }, + ), decline=extend_schema( - operation_id='invitations_decline', + operation_id="invitations_decline", request=None, - summary='Decline an invitation', + summary="Decline an invitation", responses={ - '204': OpenApiResponse(description='The invitation has been declined'), - }), + "204": OpenApiResponse(description="The invitation has been declined"), + }, + ), resend=extend_schema( - operation_id='invitations_resend', - summary='Resend an invitation', + operation_id="invitations_resend", + summary="Resend an invitation", request=None, responses={ - '204': OpenApiResponse(description='Invitation has been sent'), - '400': OpenApiResponse(description='The invitation is already accepted'), - }), + "204": OpenApiResponse(description="Invitation has been sent"), + "400": OpenApiResponse(description="The invitation is already accepted"), + }, + ), ) -class InvitationViewSet(viewsets.GenericViewSet, - mixins.RetrieveModelMixin, - mixins.ListModelMixin, - PartialUpdateModelMixin, - mixins.CreateModelMixin, - mixins.DestroyModelMixin, - ): +class InvitationViewSet( + viewsets.GenericViewSet, + mixins.RetrieveModelMixin, + mixins.ListModelMixin, + PartialUpdateModelMixin, + mixins.CreateModelMixin, + mixins.DestroyModelMixin, +): queryset = Invitation.objects.all() - http_method_names = ['get', 'post', 'patch', 'delete', 'head', 'options'] - iam_organization_field = 'membership__organization' + http_method_names = ["get", "post", "patch", "delete", "head", "options"] + iam_organization_field = "membership__organization" - search_fields = ('owner',) - filter_fields = list(search_fields) + ['user_id', 'accepted'] + search_fields = ("owner",) + filter_fields = list(search_fields) + ["user_id", "accepted"] simple_filters = list(search_fields) - ordering_fields = list(simple_filters) + ['created_date'] - ordering = '-created_date' + ordering_fields = list(simple_filters) + ["created_date"] + ordering = "-created_date" lookup_fields = { - 'owner': 'owner__username', - 'user_id': 'membership__user__id', - 'accepted': 'membership__is_active', + "owner": "owner__username", + "user_id": "membership__user__id", + "accepted": "membership__is_active", } def get_serializer_class(self): @@ -242,7 +275,10 @@ def create(self, request): try: self.perform_create(serializer) except ImproperlyConfigured: - return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR, data="Email backend is not configured.") + return Response( + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + data="Email backend is not configured.", + ) return Response(serializer.data, status=status.HTTP_201_CREATED) @@ -250,51 +286,75 @@ def perform_create(self, serializer): serializer.save( owner=self.request.user, key=get_random_string(length=64), - organization=self.request.iam_context['organization'], + organization=self.request.iam_context["organization"], request=self.request, ) def perform_update(self, serializer): - if 'accepted' in self.request.query_params: + if "accepted" in self.request.query_params: serializer.instance.accept() else: super().perform_update(serializer) @transaction.atomic - @action(detail=True, methods=['POST'], url_path='accept') + @action(detail=True, methods=["POST"], url_path="accept") def accept(self, request, pk): try: - invitation = self.get_object() # force to call check_object_permissions + invitation = self.get_object() # force to call check_object_permissions if invitation.expired: - return Response(status=status.HTTP_400_BAD_REQUEST, data="Your invitation is expired. Please contact organization owner to renew it.") + return Response( + status=status.HTTP_400_BAD_REQUEST, + data="Your invitation is expired. Please contact organization owner to renew it.", + ) if invitation.membership.is_active: - return Response(status=status.HTTP_400_BAD_REQUEST, data="Your invitation is already accepted.") + return Response( + status=status.HTTP_400_BAD_REQUEST, data="Your invitation is already accepted." + ) invitation.accept() - response_serializer = AcceptInvitationReadSerializer(data={'organization_slug': invitation.membership.organization.slug}) + response_serializer = AcceptInvitationReadSerializer( + data={"organization_slug": invitation.membership.organization.slug} + ) response_serializer.is_valid(raise_exception=True) return Response(status=status.HTTP_200_OK, data=response_serializer.data) except Invitation.DoesNotExist: - return Response(status=status.HTTP_404_NOT_FOUND, data="This invitation does not exist. Please contact organization owner.") + return Response( + status=status.HTTP_404_NOT_FOUND, + data="This invitation does not exist. Please contact organization owner.", + ) - @action(detail=True, methods=['POST'], url_path='resend', throttle_classes=[ResendOrganizationInvitationThrottle]) + @action( + detail=True, + methods=["POST"], + url_path="resend", + throttle_classes=[ResendOrganizationInvitationThrottle], + ) def resend(self, request, pk): try: - invitation = self.get_object() # force to call check_object_permissions + invitation = self.get_object() # force to call check_object_permissions if invitation.membership.is_active: - return Response(status=status.HTTP_400_BAD_REQUEST, data="This invitation is already accepted.") + return Response( + status=status.HTTP_400_BAD_REQUEST, data="This invitation is already accepted." + ) invitation.send(request) return Response(status=status.HTTP_204_NO_CONTENT) except Invitation.DoesNotExist: - return Response(status=status.HTTP_404_NOT_FOUND, data="This invitation does not exist.") + return Response( + status=status.HTTP_404_NOT_FOUND, data="This invitation does not exist." + ) except ImproperlyConfigured: - return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR, data="Email backend is not configured.") + return Response( + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + data="Email backend is not configured.", + ) - @action(detail=True, methods=['POST'], url_path='decline') + @action(detail=True, methods=["POST"], url_path="decline") def decline(self, request, pk): try: - invitation = self.get_object() # force to call check_object_permissions + invitation = self.get_object() # force to call check_object_permissions membership = invitation.membership membership.delete() return Response(status=status.HTTP_204_NO_CONTENT) except Invitation.DoesNotExist: - return Response(status=status.HTTP_404_NOT_FOUND, data="This invitation does not exist.") + return Response( + status=status.HTTP_404_NOT_FOUND, data="This invitation does not exist." + ) diff --git a/cvat/apps/quality_control/migrations/0006_rename_match_empty_frames_qualitysettings_empty_is_annotated.py b/cvat/apps/quality_control/migrations/0006_rename_match_empty_frames_qualitysettings_empty_is_annotated.py new file mode 100644 index 000000000000..ea2f74927309 --- /dev/null +++ b/cvat/apps/quality_control/migrations/0006_rename_match_empty_frames_qualitysettings_empty_is_annotated.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-12-29 19:08 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("quality_control", "0005_qualitysettings_match_empty"), + ] + + operations = [ + migrations.RenameField( + model_name="qualitysettings", + old_name="match_empty_frames", + new_name="empty_is_annotated", + ), + ] diff --git a/cvat/apps/quality_control/models.py b/cvat/apps/quality_control/models.py index a5359e4fe944..c521ac276f31 100644 --- a/cvat/apps/quality_control/models.py +++ b/cvat/apps/quality_control/models.py @@ -235,7 +235,7 @@ class QualitySettings(models.Model): compare_attributes = models.BooleanField() - match_empty_frames = models.BooleanField(default=False) + empty_is_annotated = models.BooleanField(default=False) target_metric = models.CharField( max_length=32, diff --git a/cvat/apps/quality_control/quality_reports.py b/cvat/apps/quality_control/quality_reports.py index 25b5c962dc26..f757aeabc61a 100644 --- a/cvat/apps/quality_control/quality_reports.py +++ b/cvat/apps/quality_control/quality_reports.py @@ -215,10 +215,11 @@ class ComparisonParameters(_Serializable): panoptic_comparison: bool = True "Use only the visible part of the masks and polygons in comparisons" - match_empty_frames: bool = False + empty_is_annotated: bool = False """ - Consider unannotated (empty) frames as matching. If disabled, quality metrics, such as accuracy, - will be 0 if both GT and DS frames have no annotations. When enabled, they will be 1 instead. + Consider unannotated (empty) frames virtually annotated as "nothing". + If disabled, quality metrics, such as accuracy, will be 0 if both GT and DS frames + have no annotations. When enabled, they will be 1 instead. This will also add virtual annotations to empty frames in the comparison results. """ @@ -1977,15 +1978,20 @@ def _find_closest_unmatched_shape(shape: dm.Annotation): gt_label_idx = label_id_map[gt_ann.label] if gt_ann else self._UNMATCHED_IDX confusion_matrix[ds_label_idx, gt_label_idx] += 1 - if self.settings.match_empty_frames and not gt_item.annotations and not ds_item.annotations: + if self.settings.empty_is_annotated: # Add virtual annotations for empty frames - valid_labels_count = 1 - total_labels_count = 1 + if not gt_item.annotations and not ds_item.annotations: + valid_labels_count = 1 + total_labels_count = 1 - valid_shapes_count = 1 - total_shapes_count = 1 - ds_shapes_count = 1 - gt_shapes_count = 1 + valid_shapes_count = 1 + total_shapes_count = 1 + + if not ds_item.annotations: + ds_shapes_count = 1 + + if not gt_item.annotations: + gt_shapes_count = 1 self._frame_results[frame_id] = ComparisonReportFrameSummary( annotations=self._generate_frame_annotations_summary( @@ -2078,12 +2084,17 @@ def _generate_frame_annotations_summary( ) -> ComparisonReportAnnotationsSummary: summary = self._compute_annotations_summary(confusion_matrix, confusion_matrix_labels) - if self.settings.match_empty_frames and summary.total_count == 0: + if self.settings.empty_is_annotated: # Add virtual annotations for empty frames - summary.valid_count = 1 - summary.total_count = 1 - summary.ds_count = 1 - summary.gt_count = 1 + if not summary.total_count: + summary.valid_count = 1 + summary.total_count = 1 + + if not summary.ds_count: + summary.ds_count = 1 + + if not summary.gt_count: + summary.gt_count = 1 return summary @@ -2108,14 +2119,26 @@ def _generate_dataset_annotations_summary( ), ) mean_ious = [] - empty_frame_count = 0 + empty_gt_frames = set() + empty_ds_frames = set() confusion_matrix_labels, confusion_matrix, _ = self._make_zero_confusion_matrix() - for frame_result in frame_summaries.values(): + for frame_id, frame_result in frame_summaries.items(): confusion_matrix += frame_result.annotations.confusion_matrix.rows - if not np.any(frame_result.annotations.confusion_matrix.rows): - empty_frame_count += 1 + if self.settings.empty_is_annotated and not np.any( + frame_result.annotations.confusion_matrix.rows[ + np.triu_indices_from(frame_result.annotations.confusion_matrix.rows) + ] + ): + empty_ds_frames.add(frame_id) + + if self.settings.empty_is_annotated and not np.any( + frame_result.annotations.confusion_matrix.rows[ + np.tril_indices_from(frame_result.annotations.confusion_matrix.rows) + ] + ): + empty_gt_frames.add(frame_id) if annotation_components is None: annotation_components = deepcopy(frame_result.annotation_components) @@ -2128,13 +2151,13 @@ def _generate_dataset_annotations_summary( confusion_matrix, confusion_matrix_labels ) - if self.settings.match_empty_frames and empty_frame_count: + if self.settings.empty_is_annotated: # Add virtual annotations for empty frames, # they are not included in the confusion matrix - annotation_summary.valid_count += empty_frame_count - annotation_summary.total_count += empty_frame_count - annotation_summary.ds_count += empty_frame_count - annotation_summary.gt_count += empty_frame_count + annotation_summary.valid_count += len(empty_ds_frames & empty_gt_frames) + annotation_summary.total_count += len(empty_ds_frames | empty_gt_frames) + annotation_summary.ds_count += len(empty_ds_frames) + annotation_summary.gt_count += len(empty_gt_frames) # Cannot be computed in accumulate() annotation_components.shape.mean_iou = np.mean(mean_ious) diff --git a/cvat/apps/quality_control/serializers.py b/cvat/apps/quality_control/serializers.py index 6164abc12200..11a5e0d8b02e 100644 --- a/cvat/apps/quality_control/serializers.py +++ b/cvat/apps/quality_control/serializers.py @@ -92,7 +92,7 @@ class Meta: "object_visibility_threshold", "panoptic_comparison", "compare_attributes", - "match_empty_frames", + "empty_is_annotated", ) read_only_fields = ( "id", @@ -100,7 +100,7 @@ class Meta: ) extra_kwargs = {k: {"required": False} for k in fields} - extra_kwargs.setdefault("match_empty_frames", {}).setdefault("default", False) + extra_kwargs.setdefault("empty_is_annotated", {}).setdefault("default", False) for field_name, help_text in { "target_metric": "The primary metric used for quality estimation", @@ -166,9 +166,9 @@ class Meta: Use only the visible part of the masks and polygons in comparisons """, "compare_attributes": "Enables or disables annotation attribute comparison", - "match_empty_frames": """ - Count empty frames as matching. This affects target metrics like accuracy in cases - there are no annotations. If disabled, frames without annotations + "empty_is_annotated": """ + Consider empty frames annotated as "empty". This affects target metrics like + accuracy in cases there are no annotations. If disabled, frames without annotations are counted as not matching (accuracy is 0). If enabled, accuracy will be 1 instead. This will also add virtual annotations to empty frames in the comparison results. """, diff --git a/cvat/apps/webhooks/apps.py b/cvat/apps/webhooks/apps.py index ac193baed755..0b4cf34198f7 100644 --- a/cvat/apps/webhooks/apps.py +++ b/cvat/apps/webhooks/apps.py @@ -9,7 +9,8 @@ class WebhooksConfig(AppConfig): name = "cvat.apps.webhooks" def ready(self): - from . import signals # pylint: disable=unused-import - from cvat.apps.iam.permissions import load_app_permissions + load_app_permissions(self) + + from . import signals # pylint: disable=unused-import diff --git a/cvat/apps/webhooks/event_type.py b/cvat/apps/webhooks/event_type.py index 59cdb6cf99ed..ef98e5212824 100644 --- a/cvat/apps/webhooks/event_type.py +++ b/cvat/apps/webhooks/event_type.py @@ -47,7 +47,11 @@ class AllEvents: class ProjectEvents: webhook_type = WebhookTypeChoice.PROJECT - events = [*Events.select(["task", "job", "label", "issue", "comment"]), event_name("update", "project"), event_name("delete", "project")] + events = [ + *Events.select(["task", "job", "label", "issue", "comment"]), + event_name("update", "project"), + event_name("delete", "project"), + ] class OrganizationEvents: diff --git a/cvat/apps/webhooks/migrations/0001_initial.py b/cvat/apps/webhooks/migrations/0001_initial.py index fe8f296b0514..e3638bd6be97 100644 --- a/cvat/apps/webhooks/migrations/0001_initial.py +++ b/cvat/apps/webhooks/migrations/0001_initial.py @@ -1,9 +1,10 @@ # Generated by Django 3.2.15 on 2022-09-19 08:26 -import cvat.apps.webhooks.models +import django.db.models.deletion from django.conf import settings from django.db import migrations, models -import django.db.models.deletion + +import cvat.apps.webhooks.models class Migration(migrations.Migration): @@ -11,54 +12,120 @@ class Migration(migrations.Migration): initial = True dependencies = [ - ('engine', '0060_alter_label_parent'), + ("engine", "0060_alter_label_parent"), migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ('organizations', '0001_initial'), + ("organizations", "0001_initial"), ] operations = [ migrations.CreateModel( - name='Webhook', + name="Webhook", fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('target_url', models.URLField()), - ('description', models.CharField(blank=True, default='', max_length=128)), - ('events', models.CharField(default='', max_length=4096)), - ('type', models.CharField(choices=[('organization', 'ORGANIZATION'), ('project', 'PROJECT')], max_length=16)), - ('content_type', models.CharField(choices=[('application/json', 'JSON')], default=cvat.apps.webhooks.models.WebhookContentTypeChoice['JSON'], max_length=64)), - ('secret', models.CharField(blank=True, default='', max_length=64)), - ('is_active', models.BooleanField(default=True)), - ('enable_ssl', models.BooleanField(default=True)), - ('created_date', models.DateTimeField(auto_now_add=True)), - ('updated_date', models.DateTimeField(auto_now=True)), - ('organization', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='+', to='organizations.organization')), - ('owner', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to=settings.AUTH_USER_MODEL)), - ('project', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='+', to='engine.project')), + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("target_url", models.URLField()), + ("description", models.CharField(blank=True, default="", max_length=128)), + ("events", models.CharField(default="", max_length=4096)), + ( + "type", + models.CharField( + choices=[("organization", "ORGANIZATION"), ("project", "PROJECT")], + max_length=16, + ), + ), + ( + "content_type", + models.CharField( + choices=[("application/json", "JSON")], + default=cvat.apps.webhooks.models.WebhookContentTypeChoice["JSON"], + max_length=64, + ), + ), + ("secret", models.CharField(blank=True, default="", max_length=64)), + ("is_active", models.BooleanField(default=True)), + ("enable_ssl", models.BooleanField(default=True)), + ("created_date", models.DateTimeField(auto_now_add=True)), + ("updated_date", models.DateTimeField(auto_now=True)), + ( + "organization", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to="organizations.organization", + ), + ), + ( + "owner", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="+", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "project", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to="engine.project", + ), + ), ], options={ - 'default_permissions': (), + "default_permissions": (), }, ), migrations.CreateModel( - name='WebhookDelivery', + name="WebhookDelivery", fields=[ - ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('event', models.CharField(max_length=64)), - ('status_code', models.CharField(max_length=128, null=True)), - ('redelivery', models.BooleanField(default=False)), - ('created_date', models.DateTimeField(auto_now_add=True)), - ('updated_date', models.DateTimeField(auto_now=True)), - ('changed_fields', models.CharField(default='', max_length=4096)), - ('request', models.JSONField(default=dict)), - ('response', models.JSONField(default=dict)), - ('webhook', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='deliveries', to='webhooks.webhook')), + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("event", models.CharField(max_length=64)), + ("status_code", models.CharField(max_length=128, null=True)), + ("redelivery", models.BooleanField(default=False)), + ("created_date", models.DateTimeField(auto_now_add=True)), + ("updated_date", models.DateTimeField(auto_now=True)), + ("changed_fields", models.CharField(default="", max_length=4096)), + ("request", models.JSONField(default=dict)), + ("response", models.JSONField(default=dict)), + ( + "webhook", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="deliveries", + to="webhooks.webhook", + ), + ), ], options={ - 'default_permissions': (), + "default_permissions": (), }, ), migrations.AddConstraint( - model_name='webhook', - constraint=models.CheckConstraint(check=models.Q(models.Q(('project_id__isnull', False), ('type', 'project')), models.Q(('organization_id__isnull', False), ('project_id__isnull', True), ('type', 'organization')), _connector='OR'), name='webhooks_project_or_organization'), + model_name="webhook", + constraint=models.CheckConstraint( + check=models.Q( + models.Q(("project_id__isnull", False), ("type", "project")), + models.Q( + ("organization_id__isnull", False), + ("project_id__isnull", True), + ("type", "organization"), + ), + _connector="OR", + ), + name="webhooks_project_or_organization", + ), ), ] diff --git a/cvat/apps/webhooks/migrations/0002_alter_webhookdelivery_status_code.py b/cvat/apps/webhooks/migrations/0002_alter_webhookdelivery_status_code.py index fd1a2397d249..0429b1445117 100644 --- a/cvat/apps/webhooks/migrations/0002_alter_webhookdelivery_status_code.py +++ b/cvat/apps/webhooks/migrations/0002_alter_webhookdelivery_status_code.py @@ -6,13 +6,77 @@ class Migration(migrations.Migration): dependencies = [ - ('webhooks', '0001_initial'), + ("webhooks", "0001_initial"), ] operations = [ migrations.AlterField( - model_name='webhookdelivery', - name='status_code', - field=models.IntegerField(choices=[('CONTINUE', 100), ('SWITCHING_PROTOCOLS', 101), ('PROCESSING', 102), ('OK', 200), ('CREATED', 201), ('ACCEPTED', 202), ('NON_AUTHORITATIVE_INFORMATION', 203), ('NO_CONTENT', 204), ('RESET_CONTENT', 205), ('PARTIAL_CONTENT', 206), ('MULTI_STATUS', 207), ('ALREADY_REPORTED', 208), ('IM_USED', 226), ('MULTIPLE_CHOICES', 300), ('MOVED_PERMANENTLY', 301), ('FOUND', 302), ('SEE_OTHER', 303), ('NOT_MODIFIED', 304), ('USE_PROXY', 305), ('TEMPORARY_REDIRECT', 307), ('PERMANENT_REDIRECT', 308), ('BAD_REQUEST', 400), ('UNAUTHORIZED', 401), ('PAYMENT_REQUIRED', 402), ('FORBIDDEN', 403), ('NOT_FOUND', 404), ('METHOD_NOT_ALLOWED', 405), ('NOT_ACCEPTABLE', 406), ('PROXY_AUTHENTICATION_REQUIRED', 407), ('REQUEST_TIMEOUT', 408), ('CONFLICT', 409), ('GONE', 410), ('LENGTH_REQUIRED', 411), ('PRECONDITION_FAILED', 412), ('REQUEST_ENTITY_TOO_LARGE', 413), ('REQUEST_URI_TOO_LONG', 414), ('UNSUPPORTED_MEDIA_TYPE', 415), ('REQUESTED_RANGE_NOT_SATISFIABLE', 416), ('EXPECTATION_FAILED', 417), ('MISDIRECTED_REQUEST', 421), ('UNPROCESSABLE_ENTITY', 422), ('LOCKED', 423), ('FAILED_DEPENDENCY', 424), ('UPGRADE_REQUIRED', 426), ('PRECONDITION_REQUIRED', 428), ('TOO_MANY_REQUESTS', 429), ('REQUEST_HEADER_FIELDS_TOO_LARGE', 431), ('UNAVAILABLE_FOR_LEGAL_REASONS', 451), ('INTERNAL_SERVER_ERROR', 500), ('NOT_IMPLEMENTED', 501), ('BAD_GATEWAY', 502), ('SERVICE_UNAVAILABLE', 503), ('GATEWAY_TIMEOUT', 504), ('HTTP_VERSION_NOT_SUPPORTED', 505), ('VARIANT_ALSO_NEGOTIATES', 506), ('INSUFFICIENT_STORAGE', 507), ('LOOP_DETECTED', 508), ('NOT_EXTENDED', 510), ('NETWORK_AUTHENTICATION_REQUIRED', 511)], default=None, null=True), + model_name="webhookdelivery", + name="status_code", + field=models.IntegerField( + choices=[ + ("CONTINUE", 100), + ("SWITCHING_PROTOCOLS", 101), + ("PROCESSING", 102), + ("OK", 200), + ("CREATED", 201), + ("ACCEPTED", 202), + ("NON_AUTHORITATIVE_INFORMATION", 203), + ("NO_CONTENT", 204), + ("RESET_CONTENT", 205), + ("PARTIAL_CONTENT", 206), + ("MULTI_STATUS", 207), + ("ALREADY_REPORTED", 208), + ("IM_USED", 226), + ("MULTIPLE_CHOICES", 300), + ("MOVED_PERMANENTLY", 301), + ("FOUND", 302), + ("SEE_OTHER", 303), + ("NOT_MODIFIED", 304), + ("USE_PROXY", 305), + ("TEMPORARY_REDIRECT", 307), + ("PERMANENT_REDIRECT", 308), + ("BAD_REQUEST", 400), + ("UNAUTHORIZED", 401), + ("PAYMENT_REQUIRED", 402), + ("FORBIDDEN", 403), + ("NOT_FOUND", 404), + ("METHOD_NOT_ALLOWED", 405), + ("NOT_ACCEPTABLE", 406), + ("PROXY_AUTHENTICATION_REQUIRED", 407), + ("REQUEST_TIMEOUT", 408), + ("CONFLICT", 409), + ("GONE", 410), + ("LENGTH_REQUIRED", 411), + ("PRECONDITION_FAILED", 412), + ("REQUEST_ENTITY_TOO_LARGE", 413), + ("REQUEST_URI_TOO_LONG", 414), + ("UNSUPPORTED_MEDIA_TYPE", 415), + ("REQUESTED_RANGE_NOT_SATISFIABLE", 416), + ("EXPECTATION_FAILED", 417), + ("MISDIRECTED_REQUEST", 421), + ("UNPROCESSABLE_ENTITY", 422), + ("LOCKED", 423), + ("FAILED_DEPENDENCY", 424), + ("UPGRADE_REQUIRED", 426), + ("PRECONDITION_REQUIRED", 428), + ("TOO_MANY_REQUESTS", 429), + ("REQUEST_HEADER_FIELDS_TOO_LARGE", 431), + ("UNAVAILABLE_FOR_LEGAL_REASONS", 451), + ("INTERNAL_SERVER_ERROR", 500), + ("NOT_IMPLEMENTED", 501), + ("BAD_GATEWAY", 502), + ("SERVICE_UNAVAILABLE", 503), + ("GATEWAY_TIMEOUT", 504), + ("HTTP_VERSION_NOT_SUPPORTED", 505), + ("VARIANT_ALSO_NEGOTIATES", 506), + ("INSUFFICIENT_STORAGE", 507), + ("LOOP_DETECTED", 508), + ("NOT_EXTENDED", 510), + ("NETWORK_AUTHENTICATION_REQUIRED", 511), + ], + default=None, + null=True, + ), ), ] diff --git a/cvat/apps/webhooks/migrations/0003_alter_webhookdelivery_status_code.py b/cvat/apps/webhooks/migrations/0003_alter_webhookdelivery_status_code.py index 676f03a2dc9b..234a4d685d58 100644 --- a/cvat/apps/webhooks/migrations/0003_alter_webhookdelivery_status_code.py +++ b/cvat/apps/webhooks/migrations/0003_alter_webhookdelivery_status_code.py @@ -6,13 +6,13 @@ class Migration(migrations.Migration): dependencies = [ - ('webhooks', '0002_alter_webhookdelivery_status_code'), + ("webhooks", "0002_alter_webhookdelivery_status_code"), ] operations = [ migrations.AlterField( - model_name='webhookdelivery', - name='status_code', + model_name="webhookdelivery", + name="status_code", field=models.PositiveIntegerField(default=None, null=True), ), ] diff --git a/cvat/apps/webhooks/migrations/0004_alter_webhook_target_url.py b/cvat/apps/webhooks/migrations/0004_alter_webhook_target_url.py index 00be6a309df2..f2f716f8cd88 100644 --- a/cvat/apps/webhooks/migrations/0004_alter_webhook_target_url.py +++ b/cvat/apps/webhooks/migrations/0004_alter_webhook_target_url.py @@ -6,13 +6,13 @@ class Migration(migrations.Migration): dependencies = [ - ('webhooks', '0003_alter_webhookdelivery_status_code'), + ("webhooks", "0003_alter_webhookdelivery_status_code"), ] operations = [ migrations.AlterField( - model_name='webhook', - name='target_url', + model_name="webhook", + name="target_url", field=models.URLField(max_length=8192), ), ] diff --git a/cvat/apps/webhooks/models.py b/cvat/apps/webhooks/models.py index 104faccd60a4..650cd814fae0 100644 --- a/cvat/apps/webhooks/models.py +++ b/cvat/apps/webhooks/models.py @@ -53,9 +53,7 @@ class Webhook(TimestampedModel): owner = models.ForeignKey( User, null=True, blank=True, on_delete=models.SET_NULL, related_name="+" ) - project = models.ForeignKey( - Project, null=True, on_delete=models.CASCADE, related_name="+" - ) + project = models.ForeignKey(Project, null=True, on_delete=models.CASCADE, related_name="+") organization = models.ForeignKey( Organization, null=True, on_delete=models.CASCADE, related_name="+" ) @@ -66,9 +64,7 @@ class Meta: models.CheckConstraint( name="webhooks_project_or_organization", check=( - models.Q( - type=WebhookTypeChoice.PROJECT.value, project_id__isnull=False - ) + models.Q(type=WebhookTypeChoice.PROJECT.value, project_id__isnull=False) | models.Q( type=WebhookTypeChoice.ORGANIZATION.value, project_id__isnull=True, @@ -80,9 +76,7 @@ class Meta: class WebhookDelivery(TimestampedModel): - webhook = models.ForeignKey( - Webhook, on_delete=models.CASCADE, related_name="deliveries" - ) + webhook = models.ForeignKey(Webhook, on_delete=models.CASCADE, related_name="deliveries") event = models.CharField(max_length=64) status_code = models.PositiveIntegerField(null=True, default=None) diff --git a/cvat/apps/webhooks/permissions.py b/cvat/apps/webhooks/permissions.py index e5d132c55de6..3ce72bd350a4 100644 --- a/cvat/apps/webhooks/permissions.py +++ b/cvat/apps/webhooks/permissions.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: MIT from django.conf import settings - from rest_framework.exceptions import ValidationError from cvat.apps.engine.models import Project @@ -13,27 +12,29 @@ from .models import WebhookTypeChoice + class WebhookPermission(OpenPolicyAgentPermission): class Scopes(StrEnum): - CREATE = 'create' - CREATE_IN_PROJECT = 'create@project' - CREATE_IN_ORG = 'create@organization' - DELETE = 'delete' - UPDATE = 'update' - LIST = 'list' - VIEW = 'view' + CREATE = "create" + CREATE_IN_PROJECT = "create@project" + CREATE_IN_ORG = "create@organization" + DELETE = "delete" + UPDATE = "update" + LIST = "list" + VIEW = "view" @classmethod def create(cls, request, view, obj, iam_context): permissions = [] - if view.basename == 'webhook': - project_id = request.data.get('project_id') + if view.basename == "webhook": + project_id = request.data.get("project_id") for scope in cls.get_scopes(request, view, obj): - self = cls.create_base_perm(request, view, scope, iam_context, obj, - project_id=project_id) + self = cls.create_base_perm( + request, view, scope, iam_context, obj, project_id=project_id + ) permissions.append(self) - owner = request.data.get('owner_id') or request.data.get('owner') + owner = request.data.get("owner_id") or request.data.get("owner") if owner: perm = UserPermission.create_scope_view(iam_context, owner) permissions.append(perm) @@ -46,29 +47,29 @@ def create(cls, request, view, obj, iam_context): def __init__(self, **kwargs): super().__init__(**kwargs) - self.url = settings.IAM_OPA_DATA_URL + '/webhooks/allow' + self.url = settings.IAM_OPA_DATA_URL + "/webhooks/allow" @staticmethod def get_scopes(request, view, obj): Scopes = __class__.Scopes scope = { - ('create', 'POST'): Scopes.CREATE, - ('destroy', 'DELETE'): Scopes.DELETE, - ('partial_update', 'PATCH'): Scopes.UPDATE, - ('update', 'PUT'): Scopes.UPDATE, - ('list', 'GET'): Scopes.LIST, - ('retrieve', 'GET'): Scopes.VIEW, - ('ping', 'POST'): Scopes.UPDATE, - ('deliveries', 'GET'): Scopes.VIEW, - ('retrieve_delivery', 'GET'): Scopes.VIEW, - ('redelivery', 'POST'): Scopes.UPDATE, + ("create", "POST"): Scopes.CREATE, + ("destroy", "DELETE"): Scopes.DELETE, + ("partial_update", "PATCH"): Scopes.UPDATE, + ("update", "PUT"): Scopes.UPDATE, + ("list", "GET"): Scopes.LIST, + ("retrieve", "GET"): Scopes.VIEW, + ("ping", "POST"): Scopes.UPDATE, + ("deliveries", "GET"): Scopes.VIEW, + ("retrieve_delivery", "GET"): Scopes.VIEW, + ("redelivery", "POST"): Scopes.UPDATE, }[(view.action, request.method)] scopes = [] if scope == Scopes.CREATE: - webhook_type = request.data.get('type') + webhook_type = request.data.get("type") if webhook_type in [m.value for m in WebhookTypeChoice]: - scope = Scopes(str(scope) + f'@{webhook_type}') + scope = Scopes(str(scope) + f"@{webhook_type}") scopes.append(scope) else: scopes.append(scope) @@ -80,42 +81,52 @@ def get_resource(self): if self.obj: data = { "id": self.obj.id, - "owner": {"id": getattr(self.obj.owner, 'id', None) }, - 'organization': { - "id": getattr(self.obj.organization, 'id', None) - }, - "project": None + "owner": {"id": getattr(self.obj.owner, "id", None)}, + "organization": {"id": getattr(self.obj.organization, "id", None)}, + "project": None, } - if self.obj.type == 'project' and getattr(self.obj, 'project', None): - data['project'] = { - 'owner': {'id': getattr(self.obj.project.owner, 'id', None)} - } + if self.obj.type == "project" and getattr(self.obj, "project", None): + data["project"] = {"owner": {"id": getattr(self.obj.project.owner, "id", None)}} elif self.scope in [ __class__.Scopes.CREATE, __class__.Scopes.CREATE_IN_PROJECT, - __class__.Scopes.CREATE_IN_ORG + __class__.Scopes.CREATE_IN_ORG, ]: project = None if self.project_id: try: project = Project.objects.get(id=self.project_id) except Project.DoesNotExist: - raise ValidationError(f"Could not find project with provided id: {self.project_id}") + raise ValidationError( + f"Could not find project with provided id: {self.project_id}" + ) data = { - 'id': None, - 'owner': self.user_id, - 'project': { - 'owner': { - 'id': project.owner.id, - } if project.owner else None, - } if project else None, - 'organization': { - 'id': self.org_id, - } if self.org_id is not None else None, - 'user': { - 'id': self.user_id, - } + "id": None, + "owner": self.user_id, + "project": ( + { + "owner": ( + { + "id": project.owner.id, + } + if project.owner + else None + ), + } + if project + else None + ), + "organization": ( + { + "id": self.org_id, + } + if self.org_id is not None + else None + ), + "user": { + "id": self.user_id, + }, } return data diff --git a/cvat/apps/webhooks/rules/tests/generators/webhooks_test.gen.rego.py b/cvat/apps/webhooks/rules/tests/generators/webhooks_test.gen.rego.py index 66417f3d096d..2913bb5a2a6a 100644 --- a/cvat/apps/webhooks/rules/tests/generators/webhooks_test.gen.rego.py +++ b/cvat/apps/webhooks/rules/tests/generators/webhooks_test.gen.rego.py @@ -125,13 +125,15 @@ def get_data(scope, context, ownership, privilege, membership, resource, same_or "scope": scope, "auth": { "user": {"id": random.randrange(0, 100), "privilege": privilege}, - "organization": { - "id": random.randrange(100, 200), - "owner": {"id": random.randrange(200, 300)}, - "user": {"role": membership}, - } - if context == "organization" - else None, + "organization": ( + { + "id": random.randrange(100, 200), + "owner": {"id": random.randrange(200, 300)}, + "user": {"role": membership}, + } + if context == "organization" + else None + ), }, "resource": resource, } diff --git a/cvat/apps/webhooks/serializers.py b/cvat/apps/webhooks/serializers.py index d2bb1f309105..bd540de55fbd 100644 --- a/cvat/apps/webhooks/serializers.py +++ b/cvat/apps/webhooks/serializers.py @@ -7,13 +7,8 @@ from cvat.apps.engine.models import Project from cvat.apps.engine.serializers import BasicUserSerializer, WriteOnceMixin -from .event_type import EventTypeChoice, ProjectEvents, OrganizationEvents -from .models import ( - Webhook, - WebhookContentTypeChoice, - WebhookTypeChoice, - WebhookDelivery, -) +from .event_type import EventTypeChoice, OrganizationEvents, ProjectEvents +from .models import Webhook, WebhookContentTypeChoice, WebhookDelivery, WebhookTypeChoice class EventTypeValidator: @@ -35,9 +30,7 @@ def __call__(self, attrs, serializer): webhook_type == WebhookTypeChoice.ORGANIZATION and not events.issubset(set(OrganizationEvents.events)) ): - raise serializers.ValidationError( - f"Invalid events list for {webhook_type} webhook" - ) + raise serializers.ValidationError(f"Invalid events list for {webhook_type} webhook") class EventTypesSerializer(serializers.MultipleChoiceField): @@ -67,9 +60,7 @@ class WebhookReadSerializer(serializers.ModelSerializer): type = serializers.ChoiceField(choices=WebhookTypeChoice.choices()) content_type = serializers.ChoiceField(choices=WebhookContentTypeChoice.choices()) - last_status = serializers.IntegerField( - source="deliveries.last.status_code", read_only=True - ) + last_status = serializers.IntegerField(source="deliveries.last.status_code", read_only=True) last_delivery_date = serializers.DateTimeField( source="deliveries.last.updated_date", read_only=True @@ -104,9 +95,7 @@ class Meta: class WebhookWriteSerializer(WriteOnceMixin, serializers.ModelSerializer): events = EventTypesSerializer(write_only=True) - project_id = serializers.IntegerField( - write_only=True, allow_null=True, required=False - ) + project_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) def to_representation(self, instance): serializer = WebhookReadSerializer(instance, context=self.context) @@ -129,8 +118,8 @@ class Meta: validators = [EventTypeValidator()] def create(self, validated_data): - if (project_id := validated_data.get('project_id')) is not None: - validated_data['organization'] = Project.objects.get(pk=project_id).organization + if (project_id := validated_data.get("project_id")) is not None: + validated_data["organization"] = Project.objects.get(pk=project_id).organization db_webhook = Webhook.objects.create(**validated_data) return db_webhook diff --git a/cvat/apps/webhooks/signals.py b/cvat/apps/webhooks/signals.py index 3e17e8f3d8f6..6e08e35192dd 100644 --- a/cvat/apps/webhooks/signals.py +++ b/cvat/apps/webhooks/signals.py @@ -13,17 +13,21 @@ from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.db import transaction -from django.db.models.signals import (post_delete, post_save, pre_delete, - pre_save) +from django.db.models.signals import post_delete, post_save, pre_delete, pre_save from django.dispatch import Signal, receiver from cvat.apps.engine.models import Comment, Issue, Job, Project, Task from cvat.apps.engine.serializers import BasicUserSerializer -from cvat.apps.events.handlers import (get_request, get_serializer, get_user, - get_instance_diff, organization_id, - project_id) +from cvat.apps.events.handlers import ( + get_instance_diff, + get_request, + get_serializer, + get_user, + organization_id, + project_id, +) from cvat.apps.organizations.models import Invitation, Membership, Organization -from cvat.utils.http import make_requests_session, PROXIES_FOR_UNTRUSTED_URLS +from cvat.utils.http import PROXIES_FOR_UNTRUSTED_URLS, make_requests_session from .event_type import EventTypeChoice, event_name from .models import Webhook, WebhookDelivery, WebhookTypeChoice @@ -34,6 +38,7 @@ signal_redelivery = Signal() signal_ping = Signal() + def send_webhook(webhook, payload, redelivery=False): headers = {} if webhook.secret: @@ -59,9 +64,7 @@ def send_webhook(webhook, payload, redelivery=False): proxies=PROXIES_FOR_UNTRUSTED_URLS, ) status_code = response.status_code - response_body = response.raw.read( - RESPONSE_SIZE_LIMIT + 1, decode_content=True - ) + response_body = response.raw.read(RESPONSE_SIZE_LIMIT + 1, decode_content=True) except requests.ConnectionError: status_code = HTTPStatus.BAD_GATEWAY except requests.Timeout: @@ -83,6 +86,7 @@ def send_webhook(webhook, payload, redelivery=False): return delivery + def add_to_queue(webhook, payload, redelivery=False): queue = django_rq.get_queue(settings.CVAT_QUEUES.WEBHOOKS.value) queue.enqueue_call(func=send_webhook, args=(webhook, payload, redelivery)) @@ -163,6 +167,7 @@ def pre_save_resource_event(sender, instance, **kwargs): old_serializer = get_serializer(instance=old_instance) instance._webhooks_old_data = old_serializer.data + @receiver(post_save, sender=Project, dispatch_uid=__name__ + ":project:post_save") @receiver(post_save, sender=Task, dispatch_uid=__name__ + ":task:post_save") @receiver(post_save, sender=Job, dispatch_uid=__name__ + ":job:post_save") @@ -196,10 +201,7 @@ def post_save_resource_event(sender, instance, **kwargs): if not created: if diff := get_instance_diff(old_data=old_data, data=serializer.data): - data["before_update"] = { - attr: value["old_value"] - for attr, value in diff.items() - } + data["before_update"] = {attr: value["old_value"] for attr, value in diff.items()} transaction.on_commit( lambda: batch_add_to_queue(selected_webhooks, data), @@ -250,7 +252,11 @@ def post_delete_resource_event(sender, instance, **kwargs): "sender": get_sender(instance), } - related_webhooks = [webhook for webhook in getattr(instance, "_related_webhooks", []) if webhook.id not in map(lambda a: a.id, filtered_webhooks)] + related_webhooks = [ + webhook + for webhook in getattr(instance, "_related_webhooks", []) + if webhook.id not in map(lambda a: a.id, filtered_webhooks) + ] transaction.on_commit( lambda: batch_add_to_queue(filtered_webhooks + related_webhooks, data), diff --git a/cvat/apps/webhooks/urls.py b/cvat/apps/webhooks/urls.py index c309df746f96..26f86fc2313e 100644 --- a/cvat/apps/webhooks/urls.py +++ b/cvat/apps/webhooks/urls.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT from rest_framework.routers import DefaultRouter + from .views import WebhookViewSet router = DefaultRouter(trailing_slash=False) diff --git a/cvat/apps/webhooks/views.py b/cvat/apps/webhooks/views.py index 66529bc6a7bd..b4e059c528f6 100644 --- a/cvat/apps/webhooks/views.py +++ b/cvat/apps/webhooks/views.py @@ -2,9 +2,13 @@ # # SPDX-License-Identifier: MIT -from drf_spectacular.utils import (OpenApiParameter, OpenApiResponse, - OpenApiTypes, extend_schema, - extend_schema_view) +from drf_spectacular.utils import ( + OpenApiParameter, + OpenApiResponse, + OpenApiTypes, + extend_schema, + extend_schema_view, +) from rest_framework import status, viewsets from rest_framework.decorators import action from rest_framework.permissions import SAFE_METHODS @@ -16,8 +20,12 @@ from .event_type import AllEvents, OrganizationEvents, ProjectEvents from .models import Webhook, WebhookDelivery, WebhookTypeChoice from .permissions import WebhookPermission -from .serializers import (EventsSerializer, WebhookDeliveryReadSerializer, - WebhookReadSerializer, WebhookWriteSerializer) +from .serializers import ( + EventsSerializer, + WebhookDeliveryReadSerializer, + WebhookReadSerializer, + WebhookWriteSerializer, +) from .signals import signal_ping, signal_redelivery @@ -34,24 +42,18 @@ update=extend_schema( summary="Replace a webhook", request=WebhookWriteSerializer, - responses={ - "200": WebhookReadSerializer - }, # check WebhookWriteSerializer.to_representation + responses={"200": WebhookReadSerializer}, # check WebhookWriteSerializer.to_representation ), partial_update=extend_schema( summary="Update a webhook", request=WebhookWriteSerializer, - responses={ - "200": WebhookReadSerializer - }, # check WebhookWriteSerializer.to_representation + responses={"200": WebhookReadSerializer}, # check WebhookWriteSerializer.to_representation ), create=extend_schema( request=WebhookWriteSerializer, summary="Create a webhook", parameters=ORGANIZATION_OPEN_API_PARAMETERS, - responses={ - "201": WebhookReadSerializer - }, # check WebhookWriteSerializer.to_representation + responses={"201": WebhookReadSerializer}, # check WebhookWriteSerializer.to_representation ), destroy=extend_schema( summary="Delete a webhook", @@ -71,9 +73,7 @@ class WebhookViewSet(viewsets.ModelViewSet): iam_organization_field = "organization" def get_serializer_class(self): - if self.request.path.endswith("redelivery") or self.request.path.endswith( - "ping" - ): + if self.request.path.endswith("redelivery") or self.request.path.endswith("ping"): return None else: if self.request.method in SAFE_METHODS: @@ -109,7 +109,10 @@ def perform_create(self, serializer): ], responses={"200": OpenApiResponse(EventsSerializer)}, ) - @action(detail=False, methods=["GET"], serializer_class=EventsSerializer, + @action( + detail=False, + methods=["GET"], + serializer_class=EventsSerializer, permission_classes=[], ) def events(self, request): @@ -123,9 +126,7 @@ def events(self, request): events = OrganizationEvents if events is None: - return Response( - "Incorrect value of type parameter", status=status.HTTP_400_BAD_REQUEST - ) + return Response("Incorrect value of type parameter", status=status.HTTP_400_BAD_REQUEST) return Response(EventsSerializer().to_representation(events)) @@ -137,10 +138,8 @@ def events(self, request): ) @list_action(serializer_class=WebhookDeliveryReadSerializer) def deliveries(self, request, pk): - self.get_object() # force call of check_object_permissions() - queryset = WebhookDelivery.objects.filter(webhook_id=pk).order_by( - "-updated_date" - ) + self.get_object() # force call of check_object_permissions() + queryset = WebhookDelivery.objects.filter(webhook_id=pk).order_by("-updated_date") return make_paginated_response( queryset, viewset=self, serializer_type=self.serializer_class ) # from @action @@ -156,11 +155,9 @@ def deliveries(self, request, pk): serializer_class=WebhookDeliveryReadSerializer, ) def retrieve_delivery(self, request, pk, delivery_id): - self.get_object() # force call of check_object_permissions() + self.get_object() # force call of check_object_permissions() queryset = WebhookDelivery.objects.get(webhook_id=pk, id=delivery_id) - serializer = WebhookDeliveryReadSerializer( - queryset, context={"request": request} - ) + serializer = WebhookDeliveryReadSerializer(queryset, context={"request": request}) return Response(serializer.data) @extend_schema( @@ -184,15 +181,11 @@ def redelivery(self, request, pk, delivery_id): request=None, responses={"200": WebhookDeliveryReadSerializer}, ) - @action( - detail=True, methods=["POST"], serializer_class=WebhookDeliveryReadSerializer - ) + @action(detail=True, methods=["POST"], serializer_class=WebhookDeliveryReadSerializer) def ping(self, request, pk): - instance = self.get_object() # force call of check_object_permissions() + instance = self.get_object() # force call of check_object_permissions() serializer = WebhookReadSerializer(instance, context={"request": request}) delivery = signal_ping.send(sender=self, serializer=serializer)[0][1] - serializer = WebhookDeliveryReadSerializer( - delivery, context={"request": request} - ) + serializer = WebhookDeliveryReadSerializer(delivery, context={"request": request}) return Response(serializer.data) diff --git a/cvat/requirements/all.txt b/cvat/requirements/all.txt index 4e05dcc9e85f..482db32ecf87 100644 --- a/cvat/requirements/all.txt +++ b/cvat/requirements/all.txt @@ -8,5 +8,3 @@ -r development.txt -r production.txt -r testing.txt - -# The following packages are considered to be unsafe in a requirements file: diff --git a/cvat/requirements/base.in b/cvat/requirements/base.in index fd86b51f99dc..03d74579fb21 100644 --- a/cvat/requirements/base.in +++ b/cvat/requirements/base.in @@ -12,7 +12,7 @@ azure-storage-blob==12.13.0 boto3==1.17.61 clickhouse-connect==0.6.8 coreapi==2.3.3 -datumaro @ git+https://github.com/cvat-ai/datumaro.git@bf0374689df50599a34a4f220b9e5329aca695ce +datumaro @ git+https://github.com/cvat-ai/datumaro.git@08e77b216080555a57e12c01625be8c8201e3131 dj-pagination==2.5.0 # Despite direct indication allauth in requirements we should keep 'with_social' for dj-rest-auth # to avoid possible further versions conflicts (we use registration functionality) diff --git a/cvat/requirements/base.txt b/cvat/requirements/base.txt index fe4518b64e44..f531f125ebf6 100644 --- a/cvat/requirements/base.txt +++ b/cvat/requirements/base.txt @@ -1,4 +1,4 @@ -# SHA1:1bed6e1afea11473b164df79d7d166f419074359 +# SHA1:3e6349d9e5e095c5a1f196eca66b3e5ba8672458 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -36,9 +36,9 @@ certifi==2024.12.14 # requests cffi==1.17.1 # via cryptography -charset-normalizer==3.4.0 +charset-normalizer==3.4.1 # via requests -click==8.1.7 +click==8.1.8 # via rq clickhouse-connect==0.6.8 # via -r cvat/requirements/base.in @@ -56,7 +56,7 @@ cryptography==44.0.0 # pyjwt cycler==0.12.1 # via matplotlib -datumaro @ git+https://github.com/cvat-ai/datumaro.git@bf0374689df50599a34a4f220b9e5329aca695ce +datumaro @ git+https://github.com/cvat-ai/datumaro.git@08e77b216080555a57e12c01625be8c8201e3131 # via -r cvat/requirements/base.in defusedxml==0.7.1 # via @@ -147,8 +147,10 @@ idna==3.10 # via requests importlib-metadata==8.5.0 # via clickhouse-connect -importlib-resources==6.4.5 - # via nibabel +importlib-resources==6.5.2 + # via + # matplotlib + # nibabel inflection==0.5.1 # via drf-spectacular isodate==0.7.2 @@ -157,7 +159,7 @@ isodate==0.7.2 # python3-saml itypes==1.2.0 # via coreapi -jinja2==3.1.4 +jinja2==3.1.5 # via coreschema jmespath==0.10.0 # via @@ -167,7 +169,7 @@ jsonschema==4.17.3 # via drf-spectacular kiwisolver==1.4.7 # via matplotlib -limits==3.14.1 +limits==4.0.0 # via python-logstash-async lxml==5.3.0 # via @@ -187,7 +189,7 @@ mmh3==5.0.1 # via pottery msrest==0.7.1 # via azure-storage-blob -networkx==3.4.2 +networkx==3.2.1 # via datumaro nibabel==5.3.2 # via datumaro @@ -195,7 +197,7 @@ oauthlib==3.2.2 # via requests-oauthlib orderedmultidict==1.0.1 # via furl -orjson==3.10.12 +orjson==3.10.13 # via datumaro packaging==24.2 # via @@ -213,7 +215,7 @@ pottery==3.0.0 # via -r cvat/requirements/base.in proto-plus==1.25.0 # via google-api-core -protobuf==5.29.1 +protobuf==5.29.2 # via # google-api-core # googleapis-common-protos @@ -240,7 +242,7 @@ pyjwt[crypto]==2.10.1 # via django-allauth pylogbeat==2.0.1 # via python-logstash-async -pyparsing==3.2.0 +pyparsing==3.2.1 # via matplotlib pyrsistent==0.20.0 # via jsonschema @@ -306,7 +308,7 @@ rq-scheduler==0.13.1 # via -r cvat/requirements/base.in rsa==4.9 # via google-auth -ruamel-yaml==0.18.6 +ruamel-yaml==0.18.10 # via datumaro ruamel-yaml-clib==0.2.12 # via ruamel-yaml @@ -354,6 +356,8 @@ xmlsec==1.3.14 # -r cvat/requirements/base.in # python3-saml zipp==3.21.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources zstandard==0.23.0 # via clickhouse-connect diff --git a/cvat/requirements/development.in b/cvat/requirements/development.in index ad5a5b6557ec..9c5e0662b52d 100644 --- a/cvat/requirements/development.in +++ b/cvat/requirements/development.in @@ -1,10 +1,6 @@ -r base.in -black>=24.1 django-extensions==3.0.8 django-silk==5.* -pylint-django==2.5.3 -pylint-plugin-utils==0.7 -pylint==2.14.5 rope==0.17.0 snakeviz==2.1.0 diff --git a/cvat/requirements/development.txt b/cvat/requirements/development.txt index b0c563374067..cc730b7916eb 100644 --- a/cvat/requirements/development.txt +++ b/cvat/requirements/development.txt @@ -1,4 +1,4 @@ -# SHA1:b71f4fe955f645187b7ccdf82b05f6a8d61eb3ab +# SHA1:cd8d0825dc4cfe37b22a489422105acba5483fe4 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -6,61 +6,21 @@ # pip-compile-multi # -r base.txt -astroid==2.11.7 - # via pylint autopep8==2.3.1 # via django-silk -black==24.10.0 - # via -r cvat/requirements/development.in -dill==0.3.9 - # via pylint django-extensions==3.0.8 # via -r cvat/requirements/development.in django-silk==5.3.2 # via -r cvat/requirements/development.in gprof2dot==2024.6.6 # via django-silk -isort==5.13.2 - # via pylint -lazy-object-proxy==1.10.0 - # via astroid -mccabe==0.7.0 - # via pylint -mypy-extensions==1.0.0 - # via black -pathspec==0.12.1 - # via black -platformdirs==4.3.6 - # via - # black - # pylint pycodestyle==2.12.1 # via autopep8 -pylint==2.14.5 - # via - # -r cvat/requirements/development.in - # pylint-django - # pylint-plugin-utils -pylint-django==2.5.3 - # via -r cvat/requirements/development.in -pylint-plugin-utils==0.7 - # via - # -r cvat/requirements/development.in - # pylint-django rope==0.17.0 # via -r cvat/requirements/development.in snakeviz==2.1.0 # via -r cvat/requirements/development.in tomli==2.2.1 - # via - # autopep8 - # black - # pylint -tomlkit==0.13.2 - # via pylint + # via autopep8 tornado==6.4.2 # via snakeviz - -# The following packages are considered to be unsafe in a requirements file: -setuptools==75.6.0 - # via astroid diff --git a/cvat/requirements/production.txt b/cvat/requirements/production.txt index 155d626a6984..c65ede91ad59 100644 --- a/cvat/requirements/production.txt +++ b/cvat/requirements/production.txt @@ -6,7 +6,7 @@ # pip-compile-multi # -r base.txt -anyio==4.7.0 +anyio==4.8.0 # via watchfiles coverage==7.2.3 # via -r cvat/requirements/production.in diff --git a/cvat/requirements/testing.txt b/cvat/requirements/testing.txt index 90c8a13254c0..86ab66664526 100644 --- a/cvat/requirements/testing.txt +++ b/cvat/requirements/testing.txt @@ -14,5 +14,3 @@ lupa==1.14.1 # via fakeredis sortedcontainers==2.4.0 # via fakeredis - -# The following packages are considered to be unsafe in a requirements file: diff --git a/cvat/schema.yml b/cvat/schema.yml index 45f95346c769..4e0b3a687677 100644 --- a/cvat/schema.yml +++ b/cvat/schema.yml @@ -1,7 +1,7 @@ openapi: 3.0.3 info: title: CVAT REST API - version: 2.24.0 + version: 2.25.0 description: REST API for Computer Vision Annotation Tool (CVAT) termsOfService: https://www.google.com/policies/terms/ contact: @@ -9775,12 +9775,12 @@ components: compare_attributes: type: boolean description: Enables or disables annotation attribute comparison - match_empty_frames: + empty_is_annotated: type: boolean default: false description: | - Count empty frames as matching. This affects target metrics like accuracy in cases - there are no annotations. If disabled, frames without annotations + Consider empty frames annotated as "empty". This affects target metrics like + accuracy in cases there are no annotations. If disabled, frames without annotations are counted as not matching (accuracy is 0). If enabled, accuracy will be 1 instead. This will also add virtual annotations to empty frames in the comparison results. PatchedTaskValidationLayoutWriteRequest: @@ -10282,12 +10282,12 @@ components: compare_attributes: type: boolean description: Enables or disables annotation attribute comparison - match_empty_frames: + empty_is_annotated: type: boolean default: false description: | - Count empty frames as matching. This affects target metrics like accuracy in cases - there are no annotations. If disabled, frames without annotations + Consider empty frames annotated as "empty". This affects target metrics like + accuracy in cases there are no annotations. If disabled, frames without annotations are counted as not matching (accuracy is 0). If enabled, accuracy will be 1 instead. This will also add virtual annotations to empty frames in the comparison results. RegisterSerializerEx: diff --git a/cvat/settings/base.py b/cvat/settings/base.py index 0f6147dc4bf0..c73cb31eafa2 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -19,9 +19,9 @@ import os import sys import tempfile +import urllib from datetime import timedelta from enum import Enum -import urllib from attr.converters import to_bool from corsheaders.defaults import default_headers @@ -74,7 +74,7 @@ def generate_secret_key(): try: sys.path.append(BASE_DIR) - from keys.secret_key import SECRET_KEY # pylint: disable=unused-import + from keys.secret_key import SECRET_KEY # pylint: disable=unused-import except ModuleNotFoundError: generate_secret_key() from keys.secret_key import SECRET_KEY @@ -740,6 +740,7 @@ class CVAT_QUEUES(Enum): CVAT_CONCURRENT_CHUNK_PROCESSING = int(os.getenv('CVAT_CONCURRENT_CHUNK_PROCESSING', 1)) from cvat.rq_patching import update_started_job_registry_cleanup + update_started_job_registry_cleanup() CLOUD_DATA_DOWNLOADING_MAX_THREADS_NUMBER = 4 diff --git a/cvat/settings/email_settings.py b/cvat/settings/email_settings.py index d3f9621e09d4..f83f918339de 100644 --- a/cvat/settings/email_settings.py +++ b/cvat/settings/email_settings.py @@ -5,7 +5,6 @@ from cvat.settings.production import * - # https://github.com/pennersr/django-allauth ACCOUNT_AUTHENTICATION_METHOD = 'username_email' ACCOUNT_CONFIRM_EMAIL_ON_GET = True diff --git a/cvat/settings/testing.py b/cvat/settings/testing.py index 3cd47559fbd0..e0391e4c3b40 100644 --- a/cvat/settings/testing.py +++ b/cvat/settings/testing.py @@ -2,9 +2,10 @@ # # SPDX-License-Identifier: MIT -from .development import * import tempfile +from .development import * + DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', @@ -73,11 +74,13 @@ TEST_RUNNER = "cvat.settings.testing.PatchedDiscoverRunner" from django.test.runner import DiscoverRunner + + class PatchedDiscoverRunner(DiscoverRunner): def __init__(self, *args, **kwargs): # Used fakeredis for testing (don't affect production redis) - from fakeredis import FakeRedis, FakeStrictRedis import django_rq.queues + from fakeredis import FakeRedis, FakeStrictRedis simple_redis = FakeRedis() strict_redis = FakeStrictRedis() django_rq.queues.get_redis_connection = lambda _, strict: strict_redis \ diff --git a/cvat/urls.py b/cvat/urls.py index 08257a14b811..ca62b7cb03a3 100644 --- a/cvat/urls.py +++ b/cvat/urls.py @@ -20,7 +20,7 @@ from django.apps import apps from django.contrib import admin -from django.urls import path, include +from django.urls import include, path urlpatterns = [ path("admin/", admin.site.urls), diff --git a/cvat/utils/http.py b/cvat/utils/http.py index 2cb1b7498b32..ab8771aaa2ae 100644 --- a/cvat/utils/http.py +++ b/cvat/utils/http.py @@ -2,10 +2,9 @@ # # SPDX-License-Identifier: MIT -from django.conf import settings - import requests import requests.utils +from django.conf import settings from cvat import __version__ diff --git a/dev/requirements.txt b/dev/requirements.txt new file mode 100644 index 000000000000..4603689ae469 --- /dev/null +++ b/dev/requirements.txt @@ -0,0 +1,5 @@ +black==24.* +isort==5.* +pylint-django==2.5.3 +pylint-plugin-utils==0.7 +pylint==2.14.5 diff --git a/dev/update_version.py b/dev/update_version.py index ed8d08a40f42..7419a581ef4c 100755 --- a/dev/update_version.py +++ b/dev/update_version.py @@ -6,8 +6,8 @@ import sys from dataclasses import dataclass from pathlib import Path -from typing import Callable, Match, Pattern - +from re import Match, Pattern +from typing import Callable SUCCESS_CHAR = "\u2714" FAIL_CHAR = "\u2716" @@ -159,8 +159,8 @@ def apply(self, new_version: Version, *, verify_only: bool) -> bool: ), ReplacementRule( "cvat-cli/requirements/base.txt", - re.compile(r"^cvat-sdk~=[\d.]+$", re.M), - lambda v, m: f"cvat-sdk~={v.major}.{v.minor}.{v.patch}", + re.compile(r"^cvat-sdk==[\d.]+$", re.M), + lambda v, m: f"cvat-sdk=={v.major}.{v.minor}.{v.patch}", ), ] diff --git a/docker-compose.external_db.yml b/docker-compose.external_db.yml index decd1e9ed141..8112c59fd4f4 100644 --- a/docker-compose.external_db.yml +++ b/docker-compose.external_db.yml @@ -27,6 +27,7 @@ services: cvat_worker_import: *backend-settings cvat_worker_quality_reports: *backend-settings cvat_worker_webhooks: *backend-settings + cvat_worker_chunks: *backend-settings secrets: postgres_password: diff --git a/docker-compose.yml b/docker-compose.yml index b956bc6fcca5..5e591b29e0ec 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -81,7 +81,7 @@ services: cvat_server: container_name: cvat_server - image: cvat/server:${CVAT_VERSION:-v2.24.0} + image: cvat/server:${CVAT_VERSION:-v2.25.0} restart: always depends_on: <<: *backend-deps @@ -115,7 +115,7 @@ services: cvat_utils: container_name: cvat_utils - image: cvat/server:${CVAT_VERSION:-v2.24.0} + image: cvat/server:${CVAT_VERSION:-v2.25.0} restart: always depends_on: *backend-deps environment: @@ -132,7 +132,7 @@ services: cvat_worker_import: container_name: cvat_worker_import - image: cvat/server:${CVAT_VERSION:-v2.24.0} + image: cvat/server:${CVAT_VERSION:-v2.25.0} restart: always depends_on: *backend-deps environment: @@ -148,7 +148,7 @@ services: cvat_worker_export: container_name: cvat_worker_export - image: cvat/server:${CVAT_VERSION:-v2.24.0} + image: cvat/server:${CVAT_VERSION:-v2.25.0} restart: always depends_on: *backend-deps environment: @@ -164,7 +164,7 @@ services: cvat_worker_annotation: container_name: cvat_worker_annotation - image: cvat/server:${CVAT_VERSION:-v2.24.0} + image: cvat/server:${CVAT_VERSION:-v2.25.0} restart: always depends_on: *backend-deps environment: @@ -180,7 +180,7 @@ services: cvat_worker_webhooks: container_name: cvat_worker_webhooks - image: cvat/server:${CVAT_VERSION:-v2.24.0} + image: cvat/server:${CVAT_VERSION:-v2.25.0} restart: always depends_on: *backend-deps environment: @@ -196,7 +196,7 @@ services: cvat_worker_quality_reports: container_name: cvat_worker_quality_reports - image: cvat/server:${CVAT_VERSION:-v2.24.0} + image: cvat/server:${CVAT_VERSION:-v2.25.0} restart: always depends_on: *backend-deps environment: @@ -212,7 +212,7 @@ services: cvat_worker_analytics_reports: container_name: cvat_worker_analytics_reports - image: cvat/server:${CVAT_VERSION:-v2.24.0} + image: cvat/server:${CVAT_VERSION:-v2.25.0} restart: always depends_on: *backend-deps environment: @@ -228,7 +228,7 @@ services: cvat_worker_chunks: container_name: cvat_worker_chunks - image: cvat/server:${CVAT_VERSION:-v2.24.0} + image: cvat/server:${CVAT_VERSION:-v2.25.0} restart: always depends_on: *backend-deps environment: @@ -244,7 +244,7 @@ services: cvat_ui: container_name: cvat_ui - image: cvat/ui:${CVAT_VERSION:-v2.24.0} + image: cvat/ui:${CVAT_VERSION:-v2.25.0} restart: always depends_on: - cvat_server diff --git a/helm-chart/values.yaml b/helm-chart/values.yaml index e1138ca0a40c..041d088ba214 100644 --- a/helm-chart/values.yaml +++ b/helm-chart/values.yaml @@ -139,7 +139,7 @@ cvat: additionalVolumeMounts: [] replicas: 1 image: cvat/server - tag: v2.24.0 + tag: v2.25.0 imagePullPolicy: Always permissionFix: enabled: true @@ -162,7 +162,7 @@ cvat: frontend: replicas: 1 image: cvat/ui - tag: v2.24.0 + tag: v2.25.0 imagePullPolicy: Always labels: {} # test: test @@ -475,7 +475,7 @@ ingress: ## kubernetes.io/ingress.class: nginx ## annotations: {} - ## @param ingress.className IngressClass that will be be used to implement the Ingress (Kubernetes 1.18+) + ## @param ingress.className IngressClass that will be used to implement the Ingress (Kubernetes 1.18+) ## This is supported in Kubernetes 1.18+ and required if you have more than one IngressClass marked as the default for your cluster ## ref: https://kubernetes.io/blog/2020/04/02/improvements-to-the-ingress-api-in-kubernetes-1.18/ ## diff --git a/pyproject.toml b/pyproject.toml index 528bdc579fcc..b0c13a15766f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,17 +3,21 @@ profile = "black" forced_separate = ["tests"] line_length = 100 skip_gitignore = true # align tool behavior with Black +extend_skip=[ + # Correctly ordering the imports in serverless functions would + # require a pyproject.toml in every function; don't bother with it for now. + "serverless", + # Sorting the imports in this file causes test failures; + # TODO: fix them and remove this ignore. + "cvat/apps/dataset_manager/formats/registry.py", +] [tool.black] line-length = 100 target-version = ['py39'] extend-exclude = """ # TODO: get rid of these -^/cvat/apps/( - dataset_manager|dataset_repo|engine|events - |health|iam|lambda_manager|log_viewer - |organizations|webhooks -)/ +^/cvat/apps/(dataset_manager|engine)/ | ^/cvat/settings/ | ^/serverless/ | ^/utils/dataset_manifest/ diff --git a/rqscheduler.py b/rqscheduler.py index 5ae76e64a7f0..b6cebe80f285 100644 --- a/rqscheduler.py +++ b/rqscheduler.py @@ -4,10 +4,10 @@ # implementation. This is required for correct work with CVAT queue settings and # their access options such as login and password. +from rq_scheduler.scripts import rqscheduler + # Required to initialize Django settings correctly from cvat.asgi import application # pylint: disable=unused-import -from rq_scheduler.scripts import rqscheduler - if __name__ == "__main__": rqscheduler.main() diff --git a/site/build_docs.py b/site/build_docs.py index 2eca3a941330..a01c437ae64c 100755 --- a/site/build_docs.py +++ b/site/build_docs.py @@ -10,7 +10,7 @@ import subprocess import tempfile from pathlib import Path -from typing import Dict, Optional +from typing import Optional import git import toml @@ -98,7 +98,7 @@ def run_npm_install(): def run_hugo( destination_dir: os.PathLike, *, - extra_env_vars: Dict[str, str] = None, + extra_env_vars: dict[str, str] = None, executable: Optional[str] = "hugo", ): extra_kwargs = {} diff --git a/site/content/en/docs/api_sdk/cli/_index.md b/site/content/en/docs/api_sdk/cli/_index.md index ffa5be80676b..82bfad795fb0 100644 --- a/site/content/en/docs/api_sdk/cli/_index.md +++ b/site/content/en/docs/api_sdk/cli/_index.md @@ -29,6 +29,11 @@ The following subcommands are supported: - `backup` - back up a task - `auto-annotate` - automatically annotate a task using a local function +- Functions (Enterprise/Cloud only): + - `create-native` - create a function that can be powered by an agent + - `delete` - delete a function + - `run-agent` - process requests for a native function + ## Installation To install an [official release of CVAT CLI](https://pypi.org/project/cvat-cli/), use this command: @@ -316,3 +321,35 @@ see that command's examples for more information. ```bash cvat-cli project ls --json > list_of_projects.json ``` + +## Examples - functions + +**Note**: The functionality described in this section can only be used +with the CVAT Enterprise or CVAT Cloud. + +### Create + +- Create a function that uses a detection model from torchvision + and run an agent for it: + + ``` + cvat-cli function create-native "Faster R-CNN" \ + --function-module cvat_sdk.auto_annotation.functions.torchvision_detection \ + -p model_name=str:fasterrcnn_resnet50_fpn_v2 + cvat-cli function run-agent \ + --function-module cvat_sdk.auto_annotation.functions.torchvision_detection \ + -p model_name=str:fasterrcnn_resnet50_fpn_v2 + ``` + +These commands accept functions that implement the +{{< ilink "/docs/api_sdk/sdk/auto-annotation" "auto-annotation function interface" >}} +from the SDK, same as the `task auto-annotate` command. +See that command's examples for information on how to implement these functions +and specify them in the command line. + +### Delete + +- Delete functions with IDs 100 and 101: + ``` + cvat-cli function delete 100 101 + ``` diff --git a/site/content/en/docs/contributing/development-environment.md b/site/content/en/docs/contributing/development-environment.md index e54929609e48..31fb2f755c7a 100644 --- a/site/content/en/docs/contributing/development-environment.md +++ b/site/content/en/docs/contributing/development-environment.md @@ -80,7 +80,7 @@ description: 'Installing a development environment for different operating syste python3 -m venv .env . .env/bin/activate pip install -U pip wheel setuptools - pip install -r cvat/requirements/development.txt + pip install -r cvat/requirements/development.txt -r dev/requirements.txt ``` Note that the `.txt` files in the `cvat/requirements` directory diff --git a/site/content/en/docs/enterprise/social-accounts-configuration.md b/site/content/en/docs/enterprise/social-accounts-configuration.md index 83b7f463a27e..f9489d0a8cfb 100644 --- a/site/content/en/docs/enterprise/social-accounts-configuration.md +++ b/site/content/en/docs/enterprise/social-accounts-configuration.md @@ -19,19 +19,13 @@ such benefits as: Currently, we offer three options: -- Authentication with Github. -- Authentication with Google. -- Authentication with Amazon Cognito. +- [Authentication with Google](#authentication-with-google) +- [Authentication with GitHub](#authentication-with-github) +- [Authentication with Amazon Cognito](#authentication-with-amazon-cognito) With more to come soon. Stay tuned! -See: - -- [Enable authentication with a Google account](#enable-authentication-with-a-google-account) -- [Enable authentication with a GitHub account](#enable-authentication-with-a-github-account) -- [Enable authentication with an Amazon Cognito](#enable-authentication-with-an-amazon-cognito) - -## Enable authentication with a Google account +## Authentication with Google To enable authentication, do the following: @@ -72,7 +66,7 @@ To enable authentication, do the following: docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.override.yml up -d --build ``` -## Enable authentication with a GitHub account +## Authentication with GitHub There are 2 basic steps to enable GitHub account authentication. @@ -106,32 +100,72 @@ There are 2 basic steps to enable GitHub account authentication. > but don't forget to add required permissions. >
In the **Permission** > **Account permissions** > **Email addresses** must be set to **read-only**. -## Enable authentication with an Amazon Cognito - -To enable authentication, do the following: - -1. Create a user pool. For more information, - see [Amazon Cognito user pools](https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-identity-pools.html) -2. Fill in the name field, set the homepage URL (for example: `https://localhost:8080`), - and authentication callback URL (for example: `https://localhost:8080/api/auth/social/amazon-cognito/login/callback/`). -3. Create configuration file in CVAT: - - 1. Create the `auth_config.yml` file with the following content: - - ```yaml - --- - social_account: - enabled: true - amazon_cognito: - client_id: - client_secret: - domain: https://.auth.us-east-1.amazoncognito.com - ``` - - 2. Set `AUTH_CONFIG_PATH="` environment variable. - -3. In a terminal, run the following command: - - ```bash - docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.override.yml up -d --build - ``` +## Authentication with Amazon Cognito + +To enable authentication with Amazon Cognito for your CVAT instance, follow these steps: + +1. Create an **[Amazon Cognito pool](https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-identity-pools.html)** + (_Optional_) +1. Set up a new app client +1. Configure social authentication in CVAT + +Now, let’s dive deeper into how to accomplish these steps. + +### Amazon Cognito pool creation + +This step is optional and should only be performed if a user pool has not already been created. +To create a user pool, follow these instructions: +1. Go to the [AWS Management Console](https://console.aws.amazon.com/console/home) +1. Locate `Cognito` in the list of services +1. Click `Create user pool` +1. Fill in the required fields + +### App client creation + +To create a new app client, follow these steps: +1. Go to the details page of the created user pool +1. Find the `App clients` item in the menu on the left +1. Click `Create app client` +1. Fill out the form as shown bellow: + ![](/images/cognito_pool_1.png) + - `Application type`: `Traditional web application` + - `Application name`: Specify a desired name, or leave the autogenerated one + - `Return URL` (_optional_): Specify the CVAT redirect URL + (`:///api/auth/social/amazon-cognito/login/callback/`). + This setting can also be updated or specified later after the app client is created. +1. Navigate to the `Login pages` tab of the created app client +1. Check the parameters in the `Managed login pages configuration` section and edit them if needed: + ![](/images/cognito_pool_2.png) + - `Allowed callback URLs`: Must be set to the CVAT redirect URL + - `Identity providers`: Must be specified + - `OAuth grant types`: The `Authorization code grant` must be selected + - `OpenID Connect scopes`: `OpenID`, `Profile`, `Email` scopes must be selected + +### Setting up social authentication in CVAT + +To configure social authentication in CVAT, create a configuration file +(`auth_config.yml`) with the following content: + ```yaml + --- + social_account: + enabled: true + amazon_cognito: + client_id: + client_secret: + domain: or + https://.auth.us-east-1.amazoncognito.com + ``` +To find the `client_id` and `client_secret` values, navigate to the created app client page +and check the `App client information` section. To find `domain`, look for the `Domain` item in the list on the left. + +Once the configuration file is updated, several environment variables must be exported before running CVAT: + ```bash + export AUTH_CONFIG_PATH="" + export CVAT_HOST="" + # cvat_port is optional + export CVAT_BASE_URL="://${CVAT_HOST}:" + ``` + +Start the CVAT enterprise instance as usual. +That's it! On the CVAT login page, you should now see the option `Continue with Amazon Cognito`. +![](/images/login_page_with_amazon_cognito.png) diff --git a/site/content/en/docs/manual/advanced/analytics-and-monitoring/auto-qa.md b/site/content/en/docs/manual/advanced/analytics-and-monitoring/auto-qa.md index 21ebd2d99087..4a098c6545fa 100644 --- a/site/content/en/docs/manual/advanced/analytics-and-monitoring/auto-qa.md +++ b/site/content/en/docs/manual/advanced/analytics-and-monitoring/auto-qa.md @@ -385,7 +385,7 @@ Annotation quality settings have the following parameters: | - | - | | Min overlap threshold | Min overlap threshold used for the distinction between matched and unmatched shapes. Used to match all types of annotations. It corresponds to the Intersection over union (IoU) for spatial annotations, such as bounding boxes and masks. | | Low overlap threshold | Low overlap threshold used for the distinction between strong and weak matches. Only affects _Low overlap_ warnings. It's supposed that _Min similarity threshold_ <= _Low overlap threshold_. | -| Match empty frames | Consider frames matched if there are no annotations both on GT and regular job frames | +| Empty frames are annotated | Consider frames annotated as "empty" if there are no annotations on a frame. If a frame is empty in both GT and job annotations, it will be considered a matching annotation. | | _Point and Skeleton matching_ | | | - | - | diff --git a/site/content/en/docs/manual/advanced/formats/_index.md b/site/content/en/docs/manual/advanced/formats/_index.md index f4d30a45baa1..e8818e3742f6 100644 --- a/site/content/en/docs/manual/advanced/formats/_index.md +++ b/site/content/en/docs/manual/advanced/formats/_index.md @@ -23,34 +23,34 @@ The table below outlines the available formats for data export in CVAT. -| Format | Type | Computer Vision Task | Models | Shapes | Attributes | Video Tracks | -|------------------------------------------------------------------------------------------------------------------------------------|---------------|-------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------| -------------------- | ------------- | -| [CamVid 1.0](format-camvid) | .txt
.png | Semantic
Segmentation | U-Net, SegNet, DeepLab,
PSPNet, FCN, Mask R-CNN,
ICNet, ERFNet, HRNet,
V-Net, and others. | Polygons | Not supported | Not supported | -| [Cityscapes 1.0](format-cityscapes) | .txt
.png | Semantic
Segmentation | U-Net, SegNet, DeepLab,
PSPNet, FCN, ERFNet,
ICNet, Mask R-CNN, HRNet,
ENet, and others. | Polygons | Specific attributes | Not supported | -| [COCO 1.0](format-coco) | JSON | Detection, Semantic
Segmentation | YOLO (You Only Look Once),
Faster R-CNN, Mask R-CNN, SSD (Single Shot MultiBox Detector),
RetinaNet, EfficientDet, UNet,
DeepLabv3+, CenterNet, Cascade R-CNN, and others. | Bounding Boxes, Polygons | Specific attributes | Not supported | -| [COCO Keypoints 1.0](coco-keypoints) | .xml | Keypoints | OpenPose, PoseNet, AlphaPose,
SPM (Single Person Model),
Mask R-CNN with Keypoint Detection:, and others. | Skeletons | Specific attributes | Not supported | -| {{< ilink "/docs/manual/advanced/formats/format-cvat#cvat-for-image-export" "CVAT for images 1.1" >}} | .xml | Any in 2D except for Video Tracking | Any model that can decode the format. | Bounding Boxes, Polygons,
Polylines, Points, Cuboids,
Skeletons, Ellipses, Masks, Tags. | All attributes | Not supported | -| {{< ilink "/docs/manual/advanced/formats/format-cvat#cvat-for-videos-export" "CVAT for video 1.1" >}} | .xml | Any in 2D except for Classification | Any model that can decode the format. | Bounding Boxes, Polygons,
Polylines, Points, Cuboids,
Skeletons, Ellipses, Masks. | All attributes | Supported | -| [Datumaro 1.0](format-datumaro) | JSON | Any | Any model that can decode the format.
Main format in [Datumaro](https://github.com/openvinotoolkit/datumaro) framework | Bounding Boxes, Polygons,
Polylines, Points, Cuboids,
Skeletons, Ellipses, Masks, Tags. | All attributes | Supported | +| Format | Type | Computer Vision Task | Models | Shapes | Attributes | Video Tracks | +|-----------------------------------------------------------------------------------------------------------------------------|---------------|-------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------| -------------------- | ------------- | +| [CamVid 1.0](format-camvid) | .txt
.png | Semantic
Segmentation | U-Net, SegNet, DeepLab,
PSPNet, FCN, Mask R-CNN,
ICNet, ERFNet, HRNet,
V-Net, and others. | Polygons | Not supported | Not supported | +| [Cityscapes 1.0](format-cityscapes) | .txt
.png | Semantic
Segmentation | U-Net, SegNet, DeepLab,
PSPNet, FCN, ERFNet,
ICNet, Mask R-CNN, HRNet,
ENet, and others. | Polygons | Specific attributes | Not supported | +| [COCO 1.0](format-coco) | JSON | Detection, Semantic
Segmentation | YOLO (You Only Look Once),
Faster R-CNN, Mask R-CNN, SSD (Single Shot MultiBox Detector),
RetinaNet, EfficientDet, UNet,
DeepLabv3+, CenterNet, Cascade R-CNN, and others. | Bounding Boxes, Polygons | Specific attributes | Not supported | +| [COCO Keypoints 1.0](coco-keypoints) | .xml | Keypoints | OpenPose, PoseNet, AlphaPose,
SPM (Single Person Model),
Mask R-CNN with Keypoint Detection:, and others. | Skeletons | Specific attributes | Not supported | +| {{< ilink "/docs/manual/advanced/formats/format-cvat#cvat-for-image-export" "CVAT for images 1.1" >}} | .xml | Any in 2D except for Video Tracking | Any model that can decode the format. | Bounding Boxes, Polygons,
Polylines, Points, Cuboids,
Skeletons, Ellipses, Masks, Tags. | All attributes | Not supported | +| {{< ilink "/docs/manual/advanced/formats/format-cvat#cvat-for-videos-export" "CVAT for video 1.1" >}} | .xml | Any in 2D except for Classification | Any model that can decode the format. | Bounding Boxes, Polygons,
Polylines, Points, Cuboids,
Skeletons, Ellipses, Masks. | All attributes | Supported | +| [Datumaro 1.0](format-datumaro) | JSON | Any | Any model that can decode the format.
Main format in [Datumaro](https://github.com/openvinotoolkit/datumaro) framework | Bounding Boxes, Polygons,
Polylines, Points, Cuboids,
Skeletons, Ellipses, Masks, Tags. | All attributes | Supported | | [ICDAR](format-icdar)
Includes ICDAR Recognition 1.0,
ICDAR Detection 1.0,
and ICDAR Segmentation 1.0
descriptions. | .txt | Text recognition,
Text detection,
Text segmentation | EAST: Efficient and Accurate
Scene Text Detector, CRNN, Mask TextSpotter, TextSnake,
and others. | Tag, Bounding Boxes, Polygons | Specific attributes | Not supported | -| [ImageNet 1.0](format-imagenet) | .jpg
.txt | Semantic Segmentation,
Classification,
Detection | VGG (VGG16, VGG19), Inception, YOLO, Faster R-CNN , U-Net, and others | Tags | No attributes | Not supported | -| [KITTI 1.0](format-kitti) | .txt
.png | Semantic Segmentation, Detection, 3D | PointPillars, SECOND, AVOD, YOLO, DeepSORT, PWC-Net, ORB-SLAM, and others. | Bounding Boxes, Polygons | Specific attributes | Not supported | -| [LabelMe 3.0](format-labelme) | .xml | Compatibility,
Semantic Segmentation | U-Net, Mask R-CNN, Fast R-CNN,
Faster R-CNN, DeepLab, YOLO,
and others. | Bounding Boxes, Polygons | Supported (Polygons) | Not supported | -| [LFW 1.0](format-lfw) | .txt | Verification,
Face recognition | OpenFace, VGGFace & VGGFace2,
FaceNet, ArcFace,
and others. | Tags, Skeletons | Specific attributes | Not supported | -| [Market-1501 1.0](format-market1501) | .txt | Re-identification | Triplet Loss Networks,
Deep ReID models, and others. | Bounding Boxes | Specific attributes | Not supported | -| [MOT 1.0](format-mot) | .txt | Video Tracking,
Detection | SORT, MOT-Net, IOU Tracker,
and others. | Bounding Boxes | Specific attributes | Supported | -| [MOTS PNG 1.0](format-mots) | .png
.txt | Video Tracking,
Detection | SORT, MOT-Net, IOU Tracker,
and others. | Bounding Boxes, Masks | Specific attributes | Supported | -| [Open Images 1.0](format-openimages) | .csv | Detection,
Classification,
Semantic Segmentation | Faster R-CNN, YOLO, U-Net,
CornerNet, and others. | Bounding Boxes, Tags, Polygons | Specific attributes | Not supported | -| [PASCAL VOC 1.0](format-voc) | .xml | Classification, Detection | Faster R-CNN, SSD, YOLO,
AlexNet, and others. | Bounding Boxes, Tags, Polygons | Specific attributes | Not supported | -| [Segmentation Mask 1.0](format-smask) | .txt | Semantic Segmentation | Faster R-CNN, SSD, YOLO,
AlexNet, and others. | Polygons | No attributes | Not supported | -| [VGGFace2 1.0](format-vggface2) | .csv | Face recognition | VGGFace, ResNet, Inception,
and others. | Bounding Boxes, Points | No attributes | Not supported | -| [WIDER Face 1.0](format-widerface) | .txt | Detection | SSD (Single Shot MultiBox Detector), Faster R-CNN, YOLO,
and others. | Bounding Boxes, Tags | Specific attributes | Not supported | -| [YOLO 1.0](format-yolo) | .txt | Detection | YOLOv1, YOLOv2 (YOLO9000),
YOLOv3, YOLOv4, and others. | Bounding Boxes | No attributes | Not supported | -| [YOLOv8 Detection 1.0](format-yolov8) | .txt | Detection | YOLOv8 | Bounding Boxes | No attributes | Not supported | -| [YOLOv8 Segmentation 1.0](format-yolov8) | .txt | Instance Segmentation | YOLOv8 | Polygons, Masks | No attributes | Not supported | -| [YOLOv8 Pose 1.0](format-yolov8) | .txt | Keypoints | YOLOv8 | Skeletons | No attributes | Not supported | -| [YOLOv8 Oriented Bounding Boxes 1.0](format-yolov8) | .txt | Detection | YOLOv8 | Bounding Boxes | No attributes | Not supported | -| [YOLOv8 Classification 1.0](format-yolov8-classification) | .jpg | Classification | YOLOv8 | Tags | No attributes | Not supported | +| [ImageNet 1.0](format-imagenet) | .jpg
.txt | Semantic Segmentation,
Classification,
Detection | VGG (VGG16, VGG19), Inception, YOLO, Faster R-CNN , U-Net, and others | Tags | No attributes | Not supported | +| [KITTI 1.0](format-kitti) | .txt
.png | Semantic Segmentation, Detection, 3D | PointPillars, SECOND, AVOD, YOLO, DeepSORT, PWC-Net, ORB-SLAM, and others. | Bounding Boxes, Polygons | Specific attributes | Not supported | +| [LabelMe 3.0](format-labelme) | .xml | Compatibility,
Semantic Segmentation | U-Net, Mask R-CNN, Fast R-CNN,
Faster R-CNN, DeepLab, YOLO,
and others. | Bounding Boxes, Polygons | Supported (Polygons) | Not supported | +| [LFW 1.0](format-lfw) | .txt | Verification,
Face recognition | OpenFace, VGGFace & VGGFace2,
FaceNet, ArcFace,
and others. | Tags, Skeletons | Specific attributes | Not supported | +| [Market-1501 1.0](format-market1501) | .txt | Re-identification | Triplet Loss Networks,
Deep ReID models, and others. | Bounding Boxes | Specific attributes | Not supported | +| [MOT 1.0](format-mot) | .txt | Video Tracking,
Detection | SORT, MOT-Net, IOU Tracker,
and others. | Bounding Boxes | Specific attributes | Supported | +| [MOTS PNG 1.0](format-mots) | .png
.txt | Video Tracking,
Detection | SORT, MOT-Net, IOU Tracker,
and others. | Bounding Boxes, Masks | Specific attributes | Supported | +| [Open Images 1.0](format-openimages) | .csv | Detection,
Classification,
Semantic Segmentation | Faster R-CNN, YOLO, U-Net,
CornerNet, and others. | Bounding Boxes, Tags, Polygons | Specific attributes | Not supported | +| [PASCAL VOC 1.0](format-voc) | .xml | Classification, Detection | Faster R-CNN, SSD, YOLO,
AlexNet, and others. | Bounding Boxes, Tags, Polygons | Specific attributes | Not supported | +| [Segmentation Mask 1.0](format-smask) | .txt | Semantic Segmentation | Faster R-CNN, SSD, YOLO,
AlexNet, and others. | Polygons | No attributes | Not supported | +| [VGGFace2 1.0](format-vggface2) | .csv | Face recognition | VGGFace, ResNet, Inception,
and others. | Bounding Boxes, Points | No attributes | Not supported | +| [WIDER Face 1.0](format-widerface) | .txt | Detection | SSD (Single Shot MultiBox Detector), Faster R-CNN, YOLO,
and others. | Bounding Boxes, Tags | Specific attributes | Not supported | +| [YOLO 1.0](format-yolo) | .txt | Detection | YOLOv1, YOLOv2 (YOLO9000),
YOLOv3, YOLOv4, and others. | Bounding Boxes | No attributes | Not supported | +| [Ultralytics YOLO Detection 1.0](format-yolo-ultralytics) | .txt | Detection | YOLOv8 | Bounding Boxes | No attributes | Not supported | +| [Ultralytics YOLO Segmentation 1.0](format-yolo-ultralytics) | .txt | Instance Segmentation | YOLOv8 | Polygons, Masks | No attributes | Not supported | +| [Ultralytics YOLO Pose 1.0](format-yolo-ultralytics) | .txt | Keypoints | YOLOv8 | Skeletons | No attributes | Not supported | +| [Ultralytics YOLO Oriented Bounding Boxes 1.0](format-yolo-ultralytics) | .txt | Detection | YOLOv8 | Bounding Boxes | No attributes | Not supported | +| [Ultralytics YOLO Classification 1.0](format-yolo-ultralytics-classification) | .jpg | Classification | YOLOv8 | Tags | No attributes | Not supported | diff --git a/site/content/en/docs/manual/advanced/formats/format-yolov8-classification.md b/site/content/en/docs/manual/advanced/formats/format-yolo-ultralytics-classification.md similarity index 78% rename from site/content/en/docs/manual/advanced/formats/format-yolov8-classification.md rename to site/content/en/docs/manual/advanced/formats/format-yolo-ultralytics-classification.md index 8857c11518b3..734fd229a052 100644 --- a/site/content/en/docs/manual/advanced/formats/format-yolov8-classification.md +++ b/site/content/en/docs/manual/advanced/formats/format-yolo-ultralytics-classification.md @@ -1,16 +1,16 @@ --- -title: 'YOLOv8-Classification' -linkTitle: 'YOLOv8-Classification' +title: 'Ultralytics-YOLO-Classification' +linkTitle: 'Ultralytics-YOLO-Classification' weight: 7 -description: 'How to export and import data in YOLOv8 Classification format' +description: 'How to export and import data in Ultralytics YOLO Classification format' --- For more information, see: - [Format specification](https://docs.ultralytics.com/datasets/classify/) -- [Dataset examples](https://github.com/cvat-ai/datumaro/tree/develop/tests/assets/yolo_dataset/yolov8_classification) +- [Dataset examples](https://github.com/cvat-ai/datumaro/tree/develop/tests/assets/yolo_dataset/yolo_ultralytics_classification) -## YOLOv8 Classification export +## Ultralytics YOLO Classification export For export of images: diff --git a/site/content/en/docs/manual/advanced/formats/format-yolov8.md b/site/content/en/docs/manual/advanced/formats/format-yolo-ultralytics.md similarity index 73% rename from site/content/en/docs/manual/advanced/formats/format-yolov8.md rename to site/content/en/docs/manual/advanced/formats/format-yolo-ultralytics.md index 4d2975900ab8..4d99912de014 100644 --- a/site/content/en/docs/manual/advanced/formats/format-yolov8.md +++ b/site/content/en/docs/manual/advanced/formats/format-yolo-ultralytics.md @@ -1,24 +1,24 @@ --- -title: 'YOLOv8' -linkTitle: 'YOLOv8' +title: 'Ultralytics YOLO' +linkTitle: 'Ultralytics YOLO' weight: 7 -description: 'How to export and import data in YOLOv8 formats' +description: 'How to export and import data in Ultralytics YOLO formats' --- -YOLOv8 is a format family which consists of four formats: +Ultralytics YOLO is a format family which consists of four formats: - [Detection](https://docs.ultralytics.com/datasets/detect/) - [Oriented bounding Box](https://docs.ultralytics.com/datasets/obb/) - [Segmentation](https://docs.ultralytics.com/datasets/segment/) - [Pose](https://docs.ultralytics.com/datasets/pose/) Dataset examples: -- [Detection](https://github.com/cvat-ai/datumaro/tree/develop/tests/assets/yolo_dataset/yolov8_detection) -- [Oriented Bounding Boxes](https://github.com/cvat-ai/datumaro/tree/develop/tests/assets/yolo_dataset/yolov8_oriented_boxes) -- [Segmentation](https://github.com/cvat-ai/datumaro/tree/develop/tests/assets/yolo_dataset/yolov8_segmentation) -- [Pose](https://github.com/cvat-ai/datumaro/tree/develop/tests/assets/yolo_dataset/yolov8_pose) +- [Detection](https://github.com/cvat-ai/datumaro/tree/develop/tests/assets/yolo_dataset/yolo_ultralytics_detection) +- [Oriented Bounding Boxes](https://github.com/cvat-ai/datumaro/tree/develop/tests/assets/yolo_dataset/yolo_ultralytics_oriented_boxes) +- [Segmentation](https://github.com/cvat-ai/datumaro/tree/develop/tests/assets/yolo_dataset/yolo_ultralytics_segmentation) +- [Pose](https://github.com/cvat-ai/datumaro/tree/develop/tests/assets/yolo_dataset/yolo_ultralytics_pose) -## YOLOv8 export +## Ultralytics YOLO export For export of images: @@ -59,7 +59,7 @@ images//image2.jpg path: ./ # dataset root dir train: train.txt # train images (relative to 'path') -# YOLOv8 Pose specific field +# Ultralytics YOLO Pose specific field # First number is the number of points in a skeleton. # If there are several skeletons with different number of points, it is the greatest number of points # Second number defines the format of point info in annotation txt files @@ -75,7 +75,7 @@ names: # .txt: # content depends on format -# YOLOv8 Detection: +# Ultralytics YOLO Detection: # label_id - id from names field of data.yaml # cx, cy - relative coordinates of the bbox center # rw, rh - relative size of the bbox @@ -83,19 +83,19 @@ names: 1 0.3 0.8 0.1 0.3 2 0.7 0.2 0.3 0.1 -# YOLOv8 Oriented Bounding Boxes: +# Ultralytics YOLO Oriented Bounding Boxes: # xn, yn - relative coordinates of the n-th point # label_id x1 y1 x2 y2 x3 y3 x4 y4 1 0.3 0.8 0.1 0.3 0.4 0.5 0.7 0.5 2 0.7 0.2 0.3 0.1 0.4 0.5 0.5 0.6 -# YOLOv8 Segmentation: +# Ultralytics YOLO Segmentation: # xn, yn - relative coordinates of the n-th point # label_id x1 y1 x2 y2 x3 y3 ... 1 0.3 0.8 0.1 0.3 0.4 0.5 2 0.7 0.2 0.3 0.1 0.4 0.5 0.5 0.6 0.7 0.5 -# YOLOv8 Pose: +# Ultralytics YOLO Pose: # cx, cy - relative coordinates of the bbox center # rw, rh - relative size of the bbox # xn, yn - relative coordinates of the n-th point @@ -126,3 +126,24 @@ is named to correspond with its associated image file. For example, `frame_000001.txt` serves as the annotation for the `frame_000001.jpg` image. + +## Import + +Uploaded file: a zip archive of the same structure as above. + +For compatibility with other tools exporting in Ultralytics YOLO format +(e.g. [roboflow](https://roboflow.com/formats/yolov8-pytorch-txt)), +CVAT supports datasets with the inverted directory order of subset and "images" or "labels", +i.e. both `train/images/`, `images/train/` are valid inputs. +```bash +archive.zip/ + ├── train/ + │ ├── images/ # directory with images for train subset + │ │ ├── image1.jpg + │ │ ├── image2.jpg + │ │ └── ... + │ ├── labels/ # directory with annotations for train subset + │ │ ├── image1.txt + │ │ ├── image2.txt + │ │ └── ... +``` diff --git a/site/content/en/images/cognito_pool_1.png b/site/content/en/images/cognito_pool_1.png new file mode 100644 index 000000000000..7cfc8ac03521 Binary files /dev/null and b/site/content/en/images/cognito_pool_1.png differ diff --git a/site/content/en/images/cognito_pool_2.png b/site/content/en/images/cognito_pool_2.png new file mode 100644 index 000000000000..5e1a1b47dfe6 Binary files /dev/null and b/site/content/en/images/cognito_pool_2.png differ diff --git a/site/content/en/images/login_page_with_amazon_cognito.png b/site/content/en/images/login_page_with_amazon_cognito.png new file mode 100644 index 000000000000..b44b89bb5fe8 Binary files /dev/null and b/site/content/en/images/login_page_with_amazon_cognito.png differ diff --git a/site/process_sdk_docs.py b/site/process_sdk_docs.py index 03324aea691b..4fb911b69718 100755 --- a/site/process_sdk_docs.py +++ b/site/process_sdk_docs.py @@ -12,13 +12,13 @@ import sys import textwrap from glob import iglob -from typing import Callable, List +from typing import Callable from inflection import underscore class Processor: - _reference_files: List[str] + _reference_files: list[str] def __init__(self, *, input_dir: str, site_root: str) -> None: self._input_dir = input_dir @@ -29,7 +29,7 @@ def __init__(self, *, input_dir: str, site_root: str) -> None: self._templates_dir = osp.join(self._site_root, "templates") @staticmethod - def _copy_files(src_dir: str, glob_pattern: str, dst_dir: str) -> List[str]: + def _copy_files(src_dir: str, glob_pattern: str, dst_dir: str) -> list[str]: copied_files = [] for src_path in iglob(osp.join(src_dir, glob_pattern), recursive=True): @@ -140,7 +140,7 @@ def _fix_page_links_and_references(self): with open(p, "w") as f: f.write(contents) - def _process_non_code_blocks(self, text: str, handlers: List[Callable[[str], str]]) -> str: + def _process_non_code_blocks(self, text: str, handlers: list[Callable[[str], str]]) -> str: """ Allows to process Markdown documents with passed callbacks. Callbacks are only executed outside code blocks. diff --git a/site/requirements.txt b/site/requirements.txt index e240c7a0f90e..10db0c33a9b0 100644 --- a/site/requirements.txt +++ b/site/requirements.txt @@ -1,5 +1,4 @@ gitpython inflection >= 0.5.1 -isort>=5.10.1 packaging toml diff --git a/tests/cypress/e2e/actions_objects2/case_delete_frame.js b/tests/cypress/e2e/actions_objects2/case_delete_frame.js index 393fbc17b207..c0b3b34a0c99 100644 --- a/tests/cypress/e2e/actions_objects2/case_delete_frame.js +++ b/tests/cypress/e2e/actions_objects2/case_delete_frame.js @@ -40,12 +40,8 @@ context('Delete frame from job.', () => { cy.checkFrameNum(frame + 1); }); - it('Change deleted frame visability.', () => { - cy.openSettings(); - cy.get('.cvat-workspace-settings-show-deleted').within(() => { - cy.get('[type="checkbox"]').should('not.be.checked').check(); - }); - cy.closeSettings(); + it('Change deleted frame visibility.', () => { + cy.checkDeletedFrameVisibility(); }); it('Check previous frame available and deleted.', () => { diff --git a/tests/cypress/e2e/actions_projects_models/markdown_base_pipeline.js b/tests/cypress/e2e/actions_projects_models/markdown_base_pipeline.js index 51143c318794..639e57ad09f1 100644 --- a/tests/cypress/e2e/actions_projects_models/markdown_base_pipeline.js +++ b/tests/cypress/e2e/actions_projects_models/markdown_base_pipeline.js @@ -14,14 +14,14 @@ context('Basic markdown pipeline', () => { username: 'md_job_assignee', firstName: 'Firstname', lastName: 'Lastname', - emailAddr: 'md_job_assignee@local.local', + email: 'md_job_assignee@local.local', password: 'Fv5Df3#f55g', }, taskAssignee: { username: 'md_task_assignee', firstName: 'Firstname', lastName: 'Lastname', - emailAddr: 'md_task_assignee@local.local', + email: 'md_task_assignee@local.local', password: 'UfdU21!dds', }, notAssignee: { diff --git a/tests/cypress/e2e/actions_tasks/issue_2473_import_annotations_frames_dots_in_name.js b/tests/cypress/e2e/actions_tasks/issue_2473_import_annotations_frames_dots_in_name.js index 7398019d3903..3c02ba2eada7 100644 --- a/tests/cypress/e2e/actions_tasks/issue_2473_import_annotations_frames_dots_in_name.js +++ b/tests/cypress/e2e/actions_tasks/issue_2473_import_annotations_frames_dots_in_name.js @@ -33,7 +33,7 @@ context('Import annotations for frames with dots in name.', { browser: '!firefox secondY: 450, }; - const dumpType = 'YOLO'; + const dumpType = 'YOLO 1.1'; let annotationArchiveName = ''; function confirmUpdate(modalWindowClassName) { @@ -114,7 +114,7 @@ context('Import annotations for frames with dots in name.', { browser: '!firefox cy.interactMenu('Upload annotations'); cy.intercept('GET', '/api/jobs/**/annotations?**').as('uploadAnnotationsGet'); uploadAnnotation( - dumpType.split(' ')[0], + dumpType, annotationArchiveName, '.cvat-modal-content-load-job-annotation', ); diff --git a/tests/cypress/e2e/actions_users/registration_involved/case_28_review_pipeline_feature.js b/tests/cypress/e2e/actions_users/registration_involved/case_28_review_pipeline_feature.js index c77e00df3ad1..464464492832 100644 --- a/tests/cypress/e2e/actions_users/registration_involved/case_28_review_pipeline_feature.js +++ b/tests/cypress/e2e/actions_users/registration_involved/case_28_review_pipeline_feature.js @@ -12,14 +12,14 @@ context('Review pipeline feature', () => { username: 'annotator', firstName: 'Firstname', lastName: 'Lastname', - emailAddr: 'annotator@local.local', + email: 'annotator@local.local', password: 'UfdU21!dds', }, reviewer: { username: 'reviewer', firstName: 'Firstname', lastName: 'Lastname', - emailAddr: 'reviewer@local.local', + email: 'reviewer@local.local', password: 'Fv5Df3#f55g', }, }; diff --git a/tests/cypress/e2e/issues_prs2/issue_8785_update_job_metadata.js b/tests/cypress/e2e/issues_prs2/issue_8785_update_job_metadata.js new file mode 100644 index 000000000000..f34d3417c577 --- /dev/null +++ b/tests/cypress/e2e/issues_prs2/issue_8785_update_job_metadata.js @@ -0,0 +1,59 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +/// + +import { taskName } from '../../support/const'; + +context('The UI remains stable even when the metadata request fails.', () => { + const issueId = '8785'; + + function clickDeleteFrame() { + cy.get('.cvat-player-delete-frame').click(); + cy.get('.cvat-modal-delete-frame').within(() => { + cy.contains('button', 'Delete').click(); + }); + } + function clickSave() { + cy.get('button').contains('Save').click({ force: true }); + cy.get('button').contains('Save').trigger('mouseout'); + } + + before(() => { + cy.checkDeletedFrameVisibility(); + cy.openTaskJob(taskName); + cy.goToNextFrame(1); + }); + + describe(`Testing issue ${issueId}`, () => { + it('Crash on Save job. Save again.', () => { + const badResponse = { statusCode: 502, body: 'A horrible network error' }; + + cy.on('uncaught:exception', (err) => { + expect(err.code).to.equal(badResponse.statusCode); + expect(err.message).to.include(badResponse.body); + return false; + }); + + const routeMatcher = { + url: '/api/jobs/**/data/meta**', + method: 'PATCH', + times: 1, // cancels the intercept without retries + }; + + cy.intercept(routeMatcher, badResponse).as('patchError'); + + clickDeleteFrame(); + cy.get('.cvat-player-restore-frame').should('be.visible'); + + clickSave(); + cy.wait('@patchError').then((intercept) => { + expect(intercept.response.body).to.equal(badResponse.body); + expect(intercept.response.statusCode).to.equal(badResponse.statusCode); + }); + + cy.saveJob('PATCH', 200); + }); + }); +}); diff --git a/tests/cypress/support/commands.js b/tests/cypress/support/commands.js index 9941a9b0d5c3..a027c260e7b0 100644 --- a/tests/cypress/support/commands.js +++ b/tests/cypress/support/commands.js @@ -360,8 +360,12 @@ Cypress.Commands.add('headlessCreateUser', (userSpec) => { headers: { 'Content-type': 'application/json', }, + }).then((response) => { + expect(response.status).to.eq(201); + expect(response.body.username).to.eq(userSpec.username); + expect(response.body.email).to.eq(userSpec.email); + return cy.wrap(); }); - return cy.wrap(); }); Cypress.Commands.add('headlessLogout', () => { @@ -1683,6 +1687,14 @@ Cypress.Commands.add('hideTooltips', () => { }); }); +Cypress.Commands.add('checkDeletedFrameVisibility', () => { + cy.openSettings(); + cy.get('.cvat-workspace-settings-show-deleted').within(() => { + cy.get('[type="checkbox"]').should('not.be.checked').check(); + }); + cy.closeSettings(); +}); + Cypress.Commands.overwrite('visit', (orig, url, options) => { orig(url, options); cy.closeModalUnsupportedPlatform(); diff --git a/tests/python/pyproject.toml b/tests/python/pyproject.toml index ab4db6695977..6b5fba136a78 100644 --- a/tests/python/pyproject.toml +++ b/tests/python/pyproject.toml @@ -3,3 +3,4 @@ profile = "black" forced_separate = ["tests"] line_length = 100 skip_gitignore = true # align tool behavior with Black +known_first_party = ["shared", "rest_api", "sdk", "cli"] diff --git a/tests/python/requirements.txt b/tests/python/requirements.txt index d43d9b61d5df..dc21498a1ec0 100644 --- a/tests/python/requirements.txt +++ b/tests/python/requirements.txt @@ -8,7 +8,7 @@ deepdiff==7.0.1 boto3==1.17.61 Pillow==10.3.0 python-dateutil==2.8.2 -pyyaml==6.0.0 +pyyaml==6.0.2 numpy==2.0.0 # TODO: update pytest to 7.0.0 and pytest-timeout to 2.3.1 (better debug in vscode) \ No newline at end of file diff --git a/tests/python/rest_api/test_projects.py b/tests/python/rest_api/test_projects.py index d3d807d68088..7785454c8839 100644 --- a/tests/python/rest_api/test_projects.py +++ b/tests/python/rest_api/test_projects.py @@ -714,7 +714,7 @@ def test_can_import_dataset_in_org(self, admin_user: str): ("CVAT for images 1.1", "CVAT 1.1"), ("CVAT for video 1.1", "CVAT 1.1"), ("Datumaro 1.0", "Datumaro 1.0"), - ("YOLOv8 Pose 1.0", "YOLOv8 Pose 1.0"), + ("Ultralytics YOLO Pose 1.0", "Ultralytics YOLO Pose 1.0"), ), ) def test_can_export_and_import_dataset_with_skeletons( @@ -1078,10 +1078,10 @@ def _export_task(task_id: int, format_name: str) -> io.BytesIO: ("LFW 1.0", "{subset}/images/"), ("Cityscapes 1.0", "imgsFine/leftImg8bit/{subset}/"), ("Open Images V6 1.0", "images/{subset}/"), - ("YOLOv8 Detection 1.0", "images/{subset}/"), - ("YOLOv8 Oriented Bounding Boxes 1.0", "images/{subset}/"), - ("YOLOv8 Segmentation 1.0", "images/{subset}/"), - ("YOLOv8 Pose 1.0", "images/{subset}/"), + ("Ultralytics YOLO Detection 1.0", "images/{subset}/"), + ("Ultralytics YOLO Oriented Bounding Boxes 1.0", "images/{subset}/"), + ("Ultralytics YOLO Segmentation 1.0", "images/{subset}/"), + ("Ultralytics YOLO Pose 1.0", "images/{subset}/"), ], ) @pytest.mark.parametrize("api_version", (1, 2)) diff --git a/tests/python/rest_api/test_quality_control.py b/tests/python/rest_api/test_quality_control.py index d03675c9156e..56dd24bb0abb 100644 --- a/tests/python/rest_api/test_quality_control.py +++ b/tests/python/rest_api/test_quality_control.py @@ -1213,7 +1213,7 @@ def test_modified_task_produces_different_metrics( "compare_line_orientation", "panoptic_comparison", "point_size_base", - "match_empty_frames", + "empty_is_annotated", ], ) def test_settings_affect_metrics( @@ -1246,8 +1246,11 @@ def test_settings_affect_metrics( ) new_report = self.create_quality_report(admin_user, task_id) - if parameter == "match_empty_frames": + if parameter == "empty_is_annotated": assert new_report["summary"]["valid_count"] != old_report["summary"]["valid_count"] + assert new_report["summary"]["total_count"] != old_report["summary"]["total_count"] + assert new_report["summary"]["ds_count"] != old_report["summary"]["ds_count"] + assert new_report["summary"]["gt_count"] != old_report["summary"]["gt_count"] else: assert ( new_report["summary"]["conflict_count"] != old_report["summary"]["conflict_count"] diff --git a/tests/python/rest_api/test_tasks.py b/tests/python/rest_api/test_tasks.py index 15496cc31f73..a55bd1ded65b 100644 --- a/tests/python/rest_api/test_tasks.py +++ b/tests/python/rest_api/test_tasks.py @@ -45,7 +45,7 @@ from pytest_cases import fixture, fixture_ref, parametrize import shared.utils.s3 as s3 -from shared.fixtures.init import docker_exec_cvat, kube_exec_cvat +from shared.fixtures.init import container_exec_cvat from shared.utils.config import ( delete_method, get_method, @@ -947,7 +947,7 @@ def test_export_dataset_after_deleting_related_cloud_storage( [ ("Datumaro 1.0", "", "images/{subset}"), ("YOLO 1.1", "train", "obj_{subset}_data"), - ("YOLOv8 Detection 1.0", "train", "images/{subset}"), + ("Ultralytics YOLO Detection 1.0", "train", "images/{subset}"), ], ) @pytest.mark.parametrize("api_version", (1, 2)) @@ -5315,12 +5315,9 @@ def test_check_import_cache_after_previous_interrupted_upload(self, tasks_with_s number_of_files = 1 sleep(30) # wait when the cleaning job from rq worker will be started command = ["/bin/bash", "-c", f"ls data/tasks/{task_id}/tmp | wc -l"] - platform = request.config.getoption("--platform") - assert platform in ("kube", "local") - func = docker_exec_cvat if platform == "local" else kube_exec_cvat for _ in range(12): sleep(2) - result, _ = func(command) + result, _ = container_exec_cvat(request, command) number_of_files = int(result) if not number_of_files: break @@ -5422,10 +5419,10 @@ def test_can_import_datumaro_json(self, admin_user, tasks, dimension): "Open Images V6 1.0", "Datumaro 1.0", "Datumaro 3D 1.0", - "YOLOv8 Oriented Bounding Boxes 1.0", - "YOLOv8 Detection 1.0", - "YOLOv8 Pose 1.0", - "YOLOv8 Segmentation 1.0", + "Ultralytics YOLO Oriented Bounding Boxes 1.0", + "Ultralytics YOLO Detection 1.0", + "Ultralytics YOLO Pose 1.0", + "Ultralytics YOLO Segmentation 1.0", ], ) def test_check_import_error_on_wrong_file_structure(self, tasks_with_shapes, format_name): diff --git a/tests/python/shared/assets/cvat_db/data.json b/tests/python/shared/assets/cvat_db/data.json index 5b30d421cb5a..53863fa94fcc 100644 --- a/tests/python/shared/assets/cvat_db/data.json +++ b/tests/python/shared/assets/cvat_db/data.json @@ -18173,7 +18173,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18197,7 +18197,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18221,7 +18221,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18245,7 +18245,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18269,7 +18269,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18293,7 +18293,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18317,7 +18317,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18341,7 +18341,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18365,7 +18365,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18389,7 +18389,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18413,7 +18413,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18437,7 +18437,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18461,7 +18461,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18485,7 +18485,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18509,7 +18509,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18533,7 +18533,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18557,7 +18557,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18581,7 +18581,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18605,7 +18605,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18629,7 +18629,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18653,7 +18653,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18677,7 +18677,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18701,7 +18701,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 @@ -18725,7 +18725,7 @@ "object_visibility_threshold": 0.05, "panoptic_comparison": true, "compare_attributes": true, - "match_empty_frames": false, + "empty_is_annotated": false, "target_metric": "accuracy", "target_metric_threshold": 0.7, "max_validations_per_job": 0 diff --git a/tests/python/shared/assets/quality_settings.json b/tests/python/shared/assets/quality_settings.json index 7ddc589bc7bf..dc56352fc1ef 100644 --- a/tests/python/shared/assets/quality_settings.json +++ b/tests/python/shared/assets/quality_settings.json @@ -14,7 +14,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -35,7 +35,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -56,7 +56,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -77,7 +77,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -98,7 +98,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -119,7 +119,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -140,7 +140,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -161,7 +161,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -182,7 +182,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -203,7 +203,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -224,7 +224,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -245,7 +245,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -266,7 +266,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -287,7 +287,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -308,7 +308,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -329,7 +329,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -350,7 +350,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -371,7 +371,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -392,7 +392,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -413,7 +413,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -434,7 +434,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -455,7 +455,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -476,7 +476,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, @@ -497,7 +497,7 @@ "line_orientation_threshold": 0.1, "line_thickness": 0.01, "low_overlap_threshold": 0.8, - "match_empty_frames": false, + "empty_is_annotated": false, "max_validations_per_job": 0, "object_visibility_threshold": 0.05, "oks_sigma": 0.09, diff --git a/tests/python/shared/fixtures/init.py b/tests/python/shared/fixtures/init.py index 1f5d57ffc5d7..b0d5f8a84db0 100644 --- a/tests/python/shared/fixtures/init.py +++ b/tests/python/shared/fixtures/init.py @@ -171,6 +171,16 @@ def kube_exec_cvat(command: Union[list[str], str]): return _run(_command) +def container_exec_cvat(request: pytest.FixtureRequest, command: Union[list[str], str]): + platform = request.config.getoption("--platform") + if platform == "local": + return docker_exec_cvat(command) + elif platform == "kube": + return kube_exec_cvat(command) + else: + assert False, "unknown platform" + + def kube_exec_cvat_db(command): pod_name = _kube_get_db_pod_name() _run(["kubectl", "exec", pod_name, "--"] + command) diff --git a/tests/yarn.lock b/tests/yarn.lock index a5500151360a..4b82218c14a9 100644 --- a/tests/yarn.lock +++ b/tests/yarn.lock @@ -1146,9 +1146,9 @@ crc32-stream@^4.0.2: readable-stream "^3.4.0" cross-spawn@^7.0.0, cross-spawn@^7.0.3: - version "7.0.3" - resolved "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz" - integrity sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w== + version "7.0.6" + resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.6.tgz#8a58fe78f00dcd70c370451759dfbfaf03e8ee9f" + integrity sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA== dependencies: path-key "^3.1.0" shebang-command "^2.0.0" diff --git a/utils/dataset_manifest/__init__.py b/utils/dataset_manifest/__init__.py index 74fd25ede729..7efcfcb48406 100644 --- a/utils/dataset_manifest/__init__.py +++ b/utils/dataset_manifest/__init__.py @@ -1,4 +1,4 @@ # Copyright (C) 2021-2022 Intel Corporation # # SPDX-License-Identifier: MIT -from .core import VideoManifestManager, ImageManifestManager, is_manifest +from .core import ImageManifestManager, VideoManifestManager, is_manifest diff --git a/utils/dataset_manifest/core.py b/utils/dataset_manifest/core.py index 6a7c9d92f0d6..a855d170e86b 100644 --- a/utils/dataset_manifest/core.py +++ b/utils/dataset_manifest/core.py @@ -3,24 +3,24 @@ # # SPDX-License-Identifier: MIT -from enum import Enum -from io import StringIO -import av import json import os - -from abc import ABC, abstractmethod, abstractproperty, abstractstaticmethod +from abc import ABC, abstractmethod +from collections.abc import Iterator from contextlib import closing +from enum import Enum +from inspect import isgenerator +from io import StringIO from itertools import islice -from PIL import Image from json.decoder import JSONDecodeError -from inspect import isgenerator +from typing import Any, Callable, Optional, Union + +import av +from PIL import Image from .errors import InvalidManifestError, InvalidVideoError -from .utils import SortingMethod, md5_hash, rotate_image, sort from .types import NamedBytesIO - -from typing import Any, Dict, List, Union, Optional, Iterator, Tuple, Callable +from .utils import SortingMethod, md5_hash, rotate_image, sort class VideoStreamReader: @@ -78,7 +78,7 @@ def validate_key_frame(self, container, video_stream, key_frame): return False return True - def __iter__(self) -> Iterator[Union[int, Tuple[int, int, str]]]: + def __iter__(self) -> Iterator[Union[int, tuple[int, int, str]]]: """ Iterate over video frames and yield key frames or indexes. @@ -143,12 +143,12 @@ def __iter__(self) -> Iterator[Union[int, Tuple[int, int, str]]]: class DatasetImagesReader: def __init__(self, - sources: Union[List[str], Iterator[NamedBytesIO]], + sources: Union[list[str], Iterator[NamedBytesIO]], *, start: int = 0, step: int = 1, stop: Optional[int] = None, - meta: Optional[Dict[str, List[str]]] = None, + meta: Optional[dict[str, list[str]]] = None, sorting_method: SortingMethod = SortingMethod.PREDEFINED, use_image_hash: bool = False, **kwargs @@ -196,7 +196,7 @@ def step(self): def step(self, value): self._step = int(value) - def _get_img_properties(self, image: Union[str, NamedBytesIO]) -> Dict[str, Any]: + def _get_img_properties(self, image: Union[str, NamedBytesIO]) -> dict[str, Any]: img = Image.open(image, mode='r') if self._data_dir: img_name = os.path.relpath(image, self._data_dir) @@ -469,7 +469,8 @@ def __getitem__(self, item): def index(self): return self._index - @abstractproperty + @property + @abstractmethod def data(self): ... @@ -665,7 +666,7 @@ def emulate_hierarchical_structure( prefix: str = "", default_prefix: Optional[str] = None, start_index: Optional[int] = None, - ) -> Dict: + ) -> dict: if default_prefix and prefix and not (default_prefix.startswith(prefix) or prefix.startswith(default_prefix)): return { @@ -727,12 +728,12 @@ def emulate_hierarchical_structure( 'next': next_start_index, } - def reorder(self, reordered_images: List[str]) -> None: + def reorder(self, reordered_images: list[str]) -> None: """ The method takes a list of image names and reorders its content based on this new list. Due to the implementation of Honeypots, the reordered list of image names may contain duplicates. """ - unique_images: Dict[str, Any] = {} + unique_images: dict[str, Any] = {} for _, image_details in self: if image_details.full_name not in unique_images: unique_images[image_details.full_name] = image_details @@ -766,11 +767,13 @@ def _validate_type(self, _dict): if not _dict['type'] == self.TYPE: raise InvalidManifestError('Incorrect type field') - @abstractproperty + @property + @abstractmethod def validators(self): pass - @abstractstaticmethod + @staticmethod + @abstractmethod def _validate_first_item(_dict): pass diff --git a/utils/dataset_manifest/create.py b/utils/dataset_manifest/create.py index 64efaed60f2d..fa31300e058a 100755 --- a/utils/dataset_manifest/create.py +++ b/utils/dataset_manifest/create.py @@ -7,13 +7,14 @@ import argparse import os -import sys import re +import sys from glob import glob from tqdm import tqdm -from utils import detect_related_images, is_image, is_video, SortingMethod +from utils import SortingMethod, detect_related_images, is_image, is_video + def get_args(): parser = argparse.ArgumentParser() @@ -98,5 +99,5 @@ def main(): if __name__ == "__main__": base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(base_dir) - from dataset_manifest.core import VideoManifestManager, ImageManifestManager + from dataset_manifest.core import ImageManifestManager, VideoManifestManager main() diff --git a/utils/dataset_manifest/requirements.txt b/utils/dataset_manifest/requirements.txt index 6d3ed66aecb1..c073606622ed 100644 --- a/utils/dataset_manifest/requirements.txt +++ b/utils/dataset_manifest/requirements.txt @@ -13,7 +13,7 @@ numpy==1.22.4 # via opencv-python-headless opencv-python-headless==4.10.0.84 # via -r utils/dataset_manifest/requirements.in -pillow==11.0.0 +pillow==11.1.0 # via -r utils/dataset_manifest/requirements.in tqdm==4.67.1 # via -r utils/dataset_manifest/requirements.in diff --git a/utils/dataset_manifest/types.py b/utils/dataset_manifest/types.py index 8847eee457ba..5ddcce9ad5c9 100644 --- a/utils/dataset_manifest/types.py +++ b/utils/dataset_manifest/types.py @@ -5,6 +5,7 @@ from io import BytesIO from typing import Protocol + class Named(Protocol): filename: str diff --git a/utils/dataset_manifest/utils.py b/utils/dataset_manifest/utils.py index b4eee9686b71..9cb89ce5cd4d 100644 --- a/utils/dataset_manifest/utils.py +++ b/utils/dataset_manifest/utils.py @@ -2,15 +2,17 @@ # # SPDX-License-Identifier: MIT -import os -import re import hashlib import mimetypes +import os +import re +from enum import Enum +from random import shuffle + import cv2 as cv from av import VideoFrame -from enum import Enum from natsort import os_sorted -from random import shuffle + def rotate_image(image, angle): height, width = image.shape[:2] diff --git a/utils/dicom_converter/script.py b/utils/dicom_converter/script.py index 3fe7ef0be6dd..a201845965f3 100644 --- a/utils/dicom_converter/script.py +++ b/utils/dicom_converter/script.py @@ -3,17 +3,16 @@ # SPDX-License-Identifier: MIT -import os import argparse import logging +import os from glob import glob import numpy as np -from tqdm import tqdm from PIL import Image from pydicom import dcmread from pydicom.pixel_data_handlers.util import convert_color_space - +from tqdm import tqdm # Script configuration logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")