Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix actor pool in python 3.11, add better scaling down logic #760

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config/data/openwebtext_source.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized/openwebtext/"
tokenizer: "gpt2"
cache_options:
batch_size: 1024
num_shard_groups: 64
3 changes: 0 additions & 3 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,11 +1061,8 @@ def _write_batches(writer: ShardedCacheWriter, shard_totals, batches, finished_s


def _fetch_batches(batches) -> tuple[dict[str, int], list[PreparedBatch]]:
time_in = time.time()
shards_for_batches, payloads_for_batches = zip(*batches)
payloads_for_batches = ray.get(list(payloads_for_batches))
time_out = time.time()
logger.info(f"Fetched {len(batches)} batches in {time_out - time_in} seconds")

shard_row_totals: dict[str, int] = {}
for shard, payload in zip(shards_for_batches, payloads_for_batches):
Expand Down
48 changes: 37 additions & 11 deletions src/levanter/utils/actor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
# https://github.com/ray-project/ray/blob/1bab09bf842edee51c3778be4cfb16f8b900d764/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py


def _wrap_ray_future(ray_future):
# work around https://github.com/ray-project/ray/issues/45895#issuecomment-2165164129
return asyncio.wrap_future(ray_future.future())


class AutoScalingActorPool:
"""Utility class to operate on a dynamically scaling pool of actors."""

Expand All @@ -37,6 +42,7 @@ def __init__(
self._actor_locations: Dict[ray.actor.ActorHandle, str] = {}
self._tasks_waiting_for_actor: list[asyncio.Future] = []
self._next_task_id = 0
self._scale_down_task: Optional[asyncio.Task] = None

self._scale_up(self._min_size)

Expand All @@ -45,14 +51,17 @@ def num_pending_tasks(self):
return len(self._tasks_waiting_for_actor)

def _scale_up(self, num_actors: int):
if self._scale_down_task and not self._scale_down_task.done():
self._scale_down_task.cancel()

for _ in range(num_actors):
try:
actor = self._create_actor_fn()
ready_ref = actor.get_location.remote()
self._pending_actors[ready_ref] = actor

async def wait_for_ready(actor, ready_ref):
loc = await ready_ref
loc = await _wrap_ray_future(ready_ref)
# pending -> floating
if ready_ref not in self._pending_actors:
logger.info("Actor was cancelled before it was ready.")
Expand All @@ -67,8 +76,8 @@ async def wait_for_ready(actor, ready_ref):
except Exception as e:
logger.error("Failed to create actor.", exc_info=e)

def _scale_down(self, num_actors: int):
for _ in range(num_actors):
def _scale_down(self, target_num_actors: int):
while len(self._idle_actors) + len(self._pending_actors) > target_num_actors:
if self._pending_actors:
actor = self._pending_actors.popitem()[1]
# let it die through gc
Expand Down Expand Up @@ -102,10 +111,20 @@ def _adjust_pool_size(self):
f" {self._max_size}"
)
self._scale_up(min(self._max_size - num_busy_actors, num_pending_tasks))

# Schedule scale down if idle
elif num_pending_tasks == 0 and num_nonworking_actors > self._min_size:
return # never scal edown. too many issues
logger.info(f"Scaling down due to no pending tasks. Current pool size: {total_actors}")
self._scale_down(num_nonworking_actors - self._min_size)
if self._scale_down_task is None or self._scale_down_task.done():
self._scale_down_task = asyncio.create_task(self._schedule_scale_down())

async def _schedule_scale_down(self):
try:
await asyncio.sleep(10)
if self.num_pending_tasks == 0:
logger.info("Scaling down due to no pending tasks.")
self._scale_down(self._min_size)
except asyncio.CancelledError:
logger.info("Scale down task was cancelled due to new activity.")

def _get_object_location(self, obj_ref: ray.ObjectRef) -> Optional[str]:
"""Get the location of the given object reference."""
Expand Down Expand Up @@ -153,10 +172,11 @@ def _assign_task_to_actor(self, actor, fn, value):
# floating -> busy
ray_future = fn(actor, value)
self._busy_actors[ray_future] = actor
if self._scale_down_task and not self._scale_down_task.done():
self._scale_down_task.cancel()
self._adjust_pool_size()

# return ray_future
return asyncio.ensure_future(self._wrap_ray_future(ray_future))
return asyncio.ensure_future(self._set_up_actor_return_on_finished(ray_future))

async def _enqueue_pending_task(self, fn, obj_ref, value, actor_future):
actor = await actor_future
Expand All @@ -181,10 +201,11 @@ def _maybe_start_pending_task(self, actor):
assigned = False
return assigned

async def _wrap_ray_future(self, ray_future):
await asyncio.wait([ray_future])
async def _set_up_actor_return_on_finished(self, ray_future):
future = _wrap_ray_future(ray_future)
await asyncio.wait([future])
self._on_task_done(ray_future)
return await ray_future
return await future

def _on_task_done(self, ray_future):
actor = self._busy_actors.pop(ray_future)
Expand Down Expand Up @@ -218,6 +239,11 @@ def push(self, actor: "ray.actor.ActorHandle"):
self._actor_locations[actor] = location
self._maybe_start_pending_task(actor)

def __del__(self):
if self._scale_down_task and not self._scale_down_task.done():
self._scale_down_task.cancel()
# just let ray kill the actors naturally


class PoolWorkerBase(ABC):
def get_location(self) -> str:
Expand Down
Loading