-
Notifications
You must be signed in to change notification settings - Fork 0
/
test4.py
86 lines (72 loc) · 2.34 KB
/
test4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
from numba import njit
import matplotlib.pyplot as plt
from src.utils import AMP_points_runner, check_saved, bayes_optimal_runner, load_file
n = 100
d = 1000
delta = 0.5
delta_small = 0.5
delta_large = 3.0
deltas_large = [1.0, 3.0, 5.0]
eps = 0.1
# beta = 0.0
n_alpha_points = 15
if __name__ == "__main__":
AMP_experimental_settings = [
{
"alpha_min": 0.5,
"alpha_max": 100,
"alpha_pts": 20,
"repetitions": 10,
"n_features": d,
"delta": dl,
# "delta_small": delta_small,
# "delta_large": dl,
# "percentage": eps,
"experiment_type": "GAMP",
}
for dl in deltas_large
]
BO_settings = [
{
"alpha_min": 0.01,
"alpha_max": 100,
"alpha_pts": 40,
"delta": dl,
# "delta_small": delta_small,
# "delta_large": dl,
# "percentage": eps,
"experiment_type": "BO",
}
for dl in deltas_large
]
# AMP_points_runner(**AMP_exp_setting)
alphas_bo = [None] * len(deltas_large)
errors_bo = [None] * len(deltas_large)
alphas_amp = [None] * len(deltas_large)
errors_mean_amp = [None] * len(deltas_large)
errors_std_amp = [None] * len(deltas_large)
for idx, (amp_dict, bo_dict, dl) in enumerate(
zip(AMP_experimental_settings, BO_settings, deltas_large)
):
print("Doing dl: ", dl)
file_exists, file_path = check_saved(**bo_dict)
if not file_exists:
bayes_optimal_runner(**bo_dict)
bo_dict.update({"file_path": file_path})
alphas_bo[idx], errors_bo[idx] = load_file(**bo_dict)
print("here")
file_exists, file_path = check_saved(**amp_dict)
if not file_exists:
AMP_points_runner(**amp_dict)
bo_dict.update({"file_path": file_path})
alphas_amp[idx], errors_mean_amp[idx], errors_std_amp[idx] = load_file(**amp_dict)
for idx, (a_bo, e_bo, a_amp, em_amp, es_amp) in enumerate(
zip(alphas_bo, errors_bo, alphas_amp, errors_mean_amp, errors_std_amp)
):
plt.plot(a_bo, e_bo, label=deltas_large[idx])
plt.errorbar(a_amp, em_amp, es_amp, label=deltas_large[idx])
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.show()