diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 496ba56af..b5bcf7d47 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -50,8 +50,8 @@ from ._runtime.execution_context import _set_current_context_ids if TYPE_CHECKING: + import modal._object import modal._runtime.container_io_manager - import modal.object class DaemonizedThreadPool: diff --git a/modal/_object.py b/modal/_object.py new file mode 100644 index 000000000..ff5dd89a4 --- /dev/null +++ b/modal/_object.py @@ -0,0 +1,279 @@ +# Copyright Modal Labs 2022 +import typing +import uuid +from collections.abc import Awaitable, Hashable, Sequence +from functools import wraps +from typing import Callable, ClassVar, Optional + +from google.protobuf.message import Message +from typing_extensions import Self + +from modal._utils.async_utils import aclosing + +from ._resolver import Resolver +from .client import _Client +from .config import config, logger +from .exception import ExecutionError, InvalidError + +EPHEMERAL_OBJECT_HEARTBEAT_SLEEP: int = 300 + + +def _get_environment_name(environment_name: Optional[str] = None, resolver: Optional[Resolver] = None) -> Optional[str]: + if environment_name: + return environment_name + elif resolver and resolver.environment_name: + return resolver.environment_name + else: + return config.get("environment") + + +class _Object: + _type_prefix: ClassVar[Optional[str]] = None + _prefix_to_type: ClassVar[dict[str, type]] = {} + + # For constructors + _load: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] + _preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] + _rep: str + _is_another_app: bool + _hydrate_lazily: bool + _deps: Optional[Callable[..., Sequence["_Object"]]] + _deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None + + # For hydrated objects + _object_id: Optional[str] + _client: Optional[_Client] + _is_hydrated: bool + _is_rehydrated: bool + + @classmethod + def __init_subclass__(cls, type_prefix: Optional[str] = None): + super().__init_subclass__() + if type_prefix is not None: + cls._type_prefix = type_prefix + cls._prefix_to_type[type_prefix] = cls + + def __init__(self, *args, **kwargs): + raise InvalidError(f"Class {type(self).__name__} has no constructor. Use class constructor methods instead.") + + def _init( + self, + rep: str, + load: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None, + is_another_app: bool = False, + preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None, + hydrate_lazily: bool = False, + deps: Optional[Callable[..., Sequence["_Object"]]] = None, + deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None, + ): + self._local_uuid = str(uuid.uuid4()) + self._load = load + self._preload = preload + self._rep = rep + self._is_another_app = is_another_app + self._hydrate_lazily = hydrate_lazily + self._deps = deps + self._deduplication_key = deduplication_key + + self._object_id = None + self._client = None + self._is_hydrated = False + self._is_rehydrated = False + + self._initialize_from_empty() + + def _unhydrate(self): + self._object_id = None + self._client = None + self._is_hydrated = False + + def _initialize_from_empty(self): + # default implementation, can be overriden in subclasses + pass + + def _initialize_from_other(self, other): + # default implementation, can be overriden in subclasses + self._object_id = other._object_id + self._is_hydrated = other._is_hydrated + self._client = other._client + + def _hydrate(self, object_id: str, client: _Client, metadata: Optional[Message]): + assert isinstance(object_id, str) and self._type_prefix is not None + if not object_id.startswith(self._type_prefix): + raise ExecutionError( + f"Can not hydrate {type(self)}:" + f" it has type prefix {self._type_prefix}" + f" but the object_id starts with {object_id[:3]}" + ) + self._object_id = object_id + self._client = client + self._hydrate_metadata(metadata) + self._is_hydrated = True + + def _hydrate_metadata(self, metadata: Optional[Message]): + # override this is subclasses that need additional data (other than an object_id) for a functioning Handle + pass + + def _get_metadata(self) -> Optional[Message]: + # return the necessary metadata from this handle to be able to re-hydrate in another context if one is needed + # used to provide a handle's handle_metadata for serializing/pickling a live handle + # the object_id is already provided by other means + return None + + def _validate_is_hydrated(self): + if not self._is_hydrated: + object_type = self.__class__.__name__.strip("_") + if hasattr(self, "_app") and getattr(self._app, "_running_app", "") is None: # type: ignore + # The most common cause of this error: e.g., user called a Function without using App.run() + reason = ", because the App it is defined on is not running" + else: + # Technically possible, but with an ambiguous cause. + reason = "" + raise ExecutionError( + f"{object_type} has not been hydrated with the metadata it needs to run on Modal{reason}." + ) + + def clone(self) -> Self: + """mdmd:hidden Clone a given hydrated object.""" + + # Object to clone must already be hydrated, otherwise from_loader is more suitable. + self._validate_is_hydrated() + obj = type(self).__new__(type(self)) + obj._initialize_from_other(self) + return obj + + @classmethod + def _from_loader( + cls, + load: Callable[[Self, Resolver, Optional[str]], Awaitable[None]], + rep: str, + is_another_app: bool = False, + preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None, + hydrate_lazily: bool = False, + deps: Optional[Callable[..., Sequence["_Object"]]] = None, + deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None, + ): + # TODO(erikbern): flip the order of the two first arguments + obj = _Object.__new__(cls) + obj._init(rep, load, is_another_app, preload, hydrate_lazily, deps, deduplication_key) + return obj + + @staticmethod + def _get_type_from_id(object_id: str) -> type["_Object"]: + parts = object_id.split("-") + if len(parts) != 2: + raise InvalidError(f"Object id {object_id} has no dash in it") + prefix = parts[0] + if prefix not in _Object._prefix_to_type: + raise InvalidError(f"Object prefix {prefix} does not correspond to a type") + return _Object._prefix_to_type[prefix] + + @classmethod + def _is_id_type(cls, object_id) -> bool: + return cls._get_type_from_id(object_id) == cls + + @classmethod + def _new_hydrated( + cls, object_id: str, client: _Client, handle_metadata: Optional[Message], is_another_app: bool = False + ) -> Self: + obj_cls: type[Self] + if cls._type_prefix is not None: + # This is called directly on a subclass, e.g. Secret.from_id + # validate the id matching the expected id type of the Object subclass + if not object_id.startswith(cls._type_prefix + "-"): + raise InvalidError(f"Object {object_id} does not start with {cls._type_prefix}") + + obj_cls = cls + else: + # this means the method is used directly on _Object + # typically during deserialization of objects + obj_cls = typing.cast(type[Self], cls._get_type_from_id(object_id)) + + # Instantiate provider + obj = _Object.__new__(obj_cls) + rep = f"Object({object_id})" # TODO(erikbern): dumb + obj._init(rep, is_another_app=is_another_app) + obj._hydrate(object_id, client, handle_metadata) + + return obj + + def _hydrate_from_other(self, other: Self): + self._hydrate(other.object_id, other.client, other._get_metadata()) + + def __repr__(self): + return self._rep + + @property + def local_uuid(self): + """mdmd:hidden""" + return self._local_uuid + + @property + def object_id(self) -> str: + """mdmd:hidden""" + if self._object_id is None: + raise AttributeError(f"Attempting to get object_id of unhydrated {self}") + return self._object_id + + @property + def client(self) -> _Client: + """mdmd:hidden""" + if self._client is None: + raise AttributeError(f"Attempting to get client of unhydrated {self}") + return self._client + + @property + def is_hydrated(self) -> bool: + """mdmd:hidden""" + return self._is_hydrated + + @property + def deps(self) -> Callable[..., Sequence["_Object"]]: + """mdmd:hidden""" + + def default_deps(*args, **kwargs) -> Sequence["_Object"]: + return [] + + return self._deps if self._deps is not None else default_deps + + async def resolve(self, client: Optional[_Client] = None): + """mdmd:hidden""" + if self._is_hydrated: + # memory snapshots capture references which must be rehydrated + # on restore to handle staleness. + if self.client._snapshotted and not self._is_rehydrated: + logger.debug(f"rehydrating {self} after snapshot") + self._is_hydrated = False # un-hydrate and re-resolve + c = client if client is not None else await _Client.from_env() + resolver = Resolver(c) + await resolver.load(typing.cast(_Object, self)) + self._is_rehydrated = True + logger.debug(f"rehydrated {self} with client {id(c)}") + return + elif not self._hydrate_lazily: + self._validate_is_hydrated() + else: + # TODO: this client and/or resolver can't be changed by a caller to X.from_name() + c = client if client is not None else await _Client.from_env() + resolver = Resolver(c) + await resolver.load(self) + + +def live_method(method): + @wraps(method) + async def wrapped(self, *args, **kwargs): + await self.resolve() + return await method(self, *args, **kwargs) + + return wrapped + + +def live_method_gen(method): + @wraps(method) + async def wrapped(self, *args, **kwargs): + await self.resolve() + async with aclosing(method(self, *args, **kwargs)) as stream: + async for item in stream: + yield item + + return wrapped diff --git a/modal/_resolver.py b/modal/_resolver.py index 850edb910..37bd4918a 100644 --- a/modal/_resolver.py +++ b/modal/_resolver.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from rich.tree import Tree - from modal.object import _Object + import modal._object class StatusRow: @@ -33,7 +33,7 @@ def message(self, message): self._spinner.update(text=message) def finish(self, message): - if self._step_node is not None: + if self._step_node is not None and self._spinner is not None: from ._output import OutputManager self._spinner.update(text=message) @@ -89,7 +89,7 @@ async def preload(self, obj, existing_object_id: Optional[str]): if obj._preload is not None: await obj._preload(obj, self, existing_object_id) - async def load(self, obj: "_Object", existing_object_id: Optional[str] = None): + async def load(self, obj: "modal._object._Object", existing_object_id: Optional[str] = None): if obj._is_hydrated and obj._is_another_app: # No need to reload this, it won't typically change if obj.local_uuid not in self._local_uuid_to_future: @@ -124,6 +124,8 @@ async def loader(): await TaskContext.gather(*[self.load(dep) for dep in obj.deps()]) # Load the object itself + if not obj._load: + raise Exception(f"Object {obj} has no loader function") try: await obj._load(obj, self, existing_object_id) except GRPCError as exc: @@ -154,8 +156,8 @@ async def loader(): # TODO(elias): print original exception/trace rather than the Resolver-internal trace return await cached_future - def objects(self) -> list["_Object"]: - unique_objects: dict[str, "_Object"] = {} + def objects(self) -> list["modal._object._Object"]: + unique_objects: dict[str, "modal._object._Object"] = {} for fut in self._local_uuid_to_future.values(): if not fut.done(): # this will raise an exception if not all loads have been awaited, but that *should* never happen diff --git a/modal/_runtime/user_code_imports.py b/modal/_runtime/user_code_imports.py index f5b8beecf..c9696c782 100644 --- a/modal/_runtime/user_code_imports.py +++ b/modal/_runtime/user_code_imports.py @@ -3,11 +3,11 @@ import typing from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Sequence +import modal._object import modal._runtime.container_io_manager import modal.cls -import modal.object from modal import Function from modal._utils.async_utils import synchronizer from modal._utils.function_utils import LocalFunctionError, is_async as get_is_async, is_global_object @@ -41,7 +41,7 @@ class Service(metaclass=ABCMeta): user_cls_instance: Any app: Optional["modal.app._App"] - code_deps: Optional[list["modal.object._Object"]] + code_deps: Optional[Sequence["modal._object._Object"]] @abstractmethod def get_finalized_functions( @@ -94,7 +94,7 @@ def construct_webhook_callable( class ImportedFunction(Service): user_cls_instance: Any app: Optional["modal.app._App"] - code_deps: Optional[list["modal.object._Object"]] + code_deps: Optional[Sequence["modal._object._Object"]] _user_defined_callable: Callable[..., Any] @@ -137,7 +137,7 @@ def get_finalized_functions( class ImportedClass(Service): user_cls_instance: Any app: Optional["modal.app._App"] - code_deps: Optional[list["modal.object._Object"]] + code_deps: Optional[Sequence["modal._object._Object"]] _partial_functions: dict[str, "modal.partial_function._PartialFunction"] @@ -229,7 +229,7 @@ def import_single_function_service( """ user_defined_callable: Callable function: Optional[_Function] = None - code_deps: Optional[list["modal.object._Object"]] = None + code_deps: Optional[Sequence["modal._object._Object"]] = None active_app: Optional[modal.app._App] = None if ser_fun is not None: @@ -306,7 +306,7 @@ def import_class_service( See import_function. """ active_app: Optional["modal.app._App"] - code_deps: Optional[list["modal.object._Object"]] + code_deps: Optional[Sequence["modal._object._Object"]] cls: typing.Union[type, modal.cls.Cls] if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED: diff --git a/modal/_serialization.py b/modal/_serialization.py index 814d11583..b5bf31372 100644 --- a/modal/_serialization.py +++ b/modal/_serialization.py @@ -8,10 +8,11 @@ from modal._utils.async_utils import synchronizer from modal_proto import api_pb2 +from ._object import _Object from ._vendor import cloudpickle from .config import logger from .exception import DeserializationError, ExecutionError, InvalidError -from .object import Object, _Object +from .object import Object PICKLE_PROTOCOL = 4 # Support older Python versions. @@ -48,8 +49,8 @@ def persistent_id(self, obj): return ("sync", (impl_object.__class__, attributes)) else: return - if not obj.object_id: - raise InvalidError(f"Can't serialize object {obj} which hasn't been created.") + if not obj.is_hydrated: + raise InvalidError(f"Can't serialize object {obj} which hasn't been hydrated.") return (obj.object_id, flag, obj._get_metadata()) diff --git a/modal/app.py b/modal/app.py index bb28fb28f..84e0abc5e 100644 --- a/modal/app.py +++ b/modal/app.py @@ -21,6 +21,7 @@ from modal_proto import api_pb2 from ._ipython import is_notebook +from ._object import _get_environment_name, _Object from ._utils.async_utils import synchronize_api from ._utils.deprecation import deprecation_error, deprecation_warning, renamed_parameter from ._utils.function_utils import FunctionInfo, is_global_object, is_method_fn @@ -36,7 +37,6 @@ from .image import _Image from .mount import _Mount from .network_file_system import _NetworkFileSystem -from .object import _get_environment_name, _Object from .partial_function import ( PartialFunction, _find_partial_methods_for_user_cls, diff --git a/modal/cli/app.py b/modal/cli/app.py index 5cdd2ada1..b0d23c770 100644 --- a/modal/cli/app.py +++ b/modal/cli/app.py @@ -9,11 +9,11 @@ from rich.text import Text from typer import Argument +from modal._object import _get_environment_name from modal._utils.async_utils import synchronizer from modal._utils.deprecation import deprecation_warning from modal.client import _Client from modal.environments import ensure_env -from modal.object import _get_environment_name from modal_proto import api_pb2 from .utils import ENV_OPTION, display_table, get_app_id_from_name, stream_app_logs, timestamp_to_local diff --git a/modal/cli/container.py b/modal/cli/container.py index 96d1b5173..792d03c34 100644 --- a/modal/cli/container.py +++ b/modal/cli/container.py @@ -4,6 +4,7 @@ import typer from rich.text import Text +from modal._object import _get_environment_name from modal._pty import get_pty_info from modal._utils.async_utils import synchronizer from modal._utils.grpc_utils import retry_transient_errors @@ -12,7 +13,6 @@ from modal.config import config from modal.container_process import _ContainerProcess from modal.environments import ensure_env -from modal.object import _get_environment_name from modal.stream_type import StreamType from modal_proto import api_pb2 diff --git a/modal/client.py b/modal/client.py index fd09a34dd..c17eacef8 100644 --- a/modal/client.py +++ b/modal/client.py @@ -73,6 +73,7 @@ class _Client: _cancellation_context: TaskContext _cancellation_context_event_loop: asyncio.AbstractEventLoop = None _stub: Optional[api_grpc.ModalClientStub] + _snapshotted: bool def __init__( self, diff --git a/modal/cls.py b/modal/cls.py index 6ee1332dd..5f20700bb 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -11,6 +11,7 @@ from modal._utils.function_utils import CLASS_PARAM_TYPE_MAP from modal_proto import api_pb2 +from ._object import _get_environment_name, _Object from ._resolver import Resolver from ._resources import convert_fn_config_to_resources_config from ._serialization import check_valid_cls_constructor_arg @@ -23,7 +24,6 @@ from .exception import ExecutionError, InvalidError, NotFoundError, VersionError from .functions import _Function, _parse_retries from .gpu import GPU_T -from .object import _get_environment_name, _Object from .partial_function import ( _find_callables_for_obj, _find_partial_methods_for_user_cls, diff --git a/modal/dict.py b/modal/dict.py index 36e4f998c..55c8002a9 100644 --- a/modal/dict.py +++ b/modal/dict.py @@ -7,6 +7,7 @@ from modal_proto import api_pb2 +from ._object import EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, _get_environment_name, _Object, live_method, live_method_gen from ._resolver import Resolver from ._serialization import deserialize, serialize from ._utils.async_utils import TaskContext, synchronize_api @@ -16,7 +17,6 @@ from .client import _Client from .config import logger from .exception import RequestSizeError -from .object import EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, _get_environment_name, _Object, live_method, live_method_gen def _serialize_dict(data): diff --git a/modal/environments.py b/modal/environments.py index 44f499538..ec65f2c36 100644 --- a/modal/environments.py +++ b/modal/environments.py @@ -8,6 +8,7 @@ from modal_proto import api_pb2 +from ._object import _Object from ._resolver import Resolver from ._utils.async_utils import synchronize_api, synchronizer from ._utils.deprecation import renamed_parameter @@ -15,7 +16,6 @@ from ._utils.name_utils import check_object_name from .client import _Client from .config import config, logger -from .object import _Object @dataclass(frozen=True) diff --git a/modal/functions.py b/modal/functions.py index d145c0014..08db1c19c 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -26,6 +26,7 @@ from modal_proto.modal_api_grpc import ModalClientModal from ._location import parse_cloud_provider +from ._object import _get_environment_name, _Object, live_method, live_method_gen from ._pty import get_pty_info from ._resolver import Resolver from ._resources import convert_fn_config_to_resources_config @@ -71,7 +72,6 @@ from .image import _Image from .mount import _get_client_mount, _Mount, get_auto_mounts from .network_file_system import _NetworkFileSystem, network_file_system_mount_protos -from .object import _get_environment_name, _Object, live_method, live_method_gen from .output import _get_output_manager from .parallel_map import ( _for_each_async, @@ -388,7 +388,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type _is_method: bool _spec: Optional[_FunctionSpec] = None _tag: str - _raw_f: Callable[..., Any] + _raw_f: Optional[Callable[..., Any]] # this is set to None for a "class service [function]" _build_args: dict _is_generator: Optional[bool] = None @@ -474,7 +474,7 @@ def from_args( _experimental_buffer_containers: Optional[int] = None, _experimental_proxy_ip: Optional[str] = None, _experimental_custom_scaling_factor: Optional[float] = None, - ) -> None: + ) -> "_Function": """mdmd:hidden""" # Needed to avoid circular imports from .partial_function import _find_partial_methods_for_user_cls, _PartialFunctionFlags @@ -573,7 +573,7 @@ def from_args( ) image = _Image._from_args( base_images={"base": image}, - build_function=snapshot_function, + build_function=snapshot_function, # type: ignore # TODO: separate functions.py and _functions.py force_build=image.force_build or pf.force_build, ) @@ -962,7 +962,7 @@ async def _load(param_bound_func: _Function, resolver: Resolver, existing_object f"The {identity} has not been hydrated with the metadata it needs to run on Modal{reason}." ) - assert parent._client.stub + assert parent._client and parent._client.stub if can_use_parent: # We can end up here if parent wasn't hydrated when class was instantiated, but has been since. @@ -983,9 +983,9 @@ async def _load(param_bound_func: _Function, resolver: Resolver, existing_object else: serialized_params = serialize((args, kwargs)) environment_name = _get_environment_name(None, resolver) - assert parent is not None + assert parent is not None and parent.is_hydrated req = api_pb2.FunctionBindParamsRequest( - function_id=parent._object_id, + function_id=parent.object_id, serialized_params=serialized_params, function_options=options, environment_name=environment_name @@ -1032,11 +1032,10 @@ async def keep_warm(self, warm_pool_size: int) -> None: """ ) ) - assert self._client and self._client.stub request = api_pb2.FunctionUpdateSchedulingParamsRequest( - function_id=self._object_id, warm_pool_size_override=warm_pool_size + function_id=self.object_id, warm_pool_size_override=warm_pool_size ) - await retry_transient_errors(self._client.stub.FunctionUpdateSchedulingParams, request) + await retry_transient_errors(self.client.stub.FunctionUpdateSchedulingParams, request) @classmethod @renamed_parameter((2024, 12, 18), "tag", "name") @@ -1142,7 +1141,7 @@ def get_build_def(self) -> str: """mdmd:hidden""" # Plaintext source and arg definition for the function, so it's part of the image # hash. We can't use the cloudpickle hash because it's not very stable. - assert hasattr(self, "_raw_f") and hasattr(self, "_build_args") + assert hasattr(self, "_raw_f") and hasattr(self, "_build_args") and self._raw_f is not None return f"{inspect.getsource(self._raw_f)}\n{repr(self._build_args)}" # Live handle methods @@ -1248,7 +1247,7 @@ async def _map( _map_invocation( self, # type: ignore input_queue, - self._client, + self.client, order_outputs, return_exceptions, count_update_callback, @@ -1266,7 +1265,7 @@ async def _call_function(self, args, kwargs) -> ReturnType: self, args, kwargs, - client=self._client, + client=self.client, function_call_invocation_type=function_call_invocation_type, ) @@ -1276,7 +1275,7 @@ async def _call_function_nowait( self, args, kwargs, function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType" ) -> _Invocation: return await _Invocation.create( - self, args, kwargs, client=self._client, function_call_invocation_type=function_call_invocation_type + self, args, kwargs, client=self.client, function_call_invocation_type=function_call_invocation_type ) @warn_if_generator_is_not_consumed() @@ -1287,7 +1286,7 @@ async def _call_generator(self, args, kwargs): self, args, kwargs, - client=self._client, + client=self.client, function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC_LEGACY, ) async for res in invocation.run_generator(): @@ -1303,7 +1302,7 @@ async def _call_generator_nowait(self, args, kwargs): self, args, kwargs, - client=self._client, + client=self.client, function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC_LEGACY, ) @@ -1452,14 +1451,14 @@ async def spawn(self, *args: P.args, **kwargs: P.kwargs) -> "_FunctionCall[Retur def get_raw_f(self) -> Callable[..., Any]: """Return the inner Python object wrapped by this Modal Function.""" + assert self._raw_f is not None return self._raw_f @live_method async def get_current_stats(self) -> FunctionStats: """Return a `FunctionStats` object describing the current function's queue and runner counts.""" - assert self._client.stub resp = await retry_transient_errors( - self._client.stub.FunctionGetCurrentStats, + self.client.stub.FunctionGetCurrentStats, api_pb2.FunctionGetCurrentStatsRequest(function_id=self.object_id), total_timeout=10.0, ) @@ -1491,8 +1490,7 @@ class _FunctionCall(typing.Generic[ReturnType], _Object, type_prefix="fc"): _is_generator: bool = False def _invocation(self): - assert self._client.stub - return _Invocation(self._client.stub, self.object_id, self._client) + return _Invocation(self.client.stub, self.object_id, self.client) async def get(self, timeout: Optional[float] = None) -> ReturnType: """Get the result of the function call. diff --git a/modal/image.py b/modal/image.py index 2406bba9c..24fbd988d 100644 --- a/modal/image.py +++ b/modal/image.py @@ -26,6 +26,7 @@ from modal_proto import api_pb2 +from ._object import _Object, live_method_gen from ._resolver import Resolver from ._serialization import serialize from ._utils.async_utils import synchronize_api @@ -46,7 +47,6 @@ from .gpu import GPU_T, parse_gpu_config from .mount import _Mount, python_standalone_mount_name from .network_file_system import _NetworkFileSystem -from .object import _Object, live_method_gen from .output import _get_output_manager from .scheduler_placement import SchedulerPlacement from .secret import _Secret @@ -2047,11 +2047,11 @@ def imports(self): try: yield except Exception as exc: - if self.object_id is None: - # Might be initialized later + if not self.is_hydrated: + # Might be hydrated later self.inside_exceptions.append(exc) elif env_image_id == self.object_id: - # Image is already initialized (we can remove this case later + # Image is already hydrated (we can remove this case later # when we don't hydrate objects so early) raise if not isinstance(exc, ImportError): @@ -2066,9 +2066,9 @@ async def _logs(self) -> typing.AsyncGenerator[str, None]: last_entry_id: str = "" request = api_pb2.ImageJoinStreamingRequest( - image_id=self._object_id, timeout=55, last_entry_id=last_entry_id, include_logs_for_finished=True + image_id=self.object_id, timeout=55, last_entry_id=last_entry_id, include_logs_for_finished=True ) - async for response in self._client.stub.ImageJoinStreaming.unary_stream(request): + async for response in self.client.stub.ImageJoinStreaming.unary_stream(request): if response.result.status: return if response.entry_id: diff --git a/modal/mount.py b/modal/mount.py index 1be44d806..5b741bf3d 100644 --- a/modal/mount.py +++ b/modal/mount.py @@ -20,6 +20,7 @@ from modal_proto import api_pb2 from modal_version import __version__ +from ._object import _get_environment_name, _Object from ._resolver import Resolver from ._utils.async_utils import aclosing, async_map, synchronize_api from ._utils.blob_utils import FileUploadSpec, blob_upload_file, get_file_upload_spec_from_path @@ -31,7 +32,6 @@ from .config import config, logger from .exception import InvalidError, ModuleNotMountable from .file_pattern_matcher import FilePatternMatcher -from .object import _get_environment_name, _Object ROOT_DIR: PurePosixPath = PurePosixPath("/root") MOUNT_PUT_FILE_CLIENT_TIMEOUT = 10 * 60 # 10 min max for transferring files diff --git a/modal/network_file_system.py b/modal/network_file_system.py index f0f4106c4..95b4ed75f 100644 --- a/modal/network_file_system.py +++ b/modal/network_file_system.py @@ -12,6 +12,13 @@ import modal from modal_proto import api_pb2 +from ._object import ( + EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, + _get_environment_name, + _Object, + live_method, + live_method_gen, +) from ._resolver import Resolver from ._utils.async_utils import TaskContext, aclosing, async_map, sync_or_async_iter, synchronize_api from ._utils.blob_utils import LARGE_FILE_LIMIT, blob_iter, blob_upload_file @@ -21,13 +28,6 @@ from ._utils.name_utils import check_object_name from .client import _Client from .exception import InvalidError -from .object import ( - EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, - _get_environment_name, - _Object, - live_method, - live_method_gen, -) from .volume import FileEntry NETWORK_FILE_SYSTEM_PUT_FILE_CLIENT_TIMEOUT = ( diff --git a/modal/object.py b/modal/object.py index a71a3e2a6..6ec6638d1 100644 --- a/modal/object.py +++ b/modal/object.py @@ -1,268 +1,5 @@ -# Copyright Modal Labs 2022 -import uuid -from collections.abc import Awaitable, Hashable, Sequence -from functools import wraps -from typing import Callable, ClassVar, Optional, TypeVar - -from google.protobuf.message import Message - -from modal._utils.async_utils import aclosing - -from ._resolver import Resolver +# Copyright Modal Labs 2025 +from ._object import _Object from ._utils.async_utils import synchronize_api -from .client import _Client -from .config import config, logger -from .exception import ExecutionError, InvalidError - -O = TypeVar("O", bound="_Object") - -_BLOCKING_O = synchronize_api(O) - -EPHEMERAL_OBJECT_HEARTBEAT_SLEEP: int = 300 - - -def _get_environment_name(environment_name: Optional[str] = None, resolver: Optional[Resolver] = None) -> Optional[str]: - if environment_name: - return environment_name - elif resolver and resolver.environment_name: - return resolver.environment_name - else: - return config.get("environment") - - -class _Object: - _type_prefix: ClassVar[Optional[str]] = None - _prefix_to_type: ClassVar[dict[str, type]] = {} - - # For constructors - _load: Optional[Callable[[O, Resolver, Optional[str]], Awaitable[None]]] - _preload: Optional[Callable[[O, Resolver, Optional[str]], Awaitable[None]]] - _rep: str - _is_another_app: bool - _hydrate_lazily: bool - _deps: Optional[Callable[..., list["_Object"]]] - _deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None - - # For hydrated objects - _object_id: str - _client: _Client - _is_hydrated: bool - _is_rehydrated: bool - - @classmethod - def __init_subclass__(cls, type_prefix: Optional[str] = None): - super().__init_subclass__() - if type_prefix is not None: - cls._type_prefix = type_prefix - cls._prefix_to_type[type_prefix] = cls - - def __init__(self, *args, **kwargs): - raise InvalidError(f"Class {type(self).__name__} has no constructor. Use class constructor methods instead.") - - def _init( - self, - rep: str, - load: Optional[Callable[[O, Resolver, Optional[str]], Awaitable[None]]] = None, - is_another_app: bool = False, - preload: Optional[Callable[[O, Resolver, Optional[str]], Awaitable[None]]] = None, - hydrate_lazily: bool = False, - deps: Optional[Callable[..., list["_Object"]]] = None, - deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None, - ): - self._local_uuid = str(uuid.uuid4()) - self._load = load - self._preload = preload - self._rep = rep - self._is_another_app = is_another_app - self._hydrate_lazily = hydrate_lazily - self._deps = deps - self._deduplication_key = deduplication_key - - self._object_id = None - self._client = None - self._is_hydrated = False - self._is_rehydrated = False - - self._initialize_from_empty() - - def _unhydrate(self): - self._object_id = None - self._client = None - self._is_hydrated = False - - def _initialize_from_empty(self): - # default implementation, can be overriden in subclasses - pass - - def _initialize_from_other(self, other): - # default implementation, can be overriden in subclasses - self._object_id = other._object_id - self._is_hydrated = other._is_hydrated - self._client = other._client - - def _hydrate(self, object_id: str, client: _Client, metadata: Optional[Message]): - assert isinstance(object_id, str) - if not object_id.startswith(self._type_prefix): - raise ExecutionError( - f"Can not hydrate {type(self)}:" - f" it has type prefix {self._type_prefix}" - f" but the object_id starts with {object_id[:3]}" - ) - self._object_id = object_id - self._client = client - self._hydrate_metadata(metadata) - self._is_hydrated = True - - def _hydrate_metadata(self, metadata: Optional[Message]): - # override this is subclasses that need additional data (other than an object_id) for a functioning Handle - pass - - def _get_metadata(self) -> Optional[Message]: - # return the necessary metadata from this handle to be able to re-hydrate in another context if one is needed - # used to provide a handle's handle_metadata for serializing/pickling a live handle - # the object_id is already provided by other means - return - - def _validate_is_hydrated(self: O): - if not self._is_hydrated: - object_type = self.__class__.__name__.strip("_") - if hasattr(self, "_app") and getattr(self._app, "_running_app", "") is None: - # The most common cause of this error: e.g., user called a Function without using App.run() - reason = ", because the App it is defined on is not running" - else: - # Technically possible, but with an ambiguous cause. - reason = "" - raise ExecutionError( - f"{object_type} has not been hydrated with the metadata it needs to run on Modal{reason}." - ) - - def clone(self: O) -> O: - """mdmd:hidden Clone a given hydrated object.""" - - # Object to clone must already be hydrated, otherwise from_loader is more suitable. - self._validate_is_hydrated() - obj = type(self).__new__(type(self)) - obj._initialize_from_other(self) - return obj - - @classmethod - def _from_loader( - cls, - load: Callable[[O, Resolver, Optional[str]], Awaitable[None]], - rep: str, - is_another_app: bool = False, - preload: Optional[Callable[[O, Resolver, Optional[str]], Awaitable[None]]] = None, - hydrate_lazily: bool = False, - deps: Optional[Callable[..., Sequence["_Object"]]] = None, - deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None, - ): - # TODO(erikbern): flip the order of the two first arguments - obj = _Object.__new__(cls) - obj._init(rep, load, is_another_app, preload, hydrate_lazily, deps, deduplication_key) - return obj - - @classmethod - def _get_type_from_id(cls: type[O], object_id: str) -> type[O]: - parts = object_id.split("-") - if len(parts) != 2: - raise InvalidError(f"Object id {object_id} has no dash in it") - prefix = parts[0] - if prefix not in cls._prefix_to_type: - raise InvalidError(f"Object prefix {prefix} does not correspond to a type") - return cls._prefix_to_type[prefix] - - @classmethod - def _is_id_type(cls: type[O], object_id) -> bool: - return cls._get_type_from_id(object_id) == cls - - @classmethod - def _new_hydrated( - cls: type[O], object_id: str, client: _Client, handle_metadata: Optional[Message], is_another_app: bool = False - ) -> O: - if cls._type_prefix is not None: - # This is called directly on a subclass, e.g. Secret.from_id - if not object_id.startswith(cls._type_prefix + "-"): - raise InvalidError(f"Object {object_id} does not start with {cls._type_prefix}") - obj_cls = cls - else: - # This is called on the base class, e.g. Handle.from_id - obj_cls = cls._get_type_from_id(object_id) - - # Instantiate provider - obj = _Object.__new__(obj_cls) - rep = f"Object({object_id})" # TODO(erikbern): dumb - obj._init(rep, is_another_app=is_another_app) - obj._hydrate(object_id, client, handle_metadata) - - return obj - - def _hydrate_from_other(self, other: O): - self._hydrate(other._object_id, other._client, other._get_metadata()) - - def __repr__(self): - return self._rep - - @property - def local_uuid(self): - """mdmd:hidden""" - return self._local_uuid - - @property - def object_id(self) -> str: - """mdmd:hidden""" - return self._object_id - - @property - def is_hydrated(self) -> bool: - """mdmd:hidden""" - return self._is_hydrated - - @property - def deps(self) -> Callable[..., list["_Object"]]: - """mdmd:hidden""" - return self._deps if self._deps is not None else lambda: [] - - async def resolve(self, client: Optional[_Client] = None): - """mdmd:hidden""" - if self._is_hydrated: - # memory snapshots capture references which must be rehydrated - # on restore to handle staleness. - if self._client._snapshotted and not self._is_rehydrated: - logger.debug(f"rehydrating {self} after snapshot") - self._is_hydrated = False # un-hydrate and re-resolve - c = client if client is not None else await _Client.from_env() - resolver = Resolver(c) - await resolver.load(self) - self._is_rehydrated = True - logger.debug(f"rehydrated {self} with client {id(c)}") - return - elif not self._hydrate_lazily: - self._validate_is_hydrated() - else: - # TODO: this client and/or resolver can't be changed by a caller to X.from_name() - c = client if client is not None else await _Client.from_env() - resolver = Resolver(c) - await resolver.load(self) - Object = synchronize_api(_Object, target_module=__name__) - - -def live_method(method): - @wraps(method) - async def wrapped(self, *args, **kwargs): - await self.resolve() - return await method(self, *args, **kwargs) - - return wrapped - - -def live_method_gen(method): - @wraps(method) - async def wrapped(self, *args, **kwargs): - await self.resolve() - async with aclosing(method(self, *args, **kwargs)) as stream: - async for item in stream: - yield item - - return wrapped diff --git a/modal/proxy.py b/modal/proxy.py index 84e582762..f86062cfd 100644 --- a/modal/proxy.py +++ b/modal/proxy.py @@ -3,9 +3,9 @@ from modal_proto import api_pb2 +from ._object import _get_environment_name, _Object from ._resolver import Resolver from ._utils.async_utils import synchronize_api -from .object import _get_environment_name, _Object class _Proxy(_Object, type_prefix="pr"): diff --git a/modal/queue.py b/modal/queue.py index 86b2093ec..fd21484f2 100644 --- a/modal/queue.py +++ b/modal/queue.py @@ -10,6 +10,7 @@ from modal_proto import api_pb2 +from ._object import EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, _get_environment_name, _Object, live_method, live_method_gen from ._resolver import Resolver from ._serialization import deserialize, serialize from ._utils.async_utils import TaskContext, synchronize_api, warn_if_generator_is_not_consumed @@ -18,7 +19,6 @@ from ._utils.name_utils import check_object_name from .client import _Client from .exception import InvalidError, RequestSizeError -from .object import EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, _get_environment_name, _Object, live_method, live_method_gen class _Queue(_Object, type_prefix="qu"): diff --git a/modal/runner.py b/modal/runner.py index 2cf1eaac4..4d350ef27 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -14,6 +14,7 @@ import modal_proto.api_pb2 from modal_proto import api_pb2 +from ._object import _get_environment_name, _Object from ._pty import get_pty_info from ._resolver import Resolver from ._runtime.execution_context import is_local @@ -28,7 +29,6 @@ from .environments import _get_environment_cached from .exception import InteractiveTimeoutError, InvalidError, RemoteError, _CliUserExecutionError from .functions import _Function -from .object import _get_environment_name, _Object from .output import _get_output_manager, enable_output from .running_app import RunningApp, running_app_from_layout from .sandbox import _Sandbox @@ -155,7 +155,7 @@ async def _preload(tag, obj): # this is to ensure that directly referenced functions from the global scope has # ids associated with them when they are serialized into other functions await resolver.preload(obj, existing_object_id) - if obj.object_id is not None: + if obj.is_hydrated: tag_to_object_id[tag] = obj.object_id await TaskContext.gather(*(_preload(tag, obj) for tag, obj in indexed_objects.items())) diff --git a/modal/sandbox.py b/modal/sandbox.py index d00ca83e4..04f17c1fb 100644 --- a/modal/sandbox.py +++ b/modal/sandbox.py @@ -16,6 +16,7 @@ from modal_proto import api_pb2 from ._location import parse_cloud_provider +from ._object import _get_environment_name, _Object from ._resolver import Resolver from ._resources import convert_fn_config_to_resources_config from ._utils.async_utils import synchronize_api @@ -32,7 +33,6 @@ from .io_streams import StreamReader, StreamWriter, _StreamReader, _StreamWriter from .mount import _Mount from .network_file_system import _NetworkFileSystem, network_file_system_mount_protos -from .object import _get_environment_name, _Object from .proxy import _Proxy from .scheduler_placement import SchedulerPlacement from .secret import _Secret diff --git a/modal/secret.py b/modal/secret.py index 3f6364031..dab08c3fd 100644 --- a/modal/secret.py +++ b/modal/secret.py @@ -6,6 +6,7 @@ from modal_proto import api_pb2 +from ._object import _get_environment_name, _Object from ._resolver import Resolver from ._runtime.execution_context import is_local from ._utils.async_utils import synchronize_api @@ -14,7 +15,6 @@ from ._utils.name_utils import check_object_name from .client import _Client from .exception import InvalidError, NotFoundError -from .object import _get_environment_name, _Object ENV_DICT_WRONG_TYPE_ERR = "the env_dict argument to Secret has to be a dict[str, Union[str, None]]" diff --git a/modal/volume.py b/modal/volume.py index 127af5ead..f9972f02b 100644 --- a/modal/volume.py +++ b/modal/volume.py @@ -27,6 +27,7 @@ from modal.exception import VolumeUploadTimeoutError from modal_proto import api_pb2 +from ._object import EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, _get_environment_name, _Object, live_method, live_method_gen from ._resolver import Resolver from ._utils.async_utils import TaskContext, aclosing, async_map, asyncnullcontext, synchronize_api from ._utils.blob_utils import ( @@ -41,7 +42,6 @@ from ._utils.name_utils import check_object_name from .client import _Client from .config import logger -from .object import EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, _get_environment_name, _Object, live_method, live_method_gen # Max duration for uploading to volumes files # As a guide, files >40GiB will take >10 minutes to upload. diff --git a/pyproject.toml b/pyproject.toml index 930846086..9f3da6a33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" [tool.mypy] -python_version = "3.9" +python_version = "3.11" exclude = "build" ignore_missing_imports = true check_untyped_defs = true diff --git a/test/function_serialization_test.py b/test/function_serialization_test.py index 196f24c6c..feb4798c6 100644 --- a/test/function_serialization_test.py +++ b/test/function_serialization_test.py @@ -13,7 +13,9 @@ async def test_serialize_deserialize_function(servicer, client): def foo(): 2 * foo.remote() - assert foo.object_id is None + assert not foo.is_hydrated + with pytest.raises(Exception): + foo.object_id # noqa with app.run(client=client): object_id = foo.object_id diff --git a/test/function_test.py b/test/function_test.py index 97b9f4332..b200404dc 100644 --- a/test/function_test.py +++ b/test/function_test.py @@ -756,7 +756,7 @@ def test_serialize_deserialize_function_handle(servicer, client): def my_handle(): pass - with pytest.raises(InvalidError, match="hasn't been created"): + with pytest.raises(InvalidError, match="hasn't been hydrated"): serialize(my_handle) # handle is not "live" yet! should not be serializable yet with app.run(client=client): diff --git a/test/object_test.py b/test/object_test.py index cb4c92f8f..0ecea8890 100644 --- a/test/object_test.py +++ b/test/object_test.py @@ -2,9 +2,9 @@ import pytest from modal import Secret +from modal._object import _Object from modal.dict import _Dict from modal.exception import InvalidError -from modal.object import _Object from modal.queue import _Queue diff --git a/test/resolver_test.py b/test/resolver_test.py index 1721c2500..69acf6187 100644 --- a/test/resolver_test.py +++ b/test/resolver_test.py @@ -4,8 +4,8 @@ import time from typing import Optional +from modal._object import _Object from modal._resolver import Resolver -from modal.object import _Object @pytest.mark.flaky(max_runs=2)