From 2ee3c7cece73c3911a60a4bfecfcf405744226b4 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Tue, 14 Jan 2025 18:28:32 +0000 Subject: [PATCH] wip --- modal/sandbox.py | 15 ++++++++++++++- modal_proto/api.proto | 3 +++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/modal/sandbox.py b/modal/sandbox.py index 819483c6a..fa901f423 100644 --- a/modal/sandbox.py +++ b/modal/sandbox.py @@ -60,6 +60,7 @@ class _Sandbox(_Object, type_prefix="sb"): _stdin: _StreamWriter _task_id: Optional[str] = None _tunnels: Optional[dict[int, Tunnel]] = None + _enable_memory_snapshot: bool = False @staticmethod def _new( @@ -83,6 +84,7 @@ def _new( unencrypted_ports: Sequence[int] = [], proxy: Optional[_Proxy] = None, _experimental_scheduler_placement: Optional[SchedulerPlacement] = None, + enable_memory_snapshot: bool = False, ) -> "_Sandbox": """mdmd:hidden""" @@ -178,6 +180,7 @@ async def _load(self: _Sandbox, resolver: Resolver, _existing_object_id: Optiona open_ports=api_pb2.PortSpecs(ports=open_ports), network_access=network_access, proxy_id=(proxy.object_id if proxy else None), + enable_memory_snapshot=enable_memory_snapshot, ) # Note - `resolver.app_id` will be `None` for app-less sandboxes @@ -225,6 +228,8 @@ async def create( unencrypted_ports: Sequence[int] = [], # Reference to a Modal Proxy to use in front of this Sandbox. proxy: Optional[_Proxy] = None, + # Enable memory snapshots. + enable_memory_snapshot: bool = False, _experimental_scheduler_placement: Optional[ SchedulerPlacement ] = None, # Experimental controls over fine-grained scheduling (alpha). @@ -262,7 +267,9 @@ async def create( unencrypted_ports=unencrypted_ports, proxy=proxy, _experimental_scheduler_placement=_experimental_scheduler_placement, + enable_memory_snapshot=enable_memory_snapshot, ) + obj._enable_memory_snapshot = enable_memory_snapshot app_id: Optional[str] = None app_client: Optional[_Client] = None @@ -531,8 +538,13 @@ async def exec( resp = await retry_transient_errors(self._client.stub.ContainerExec, req) by_line = bufsize == 1 return _ContainerProcess(resp.exec_id, self._client, stdout=stdout, stderr=stderr, text=text, by_line=by_line) - + async def snapshot(self) -> str: + if not self._enable_memory_snapshot: + raise ValueError( + "Memory snapshots are not supported for this sandbox. To enable memory snapshots, " + "set `enable_memory_snapshot=True` when creating the sandbox." + ) req = api_pb2.SandboxSnapshotRequest(sandbox_id=self.object_id) resp = await retry_transient_errors(self._client.stub.SandboxSnapshot, req) return resp.snapshot_id @@ -541,6 +553,7 @@ async def snapshot(self) -> str: async def from_snapshot(snapshot_id: str, app: Optional["modal.app._App"] = None): app_client: Optional[_Client] = None + # TODO: I think we should do this without passing in an app. It should be more similar to `Sandbox.from_id`. if app is not None: if app.app_id is None: raise ValueError( diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 5b4c3b276..9b319a66a 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -2180,6 +2180,9 @@ message Sandbox { NetworkAccess network_access = 22; optional string proxy_id = 23; + + // Enable memory snapshots. + bool enable_memory_snapshot = 24; } message SandboxCreateRequest {