Skip to content

Commit

Permalink
Fixes a bunch of static typing issues in modal.object
Browse files Browse the repository at this point in the history
* Splits object.py into _object.py with implemenation and object.py with synchronicity wrappers
* Fixes a lot of type errors that surfaced as a result
  • Loading branch information
freider committed Jan 17, 2025
1 parent 978c18e commit 53f2afa
Show file tree
Hide file tree
Showing 28 changed files with 126 additions and 107 deletions.
2 changes: 1 addition & 1 deletion modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
from ._runtime.execution_context import _set_current_context_ids

if TYPE_CHECKING:
import modal._object
import modal._runtime.container_io_manager
import modal.object


class DaemonizedThreadPool:
Expand Down
89 changes: 50 additions & 39 deletions modal/_object.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
# Copyright Modal Labs 2022
import typing
import uuid
from collections.abc import Awaitable, Hashable, Sequence
from functools import wraps
from typing import Callable, ClassVar, Optional, TypeVar
from typing import Callable, ClassVar, Optional

from google.protobuf.message import Message
from typing_extensions import Self

from modal._utils.async_utils import aclosing

from ._resolver import Resolver
from ._utils.async_utils import synchronize_api
from .client import _Client
from .config import config, logger
from .exception import ExecutionError, InvalidError

O = TypeVar("O", bound="_Object")

_BLOCKING_O = synchronize_api(O)

EPHEMERAL_OBJECT_HEARTBEAT_SLEEP: int = 300


Expand All @@ -35,17 +32,17 @@ class _Object:
_prefix_to_type: ClassVar[dict[str, type]] = {}

# For constructors
_load: Optional[Callable[[O, Resolver, Optional[str]], Awaitable[None]]]
_preload: Optional[Callable[[O, Resolver, Optional[str]], Awaitable[None]]]
_load: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]]
_preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]]
_rep: str
_is_another_app: bool
_hydrate_lazily: bool
_deps: Optional[Callable[..., list["_Object"]]]
_deps: Optional[Callable[..., Sequence["_Object"]]]
_deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None

# For hydrated objects
_object_id: str
_client: _Client
_object_id: Optional[str]
_client: Optional[_Client]
_is_hydrated: bool
_is_rehydrated: bool

Expand All @@ -62,11 +59,11 @@ def __init__(self, *args, **kwargs):
def _init(
self,
rep: str,
load: Optional[Callable[[O, Resolver, Optional[str]], Awaitable[None]]] = None,
load: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None,
is_another_app: bool = False,
preload: Optional[Callable[[O, Resolver, Optional[str]], Awaitable[None]]] = None,
preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None,
hydrate_lazily: bool = False,
deps: Optional[Callable[..., list["_Object"]]] = None,
deps: Optional[Callable[..., Sequence["_Object"]]] = None,
deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None,
):
self._local_uuid = str(uuid.uuid4())
Expand Down Expand Up @@ -101,7 +98,7 @@ def _initialize_from_other(self, other):
self._client = other._client

def _hydrate(self, object_id: str, client: _Client, metadata: Optional[Message]):
assert isinstance(object_id, str)
assert isinstance(object_id, str) and self._type_prefix is not None
if not object_id.startswith(self._type_prefix):
raise ExecutionError(
f"Can not hydrate {type(self)}:"
Expand All @@ -121,12 +118,12 @@ def _get_metadata(self) -> Optional[Message]:
# return the necessary metadata from this handle to be able to re-hydrate in another context if one is needed
# used to provide a handle's handle_metadata for serializing/pickling a live handle
# the object_id is already provided by other means
return
return None

def _validate_is_hydrated(self: O):
def _validate_is_hydrated(self):
if not self._is_hydrated:
object_type = self.__class__.__name__.strip("_")
if hasattr(self, "_app") and getattr(self._app, "_running_app", "") is None:
if hasattr(self, "_app") and getattr(self._app, "_running_app", "") is None: # type: ignore
# The most common cause of this error: e.g., user called a Function without using App.run()
reason = ", because the App it is defined on is not running"
else:
Expand All @@ -136,7 +133,7 @@ def _validate_is_hydrated(self: O):
f"{object_type} has not been hydrated with the metadata it needs to run on Modal{reason}."
)

def clone(self: O) -> O:
def clone(self) -> Self:
"""mdmd:hidden Clone a given hydrated object."""

# Object to clone must already be hydrated, otherwise from_loader is more suitable.
Expand All @@ -148,10 +145,10 @@ def clone(self: O) -> O:
@classmethod
def _from_loader(
cls,
load: Callable[[O, Resolver, Optional[str]], Awaitable[None]],
load: Callable[[Self, Resolver, Optional[str]], Awaitable[None]],
rep: str,
is_another_app: bool = False,
preload: Optional[Callable[[O, Resolver, Optional[str]], Awaitable[None]]] = None,
preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None,
hydrate_lazily: bool = False,
deps: Optional[Callable[..., Sequence["_Object"]]] = None,
deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None,
Expand All @@ -161,32 +158,36 @@ def _from_loader(
obj._init(rep, load, is_another_app, preload, hydrate_lazily, deps, deduplication_key)
return obj

@classmethod
def _get_type_from_id(cls: type[O], object_id: str) -> type[O]:
@staticmethod
def _get_type_from_id(object_id: str) -> type["_Object"]:
parts = object_id.split("-")
if len(parts) != 2:
raise InvalidError(f"Object id {object_id} has no dash in it")
prefix = parts[0]
if prefix not in cls._prefix_to_type:
if prefix not in _Object._prefix_to_type:
raise InvalidError(f"Object prefix {prefix} does not correspond to a type")
return cls._prefix_to_type[prefix]
return _Object._prefix_to_type[prefix]

@classmethod
def _is_id_type(cls: type[O], object_id) -> bool:
def _is_id_type(cls, object_id) -> bool:
return cls._get_type_from_id(object_id) == cls

@classmethod
def _new_hydrated(
cls: type[O], object_id: str, client: _Client, handle_metadata: Optional[Message], is_another_app: bool = False
) -> O:
cls, object_id: str, client: _Client, handle_metadata: Optional[Message], is_another_app: bool = False
) -> Self:
obj_cls: type[Self]
if cls._type_prefix is not None:
# This is called directly on a subclass, e.g. Secret.from_id
# validate the id matching the expected id type of the Object subclass
if not object_id.startswith(cls._type_prefix + "-"):
raise InvalidError(f"Object {object_id} does not start with {cls._type_prefix}")

obj_cls = cls
else:
# This is called on the base class, e.g. Handle.from_id
obj_cls = cls._get_type_from_id(object_id)
# this means the method is used directly on _Object
# typically during deserialization of objects
obj_cls = typing.cast(type[Self], cls._get_type_from_id(object_id))

# Instantiate provider
obj = _Object.__new__(obj_cls)
Expand All @@ -196,8 +197,8 @@ def _new_hydrated(

return obj

def _hydrate_from_other(self, other: O):
self._hydrate(other._object_id, other._client, other._get_metadata())
def _hydrate_from_other(self, other: Self):
self._hydrate(other.object_id, other.client, other._get_metadata())

def __repr__(self):
return self._rep
Expand All @@ -210,29 +211,42 @@ def local_uuid(self):
@property
def object_id(self) -> str:
"""mdmd:hidden"""
if self._object_id is None:
raise AttributeError(f"Attempting to get object_id of unhydrated {self}")
return self._object_id

@property
def client(self) -> _Client:
"""mdmd:hidden"""
if self._client is None:
raise AttributeError(f"Attempting to get client of unhydrated {self}")
return self._client

@property
def is_hydrated(self) -> bool:
"""mdmd:hidden"""
return self._is_hydrated

@property
def deps(self) -> Callable[..., list["_Object"]]:
def deps(self) -> Callable[..., Sequence["_Object"]]:
"""mdmd:hidden"""
return self._deps if self._deps is not None else lambda: []

def default_deps(*args, **kwargs) -> Sequence["_Object"]:
return []

return self._deps if self._deps is not None else default_deps

async def resolve(self, client: Optional[_Client] = None):
"""mdmd:hidden"""
if self._is_hydrated:
# memory snapshots capture references which must be rehydrated
# on restore to handle staleness.
if self._client._snapshotted and not self._is_rehydrated:
if self.client._snapshotted and not self._is_rehydrated:
logger.debug(f"rehydrating {self} after snapshot")
self._is_hydrated = False # un-hydrate and re-resolve
c = client if client is not None else await _Client.from_env()
resolver = Resolver(c)
await resolver.load(self)
await resolver.load(typing.cast(_Object, self))
self._is_rehydrated = True
logger.debug(f"rehydrated {self} with client {id(c)}")
return
Expand All @@ -245,9 +259,6 @@ async def resolve(self, client: Optional[_Client] = None):
await resolver.load(self)


Object = synchronize_api(_Object, target_module=__name__)


def live_method(method):
@wraps(method)
async def wrapped(self, *args, **kwargs):
Expand Down
12 changes: 7 additions & 5 deletions modal/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
if TYPE_CHECKING:
from rich.tree import Tree

from modal.object import _Object
import modal._object


class StatusRow:
Expand All @@ -33,7 +33,7 @@ def message(self, message):
self._spinner.update(text=message)

def finish(self, message):
if self._step_node is not None:
if self._step_node is not None and self._spinner is not None:
from ._output import OutputManager

self._spinner.update(text=message)
Expand Down Expand Up @@ -89,7 +89,7 @@ async def preload(self, obj, existing_object_id: Optional[str]):
if obj._preload is not None:
await obj._preload(obj, self, existing_object_id)

async def load(self, obj: "_Object", existing_object_id: Optional[str] = None):
async def load(self, obj: "modal._object._Object", existing_object_id: Optional[str] = None):
if obj._is_hydrated and obj._is_another_app:
# No need to reload this, it won't typically change
if obj.local_uuid not in self._local_uuid_to_future:
Expand Down Expand Up @@ -124,6 +124,8 @@ async def loader():
await TaskContext.gather(*[self.load(dep) for dep in obj.deps()])

# Load the object itself
if not obj._load:
raise Exception(f"Object {obj} has no loader function")
try:
await obj._load(obj, self, existing_object_id)
except GRPCError as exc:
Expand Down Expand Up @@ -154,8 +156,8 @@ async def loader():
# TODO(elias): print original exception/trace rather than the Resolver-internal trace
return await cached_future

def objects(self) -> list["_Object"]:
unique_objects: dict[str, "_Object"] = {}
def objects(self) -> list["modal._object._Object"]:
unique_objects: dict[str, "modal._object._Object"] = {}
for fut in self._local_uuid_to_future.values():
if not fut.done():
# this will raise an exception if not all loads have been awaited, but that *should* never happen
Expand Down
14 changes: 7 additions & 7 deletions modal/_runtime/user_code_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import typing
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Sequence

import modal._object
import modal._runtime.container_io_manager
import modal.cls
import modal.object
from modal import Function
from modal._utils.async_utils import synchronizer
from modal._utils.function_utils import LocalFunctionError, is_async as get_is_async, is_global_object
Expand Down Expand Up @@ -41,7 +41,7 @@ class Service(metaclass=ABCMeta):

user_cls_instance: Any
app: Optional["modal.app._App"]
code_deps: Optional[list["modal.object._Object"]]
code_deps: Optional[Sequence["modal._object._Object"]]

@abstractmethod
def get_finalized_functions(
Expand Down Expand Up @@ -94,7 +94,7 @@ def construct_webhook_callable(
class ImportedFunction(Service):
user_cls_instance: Any
app: Optional["modal.app._App"]
code_deps: Optional[list["modal.object._Object"]]
code_deps: Optional[Sequence["modal._object._Object"]]

_user_defined_callable: Callable[..., Any]

Expand Down Expand Up @@ -137,7 +137,7 @@ def get_finalized_functions(
class ImportedClass(Service):
user_cls_instance: Any
app: Optional["modal.app._App"]
code_deps: Optional[list["modal.object._Object"]]
code_deps: Optional[Sequence["modal._object._Object"]]

_partial_functions: dict[str, "modal.partial_function._PartialFunction"]

Expand Down Expand Up @@ -229,7 +229,7 @@ def import_single_function_service(
"""
user_defined_callable: Callable
function: Optional[_Function] = None
code_deps: Optional[list["modal.object._Object"]] = None
code_deps: Optional[Sequence["modal._object._Object"]] = None
active_app: Optional[modal.app._App] = None

if ser_fun is not None:
Expand Down Expand Up @@ -306,7 +306,7 @@ def import_class_service(
See import_function.
"""
active_app: Optional["modal.app._App"]
code_deps: Optional[list["modal.object._Object"]]
code_deps: Optional[Sequence["modal._object._Object"]]
cls: typing.Union[type, modal.cls.Cls]

if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
Expand Down
7 changes: 4 additions & 3 deletions modal/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from modal._utils.async_utils import synchronizer
from modal_proto import api_pb2

from ._object import _Object
from ._vendor import cloudpickle
from .config import logger
from .exception import DeserializationError, ExecutionError, InvalidError
from .object import Object, _Object
from .object import Object

PICKLE_PROTOCOL = 4 # Support older Python versions.

Expand Down Expand Up @@ -48,8 +49,8 @@ def persistent_id(self, obj):
return ("sync", (impl_object.__class__, attributes))
else:
return
if not obj.object_id:
raise InvalidError(f"Can't serialize object {obj} which hasn't been created.")
if not obj.is_hydrated:
raise InvalidError(f"Can't serialize object {obj} which hasn't been hydrated.")
return (obj.object_id, flag, obj._get_metadata())


Expand Down
2 changes: 1 addition & 1 deletion modal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from modal_proto import api_pb2

from ._ipython import is_notebook
from ._object import _get_environment_name, _Object
from ._utils.async_utils import synchronize_api
from ._utils.deprecation import deprecation_error, deprecation_warning, renamed_parameter
from ._utils.function_utils import FunctionInfo, is_global_object, is_method_fn
Expand All @@ -36,7 +37,6 @@
from .image import _Image
from .mount import _Mount
from .network_file_system import _NetworkFileSystem
from .object import _get_environment_name, _Object
from .partial_function import (
PartialFunction,
_find_partial_methods_for_user_cls,
Expand Down
2 changes: 1 addition & 1 deletion modal/cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from rich.text import Text
from typer import Argument

from modal._object import _get_environment_name
from modal._utils.async_utils import synchronizer
from modal._utils.deprecation import deprecation_warning
from modal.client import _Client
from modal.environments import ensure_env
from modal.object import _get_environment_name
from modal_proto import api_pb2

from .utils import ENV_OPTION, display_table, get_app_id_from_name, stream_app_logs, timestamp_to_local
Expand Down
Loading

0 comments on commit 53f2afa

Please sign in to comment.