Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: self and mat2 must have the same dtype, but got Float and BFloat16 when training with torch_compile #35382

Open
2 of 4 tasks
umarbutler opened this issue Dec 21, 2024 · 2 comments
Assignees
Labels

Comments

@umarbutler
Copy link
Contributor

System Info

  • transformers version: 4.48.0.dev0
  • Platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
  • Python version: 3.12.5
  • Huggingface_hub version: 0.25.1
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 4090

Who can help?

@ArthurZucker @muellerzr @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When I try continual pretraining ModernBERT for a MLM objective with the torch_compile flag of my TrainingArguments set to True, I get the below error:

  0%|                                                                                   | 0/1223301 [00:00<?, ?it/s]
/home/dev/.venv/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
/home/dev/.venv/lib/python3.12/site-packages/onnxscript/converter.py:820: FutureWarning: 'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
  param_schemas = callee.param_schemas()
W1221 15:33:48.046000 14779 torch/_dynamo/variables/tensor.py:776] [4/0] Graph break from `Tensor.item()`, consider setting:
W1221 15:33:48.046000 14779 torch/_dynamo/variables/tensor.py:776] [4/0]     torch._dynamo.config.capture_scalar_outputs = True
W1221 15:33:48.046000 14779 torch/_dynamo/variables/tensor.py:776] [4/0] or:
W1221 15:33:48.046000 14779 torch/_dynamo/variables/tensor.py:776] [4/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W1221 15:33:48.046000 14779 torch/_dynamo/variables/tensor.py:776] [4/0] to include these operations in the captured graph.
W1221 15:33:48.046000 14779 torch/_dynamo/variables/tensor.py:776] [4/0]
W1221 15:33:48.046000 14779 torch/_dynamo/variables/tensor.py:776] [4/0] Graph break: from user code at:
W1221 15:33:48.046000 14779 torch/_dynamo/variables/tensor.py:776] [4/0]   File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 711, in torch_dynamo_resume_in__unpad_modernbert_input_at_710
W1221 15:33:48.046000 14779 torch/_dynamo/variables/tensor.py:776] [4/0]     max_seqlen_in_batch = int(seqlens_in_batch.max().item())
W1221 15:33:48.046000 14779 torch/_dynamo/variables/tensor.py:776] [4/0]
W1221 15:33:48.046000 14779 torch/_dynamo/variables/tensor.py:776] [4/0]
Traceback (most recent call last):
  File "/home/dev/encoder/scripts/train/train_modernbert.py", line 206, in <module>
    trainer.train(
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2163, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2523, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3668, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3722, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 820, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 808, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 1023, in forward
    @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING)
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 1055, in torch_dynamo_resume_in_forward_at_1055
    input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 913, in forward
    layer_outputs = encoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 519, in forward
    def forward(
  File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/modernbert/modeling_modernbert.py", line 529, in torch_dynamo_resume_in_forward_at_529
    attn_outputs = self.attn(
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1100, in forward
    return compiled_fn(full_args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 308, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
                            ^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 98, in g
    return f(*args)
           ^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1525, in forward
    fw_outs = call_func_at_runtime_with_args(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
                            ^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 488, in wrapper
    return compiled_fn(runtime_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 667, in inner_fn
    outs = compiled_fn(args)
           ^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1478, in __call__
    return self.current_callable(inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/.venv/lib/python3.12/site-packages/torch/_inductor/utils.py", line 1977, in run
    return model(new_inputs)
           ^^^^^^^^^^^^^^^^^
  File "/tmp/torchinductor_/y7/cy7xv2rrzhznbq3e2wurnq5pmygfytvnpovxlh5bugtoa3ebwy6f.py", line 277, in call
    extern_kernels.addmm(buf9, buf7, reinterpret_tensor(buf8, (1152, 768), (1, 1152), 0), alpha=1, beta=1, out=buf10)
RuntimeError: self and mat2 must have the same dtype, but got Float and BFloat16

This does not occur when finetuning for a classification task.

I am using bfloat16 mixed precision.

Expected behavior

The training works.

@umarbutler umarbutler added the bug label Dec 21, 2024
@ArthurZucker
Copy link
Collaborator

cc @tomaarsen could you have a look? 🤗

@tomaarsen tomaarsen self-assigned this Dec 28, 2024
@tomaarsen
Copy link
Member

Hello!

By default, ModernBERT uses Flash Attention 2 (if installed & training on CUDA) for efficient training. However, FA2 isn't currently properly compatible with torch compilation. This is one of the reasons that we trained with specific components compiled (controllable via reference_compile in the model config) rather than the full model.

If you use a different attention mechanism like "eager" or "sdpa", it will work again:

model = AutoModel.from_pretrained("answerdotai/ModernBERT-base", attn_implementation="sdpa")

However, this is significantly slower than just running FA2 outright. In my tests, it feels like it's recompiling at every step or something. A training script that I had ran in ~90 minutes with attn_implementation="flash_attention_2" & bf16=True and is scheduled to take ~35 hours with attn_implementation="sdpa", torch_compile=True & bf16=True.
For context, just sdpa & bf16=True without torch_compile=True is estimated to take about 3.5 hours. The main reason this is much slower than FA2 is due to the unpacking that it uses.

In short: Avoid torch_compile=True for training ModernBERT. And install flash_attn if you can.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants