Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

Swish activation saves beta as weight even if it is not trainable #483

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions keras_contrib/layers/advanced_activations/swish.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from keras import backend as K
from keras.layers import Layer
from keras import backend as K
from keras.layers import InputSpec
from keras.initializers import Constant


class Swish(Layer):
Expand All @@ -14,7 +16,7 @@ class Swish(Layer):
Same shape as the input.

# Arguments
beta: float >= 0. Scaling factor
initial_beta: float >= 0. Scaling factor
if set to 1 and trainable set to False (default),
Swish equals the SiLU activation (Elfwing et al., 2017)
trainable: whether to learn the scaling factor during training or not
Expand All @@ -24,29 +26,28 @@ class Swish(Layer):
- [Sigmoid-weighted linear units for neural network function
approximation in reinforcement learning](https://arxiv.org/abs/1702.03118)
"""

def __init__(self, beta=1.0, trainable=False, **kwargs):
"""
Swish activation function with a trainable parameter referred to as 'beta' in https://arxiv.org/abs/1710.05941"""
def __init__(self, trainable = True, initial_beta = 1., **kwargs):
super(Swish, self).__init__(**kwargs)
self.supports_masking = True
self.beta = beta
self.trainable = trainable
self.initial_beta = initial_beta
self.beta_initializer = Constant(value=self.initial_beta)
self.__name__ = 'swish'

def build(self, input_shape):
self.scaling_factor = K.variable(self.beta,
dtype=K.floatx(),
name='scaling_factor')
if self.trainable:
self._trainable_weights.append(self.scaling_factor)
super(Swish, self).build(input_shape)
self.beta = self.add_weight(shape=[1], name='beta',
initializer=self.beta_initializer,
trainable=trainable)
self.input_spec = InputSpec(ndim=len(input_shape))
self.built = True

def call(self, inputs, mask=None):
return inputs * K.sigmoid(self.scaling_factor * inputs)
def call(self, inputs):
return inputs * K.sigmoid(self.beta * inputs)

def get_config(self):
config = {'beta': self.get_weights()[0] if self.trainable else self.beta,
'trainable': self.trainable}
config = {'trainable': self.trainable,
'initial_beta': self.initial_beta)}
base_config = super(Swish, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def compute_output_shape(self, input_shape):
return input_shape