Skip to content

Commit

Permalink
Merge pull request #26 from alan-turing-institute/n-remaining
Browse files Browse the repository at this point in the history
Add n_remaining argument to models!
  • Loading branch information
ablaom authored Mar 26, 2020
2 parents 7c0bac0 + 780376e commit 862bc50
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 45 deletions.
86 changes: 62 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -514,42 +514,80 @@ used instead.
The fallback is

```julia
MLJTuning.default_n(::TuningStrategy, range) = 10
default_n(tuning::TuningStrategy, range) = DEFAULT_N
```

where `DEFAULT_N` is a global constant. Do `using MLJTuning;
MLJTuning.DEFAULT_N` to see check the current value.

### Implementation example: Search through an explicit list

The most rudimentary tuning strategy just evaluates every model in a
specified list of models sharing a common type, such lists
constituting the only kind of supported range. (In this special case
`range` is an arbitrary iterator of models, which are `Probabilistic`
or `Deterministic`, according to the type of the prototype `model`,
which is otherwise ignored.) The fallback implementations for `setup`,
`result`, `best` and `report_history` suffice. In particular, there
is not distinction between `range` and `state` in this case.
### Implementation example: Search through an explicit list

Here's the complete implementation:
The most rudimentary tuning strategy just evaluates every model
generated by some iterator, such iterators constituting the only kind
of supported range. The models generated must all have a common type
and, in th implementation below, the type information is conveyed by
the specified prototype `model` (which is otherwise ignored). The
fallback implementations for `result`, `best` and `report_history`
suffice.

```julia

import MLJBase

mutable struct Explicit <: TuningStrategy end

mutable struct ExplicitState{R,N}
range::R
next::Union{Nothing,N} # to hold output of `iterate(range)`
end

ExplicitState(r::R, ::Nothing) where R = ExplicitState{R,Nothing}(r,nothing)
ExplictState(r::R, n::N) where {R,N} = ExplicitState{R,Union{Nothing,N}}(r,n)

function MLJTuning.setup(tuning::Explicit, model, range, verbosity)
next = iterate(range)
return ExplicitState(range, next)
end

# models! returns all available models in the range at once:
MLJTuning.models!(tuning::Explicit, model, history::Nothing,
state, verbosity) = state
MLJTuning.models!(tuning::Explicit, model, history,
state, verbosity) = state[length(history) + 1:end]

function MLJTuning.default_n(tuning::Explicit, range)
try
length(range)
catch MethodError
10
end
function MLJTuning.models!(tuning::Explicit,
model,
history,
state,
n_remaining,
verbosity)

range, next = state.range, state.next

next === nothing && return nothing

m, s = next
models = [m, ]

next = iterate(range, s)

i = 1 # current length of `models`
while i < n_remaining
next === nothing && break
m, s = next
push!(models, m)
i += 1
next = iterate(range, s)
end

state.next = next

return models

end

function default_n(tuning::Explicit, range)
try
length(range)
catch MethodError
DEFAULT_N
end
end

```

For slightly less trivial example, see
Expand Down
48 changes: 45 additions & 3 deletions src/strategies/explicit.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,54 @@
mutable struct Explicit <: TuningStrategy end
mutable struct Explicit <: TuningStrategy end

mutable struct ExplicitState{R,N}
range::R # a model-generating iterator
next::Union{Nothing,N} # to hold output of `iterate(range)`
end

ExplicitState(r::R, ::Nothing) where R = ExplicitState{R,Nothing}(r,nothing)
ExplictState(r::R, n::N) where {R,N} = ExplicitState{R,Union{Nothing,N}}(r,n)

function MLJTuning.setup(tuning::Explicit, model, range, verbosity)
next = iterate(range)
return ExplicitState(range, next)
end

# models! returns all available models in the range at once:
function MLJTuning.models!(tuning::Explicit,
model,
history,
state,
n_remaining,
verbosity)
history === nothing && return state
return state[length(history) + 1:end]

range, next = state.range, state.next

next === nothing && return nothing

m, s = next
models = [m, ]

next = iterate(range, s)

i = 1 # current length of `models`
while i < n_remaining
next === nothing && break
m, s = next
push!(models, m)
i += 1
next = iterate(range, s)
end

state.next = next

return models

end

function default_n(tuning::Explicit, range)
try
length(range)
catch MethodError
DEFAULT_N
end
end
12 changes: 7 additions & 5 deletions src/strategies/grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ function setup(tuning::Grid, model, user_range, verbosity)

end

MLJTuning.models!(tuning::Grid, model, history::Nothing,
state, verbosity) = state.models
MLJTuning.models!(tuning::Grid, model, history,
state, verbosity) =
state.models[length(history) + 1:end]
MLJTuning.models!(tuning::Grid,
model,
history,
state,
n_remaining,
verbosity) =
state.models[_length(history) + 1:end]

function tuning_report(tuning::Grid, history, state)

Expand Down
7 changes: 6 additions & 1 deletion src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,12 @@ function build(history,
j = _length(history)
models_exhausted = false
while j < n && !models_exhausted
metamodels = models!(tuning, model, history, state, verbosity)
metamodels = models!(tuning,
model,
history,
state,
n - j,
verbosity)
Δj = _length(metamodels)
Δj == 0 && (models_exhausted = true)
shortfall = n - Δj
Expand Down
9 changes: 2 additions & 7 deletions src/tuning_strategy_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ end
tuning_report(tuning::TuningStrategy, history, state) = (history=history,)

# for declaring the default number of models to evaluate:
function default_n(tuning::TuningStrategy, range)
try
length(range)
catch MethodError
DEFAULT_N
end
end
default_n(tuning::TuningStrategy, range) = DEFAULT_N


24 changes: 19 additions & 5 deletions test/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.4*rand(N);
m(K) = KNNRegressor(K=K)
r = [m(K) for K in 2:13]

# TODO: replace the above with the line below and fix post an issue on
# the failure (a bug in Distributed, I reckon):
# r = (m(K) for K in 2:13)

@testset "constructor" begin
@test_throws ErrorException TunedModel(model=first(r), tuning=Explicit(),
measure=rms)
Expand Down Expand Up @@ -96,20 +100,30 @@ end)

annotate(model) = (model, params(model)[1])

_length(x) = length(x)
_length(::Nothing) = 0
function MLJTuning.models!(tuning::MockExplicit,
model,
history,
state,
n_remaining,
verbosity)
history === nothing && return annotate.(state)
return annotate.(state)[length(history) + 1:end]
return annotate.(state)[_length(history) + 1:end]
end

MLJTuning.result(tuning::MockExplicit, history, state, e, metadata) =
(measure=e.measure, measurement=e.measurement, K=metadata)
end

@test MockExplicit == MockExplicit
function default_n(tuning::Explicit, range)
try
length(range)
catch MethodError
DEFAULT_N
end

end

end

@testset_accelerated("passing of model metadata", accel,
(exclude=[CPUThreads],), begin
Expand All @@ -119,7 +133,7 @@ end
fitresult, meta_state, report = fit(tm, 0, X, y);
history, _, state = meta_state;
for (m, r) in history
#@test m.K == r.K
@test m.K == r.K
end
end)

Expand Down

0 comments on commit 862bc50

Please sign in to comment.