Skip to content

Commit

Permalink
Added information logging print in transfer_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Oct 1, 2024
1 parent 909fbcb commit c783ce8
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions pytorch_toolbelt/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Common functions to marshal data to/from PyTorch
"""

import collections
import dataclasses
import functools
Expand Down Expand Up @@ -282,6 +283,8 @@ def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict
"""
existing_model_state_dict = model.state_dict()

loaded_layers = 0

for name, value in model_state_dict.items():
if name not in existing_model_state_dict:
logger.debug(
Expand All @@ -308,9 +311,16 @@ def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict

try:
model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False)
loaded_layers += 1
except Exception as e:
logger.debug(f"transfer_weights skipped loading weights for key {name}, because of error: {e}")

percentage_of_layers_from_checkpoint = loaded_layers / len(model_state_dict) * 100
percentage_of_layers_in_model = loaded_layers / len(existing_model_state_dict) * 100
logger.info(
f"Transferred {percentage_of_layers_from_checkpoint:.2f}% of layers from checkpoint to model, filling {percentage_of_layers_in_model:.2f}% of model layers"
)


def resize_like(x: Tensor, target: Tensor, mode: str = "bilinear", align_corners: Union[bool, None] = True) -> Tensor:
"""
Expand Down

0 comments on commit c783ce8

Please sign in to comment.