-
Notifications
You must be signed in to change notification settings - Fork 3
/
active_search.py
451 lines (356 loc) · 13.6 KB
/
active_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
import numpy as np
import scipy.special as sc
import scipy as sp
import pystan
import pickle
from enum import Enum
from matplotlib import pyplot as plt
"""
Implementation of InfoGain, EPMV, MCMV methods. See main() for
usage.
"""
class ActiveSearcher():
"""
search object for InfoGain, EPMV, and MCMV methods.
Callable methods:
- initialize: initializes search object
- getQuery: actively selects next search pair
- getEstimate: produces user point estimate
"""
my_model = """
data {
int<lower=0> D; // space dimension
int<lower=0> M; // number of measurements so far
real k; // logistic noise parameter (scale)
vector[2] bounds; // hypercube bounds [lower,upper]
int y[M]; // measurement outcomes
vector[D] A[M]; // hyperplane directions
vector[M] tau; // hyperplane offsets
}
parameters {
vector<lower=bounds[1],upper=bounds[2]>[D] W; // the user point
}
transformed parameters {
vector[M] z;
for (i in 1:M)
z[i] = dot_product(A[i], W) - tau[i];
}
model {
// prior
W ~ uniform(bounds[1],bounds[2]);
// linking observations
y ~ bernoulli_logit(k * z);
}
"""
def __init__(self):
# make Stan model
try:
self.sm = pickle.load(open('model.pkl', 'rb'))
except:
self.sm = pystan.StanModel(model_code=self.my_model)
with open('model.pkl', 'wb') as f:
pickle.dump(self.sm, f)
def initialize(self, embedding, k, normalization, method,
bounds=np.array([-1, 1]), Nchains=4, Nsamples=4000,
pair_sample_rate=10**-3, plotting=False, plot_pause=0.5,
scale_to_embedding=False, ref=None, lambda_pen_MCMV=1,
lambda_pen_EPMV=None):
"""
arguments:
embedding: np.array - an N x d embedding of points
k: noise constant value
normalization: model normalization scheme
method: pair selection method
optional arguments:
bounds: hypercube lower and upper bounds [lb, ub]
Nchains: number of sampling chains
Nsamples: number of posterior samples
pair_sample_rate: downsample rate for pair selection
plotting settings:
plotting: plotting flag (bool)
plot_pause: pause time between plots, in seconds
scale_to_embedding: if True, scale plot to embedding
ref: np.array - d x 1 user point vector
lambda_pen_MCMV: lambda penalty for MCMV method
lambda_pen_EPMV: lambda penalty for EPMV method
"""
self.embedding = embedding
self.k = k
self.method = method
self.normalization = normalization
self.bounds = bounds
self.Nchains = Nchains
self.Nsamples = Nsamples
Niter = int(2*Nsamples/Nchains)
assert Niter >= 1000
self.Niter = Niter
self.N = embedding.shape[0]
self.Npairs = int(pair_sample_rate * sp.special.comb(self.N, 2))
self.D = embedding.shape[1]
self.oracle_queries_made = []
self.mu_W = np.zeros(self.D)
self.A = []
self.tau = []
self.y_vec = []
self.plotting = plotting
self.plot_pause = plot_pause
self.scale_to_embedding = scale_to_embedding
self.ref = ref
self.lambda_pen_MCMV = lambda_pen_MCMV
if lambda_pen_EPMV is None:
self.lambda_pen_EPMV = np.sqrt(self.D)
else:
self.lambda_pen_EPMV = lambda_pen_EPMV
def getQuery(self, oracle):
"""
selects pair for searching
arguments:
oracle: function accepting two indices i,j and returning
sorted pair
arguments:
p: tuple (i,j) of query pair
output: dict with key 'y' where y=1 selects p[0], y=0 selects
p[1]
outputs:
'query': (i,j)
'oracle_output': output of oracle function
"""
# given measurements 0..i, get posterior samples
if not self.A:
W_samples = np.random.uniform(
self.bounds[0], self.bounds[1], (self.Nsamples, self.D))
else:
data_gen = {'D': self.D, 'k': self.k, 'M': len(self.A),
'A': self.A,
'tau': self.tau,
'y': self.y_vec,
'bounds': self.bounds}
# get posterior samples
# num_samples = iter * chains / 2, unless warmup is changed
fit = self.sm.sampling(data=data_gen, iter=self.Niter,
chains=self.Nchains, init=0)
W_samples = fit.extract()['W']
if W_samples.ndim < 2:
W_samples = W_samples[:, np.newaxis]
assert W_samples.shape == (self.Nsamples, self.D)
self.mu_W = np.mean(W_samples, 0)
# generate and evaluate a batch of proposal pairs
Pairs = self.get_random_pairs(self.N, self.Npairs)
if self.method == AdaptType.INFOGAIN:
value = np.zeros((self.Npairs,))
for j in range(self.Npairs):
p = Pairs[j]
(A_emb, tau_emb) = pair2hyperplane(
p, self.embedding, self.normalization)
value[j] = self.evaluate_pair(
A_emb, tau_emb, W_samples, self.k)
p = Pairs[np.argmax(value)]
elif self.method == AdaptType.MCMV:
Wcov = np.cov(W_samples, rowvar=False)
value = np.zeros((self.Npairs,))
for j in range(self.Npairs):
p = Pairs[j]
(A_emb, tau_emb) = pair2hyperplane(
p, self.embedding, self.normalization)
varest = np.dot(A_emb, Wcov).dot(A_emb)
distmu = np.abs(
(np.dot(A_emb, self.mu_W) - tau_emb)
/ np.linalg.norm(A_emb)
)
# choose highest variance, but smallest distance to mean
value[j] = self.k * \
np.sqrt(varest) - self.lambda_pen_MCMV * distmu
p = Pairs[np.argmax(value)]
elif self.method == AdaptType.EPMV:
Wcov = np.cov(W_samples, rowvar=False)
value = np.zeros((self.Npairs,))
for j in range(self.Npairs):
p = Pairs[j]
(A_emb, tau_emb) = pair2hyperplane(
p, self.embedding, self.normalization)
assert np.dot(A_emb, W_samples.T).size == self.Nsamples
varest = np.dot(A_emb, Wcov).dot(A_emb)
p1 = np.mean(sp.special.expit(
self.k*(np.dot(A_emb, W_samples.T) - tau_emb)
))
assert p1.size == 1
value[j] = (
self.k * np.sqrt(varest)
- self.lambda_pen_EPMV * np.abs(p1 - 0.5))
p = Pairs[np.argmax(value)]
else: # random pair method
p = Pairs[0]
(A_sel, tau_sel) = pair2hyperplane(
p, self.embedding, self.normalization)
self.A.append(A_sel)
self.tau = np.append(self.tau, tau_sel)
oracle_out = oracle(p)
y = oracle_out['y']
self.y_vec.append(y)
self.oracle_queries_made.append(p)
# for plotting during experiment
if self.plotting:
# diagnostic
Nsplit = 0
Isplit = []
for j in range(1, W_samples.shape[0]):
z = np.dot(A_sel, W_samples[j, :]) - tau_sel
if z > 0:
Isplit.append(j)
Nsplit += 1
plt.figure(189)
plt.clf()
if self.scale_to_embedding:
ax_min = np.min(self.embedding[:, 0])
ax_max = np.max(self.embedding[:, 0])
else:
ax_min = self.bounds[0]
ax_max = self.bounds[1]
if self.D == 1:
y_samples = np.zeros(self.Nsamples)
y_p0 = 0
y_p1 = 0
ay_min = -1
ay_max = 1
y_ref = 0
else:
y_samples = W_samples[:, 1]
y_p0 = self.embedding[p[0], 1]
y_p1 = self.embedding[p[1], 1]
if self.scale_to_embedding:
ay_min = np.min(self.embedding[:, 1])
ay_max = np.max(self.embedding[:, 1])
else:
ay_min = self.bounds[0]
ay_max = self.bounds[1]
if self.ref is not None:
y_ref = self.ref[1]
plt.axis([ax_min, ax_max, ay_min, ay_max])
plt.plot(W_samples[:, 0], y_samples, 'y.')
plt.plot(W_samples[Isplit, 0], y_samples[Isplit], 'r.')
if self.ref is not None:
plt.plot(self.ref[0], y_ref, 'go')
plt.plot(self.embedding[p[0], 0], y_p0, 'bo')
plt.plot(self.embedding[p[1], 0], y_p1, 'bo')
plt.ion()
plt.pause(self.plot_pause) # for observation
return p, oracle_out
def getEstimate(self):
"""
returns estimate of user point as d x 1 np.array
"""
return self.mu_W
def evaluate_pair(self, a, tau, W_samples, k):
# estimates mutual information of input pair
# mutual information heuristic, larger is better
# NOTE: each row of W_samples is a sample
Lik = self.likelihood_vec(a, tau, W_samples, k)
Ftilde = np.mean(Lik)
mutual_info = self.binary_entropy(Ftilde) - np.mean(
self.binary_entropy(Lik))
return mutual_info
def likelihood_vec(self, a, tau, W, k):
# mutual information support function
# a: (3,) tau: (1,) W: (1000, 3)
z = np.dot(W, a) - tau # broadcasting
return sp.special.expit(k * z)
def binary_entropy(self, x):
# mutual information support function
return -(sc.xlogy(x, x) + sc.xlog1py(1 - x, -x))/np.log(2)
def get_random_pairs(self, N, M):
# pair selection support function
indices = np.random.choice(N, (int(1.5*M), 2))
indices = [(i[0], i[1]) for i in indices if i[0] != i[1]]
assert len(indices) >= M
return indices[0:M]
class AdaptType(Enum):
# enumerate experiment types
RANDOM = 0
INFOGAIN = 1
MCMV = 2
EPMV = 3
ACTRANKQ = 4
class KNormalizationType(Enum):
# enumerate noise constant ytpes
CONSTANT = 0
NORMALIZED = 1
DECAYING = 2
class NoiseModel(Enum):
# enumerate noise model types
BT = 0
NONE = 1
def pair2hyperplane(p, embedding, normalization, slice_point=None):
# converts pair to hyperplane weights and bias
A_emb = 2*(embedding[p[0], :] - embedding[p[1], :])
if slice_point is None:
tau_emb = (np.linalg.norm(embedding[p[0], :])**2
- np.linalg.norm(embedding[p[1], :])**2)
else:
tau_emb = np.dot(A_emb, slice_point)
if normalization == KNormalizationType.CONSTANT:
pass
elif normalization == KNormalizationType.NORMALIZED:
A_mag = np.linalg.norm(A_emb)
A_emb = A_emb / A_mag
tau_emb = tau_emb / A_mag
elif normalization == KNormalizationType.DECAYING:
A_mag = np.linalg.norm(A_emb)
A_emb = A_emb * np.exp(-A_mag)
tau_emb = tau_emb * np.exp(-A_mag)
return (A_emb, tau_emb)
def main():
"""
Example usage of ActiveSearcher:
- generates random embedding of items
- defines search parameters
- defines oracle
- initalize search object
- get paired comparison queries
- get user point estimate
"""
N = 100 # number of items
d = 2 # embedding dimension
max_query = 100 # number of queries to ask
embedding = np.random.randn(N, d) # generate embedding
k = 10 # specify noise constant value
k_normalization = KNormalizationType.NORMALIZED # specify noise constant type
noise_model = NoiseModel.NONE # specify noise model
method = AdaptType.MCMV # specify search method
def oracle(p):
# if y=1, then p[0] selected
(a, tau) = pair2hyperplane(p, embedding, k_normalization)
z = np.dot(a, ref) - tau
if noise_model == NoiseModel.BT:
y = int(np.random.binomial(1, sp.special.expit(k * z)))
else:
y = int(z > 0)
return {'y': y, 'z': z, 'a': a, 'tau': tau}
print("Search points: ")
print(embedding)
bounds = [-1, 1] # define user point prior
ref = np.random.uniform(bounds[0], bounds[1], (d, 1))
print("Reference point: ")
print(ref)
# construct searcher
searcher = ActiveSearcher()
# inialize searcher
searcher.initialize(embedding, k, k_normalization, method,
pair_sample_rate=10**-3, plotting=False, ref=ref,
scale_to_embedding=True)
queries_made = 0
while queries_made < max_query:
# get query, pass oracle
query, response = searcher.getQuery(oracle)
if query is None:
break
queries_made += 1
print('# queries made: {} / {}'.format(queries_made, max_query))
# get user point estimate
user_estimate = searcher.getEstimate()
print("Estimated user point: ")
print(user_estimate)
err = user_estimate - ref.squeeze()
print("Error: ")
print(err)
if __name__ == '__main__':
main()