Skip to content

Commit

Permalink
refactor: Remove Flask-Injector as a dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
chisholm committed Jul 27, 2023
1 parent b1208d1 commit 3ee65ff
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 18 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@ dependencies = [
"boto3>=1.16.0",
"Click>=8.0.0,<9",
"entrypoints>=0.3",
"Flask>=2.0.0,<2.2.0",
"Flask>=2.0.0",
"flask-accepts>=0.17.0",
"Flask-Cors>=3.0.1",
"Flask-Injector>=0.14.0",
"Flask-Migrate>=2.5.0",
"flask-restx>=0.5.1",
"Flask-SQLAlchemy>=2.4.0",
Expand Down
2 changes: 1 addition & 1 deletion src/dioptra/mlflow_plugins/dioptra_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
class DioptraDatabaseClient(object):
@property
def app(self) -> Flask:
app: Flask = create_app(env=self.restapi_env)
app, _ = create_app(env=self.restapi_env)
return app

@property
Expand Down
17 changes: 10 additions & 7 deletions src/dioptra/restapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@
import structlog
from flask import Flask, jsonify
from flask_cors import CORS
from flask_injector import FlaskInjector
from flask_migrate import Migrate
from flask_restx import Api
from flask_sqlalchemy import SQLAlchemy
from flask_wtf import CSRFProtect
from injector import Injector
from sqlalchemy import MetaData
from structlog.stdlib import BoundLogger

from dioptra.restapi.utils import setup_injection

from .__version__ import __version__ as API_VERSION

LOGGER: BoundLogger = structlog.stdlib.get_logger()
Expand All @@ -57,7 +59,9 @@
migrate: Migrate = Migrate()


def create_app(env: Optional[str] = None, inject_dependencies: bool = True):
def create_app(
env: Optional[str] = None, inject_dependencies: bool = True
) -> tuple[Flask, Api]:
"""Creates and configures a fresh instance of the Dioptra REST API.
Args:
Expand Down Expand Up @@ -113,9 +117,8 @@ def health():
log = LOGGER.new(request_id=str(uuid.uuid4())) # noqa: F841
return jsonify("healthy")

if not inject_dependencies:
return app

FlaskInjector(app=app, modules=modules)
if inject_dependencies:
injector = Injector(modules)
setup_injection(api, injector)

return app
return app, api
3 changes: 2 additions & 1 deletion src/dioptra/restapi/experiment/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@

from typing import Any, Callable, List

from flask_injector import request
from injector import Binder, Module, provider
from mlflow.tracking import MlflowClient

from dioptra.restapi.shared.request_scope import request

from .schema import ExperimentRegistrationFormSchema


Expand Down
2 changes: 1 addition & 1 deletion src/dioptra/restapi/job/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

from boto3.session import Session
from botocore.client import BaseClient
from flask_injector import request
from injector import Binder, Module, provider
from redis import Redis

from dioptra.restapi.shared.request_scope import request
from dioptra.restapi.shared.rq.service import RQService

from .schema import JobFormSchema
Expand Down
136 changes: 136 additions & 0 deletions src/dioptra/restapi/shared/request_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Code below is copied from Flask-Injector library. It has the following
# license:
#
# Copyright (c) 2012, Alec Thomas
# Copyright (c) 2015 Smarkets Limited <[email protected]>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# - Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# - Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# - Neither the name of SwapOff.org nor the names of its contributors may
# be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# Tell flake8 to ignore this file. It is mostly just copy-pastes from
# Flask-Injector.
# flake8: noqa

from typing import Any, Dict

import flask
from injector import Injector, Provider, Scope, ScopeDecorator
from werkzeug.local import Local, LocalManager


class CachedProviderWrapper(Provider):
def __init__(self, old_provider: Provider) -> None:
self._old_provider = old_provider
self._cache = {} # type: Dict[int, Any]

def get(self, injector: Injector) -> Any:
key = id(injector)
try:
return self._cache[key]
except KeyError:
instance = self._cache[key] = self._old_provider.get(injector)
return instance


class RequestScope(Scope):
"""A scope whose object lifetime is tied to a request.
@request
class Session:
pass
"""

# We don't want to assign here, just provide type hints
if False:
_local_manager = None # type: LocalManager
_locals = None # type: Any

def cleanup(self) -> None:
self._local_manager.cleanup()

def prepare(self) -> None:
self._locals.scope = {}

def configure(self) -> None:
self._locals = Local()
self._local_manager = LocalManager([self._locals])
self.prepare()

def get(self, key: Any, provider: Provider) -> Any:
try:
return self._locals.scope[key]
except KeyError:
new_provider = self._locals.scope[key] = CachedProviderWrapper(provider)
return new_provider


request = ScopeDecorator(RequestScope)


def set_request_scope_callbacks(app: flask.Flask, injector: Injector) -> None:
"""
Set callbacks to enable request scoping behavior: initialize at the
beginning of request handling, and cleanup at the end.
Args:
app: A Flask app
injector: An injector, used to get the RequestScope object
"""

def reset_request_scope_before(*args: Any, **kwargs: Any) -> None:
injector.get(RequestScope).prepare()

def global_reset_request_scope_after(*args: Any, **kwargs: Any) -> None:
blueprint = flask.request.blueprint
# If current blueprint has teardown_request_funcs associated with it we know there may be
# a some teardown request handlers we need to inject into, so we can't reset the scope just yet.
# We'll leave it to blueprint_reset_request_scope_after to do the job which we know will run
# later and we know it'll run after any teardown_request handlers we may want to inject into.
if blueprint is None or blueprint not in app.teardown_request_funcs:
injector.get(RequestScope).cleanup()

def blueprint_reset_request_scope_after(*args: Any, **kwargs: Any) -> None:
# If we got here we truly know this is the last teardown handler, which means we can reset the
# scope unconditionally.
injector.get(RequestScope).cleanup()

app.before_request_funcs.setdefault(None, []).insert(0, reset_request_scope_before)
# We're accessing Flask internals here as the app.teardown_request decorator appends to a list of
# handlers but Flask itself reverses the list when it executes them. To allow injecting request-scoped
# dependencies into teardown_request handlers we need to run our teardown_request handler after them.
# Also see https://github.com/alecthomas/flask_injector/issues/42 where it was reported.
# Secondly, we need to handle blueprints. Flask first executes non-blueprint teardown handlers in
# reverse order and only then executes blueprint-associated teardown handlers in reverse order,
# which means we can't just set on non-blueprint teardown handler, but we need to set both.
# In non-blueprint teardown handler we check if a blueprint handler will run – if so, we do nothing
# there and leave it to the blueprint teardown handler.
#
# We need the None key to be present in the dictionary so that the dictionary iteration always yields
# None as well. We *always* have to set the global teardown request.
app.teardown_request_funcs.setdefault(None, []).insert(
0, global_reset_request_scope_after
)
for bp, functions in app.teardown_request_funcs.items():
if bp is not None:
functions.insert(0, blueprint_reset_request_scope_after)
137 changes: 135 additions & 2 deletions src/dioptra/restapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@
"""
from __future__ import annotations

from typing import List
import functools
from typing import Any, Callable, List, Protocol, Type, cast

from flask_restx import Namespace
from flask.views import View
from flask_restx import Api, Namespace
from flask_restx.reqparse import RequestParser
from injector import Injector
from typing_extensions import TypedDict

from dioptra.restapi.shared.request_scope import set_request_scope_callbacks


class ParametersSchema(TypedDict, total=False):
"""A schema of the parameters that can be passed to the |RequestParser|."""
Expand Down Expand Up @@ -79,3 +84,131 @@ def slugify(text: str) -> str:
"""

return text.lower().strip().replace(" ", "-")


def _simple_function_wrap(
func: Callable[..., Any], injector: Injector
) -> Callable[..., Any]:
"""
Wrap func such that it is called with dependency injection, using the given
injector.
Args:
func: A function
injector: An injector
Returns:
A function which calls func with dependency injection
"""

@functools.wraps(
# view functions can have some other attribute junk which might be
# important...?
func,
assigned=functools.WRAPPER_ASSIGNMENTS
+ ("methods", "required_methods", "provide_automatic_options"),
)
def wrapper(*args, **kwargs):
return injector.call_with_injection(func, args=args, kwargs=kwargs)

return wrapper


class _ClassBasedViewFunction(Protocol):
"""
We distinguish a class-based view function from other view functions
by looking for a "view_class" attribute on the function.
"""

view_class: Type[View]

def __call__(self, *args, **kwargs) -> Any:
...


def _new_class_view_function(
func: _ClassBasedViewFunction, injector: Injector, api: Api
) -> _ClassBasedViewFunction:
"""
Create a view function which supports injection, based on the given
class-based view function. "Wrapping" func won't work here, in the sense
that our view function can't delegate to func. The original view function
does not support dependency-injected view object creation, so it is
unusable. So we create a brand new one (@wrap'd, so it has the look of
func at least), which does dependency-injected view creation.
Args:
func: The old class-based view function
injector: An injector
api: The flask_restx Api instance
Returns:
A new view function
"""

view_obj = None

# Honoring init_every_request is simple enough to do, so why not.
# It was added in Flask 2.2.0; it behaved as though True, previously.
init_every_request = getattr(func.view_class, "init_every_request", True)

if not init_every_request:
view_obj = injector.create_object(
func.view_class, additional_kwargs={"api": api}
)

@functools.wraps(
func,
assigned=functools.WRAPPER_ASSIGNMENTS
+ ("view_class", "methods", "provide_automatic_options"),
)
def new_view_func(*args, **kwargs):
nonlocal view_obj
if init_every_request:
view_obj = injector.create_object(
func.view_class, additional_kwargs={"api": api}
)

return view_obj.dispatch_request(*args, **kwargs)

new_view_func = api.output(new_view_func)

# "func" must have a view_class attribute since it is the
# prerequisite for calling this function. @wraps copied
# that over to new_view_func, so it must have the attribute
# too. But mypy can't see that. I don't know another way
# to satisfy mypy.
new_view_func = cast(_ClassBasedViewFunction, new_view_func)

return new_view_func


def setup_injection(api: Api, injector: Injector) -> None:
"""
Fixup the given flask app such that view functions support dependency
injection.
Args:
api: A flask_restx Api object. This contains the flask app, and is
also necessary to make restx views (resources) work with
dependency injection.
injector: An injector
"""

new_view_func: Callable[..., Any]

for key, func in api.app.view_functions.items():
if hasattr(func, "view_class"):
new_view_func = _new_class_view_function(func, injector, api)
else:
new_view_func = _simple_function_wrap(func, injector)

api.app.view_functions[key] = new_view_func

set_request_scope_callbacks(api.app, injector)

# Uncomment to see more detailed logging regarding dependency injection
# in debug mode.
# if api.app.debug:
# injector_logger = logging.getLogger("injector")
# injector_logger.setLevel(logging.DEBUG)
Loading

0 comments on commit 3ee65ff

Please sign in to comment.