Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added shifter and updated 3d-encoder nonlinearity #230

Merged
merged 5 commits into from
Mar 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions neuralpredictors/layers/encoders/encoder3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,30 @@
else:
self.nonlinearity = core.nonlinearities[readout_nonlinearity]()

def forward(self, x, data_key=None):
def forward(self, x, data_key=None, pupil_center=None, trial_idx=None, shift=None, detach_core=False, **kwargs):

Check warning on line 15 in neuralpredictors/layers/encoders/encoder3d.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/encoders/encoder3d.py#L15

Added line #L15 was not covered by tests
out_core = self.core(x)
if detach_core:
out_core = out_core.detach()

Check warning on line 18 in neuralpredictors/layers/encoders/encoder3d.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/encoders/encoder3d.py#L18

Added line #L18 was not covered by tests

if self.shifter:
if pupil_center is None:
raise ValueError("pupil_center is not given")

Check warning on line 22 in neuralpredictors/layers/encoders/encoder3d.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/encoders/encoder3d.py#L22

Added line #L22 was not covered by tests
if shift is None:
time_points = x.shape[1]
pupil_center = pupil_center[:, :, -time_points:]
pupil_center = torch.transpose(pupil_center, 1, 2)
pupil_center = pupil_center.reshape(((-1,) + pupil_center.size()[2:]))
shift = self.shifter[data_key](pupil_center, trial_idx)

Check warning on line 28 in neuralpredictors/layers/encoders/encoder3d.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/encoders/encoder3d.py#L24-L28

Added lines #L24 - L28 were not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename shift into shifter? From how the word sounds, I would expect shift to be bool, but it actually is rather smth like a nn.Module I assume?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok just seen that in the readouts it's also called shift, so no strong opinion here from my side

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I guess shift and shifter are also slightly semantically different

A I get it shift if the output of the shifter and theoretically it could be provided for the readout from somewhere else, not from the model.shifter (that the only reason why this parameter is here and I also tried to make it consistency with the 2d core)


out_core = torch.transpose(out_core, 1, 2)
# the expected readout is 2d whereas the core can output 3d matrices
# therefore, the first two dimensions (representing depth and batch size) are flattened and then passed
# through the readout
out_core = out_core.reshape(((-1,) + out_core.size()[2:]))
readout_out = self.readout(out_core, data_key=data_key, shift=shift, **kwargs)

Check warning on line 35 in neuralpredictors/layers/encoders/encoder3d.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/encoders/encoder3d.py#L35

Added line #L35 was not covered by tests

readout_out = self.readout(out_core)
out = self.nonlinearity(readout_out)
if self.nonlinearity_type == "elu":
out = self.nonlinearity_fn(readout_out + self.offset) + 1

Check warning on line 38 in neuralpredictors/layers/encoders/encoder3d.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/encoders/encoder3d.py#L38

Added line #L38 was not covered by tests
else:
out = self.nonlinearity_fn(readout_out)

Check warning on line 40 in neuralpredictors/layers/encoders/encoder3d.py

View check run for this annotation

Codecov / codecov/patch

neuralpredictors/layers/encoders/encoder3d.py#L40

Added line #L40 was not covered by tests
return out
Loading