diff --git a/main.py b/main.py index 94e6aec..4ec8c6c 100644 --- a/main.py +++ b/main.py @@ -299,7 +299,8 @@ def main(): optimizer.load_state_dict(checkpoint['optimizer']) print("=> checkpoint state loaded.") - model = torch.nn.DataParallel(model) + if cuda: + model = torch.nn.DataParallel(model) # Data loading code print("=> creating data loaders ... ")