Skip to content

Commit

Permalink
disable torch.compile for mps, give clearer error messages, fix #2244
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Jun 3, 2024
1 parent 0dd755a commit 80f7c3d
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,24 @@ def initialize(self):
def _do_i_compile(self):
# new default: compile is enabled!

# compile does not work on mps
if self.device == torch.device('mps'):
if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):
self.print_to_log_file("INFO: torch.compile disabled because of unsupported mps device")
return False

# CPU compile crashes for 2D models. Not sure if we even want to support CPU compile!? Better disable
if self.device == torch.device('cpu'):
if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):
self.print_to_log_file("INFO: torch.compile disabled because device is CPU")
return False

# default torch.compile doesn't work on windows because there are apparently no triton wheels for it
# https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2
if os.name == 'nt':
if 'nnUNet_compile' in os.environ.keys() and os.environ['nnUNet_compile'].lower() in ('true', '1', 't'):
self.print_to_log_file("INFO: torch.compile disabled because Windows is not natively supported. If "
"you know what you are doing, check https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2")
return False

if 'nnUNet_compile' not in os.environ.keys():
Expand Down

0 comments on commit 80f7c3d

Please sign in to comment.