Skip to content

Commit

Permalink
Merge pull request #597 from BoyuanFeng/feat/pt2-compile-loss-fn
Browse files Browse the repository at this point in the history
Enable torch.compile for loss_fn
  • Loading branch information
priyakasimbeg authored Dec 8, 2023
2 parents 64d1a85 + a134d08 commit a7b9650
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ def train_once(
]
eager_backend_workloads = ['librispeech_deepspeech']
aot_eager_backend_workloads = []
loss_compilation_workloads = [
'fastmri', 'librispeech_deepspeech', 'ogbg', 'wmt'
]
base_workload = workloads.get_base_workload_name(workload_name)
if base_workload in compile_error_workloads:
logging.warning(
Expand All @@ -247,6 +250,8 @@ def train_once(
else:
logging.info('Performing `torch.compile`.')
model_params = torch.compile(model_params)
if base_workload in loss_compilation_workloads:
workload.loss_fn = torch.compile(workload.loss_fn)
logging.info('Initializing optimizer.')
with profiler.profile('Initializing optimizer'):
optimizer_state = init_optimizer_state(workload,
Expand Down

0 comments on commit a7b9650

Please sign in to comment.