diff --git a/hanlp/common/torch_component.py b/hanlp/common/torch_component.py index 6d8e07f6a..4ec9ddba8 100644 --- a/hanlp/common/torch_component.py +++ b/hanlp/common/torch_component.py @@ -97,7 +97,7 @@ def load_weights(self, save_dir, filename='model.pt', **kwargs): save_dir = get_resource(save_dir) filename = os.path.join(save_dir, filename) # flash(f'Loading model: {filename} [blink]...[/blink][/yellow]') - self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False) + self.model_.load_state_dict(torch.load(filename, map_location='cpu', weights_only=True), strict=False) # flash('') def save_config(self, save_dir, filename='config.json'):