diff --git a/neuralpredictors/layers/readouts/factorized.py b/neuralpredictors/layers/readouts/factorized.py index 47a77338..4ffb5151 100644 --- a/neuralpredictors/layers/readouts/factorized.py +++ b/neuralpredictors/layers/readouts/factorized.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np import torch from torch import nn as nn @@ -19,12 +21,31 @@ def __init__( init_noise=1e-3, constrain_pos=False, positive_weights=False, + positive_spatial=False, shared_features=None, mean_activity=None, spatial_and_feature_reg_weight=None, - gamma_readout=None, # depricated, use feature_reg_weight instead + gamma_readout=None, **kwargs, ): + """ + + Args: + in_shape: batch, channels, height, width (batch could be arbitrary) + outdims: number of neurons to predict + bias: if True, bias is used + normalize: if True, normalizes the spatial mask using l2 norm + init_noise: the std for readout initialisation + constrain_pos: if True, negative values in the spatial mask and feature readout are clamped to 0 + positive_weights: if True, negative values in the feature readout are turned into 0 + positive_spatial: if True, spatial readout mask values are restricted to be positive by taking the absolute values + shared_features: if True, uses a copy of the features from somewhere else + mean_activity: the mean for readout initialisation + spatial_and_feature_reg_weight: lagrange multiplier (constant) for L1 penalty, + the bigger the number, the stronger the penalty + gamma_readout: depricated, use spatial_and_feature_reg_weight instead + **kwargs: + """ super().__init__() @@ -33,6 +54,12 @@ def __init__( self.outdims = outdims self.positive_weights = positive_weights self.constrain_pos = constrain_pos + self.positive_spatial = positive_spatial + if positive_spatial and constrain_pos: + warnings.warn( + f"If both positive_spatial and constrain_pos are True, " + f"only constrain_pos will effectively take place" + ) self.init_noise = init_noise self.normalize = normalize self.mean_activity = mean_activity @@ -50,7 +77,7 @@ def __init__( else: self.register_parameter("bias", None) - self.initialize(mean_activity) + self.initialize() @property def shared_features(self): @@ -84,6 +111,8 @@ def normalized_spatial(self): weight = self.spatial if self.constrain_pos: weight.data.clamp_min_(0) + elif self.positive_spatial: + weight = torch.abs(weight) return weight def regularizer(self, reduction="sum", average=None):