From 71d725071a788e2c14bb9a4d919bf609fd10535d Mon Sep 17 00:00:00 2001 From: BenjaminMidtvedt <41636530+BenjaminMidtvedt@users.noreply.github.com> Date: Mon, 13 May 2024 22:10:25 +0200 Subject: [PATCH] Allow to set tensors parameter to control what is initialized (#104) --- deeplay/initializers/constant.py | 5 +---- deeplay/initializers/initializer.py | 16 ++++++---------- deeplay/initializers/kaiming.py | 15 +++++++++++---- deeplay/initializers/normal.py | 7 ++----- deeplay/module.py | 14 +++++++++----- 5 files changed, 29 insertions(+), 28 deletions(-) diff --git a/deeplay/initializers/constant.py b/deeplay/initializers/constant.py index 7ead6a99..ae5359d9 100644 --- a/deeplay/initializers/constant.py +++ b/deeplay/initializers/constant.py @@ -21,8 +21,5 @@ def __init__( self.weight = weight self.bias = bias - def initialize_weight(self, tensor): + def initialize_tensor(self, tensor, name): tensor.data.fill_(self.weight) - - def initialize_bias(self, tensor): - tensor.data.fill_(self.bias) diff --git a/deeplay/initializers/initializer.py b/deeplay/initializers/initializer.py index e42a974c..de8d2cc1 100644 --- a/deeplay/initializers/initializer.py +++ b/deeplay/initializers/initializer.py @@ -3,15 +3,11 @@ class Initializer: def __init__(self, targets): self.targets = targets - def initialize(self, module): + def initialize(self, module, tensors=("weight", "bias")): if isinstance(module, self.targets): - if hasattr(module, "weight") and module.weight is not None: - self.initialize_weight(module.weight) - if hasattr(module, "bias") and module.bias is not None: - self.initialize_bias(module.bias) + for tensor in tensors: + if hasattr(module, tensor) and getattr(module, tensor) is not None: + self.initialize_tensor(getattr(module, tensor), name=tensor) - def initialize_weight(self, tensor): - pass - - def initialize_bias(self, tensor): - pass + def initialize_tensor(self, tensor, name): + raise NotImplementedError diff --git a/deeplay/initializers/kaiming.py b/deeplay/initializers/kaiming.py index 9c006dd2..95368388 100644 --- a/deeplay/initializers/kaiming.py +++ b/deeplay/initializers/kaiming.py @@ -24,13 +24,20 @@ def __init__( targets: Tuple[Type[nn.Module], ...] = _kaiming_default_targets, mode: str = "fan_out", nonlinearity: str = "relu", + fill_bias: bool = True, + bias: float = 0.0, ): super().__init__(targets) self.mode = mode self.nonlinearity = nonlinearity + self.fill_bias = fill_bias + self.bias = bias - def initialize_weight(self, tensor): - nn.init.kaiming_normal_(tensor, mode=self.mode, nonlinearity=self.nonlinearity) + def initialize_tensor(self, tensor, name): - def initialize_bias(self, tensor): - tensor.data.fill_(0.0) + if name == "bias" and self.fill_bias: + tensor.data.fill_(self.bias) + else: + nn.init.kaiming_normal_( + tensor, mode=self.mode, nonlinearity=self.nonlinearity + ) diff --git a/deeplay/initializers/normal.py b/deeplay/initializers/normal.py index 928bc516..e160c318 100644 --- a/deeplay/initializers/normal.py +++ b/deeplay/initializers/normal.py @@ -28,8 +28,5 @@ def __init__( self.mean = mean self.std = std - def initialize_bias(self, tensor): - tensor.data.fill_(self.mean) - - def initialize_weight(self, tensor): - tensor.data.normal_(self.mean, self.std) + def initialize_tensor(self, tensor, name): + tensor.data.normal_(mean=self.mean, std=self.std) diff --git a/deeplay/module.py b/deeplay/module.py index 641b88ba..450f3e12 100644 --- a/deeplay/module.py +++ b/deeplay/module.py @@ -1023,16 +1023,20 @@ def log_tensor(self, name: str, tensor: torch.Tensor): """ self.logs[name] = tensor - def initialize(self, initializer): + def initialize( + self, initializer, tensors: Union[str, Tuple[str, ...]] = ("weight", "bias") + ): + if isinstance(tensors, str): + tensors = (tensors,) for module in self.modules(): if isinstance(module, DeeplayModule): - module._initialize_after_build(initializer) + module._initialize_after_build(initializer, tensors) else: - initializer.initialize(module) + initializer.initialize(module, tensors) @after_build - def _initialize_after_build(self, initializer): - initializer.initialize(self) + def _initialize_after_build(self, initializer, tensors: Tuple[str, ...]): + initializer.initialize(self, tensors) @after_build def _validate_after_build(self):