Skip to content

Commit

Permalink
fix: update mypy config and fix type errors which appeared as a result
Browse files Browse the repository at this point in the history
  • Loading branch information
chisholm authored and jkglasbrenner committed Jul 20, 2023
1 parent 81335b5 commit b1208d1
Show file tree
Hide file tree
Showing 24 changed files with 119 additions and 84 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ honor_noqa = true
[tool.mypy]
python_version = "3.9"
platform = "linux"
mypy_path = "src,task-plugins"
explicit_package_bases = true
namespace_packages = true
show_column_numbers = true
follow_imports = "normal"
Expand Down
2 changes: 1 addition & 1 deletion src/dioptra/mlflow_plugins/dioptra_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _log_workflow_artifact(run_id, workflow_filepath):
client.log_artifact(run_id=run_id, local_path=workflow_filepath)


def _set_dioptra_tags(run_id):
def _set_dioptra_tags(run_id: str) -> None:
job: Optional[Dict[str, Any]] = DioptraDatabaseClient().get_active_job()

if job is None:
Expand Down
8 changes: 6 additions & 2 deletions src/dioptra/mlflow_plugins/dioptra_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# https://creativecommons.org/licenses/by/4.0/legalcode
import datetime
import os
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, cast

import structlog
from flask import Flask
Expand Down Expand Up @@ -152,7 +152,11 @@ def delete_job(self, job_id: str) -> None:
raise

def get_experiment_by_id(self, experiment_id: int) -> Optional[Experiment]:
return Experiment.query.get(experiment_id)
result = Experiment.query.get(experiment_id)
# 1.4 SQLAlchemy type stubs don't seem to have a return type annotation for the
# get() method, so mypy assumes Any. Latest SQLAlchemy code uses Optional[Any].
# https://github.com/sqlalchemy/sqlalchemy/blob/d9d0ffd96c632750be9adcb03a207d75aecaa80f/lib/sqlalchemy/orm/query.py#L1048
return cast(Optional[Experiment], result)

def create_experiment(self, experiment_name: str, experiment_id: int) -> None:
timestamp = datetime.datetime.now()
Expand Down
79 changes: 37 additions & 42 deletions src/dioptra/pyplugs/_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
Optional,
TypeVar,
Union,
cast,
overload,
)

Expand Down Expand Up @@ -96,7 +97,7 @@
from typing_extensions import Protocol # type: ignore

if TYPE_CHECKING:
from prefect.core.task import TaskMetaclass as Task
from prefect.tasks.core.function import FunctionTask


# Structural subtyping
Expand All @@ -108,22 +109,10 @@ def __call__(self, *args, **kwargs) -> Any:


# Type aliases
F = TypeVar("F", bound=NoutPlugin)
T = TypeVar("T")
Plugin = Callable[..., Any]


# Only expose decorated functions to the outside
__all__ = []


def expose(func: Callable[..., T]) -> Callable[..., T]:
"""Add function to __all__ so it will be exposed at the top level"""
__all__.append(func.__name__)

return func


class PluginInfo(NamedTuple):
"""Information about one plug-in"""

Expand All @@ -142,7 +131,7 @@ class PluginInfo(NamedTuple):


@overload
def register(func: None, *, sort_value: float) -> Callable[[Plugin], Plugin]:
def register(*, sort_value: float) -> Callable[[Plugin], Plugin]:
"""Signature for using decorator with parameters"""
... # pragma: nocover

Expand All @@ -153,10 +142,7 @@ def register(func: Plugin) -> Plugin:
... # pragma: nocover


@expose
def register(
_func: Optional[Plugin] = None, *, sort_value: float = 0
) -> Callable[..., Any]:
def register(_func=None, *, sort_value=0):
"""Decorator for registering a new plug-in"""

def decorator_register(func: Callable[..., T]) -> Callable[..., T]:
Expand Down Expand Up @@ -188,25 +174,26 @@ def decorator_register(func: Callable[..., T]) -> Callable[..., T]:
return decorator_register(_func)


@expose
def task_nout(nout: int) -> Callable[[F], F]:
def decorator(func: F) -> F:
func._task_nout = nout
def task_nout(nout: int) -> Callable[[Plugin], NoutPlugin]:
def decorator(func: Plugin) -> NoutPlugin:
# We're just assigning an attribute, and we need mypy to let us
# do that. So we just force a type change, to a callable type which
# includes an attribute.
nout_func = cast(NoutPlugin, func)
nout_func._task_nout = nout

return func
return nout_func

return decorator


@expose
def names(package: str) -> List[str]:
"""List all plug-ins in one package"""
_import_all(package)

return sorted(_PLUGINS[package].keys(), key=lambda p: info(package, p).sort_value)


@expose
def funcs(package: str, plugin: str) -> List[str]:
"""List all functions in one plug-in"""
_import(package, plugin)
Expand All @@ -215,7 +202,6 @@ def funcs(package: str, plugin: str) -> List[str]:
return list(plugin_info.keys())


@expose
def info(package: str, plugin: str, func: Optional[str] = None) -> PluginInfo:
"""Get information about a plug-in"""
_import(package, plugin)
Expand All @@ -241,7 +227,6 @@ def info(package: str, plugin: str, func: Optional[str] = None) -> PluginInfo:
) from exc


@expose
def exists(package: str, plugin: str) -> bool:
"""Check if a given plugin exists"""
if package in _PLUGINS and plugin in _PLUGINS[package]:
Expand All @@ -257,13 +242,11 @@ def exists(package: str, plugin: str) -> bool:
return package in _PLUGINS and plugin in _PLUGINS[package]


@expose
def get(package: str, plugin: str, func: Optional[str] = None) -> Plugin:
"""Get a given plugin"""
return info(package, plugin, func).func


@expose
def call(
package: str, plugin: str, func: Optional[str] = None, *args: Any, **kwargs: Any
) -> Any:
Expand All @@ -273,17 +256,15 @@ def call(
return plugin_func(*args, **kwargs)


@expose
@require_package("prefect", exc_type=PrefectDependencyError)
def get_task(package: str, plugin: str, func: Optional[str] = None) -> Task:
def get_task(package: str, plugin: str, func: Optional[str] = None) -> FunctionTask:
"""Get a given plugin wrapped as a prefect task"""
plugin_func: Union[Plugin, NoutPlugin] = info(package, plugin, func).func
nout: Optional[int] = getattr(plugin_func, "_task_nout", None)

return task(plugin_func, nout=nout) # type: ignore
return task(nout=nout)(plugin_func)


@expose
@require_package("prefect", exc_type=PrefectDependencyError)
def call_task(
package: str, plugin: str, func: Optional[str] = None, *args: Any, **kwargs: Any
Expand Down Expand Up @@ -340,51 +321,65 @@ def _import_all(package: str) -> None:
pass # Don't let errors in one plugin, affect the others


@expose
def names_factory(package: str) -> Callable[[], List[str]]:
"""Create a names() function for one package"""
return functools.partial(names, package)


@expose
def funcs_factory(package: str) -> Callable[[str], List[str]]:
"""Create a funcs() function for one package"""
return functools.partial(funcs, package)


@expose
def info_factory(package: str) -> Callable[[str, Optional[str]], PluginInfo]:
"""Create a info() function for one package"""
return functools.partial(info, package)


@expose
def exists_factory(package: str) -> Callable[[str], bool]:
"""Create an exists() function for one package"""
return functools.partial(exists, package)


@expose
def get_factory(package: str) -> Callable[[str, Optional[str]], Plugin]:
"""Create a get() function for one package"""
return functools.partial(get, package)


@expose
def call_factory(package: str) -> Callable[..., Any]:
"""Create a call() function for one package"""
return functools.partial(call, package)


@expose
@require_package("prefect", exc_type=PrefectDependencyError)
def get_task_factory(package: str) -> Callable[[str, Optional[str]], Task]:
def get_task_factory(package: str) -> Callable[[str, Optional[str]], FunctionTask]:
"""Create a get_task() function for one package"""
return functools.partial(get_task, package)


@expose
@require_package("prefect", exc_type=PrefectDependencyError)
def call_task_factory(package: str) -> Callable[..., Any]:
"""Create a call_task() function for one package"""
return functools.partial(call_task, package)


__all__ = [
"register",
"task_nout",
"names",
"funcs",
"info",
"exists",
"get",
"call",
"get_task",
"call_task",
"names_factory",
"funcs_factory",
"info_factory",
"exists_factory",
"get_factory",
"call_factory",
"get_task_factory",
"call_task_factory",
]
6 changes: 3 additions & 3 deletions src/dioptra/restapi/experiment/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def delete(self, experimentId: int) -> Response:
experiment_id=experimentId
)

return jsonify(dict(status="Success", id=id)) # type: ignore
return jsonify(dict(status="Success", id=id))

@accepts(schema=ExperimentUpdateSchema, api=api)
@responds(schema=ExperimentSchema, api=api)
Expand Down Expand Up @@ -202,13 +202,13 @@ def delete(self, experimentName: str) -> Response:
)

if experiment is None:
return jsonify(dict(status="Success", id=[])) # type: ignore
return jsonify(dict(status="Success", id=[]))

id: List[int] = self._experiment_service.delete_experiment(
experiment_id=experiment.experiment_id
)

return jsonify(dict(status="Success", id=id)) # type: ignore
return jsonify(dict(status="Success", id=id))

@accepts(schema=ExperimentUpdateSchema, api=api)
@responds(schema=ExperimentSchema, api=api)
Expand Down
2 changes: 1 addition & 1 deletion src/dioptra/restapi/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class ExperimentRegistrationForm(FlaskForm):

name = StringField("Name of Experiment", validators=[InputRequired()])

def validate_name(self, field):
def validate_name(self, field: StringField) -> None:
"""Validates that the experiment does not exist in the registry.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/dioptra/restapi/experiment/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from marshmallow import Schema, fields, post_dump, pre_dump

from dioptra.restapi.utils import slugify
from dioptra.restapi.utils import ParametersSchema, slugify

from .model import (
Experiment,
Expand Down Expand Up @@ -113,7 +113,7 @@ def serialize_object(
return self.__model__(**data)


ExperimentRegistrationSchema = [
ExperimentRegistrationSchema: list[ParametersSchema] = [
dict(
name="name",
type=str,
Expand Down
4 changes: 2 additions & 2 deletions src/dioptra/restapi/job/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class JobForm(FlaskForm):
"MLproject file and its associated entry point scripts.",
)

def validate_experiment_name(self, field):
def validate_experiment_name(self, field: StringField) -> None:
"""Validates that the experiment is registered and not deleted.
Args:
Expand All @@ -186,7 +186,7 @@ def validate_experiment_name(self, field):
"Please check spelling and resubmit."
)

def validate_queue(self, field):
def validate_queue(self, field: StringField) -> None:
"""Validates that the queue is registered, active and not deleted.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/dioptra/restapi/job/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from marshmallow import Schema, fields, post_dump, post_load, pre_dump, validate
from werkzeug.datastructures import FileStorage

from dioptra.restapi.utils import slugify
from dioptra.restapi.utils import ParametersSchema, slugify

from .model import Job, JobForm, JobFormData

Expand Down Expand Up @@ -227,7 +227,7 @@ def serialize_object(
return self.__model__(**data)


job_submit_form_schema = [
job_submit_form_schema: list[ParametersSchema] = [
dict(
name="experiment_name",
type=str,
Expand Down
13 changes: 11 additions & 2 deletions src/dioptra/restapi/job/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
from werkzeug.utils import secure_filename

from dioptra.restapi.app import db
from dioptra.restapi.experiment.errors import ExperimentDoesNotExistError
from dioptra.restapi.experiment.service import ExperimentService
from dioptra.restapi.queue.errors import QueueDoesNotExistError
from dioptra.restapi.queue.service import QueueService
from dioptra.restapi.shared.rq.service import RQService
from dioptra.restapi.shared.s3.service import S3Service
Expand Down Expand Up @@ -92,13 +94,20 @@ def extract_data_from_form(self, job_form: JobForm, **kwargs) -> JobFormData:

job_form_data: JobFormData = self._job_form_schema.dump(job_form)

experiment: Experiment = self._experiment_service.get_by_name(
experiment: Optional[Experiment] = self._experiment_service.get_by_name(
job_form_data["experiment_name"], log=log
)
queue: Queue = self._queue_service.get_unlocked_by_name(

if experiment is None:
raise ExperimentDoesNotExistError

queue: Optional[Queue] = self._queue_service.get_unlocked_by_name(
job_form_data["queue"], log=log
)

if queue is None:
raise QueueDoesNotExistError

job_form_data["experiment_id"] = experiment.experiment_id
job_form_data["queue_id"] = queue.queue_id

Expand Down
Loading

0 comments on commit b1208d1

Please sign in to comment.