Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the restriction on Factorised masks to be positive using absolute value (as in ecker 2018) #232

Merged
merged 8 commits into from
Mar 8, 2024
33 changes: 31 additions & 2 deletions neuralpredictors/layers/readouts/factorized.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

Check warning on line 1 in neuralpredictors/layers/readouts/factorized.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/readouts/factorized.py#L1

Added line #L1 was not covered by tests

import numpy as np
import torch
from torch import nn as nn
Expand All @@ -19,12 +21,31 @@
init_noise=1e-3,
constrain_pos=False,
positive_weights=False,
positive_spatial=False,
shared_features=None,
mean_activity=None,
spatial_and_feature_reg_weight=None,
gamma_readout=None, # depricated, use feature_reg_weight instead
gamma_readout=None,
**kwargs,
):
"""

Args:
in_shape: batch, channels, height, width (batch could be arbitrary)
outdims: number of neurons to predict
bias: if True, bias is used
normalize: if True, normalizes the spatial mask using l2 norm
init_noise: the std for readout initialisation
constrain_pos: if True, negative values in the spatial mask and feature readout are clamped to 0
positive_weights: if True, negative values in the feature readout are turned into 0
positive_spatial: if True, spatial readout mask values are restricted to be positive by taking the absolute values
shared_features: if True, uses a copy of the features from somewhere else
mean_activity: the mean for readout initialisation
spatial_and_feature_reg_weight: lagrange multiplier (constant) for L1 penalty,
the bigger the number, the stronger the penalty
gamma_readout: depricated, use spatial_and_feature_reg_weight instead
**kwargs:
"""

super().__init__()

Expand All @@ -33,6 +54,12 @@
self.outdims = outdims
self.positive_weights = positive_weights
self.constrain_pos = constrain_pos
self.positive_spatial = positive_spatial

Check warning on line 57 in neuralpredictors/layers/readouts/factorized.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/readouts/factorized.py#L57

Added line #L57 was not covered by tests
if positive_spatial and constrain_pos:
warnings.warn(

Check warning on line 59 in neuralpredictors/layers/readouts/factorized.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/readouts/factorized.py#L59

Added line #L59 was not covered by tests
f"If both positive_spatial and constrain_pos are True, "
f"only constrain_pos will effectively take place"
)
self.init_noise = init_noise
self.normalize = normalize
self.mean_activity = mean_activity
Expand All @@ -50,7 +77,7 @@
else:
self.register_parameter("bias", None)

self.initialize(mean_activity)
self.initialize()

Check warning on line 80 in neuralpredictors/layers/readouts/factorized.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/readouts/factorized.py#L80

Added line #L80 was not covered by tests
pollytur marked this conversation as resolved.
Show resolved Hide resolved

@property
def shared_features(self):
Expand Down Expand Up @@ -84,6 +111,8 @@
weight = self.spatial
if self.constrain_pos:
weight.data.clamp_min_(0)
elif self.positive_spatial:
weight = torch.abs(weight)

Check warning on line 115 in neuralpredictors/layers/readouts/factorized.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/readouts/factorized.py#L115

Added line #L115 was not covered by tests
pollytur marked this conversation as resolved.
Show resolved Hide resolved
return weight

def regularizer(self, reduction="sum", average=None):
Expand Down
Loading