From 89a89b8ffac2e8f844c9836eb75aa7eb48de67ad Mon Sep 17 00:00:00 2001 From: pomonam Date: Thu, 27 Jul 2023 03:24:20 -0400 Subject: [PATCH 1/2] Initial Commit --- baselines/adamw/pytorch/submission.py | 3 ++- submission_runner.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/baselines/adamw/pytorch/submission.py b/baselines/adamw/pytorch/submission.py index b6e4b214f..6a086ff2d 100644 --- a/baselines/adamw/pytorch/submission.py +++ b/baselines/adamw/pytorch/submission.py @@ -32,7 +32,8 @@ def init_optimizer_state(workload: spec.Workload, betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2), eps=1e-8, - weight_decay=hyperparameters.weight_decay), + weight_decay=hyperparameters.weight_decay, + fused=False), } def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): diff --git a/submission_runner.py b/submission_runner.py index 669b776be..d4e72c8cb 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -192,7 +192,7 @@ def train_once( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: compile_error_workloads = [ - 'fastmri', 'ogbg', 'librispeech_deepspeech', 'wmt' + 'ogbg', 'librispeech_deepspeech', 'wmt' ] eager_backend_workloads = ['librispeech_conformer'] aot_eager_backend_workloads = ['criteo1tb'] From 6a18963aaae33c89ab806fea0d2be164e07a9118 Mon Sep 17 00:00:00 2001 From: pomonam Date: Thu, 27 Jul 2023 12:15:28 -0400 Subject: [PATCH 2/2] Linting fix --- submission_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index d4e72c8cb..1850c598e 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -191,9 +191,7 @@ def train_once( model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: - compile_error_workloads = [ - 'ogbg', 'librispeech_deepspeech', 'wmt' - ] + compile_error_workloads = ['ogbg', 'librispeech_deepspeech', 'wmt'] eager_backend_workloads = ['librispeech_conformer'] aot_eager_backend_workloads = ['criteo1tb'] if FLAGS.workload in compile_error_workloads: