-
Notifications
You must be signed in to change notification settings - Fork 12
/
selection_heuristics.jl
71 lines (56 loc) · 2.24 KB
/
selection_heuristics.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
const ERR_LOSSES =
ArgumentError("Tuning selection heuristic does not "*
"support losses() function. ")
abstract type SelectionHeuristic end
losses(heuristic, history) = throw(ERR_LOSSES)
## HELPERS
measure_adjusted_weights(weights, measures) =
if weights isa Nothing
vcat([signature(measures[1]), ], zeros(length(measures) - 1))
else
length(weights) == length(measures) ||
throw(DimensionMismatch(
"`OptimizeAggregatedMeasurement` heuristic "*
"is being applied to a list of measures whose length "*
"differs from that of the specified `weights`. "))
signature.(measures) .* weights
end
## OPTIMIZE AGGREGATED MEASURE
"""
NaiveSelection(; weights=nothing)
Construct a common selection heuristic for use with `TunedModel` instances
which only considers measurements aggregated over all samples (folds)
in resampling.
For each entry in the tuning history, one defines a penalty equal to
the evaluations of the `measure` specified in the `TunedModel`
instance, aggregated over all samples, and multiplied by `-1` if `measure`
is a `:score`, and `+`` if it is a loss. The heuristic declares as
"best" (optimal) the model whose corresponding entry has the lowest
penalty.
If `measure` is a vector, then the first element is used, unless
per-measure `weights` are explicitly specified. Weights associated
with measures that are neither `:loss` nor `:score` are reset to zero.
"""
struct NaiveSelection <: SelectionHeuristic
weights::Union{Nothing, Vector{Real}}
end
function NaiveSelection(; weights=nothing)
if weights isa Vector
all(x -> x >= 0, weights) ||
error("`weights` must be non-negative. ")
end
return NaiveSelection(weights)
end
function losses(heuristic::NaiveSelection, history)
first_entry = history[1]
measures = first_entry.measure
weights = measure_adjusted_weights(heuristic.weights, measures)
return [weights'*(h.measurement) for h in history]
end
function best(heuristic::NaiveSelection, history)
measurements = losses(heuristic, history)
best_index = argmin(measurements)
return history[best_index]
end
MLJTuning.supports_heuristic(::Any, ::NaiveSelection) =
true