forked from hannahsheahan/context_magnitude
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy paththeoretical_performance.py
105 lines (86 loc) · 4.83 KB
/
theoretical_performance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Script for calculating and plotting optimal performance under two different policies post-lesion:
1. a local policy that uses the local context mean when responding whether number A>B
2. a global policy that uses the global number mean across all contexts when responding whether A>B
Sheahan, H.*, Luyckx, F.*, Nelli, S., Taupe, C., & Summerfield, C. (2021). Neural
state space alignment for magnitude generalisation in humans and recurrent networks.
Neuron (in press)
Author: Hannah Sheahan, [email protected]
Date: 07/04/2020
Notes: N/A
Issues: N/A
"""
# ---------------------------------------------------------------------------- #
import analysis_helpers as anh
import constants as const
import statistics as stats
import numpy as np
import matplotlib.pyplot as plt
import os
# ---------------------------------------------------------------------------- #
def simulate_theoretical_policies():
"""This function calculates the theoretical performance of an agent,
when making relative magnitude decisions seeing just the current number and using either local or global context info.
- distribution of numbers in each context are the same as for the human and network relative magnitude task.
"""
print('Simulating theoretical agent performance...')
# Define the ranges of primary targets displayed in each context
localxranges = [[const.FULLR_LLIM,const.FULLR_ULIM], [const.LOWR_LLIM,const.LOWR_ULIM], [const.HIGHR_LLIM,const.HIGHR_ULIM]]
globalxrange = [i for contextrange in localxranges for i in range(contextrange[0], contextrange[1]+1)]
globalmean = stats.mean(globalxrange) # should be 8.5
# record performance as a function of distance between current number and context (or global) mean
policies = ['global','local']
numberdiffs = {"global":dict(list(enumerate([[],[],[]]))), "local":dict(list(enumerate([[],[],[]])))}
globalnumberdiffs = {"global":dict(list(enumerate([[],[],[]]))), "local":dict(list(enumerate([[],[],[]])))}
perf = {"global":dict(list(enumerate([[],[],[]]))), "local":dict(list(enumerate([[],[],[]])))}
for whichpolicy, policy in enumerate(policies):
print('Testing policy: '+ policy)
Ptotal = 0
for whichrange in range(len(localxranges)):
xmin = localxranges[whichrange][0]
xmax = localxranges[whichrange][1]
# Possible values for xA and xB
xvalues = list(range(xmin, xmax+1))
if policy == 'local':
xmean = stats.mean(xvalues)
else:
xmean = globalmean
Na = len(xvalues)
Nb = Na-1 # xA never equals xB
P_a = 1/Na # uniform distribution for sampling a from xA
Pcorrect = 0
for a in xvalues:
P_agreaterB = (a-xmin)/Nb
if (a<=xmean):
Pcorrect_a = (1 - P_agreaterB)
else:
Pcorrect_a = P_agreaterB
Pcorrect += P_a*Pcorrect_a
# distance of current number a to mean (local or global)
numberdiffs[policy][whichrange].append(abs(a-stats.mean(xvalues)))
globalnumberdiffs[policy][whichrange].append(abs(a-globalmean))
perf[policy][whichrange].append(Pcorrect_a)
print(('{:.2f}% correct for range {}, under policy '+policy).format(Pcorrect*100, whichrange))
Ptotal += Pcorrect
Ptotal /= const.NCONTEXTS
print('Mean performance across all 3 ranges with ' + policy + ' policy: {:.2f}%'.format(Ptotal*100))
print('\n')
return numberdiffs, globalnumberdiffs, perf
# ---------------------------------------------------------------------------- #
def plot_theoretical_predictions(ax, numberdiffs, globalnumberdiffs, perf, whichpolicy):
""" This function plots performance under each policy as a function of numerical each distance to context median (context distance).
- plots just the policy specied in 'whichpolicy' i.e. 0=global, 1=local. """
localxranges = [[const.FULLR_LLIM,const.FULLR_ULIM], [const.LOWR_LLIM,const.LOWR_ULIM], [const.HIGHR_LLIM,const.HIGHR_ULIM]]
linestyles = ['solid', 'dotted', 'dashed']
handles = []
policies = ['global', 'local', 'local', 'local', 'local'] # corresponds to same plots as each lesion frequency: 0.0, 0.1, 0.2, 0.3, 0.4
policy = policies[whichpolicy]
for whichrange in range(len(localxranges)):
context_perf, context_numberdiffs = anh.performance_mean(numberdiffs[policy][whichrange], perf[policy][whichrange])
h, = ax.plot(context_numberdiffs, context_perf, color=const.CONTEXT_COLOURS[whichrange])
handles.append(h)
ax.set_ylim([0.27, 1.03])
return handles
# ---------------------------------------------------------------------------- #
if __name__ == '__main__':
simulate_theoretical_policies()