diff --git a/README.md b/README.md index 2a0e993a..4473022d 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi 🔥 You can experience Open-Sora on our [🤗 Gradio application on Hugging Face](https://huggingface.co/spaces/hpcai-tech/open-sora). More samples and corresponding prompts are available in our [Gallery](https://hpcaitech.github.io/Open-Sora/). | **4s 720×1280** | **4s 720×1280** | **4s 720×1280** | -| ---------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | +|------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------| | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/7895aab6-ed23-488c-8486-091480c26327) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/20f07c7b-182b-4562-bbee-f1df74c86c9a) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/3d897e0d-dc21-453a-b911-b3bda838acc2) | | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/644bf938-96ce-44aa-b797-b3c0b513d64c) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/272d88ac-4b4a-484d-a665-8d07431671d0) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/ebbac621-c34e-4bb4-9543-1c34f8989764) | | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/a1e3a1a3-4abd-45f5-8df2-6cced69da4ca) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/d6ce9c13-28e1-4dff-9644-cc01f5f11926) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/561978f8-f1b0-4f4d-ae7b-45bec9001b4a) | @@ -55,16 +55,16 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi OpenSora 1.1 Demo | **2s 240×426** | **2s 240×426** | -| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | +|-------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------| | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c31ebc52-de39-4a4e-9b1e-9211d45e05b2) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c31ebc52-de39-4a4e-9b1e-9211d45e05b2) | | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/f7ce4aaa-528f-40a8-be7a-72e61eaacbbd) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/5d58d71e-1fda-4d90-9ad3-5f2f7b75c6a9) | | **2s 426×240** | **4s 480×854** | -| ---------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | +|------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------| | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/34ecb4a0-4eef-4286-ad4c-8e3a87e5a9fd) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c1619333-25d7-42ba-a91c-18dbc1870b18) | | **16s 320×320** | **16s 224×448** | **2s 426×240** | -| ------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------- | +|--------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------| | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/3cab536e-9b43-4b33-8da8-a0f9cf842ff2) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/9fb0b9e0-c6f4-4935-b29e-4cac10b373c4) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/3e892ad2-9543-4049-b005-643a4c1bf3bf) | @@ -73,7 +73,7 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi OpenSora 1.0 Demo | **2s 512×512** | **2s 512×512** | **2s 512×512** | -| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------- | +|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------| | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/de1963d3-b43b-4e68-a670-bb821ebb6f80) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/13f8338f-3d42-4b71-8142-d234fbd746cc) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/fa6a65a6-e32a-4d64-9a9e-eabb0ebb8c16) | | A serene night scene in a forested area. [...] The video is a time-lapse, capturing the transition from day to night, with the lake and forest serving as a constant backdrop. | A soaring drone footage captures the majestic beauty of a coastal cliff, [...] The water gently laps at the rock base and the greenery that clings to the top of the cliff. | The majestic beauty of a waterfall cascading down a cliff into a serene lake. [...] The camera angle provides a bird's eye view of the waterfall. | | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/64232f84-1b36-4750-a6c0-3e610fa9aa94) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/983a1965-a374-41a7-a76b-c07941a6c1e9) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/ec10c879-9767-4c31-865f-2e8d6cf11e65) | @@ -190,6 +190,12 @@ pip install -r requirements/requirements-cu121.txt # the default installation is for inference only pip install -v . # for development mode, `pip install -v -e .` + +# install the latest tensornvme to use async checkpoint saving +pip install git+https://github.com/hpcaitech/TensorNVMe.git + +# install the latest colossalai to use the latest features +pip install git+https://github.com/hpcaitech/ColossalAI.git ``` (Optional, recommended for fast speed, especially for training) To enable `layernorm_kernel` and `flash_attn`, you need to install `apex` and `flash-attn` with the following commands. @@ -224,7 +230,7 @@ docker run -ti --gpus all -v .:/workspace/Open-Sora opensora ### Open-Sora 1.2 Model Weights | Model | Model Size | Data | #iterations | Batch Size | URL | -| --------- | ---------- | ---- | ----------- | ---------- | ------------------------------------------------------------- | +|-----------|------------|------|-------------|------------|---------------------------------------------------------------| | Diffusion | 1.1B | 30M | 70k | Dynamic | [:link:](https://huggingface.co/hpcai-tech/OpenSora-STDiT-v3) | | VAE | 384M | 3M | 1M | 8 | [:link:](https://huggingface.co/hpcai-tech/OpenSora-VAE-v1.2) | @@ -238,7 +244,7 @@ See our **[report 1.2](docs/report_03.md)** for more infomation. Weight will be View more | Resolution | Model Size | Data | #iterations | Batch Size | URL | -| ------------------ | ---------- | -------------------------- | ----------- | ------------------------------------------------- | -------------------------------------------------------------------- | +|--------------------|------------|----------------------------|-------------|---------------------------------------------------|----------------------------------------------------------------------| | mainly 144p & 240p | 700M | 10M videos + 2M images | 100k | [dynamic](/configs/opensora-v1-1/train/stage2.py) | [:link:](https://huggingface.co/hpcai-tech/OpenSora-STDiT-v2-stage2) | | 144p to 720p | 700M | 500K HQ videos + 1M images | 4k | [dynamic](/configs/opensora-v1-1/train/stage3.py) | [:link:](https://huggingface.co/hpcai-tech/OpenSora-STDiT-v2-stage3) | @@ -254,7 +260,7 @@ See our **[report 1.1](docs/report_02.md)** for more infomation. View more | Resolution | Model Size | Data | #iterations | Batch Size | GPU days (H800) | URL | -| ---------- | ---------- | ------ | ----------- | ---------- | --------------- | --------------------------------------------------------------------------------------------- | +|------------|------------|--------|-------------|------------|-----------------|-----------------------------------------------------------------------------------------------| | 16×512×512 | 700M | 20K HQ | 20k | 2×64 | 35 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x512x512.pth) | | 16×256×256 | 700M | 20K HQ | 24k | 8×64 | 45 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x256x256.pth) | | 16×256×256 | 700M | 366K | 80k | 8×64 | 117 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-16x256x256.pth) | @@ -303,7 +309,7 @@ The easiest way to generate a video is to input a text prompt and click the "**G Then, you can choose the **resolution**, **duration**, and **aspect ratio** of the generated video. Different resolution and video length will affect the video generation speed. On a 80G H100 GPU, the generation speed (with `num_sampling_step=30`) and peak memory usage is: | | Image | 2s | 4s | 8s | 16s | -| ---- | ------- | -------- | --------- | --------- | --------- | +|------|---------|----------|-----------|-----------|-----------| | 360p | 3s, 24G | 18s, 27G | 31s, 27G | 62s, 28G | 121s, 33G | | 480p | 2s, 24G | 29s, 31G | 55s, 30G | 108s, 32G | 219s, 36G | | 720p | 6s, 27G | 68s, 41G | 130s, 39G | 260s, 45G | 547s, 67G | @@ -446,6 +452,9 @@ Also check out the [datasets](docs/datasets.md) we use. The training process is same as Open-Sora 1.1. ```bash +# If you use async checkpoint saving, and you want to validate the integrity of checkpoints, you can use the following command +# Then there will be a `async_file_io.log` in checkpoint directory. If the number of lines of the log file is not equal to the number of checkpoints (.safetensors files), there may be some errors. +export TENSORNVME_DEBUG=1 # one node torchrun --standalone --nproc_per_node 8 scripts/train.py \ configs/opensora-v1-2/train/stage1.py --data-path YOUR_CSV_PATH --ckpt-path YOUR_PRETRAINED_CKPT @@ -510,7 +519,7 @@ We support evaluation based on: All the evaluation code is released in `eval` folder. Check the [README](/eval/README.md) for more details. Our [report](/docs/report_03.md#evaluation) also provides more information about the evaluation during training. The following table shows Open-Sora 1.2 greatly improves Open-Sora 1.0. | Model | Total Score | Quality Score | Semantic Score | -| -------------- | ----------- | ------------- | -------------- | +|----------------|-------------|---------------|----------------| | Open-Sora V1.0 | 75.91% | 78.81% | 64.28% | | Open-Sora V1.2 | 79.23% | 80.71% | 73.30% | diff --git a/configs/opensora-v1-2/train/stage1.py b/configs/opensora-v1-2/train/stage1.py index e1f1e7c0..cbaa20c8 100644 --- a/configs/opensora-v1-2/train/stage1.py +++ b/configs/opensora-v1-2/train/stage1.py @@ -108,3 +108,6 @@ ema_decay = 0.99 adam_eps = 1e-15 warmup_steps = 1000 + +cache_pin_memory = True +pin_memory_cache_pre_alloc_numels = [(290 + 20) * 1024**2] * (2 * 8 + 4) diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py index 60d0b24a..97df0a16 100644 --- a/opensora/datasets/dataloader.py +++ b/opensora/datasets/dataloader.py @@ -1,17 +1,216 @@ import collections +import functools +import queue import random +import threading from typing import Optional import numpy as np import torch +import torch.multiprocessing as multiprocessing +from torch._utils import ExceptionWrapper from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, _utils +from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL +from torch.utils.data.dataloader import ( + IterDataPipe, + MapDataPipe, + _BaseDataLoaderIter, + _MultiProcessingDataLoaderIter, + _sharding_worker_init_fn, + _SingleProcessDataLoaderIter, +) from .datasets import BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset +from .pin_memory_cache import PinMemoryCache from .sampler import BatchDistributedSampler, StatefulDistributedSampler, VariableVideoBatchSampler +def _pin_memory_loop( + in_queue, out_queue, device_id, done_event, device, pin_memory_cache: PinMemoryCache, pin_memory_key: str +): + # This setting is thread local, and prevents the copy in pin_memory from + # consuming all CPU cores. + torch.set_num_threads(1) + + if device == "cuda": + torch.cuda.set_device(device_id) + elif device == "xpu": + torch.xpu.set_device(device_id) # type: ignore[attr-defined] + elif device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) + custom_device_mod.set_device(device_id) + + def do_one_step(): + try: + r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + return + idx, data = r + if not done_event.is_set() and not isinstance(data, ExceptionWrapper): + try: + assert isinstance(data, dict) + if pin_memory_key in data: + val = data[pin_memory_key] + pin_memory_value = pin_memory_cache.get(val) + pin_memory_value.copy_(val) + data[pin_memory_key] = pin_memory_value + except Exception: + data = ExceptionWrapper(where=f"in pin memory thread for device {device_id}") + r = (idx, data) + while not done_event.is_set(): + try: + out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) + break + except queue.Full: + continue + + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + while not done_event.is_set(): + # Make sure that we don't preserve any object from one iteration + # to the next + do_one_step() + + +class _MultiProcessingDataLoaderIterForVideo(_MultiProcessingDataLoaderIter): + pin_memory_key: str = "video" + + def __init__(self, loader): + _BaseDataLoaderIter.__init__(self, loader) + self.pin_memory_cache = PinMemoryCache() + + self._prefetch_factor = loader.prefetch_factor + + assert self._num_workers > 0 + assert self._prefetch_factor > 0 + + if loader.multiprocessing_context is None: + multiprocessing_context = multiprocessing + else: + multiprocessing_context = loader.multiprocessing_context + + self._worker_init_fn = loader.worker_init_fn + + # Adds forward compatibilities so classic DataLoader can work with DataPipes: + # Additional worker init function will take care of sharding in MP and Distributed + if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): + self._worker_init_fn = functools.partial( + _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank + ) + + # No certainty which module multiprocessing_context is + self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + self._worker_pids_set = False + self._shutdown = False + self._workers_done_event = multiprocessing_context.Event() + + self._index_queues = [] + self._workers = [] + for i in range(self._num_workers): + # No certainty which module multiprocessing_context is + index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] + # Need to `cancel_join_thread` here! + # See sections (2) and (3b) above. + index_queue.cancel_join_thread() + w = multiprocessing_context.Process( + target=_utils.worker._worker_loop, + args=( + self._dataset_kind, + self._dataset, + index_queue, + self._worker_result_queue, + self._workers_done_event, + self._auto_collation, + self._collate_fn, + self._drop_last, + self._base_seed, + self._worker_init_fn, + i, + self._num_workers, + self._persistent_workers, + self._shared_seed, + ), + ) + w.daemon = True + # NB: Process.start() actually take some time as it needs to + # start a process and pass the arguments over via a pipe. + # Therefore, we only add a worker to self._workers list after + # it started, so that we do not call .join() if program dies + # before it starts, and __del__ tries to join but will get: + # AssertionError: can only join a started process. + w.start() + self._index_queues.append(index_queue) + self._workers.append(w) + + if self._pin_memory: + self._pin_memory_thread_done_event = threading.Event() + + # Queue is not type-annotated + self._data_queue = queue.Queue() # type: ignore[var-annotated] + if self._pin_memory_device == "xpu": + current_device = torch.xpu.current_device() # type: ignore[attr-defined] + elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): + custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) + current_device = custom_device_mod.current_device() + else: + current_device = torch.cuda.current_device() # choose cuda for default + pin_memory_thread = threading.Thread( + target=_pin_memory_loop, + args=( + self._worker_result_queue, + self._data_queue, + current_device, + self._pin_memory_thread_done_event, + self._pin_memory_device, + self.pin_memory_cache, + self.pin_memory_key, + ), + ) + pin_memory_thread.daemon = True + pin_memory_thread.start() + # Similar to workers (see comment above), we only register + # pin_memory_thread once it is started. + self._pin_memory_thread = pin_memory_thread + else: + self._data_queue = self._worker_result_queue # type: ignore[assignment] + + # In some rare cases, persistent workers (daemonic processes) + # would be terminated before `__del__` of iterator is invoked + # when main process exits + # It would cause failure when pin_memory_thread tries to read + # corrupted data from worker_result_queue + # atexit is used to shutdown thread and child processes in the + # right sequence before main process exits + if self._persistent_workers and self._pin_memory: + import atexit + + for w in self._workers: + atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) + + # .pid can be None only before process is spawned (not the case, so ignore) + _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] + _utils.signal_handling._set_SIGCHLD_handler() + self._worker_pids_set = True + self._reset(loader, first_iter=True) + + def remove_cache(self, output_tensor: torch.Tensor): + self.pin_memory_cache.remove(output_tensor) + + def get_cache_info(self) -> str: + return str(self.pin_memory_cache) + + +class DataloaderForVideo(DataLoader): + def _get_iterator(self) -> "_BaseDataLoaderIter": + if self.num_workers == 0: + return _SingleProcessDataLoaderIter(self) + else: + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIterForVideo(self) + + # Deterministic dataloader def get_seed_worker(seed): def seed_worker(worker_id): @@ -35,6 +234,7 @@ def prepare_dataloader( bucket_config=None, num_bucket_build_workers=1, prefetch_factor=None, + cache_pin_memory=False, **kwargs, ): _kwargs = kwargs.copy() @@ -50,8 +250,9 @@ def prepare_dataloader( verbose=True, num_bucket_build_workers=num_bucket_build_workers, ) + dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader return ( - DataLoader( + dl_cls( dataset, batch_sampler=batch_sampler, worker_init_fn=get_seed_worker(seed), @@ -71,8 +272,9 @@ def prepare_dataloader( rank=process_group.rank(), shuffle=shuffle, ) + dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader return ( - DataLoader( + dl_cls( dataset, batch_size=batch_size, sampler=sampler, @@ -137,7 +339,7 @@ def collate_fn_batch(batch): """ # filter out None batch = [x for x in batch if x is not None] - + res = torch.utils.data.default_collate(batch) # squeeze the first dimension, which is due to torch.stack() in default_collate() diff --git a/opensora/datasets/pin_memory_cache.py b/opensora/datasets/pin_memory_cache.py new file mode 100644 index 00000000..3f6e7559 --- /dev/null +++ b/opensora/datasets/pin_memory_cache.py @@ -0,0 +1,76 @@ +import threading +from typing import Dict, List, Optional + +import torch + + +class PinMemoryCache: + force_dtype: Optional[torch.dtype] = None + min_cache_numel: int = 0 + pre_alloc_numels: List[int] = [] + + def __init__(self): + self.cache: Dict[int, torch.Tensor] = {} + self.output_to_cache: Dict[int, int] = {} + self.cache_to_output: Dict[int, int] = {} + self.lock = threading.Lock() + self.total_cnt = 0 + self.hit_cnt = 0 + + if len(self.pre_alloc_numels) > 0 and self.force_dtype is not None: + for n in self.pre_alloc_numels: + cache_tensor = torch.empty(n, dtype=self.force_dtype, device="cpu", pin_memory=True) + with self.lock: + self.cache[id(cache_tensor)] = cache_tensor + + def get(self, tensor: torch.Tensor) -> torch.Tensor: + """Receive a cpu tensor and return the corresponding pinned tensor. Note that this only manage memory allocation, doesn't copy content. + + Args: + tensor (torch.Tensor): The tensor to be pinned. + + Returns: + torch.Tensor: The pinned tensor. + """ + self.total_cnt += 1 + with self.lock: + # find free cache + for cache_id, cache_tensor in self.cache.items(): + if cache_id not in self.cache_to_output and cache_tensor.numel() >= tensor.numel(): + target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape) + out_id = id(target_cache_tensor) + self.output_to_cache[out_id] = cache_id + self.cache_to_output[cache_id] = out_id + self.hit_cnt += 1 + return target_cache_tensor + # no free cache, create a new one + dtype = self.force_dtype if self.force_dtype is not None else tensor.dtype + cache_numel = max(tensor.numel(), self.min_cache_numel) + cache_tensor = torch.empty(cache_numel, dtype=dtype, device="cpu", pin_memory=True) + target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape) + out_id = id(target_cache_tensor) + with self.lock: + self.cache[id(cache_tensor)] = cache_tensor + self.output_to_cache[out_id] = id(cache_tensor) + self.cache_to_output[id(cache_tensor)] = out_id + return target_cache_tensor + + def remove(self, output_tensor: torch.Tensor) -> None: + """Release corresponding cache tensor. + + Args: + output_tensor (torch.Tensor): The tensor to be released. + """ + out_id = id(output_tensor) + with self.lock: + if out_id not in self.output_to_cache: + raise ValueError("Tensor not found in cache.") + cache_id = self.output_to_cache.pop(out_id) + del self.cache_to_output[cache_id] + + def __str__(self): + with self.lock: + num_cached = len(self.cache) + num_used = len(self.output_to_cache) + total_cache_size = sum([v.numel() * v.element_size() for v in self.cache.values()]) + return f"PinMemoryCache(num_cached={num_cached}, num_used={num_used}, total_cache_size={total_cache_size / 1024**3:.2f} GB, hit rate={self.hit_cnt / self.total_cnt:.2f})" diff --git a/opensora/models/text_encoder/t5.py b/opensora/models/text_encoder/t5.py index d67328a8..6378c347 100644 --- a/opensora/models/text_encoder/t5.py +++ b/opensora/models/text_encoder/t5.py @@ -170,14 +170,8 @@ def shardformer_t5(self): from opensora.utils.misc import requires_grad shard_config = ShardConfig( - tensor_parallel_process_group=None, - pipeline_stage_manager=None, enable_tensor_parallelism=False, - enable_fused_normalization=False, - enable_flash_attention=False, enable_jit_fused=True, - enable_sequence_parallelism=False, - enable_sequence_overlap=False, ) shard_former = ShardFormer(shard_config=shard_config) optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy()) diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py index d730981c..d0607ff3 100644 --- a/opensora/utils/ckpt_utils.py +++ b/opensora/utils/ckpt_utils.py @@ -2,13 +2,16 @@ import json import operator import os -from typing import Tuple +from typing import Dict, Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn from colossalai.booster import Booster from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.utils.safetensors import save as async_save +from safetensors.torch import load_file +from tensornvme.async_file_io import AsyncFileWriter from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torchvision.datasets.utils import download_url @@ -150,10 +153,19 @@ def load_from_sharded_state_dict(model, ckpt_path, model_name="model", strict=Fa ckpt_io.load_model(model, os.path.join(ckpt_path, model_name), strict=strict) -def model_sharding(model: torch.nn.Module): +def model_sharding(model: torch.nn.Module, device: torch.device = None): + """ + Sharding the model parameters across multiple GPUs. + + Args: + model (torch.nn.Module): The model to shard. + device (torch.device): The device to shard the model to. + """ global_rank = dist.get_rank() world_size = dist.get_world_size() for _, param in model.named_parameters(): + if device is None: + device = param.device padding_size = (world_size - param.numel() % world_size) % world_size if padding_size > 0: padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) @@ -161,18 +173,34 @@ def model_sharding(model: torch.nn.Module): padding_param = param.data.view(-1) splited_params = padding_param.split(padding_param.numel() // world_size) splited_params = splited_params[global_rank] - param.data = splited_params + param.data = splited_params.to(device) -def model_gathering(model: torch.nn.Module, model_shape_dict: dict): +def model_gathering(model: torch.nn.Module, model_shape_dict: dict, pinned_state_dict: dict) -> None: + """ + Gather the model parameters from multiple GPUs. + + Args: + model (torch.nn.Module): The model to gather. + model_shape_dict (dict): The shape of the model parameters. + device (torch.device): The device to gather the model to. + """ global_rank = dist.get_rank() global_size = dist.get_world_size() + params = set() for name, param in model.named_parameters(): + params.add(name) all_params = [torch.empty_like(param.data) for _ in range(global_size)] dist.all_gather(all_params, param.data, group=dist.group.WORLD) if int(global_rank) == 0: all_params = torch.cat(all_params) - param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name]) + gathered_param = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name]) + pinned_state_dict[name].copy_(gathered_param) + if int(global_rank) == 0: + for k, v in model.state_dict(keep_vars=True).items(): + if k not in params: + pinned_state_dict[k].copy_(v) + dist.barrier() @@ -195,6 +223,7 @@ def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", stri get_logger().info("Unexpected keys: %s", unexpected_keys) elif ckpt_path.endswith(".safetensors"): from safetensors.torch import load_file + state_dict = load_file(ckpt_path) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) print(f"Missing keys: {missing_keys}") @@ -223,76 +252,130 @@ def save_json(data, file_path: str): # save and load for training -def save( - booster: Booster, - save_dir: str, - model: nn.Module = None, - ema: nn.Module = None, - optimizer: Optimizer = None, - lr_scheduler: _LRScheduler = None, - sampler=None, - epoch: int = None, - step: int = None, - global_step: int = None, - batch_size: int = None, -): - save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}") - os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) - - if model is not None: - booster.save_model(model, os.path.join(save_dir, "model"), shard=True) - if optimizer is not None: - booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) - if lr_scheduler is not None: - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) - if dist.get_rank() == 0: - running_states = { - "epoch": epoch, - "step": step, - "global_step": global_step, - "batch_size": batch_size, - } - save_json(running_states, os.path.join(save_dir, "running_states.json")) - +def _prepare_ema_pinned_state_dict(model: nn.Module, ema_shape_dict: dict): + ema_pinned_state_dict = dict() + for name, p in model.named_parameters(): + ema_pinned_state_dict[name] = torch.empty(ema_shape_dict[name], pin_memory=True, device="cpu", dtype=p.dtype) + sd = model.state_dict(keep_vars=True) + # handle buffers + for k, v in sd.items(): + if k not in ema_pinned_state_dict: + ema_pinned_state_dict[k] = torch.empty(v.shape, pin_memory=True, device="cpu", dtype=v.dtype) + + return ema_pinned_state_dict + + +class CheckpointIO: + def __init__(self, n_write_entries: int = 32): + self.n_write_entries = n_write_entries + self.writer: Optional[AsyncFileWriter] = None + self.pinned_state_dict: Optional[Dict[str, torch.Tensor]] = None + + def _sync_io(self): + if self.writer is not None: + self.writer = None + + def __del__(self): + self._sync_io() + + def _prepare_pinned_state_dict(self, ema: nn.Module, ema_shape_dict: dict): + if self.pinned_state_dict is None and dist.get_rank() == 0: + self.pinned_state_dict = _prepare_ema_pinned_state_dict(ema, ema_shape_dict) + + def save( + self, + booster: Booster, + save_dir: str, + model: nn.Module = None, + ema: nn.Module = None, + optimizer: Optimizer = None, + lr_scheduler: _LRScheduler = None, + sampler=None, + epoch: int = None, + step: int = None, + global_step: int = None, + batch_size: int = None, + ema_shape_dict: dict = None, + async_io: bool = True, + ): + self._sync_io() + save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}") + os.environ["TENSORNVME_DEBUG_LOG"] = os.path.join(save_dir, "async_file_io.log") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + + if model is not None: + booster.save_model( + model, + os.path.join(save_dir, "model"), + shard=True, + use_safetensors=True, + size_per_shard=4096, + use_async=async_io, + ) + if optimizer is not None: + booster.save_optimizer( + optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096, use_async=async_io + ) + if lr_scheduler is not None: + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) if ema is not None: - torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt")) - + self._prepare_pinned_state_dict(ema, ema_shape_dict) + model_gathering(ema, ema_shape_dict, self.pinned_state_dict) + if dist.get_rank() == 0: + running_states = { + "epoch": epoch, + "step": step, + "global_step": global_step, + "batch_size": batch_size, + } + save_json(running_states, os.path.join(save_dir, "running_states.json")) + + if ema is not None: + if async_io: + self.writer = async_save(os.path.join(save_dir, "ema.safetensors"), self.pinned_state_dict) + else: + torch.save(self.pinned_state_dict, os.path.join(save_dir, "ema.pt")) + + if sampler is not None: + # only for VariableVideoBatchSampler + torch.save(sampler.state_dict(step), os.path.join(save_dir, "sampler")) + dist.barrier() + return save_dir + + def load( + self, + booster: Booster, + load_dir: str, + model: nn.Module = None, + ema: nn.Module = None, + optimizer: Optimizer = None, + lr_scheduler: _LRScheduler = None, + sampler=None, + ) -> Tuple[int, int, int]: + assert os.path.exists(load_dir), f"Checkpoint directory {load_dir} does not exist" + assert os.path.exists(os.path.join(load_dir, "running_states.json")), "running_states.json does not exist" + running_states = load_json(os.path.join(load_dir, "running_states.json")) + if model is not None: + booster.load_model(model, os.path.join(load_dir, "model")) + if ema is not None: + # ema is not boosted, so we don't use booster.load_model + if os.path.exists(os.path.join(load_dir, "ema.safetensors")): + ema_state_dict = load_file(os.path.join(load_dir, "ema.safetensors")) + else: + ema_state_dict = torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")) + ema.load_state_dict( + ema_state_dict, + strict=False, + ) + if optimizer is not None: + booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) + if lr_scheduler is not None: + booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) if sampler is not None: - # only for VariableVideoBatchSampler - torch.save(sampler.state_dict(step), os.path.join(save_dir, "sampler")) - dist.barrier() - return save_dir - - -def load( - booster: Booster, - load_dir: str, - model: nn.Module = None, - ema: nn.Module = None, - optimizer: Optimizer = None, - lr_scheduler: _LRScheduler = None, - sampler=None, -) -> Tuple[int, int, int]: - assert os.path.exists(load_dir), f"Checkpoint directory {load_dir} does not exist" - assert os.path.exists(os.path.join(load_dir, "running_states.json")), "running_states.json does not exist" - running_states = load_json(os.path.join(load_dir, "running_states.json")) - if model is not None: - booster.load_model(model, os.path.join(load_dir, "model")) - if ema is not None: - # ema is not boosted, so we don't use booster.load_model - ema.load_state_dict( - torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")), - strict=False, - ) - if optimizer is not None: - booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) - if lr_scheduler is not None: - booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) - if sampler is not None: - sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler"))) - dist.barrier() + sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler"))) + dist.barrier() - return ( - running_states["epoch"], - running_states["step"], - ) + return ( + running_states["epoch"], + running_states["step"], + ) diff --git a/opensora/utils/train_utils.py b/opensora/utils/train_utils.py index 95d0011b..f015b44c 100644 --- a/opensora/utils/train_utils.py +++ b/opensora/utils/train_utils.py @@ -61,7 +61,7 @@ def update_ema( else: if param.data.dtype != torch.float32: param_id = id(param) - master_param = optimizer._param_store.working_to_master_param[param_id] + master_param = optimizer.get_working_to_master_map()[param_id] param_data = master_param.data else: param_data = param.data diff --git a/scripts/train.py b/scripts/train.py index 110f2f84..e41857f6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,4 +1,5 @@ import os +import subprocess from contextlib import nullcontext from copy import deepcopy from datetime import timedelta @@ -16,8 +17,9 @@ from opensora.acceleration.checkpoint import set_grad_checkpoint from opensora.acceleration.parallel_states import get_data_parallel_group from opensora.datasets.dataloader import prepare_dataloader +from opensora.datasets.pin_memory_cache import PinMemoryCache from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module -from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save +from opensora.utils.ckpt_utils import CheckpointIO, model_sharding, record_model_param_shape from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config from opensora.utils.lr_scheduler import LinearWarmupLR from opensora.utils.misc import ( @@ -46,12 +48,16 @@ def main(): cfg_dtype = cfg.get("dtype", "bf16") assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}" dtype = to_torch_dtype(cfg.get("dtype", "bf16")) + checkpoint_io = CheckpointIO() # == colossalai init distributed training == # NOTE: A very large timeout is set to avoid some processes exit early dist.init_process_group(backend="nccl", timeout=timedelta(hours=24)) torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) set_seed(cfg.get("seed", 1024)) + PinMemoryCache.force_dtype = dtype + pin_memory_cache_pre_alloc_numels = cfg.get("pin_memory_cache_pre_alloc_numels", []) + PinMemoryCache.pre_alloc_numels = pin_memory_cache_pre_alloc_numels coordinator = DistCoordinator() device = get_current_device() @@ -92,6 +98,7 @@ def main(): logger.info("Dataset contains %s samples.", len(dataset)) # == build dataloader == + cache_pin_memory = cfg.get("cache_pin_memory", False) dataloader_args = dict( dataset=dataset, batch_size=cfg.get("batch_size", None), @@ -102,6 +109,7 @@ def main(): pin_memory=True, process_group=get_data_parallel_group(), prefetch_factor=cfg.get("prefetch_factor", None), + cache_pin_memory=cache_pin_memory, ) dataloader, sampler = prepare_dataloader( bucket_config=cfg.get("bucket_config", None), @@ -157,11 +165,10 @@ def main(): ) # == build ema for diffusion model == - ema = deepcopy(model).to(torch.float32).to(device) + ema = deepcopy(model).cpu().to(torch.float32) requires_grad(ema, False) ema_shape_dict = record_model_param_shape(ema) ema.eval() - update_ema(ema, model, decay=0, sharded=False) # == setup loss function, build scheduler == scheduler = build_module(cfg.scheduler, SCHEDULERS) @@ -213,7 +220,7 @@ def main(): # == resume == if cfg.get("load", None) is not None: logger.info("Loading checkpoint") - ret = load( + ret = checkpoint_io.load( booster, cfg.load, model=model, @@ -226,7 +233,7 @@ def main(): start_epoch, start_step = ret logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step) - model_sharding(ema) + model_sharding(ema, device=device) # ======================================================= # 5. training loop @@ -241,6 +248,8 @@ def main(): "backward", "update_ema", "reduce_loss", + "optim", + "ckpt", ] for key in timer_keys: if record_time: @@ -262,9 +271,12 @@ def main(): total=num_steps_per_epoch, ) as pbar: for step, batch in pbar: + # if cache_pin_memory: + # print(f"==debug== rank{dist.get_rank()} {dataloader_iter.get_cache_info()}") timer_list = [] with timers["move_data"] as move_data_t: - x = batch.pop("video").to(device, dtype) # [B, C, T, H, W] + pinned_video = batch.pop("video") + x = pinned_video.to(device, dtype, non_blocking=True) # [B, C, T, H, W] y = batch.pop("text") if record_time: timer_list.append(move_data_t) @@ -303,6 +315,9 @@ def main(): if isinstance(v, torch.Tensor): model_args[k] = v.to(device, dtype) + if cache_pin_memory: + dataloader_iter.remove_cache(pinned_video) + # == diffusion loss computation == with timers["diffusion"] as loss_t: loss_dict = scheduler.training_losses(model, x, model_args, mask=mask) @@ -313,6 +328,10 @@ def main(): with timers["backward"] as backward_t: loss = loss_dict["loss"].mean() booster.backward(loss=loss, optimizer=optimizer) + if record_time: + timer_list.append(backward_t) + + with timers["optim"] as optim_t: optimizer.step() optimizer.zero_grad() @@ -320,7 +339,7 @@ def main(): if lr_scheduler is not None: lr_scheduler.step() if record_time: - timer_list.append(backward_t) + timer_list.append(optim_t) # == update EMA == with timers["update_ema"] as ema_t: @@ -372,32 +391,42 @@ def main(): running_loss = 0.0 log_step = 0 + # == uncomment to clear ram cache == + # if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0 and coordinator.is_master(): + # subprocess.run("sync && sudo sh -c \"echo 3 > /proc/sys/vm/drop_caches\"", shell=True) + # == checkpoint saving == ckpt_every = cfg.get("ckpt_every", 0) - if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0: - model_gathering(ema, ema_shape_dict) - save_dir = save( - booster, - exp_dir, - model=model, - ema=ema, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - sampler=sampler, - epoch=epoch, - step=step + 1, - global_step=global_step + 1, - batch_size=cfg.get("batch_size", None), - ) - if dist.get_rank() == 0: - model_sharding(ema) - logger.info( - "Saved checkpoint at epoch %s, step %s, global_step %s to %s", - epoch, - step + 1, - global_step + 1, - save_dir, - ) + with timers["ckpt"] as ckpt_t: + if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0: + save_dir = checkpoint_io.save( + booster, + exp_dir, + model=model, + ema=ema, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + sampler=sampler, + epoch=epoch, + step=step + 1, + global_step=global_step + 1, + batch_size=cfg.get("batch_size", None), + ema_shape_dict=ema_shape_dict, + async_io=True, + ) + logger.info( + "Saved checkpoint at epoch %s, step %s, global_step %s to %s", + epoch, + step + 1, + global_step + 1, + save_dir, + ) + if record_time: + timer_list.append(ckpt_t) + # uncomment below 3 lines to benchmark checkpoint + # if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0: + # booster.checkpoint_io._sync_io() + # checkpoint_io._sync_io() if record_time: log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | " for timer in timer_list: