Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PP-related issues #771

Open
4 of 6 tasks
tianyu-l opened this issue Jan 3, 2025 · 1 comment
Open
4 of 6 tasks

PP-related issues #771

tianyu-l opened this issue Jan 3, 2025 · 1 comment
Labels
bug Something isn't working release_blocking Issues that are blocking the milestone / release completion

Comments

@tianyu-l
Copy link
Contributor

tianyu-l commented Jan 3, 2025

I found the below issues when debugging FSDP + CP + PP loss converging. I used a seed checkpoint, on the debug model, with at most 8 GPUs.

  • numerics mismatch for FSDP vs. FSDP + PP (e.g. FSDP2 vs. FSDP2 + PP2), or single-GPU vs. PP-only. (The mismatch may be due to different accumulation order of gradients on the microbatches. It is likely not an issue.)
  • default number of microbatches is PP degree; instead, it should be number of PipelineStages. (fixed in fix num_microbatches input for PP #781)
  • numerics mismatch for FSDP + PP + none AC vs. FSDP + PP + full AC (gone after [Pipelining] Fix PP grad scaling pytorch#144352)
  • FSDP + PP + training.deterministic: NaN within 5 steps (gone after [Pipelining] Fix PP grad scaling pytorch#144352)
  • FSDP + PP + CP has numerical issues, compared with FSDP + PP or FSDP + CP, which affect loss converging
  • FSDP + PP + CP + training.mixed_precision_param = "float32": see error log below
error log traceback : Traceback (most recent call last): File "/home/lty/local/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper return f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File "/home/lty/local/torchtitan/train.py", line 287, in main pp_schedule.step(input_ids) File "/home/lty/local/pytorch/torch/distributed/pipelining/schedules.py", line 503, in step self._step_microbatches(args_split, kwargs_split, targets_split, losses) File "/home/lty/local/pytorch/torch/distributed/pipelining/schedules.py", line 671, in _step_microbatches self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) File "/home/lty/local/pytorch/torch/distributed/pipelining/schedules.py", line 473, in _initialize_stage self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) File "/home/lty/local/pytorch/torch/distributed/pipelining/stage.py", line 1421, in _prepare_forward_infra outputs = self._shape_inference(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/distributed/pipelining/stage.py", line 1362, in _shape_inference outputs = self.submod(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/nn/modules/module.py", line 1845, in _call_impl return inner() ^^^^^^^ File "/home/lty/local/pytorch/torch/nn/modules/module.py", line 1793, in inner result = forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/torchtitan/torchtitan/models/llama/model.py", line 442, in forward h = layer(h, self.freqs_cis) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/nn/modules/module.py", line 1845, in _call_impl return inner() ^^^^^^^ File "/home/lty/local/pytorch/torch/nn/modules/module.py", line 1793, in inner result = forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/torchtitan/torchtitan/models/llama/model.py", line 323, in forward h = x + self.attention(self.attention_norm(x), freqs_cis) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/nn/modules/module.py", line 1750, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/torchtitan/torchtitan/models/llama/model.py", line 209, in forward output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/distributed/tensor/experimental/_attention.py", line 907, in inner_fn output = target_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/_compile.py", line 32, in inner return disable_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/_dynamo/eval_frame.py", line 751, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/distributed/tensor/_api.py", line 343, in __torch_dispatch__ return DTensor._op_dispatcher.dispatch( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/distributed/tensor/_dispatch.py", line 164, in dispatch return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/distributed/tensor/experimental/_attention.py", line 555, in _sdpa_handler local_results = _scaled_dot_product_ring_efficient_attention( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/lty/local/pytorch/torch/distributed/tensor/experimental/_attention.py", line 239, in _scaled_dot_product_ring_efficient_attention raise NotImplementedError("compute_log_sumexp must be set") NotImplementedError: compute_log_sumexp must be set
@tianyu-l tianyu-l added this to the torchtitan v1.0.0 release milestone Jan 3, 2025
@tianyu-l tianyu-l added bug Something isn't working release_blocking Issues that are blocking the milestone / release completion labels Jan 3, 2025
@wconstab
Copy link
Contributor

wconstab commented Jan 7, 2025

1/6 update: we checked off the first 2 boxes for PP vs GPU and PP vs FSDP after running experiments for 300 steps, observing that for the first 10-20 steps there might be a small gap between PP/non-PP but overall the loss curves look quite consistent with each other.

image

Comparatively, we see much more significant gaps between PP+CP+FDSP runs and either (PP+FSDP or CP+FSDP) runs, so we need to focus more investigation there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working release_blocking Issues that are blocking the milestone / release completion
Projects
None yet
Development

No branches or pull requests

2 participants