Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there way to offload training memory to DRAM (using FSDP2?) for training Llama3-8B with torchtitan? #620

Closed
0781532 opened this issue Oct 15, 2024 · 9 comments · Fixed by #624 · May be fixed by #622
Closed

Is there way to offload training memory to DRAM (using FSDP2?) for training Llama3-8B with torchtitan? #620

0781532 opened this issue Oct 15, 2024 · 9 comments · Fixed by #624 · May be fixed by #622
Labels
question Further information is requested

Comments

@0781532
Copy link

0781532 commented Oct 15, 2024

I am training Llama3-8B using 2 RTX A6000ada 48GB, but got OOM. Is there way to offload training memory to DRAM (using FSDP2?) for training Llama3-8B with torchtitan?

Thanks!

***Error message:
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU 0 has a total capacity of 47.48 GiB of which 92.81 MiB is free. Including non-PyTorch memory, this process has 46.71 GiB memory in use. Of the allocated memory 45.56 GiB is allocated by PyTorch, and 448.27 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

***Here is my training config:

torchtitan Config.toml

NOTE: this toml config is a preset for 64 A100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
batch_size = 2 #1
seq_len = 256 #512 #8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_replicate_degree = 1 #1
data_parallel_shard_degree = -1 #-1
tensor_parallel_degree = 2 #1
compile = true
dataset = "c4"

[experimental]
pipeline_parallel_degree = 1 #1
enable_async_tensor_parallel = true

[checkpoint]
enable_checkpoint = false #false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "bfloat16" #32
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective' # ['none', 'selective', 'full']
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_float8_linear = true
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true

@awgu
Copy link
Contributor

awgu commented Oct 15, 2024

Maybe try to decrease the batch size from 2 to 1 first?

@lessw2020
Copy link
Contributor

Hi @0781532 -
As @awgu noted, a first step might be to decrease the batch size to 1 but as additional step or if you are concerned about convergence due to smaller batch size:

1 - move from selective op activation checkpointing to 'full' and that should let you fit with 2 and potentially higher batch sizes without the need to offload.
i.e.

[activation_checkpoint]
mode = 'full' # ['none', 'selective', 'full']

2 - We have some experimental work that offloads the skeletal activations that remain even after running with full checkpointing and that can further reduce GPU memory with generally minor impact on perf. (roughly 10% reduction in GPU memory, 2% or so slower perf).
If lower bs and/or full AC doesn't resolve this then I can make a PR to enable you to try the offloading for AC.
Please let us know how it goes!

@tianyu-l tianyu-l added the question Further information is requested label Oct 15, 2024
@0781532
Copy link
Author

0781532 commented Oct 16, 2024

Hi @lessw2020 ,

Thank you for your quick reply.

I still got the same error OOM when using your suggestion (batch_size = 1, or mode = 'full).

***Do you have any other solution for this issues?
"2 - We have some experimental work that offloads the skeletal activations that remain even after running with full checkpointing and that can further reduce GPU memory with generally minor impact on perf. (roughly 10% reduction in GPU memory, 2% or so slower perf)."

***Or supports training config for smaller Llama model as Llama-3.2-1B or 3B?


Error message:

============================================================
train.py FAILED

Failures:
<NO_OTHER_FAILURES>

Root Cause (first observed failure):
[0]:
time : 2024-10-16_11:13:21
host : trx50-TRX50-AERO-D
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 238568)
error_file: /tmp/torchelastic_ln6u9u9e/none__669dvmg/attempt_0/1/error.json
traceback : Traceback (most recent call last):
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 355, in wrapper
return f(*args, **kwargs)
File "/home/trx50/torchtitan/train.py", line 318, in main
optimizers.step()
File "/home/trx50/torchtitan/torchtitan/optimizer.py", line 51, in step
optimizer.step()
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 140, in wrapper
return func.get(opt, opt.class)(*args, **kwargs)
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
out = func(*args, **kwargs)
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
ret = func(self, *args, **kwargs)
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/optim/adamw.py", line 209, in step
has_complex = self._init_group(
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 654, in _fn
return fn(*args, **kwargs)
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/optim/adamw.py", line 152, in _init_group
state["exp_avg_sq"] = torch.zeros_like(
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
return disable_fn(*args, **kwargs)
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 654, in _fn
return fn(*args, **kwargs)
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 340, in torch_dispatch
return DTensor._op_dispatcher.dispatch(
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 215, in dispatch
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
File "/home/trx50/.virtualenvs/torchtitan/lib/python3.10/site-packages/torch/_ops.py", line 723, in call
return self._op(*args, **kwargs)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU 1 has a total capacity of 47.50 GiB of which 21.62 MiB is free. Including non-PyTorch memory, this process has 47.46 GiB memory in use. Of the allocated memory 46.16 GiB is allocated by PyTorch, and 600.21 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Training config:

torchtitan Config.toml

NOTE: this toml config is a preset for 64 A100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"
[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"
[optimizer]
name = "AdamW"
lr = 3e-4
[training]
batch_size = 2 #1
seq_len = 256 #512 #8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_replicate_degree = 1 #1
data_parallel_shard_degree = -1 #-1
tensor_parallel_degree = 2 #1
compile = true
dataset = "c4"
[experimental]
pipeline_parallel_degree = 1 #1
enable_async_tensor_parallel = true
[checkpoint]
enable_checkpoint = false #false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "bfloat16" #32
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
[activation_checkpoint]
mode = 'full' # ['none', 'selective', 'full']
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
[float8]
enable_float8_linear = true
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true

@lessw2020
Copy link
Contributor

Hi @0781532 - based on the above error ("optimizers.step()") I think you are making it to through the first forward pass and then OOMing during the backward pass (due to the init of optimizer states).
We can either offload the optimizer states to cpu, or, you also asked about 1B or 3B size models?
I made a quick branch here to add an approx 1B (1.397B) based largely on the llama3.2 1B structure.
https://github.com/pytorch/torchtitan/tree/lessw2020/add_32_1B
You can run that and see if that gets you training, esp. if you don't need expressly to train 8B size.
Please let me know and we can build the 1B out further from there:
llama1B

@0781532
Copy link
Author

0781532 commented Oct 16, 2024

Hi @lessw2020 ,

I really need to run training Llama-3-8B. Could you please support a training config or a branch to offload the optimizer states to cpu?

Thank you very much!

@lessw2020
Copy link
Contributor

Hi @0781532 - I tested CPUOffloading earlier and found it has some issues. However, Wei has temp resolved these in the PR above or here:
https://github.com/pytorch/torchtitan/pull/622/files
This PR is hardcoded to only run with CPUoffloading.
Could you give it a spin and confirm if you are now able to train 8B on your setup?

@awgu
Copy link
Contributor

awgu commented Oct 16, 2024

Just sharing for completeness: if you want to run Llama3-8B with AdamW, you cannot do so on 2 GPUs with 48 GB memory each without some kind of offloading.

8B * 4 bytes * 3 / 2 GPUs = 48 GB / GPU, where the * 3 comes from sharded parameters, sharded exp_avgs, and sharded exp_avg_sqs. In other words, you cannot even fit the parameters and optimizer states.

Either you can use a lighter-weight optimizer, or you have to offload something (e.g. optimizer states) to CPU.

@0781532
Copy link
Author

0781532 commented Oct 16, 2024

Dear @lessw2020 and @awgu ,

Thank you very much for your quick help!

I am now successfully running Llama3-8B with your "cpu_offload commit" with batch_size=4.

The VRAM loading states for each GPU is around 18GB (in total 2 GPU A6000ada 48G).
The DRAM loading state is around ~155 GB (by offloading to cpu).

I am still trying with higher values of batch_size to check the limitation of my HW for this training case.


I have some further questions about the config settings:

  1. tensor_parallel_degree = 1:
    Does it mean that I am running with the combination of TPP+FSDP2 for my training?
    or
    I actually did not apply tensor_parallel for 2 GPUs but just offload optimizer states to CPU (system RAM) using FSDP2 only?

  2. [experimental]
    pipeline_parallel_degree = 1
    Did we apply any kind of "pipeline_parallel" in the training process?

If Yes, that means we are applying PP+TPP+FSDP at the same time?

  1. Does Geforce RTX 4090 24GB can work for (this case by using) tensor_parallel technique? (in case VRAM requirement is still < 24GB for each GPUs when training)

I read some refs said that only NVIDIA Center GPUs like A5000, A6000, A6000ada or higher work.

Thanks,
0781532

@lessw2020
Copy link
Contributor

Hi @0781532,
First, great to hear you are up and running with the cpu_offloading!
Regarding your questions:
1 in the config file means "not applied"...ie. anything * 1 = anything.
thus right now you are running with 1D FSDP2. No tensor parallel nor pipeline parallel.
Neither of these would give you any added gain until you are training at larger scale or larger model size.

re 3 - tensor parallel is applied at the PyTorch level and is implemented independent of any hardware. It does rely on NCCL for the communication but that's also a software aspect. Thus, you can certainly run TP on your hardware or really most reasonably new gpus as there is no direct hardware reliance.

Hope this helps!

mori360 added a commit to mori360/torchtitan that referenced this issue Nov 26, 2024
resolve pytorch#620 
Add config: `--training.enable_cpu_offload`

Command: `CONFIG_FILE="./train_configs/llama3_8b.toml"
./run_llama_train.sh`

For non-pp case:
<img width="611" alt="Screenshot 2024-10-23 at 1 45 56 PM"
src="https://github.com/user-attachments/assets/8692f8a6-c0f3-460e-8eb6-7f7195bed370">

For pp case:
<img width="587" alt="cpu offload+pp"
src="https://github.com/user-attachments/assets/73e40861-47e2-4845-a41c-4bfea2860109">
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
4 participants