Skip to content

Commit

Permalink
Merge pull request #461 from mlcommons/juhan/fastmri
Browse files Browse the repository at this point in the history
torch.compile FastMRI
  • Loading branch information
znado authored Jul 27, 2023
2 parents 630131f + 6a18963 commit 5577b32
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 2 additions & 1 deletion baselines/adamw/pytorch/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def init_optimizer_state(workload: spec.Workload,
betas=(1.0 - hyperparameters.one_minus_beta1,
hyperparameters.beta2),
eps=1e-8,
weight_decay=hyperparameters.weight_decay),
weight_decay=hyperparameters.weight_decay,
fused=False),
}

def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
Expand Down
4 changes: 1 addition & 3 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ def train_once(
model_params, model_state = workload.init_model_fn(
model_init_rng, dropout_rate, aux_dropout_rate)
if FLAGS.framework == 'pytorch' and FLAGS.torch_compile:
compile_error_workloads = [
'fastmri', 'ogbg', 'librispeech_deepspeech', 'wmt'
]
compile_error_workloads = ['ogbg', 'librispeech_deepspeech', 'wmt']
eager_backend_workloads = ['librispeech_conformer']
aot_eager_backend_workloads = ['criteo1tb']
if FLAGS.workload in compile_error_workloads:
Expand Down

0 comments on commit 5577b32

Please sign in to comment.