Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeech pytorch timing gap and torch compile issues #488

Closed
priyakasimbeg opened this issue Aug 15, 2023 · 4 comments · Fixed by #597
Closed

Deepspeech pytorch timing gap and torch compile issues #488

priyakasimbeg opened this issue Aug 15, 2023 · 4 comments · Fixed by #597
Labels
⏰ Timing gap Significant difference (>= 10%) between pytorch and jax workloads

Comments

@priyakasimbeg
Copy link
Contributor

priyakasimbeg commented Aug 15, 2023

Backwards compilation for deepspeech pytorch torch compile unsupported/broken.
Pytorch Deepspeech is currently 13% slower in pytorch compared to jax.

Description

Currently deepspeech works with torch.compile backend option 'eager' but breaks with 'aot_eager'.
Goal of this bug is to:

  1. work w pytorch contributors to determine whether we can enable full torch compile on this workload.
  2. reduce timing gap with jax.

Steps to reproduce

torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py --framework=pytorch --workload=librispeech_deepspeech --submission_path=baselines/adamw/pytorch/submission.py --tuning_search_space=baselines/adamw/tuning_search_space.json --data_dir=/data/librispeech --num_tuning_trials=1 --experiment_dir=/experiment_runs --experiment_name=timing_pytorch_2_preliminary_after_pytorch_fixes/adamw --overwrite=True --save_checkpoints=False --max_global_steps=8000 --librispeech_tokenizer_vocab_path=/data/librispeech/spm_model.vocab --torch_compile=true
@priyakasimbeg priyakasimbeg changed the title Deepspeech torch compile Deepspeech pytorch timing gap and torch compile Aug 15, 2023
@priyakasimbeg priyakasimbeg changed the title Deepspeech pytorch timing gap and torch compile Deepspeech pytorch timing gap and torch compile issues Aug 15, 2023
@priyakasimbeg priyakasimbeg added the 🚀 Launch Blocker Issues that are blocking launch of benchmark label Aug 17, 2023
@priyakasimbeg
Copy link
Contributor Author

priyakasimbeg commented Aug 22, 2023

Seems like torch compile breaks in Dynamo tracing step on 2.1.0.dev20230820+cu118.

Traceback

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function pack_padded_sequence at 0x7feca75669d0>(*(FakeTensor(..., device='cuda:0', size=(32, 500, 512), grad_fn=<CloneBackward0>), FakeTensor(..., size=(32,))), **{'batch_first': True, 'enforce_sorted': False}):
'lengths' argument should be a 1D CPU int64 tensor, but got 1D meta Long tensor

from user code:
   File "/algorithmic-efficiency/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py", line 285, in <resume in forward>
    packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(

Full error logs in regression test.

Filed separate bug to track specific torch compile issue #498

@priyakasimbeg priyakasimbeg added the ⏰ Timing gap Significant difference (>= 10%) between pytorch and jax workloads label Aug 22, 2023
@pomonam
Copy link
Contributor

pomonam commented Aug 22, 2023

Related #483

@priyakasimbeg priyakasimbeg removed the 🚀 Launch Blocker Issues that are blocking launch of benchmark label Aug 31, 2023
@priyakasimbeg
Copy link
Contributor Author

Current status after enabling eager for deepspeech is that pytorch is 12% slower than jax for this workload.

@priyakasimbeg
Copy link
Contributor Author

Also resolved in #597

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⏰ Timing gap Significant difference (>= 10%) between pytorch and jax workloads
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants