-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcall_ep_all1.m
243 lines (214 loc) · 9.38 KB
/
call_ep_all1.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
%call_ep_all1
%version 16/10/2019
%
%Event prediction (EP)
%*********************
% This is AP's name for a form of RL where there is no explicit
% reward. EP error (directly akin to reward pred error) is used to train the
% prediction of the n+1th event by the nth event.
%
%This set of Matlab scripts/fxns simulates event prediction learning for tasks where a sequence of
%events occur and a prediction of the upcoming events is required.
%The simulation can be under a variety of models
%such as:-
%Classic Q learning
%
%This pgm may call the following functions
%
%evpred1pt<n>.m : this function runs event prediction (EP) learning
% <n> denotes version number
%ballscorer.m : this function scores performance and summarises it
% along with graphs of performance
%
%the next 3 functions are used only when computing the transitions from
%random variables (tasktype=1 below), and so are not currently used
%trans_prob.m : this function sets up the transition probabilities that are required
%maketrans.m : this function uses the m matrix to set values for use with
% [0,1] uniform random values to control transitions with
% right probabilities
%choose_first.m : this function chooses the first event of the sequence at random
%
%it also requires the following files to read in trial sequences
%balls_info.xlsx : An Excel file with a quasi-random sequence of balls
%
%at the moment you need to change settings within code in areas marekd with
%@@@ -- use search to locate them
%
%to do
%
%a) save the data from subjects to files for later scoring
%b) vary the task that is being simulated (with choice options)
%c) add numerous random response sequences not just one from file, using
% rand numbers, as tasksetup=3
%d) add code to use transition probabilities here (tasksetup=2)
%e) add choice of other models eg Bayesian
%f) create a version to fit simulated data to real performance
%g) add the opion for repeated trials where there is no prediction between
% end of one trial and the next
clear variables;
clc;
rng('default'); %set the random number seed
%@@@initialise key task features
tasksetup=2; %1=generated by transition probabilities between events; 2= sequences read from file
ne = 3; %number of events
lenseq = 270; %length of the event sequence
%seqname='triplets'; %one of the sequences for tasktype=1
%@@@initialise simulation features
nsubjects=30;
%checksubs=[2 4]; % this will display various checks on simulated subjects 2 and 4
checksubs=nsubjects+1; %ensures no subject checks are displayed
wantshow=0; %0 (noshow) or 1 (show) trial by trial during learning
%@@@simulation settings, via a structure array
simsettings.RLformat='classic'; %uses RPE to adjust weights
simsettings.PREDformat='sm_choice'; %choice made using softmax
%simsettings.PREDformat='nochoice'; %no responses choices are made
simsettings.dolearn=1;
%simulation parameters
param.alpha=0.5;
param.beta=8;
param.basewt=0.1; %this is a trivial parameter which sets the size of the initial random weights
%setting up the task sequence in various ways
if tasksetup==1
%create a transitions matrix to generate probabilities of transition from one event to another
%to be added in here
elseif tasksetup==2
%read the fixed event sequence from a file
%1=blue; 2=red; 3=green
mycolour=zeros(lenseq,nsubjects);
[numdata,txtdata,balldata]=xlsread('\balls_info.xlsx');
[nr, nc]=size(numdata);
%read in color of ball seen (same for all subjects)
mycolour(strcmp(txtdata(2:nr+1,2),{'blue'}),:)=1;
mycolour(strcmp(txtdata(2:nr+1,2),{'red'}),:)=2;
mycolour(strcmp(txtdata(2:nr+1,2),{'green'}),:)=3;
%mycolour(lenseq+1,1)=1; %just so we have a next event for the final trial
end;
%initialise arrays for recording outputs
evseq=zeros(lenseq,nsubjects);
act_trans=zeros(ne,ne,nsubjects); %array for recording actual transitions
pred_error=zeros(ne,lenseq,nsubjects);
pred_choice=zeros(lenseq,nsubjects);
n_of_ev=zeros(ne,nsubjects);
for k=1:nsubjects
disp(['Simulating subject ' num2str(k)]);
%initialise key simulation variables
pred_u=zeros(1,ne); %set up the vector for the prediction units
%set up the random small weights from the prediction units to the
%predictions
pred_wt=param.basewt.*rand(ne,ne);
%now run the sequence
for i=1:lenseq
if i<lenseq
simsettings.dolearn=1;
else
simsettings.dolearn=0; %no learning possible on final trial as don't see next event
end;
if tasksetup==2
old_ev=mycolour(i);
if i<lenseq
next_ev=mycolour(i+1);
else
next_ev=1; %arbitrarily set next event to be 1
end;
else
%not used here
end;
%now call the event prediction learning function
pred_u=zeros(1,ne);
pred_u(old_ev)=1; %activate the relevant prediction unit which codes predictions for a specific event
if wantshow==0
[fxncheck, pred_wt, delta, pred_ch] = evpred1pt2(ne,old_ev,next_ev,pred_u,pred_wt,param,simsettings, 'noshowp','noshowl') ;%,gamma,alpha_minus)
elseif wantshow==1
[fxncheck, pred_wt, delta, pred_ch] = evpred1pt2(ne,old_ev,next_ev,pred_u,pred_wt,param,simsettings, 'noshowp','showl') ;%,gamma,alpha_minus)
end;
%record sequence
evseq(i,k)=old_ev; %should be same as mycolour for 1:270
%and record numbers of actual transitions
if i>1
act_trans(evseq(i-1),evseq(i),k) = act_trans(evseq(i-1),evseq(i),k)+1;
end;
%and record prediction errors and choices
pred_error(:,i,k)=delta;
switch simsettings.PREDformat
case 'nochoice'
%pred_ch here should return -1 for every choice
otherwise
[maxval, pred_choice(i,k)]=max(pred_ch);
if maxval ~=1
error(emsg_struct);
end;
end;
%end of event sequence
end;
%now optionally test the simulation for selected subjects
if sum(checksubs==k)==1
disp(['Checking subject #' num2str(k)]);
%firt test the event predictions
%here all events are tested but may need adjustment if we have sequences in
%trials
disp(' ');
disp('Predictions for next event at the end of training')
simsettings.dolearn=0;
for i=1:ne
%disp(['Event number ' num2str(i)]);
pred_u=zeros(1,ne);
pred_u(i)=1; %activate the relevant prediction unit which codes predictions for a specific event
%show the predictions for each event based and allow no learning here
%by setting simsettings.dolearn to zero and not passing the updated wt to pred_wt
if i<ne
[fxncheck, ~, ~, ~] = evpred1pt2(ne,i,i+1,pred_u,pred_wt,param,simsettings, 'showp','noshowl'); %,gamma,alpha_minus)
else
[fxncheck, ~, ~, ~] = evpred1pt2(ne,i,1,pred_u,pred_wt,param,simsettings, 'showp','noshowl'); %,gamma,alpha_minus)
end;
end;
%do some housekeeping checks on the sequences
for i=1:ne
n_of_ev(i,k)=sum(evseq(:,k)==i);
end;
%showing sequence details
%there should be some random fluctuation here
disp(' ');
disp(['Number of events= ' num2str(n_of_ev(:,k)')]);
disp(' ');
%and finally we will display the weights
zz=[pred_wt, zeros(ne,1); zeros(1,ne+1)]; %we have to add 0s because of the way pcolor works
figure;
pcolor(zz)
title({'Colour map of prediction weight matrix',['For subject #' num2str(k)]})
%and compare it against the transition probabilities
%may need adjustment when we have trials-based sequences
zz=[act_trans(:,:,1), zeros(ne,1); zeros(1,ne+1)]; %we have to add 0s because of the way pcolor works
figure;
pcolor(zz)
title({'Colour map of actual transition matrix',['For subject #' num2str(k)]})
%end of check on specific subjects
end;
%end of subjects loop
end;
%call up ball scorer
wantsubprint=0; %1=prints results to screen subject by subject, 0 doesn't
ls_type=1; %1=lose shift; 2=lose shift and use current
[mean_accuracy, mean_winstay, mean_lsh, mean_lshuc]=ballscorer(nsubjects, ne, lenseq, pred_choice, mycolour, wantsubprint, ls_type);
%some summary accuracy stats
maccbyball=zeros(ne,1); %mean accuracy by ball
mwsbyball=zeros(ne,1); %mean win stay by ball
mlsbyball=zeros(ne,1); %mean lose shift by ball
mlsucbyball=zeros(ne,1); %mean lose shift use current by ball
x1=mean(mean(mean_accuracy));
x2=mean(mean(mean_winstay));
x3=mean(mean(mean_lsh));
x4=mean(mean(mean_lshuc));
for i=1:ne
maccbyball(i)=x1(:,:,i);
mwsbyball(i)=x2(:,:,i);
mlsbyball(i)=x3(:,:,i);
mlsucbyball(i)=x4(:,:,i);
end;
disp('Mean accuracy by ball position=');
disp(maccbyball)
disp('Mean win-stay by ball position=');
disp(mwsbyball)
disp('Mean lose-shift by ball position=');
disp(mlsbyball)
disp('Mean lose-shift use current by ball position=');
disp(mlsucbyball)