-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is there way to offload training memory to DRAM (using FSDP2?) for training Llama3-8B with torchtitan? #620
Comments
Maybe try to decrease the batch size from 2 to 1 first? |
Hi @0781532 - 1 - move from selective op activation checkpointing to 'full' and that should let you fit with 2 and potentially higher batch sizes without the need to offload.
2 - We have some experimental work that offloads the skeletal activations that remain even after running with full checkpointing and that can further reduce GPU memory with generally minor impact on perf. (roughly 10% reduction in GPU memory, 2% or so slower perf). |
Hi @lessw2020 , Thank you for your quick reply. I still got the same error OOM when using your suggestion (batch_size = 1, or mode = 'full). ***Do you have any other solution for this issues? ***Or supports training config for smaller Llama model as Llama-3.2-1B or 3B? Error message: ============================================================
|
Hi @0781532 - based on the above error ("optimizers.step()") I think you are making it to through the first forward pass and then OOMing during the backward pass (due to the init of optimizer states). |
Hi @lessw2020 , I really need to run training Llama-3-8B. Could you please support a training config or a branch to offload the optimizer states to cpu? Thank you very much! |
Hi @0781532 - I tested CPUOffloading earlier and found it has some issues. However, Wei has temp resolved these in the PR above or here: |
Just sharing for completeness: if you want to run Llama3-8B with AdamW, you cannot do so on 2 GPUs with 48 GB memory each without some kind of offloading.
Either you can use a lighter-weight optimizer, or you have to offload something (e.g. optimizer states) to CPU. |
Dear @lessw2020 and @awgu , Thank you very much for your quick help! I am now successfully running Llama3-8B with your "cpu_offload commit" with batch_size=4. The VRAM loading states for each GPU is around 18GB (in total 2 GPU A6000ada 48G). I am still trying with higher values of batch_size to check the limitation of my HW for this training case. I have some further questions about the config settings:
If Yes, that means we are applying PP+TPP+FSDP at the same time?
I read some refs said that only NVIDIA Center GPUs like A5000, A6000, A6000ada or higher work. Thanks, |
Hi @0781532, re 3 - tensor parallel is applied at the PyTorch level and is implemented independent of any hardware. It does rely on NCCL for the communication but that's also a software aspect. Thus, you can certainly run TP on your hardware or really most reasonably new gpus as there is no direct hardware reliance. Hope this helps! |
resolve pytorch#620 Add config: `--training.enable_cpu_offload` Command: `CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh` For non-pp case: <img width="611" alt="Screenshot 2024-10-23 at 1 45 56 PM" src="https://github.com/user-attachments/assets/8692f8a6-c0f3-460e-8eb6-7f7195bed370"> For pp case: <img width="587" alt="cpu offload+pp" src="https://github.com/user-attachments/assets/73e40861-47e2-4845-a41c-4bfea2860109">
I am training Llama3-8B using 2 RTX A6000ada 48GB, but got OOM. Is there way to offload training memory to DRAM (using FSDP2?) for training Llama3-8B with torchtitan?
Thanks!
***Error message:
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU 0 has a total capacity of 47.48 GiB of which 92.81 MiB is free. Including non-PyTorch memory, this process has 46.71 GiB memory in use. Of the allocated memory 45.56 GiB is allocated by PyTorch, and 448.27 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
***Here is my training config:
torchtitan Config.toml
NOTE: this toml config is a preset for 64 A100 GPUs.
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"
[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"
[optimizer]
name = "AdamW"
lr = 3e-4
[training]
batch_size = 2 #1
seq_len = 256 #512 #8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_replicate_degree = 1 #1
data_parallel_shard_degree = -1 #-1
tensor_parallel_degree = 2 #1
compile = true
dataset = "c4"
[experimental]
pipeline_parallel_degree = 1 #1
enable_async_tensor_parallel = true
[checkpoint]
enable_checkpoint = false #false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "bfloat16" #32
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
[activation_checkpoint]
mode = 'selective' # ['none', 'selective', 'full']
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
[float8]
enable_float8_linear = true
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
The text was updated successfully, but these errors were encountered: