diff --git a/README.md b/README.md
index 2a0e993a..4473022d 100644
--- a/README.md
+++ b/README.md
@@ -46,7 +46,7 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi
🔥 You can experience Open-Sora on our [🤗 Gradio application on Hugging Face](https://huggingface.co/spaces/hpcai-tech/open-sora). More samples and corresponding prompts are available in our [Gallery](https://hpcaitech.github.io/Open-Sora/).
| **4s 720×1280** | **4s 720×1280** | **4s 720×1280** |
-| ---------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- |
+|------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------|
| [](https://github.com/hpcaitech/Open-Sora/assets/99191637/7895aab6-ed23-488c-8486-091480c26327) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/20f07c7b-182b-4562-bbee-f1df74c86c9a) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/3d897e0d-dc21-453a-b911-b3bda838acc2) |
| [](https://github.com/hpcaitech/Open-Sora/assets/99191637/644bf938-96ce-44aa-b797-b3c0b513d64c) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/272d88ac-4b4a-484d-a665-8d07431671d0) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/ebbac621-c34e-4bb4-9543-1c34f8989764) |
| [](https://github.com/hpcaitech/Open-Sora/assets/99191637/a1e3a1a3-4abd-45f5-8df2-6cced69da4ca) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/d6ce9c13-28e1-4dff-9644-cc01f5f11926) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/561978f8-f1b0-4f4d-ae7b-45bec9001b4a) |
@@ -55,16 +55,16 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi
OpenSora 1.1 Demo
| **2s 240×426** | **2s 240×426** |
-| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
+|-------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c31ebc52-de39-4a4e-9b1e-9211d45e05b2) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c31ebc52-de39-4a4e-9b1e-9211d45e05b2) |
| [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/f7ce4aaa-528f-40a8-be7a-72e61eaacbbd) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/5d58d71e-1fda-4d90-9ad3-5f2f7b75c6a9) |
| **2s 426×240** | **4s 480×854** |
-| ---------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
+|------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/34ecb4a0-4eef-4286-ad4c-8e3a87e5a9fd) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/c1619333-25d7-42ba-a91c-18dbc1870b18) |
| **16s 320×320** | **16s 224×448** | **2s 426×240** |
-| ------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------- |
+|--------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------|
| [](https://github.com/hpcaitech/Open-Sora/assets/99191637/3cab536e-9b43-4b33-8da8-a0f9cf842ff2) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/9fb0b9e0-c6f4-4935-b29e-4cac10b373c4) | [](https://github.com/hpcaitech/Open-Sora-dev/assets/99191637/3e892ad2-9543-4049-b005-643a4c1bf3bf) |
@@ -73,7 +73,7 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi
OpenSora 1.0 Demo
| **2s 512×512** | **2s 512×512** | **2s 512×512** |
-| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------- |
+|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------|
| [](https://github.com/hpcaitech/Open-Sora/assets/99191637/de1963d3-b43b-4e68-a670-bb821ebb6f80) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/13f8338f-3d42-4b71-8142-d234fbd746cc) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/fa6a65a6-e32a-4d64-9a9e-eabb0ebb8c16) |
| A serene night scene in a forested area. [...] The video is a time-lapse, capturing the transition from day to night, with the lake and forest serving as a constant backdrop. | A soaring drone footage captures the majestic beauty of a coastal cliff, [...] The water gently laps at the rock base and the greenery that clings to the top of the cliff. | The majestic beauty of a waterfall cascading down a cliff into a serene lake. [...] The camera angle provides a bird's eye view of the waterfall. |
| [](https://github.com/hpcaitech/Open-Sora/assets/99191637/64232f84-1b36-4750-a6c0-3e610fa9aa94) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/983a1965-a374-41a7-a76b-c07941a6c1e9) | [](https://github.com/hpcaitech/Open-Sora/assets/99191637/ec10c879-9767-4c31-865f-2e8d6cf11e65) |
@@ -190,6 +190,12 @@ pip install -r requirements/requirements-cu121.txt
# the default installation is for inference only
pip install -v . # for development mode, `pip install -v -e .`
+
+# install the latest tensornvme to use async checkpoint saving
+pip install git+https://github.com/hpcaitech/TensorNVMe.git
+
+# install the latest colossalai to use the latest features
+pip install git+https://github.com/hpcaitech/ColossalAI.git
```
(Optional, recommended for fast speed, especially for training) To enable `layernorm_kernel` and `flash_attn`, you need to install `apex` and `flash-attn` with the following commands.
@@ -224,7 +230,7 @@ docker run -ti --gpus all -v .:/workspace/Open-Sora opensora
### Open-Sora 1.2 Model Weights
| Model | Model Size | Data | #iterations | Batch Size | URL |
-| --------- | ---------- | ---- | ----------- | ---------- | ------------------------------------------------------------- |
+|-----------|------------|------|-------------|------------|---------------------------------------------------------------|
| Diffusion | 1.1B | 30M | 70k | Dynamic | [:link:](https://huggingface.co/hpcai-tech/OpenSora-STDiT-v3) |
| VAE | 384M | 3M | 1M | 8 | [:link:](https://huggingface.co/hpcai-tech/OpenSora-VAE-v1.2) |
@@ -238,7 +244,7 @@ See our **[report 1.2](docs/report_03.md)** for more infomation. Weight will be
View more
| Resolution | Model Size | Data | #iterations | Batch Size | URL |
-| ------------------ | ---------- | -------------------------- | ----------- | ------------------------------------------------- | -------------------------------------------------------------------- |
+|--------------------|------------|----------------------------|-------------|---------------------------------------------------|----------------------------------------------------------------------|
| mainly 144p & 240p | 700M | 10M videos + 2M images | 100k | [dynamic](/configs/opensora-v1-1/train/stage2.py) | [:link:](https://huggingface.co/hpcai-tech/OpenSora-STDiT-v2-stage2) |
| 144p to 720p | 700M | 500K HQ videos + 1M images | 4k | [dynamic](/configs/opensora-v1-1/train/stage3.py) | [:link:](https://huggingface.co/hpcai-tech/OpenSora-STDiT-v2-stage3) |
@@ -254,7 +260,7 @@ See our **[report 1.1](docs/report_02.md)** for more infomation.
View more
| Resolution | Model Size | Data | #iterations | Batch Size | GPU days (H800) | URL |
-| ---------- | ---------- | ------ | ----------- | ---------- | --------------- | --------------------------------------------------------------------------------------------- |
+|------------|------------|--------|-------------|------------|-----------------|-----------------------------------------------------------------------------------------------|
| 16×512×512 | 700M | 20K HQ | 20k | 2×64 | 35 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x512x512.pth) |
| 16×256×256 | 700M | 20K HQ | 24k | 8×64 | 45 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x256x256.pth) |
| 16×256×256 | 700M | 366K | 80k | 8×64 | 117 | [:link:](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-16x256x256.pth) |
@@ -303,7 +309,7 @@ The easiest way to generate a video is to input a text prompt and click the "**G
Then, you can choose the **resolution**, **duration**, and **aspect ratio** of the generated video. Different resolution and video length will affect the video generation speed. On a 80G H100 GPU, the generation speed (with `num_sampling_step=30`) and peak memory usage is:
| | Image | 2s | 4s | 8s | 16s |
-| ---- | ------- | -------- | --------- | --------- | --------- |
+|------|---------|----------|-----------|-----------|-----------|
| 360p | 3s, 24G | 18s, 27G | 31s, 27G | 62s, 28G | 121s, 33G |
| 480p | 2s, 24G | 29s, 31G | 55s, 30G | 108s, 32G | 219s, 36G |
| 720p | 6s, 27G | 68s, 41G | 130s, 39G | 260s, 45G | 547s, 67G |
@@ -446,6 +452,9 @@ Also check out the [datasets](docs/datasets.md) we use.
The training process is same as Open-Sora 1.1.
```bash
+# If you use async checkpoint saving, and you want to validate the integrity of checkpoints, you can use the following command
+# Then there will be a `async_file_io.log` in checkpoint directory. If the number of lines of the log file is not equal to the number of checkpoints (.safetensors files), there may be some errors.
+export TENSORNVME_DEBUG=1
# one node
torchrun --standalone --nproc_per_node 8 scripts/train.py \
configs/opensora-v1-2/train/stage1.py --data-path YOUR_CSV_PATH --ckpt-path YOUR_PRETRAINED_CKPT
@@ -510,7 +519,7 @@ We support evaluation based on:
All the evaluation code is released in `eval` folder. Check the [README](/eval/README.md) for more details. Our [report](/docs/report_03.md#evaluation) also provides more information about the evaluation during training. The following table shows Open-Sora 1.2 greatly improves Open-Sora 1.0.
| Model | Total Score | Quality Score | Semantic Score |
-| -------------- | ----------- | ------------- | -------------- |
+|----------------|-------------|---------------|----------------|
| Open-Sora V1.0 | 75.91% | 78.81% | 64.28% |
| Open-Sora V1.2 | 79.23% | 80.71% | 73.30% |
diff --git a/configs/opensora-v1-2/train/stage1.py b/configs/opensora-v1-2/train/stage1.py
index e1f1e7c0..cbaa20c8 100644
--- a/configs/opensora-v1-2/train/stage1.py
+++ b/configs/opensora-v1-2/train/stage1.py
@@ -108,3 +108,6 @@
ema_decay = 0.99
adam_eps = 1e-15
warmup_steps = 1000
+
+cache_pin_memory = True
+pin_memory_cache_pre_alloc_numels = [(290 + 20) * 1024**2] * (2 * 8 + 4)
diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py
index 60d0b24a..97df0a16 100644
--- a/opensora/datasets/dataloader.py
+++ b/opensora/datasets/dataloader.py
@@ -1,17 +1,216 @@
import collections
+import functools
+import queue
import random
+import threading
from typing import Optional
import numpy as np
import torch
+import torch.multiprocessing as multiprocessing
+from torch._utils import ExceptionWrapper
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
-from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, _utils
+from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
+from torch.utils.data.dataloader import (
+ IterDataPipe,
+ MapDataPipe,
+ _BaseDataLoaderIter,
+ _MultiProcessingDataLoaderIter,
+ _sharding_worker_init_fn,
+ _SingleProcessDataLoaderIter,
+)
from .datasets import BatchFeatureDataset, VariableVideoTextDataset, VideoTextDataset
+from .pin_memory_cache import PinMemoryCache
from .sampler import BatchDistributedSampler, StatefulDistributedSampler, VariableVideoBatchSampler
+def _pin_memory_loop(
+ in_queue, out_queue, device_id, done_event, device, pin_memory_cache: PinMemoryCache, pin_memory_key: str
+):
+ # This setting is thread local, and prevents the copy in pin_memory from
+ # consuming all CPU cores.
+ torch.set_num_threads(1)
+
+ if device == "cuda":
+ torch.cuda.set_device(device_id)
+ elif device == "xpu":
+ torch.xpu.set_device(device_id) # type: ignore[attr-defined]
+ elif device == torch._C._get_privateuse1_backend_name():
+ custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
+ custom_device_mod.set_device(device_id)
+
+ def do_one_step():
+ try:
+ r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+ except queue.Empty:
+ return
+ idx, data = r
+ if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
+ try:
+ assert isinstance(data, dict)
+ if pin_memory_key in data:
+ val = data[pin_memory_key]
+ pin_memory_value = pin_memory_cache.get(val)
+ pin_memory_value.copy_(val)
+ data[pin_memory_key] = pin_memory_value
+ except Exception:
+ data = ExceptionWrapper(where=f"in pin memory thread for device {device_id}")
+ r = (idx, data)
+ while not done_event.is_set():
+ try:
+ out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
+ break
+ except queue.Full:
+ continue
+
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+ # logic of this function.
+ while not done_event.is_set():
+ # Make sure that we don't preserve any object from one iteration
+ # to the next
+ do_one_step()
+
+
+class _MultiProcessingDataLoaderIterForVideo(_MultiProcessingDataLoaderIter):
+ pin_memory_key: str = "video"
+
+ def __init__(self, loader):
+ _BaseDataLoaderIter.__init__(self, loader)
+ self.pin_memory_cache = PinMemoryCache()
+
+ self._prefetch_factor = loader.prefetch_factor
+
+ assert self._num_workers > 0
+ assert self._prefetch_factor > 0
+
+ if loader.multiprocessing_context is None:
+ multiprocessing_context = multiprocessing
+ else:
+ multiprocessing_context = loader.multiprocessing_context
+
+ self._worker_init_fn = loader.worker_init_fn
+
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
+ # Additional worker init function will take care of sharding in MP and Distributed
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
+ self._worker_init_fn = functools.partial(
+ _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank
+ )
+
+ # No certainty which module multiprocessing_context is
+ self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
+ self._worker_pids_set = False
+ self._shutdown = False
+ self._workers_done_event = multiprocessing_context.Event()
+
+ self._index_queues = []
+ self._workers = []
+ for i in range(self._num_workers):
+ # No certainty which module multiprocessing_context is
+ index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
+ # Need to `cancel_join_thread` here!
+ # See sections (2) and (3b) above.
+ index_queue.cancel_join_thread()
+ w = multiprocessing_context.Process(
+ target=_utils.worker._worker_loop,
+ args=(
+ self._dataset_kind,
+ self._dataset,
+ index_queue,
+ self._worker_result_queue,
+ self._workers_done_event,
+ self._auto_collation,
+ self._collate_fn,
+ self._drop_last,
+ self._base_seed,
+ self._worker_init_fn,
+ i,
+ self._num_workers,
+ self._persistent_workers,
+ self._shared_seed,
+ ),
+ )
+ w.daemon = True
+ # NB: Process.start() actually take some time as it needs to
+ # start a process and pass the arguments over via a pipe.
+ # Therefore, we only add a worker to self._workers list after
+ # it started, so that we do not call .join() if program dies
+ # before it starts, and __del__ tries to join but will get:
+ # AssertionError: can only join a started process.
+ w.start()
+ self._index_queues.append(index_queue)
+ self._workers.append(w)
+
+ if self._pin_memory:
+ self._pin_memory_thread_done_event = threading.Event()
+
+ # Queue is not type-annotated
+ self._data_queue = queue.Queue() # type: ignore[var-annotated]
+ if self._pin_memory_device == "xpu":
+ current_device = torch.xpu.current_device() # type: ignore[attr-defined]
+ elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
+ custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
+ current_device = custom_device_mod.current_device()
+ else:
+ current_device = torch.cuda.current_device() # choose cuda for default
+ pin_memory_thread = threading.Thread(
+ target=_pin_memory_loop,
+ args=(
+ self._worker_result_queue,
+ self._data_queue,
+ current_device,
+ self._pin_memory_thread_done_event,
+ self._pin_memory_device,
+ self.pin_memory_cache,
+ self.pin_memory_key,
+ ),
+ )
+ pin_memory_thread.daemon = True
+ pin_memory_thread.start()
+ # Similar to workers (see comment above), we only register
+ # pin_memory_thread once it is started.
+ self._pin_memory_thread = pin_memory_thread
+ else:
+ self._data_queue = self._worker_result_queue # type: ignore[assignment]
+
+ # In some rare cases, persistent workers (daemonic processes)
+ # would be terminated before `__del__` of iterator is invoked
+ # when main process exits
+ # It would cause failure when pin_memory_thread tries to read
+ # corrupted data from worker_result_queue
+ # atexit is used to shutdown thread and child processes in the
+ # right sequence before main process exits
+ if self._persistent_workers and self._pin_memory:
+ import atexit
+
+ for w in self._workers:
+ atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
+
+ # .pid can be None only before process is spawned (not the case, so ignore)
+ _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
+ _utils.signal_handling._set_SIGCHLD_handler()
+ self._worker_pids_set = True
+ self._reset(loader, first_iter=True)
+
+ def remove_cache(self, output_tensor: torch.Tensor):
+ self.pin_memory_cache.remove(output_tensor)
+
+ def get_cache_info(self) -> str:
+ return str(self.pin_memory_cache)
+
+
+class DataloaderForVideo(DataLoader):
+ def _get_iterator(self) -> "_BaseDataLoaderIter":
+ if self.num_workers == 0:
+ return _SingleProcessDataLoaderIter(self)
+ else:
+ self.check_worker_number_rationality()
+ return _MultiProcessingDataLoaderIterForVideo(self)
+
+
# Deterministic dataloader
def get_seed_worker(seed):
def seed_worker(worker_id):
@@ -35,6 +234,7 @@ def prepare_dataloader(
bucket_config=None,
num_bucket_build_workers=1,
prefetch_factor=None,
+ cache_pin_memory=False,
**kwargs,
):
_kwargs = kwargs.copy()
@@ -50,8 +250,9 @@ def prepare_dataloader(
verbose=True,
num_bucket_build_workers=num_bucket_build_workers,
)
+ dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader
return (
- DataLoader(
+ dl_cls(
dataset,
batch_sampler=batch_sampler,
worker_init_fn=get_seed_worker(seed),
@@ -71,8 +272,9 @@ def prepare_dataloader(
rank=process_group.rank(),
shuffle=shuffle,
)
+ dl_cls = DataloaderForVideo if cache_pin_memory else DataLoader
return (
- DataLoader(
+ dl_cls(
dataset,
batch_size=batch_size,
sampler=sampler,
@@ -137,7 +339,7 @@ def collate_fn_batch(batch):
"""
# filter out None
batch = [x for x in batch if x is not None]
-
+
res = torch.utils.data.default_collate(batch)
# squeeze the first dimension, which is due to torch.stack() in default_collate()
diff --git a/opensora/datasets/pin_memory_cache.py b/opensora/datasets/pin_memory_cache.py
new file mode 100644
index 00000000..3f6e7559
--- /dev/null
+++ b/opensora/datasets/pin_memory_cache.py
@@ -0,0 +1,76 @@
+import threading
+from typing import Dict, List, Optional
+
+import torch
+
+
+class PinMemoryCache:
+ force_dtype: Optional[torch.dtype] = None
+ min_cache_numel: int = 0
+ pre_alloc_numels: List[int] = []
+
+ def __init__(self):
+ self.cache: Dict[int, torch.Tensor] = {}
+ self.output_to_cache: Dict[int, int] = {}
+ self.cache_to_output: Dict[int, int] = {}
+ self.lock = threading.Lock()
+ self.total_cnt = 0
+ self.hit_cnt = 0
+
+ if len(self.pre_alloc_numels) > 0 and self.force_dtype is not None:
+ for n in self.pre_alloc_numels:
+ cache_tensor = torch.empty(n, dtype=self.force_dtype, device="cpu", pin_memory=True)
+ with self.lock:
+ self.cache[id(cache_tensor)] = cache_tensor
+
+ def get(self, tensor: torch.Tensor) -> torch.Tensor:
+ """Receive a cpu tensor and return the corresponding pinned tensor. Note that this only manage memory allocation, doesn't copy content.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be pinned.
+
+ Returns:
+ torch.Tensor: The pinned tensor.
+ """
+ self.total_cnt += 1
+ with self.lock:
+ # find free cache
+ for cache_id, cache_tensor in self.cache.items():
+ if cache_id not in self.cache_to_output and cache_tensor.numel() >= tensor.numel():
+ target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
+ out_id = id(target_cache_tensor)
+ self.output_to_cache[out_id] = cache_id
+ self.cache_to_output[cache_id] = out_id
+ self.hit_cnt += 1
+ return target_cache_tensor
+ # no free cache, create a new one
+ dtype = self.force_dtype if self.force_dtype is not None else tensor.dtype
+ cache_numel = max(tensor.numel(), self.min_cache_numel)
+ cache_tensor = torch.empty(cache_numel, dtype=dtype, device="cpu", pin_memory=True)
+ target_cache_tensor = cache_tensor[: tensor.numel()].view(tensor.shape)
+ out_id = id(target_cache_tensor)
+ with self.lock:
+ self.cache[id(cache_tensor)] = cache_tensor
+ self.output_to_cache[out_id] = id(cache_tensor)
+ self.cache_to_output[id(cache_tensor)] = out_id
+ return target_cache_tensor
+
+ def remove(self, output_tensor: torch.Tensor) -> None:
+ """Release corresponding cache tensor.
+
+ Args:
+ output_tensor (torch.Tensor): The tensor to be released.
+ """
+ out_id = id(output_tensor)
+ with self.lock:
+ if out_id not in self.output_to_cache:
+ raise ValueError("Tensor not found in cache.")
+ cache_id = self.output_to_cache.pop(out_id)
+ del self.cache_to_output[cache_id]
+
+ def __str__(self):
+ with self.lock:
+ num_cached = len(self.cache)
+ num_used = len(self.output_to_cache)
+ total_cache_size = sum([v.numel() * v.element_size() for v in self.cache.values()])
+ return f"PinMemoryCache(num_cached={num_cached}, num_used={num_used}, total_cache_size={total_cache_size / 1024**3:.2f} GB, hit rate={self.hit_cnt / self.total_cnt:.2f})"
diff --git a/opensora/models/text_encoder/t5.py b/opensora/models/text_encoder/t5.py
index d67328a8..6378c347 100644
--- a/opensora/models/text_encoder/t5.py
+++ b/opensora/models/text_encoder/t5.py
@@ -170,14 +170,8 @@ def shardformer_t5(self):
from opensora.utils.misc import requires_grad
shard_config = ShardConfig(
- tensor_parallel_process_group=None,
- pipeline_stage_manager=None,
enable_tensor_parallelism=False,
- enable_fused_normalization=False,
- enable_flash_attention=False,
enable_jit_fused=True,
- enable_sequence_parallelism=False,
- enable_sequence_overlap=False,
)
shard_former = ShardFormer(shard_config=shard_config)
optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy())
diff --git a/opensora/utils/ckpt_utils.py b/opensora/utils/ckpt_utils.py
index d730981c..d0607ff3 100644
--- a/opensora/utils/ckpt_utils.py
+++ b/opensora/utils/ckpt_utils.py
@@ -2,13 +2,16 @@
import json
import operator
import os
-from typing import Tuple
+from typing import Dict, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.booster import Booster
from colossalai.checkpoint_io import GeneralCheckpointIO
+from colossalai.utils.safetensors import save as async_save
+from safetensors.torch import load_file
+from tensornvme.async_file_io import AsyncFileWriter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torchvision.datasets.utils import download_url
@@ -150,10 +153,19 @@ def load_from_sharded_state_dict(model, ckpt_path, model_name="model", strict=Fa
ckpt_io.load_model(model, os.path.join(ckpt_path, model_name), strict=strict)
-def model_sharding(model: torch.nn.Module):
+def model_sharding(model: torch.nn.Module, device: torch.device = None):
+ """
+ Sharding the model parameters across multiple GPUs.
+
+ Args:
+ model (torch.nn.Module): The model to shard.
+ device (torch.device): The device to shard the model to.
+ """
global_rank = dist.get_rank()
world_size = dist.get_world_size()
for _, param in model.named_parameters():
+ if device is None:
+ device = param.device
padding_size = (world_size - param.numel() % world_size) % world_size
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
@@ -161,18 +173,34 @@ def model_sharding(model: torch.nn.Module):
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // world_size)
splited_params = splited_params[global_rank]
- param.data = splited_params
+ param.data = splited_params.to(device)
-def model_gathering(model: torch.nn.Module, model_shape_dict: dict):
+def model_gathering(model: torch.nn.Module, model_shape_dict: dict, pinned_state_dict: dict) -> None:
+ """
+ Gather the model parameters from multiple GPUs.
+
+ Args:
+ model (torch.nn.Module): The model to gather.
+ model_shape_dict (dict): The shape of the model parameters.
+ device (torch.device): The device to gather the model to.
+ """
global_rank = dist.get_rank()
global_size = dist.get_world_size()
+ params = set()
for name, param in model.named_parameters():
+ params.add(name)
all_params = [torch.empty_like(param.data) for _ in range(global_size)]
dist.all_gather(all_params, param.data, group=dist.group.WORLD)
if int(global_rank) == 0:
all_params = torch.cat(all_params)
- param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
+ gathered_param = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
+ pinned_state_dict[name].copy_(gathered_param)
+ if int(global_rank) == 0:
+ for k, v in model.state_dict(keep_vars=True).items():
+ if k not in params:
+ pinned_state_dict[k].copy_(v)
+
dist.barrier()
@@ -195,6 +223,7 @@ def load_checkpoint(model, ckpt_path, save_as_pt=False, model_name="model", stri
get_logger().info("Unexpected keys: %s", unexpected_keys)
elif ckpt_path.endswith(".safetensors"):
from safetensors.torch import load_file
+
state_dict = load_file(ckpt_path)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print(f"Missing keys: {missing_keys}")
@@ -223,76 +252,130 @@ def save_json(data, file_path: str):
# save and load for training
-def save(
- booster: Booster,
- save_dir: str,
- model: nn.Module = None,
- ema: nn.Module = None,
- optimizer: Optimizer = None,
- lr_scheduler: _LRScheduler = None,
- sampler=None,
- epoch: int = None,
- step: int = None,
- global_step: int = None,
- batch_size: int = None,
-):
- save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}")
- os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
-
- if model is not None:
- booster.save_model(model, os.path.join(save_dir, "model"), shard=True)
- if optimizer is not None:
- booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096)
- if lr_scheduler is not None:
- booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
- if dist.get_rank() == 0:
- running_states = {
- "epoch": epoch,
- "step": step,
- "global_step": global_step,
- "batch_size": batch_size,
- }
- save_json(running_states, os.path.join(save_dir, "running_states.json"))
-
+def _prepare_ema_pinned_state_dict(model: nn.Module, ema_shape_dict: dict):
+ ema_pinned_state_dict = dict()
+ for name, p in model.named_parameters():
+ ema_pinned_state_dict[name] = torch.empty(ema_shape_dict[name], pin_memory=True, device="cpu", dtype=p.dtype)
+ sd = model.state_dict(keep_vars=True)
+ # handle buffers
+ for k, v in sd.items():
+ if k not in ema_pinned_state_dict:
+ ema_pinned_state_dict[k] = torch.empty(v.shape, pin_memory=True, device="cpu", dtype=v.dtype)
+
+ return ema_pinned_state_dict
+
+
+class CheckpointIO:
+ def __init__(self, n_write_entries: int = 32):
+ self.n_write_entries = n_write_entries
+ self.writer: Optional[AsyncFileWriter] = None
+ self.pinned_state_dict: Optional[Dict[str, torch.Tensor]] = None
+
+ def _sync_io(self):
+ if self.writer is not None:
+ self.writer = None
+
+ def __del__(self):
+ self._sync_io()
+
+ def _prepare_pinned_state_dict(self, ema: nn.Module, ema_shape_dict: dict):
+ if self.pinned_state_dict is None and dist.get_rank() == 0:
+ self.pinned_state_dict = _prepare_ema_pinned_state_dict(ema, ema_shape_dict)
+
+ def save(
+ self,
+ booster: Booster,
+ save_dir: str,
+ model: nn.Module = None,
+ ema: nn.Module = None,
+ optimizer: Optimizer = None,
+ lr_scheduler: _LRScheduler = None,
+ sampler=None,
+ epoch: int = None,
+ step: int = None,
+ global_step: int = None,
+ batch_size: int = None,
+ ema_shape_dict: dict = None,
+ async_io: bool = True,
+ ):
+ self._sync_io()
+ save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}")
+ os.environ["TENSORNVME_DEBUG_LOG"] = os.path.join(save_dir, "async_file_io.log")
+ os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
+
+ if model is not None:
+ booster.save_model(
+ model,
+ os.path.join(save_dir, "model"),
+ shard=True,
+ use_safetensors=True,
+ size_per_shard=4096,
+ use_async=async_io,
+ )
+ if optimizer is not None:
+ booster.save_optimizer(
+ optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096, use_async=async_io
+ )
+ if lr_scheduler is not None:
+ booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
if ema is not None:
- torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
-
+ self._prepare_pinned_state_dict(ema, ema_shape_dict)
+ model_gathering(ema, ema_shape_dict, self.pinned_state_dict)
+ if dist.get_rank() == 0:
+ running_states = {
+ "epoch": epoch,
+ "step": step,
+ "global_step": global_step,
+ "batch_size": batch_size,
+ }
+ save_json(running_states, os.path.join(save_dir, "running_states.json"))
+
+ if ema is not None:
+ if async_io:
+ self.writer = async_save(os.path.join(save_dir, "ema.safetensors"), self.pinned_state_dict)
+ else:
+ torch.save(self.pinned_state_dict, os.path.join(save_dir, "ema.pt"))
+
+ if sampler is not None:
+ # only for VariableVideoBatchSampler
+ torch.save(sampler.state_dict(step), os.path.join(save_dir, "sampler"))
+ dist.barrier()
+ return save_dir
+
+ def load(
+ self,
+ booster: Booster,
+ load_dir: str,
+ model: nn.Module = None,
+ ema: nn.Module = None,
+ optimizer: Optimizer = None,
+ lr_scheduler: _LRScheduler = None,
+ sampler=None,
+ ) -> Tuple[int, int, int]:
+ assert os.path.exists(load_dir), f"Checkpoint directory {load_dir} does not exist"
+ assert os.path.exists(os.path.join(load_dir, "running_states.json")), "running_states.json does not exist"
+ running_states = load_json(os.path.join(load_dir, "running_states.json"))
+ if model is not None:
+ booster.load_model(model, os.path.join(load_dir, "model"))
+ if ema is not None:
+ # ema is not boosted, so we don't use booster.load_model
+ if os.path.exists(os.path.join(load_dir, "ema.safetensors")):
+ ema_state_dict = load_file(os.path.join(load_dir, "ema.safetensors"))
+ else:
+ ema_state_dict = torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"))
+ ema.load_state_dict(
+ ema_state_dict,
+ strict=False,
+ )
+ if optimizer is not None:
+ booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
+ if lr_scheduler is not None:
+ booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
if sampler is not None:
- # only for VariableVideoBatchSampler
- torch.save(sampler.state_dict(step), os.path.join(save_dir, "sampler"))
- dist.barrier()
- return save_dir
-
-
-def load(
- booster: Booster,
- load_dir: str,
- model: nn.Module = None,
- ema: nn.Module = None,
- optimizer: Optimizer = None,
- lr_scheduler: _LRScheduler = None,
- sampler=None,
-) -> Tuple[int, int, int]:
- assert os.path.exists(load_dir), f"Checkpoint directory {load_dir} does not exist"
- assert os.path.exists(os.path.join(load_dir, "running_states.json")), "running_states.json does not exist"
- running_states = load_json(os.path.join(load_dir, "running_states.json"))
- if model is not None:
- booster.load_model(model, os.path.join(load_dir, "model"))
- if ema is not None:
- # ema is not boosted, so we don't use booster.load_model
- ema.load_state_dict(
- torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu")),
- strict=False,
- )
- if optimizer is not None:
- booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer"))
- if lr_scheduler is not None:
- booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
- if sampler is not None:
- sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")))
- dist.barrier()
+ sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")))
+ dist.barrier()
- return (
- running_states["epoch"],
- running_states["step"],
- )
+ return (
+ running_states["epoch"],
+ running_states["step"],
+ )
diff --git a/opensora/utils/train_utils.py b/opensora/utils/train_utils.py
index 95d0011b..f015b44c 100644
--- a/opensora/utils/train_utils.py
+++ b/opensora/utils/train_utils.py
@@ -61,7 +61,7 @@ def update_ema(
else:
if param.data.dtype != torch.float32:
param_id = id(param)
- master_param = optimizer._param_store.working_to_master_param[param_id]
+ master_param = optimizer.get_working_to_master_map()[param_id]
param_data = master_param.data
else:
param_data = param.data
diff --git a/scripts/train.py b/scripts/train.py
index 110f2f84..e41857f6 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -1,4 +1,5 @@
import os
+import subprocess
from contextlib import nullcontext
from copy import deepcopy
from datetime import timedelta
@@ -16,8 +17,9 @@
from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.dataloader import prepare_dataloader
+from opensora.datasets.pin_memory_cache import PinMemoryCache
from opensora.registry import DATASETS, MODELS, SCHEDULERS, build_module
-from opensora.utils.ckpt_utils import load, model_gathering, model_sharding, record_model_param_shape, save
+from opensora.utils.ckpt_utils import CheckpointIO, model_sharding, record_model_param_shape
from opensora.utils.config_utils import define_experiment_workspace, parse_configs, save_training_config
from opensora.utils.lr_scheduler import LinearWarmupLR
from opensora.utils.misc import (
@@ -46,12 +48,16 @@ def main():
cfg_dtype = cfg.get("dtype", "bf16")
assert cfg_dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg_dtype}"
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
+ checkpoint_io = CheckpointIO()
# == colossalai init distributed training ==
# NOTE: A very large timeout is set to avoid some processes exit early
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
set_seed(cfg.get("seed", 1024))
+ PinMemoryCache.force_dtype = dtype
+ pin_memory_cache_pre_alloc_numels = cfg.get("pin_memory_cache_pre_alloc_numels", [])
+ PinMemoryCache.pre_alloc_numels = pin_memory_cache_pre_alloc_numels
coordinator = DistCoordinator()
device = get_current_device()
@@ -92,6 +98,7 @@ def main():
logger.info("Dataset contains %s samples.", len(dataset))
# == build dataloader ==
+ cache_pin_memory = cfg.get("cache_pin_memory", False)
dataloader_args = dict(
dataset=dataset,
batch_size=cfg.get("batch_size", None),
@@ -102,6 +109,7 @@ def main():
pin_memory=True,
process_group=get_data_parallel_group(),
prefetch_factor=cfg.get("prefetch_factor", None),
+ cache_pin_memory=cache_pin_memory,
)
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
@@ -157,11 +165,10 @@ def main():
)
# == build ema for diffusion model ==
- ema = deepcopy(model).to(torch.float32).to(device)
+ ema = deepcopy(model).cpu().to(torch.float32)
requires_grad(ema, False)
ema_shape_dict = record_model_param_shape(ema)
ema.eval()
- update_ema(ema, model, decay=0, sharded=False)
# == setup loss function, build scheduler ==
scheduler = build_module(cfg.scheduler, SCHEDULERS)
@@ -213,7 +220,7 @@ def main():
# == resume ==
if cfg.get("load", None) is not None:
logger.info("Loading checkpoint")
- ret = load(
+ ret = checkpoint_io.load(
booster,
cfg.load,
model=model,
@@ -226,7 +233,7 @@ def main():
start_epoch, start_step = ret
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, start_epoch, start_step)
- model_sharding(ema)
+ model_sharding(ema, device=device)
# =======================================================
# 5. training loop
@@ -241,6 +248,8 @@ def main():
"backward",
"update_ema",
"reduce_loss",
+ "optim",
+ "ckpt",
]
for key in timer_keys:
if record_time:
@@ -262,9 +271,12 @@ def main():
total=num_steps_per_epoch,
) as pbar:
for step, batch in pbar:
+ # if cache_pin_memory:
+ # print(f"==debug== rank{dist.get_rank()} {dataloader_iter.get_cache_info()}")
timer_list = []
with timers["move_data"] as move_data_t:
- x = batch.pop("video").to(device, dtype) # [B, C, T, H, W]
+ pinned_video = batch.pop("video")
+ x = pinned_video.to(device, dtype, non_blocking=True) # [B, C, T, H, W]
y = batch.pop("text")
if record_time:
timer_list.append(move_data_t)
@@ -303,6 +315,9 @@ def main():
if isinstance(v, torch.Tensor):
model_args[k] = v.to(device, dtype)
+ if cache_pin_memory:
+ dataloader_iter.remove_cache(pinned_video)
+
# == diffusion loss computation ==
with timers["diffusion"] as loss_t:
loss_dict = scheduler.training_losses(model, x, model_args, mask=mask)
@@ -313,6 +328,10 @@ def main():
with timers["backward"] as backward_t:
loss = loss_dict["loss"].mean()
booster.backward(loss=loss, optimizer=optimizer)
+ if record_time:
+ timer_list.append(backward_t)
+
+ with timers["optim"] as optim_t:
optimizer.step()
optimizer.zero_grad()
@@ -320,7 +339,7 @@ def main():
if lr_scheduler is not None:
lr_scheduler.step()
if record_time:
- timer_list.append(backward_t)
+ timer_list.append(optim_t)
# == update EMA ==
with timers["update_ema"] as ema_t:
@@ -372,32 +391,42 @@ def main():
running_loss = 0.0
log_step = 0
+ # == uncomment to clear ram cache ==
+ # if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0 and coordinator.is_master():
+ # subprocess.run("sync && sudo sh -c \"echo 3 > /proc/sys/vm/drop_caches\"", shell=True)
+
# == checkpoint saving ==
ckpt_every = cfg.get("ckpt_every", 0)
- if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
- model_gathering(ema, ema_shape_dict)
- save_dir = save(
- booster,
- exp_dir,
- model=model,
- ema=ema,
- optimizer=optimizer,
- lr_scheduler=lr_scheduler,
- sampler=sampler,
- epoch=epoch,
- step=step + 1,
- global_step=global_step + 1,
- batch_size=cfg.get("batch_size", None),
- )
- if dist.get_rank() == 0:
- model_sharding(ema)
- logger.info(
- "Saved checkpoint at epoch %s, step %s, global_step %s to %s",
- epoch,
- step + 1,
- global_step + 1,
- save_dir,
- )
+ with timers["ckpt"] as ckpt_t:
+ if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
+ save_dir = checkpoint_io.save(
+ booster,
+ exp_dir,
+ model=model,
+ ema=ema,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ sampler=sampler,
+ epoch=epoch,
+ step=step + 1,
+ global_step=global_step + 1,
+ batch_size=cfg.get("batch_size", None),
+ ema_shape_dict=ema_shape_dict,
+ async_io=True,
+ )
+ logger.info(
+ "Saved checkpoint at epoch %s, step %s, global_step %s to %s",
+ epoch,
+ step + 1,
+ global_step + 1,
+ save_dir,
+ )
+ if record_time:
+ timer_list.append(ckpt_t)
+ # uncomment below 3 lines to benchmark checkpoint
+ # if ckpt_every > 0 and (global_step + 1) % ckpt_every == 0:
+ # booster.checkpoint_io._sync_io()
+ # checkpoint_io._sync_io()
if record_time:
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
for timer in timer_list: