diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index fbf8229f5..a3c8ac4c2 100644 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -299,7 +299,7 @@ def load_network(self, net, load_path, strict=True, param_key='params'): """ logger = get_root_logger() net = self.get_bare_model(net) - load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + load_net = torch.load(load_path, map_location=lambda storage, loc: storage, weights_only=False) if param_key is not None: if param_key not in load_net and 'params' in load_net: param_key = 'params'