-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#147 addition of new discount function model
- Loading branch information
Ben Vincent
committed
Nov 18, 2016
1 parent
1e3bec2
commit 96a88ce
Showing
14 changed files
with
443 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
classdef DF_ExponentialPower < DiscountFunction | ||
%DF_ExponentialPower The classic 1-parameter discount function | ||
|
||
properties (Dependent) | ||
|
||
end | ||
|
||
methods (Access = public) | ||
|
||
function obj = DF_ExponentialPower(varargin) | ||
obj = obj@DiscountFunction(); | ||
|
||
obj.theta.k = Stochastic('k'); | ||
obj.theta.tau = Stochastic('tau'); | ||
|
||
% Input parsing ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
p = inputParser; | ||
p.StructExpand = false; | ||
p.addParameter('samples',struct(), @isstruct) | ||
p.parse(varargin{:}); | ||
|
||
fieldnames = fields(p.Results.samples); | ||
% Add any provided samples | ||
for n = 1:numel(fieldnames) | ||
obj.theta.(fieldnames{n}).addSamples( p.Results.samples.(fieldnames{n}) ); | ||
end | ||
% ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
end | ||
|
||
|
||
function plot(obj) | ||
x = [1:365]; | ||
|
||
% don't plot if we've been given NaN's | ||
if any(isnan(obj.theta.k.samples)) | ||
warning('Not plotting due to NaN''s') | ||
return | ||
end | ||
|
||
% TODO | ||
discountFraction = obj.eval(x, 'nExamples', 100); | ||
|
||
try | ||
plot(x, discountFraction, '-', 'Color',[0.5 0.5 0.5 0.1]) | ||
catch | ||
% backward compatability | ||
plot(x, discountFraction, '-', 'Color',[0.5 0.5 0.5]) | ||
end | ||
|
||
xlabel('delay $D^B$', 'interpreter','latex') | ||
ylabel('discount factor', 'interpreter','latex') | ||
set(gca,'Xlim', [0 max(x)]) | ||
box off | ||
axis square | ||
|
||
% ~~~~~~~~~~~~~ | ||
obj.data.plot() | ||
% ~~~~~~~~~~~~~ | ||
end | ||
|
||
|
||
|
||
|
||
% function discountFraction = eval(obj, x, varargin) | ||
% % evaluate the discount fraction : | ||
% % - at the delays (x.delays) | ||
% % - given the onj.parameters | ||
% | ||
% p = inputParser; | ||
% p.addRequired('x', @isnumeric); | ||
% p.addParameter('nExamples', [], @isscalar); | ||
% p.parse(x, varargin{:}); | ||
% | ||
% n_samples_requested = p.Results.nExamples; | ||
% n_samples_got = numel(obj.theta.k.samples); | ||
% n_samples_to_get = min([n_samples_requested n_samples_got]); | ||
% if ~isempty(n_samples_requested) | ||
% % shuffle the deck and pick the top nExamples | ||
% shuffledExamples = randperm(n_samples_to_get); | ||
% ExamplesToPlot = shuffledExamples([1:n_samples_to_get]); | ||
% else | ||
% ExamplesToPlot = 1:n_samples_to_get; | ||
% end | ||
% | ||
% if verLessThan('matlab','9.1') | ||
% discountFraction = (bsxfun(@times,... | ||
% exp( - obj.theta.k.samples(ExamplesToPlot)),... | ||
% x) ); | ||
% else | ||
% % use new array broadcasting in 2016b | ||
% discountFraction = exp( - obj.theta.k.samples(ExamplesToPlot) .* x ); | ||
% end | ||
% end | ||
|
||
end | ||
|
||
methods (Static, Access = protected) | ||
|
||
function y = function_evaluation(x, theta, ExamplesToPlot) | ||
k = theta.k.samples(ExamplesToPlot); | ||
tau = theta.tau.samples(ExamplesToPlot); | ||
if verLessThan('matlab','9.1') | ||
y = (bsxfun(@times, exp(-k), x) ); | ||
else | ||
% use new array broadcasting in 2016b | ||
y = exp( - k .* x.^tau ); | ||
end | ||
end | ||
|
||
end | ||
|
||
end |
113 changes: 113 additions & 0 deletions
113
ddToolbox/ModelClasses/bens_new_model/ExponentialPower.m
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
classdef (Abstract) ExponentialPower < Parametric | ||
|
||
properties (Access = private) | ||
getDiscountRate % function handle | ||
end | ||
|
||
methods (Access = public) | ||
|
||
function obj = ExponentialPower(data, varargin) | ||
obj = obj@Parametric(data, varargin{:}); | ||
|
||
obj.dfClass = @DF_ExponentialPower; | ||
|
||
% Create variables | ||
obj.varList.participantLevel = {'k','tau','alpha','epsilon'}; | ||
obj.varList.monitored = {'k','tau','alpha','epsilon', 'Rpostpred', 'P', 'VA', 'VB'}; | ||
|
||
%% Plotting | ||
obj.plotFuncs.clusterPlotFunc = @plotExpPowerclusters; | ||
|
||
end | ||
|
||
function conditionalDiscountRates(obj, reward, plotFlag) | ||
error('Not applicable to this model') | ||
end | ||
|
||
function conditionalDiscountRates_GroupLevel(obj, reward, plotFlag) | ||
error('Not applicable to this model') | ||
end | ||
|
||
function experimentPlot(obj) | ||
|
||
names = obj.data.getIDnames('all'); | ||
|
||
for ind = 1:numel(names) | ||
fh = figure('Name', ['participant: ' names{ind}]); | ||
latex_fig(12, 10, 3) | ||
|
||
%% Set up psychometric function | ||
psycho = PsychometricFunction('samples', obj.coda.getSamplesAtIndex(ind,{'alpha','epsilon'})); | ||
|
||
%% plot bivariate distribution of alpha, epsilon | ||
subplot(1,4,1) | ||
samples = obj.coda.getSamplesAtIndex(ind,{'alpha','epsilon'}); | ||
mcmc.BivariateDistribution(... | ||
samples.epsilon(:),... | ||
samples.alpha(:),... | ||
'xLabel','error rate, $\epsilon$',... | ||
'ylabel','comparison accuity, $\alpha$',... | ||
'pointEstimateType',obj.pointEstimateType,... | ||
'plotStyle', 'hist',... | ||
'axisSquare', true); | ||
|
||
%% Plot the psychometric function | ||
subplot(1,4,2) | ||
psycho.plot() | ||
|
||
%% Set up discount function | ||
ksamples = obj.coda.getSamplesAtIndex(ind,{'k','tau'}); | ||
% don't plot if we don't have any samples. This is expected | ||
% to happen if we are currently looking at the group-level | ||
% unobserved participant and we are analysing a model | ||
% without group level inferences (ie the mixed or separate | ||
% models) | ||
discountFunction = DF_ExponentialPower('samples', ksamples ); | ||
% add data: TODO: streamline this on object creation ~~~~~ | ||
% NOTE: we don't have data for group-level | ||
data_struct = obj.data.getExperimentData(ind); | ||
data_object = DataFile(data_struct); | ||
discountFunction.data = data_object; | ||
% ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
% TODO: this checking needs to be implemented in a | ||
% smoother, more robust way | ||
if ~isempty(ksamples) || ~any(isnan(ksamples)) | ||
%% plot distribution of (k, tau) | ||
subplot(1,4,3) | ||
%discountFunction.plotParameters() | ||
samples = obj.coda.getSamplesAtIndex(ind,{'k','tau'}); | ||
mcmc.BivariateDistribution(... | ||
samples.k(:),... | ||
samples.tau(:),... | ||
'xLabel','discount rate, $k$',... | ||
'ylabel','time exponent, $\tau$',... | ||
'pointEstimateType',obj.pointEstimateType,... | ||
'plotStyle', 'hist',... | ||
'axisSquare', true); | ||
|
||
%% plot discount function | ||
subplot(1,4,4) | ||
discountFunction.plot() | ||
end | ||
|
||
|
||
if obj.shouldExportPlots | ||
myExport(obj.savePath, 'expt',... | ||
'prefix', names{ind},... | ||
'suffix', obj.modelFilename,... | ||
'formats', {'png'}); | ||
end | ||
|
||
close(fh) | ||
end | ||
end | ||
|
||
end | ||
|
||
|
||
methods (Abstract) | ||
initialiseChainValues | ||
end | ||
|
||
end |
36 changes: 36 additions & 0 deletions
36
ddToolbox/ModelClasses/bens_new_model/ModelHierarchicalExpPower.m
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
classdef ModelHierarchicalExpPower < ExponentialPower | ||
%ModelHierarchical A model to estimate the log discount rate, according to the 1-parameter hyperbolic discount function. | ||
% All parameters are estimated hierarchically. | ||
|
||
methods (Access = public) | ||
|
||
function obj = ModelHierarchicalExp1(data, varargin) | ||
obj = obj@ExponentialPower(data, varargin{:}); | ||
obj.modelFilename = 'hierarchicalExpPower'; | ||
obj = obj.addUnobservedParticipant('GROUP'); | ||
|
||
% MUST CALL THIS METHOD AT THE END OF ALL MODEL-SUBCLASS CONSTRUCTORS | ||
obj = obj.conductInference(); | ||
end | ||
|
||
end | ||
|
||
|
||
methods | ||
|
||
function initialParams = initialiseChainValues(obj, nchains) | ||
% Generate initial values of the root nodes | ||
for chain = 1:nchains | ||
initialParams(chain).groupKmu = normrnd(0.001,0.1); | ||
initialParams(chain).groupKsigma = rand*5; | ||
initialParams(chain).groupW = rand; | ||
initialParams(chain).groupALPHAmu = rand*10; | ||
initialParams(chain).groupALPHAsigma = rand*5; | ||
|
||
% TODO: prior over group-level tau parameters | ||
end | ||
end | ||
|
||
end | ||
|
||
end |
29 changes: 29 additions & 0 deletions
29
ddToolbox/ModelClasses/bens_new_model/ModelMixedExpPower.m
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
classdef ModelMixedExpPower < ExponentialPower | ||
%ModelMixedExp1 | ||
|
||
methods (Access = public) | ||
function obj = ModelMixedExp1(data, varargin) | ||
obj = obj@ExponentialPower(data, varargin{:}); | ||
obj.modelFilename = 'mixedExpPower'; | ||
obj = obj.addUnobservedParticipant('GROUP'); | ||
|
||
% MUST CALL THIS METHOD AT THE END OF ALL MODEL-SUBCLASS CONSTRUCTORS | ||
obj = obj.conductInference(); | ||
end | ||
end | ||
|
||
methods | ||
|
||
function initialParams = initialiseChainValues(obj, nchains) | ||
% Generate initial values of the root nodes | ||
nExperimentFiles = obj.data.getNExperimentFiles(); | ||
for chain = 1:nchains | ||
initialParams(chain).groupW = rand; | ||
initialParams(chain).groupALPHAmu = rand*100; | ||
initialParams(chain).groupALPHAsigma = rand*100; | ||
end | ||
end | ||
|
||
end | ||
|
||
end |
32 changes: 32 additions & 0 deletions
32
ddToolbox/ModelClasses/bens_new_model/ModelSeparateExpPower.m
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
classdef ModelSeparateExpPower < ExponentialPower | ||
|
||
|
||
methods (Access = public) | ||
|
||
function obj = ModelSeparateExpPower(data, varargin) | ||
obj = obj@ExponentialPower(data, varargin{:}); | ||
obj.modelFilename = 'separateExpPower'; | ||
|
||
% MUST CALL THIS METHOD AT THE END OF ALL MODEL-SUBCLASS CONSTRUCTORS | ||
obj = obj.conductInference(); | ||
end | ||
|
||
end | ||
|
||
|
||
methods | ||
|
||
function initialParams = initialiseChainValues(obj, nchains) | ||
% Generate initial values of the root nodes | ||
nExperimentFiles = obj.data.getNExperimentFiles(); | ||
for chain = 1:nchains | ||
initialParams(chain).k = unifrnd(0, 0.5, [nExperimentFiles,1]); | ||
initialParams(chain).tau = unifrnd(0.01, 2, [nExperimentFiles,1]); | ||
initialParams(chain).epsilon = 0.1 + rand([nExperimentFiles,1])/10; | ||
initialParams(chain).alpha = abs(normrnd(0.01,10,[nExperimentFiles,1])); | ||
end | ||
end | ||
|
||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Ben's new discounting function | ||
|
||
As far as I know this new discount function is novel. The discount fraction is described by | ||
|
||
discount fraction = exp(-k.delay^tau) |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.