Skip to content

Commit

Permalink
Fixes a bunch of static typing issues in modal.object [CLI-312] (#2771)
Browse files Browse the repository at this point in the history
* Splits object.py into _object.py with implemenation and object.py with synchronicity wrappers
* Fixes a lot of type errors that surfaced as a result
  • Loading branch information
freider authored Jan 17, 2025
1 parent dc4e923 commit 690aa85
Show file tree
Hide file tree
Showing 28 changed files with 353 additions and 333 deletions.
2 changes: 1 addition & 1 deletion modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
from ._runtime.execution_context import _set_current_context_ids

if TYPE_CHECKING:
import modal._object
import modal._runtime.container_io_manager
import modal.object


class DaemonizedThreadPool:
Expand Down
279 changes: 279 additions & 0 deletions modal/_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
# Copyright Modal Labs 2022
import typing
import uuid
from collections.abc import Awaitable, Hashable, Sequence
from functools import wraps
from typing import Callable, ClassVar, Optional

from google.protobuf.message import Message
from typing_extensions import Self

from modal._utils.async_utils import aclosing

from ._resolver import Resolver
from .client import _Client
from .config import config, logger
from .exception import ExecutionError, InvalidError

EPHEMERAL_OBJECT_HEARTBEAT_SLEEP: int = 300


def _get_environment_name(environment_name: Optional[str] = None, resolver: Optional[Resolver] = None) -> Optional[str]:
if environment_name:
return environment_name
elif resolver and resolver.environment_name:
return resolver.environment_name
else:
return config.get("environment")


class _Object:
_type_prefix: ClassVar[Optional[str]] = None
_prefix_to_type: ClassVar[dict[str, type]] = {}

# For constructors
_load: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]]
_preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]]
_rep: str
_is_another_app: bool
_hydrate_lazily: bool
_deps: Optional[Callable[..., Sequence["_Object"]]]
_deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None

# For hydrated objects
_object_id: Optional[str]
_client: Optional[_Client]
_is_hydrated: bool
_is_rehydrated: bool

@classmethod
def __init_subclass__(cls, type_prefix: Optional[str] = None):
super().__init_subclass__()
if type_prefix is not None:
cls._type_prefix = type_prefix
cls._prefix_to_type[type_prefix] = cls

def __init__(self, *args, **kwargs):
raise InvalidError(f"Class {type(self).__name__} has no constructor. Use class constructor methods instead.")

def _init(
self,
rep: str,
load: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None,
is_another_app: bool = False,
preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None,
hydrate_lazily: bool = False,
deps: Optional[Callable[..., Sequence["_Object"]]] = None,
deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None,
):
self._local_uuid = str(uuid.uuid4())
self._load = load
self._preload = preload
self._rep = rep
self._is_another_app = is_another_app
self._hydrate_lazily = hydrate_lazily
self._deps = deps
self._deduplication_key = deduplication_key

self._object_id = None
self._client = None
self._is_hydrated = False
self._is_rehydrated = False

self._initialize_from_empty()

def _unhydrate(self):
self._object_id = None
self._client = None
self._is_hydrated = False

def _initialize_from_empty(self):
# default implementation, can be overriden in subclasses
pass

def _initialize_from_other(self, other):
# default implementation, can be overriden in subclasses
self._object_id = other._object_id
self._is_hydrated = other._is_hydrated
self._client = other._client

def _hydrate(self, object_id: str, client: _Client, metadata: Optional[Message]):
assert isinstance(object_id, str) and self._type_prefix is not None
if not object_id.startswith(self._type_prefix):
raise ExecutionError(
f"Can not hydrate {type(self)}:"
f" it has type prefix {self._type_prefix}"
f" but the object_id starts with {object_id[:3]}"
)
self._object_id = object_id
self._client = client
self._hydrate_metadata(metadata)
self._is_hydrated = True

def _hydrate_metadata(self, metadata: Optional[Message]):
# override this is subclasses that need additional data (other than an object_id) for a functioning Handle
pass

def _get_metadata(self) -> Optional[Message]:
# return the necessary metadata from this handle to be able to re-hydrate in another context if one is needed
# used to provide a handle's handle_metadata for serializing/pickling a live handle
# the object_id is already provided by other means
return None

def _validate_is_hydrated(self):
if not self._is_hydrated:
object_type = self.__class__.__name__.strip("_")
if hasattr(self, "_app") and getattr(self._app, "_running_app", "") is None: # type: ignore
# The most common cause of this error: e.g., user called a Function without using App.run()
reason = ", because the App it is defined on is not running"
else:
# Technically possible, but with an ambiguous cause.
reason = ""
raise ExecutionError(
f"{object_type} has not been hydrated with the metadata it needs to run on Modal{reason}."
)

def clone(self) -> Self:
"""mdmd:hidden Clone a given hydrated object."""

# Object to clone must already be hydrated, otherwise from_loader is more suitable.
self._validate_is_hydrated()
obj = type(self).__new__(type(self))
obj._initialize_from_other(self)
return obj

@classmethod
def _from_loader(
cls,
load: Callable[[Self, Resolver, Optional[str]], Awaitable[None]],
rep: str,
is_another_app: bool = False,
preload: Optional[Callable[[Self, Resolver, Optional[str]], Awaitable[None]]] = None,
hydrate_lazily: bool = False,
deps: Optional[Callable[..., Sequence["_Object"]]] = None,
deduplication_key: Optional[Callable[[], Awaitable[Hashable]]] = None,
):
# TODO(erikbern): flip the order of the two first arguments
obj = _Object.__new__(cls)
obj._init(rep, load, is_another_app, preload, hydrate_lazily, deps, deduplication_key)
return obj

@staticmethod
def _get_type_from_id(object_id: str) -> type["_Object"]:
parts = object_id.split("-")
if len(parts) != 2:
raise InvalidError(f"Object id {object_id} has no dash in it")
prefix = parts[0]
if prefix not in _Object._prefix_to_type:
raise InvalidError(f"Object prefix {prefix} does not correspond to a type")
return _Object._prefix_to_type[prefix]

@classmethod
def _is_id_type(cls, object_id) -> bool:
return cls._get_type_from_id(object_id) == cls

@classmethod
def _new_hydrated(
cls, object_id: str, client: _Client, handle_metadata: Optional[Message], is_another_app: bool = False
) -> Self:
obj_cls: type[Self]
if cls._type_prefix is not None:
# This is called directly on a subclass, e.g. Secret.from_id
# validate the id matching the expected id type of the Object subclass
if not object_id.startswith(cls._type_prefix + "-"):
raise InvalidError(f"Object {object_id} does not start with {cls._type_prefix}")

obj_cls = cls
else:
# this means the method is used directly on _Object
# typically during deserialization of objects
obj_cls = typing.cast(type[Self], cls._get_type_from_id(object_id))

# Instantiate provider
obj = _Object.__new__(obj_cls)
rep = f"Object({object_id})" # TODO(erikbern): dumb
obj._init(rep, is_another_app=is_another_app)
obj._hydrate(object_id, client, handle_metadata)

return obj

def _hydrate_from_other(self, other: Self):
self._hydrate(other.object_id, other.client, other._get_metadata())

def __repr__(self):
return self._rep

@property
def local_uuid(self):
"""mdmd:hidden"""
return self._local_uuid

@property
def object_id(self) -> str:
"""mdmd:hidden"""
if self._object_id is None:
raise AttributeError(f"Attempting to get object_id of unhydrated {self}")
return self._object_id

@property
def client(self) -> _Client:
"""mdmd:hidden"""
if self._client is None:
raise AttributeError(f"Attempting to get client of unhydrated {self}")
return self._client

@property
def is_hydrated(self) -> bool:
"""mdmd:hidden"""
return self._is_hydrated

@property
def deps(self) -> Callable[..., Sequence["_Object"]]:
"""mdmd:hidden"""

def default_deps(*args, **kwargs) -> Sequence["_Object"]:
return []

return self._deps if self._deps is not None else default_deps

async def resolve(self, client: Optional[_Client] = None):
"""mdmd:hidden"""
if self._is_hydrated:
# memory snapshots capture references which must be rehydrated
# on restore to handle staleness.
if self.client._snapshotted and not self._is_rehydrated:
logger.debug(f"rehydrating {self} after snapshot")
self._is_hydrated = False # un-hydrate and re-resolve
c = client if client is not None else await _Client.from_env()
resolver = Resolver(c)
await resolver.load(typing.cast(_Object, self))
self._is_rehydrated = True
logger.debug(f"rehydrated {self} with client {id(c)}")
return
elif not self._hydrate_lazily:
self._validate_is_hydrated()
else:
# TODO: this client and/or resolver can't be changed by a caller to X.from_name()
c = client if client is not None else await _Client.from_env()
resolver = Resolver(c)
await resolver.load(self)


def live_method(method):
@wraps(method)
async def wrapped(self, *args, **kwargs):
await self.resolve()
return await method(self, *args, **kwargs)

return wrapped


def live_method_gen(method):
@wraps(method)
async def wrapped(self, *args, **kwargs):
await self.resolve()
async with aclosing(method(self, *args, **kwargs)) as stream:
async for item in stream:
yield item

return wrapped
12 changes: 7 additions & 5 deletions modal/_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
if TYPE_CHECKING:
from rich.tree import Tree

from modal.object import _Object
import modal._object


class StatusRow:
Expand All @@ -33,7 +33,7 @@ def message(self, message):
self._spinner.update(text=message)

def finish(self, message):
if self._step_node is not None:
if self._step_node is not None and self._spinner is not None:
from ._output import OutputManager

self._spinner.update(text=message)
Expand Down Expand Up @@ -89,7 +89,7 @@ async def preload(self, obj, existing_object_id: Optional[str]):
if obj._preload is not None:
await obj._preload(obj, self, existing_object_id)

async def load(self, obj: "_Object", existing_object_id: Optional[str] = None):
async def load(self, obj: "modal._object._Object", existing_object_id: Optional[str] = None):
if obj._is_hydrated and obj._is_another_app:
# No need to reload this, it won't typically change
if obj.local_uuid not in self._local_uuid_to_future:
Expand Down Expand Up @@ -124,6 +124,8 @@ async def loader():
await TaskContext.gather(*[self.load(dep) for dep in obj.deps()])

# Load the object itself
if not obj._load:
raise Exception(f"Object {obj} has no loader function")
try:
await obj._load(obj, self, existing_object_id)
except GRPCError as exc:
Expand Down Expand Up @@ -154,8 +156,8 @@ async def loader():
# TODO(elias): print original exception/trace rather than the Resolver-internal trace
return await cached_future

def objects(self) -> list["_Object"]:
unique_objects: dict[str, "_Object"] = {}
def objects(self) -> list["modal._object._Object"]:
unique_objects: dict[str, "modal._object._Object"] = {}
for fut in self._local_uuid_to_future.values():
if not fut.done():
# this will raise an exception if not all loads have been awaited, but that *should* never happen
Expand Down
14 changes: 7 additions & 7 deletions modal/_runtime/user_code_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import typing
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Sequence

import modal._object
import modal._runtime.container_io_manager
import modal.cls
import modal.object
from modal import Function
from modal._utils.async_utils import synchronizer
from modal._utils.function_utils import LocalFunctionError, is_async as get_is_async, is_global_object
Expand Down Expand Up @@ -41,7 +41,7 @@ class Service(metaclass=ABCMeta):

user_cls_instance: Any
app: Optional["modal.app._App"]
code_deps: Optional[list["modal.object._Object"]]
code_deps: Optional[Sequence["modal._object._Object"]]

@abstractmethod
def get_finalized_functions(
Expand Down Expand Up @@ -94,7 +94,7 @@ def construct_webhook_callable(
class ImportedFunction(Service):
user_cls_instance: Any
app: Optional["modal.app._App"]
code_deps: Optional[list["modal.object._Object"]]
code_deps: Optional[Sequence["modal._object._Object"]]

_user_defined_callable: Callable[..., Any]

Expand Down Expand Up @@ -137,7 +137,7 @@ def get_finalized_functions(
class ImportedClass(Service):
user_cls_instance: Any
app: Optional["modal.app._App"]
code_deps: Optional[list["modal.object._Object"]]
code_deps: Optional[Sequence["modal._object._Object"]]

_partial_functions: dict[str, "modal.partial_function._PartialFunction"]

Expand Down Expand Up @@ -229,7 +229,7 @@ def import_single_function_service(
"""
user_defined_callable: Callable
function: Optional[_Function] = None
code_deps: Optional[list["modal.object._Object"]] = None
code_deps: Optional[Sequence["modal._object._Object"]] = None
active_app: Optional[modal.app._App] = None

if ser_fun is not None:
Expand Down Expand Up @@ -306,7 +306,7 @@ def import_class_service(
See import_function.
"""
active_app: Optional["modal.app._App"]
code_deps: Optional[list["modal.object._Object"]]
code_deps: Optional[Sequence["modal._object._Object"]]
cls: typing.Union[type, modal.cls.Cls]

if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
Expand Down
Loading

0 comments on commit 690aa85

Please sign in to comment.