-
Notifications
You must be signed in to change notification settings - Fork 0
/
forwardModelDemo_Summer2021.asv
435 lines (356 loc) · 16 KB
/
forwardModelDemo_Summer2021.asv
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
% forwardModelDemo: Illustrate the principle of inverted encoding models
%
% Implementation for 3D motion perception
% by Bas Rokers ([email protected])
% Based on demo of forward model for estimating
% tuning functions (TFs)
% js: 9.9.2014, ([email protected])
% Relevant papers:
% Brouwer & Heeger, 2009 - http://www.jneurosci.org/content/jneuro/29/44/13992.full.pdf
% Kok, Brouwer et al, 2013 - http://www.jneurosci.org/content/jneuro/33/41/16275.full.pdf
% Sprague, Ester, Serences, 2014 - https://www.sciencedirect.com/science/article/pii/S096098221400935X
% Requirements:
% Matlab deep learning toolbox (plotconfusion)
% Circular statistics toolbox - https://github.com/circstat/circstat-matlab
restoredefaultpath
circstat_path = '~/Documents/GitHub/circstat-matlab';
if exist(circstat_path,'dir')
addpath(genpath('~/Documents/GitHub/circstat-matlab')); % Add Circular Statistics toolbox
else
error('Please install the Circular statistics toolbox - https://github.com/circstat/circstat-matlab')
end
set(0, 'DefaultLineLineWidth', 2); % Set figure linewidth default
clear all;
close all;
%% Section 0: Parameters
% Encoding model
nChans = 8; % number of channels in your model (does not have to match # stim features).
basis = 'cosine'; %'von_mises'; %'delta'; %'cosine', 'von_mises'
% Data
nDirections = 8; % number of stimulus directions in your study
nRepeats = 20; % number repeats of each motion direction for the simulation
nTrials = nRepeats * nDirections; % number of trials in your experiment.
nFolds = 250; % n-fold cross-validation
%% Section 1: Build encoding model
% design the basis function for estimating the channel weights in each voxel
% basis function is nChannels x nDirections
xs = linspace(0, 360-360/nDirections, nDirections);
channel_peaks = linspace(0, 360-360/nChans, nChans);
switch basis
case 'delta'
bf = eye(nChans); % delta functions
case 'cosine'
exponent = 5;
for ii = 1:nChans
bf(ii,:) = cosd(channel_peaks(ii)-xs).^exponent; % cosine
end
case 'von_mises'
kappa = 5; % concentration parameter, like 1/sigma
for ii = 1:nChans
bf(ii,:) = circ_vmpdf(pi.*xs./180, pi*xs(ii)./180, kappa); % von mises (to fix)
end
otherwise
error('Unknown basis function')
end
% Lowell wants to explore bimodal gaussians
% Lowell: Can you provide some motivation?
bf = max(0,bf); % rectify
% bf = bf./max(bf); % norm to unit height
%% Plot model
figure(1); hold on
ph = plot([xs 360],[bf, bf(:,1)]); % repeat 0 deg data at 360
xlabel({' ','Motion Direction (deg)'})
ylabel({' ','Channel Response',' '});
title('Encoding model')
axis([0 360 0 1]);
xticks(0:45:360);
yticks(0:.25:1);
%% Section 2: Load data
% TAFKAP_Decode(data, params);
% data is nTrials x nVoxels
sub = 'sub-0201';
ses = {'01'}; %,'02'};
run = 1:10; %[[1:10]'];%,[1:10]'];
roi = {'V2'};
% roi = {'V1','V2','hMT','IPS0'};
%roiname = {'V1','V2','V3','V3A','hV4','LO','hMT','MST','IPS'};
% roiname = {'V1','V2','V3','V3A','V3B','hV4','LO1','LO2','hMT','MST','IPS0','IPS1','IPS2','IPS3','IPS4','IPS5','VO1','VO2','SPL1','PHC1','PHC2','FEF'};
projectDir = '~/Dropbox (RVL)/MRI/Decoding/';
DATA = loadmydata(sub,ses,run,projectDir,roi);
gg = [5:-1:1 8:-1:6 4:8 1:3]'; % true trial design matrix across all runs %lh
g = repmat(gg,10,1);
% This demo contains V1, MT and IPS data for one participant (br)
% 3D motion stimuli are presented in 8 labeled directions, where
% 1: rightward, 3: towards, 5:leftward, and 7: away.
% Intermediate values reflect intermediate directions moving directly
% towards or away from one of the eyes
% Heads up TAFKAP does:
% normalize per TR over voxels. Check if that matters
% Also, make sure to pull TAFKAP again
% load('workspace.mat') % fft detrended, z-scored data
% load('workspace_pct.mat'); % percent BOLD change data
% load('workspace_pct_poly.mat'); % percent BOLD change data, 1st order polynomial detrend
% load('workspace_poly_z.mat'); % polynomial detrend, z-scored
% load('workspace_sub-205-poly-0_z.mat'); % sub-205, polynomial detrend, z-scored
% load('workspace_sub-205-poly-3_z_roi-all.mat'); % sub-205, polynomial detrend, z-scored
% % load data directly
% for ii = 1 %1:20
% % load fmriprepped data
% nii{ii} = niftiread('~/Dropbox (RVL)/MRI/Decoding/derivatives/fmriprep/sub-0201/ses-01/func/sub-0201_ses-01_task-3dmotion_run-1_space-T1w_desc-preproc_bold.nii.gz');
%
% % apply mask
% mask = niftiread('~/Dropbox (RVL)/MRI/Decoding/derivatives/fmriprep/sub-0201/ses-01/anat/rois/sub-0201_space-T1w_downsampled_V1.nii.gz');
% mask = repmat(mask,[1,1,1,250]);
%
% nii_masked{ii} = nii{ii}(mask);
%
% % extract timeseries in mask
% masked_nii{ii} = cosmo_slice
%
%
% % detrend, normalize and average
% end
% _pct data has some outliers (>500% signal change) in V1 that affect the
% results
% Contains rois (roi names), new_p (parameters),
% and masked_ds (trials x voxels dataset organized by roi)
for whichRoi = 1:length(roi)
% whichRoi = 1; % V1 - 1, hMT - 2, IPS - 3
% take average of every 2 TRs
data = (DATA{whichRoi}(1:2:end-1,:) + DATA{whichRoi}(2:2:end,:)) ./2;
% average every 8 datapoints (if wanted)
% data = squeeze(mean(reshape(data,15,8,[])));
% data = masked_ds{whichRoi}.samples;
g = repmat(1:8,1,15)';
% g = round(new_p.stimval./22.5); % ground truth, convert back to labels 1-8
% block = new_p.runNs; % block/scan indices
% data = DATA{whichRoi};
% zero-meaning seems to result in weird behavior, rank-deficiency errors
% and strange channel response reconstruction
% at least for non-delta basis functions
% BR thinks it has something to do with:
% If the channels peak exactly on or halfway between the stimuli you sampled,
% then the problem is underspecified, as the two channels on either side
% of each sampled stimulus value make the same prediction for the response at that stimulus value.
% Inspect data
figure(2)
imagesc(data)
title('Measured voxel response')
xlabel('Voxel #')
ylabel('Trial #')
% % Old (UW data)
% load('myData_br_V1.mat')
% load('myData_br_MT.mat') % contains trials x nVoxels
% % load('myData_br_IPS.mat')
%
% % data is trials x voxels
% data = [traindat; testdat]; % reconstitute. Data was originally 50/50 split
% g = [Truth; Truth]; % presented direction on each trial
% block = sort(repmat((1:nRepeats)', nDirections, 1));
%% STEP 3: Hold one block out and solve for channel weights: 'Forward model'
chan = []; chan_tstg = []; totpred_direction
for ff=1:nFolds % Hold one fold out at a time
fprintf('Computing iteration %d out of %d\n', ff, nFolds);
% stratify by scan and motion direction
% sometimes shows weird reconstructed channel responses
% c.training = (block ~= (rem(ff,20)+1)); % set training data
% c.test = ~c.training; % data from training scans (all but one scan)
% Hold one out cross validation
c = cvpartition(g, 'Holdout', 0.2); % stratify by motion direction, but not scan
trn = data(c.training,:); % training data
tst = data(c.test,:); % test data
trng = g(c.training); % trial labels for training data.
tstg = g(c.test); % trial labels for test data.
% create the design matrix for computing channel weights in each voxel
X = zeros(size(trn,1), nChans); % initialize predicted channel responses
% Define presented directions
% presented_dir = [0 80 90 100 180 260 270 280]; % 24.0063% in MT, delta function
% presented_dir = [0 70 90 110 180 250 270 290]; % 24.6938% in MT
% presented_dir = [0 60 90 120 180 240 270 300]; % 28.3938% in MT
% presented_dir = [0 55 90 125 180 235 270 305]; % 29.325% in MT
% presented_dir = [0 50 90 130 180 230 270 310]; % 29.175% in MT
% presented_dir = linspace(0, 360-360/nChans, nChans); % 29.2188% in MT
% presented_dir = [0 30 90 150 180 210 270 330]; % 28.8813% in MT - better l/r decoding, worse t/a
for ii=1:size(trn,1)
% populate predicted channel responses
% Retinal motion channel responses
X(ii,:) = bf(:,trng(ii)); % rows: observations (trials), columns: predicted response of each orientation channel
% Or use channel responses closer to world motion Oblique
% trajectories are closer to toward/away trajectories, so would
% produce more similar channel responses
% will not work correctly for delta basis function, as it will just
% rescale relative amplitude of basis functions
% X(ii,:) = fshift(bf(1,:),presented_dir(trng(ii))*4/180); % rows: observations (trials), columns: predicted response of each orientation channel
end
% use a GLM to compute weight of each channel in each
% voxel, based on data from the training set.
w = X\trn; % channel weights matrix - or inv(X'*X)*X'*trn;
% Optional - regularize (shrink weights)
% TODO
% then invert and apply to test data ...
% basically, you solved for the weights using the
% training data, now you have some observed data from the test set, and
% you want to infer the channel response profile (tuning function) that
% best maps the known selectivity of each voxel (w) to the test data
x = (w'\tst')'; % reconstructed channel responses - or (inv(w*w')*w*tst')'
% stack up the channel responses from each iteration of holding
% one fold out... chan contains blocks by directions (20x8) by channels (8)
%chan((ff-1)*nDirections+1:ff*nDirections,:) = x;
chan = [chan; x];
chan_tstg = [chan_tstg; tstg];
% Compute predicted motion direction
pred_direction = [];
for ii = 1:length(tstg)
pred_direction(ii) = mod(rad2deg(circ_mean(deg2rad(channel_peaks),x(ii,:))),360);
end
total_pred_direction = [total_pred_direction; pred_direction];
end
%% Visualize Results %%
nplot = ceil(sqrt(length(roi)));
%% Plot channel weights (w) for a sample voxel
figure(3); hold on
title('Channel weights (w) for a sample voxel')
subplot(nplot,nplot,whichRoi)
whichVoxel = 10; % Pick a sample voxel
bar(w(:,whichVoxel))
xlabel('Channel (deg)')
ylabel('Weight')
title(['Channel weights (voxel ' num2str(whichVoxel) ')'])
xlim([.5 nChans+.5])
set(gca,'xticklabel', xs)
title(roi(whichRoi))
%% Plot combined channel response for test data
figure(5), hold on
subplot(nplot,nplot,whichRoi)
meanchan = grpstats(chan,chan_tstg)'; % calculate mean per channel
ph = plot([xs 360], [meanchan meanchan(:,1)],'o-'); % tuning each direction
% plot([0 360],[mean(mean(chan)) mean(mean(chan))],'k:'); % mean response
% format figure
if whichRoi == length(roi)
lh = legend(cellfun(@num2str,num2cell(xs), 'UniformOutput',false));
set(lh, 'Location', 'northeastoutside')
end
title(lh,'Presented direction')
title([roi(whichRoi) 'Mean Reconstructed channel response'])
xlabel('Direction Channel (deg)')
ylabel('Estimated Response')
xticks([0:45:360])
xlim([0 360])
% % Compute average reconstructed channel response
% for ii = 1:nDirections
% mean_chan(ii,:) = mean(chan(ii:nDirections:nTrials,:));
% end
%
% figure; hold on
% plot([xs 360],[mean_chan mean_chan(:,1)])
% title('Mean channel response (averaged over hold-outs)')
% xlabel('Direction (deg)')
% ylabel('Channel response')
% xticks([0:45:360])
% xlim([0 360])
% lh = legend(cellfun(@num2str,num2cell(xs), 'UniformOutput',false));
% set(lh, 'Location', 'northeastoutside')
% title(lh,'Presented direction')
%% Plot combined channel response for test data
% Response as a function of presented direction, grouped by stimulus
% x - reconstructed channel response
% tstg - presented direction label
figure(9)
hold on
for ii = 1:length(tstg)
plot(xs,sum(x(ii,:)'.*bf))
end
legend(num2str(tstg))
% plot()
xlabel('Direction (deg)')
ylabel('Combined Channel Response')
%% Make predictions (based on weighted sum of channel responses
for ii = 1:length(tstg)
pred_direction(ii) = mod(rad2deg(circ_mean(deg2rad(channel_peaks),x(ii,:))),360);
end
figure(20)
subplot(nplot,nplot,whichRoi)
scatter(xs(tstg),pred_direction)
lsline
xlabel('Presented direction label')
ylabel('Predicted direction label')
xlim([0 360])
ylim([0 360])
title(roi(whichRoi))
%% Make predictions (based on correlation between observed and predicted
% reconstruction)
%% Plot confusion matrix (presented vs reconstructed (max x) over iterations
% compute predicted motion direction
for ii = 1:size(chan_tstg)
[~, pred(ii)] = max(chan(ii,:)); % This is wrong as it predominantly produces estimates that align with the peak of a channel
end
% Instead either
% (1) compute a weighted circular sum of the channel weights and its basis
% function peak
% (2) generate predicted channel responses for stimuli on 0:360 in 1 deg
% increments, and pick the stimulus direction where the resulting curve
% correlates best with the observed combined weighted channel response
% figure; hold on
% plotconfusion(categorical(g),categorical(pred'),'Test')
% or alternatively, fit von mises to chan and read off mean
% or do it yourself
conmat = confusionmat(categorical(chan_tstg),categorical(floor(pred_direction/45))');
conmat = conmat.*nDirections./length(chan_tstg);
figure(6); hold on;
subplot(nplot,nplot,whichRoi)
conmat = [conmat; conmat(1,:)]; % wrap matrix
conmat = [conmat, conmat(:,1)];
imagesc(conmat);
% clim = [0 .5]; % upper, lower limits
% imagesc(conmat, clim);
xlabel('Presented direction')
ylabel('Decoded direction')
xticks([1:9])
xticklabels(cellstr([{char(8594)} {char(8599)} {char(8593)} {char(8598)} {char(8592)} {char(8601)} {char(8595)} {char(8600)} {char(8594)}]))
yticks([1:9])
yticklabels(cellstr([{char(8594)} {char(8599)} {char(8593)} {char(8598)} {char(8592)} {char(8601)} {char(8595)} {char(8600)} {char(8594)}]))
axis tight
title(roi(whichRoi))
if whichRoi == length(roi)
cb = colorbar;
end
cb.Label.String = 'Classification performance (%)';
% % Todo: Make blue/white/red colorbar, with white = chance performance
% % Or maybe red -> blue with transparency?
% T = [255, 0, 0 %// red
% 255, 255, 255 %// white
% 255, 255, 255 %// white
% 0, 0, 255]./255; %// blue again -> note that this means values between 161 and 255 will be indistinguishable
% x = [0; 50; 100; 255];
% mymap = interp1(x/255,T,linspace(0,1,255));
%
% colormap(mymap)
% % use caxis
%% Extra stuff follows below
%% Plot average tuning function for each cross-validated test stimulus
% Shift the rows (tf on each trial) so that the channel corresponding to the stimulus on
% each trial is in the middle column
for ii=1:size(chan,1)
schan(ii,:) = circshift(chan(ii,:), -chan_tstg(ii)+5); % center on 0 deg channel
end
figure(7), hold on
plot([-180:360/nChans:180], [mean(schan), mean(schan(:,1))],'o-')
% plot([-180 180],[mean(mean(schan)) mean(mean(schan))],'k:'); % mean response
xlabel('Distance from Preferred Direction (deg)')
ylabel('Estimated Response')
title('Mean Tuning/Channel Response Function')
xlim([-180 180])
xticks([-180:45:360])
if whichRoi == length(roi)
legend(roi)
end
%% Or as a matrix
% figure
% imagesc(meanchan)
% xlabel('Channels');
% ylabel('Stimuli');
% colorbar
% axis xy
%% How well did we do on the test set?
disp(['Overall accuracy: ' num2str(100.*mean(chan_tstg == pred')) '%'])
end