From c783ce8d9a926678d82a102384fde566a7d60401 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Tue, 1 Oct 2024 14:50:32 +0300 Subject: [PATCH] Added information logging print in transfer_weights --- pytorch_toolbelt/utils/torch_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pytorch_toolbelt/utils/torch_utils.py b/pytorch_toolbelt/utils/torch_utils.py index f45f0c31b..f13d1a1ce 100644 --- a/pytorch_toolbelt/utils/torch_utils.py +++ b/pytorch_toolbelt/utils/torch_utils.py @@ -1,6 +1,7 @@ """Common functions to marshal data to/from PyTorch """ + import collections import dataclasses import functools @@ -282,6 +283,8 @@ def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict """ existing_model_state_dict = model.state_dict() + loaded_layers = 0 + for name, value in model_state_dict.items(): if name not in existing_model_state_dict: logger.debug( @@ -308,9 +311,16 @@ def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict try: model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False) + loaded_layers += 1 except Exception as e: logger.debug(f"transfer_weights skipped loading weights for key {name}, because of error: {e}") + percentage_of_layers_from_checkpoint = loaded_layers / len(model_state_dict) * 100 + percentage_of_layers_in_model = loaded_layers / len(existing_model_state_dict) * 100 + logger.info( + f"Transferred {percentage_of_layers_from_checkpoint:.2f}% of layers from checkpoint to model, filling {percentage_of_layers_in_model:.2f}% of model layers" + ) + def resize_like(x: Tensor, target: Tensor, mode: str = "bilinear", align_corners: Union[bool, None] = True) -> Tensor: """