Skip to content

Commit

Permalink
new dist_cp save planner to fix issue that each rank needs to downlo…
Browse files Browse the repository at this point in the history
…ad all checkpoint files (#3271)

* a

* a

* a

* a

* lint

* lint

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
bigning and mvpatel2000 authored May 13, 2024
1 parent 986a394 commit beb5a35
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
75 changes: 75 additions & 0 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,3 +717,78 @@ def set_optimizer_state_dict(

_verify_state_dict({}, optim_state_dict, info)
_load_optim_state_dict(model, optimizers, optim_state_dict, info)


# torch2.3 patch to fix https://github.com/pytorch/pytorch/issues/125740
from torch.distributed.checkpoint.default_planner import (
create_default_global_save_plan,
DefaultSavePlanner,
_validate_global_plan,
)
import dataclasses
from collections import defaultdict, ChainMap
from typing import Dict, List, Set, TYPE_CHECKING

from torch.distributed.checkpoint.planner import SavePlan, WriteItem
from torch.distributed.checkpoint.metadata import MetadataIndex, Metadata

def dedup_save_plans(all_plans: List[SavePlan]) -> List[SavePlan]: # noqa: D103
write_item_to_plan_indices: Dict[MetadataIndex, Set[int]] = defaultdict(set)
write_item_idx_to_write_item: Dict[MetadataIndex, WriteItem] = {}
for plan_idx, plan in enumerate(all_plans):
for write_item in plan.items:
# map each write item to its plan
write_item_to_plan_indices[write_item.index].add(plan_idx)
write_item_idx_to_write_item[write_item.index] = write_item

# put item in the plan with the smallest size and remove it from the other plan_indices
to_remove: List[Set] = [set() for _ in range(len(all_plans))]
plan_to_size = [0] * len(all_plans)
for write_item_idx, plan_indices in write_item_to_plan_indices.items():
# this line is the fix, to keep the duplicated tensors on the same rank
select_plan_idx = min(plan_indices, key=lambda plan_idx: plan_idx)

write_item = write_item_idx_to_write_item[write_item_idx]
# essentially ignores the storage size of anything that is not a tensor, since
# we don't know how much storage they represent
plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1

plan_indices.remove(select_plan_idx)
for plan_idx in plan_indices:
to_remove[plan_idx].add(write_item_idx)

for plan_idx, remove_set in enumerate(to_remove):
new_items = [
write_item
for write_item in all_plans[plan_idx].items
if write_item.index not in remove_set
]
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)

return all_plans


class SavePlannerWithDedupFix(DefaultSavePlanner): # noqa: D101
def create_global_plan(
self, all_plans: List[SavePlan],
) -> Tuple[List[SavePlan], Metadata]:
all_plans = dedup_save_plans(all_plans)

global_plan, metadata = create_default_global_save_plan(all_plans)

if self.flatten_state_dict:
# | does not work for Python 3.8 or older version.
# merged_mappings = reduce(
# lambda x, y: x | y, (p.planner_data for p in global_plan)
# )
planner_data_dict = [p.planner_data for p in global_plan]
merged_mappings = dict(ChainMap(*planner_data_dict))
metadata = dataclasses.replace(metadata, planner_data=merged_mappings)

if not _validate_global_plan(global_plan, metadata):
raise ValueError('Failed to validate global plan')

self.global_plan = global_plan
self.metadata = metadata

return self.global_plan, self.metadata
6 changes: 5 additions & 1 deletion composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,10 +1119,14 @@ def _save_checkpoint(

if expect_file:
if version.parse(torch.__version__) >= version.parse('2.3.0'):
save_planner = state.fsdp_config['save_planner']
if save_planner is None:
from composer.trainer.mosaic_fsdp_utils import SavePlannerWithDedupFix
save_planner = SavePlannerWithDedupFix()
dist_cp.save(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(dirname),
planner=state.fsdp_config['save_planner'],
planner=save_planner,
process_group=process_group,
)
else:
Expand Down

0 comments on commit beb5a35

Please sign in to comment.