From de1d88515d0b8d1d4f1d53d6fbbfc6f1b5bf8b3c Mon Sep 17 00:00:00 2001 From: Elias Freider Date: Thu, 16 Jan 2025 11:43:09 +0100 Subject: [PATCH] Fixes issue where `modal run` of a method on a class triggered deprecation warning [CLI-304] (#2761) --- modal/app.py | 2 +- modal/cli/import_refs.py | 179 ++++++++++++++++++++++++++------------ modal/cli/launch.py | 10 ++- modal/cli/run.py | 135 ++++++++++++++++------------ modal/cls.py | 16 ++++ modal/partial_function.py | 3 + tasks.py | 1 + test/cli_imports_test.py | 132 +++++++++++++++++++++++----- test/cli_test.py | 6 +- 9 files changed, 344 insertions(+), 140 deletions(-) diff --git a/modal/app.py b/modal/app.py index e0fa4c046..bb28fb28f 100644 --- a/modal/app.py +++ b/modal/app.py @@ -507,7 +507,7 @@ def registered_functions(self) -> dict[str, _Function]: return self._functions @property - def registered_classes(self) -> dict[str, _Function]: + def registered_classes(self) -> dict[str, _Cls]: """All modal.Cls objects registered on the app.""" return self._classes diff --git a/modal/cli/import_refs.py b/modal/cli/import_refs.py index 7479c4fbe..13a7a2d6f 100644 --- a/modal/cli/import_refs.py +++ b/modal/cli/import_refs.py @@ -9,8 +9,11 @@ import dataclasses import importlib +import importlib.util import inspect import sys +import types +from dataclasses import dataclass from pathlib import Path from typing import Any, Optional, Union @@ -19,6 +22,7 @@ from rich.markdown import Markdown from modal.app import App, LocalEntrypoint +from modal.cls import Cls from modal.exception import InvalidError, _CliUserExecutionError from modal.functions import Function @@ -26,6 +30,12 @@ @dataclasses.dataclass class ImportRef: file_or_module: str + + # object_path is a .-delimited path to the object to execute, or a parent from which to infer the object + # e.g. + # function or local_entrypoint in module scope + # app in module scope [+ method name] + # app [+ function/entrypoint on that app] object_path: Optional[str] @@ -62,11 +72,14 @@ def import_file_or_module(file_or_module: str): sys.path.insert(0, str(full_path.parent)) module_name = inspect.getmodulename(file_or_module) + assert module_name is not None # Import the module - see https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly spec = importlib.util.spec_from_file_location(module_name, file_or_module) + assert spec is not None module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module try: + assert spec.loader spec.loader.exec_module(module) except Exception as exc: raise _CliUserExecutionError(str(full_path)) from exc @@ -79,23 +92,39 @@ def import_file_or_module(file_or_module: str): return module -def get_by_object_path(obj: Any, obj_path: str) -> Optional[Any]: +@dataclass +class MethodReference: + """This helps with deferring method reference until after the class gets instantiated by the CLI""" + + cls: Cls + method_name: str + + +def get_by_object_path(obj: Any, obj_path: str) -> Union[Function, LocalEntrypoint, MethodReference, App, None]: # Try to evaluate a `.`-delimited object path in a Modal context # With the caveat that some object names can actually have `.` in their name (lifecycled methods' tags) # Note: this is eager, so no backtracking is performed in case an # earlier match fails at some later point in the path expansion prefix = "" - for segment in obj_path.split("."): + obj_path_segments = obj_path.split(".") + for i, segment in enumerate(obj_path_segments): attr = prefix + segment + if isinstance(obj, App): + if attr in obj.registered_entrypoints: + # local entrypoints can't be accessed via getattr + obj = obj.registered_entrypoints[attr] + continue + if isinstance(obj, Cls): + remaining_segments = obj_path_segments[i:] + remaining_path = ".".join(remaining_segments) + if len(remaining_segments) > 1: + raise ValueError(f"{obj._get_name()} is a class, but {remaining_path} is not a method reference") + # TODO: add method inference here? + return MethodReference(obj, remaining_path) + try: - if isinstance(obj, App): - if attr in obj.registered_entrypoints: - # local entrypoints are not on stub blueprint - obj = obj.registered_entrypoints[attr] - continue obj = getattr(obj, attr) - except Exception: prefix = f"{prefix}{segment}." else: @@ -109,54 +138,62 @@ def get_by_object_path(obj: Any, obj_path: str) -> Optional[Any]: def _infer_function_or_help( app: App, module, accept_local_entrypoint: bool, accept_webhook: bool -) -> Union[Function, LocalEntrypoint]: - function_choices = set(app.registered_functions) - if not accept_webhook: - function_choices -= set(app.registered_web_endpoints) - if accept_local_entrypoint: - function_choices |= set(app.registered_entrypoints.keys()) +) -> Union[Function, LocalEntrypoint, MethodReference]: + """Using only an app - automatically infer a single "runnable" for a `modal run` invocation - sorted_function_choices = sorted(function_choices) - registered_functions_str = "\n".join(sorted_function_choices) + If a single runnable can't be determined, show CLI help indicating valid choices. + """ filtered_local_entrypoints = [ - name - for name, entrypoint in app.registered_entrypoints.items() + entrypoint + for entrypoint in app.registered_entrypoints.values() if entrypoint.info.module_name == module.__name__ ] - if accept_local_entrypoint and len(filtered_local_entrypoints) == 1: - # If there is just a single local entrypoint in the target module, use - # that regardless of other functions. - function_name = list(filtered_local_entrypoints)[0] - elif accept_local_entrypoint and len(app.registered_entrypoints) == 1: - # Otherwise, if there is just a single local entrypoint in the stub as a whole, - # use that one. - function_name = list(app.registered_entrypoints.keys())[0] - elif len(function_choices) == 1: - function_name = sorted_function_choices[0] - elif len(function_choices) == 0: + if accept_local_entrypoint: + if len(filtered_local_entrypoints) == 1: + # If there is just a single local entrypoint in the target module, use + # that regardless of other functions. + return filtered_local_entrypoints[0] + elif len(app.registered_entrypoints) == 1: + # Otherwise, if there is just a single local entrypoint in the app as a whole, + # use that one. + return list(app.registered_entrypoints.values())[0] + + # TODO: refactor registered_functions to only contain function services, not class services + function_choices: dict[str, Union[Function, LocalEntrypoint, MethodReference]] = { + name: f for name, f in app.registered_functions.items() if not name.endswith(".*") + } + for cls_name, cls in app.registered_classes.items(): + for method_name in cls._get_method_names(): + function_choices[f"{cls_name}.{method_name}"] = MethodReference(cls, method_name) + + if not accept_webhook: + for web_endpoint_name in app.registered_web_endpoints: + function_choices.pop(web_endpoint_name, None) + + if accept_local_entrypoint: + function_choices.update(app.registered_entrypoints) + + if len(function_choices) == 1: + return list(function_choices.values())[0] + + if len(function_choices) == 0: if app.registered_web_endpoints: err_msg = "Modal app has only web endpoints. Use `modal serve` instead of `modal run`." else: err_msg = "Modal app has no registered functions. Nothing to run." raise click.UsageError(err_msg) - else: - help_text = f"""You need to specify a Modal function or local entrypoint to run, e.g. + + # there are multiple choices - we can't determine which one to use: + registered_functions_str = "\n".join(sorted(function_choices)) + help_text = f"""You need to specify a Modal function or local entrypoint to run, e.g. modal run app.py::my_function [...args] Registered functions and local entrypoints on the selected app are: {registered_functions_str} """ - raise click.UsageError(help_text) - - if function_name in app.registered_entrypoints: - # entrypoint is in entrypoint registry, for now - return app.registered_entrypoints[function_name] - - function = app.registered_functions[function_name] - assert isinstance(function, Function) - return function + raise click.UsageError(help_text) def _show_no_auto_detectable_app(app_ref: ImportRef) -> None: @@ -223,30 +260,60 @@ def foo(): error_console.print(guidance_msg) -def import_function( - func_ref: str, base_cmd: str, accept_local_entrypoint=True, accept_webhook=False -) -> Union[Function, LocalEntrypoint]: +def _import_object(func_ref, base_cmd): import_ref = parse_import_ref(func_ref) - module = import_file_or_module(import_ref.file_or_module) - app_or_function = get_by_object_path(module, import_ref.object_path or DEFAULT_APP_NAME) + app_function_or_method_ref = get_by_object_path(module, import_ref.object_path or DEFAULT_APP_NAME) - if app_or_function is None: + if app_function_or_method_ref is None: _show_function_ref_help(import_ref, base_cmd) - sys.exit(1) + raise SystemExit(1) + + return app_function_or_method_ref, module + - if isinstance(app_or_function, App): +def _infer_runnable( + partial_obj: Union[App, Function, MethodReference, LocalEntrypoint], + module: types.ModuleType, + accept_local_entrypoint: bool = True, + accept_webhook: bool = False, +) -> tuple[App, Union[Function, MethodReference, LocalEntrypoint]]: + if isinstance(partial_obj, App): # infer function or display help for how to select one - app = app_or_function + app = partial_obj function_handle = _infer_function_or_help(app, module, accept_local_entrypoint, accept_webhook) - return function_handle - elif isinstance(app_or_function, Function): - return app_or_function - elif isinstance(app_or_function, LocalEntrypoint): + return app, function_handle + elif isinstance(partial_obj, Function): + return partial_obj.app, partial_obj + elif isinstance(partial_obj, MethodReference): + return partial_obj.cls._get_app(), partial_obj + elif isinstance(partial_obj, LocalEntrypoint): if not accept_local_entrypoint: raise click.UsageError( - f"{func_ref} is not a Modal Function (a Modal local_entrypoint can't be used in this context)" + f"{partial_obj.info.function_name} is not a Modal Function " + f"(a Modal local_entrypoint can't be used in this context)" ) - return app_or_function + return partial_obj.app, partial_obj else: - raise click.UsageError(f"{app_or_function} is not a Modal entity (should be an App or Function)") + raise click.UsageError( + f"{partial_obj} is not a Modal entity (should be an App, Local entrypoint, " "Function or Class/Method)" + ) + + +def import_and_infer( + func_ref: str, base_cmd: str, accept_local_entrypoint=True, accept_webhook=False +) -> tuple[App, Union[Function, LocalEntrypoint, MethodReference]]: + """Takes a function ref string and returns something "runnable" + + The function ref can leave out partial information (apart from the file name) as + long as the runnable is uniquely identifiable by the provided information. + + When there are multiple runnables within the provided ref, the following rules should + be followed: + + 1. if there is a single local_entrypoint, that one is used + 2. if there is a single {function, class} that one is used + 3. if there is a single method (within a class) that one is used + """ + app_function_or_method_ref, module = _import_object(func_ref, base_cmd) + return _infer_runnable(app_function_or_method_ref, module, accept_local_entrypoint, accept_webhook) diff --git a/modal/cli/launch.py b/modal/cli/launch.py index 12eea123d..334c1d516 100644 --- a/modal/cli/launch.py +++ b/modal/cli/launch.py @@ -8,11 +8,11 @@ from typer import Typer -from ..app import App +from ..app import LocalEntrypoint from ..exception import _CliUserExecutionError from ..output import enable_output from ..runner import run_app -from .import_refs import import_function +from .import_refs import import_and_infer launch_cli = Typer( name="launch", @@ -29,8 +29,10 @@ def _launch_program(name: str, filename: str, detach: bool, args: dict[str, Any] os.environ["MODAL_LAUNCH_ARGS"] = json.dumps(args) program_path = str(Path(__file__).parent / "programs" / filename) - entrypoint = import_function(program_path, "modal launch") - app: App = entrypoint.app + app, entrypoint = import_and_infer(program_path, "modal launch") + if not isinstance(entrypoint, LocalEntrypoint): + raise ValueError(f"{program_path} has no single local_entrypoint") + app.set_description(f"modal launch {name}") # `launch/` scripts must have a `local_entrypoint()` with no args, for simplicity here. diff --git a/modal/cli/run.py b/modal/cli/run.py index 05b80593f..2261602c0 100644 --- a/modal/cli/run.py +++ b/modal/cli/run.py @@ -7,7 +7,6 @@ import shlex import sys import time -import typing from functools import partial from typing import Any, Callable, Optional, get_type_hints @@ -15,7 +14,6 @@ import typer from typing_extensions import TypedDict -from .. import Cls from ..app import App, LocalEntrypoint from ..config import config from ..environments import ensure_env @@ -26,7 +24,7 @@ from ..runner import deploy_app, interactive_shell, run_app from ..serving import serve_app from ..volume import Volume -from .import_refs import import_app, import_function +from .import_refs import MethodReference, import_and_infer, import_app from .utils import ENV_OPTION, ENV_OPTION_HELP, is_tty, stream_app_logs @@ -145,39 +143,7 @@ def _write_local_result(result_path: str, res: Any): fid.write(res) -def _get_click_command_for_function(app: App, function_tag): - function = app.registered_functions.get(function_tag) - if not function or (isinstance(function, Function) and function.info.user_cls is not None): - # This is either a function_tag for a class method function (e.g MyClass.foo) or a function tag for a - # class service function (MyClass.*) - class_name, method_name = function_tag.rsplit(".", 1) - if not function: - function = app.registered_functions.get(f"{class_name}.*") - assert isinstance(function, Function) - function = typing.cast(Function, function) - if function.is_generator: - raise InvalidError("`modal run` is not supported for generator functions") - - signature: dict[str, ParameterMetadata] - cls: Optional[Cls] = None - if function.info.user_cls is not None: - cls = typing.cast(Cls, app.registered_classes[class_name]) - cls_signature = _get_signature(function.info.user_cls) - if method_name == "*": - method_names = list(cls._get_partial_functions().keys()) - if len(method_names) == 1: - method_name = method_names[0] - else: - class_name = function.info.user_cls.__name__ - raise click.UsageError( - f"Please specify a specific method of {class_name} to run, e.g. `modal run foo.py::MyClass.bar`" # noqa: E501 - ) - fun_signature = _get_signature(getattr(cls, method_name).info.raw_f, is_method=True) - signature = dict(**cls_signature, **fun_signature) # Pool all arguments - # TODO(erikbern): assert there's no overlap? - else: - signature = _get_signature(function.info.raw_f) - +def _make_click_function(app, inner: Callable[[dict[str, Any]], Any]): @click.pass_context def f(ctx, **kwargs): show_progress: bool = ctx.obj["show_progress"] @@ -188,21 +154,65 @@ def f(ctx, **kwargs): environment_name=ctx.obj["env"], interactive=ctx.obj["interactive"], ): - if cls is None: - res = function.remote(**kwargs) - else: - # unpool class and method arguments - # TODO(erikbern): this code is a bit hacky - cls_kwargs = {k: kwargs[k] for k in cls_signature} - fun_kwargs = {k: kwargs[k] for k in fun_signature} - - instance = cls(**cls_kwargs) - method: Function = getattr(instance, method_name) - res = method.remote(**fun_kwargs) + res = inner(kwargs) if result_path := ctx.obj["result_path"]: _write_local_result(result_path, res) + return f + + +def _get_click_command_for_function(app: App, function: Function): + if function.is_generator: + raise InvalidError("`modal run` is not supported for generator functions") + + signature: dict[str, ParameterMetadata] = _get_signature(function.info.raw_f) + + def _inner(click_kwargs): + return function.remote(**click_kwargs) + + f = _make_click_function(app, _inner) + + with_click_options = _add_click_options(f, signature) + return click.command(with_click_options) + + +def _get_click_command_for_cls(app: App, method_ref: MethodReference): + signature: dict[str, ParameterMetadata] + cls = method_ref.cls + method_name = method_ref.method_name + + cls_signature = _get_signature(cls._get_user_cls()) + partial_functions = cls._get_partial_functions() + + if method_name in ("*", ""): + # auto infer method name - not sure if we have to support this... + method_names = list(partial_functions.keys()) + if len(method_names) == 1: + method_name = method_names[0] + else: + raise click.UsageError( + f"Please specify a specific method of {cls._get_name()} to run, " + f"e.g. `modal run foo.py::MyClass.bar`" # noqa: E501 + ) + + partial_function = partial_functions[method_name] + fun_signature = _get_signature(partial_function._get_raw_f(), is_method=True) + + # TODO(erikbern): assert there's no overlap? + signature = dict(**cls_signature, **fun_signature) # Pool all arguments + + def _inner(click_kwargs): + # unpool class and method arguments + # TODO(erikbern): this code is a bit hacky + cls_kwargs = {k: click_kwargs[k] for k in cls_signature} + fun_kwargs = {k: click_kwargs[k] for k in fun_signature} + + instance = cls(**cls_kwargs) + method: Function = getattr(instance, method_name) + return method.remote(**fun_kwargs) + + f = _make_click_function(app, _inner) with_click_options = _add_click_options(f, signature) return click.command(with_click_options) @@ -249,16 +259,20 @@ def get_command(self, ctx, func_ref): # needs to be handled here, and not in the `run` logic below ctx.ensure_object(dict) ctx.obj["env"] = ensure_env(ctx.params["env"]) - function_or_entrypoint = import_function(func_ref, accept_local_entrypoint=True, base_cmd="modal run") - app: App = function_or_entrypoint.app + + app, imported_object = import_and_infer(func_ref, accept_local_entrypoint=True, base_cmd="modal run") + if app.description is None: app.set_description(_get_clean_app_description(func_ref)) - if isinstance(function_or_entrypoint, LocalEntrypoint): - click_command = _get_click_command_for_local_entrypoint(app, function_or_entrypoint) - else: - tag = function_or_entrypoint.info.get_tag() - click_command = _get_click_command_for_function(app, tag) + if isinstance(imported_object, LocalEntrypoint): + click_command = _get_click_command_for_local_entrypoint(app, imported_object) + elif isinstance(imported_object, Function): + click_command = _get_click_command_for_function(app, imported_object) + elif isinstance(imported_object, MethodReference): + click_command = _get_click_command_for_cls(app, imported_object) + else: + raise ValueError(f"{imported_object} is neither function, local entrypoint or class/method") return click_command @@ -464,11 +478,18 @@ def shell( exec(container_id=container_or_function, command=shlex.split(cmd), pty=pty) return - function = import_function( + original_app, function_or_method_ref = import_and_infer( container_or_function, accept_local_entrypoint=False, accept_webhook=True, base_cmd="modal shell" ) - assert isinstance(function, Function) - function_spec: _FunctionSpec = function.spec + function_spec: _FunctionSpec + if isinstance(function_or_method_ref, MethodReference): + class_service_function = function_or_method_ref.cls._get_class_service_function() + function_spec = class_service_function.spec + elif isinstance(function_or_method_ref, Function): + function_spec = function_or_method_ref.spec + else: + raise ValueError("Referenced entity is neither a function nor a class/method.") + start_shell = partial( interactive_shell, image=function_spec.image, diff --git a/modal/cls.py b/modal/cls.py index b14c5d402..6ee1332dd 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -397,6 +397,22 @@ def _get_partial_functions(self) -> dict[str, _PartialFunction]: raise AttributeError("You can only get the partial functions of a local Cls instance") return _find_partial_methods_for_user_cls(self._user_cls, _PartialFunctionFlags.all()) + def _get_app(self) -> "modal.app._App": + return self._app + + def _get_user_cls(self) -> type: + return self._user_cls + + def _get_name(self) -> str: + return self._name + + def _get_class_service_function(self) -> "modal.functions._Function": + return self._class_service_function + + def _get_method_names(self) -> Collection[str]: + # returns method names for a *local* class only for now (used by cli) + return self._method_functions.keys() + def _hydrate_metadata(self, metadata: Message): assert isinstance(metadata, api_pb2.ClassHandleMetadata) if ( diff --git a/modal/partial_function.py b/modal/partial_function.py index 0838c385f..da6bd6460 100644 --- a/modal/partial_function.py +++ b/modal/partial_function.py @@ -89,6 +89,9 @@ def __init__( self.force_build = force_build self.build_timeout = build_timeout + def _get_raw_f(self) -> Callable[P, ReturnType]: + return self.raw_f + def __get__(self, obj, objtype=None) -> _Function[P, ReturnType, OriginalReturnType]: k = self.raw_f.__name__ if obj: # accessing the method on an instance of a class, e.g. `MyClass().fun`` diff --git a/tasks.py b/tasks.py index 8eab5daea..c57bd39e4 100644 --- a/tasks.py +++ b/tasks.py @@ -153,6 +153,7 @@ def type_check(ctx): "modal/io_streams.py", "modal/image.py", "modal/file_io.py", + "modal/cli/import_refs.py", ] ctx.run(f"pyright {' '.join(pyright_allowlist)}", pty=True) diff --git a/test/cli_imports_test.py b/test/cli_imports_test.py index f4007a874..a500469a9 100644 --- a/test/cli_imports_test.py +++ b/test/cli_imports_test.py @@ -1,15 +1,20 @@ # Copyright Modal Labs 2023 import pytest +import sys -from modal._utils.async_utils import synchronizer -from modal.app import _App, _LocalEntrypoint +import click + +from modal import web_endpoint +from modal.app import App, LocalEntrypoint from modal.cli.import_refs import ( - DEFAULT_APP_NAME, + MethodReference, + _import_object, + _infer_runnable, get_by_object_path, import_file_or_module, - parse_import_ref, ) from modal.exception import InvalidError +from modal.partial_function import asgi_app, method # Some helper vars for import_stub tests: local_entrypoint_src = """ @@ -86,28 +91,113 @@ def func(): ["dir_structure", "ref", "expected_object_type"], [ # # file syntax - (empty_dir_with_python_file, "mod.py", _App), - (empty_dir_with_python_file, "mod.py::app", _App), - (empty_dir_with_python_file, "mod.py::other_app", _App), - (dir_containing_python_package, "pack/file.py", _App), - (dir_containing_python_package, "pack/sub/subfile.py", _App), - (dir_containing_python_package, "dir/sub/subfile.py", _App), + (empty_dir_with_python_file, "mod.py", App), + (empty_dir_with_python_file, "mod.py::app", App), + (empty_dir_with_python_file, "mod.py::other_app", App), + (dir_containing_python_package, "pack/file.py", App), + (dir_containing_python_package, "pack/sub/subfile.py", App), + (dir_containing_python_package, "dir/sub/subfile.py", App), # # python module syntax - (empty_dir_with_python_file, "mod", _App), - (empty_dir_with_python_file, "mod::app", _App), - (empty_dir_with_python_file, "mod::other_app", _App), - (dir_containing_python_package, "pack.mod", _App), - (dir_containing_python_package, "pack.mod::other_app", _App), - (dir_containing_python_package, "pack/local.py::app.main", _LocalEntrypoint), + (empty_dir_with_python_file, "mod", App), + (empty_dir_with_python_file, "mod::app", App), + (empty_dir_with_python_file, "mod::other_app", App), + (dir_containing_python_package, "pack.mod", App), + (dir_containing_python_package, "pack.mod::other_app", App), + (dir_containing_python_package, "pack/local.py::app.main", LocalEntrypoint), ], ) def test_import_object(dir_structure, ref, expected_object_type, mock_dir): with mock_dir(dir_structure): - import_ref = parse_import_ref(ref) - module = import_file_or_module(import_ref.file_or_module) - imported_object = get_by_object_path(module, import_ref.object_path or DEFAULT_APP_NAME) - _translated_obj = synchronizer._translate_in(imported_object) - assert isinstance(_translated_obj, expected_object_type) + obj, _ = _import_object(ref, base_cmd="modal some_command") + assert isinstance(obj, expected_object_type) + + +app_with_one_web_function = App() + + +@app_with_one_web_function.function() +@web_endpoint() +def web1(): + pass + + +app_with_one_function_one_web_endpoint = App() + + +@app_with_one_function_one_web_endpoint.function() +def f1(): + pass + + +@app_with_one_function_one_web_endpoint.function() +@web_endpoint() +def web2(): + pass + + +app_with_one_web_method = App() + + +@app_with_one_web_method.cls() +class C1: + @asgi_app() + def web_3(self): + pass + + +app_with_one_web_method_one_method = App() + + +@app_with_one_web_method_one_method.cls() +class C2: + @asgi_app() + def web_4(self): + pass + + @method() + def f2(self): + pass + + +app_with_local_entrypoint_and_function = App() + + +@app_with_local_entrypoint_and_function.local_entrypoint() +def le_1(): + pass + + +@app_with_local_entrypoint_and_function.function() +def f3(): + pass + + +def test_infer_object(): + this_module = sys.modules[__name__] + with pytest.raises(click.ClickException, match="web endpoint"): + _infer_runnable(app_with_one_web_function, this_module, accept_webhook=False) + + _, runnable = _infer_runnable(app_with_one_web_function, this_module, accept_webhook=True) + assert runnable == web1 + + _, runnable = _infer_runnable(app_with_one_function_one_web_endpoint, this_module, accept_webhook=False) + assert runnable == f1 + + with pytest.raises(click.UsageError, match="(?s)You need to specify.*\nf1\nweb2\n"): + _, runnable = _infer_runnable(app_with_one_function_one_web_endpoint, this_module, accept_webhook=True) + assert runnable == f1 + + with pytest.raises(click.UsageError, match="web endpoint"): + _, runnable = _infer_runnable(app_with_one_web_method, this_module, accept_webhook=False) + + _, runnable = _infer_runnable(app_with_one_web_method, this_module, accept_webhook=True) + assert runnable == MethodReference(C1, "web_3") # type: ignore + + _, runnable = _infer_runnable(app_with_local_entrypoint_and_function, this_module, accept_local_entrypoint=True) + assert runnable == le_1 + + _, runnable = _infer_runnable(app_with_local_entrypoint_and_function, this_module, accept_local_entrypoint=False) + assert runnable == f3 def test_import_package_and_module_names(monkeypatch, supports_dir): diff --git a/test/cli_test.py b/test/cli_test.py index 6a6fcc5c3..83d8ea98b 100644 --- a/test/cli_test.py +++ b/test/cli_test.py @@ -314,7 +314,7 @@ def test_run_parse_args_entrypoint(servicer, set_env_client, test_dir): assert "Unable to generate command line interface for app entrypoint." in str(res.exception) -def test_run_parse_args_function(servicer, set_env_client, test_dir): +def test_run_parse_args_function(servicer, set_env_client, test_dir, recwarn): app_file = test_dir / "supports" / "app_run_tests" / "cli_args.py" res = _run(["run", app_file.as_posix()], expected_exit_code=2, expected_stderr=None) assert "You need to specify a Modal function or local entrypoint to run" in res.stderr @@ -334,6 +334,10 @@ def print_type(i): res = _run(args) assert expected in res.stdout + if len(recwarn): + print("Unexpected warnings:", [str(w) for w in recwarn]) + assert len(recwarn) == 0 + def test_run_user_script_exception(servicer, set_env_client, test_dir): app_file = test_dir / "supports" / "app_run_tests" / "raises_error.py"