Skip to content

Commit

Permalink
Merge pull request #230 from pollytur/encoder3d_shifter_fix
Browse files Browse the repository at this point in the history
Added shifter and updated 3d-encoder nonlinearity
  • Loading branch information
MaxFBurg authored Mar 8, 2024
2 parents a57dc39 + 584c646 commit 1dda093
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions neuralpredictors/layers/encoders/encoder3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,30 @@ def __init__(self, core, readout, readout_nonlinearity, elu_xshift, elu_yshift):
else:
self.nonlinearity = core.nonlinearities[readout_nonlinearity]()

def forward(self, x, data_key=None):
def forward(self, x, data_key=None, pupil_center=None, trial_idx=None, shift=None, detach_core=False, **kwargs):
out_core = self.core(x)
if detach_core:
out_core = out_core.detach()

if self.shifter:
if pupil_center is None:
raise ValueError("pupil_center is not given")
if shift is None:
time_points = x.shape[1]
pupil_center = pupil_center[:, :, -time_points:]
pupil_center = torch.transpose(pupil_center, 1, 2)
pupil_center = pupil_center.reshape(((-1,) + pupil_center.size()[2:]))
shift = self.shifter[data_key](pupil_center, trial_idx)

out_core = torch.transpose(out_core, 1, 2)
# the expected readout is 2d whereas the core can output 3d matrices
# therefore, the first two dimensions (representing depth and batch size) are flattened and then passed
# through the readout
out_core = out_core.reshape(((-1,) + out_core.size()[2:]))
readout_out = self.readout(out_core, data_key=data_key, shift=shift, **kwargs)

readout_out = self.readout(out_core)
out = self.nonlinearity(readout_out)
if self.nonlinearity_type == "elu":
out = self.nonlinearity_fn(readout_out + self.offset) + 1
else:
out = self.nonlinearity_fn(readout_out)
return out

0 comments on commit 1dda093

Please sign in to comment.