Skip to content

Commit

Permalink
Allow to set tensors parameter to control what is initialized (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminMidtvedt authored May 13, 2024
1 parent 0013dff commit 71d7250
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 28 deletions.
5 changes: 1 addition & 4 deletions deeplay/initializers/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,5 @@ def __init__(
self.weight = weight
self.bias = bias

def initialize_weight(self, tensor):
def initialize_tensor(self, tensor, name):
tensor.data.fill_(self.weight)

def initialize_bias(self, tensor):
tensor.data.fill_(self.bias)
16 changes: 6 additions & 10 deletions deeplay/initializers/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@ class Initializer:
def __init__(self, targets):
self.targets = targets

def initialize(self, module):
def initialize(self, module, tensors=("weight", "bias")):
if isinstance(module, self.targets):
if hasattr(module, "weight") and module.weight is not None:
self.initialize_weight(module.weight)
if hasattr(module, "bias") and module.bias is not None:
self.initialize_bias(module.bias)
for tensor in tensors:
if hasattr(module, tensor) and getattr(module, tensor) is not None:
self.initialize_tensor(getattr(module, tensor), name=tensor)

def initialize_weight(self, tensor):
pass

def initialize_bias(self, tensor):
pass
def initialize_tensor(self, tensor, name):
raise NotImplementedError
15 changes: 11 additions & 4 deletions deeplay/initializers/kaiming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,20 @@ def __init__(
targets: Tuple[Type[nn.Module], ...] = _kaiming_default_targets,
mode: str = "fan_out",
nonlinearity: str = "relu",
fill_bias: bool = True,
bias: float = 0.0,
):
super().__init__(targets)
self.mode = mode
self.nonlinearity = nonlinearity
self.fill_bias = fill_bias
self.bias = bias

def initialize_weight(self, tensor):
nn.init.kaiming_normal_(tensor, mode=self.mode, nonlinearity=self.nonlinearity)
def initialize_tensor(self, tensor, name):

def initialize_bias(self, tensor):
tensor.data.fill_(0.0)
if name == "bias" and self.fill_bias:
tensor.data.fill_(self.bias)
else:
nn.init.kaiming_normal_(
tensor, mode=self.mode, nonlinearity=self.nonlinearity
)
7 changes: 2 additions & 5 deletions deeplay/initializers/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,5 @@ def __init__(
self.mean = mean
self.std = std

def initialize_bias(self, tensor):
tensor.data.fill_(self.mean)

def initialize_weight(self, tensor):
tensor.data.normal_(self.mean, self.std)
def initialize_tensor(self, tensor, name):
tensor.data.normal_(mean=self.mean, std=self.std)
14 changes: 9 additions & 5 deletions deeplay/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,16 +1023,20 @@ def log_tensor(self, name: str, tensor: torch.Tensor):
"""
self.logs[name] = tensor

def initialize(self, initializer):
def initialize(
self, initializer, tensors: Union[str, Tuple[str, ...]] = ("weight", "bias")
):
if isinstance(tensors, str):
tensors = (tensors,)
for module in self.modules():
if isinstance(module, DeeplayModule):
module._initialize_after_build(initializer)
module._initialize_after_build(initializer, tensors)
else:
initializer.initialize(module)
initializer.initialize(module, tensors)

@after_build
def _initialize_after_build(self, initializer):
initializer.initialize(self)
def _initialize_after_build(self, initializer, tensors: Tuple[str, ...]):
initializer.initialize(self, tensors)

@after_build
def _validate_after_build(self):
Expand Down

0 comments on commit 71d7250

Please sign in to comment.