Skip to content

Commit

Permalink
refactor: refactor resource import workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
keithmanville committed Oct 30, 2024
1 parent ace4574 commit b4372ce
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 35 deletions.
4 changes: 2 additions & 2 deletions src/dioptra/restapi/v1/plugin_parameter_types/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,8 +571,8 @@ def get(
group_id: int,
error_if_not_found: bool = False,
**kwargs,
) -> models.PluginTaskParameterType | None:
"""Fetch a list of plugin parameter types by their names.
) -> list[models.PluginTaskParameterType]:
"""Fetch a list of builtin plugin parameter types.
Args:
group_id: The the group id of the plugin parameter type.
Expand Down
2 changes: 1 addition & 1 deletion src/dioptra/restapi/v1/plugins/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def create(
self,
filename: str,
contents: str,
description: str,
description: str | None,
tasks: list[dict[str, Any]],
plugin_id: int,
commit: bool = True,
Expand Down
1 change: 0 additions & 1 deletion src/dioptra/restapi/v1/workflows/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def post(self):
git_url=parsed_form.get("git_url", None),
archive_file=request.files.get("archiveFile", None),
config_path=parsed_form["config_path"],
read_only=parsed_form["read_only"],
resolve_name_conflicts_strategy=parsed_form[
"resolve_name_conflicts_strategy"
],
Expand Down
8 changes: 4 additions & 4 deletions src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from urllib.parse import urlparse


def clone_git_repository(url: str, dir: Path):
def clone_git_repository(url: str, dir: Path) -> str:
parsed_url = urlparse(url)
git_branch = parsed_url.fragment or None
git_paths = parsed_url.params or None
git_url = parsed_url._replace(fragment="", params="").geturl()

git_sparse_args = ["--filter=blob:none", "--no-checkout", "--depth=1"]
git_branch_args = ["-b", git_branch] if git_branch else []
clone_cmd = ["git", "clone", *git_sparse_args, *git_branch_args, git_url, dir]
clone_cmd = ["git", "clone", *git_sparse_args, *git_branch_args, git_url, str(dir)]
clone_result = subprocess.run(clone_cmd, capture_output=True, text=True)

if clone_result.returncode != 0:
Expand Down Expand Up @@ -45,9 +45,9 @@ def clone_git_repository(url: str, dir: Path):
hash_result = subprocess.run(hash_cmd, cwd=dir, capture_output=True, text=True)

if hash_result.returncode != 0:
raise subprocess.CalledProcessError
raise subprocess.CalledProcessError(hash_result.returncode, hash_result.stderr)

return hash
return str(hash)


if __name__ == "__main__":
Expand Down
5 changes: 0 additions & 5 deletions src/dioptra/restapi/v1/workflows/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,6 @@ class ResourceImportSchema(Schema):
metdata=dict(description="The path to the toml configuration file."),
load_default="dioptra.toml",
)
readOnly = fields.Bool(
attribute="read_only",
metadata=dict(description="Whether imported resources should be readonly."),
load_default=False,
)
resolveNameConflictsStrategy = fields.Enum(
ResourceImportResolveNameConflictsStrategy,
attribute="resolve_name_conflicts_strategy",
Expand Down
86 changes: 64 additions & 22 deletions src/dioptra/restapi/v1/workflows/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,24 @@
from werkzeug.datastructures import FileStorage

from dioptra.restapi.db import db
from dioptra.restapi.v1.entrypoints.service import EntrypointService
from dioptra.restapi.errors import DioptraError
from dioptra.restapi.v1.entrypoints.service import (
EntrypointIdService,
EntrypointNameService,
EntrypointService,
)
from dioptra.restapi.v1.plugin_parameter_types.service import (
BuiltinPluginParameterTypeService,
PluginParameterTypeIdService,
PluginParameterTypeNameService,
PluginParameterTypeService,
)
from dioptra.restapi.v1.plugins.service import PluginIdFileService, PluginService
from dioptra.restapi.v1.plugins.service import (
PluginIdFileService,
PluginIdService,
PluginNameService,
PluginService,
)
from dioptra.sdk.utilities.paths import set_cwd

from .lib import clone_git_repository, package_job_files, views
Expand Down Expand Up @@ -99,29 +111,48 @@ class ResourceImportService(object):
def __init__(
self,
plugin_service: PluginService,
plugin_id_service: PluginIdService,
plugin_name_service: PluginNameService,
plugin_id_file_service: PluginIdFileService,
plugin_parameter_type_service: PluginParameterTypeService,
plugin_parameter_type_id_service: PluginParameterTypeIdService,
plugin_parameter_type_name_service: PluginParameterTypeNameService,
builtin_plugin_parameter_type_service: BuiltinPluginParameterTypeService,
entrypoint_service: EntrypointService,
entrypoint_id_service: EntrypointIdService,
entrypoint_name_service: EntrypointNameService,
) -> None:
"""Initialize the resource import service.
All arguments are provided via dependency injection.
Args:
plugin_service: A PluginService object,
plugin_name_service: A PluginNameService object.
plugin_id_service: A PluginIdService object.
plugin_id_file_service: A PluginIdFileService object.
plugin_parameter_type_service: A PluginParameterTypeService object.
builtin_plugin_parameter_type_service: A BuiltinPluginParameterTypeService object.
entrypoint_service: A EntrypointService object.
plugin_parameter_type_id_service: A PluginParameterTypeIdService object.
plugin_parameter_type_name_service: A PluginParameterTypeNameService object.
builtin_plugin_parameter_type_service: A BuiltinPluginParameterTypeService
object.
entrypoint_service: An EntrypointService object.
entrypoint_id_service: An EntrypointIdService object.
entrypoint_name_service: An EntrypointNameService object.
"""
self._plugin_service = plugin_service
self._plugin_id_service = plugin_id_service
self._plugin_name_service = plugin_name_service
self._plugin_id_file_service = plugin_id_file_service
self._plugin_parameter_type_service = plugin_parameter_type_service
self._plugin_parameter_type_id_service = plugin_parameter_type_id_service
self._plugin_parameter_type_name_service = plugin_parameter_type_name_service
self._builtin_plugin_parameter_type_service = (
builtin_plugin_parameter_type_service
)
self._entrypoint_service = entrypoint_service
self._entrypoint_id_service = entrypoint_id_service
self._entrypoint_name_service = entrypoint_name_service

def import_resources(
self,
Expand All @@ -130,7 +161,6 @@ def import_resources(
git_url: str | None,
archive_file: FileStorage | None,
config_path: str,
read_only: bool,
resolve_name_conflicts_strategy: str,
**kwargs,
) -> dict[str, Any]:
Expand All @@ -141,7 +171,6 @@ def import_resources(
source_type: The source to import from (either "upload" or "git")
git_url: The url to the git repository if source_type is "git"
archive_file: The contents of the upload if source_type is "upload"
read_only: Whether to apply a readonly lock to all imported resources
resolve_name_conflicts_strategy: The strategy for resolving name conflicts.
Either "fail" or "overwrite"
Expand All @@ -163,11 +192,16 @@ def import_resources(
bytes = archive_file.stream.read()
with tarfile.open(fileobj=BytesIO(bytes), mode="r:*") as tar:
tar.extractall(path=working_dir, filter="data")
hash = sha256(bytes).hexdigest()
hash = str(sha256(bytes).hexdigest())
elif source_type == ResourceImportSourceTypes.GIT:
hash = clone_git_repository(git_url, working_dir)

config = toml.load(working_dir / config_path)
try:
config = toml.load(working_dir / config_path)
except Exception as e:
raise DioptraError(
f"Failed to load resource import config from {config_path}."
) from e

# validate the config file
with open(
Expand All @@ -192,16 +226,16 @@ def import_resources(
"message": "successfully imported",
"hash": hash,
"resources": {
"plugins": [plugin.name for plugin in plugins],
"plugin_param_types": [param_type.name for param_type in param_types],
"entrypoints": [entrypoint.name for entrypoint in entrypoints],
"plugins": list(plugins.keys()),
"plugin_param_types": list(param_types.keys()),
"entrypoints": list(entrypoints.keys()),
},
}

def _register_plugin_param_types(
self,
group_id: int,
param_types_config: dict[str, Any],
param_types_config: list[dict[str, Any]],
overwrite: bool,
log: BoundLogger,
) -> dict[str, Any]:
Expand All @@ -212,19 +246,22 @@ def _register_plugin_param_types(
param_type["name"], group_id=group_id, log=log
)
if existing:
self._plugin_parameter_type_service.delete(
plugin_parameter_type_id=existing["id"],
self._plugin_parameter_type_id_service.delete(
plugin_parameter_type_id=existing.resource_id,
log=log,
)

param_type["name"] = self._plugin_parameter_type_service.create(
param_type_dict = self._plugin_parameter_type_service.create(
name=param_type["name"],
description=param_type.get("description", None),
structure=param_type.get("structure", None),
group_id=group_id,
commit=False,
log=log,
)["plugin_task_parameter_type"]
)
param_types[param_type["name"]] = param_type_dict[
"plugin_task_parameter_type"
]

db.session.flush()

Expand All @@ -233,11 +270,12 @@ def _register_plugin_param_types(
def _register_plugins(
self,
group_id: int,
plugins_config: dict[str, Any],
plugins_config: list[dict[str, Any]],
param_types: Any,
overwrite: bool,
log: BoundLogger,
):
param_types = param_types.copy()
builtin_param_types = self._builtin_plugin_parameter_type_service.get(
group_id=group_id, error_if_not_found=False, log=log
)
Expand All @@ -253,7 +291,7 @@ def _register_plugins(
)
if existing:
self._plugin_id_service.delete(
plugin_id=existing["id"],
plugin_id=existing.resource_id,
log=log,
)

Expand Down Expand Up @@ -289,7 +327,7 @@ def _register_plugins(
def _register_entrypoints(
self,
group_id: int,
entrypoints_config: dict[str, Any],
entrypoints_config: list[dict[str, Any]],
plugins,
overwrite: bool,
log: BoundLogger,
Expand All @@ -301,7 +339,7 @@ def _register_entrypoints(
entrypoint["name"], group_id=group_id, log=log
)
if existing is not None:
self.entrypoint_id_service.delete(
self._entrypoint_id_service.delete(
entrypoint_id=existing.resource_id
)

Expand All @@ -328,13 +366,17 @@ def _register_entrypoints(
commit=False,
log=log,
)
entrypoints[entrypoint_dict["name"]] = entrypoint_dict["entrypoint"]
entrypoints[entrypoint_dict["entry_point"].name] = entrypoint_dict[
"entry_point"
]

db.session.flush()

return entrypoints

def _build_tasks(self, tasks_config, param_types):
def _build_tasks(
self, tasks_config: list[dict[str, Any]], param_types: list[dict[str, str]]
) -> dict[str, list]:
tasks = defaultdict(list)
for task in tasks_config:
tasks[task["filename"]].append(
Expand Down

0 comments on commit b4372ce

Please sign in to comment.