diff --git a/README.md b/README.md index 2b24f64..8f5f7e7 100644 --- a/README.md +++ b/README.md @@ -514,42 +514,80 @@ used instead. The fallback is ```julia -MLJTuning.default_n(::TuningStrategy, range) = 10 +default_n(tuning::TuningStrategy, range) = DEFAULT_N ``` +where `DEFAULT_N` is a global constant. Do `using MLJTuning; +MLJTuning.DEFAULT_N` to see check the current value. -### Implementation example: Search through an explicit list -The most rudimentary tuning strategy just evaluates every model in a -specified list of models sharing a common type, such lists -constituting the only kind of supported range. (In this special case -`range` is an arbitrary iterator of models, which are `Probabilistic` -or `Deterministic`, according to the type of the prototype `model`, -which is otherwise ignored.) The fallback implementations for `setup`, -`result`, `best` and `report_history` suffice. In particular, there -is not distinction between `range` and `state` in this case. +### Implementation example: Search through an explicit list -Here's the complete implementation: +The most rudimentary tuning strategy just evaluates every model +generated by some iterator, such iterators constituting the only kind +of supported range. The models generated must all have a common type +and, in th implementation below, the type information is conveyed by +the specified prototype `model` (which is otherwise ignored). The +fallback implementations for `result`, `best` and `report_history` +suffice. ```julia -import MLJBase - mutable struct Explicit <: TuningStrategy end +mutable struct ExplicitState{R,N} + range::R + next::Union{Nothing,N} # to hold output of `iterate(range)` +end + +ExplicitState(r::R, ::Nothing) where R = ExplicitState{R,Nothing}(r,nothing) +ExplictState(r::R, n::N) where {R,N} = ExplicitState{R,Union{Nothing,N}}(r,n) + +function MLJTuning.setup(tuning::Explicit, model, range, verbosity) + next = iterate(range) + return ExplicitState(range, next) +end + # models! returns all available models in the range at once: -MLJTuning.models!(tuning::Explicit, model, history::Nothing, - state, verbosity) = state -MLJTuning.models!(tuning::Explicit, model, history, - state, verbosity) = state[length(history) + 1:end] - -function MLJTuning.default_n(tuning::Explicit, range) - try - length(range) - catch MethodError - 10 - end +function MLJTuning.models!(tuning::Explicit, + model, + history, + state, + n_remaining, + verbosity) + + range, next = state.range, state.next + + next === nothing && return nothing + + m, s = next + models = [m, ] + + next = iterate(range, s) + + i = 1 # current length of `models` + while i < n_remaining + next === nothing && break + m, s = next + push!(models, m) + i += 1 + next = iterate(range, s) + end + + state.next = next + + return models + +end + +function default_n(tuning::Explicit, range) + try + length(range) + catch MethodError + DEFAULT_N + end end + ``` For slightly less trivial example, see diff --git a/src/strategies/explicit.jl b/src/strategies/explicit.jl index 23c721e..2cfe913 100644 --- a/src/strategies/explicit.jl +++ b/src/strategies/explicit.jl @@ -1,12 +1,54 @@ -mutable struct Explicit <: TuningStrategy end +mutable struct Explicit <: TuningStrategy end + +mutable struct ExplicitState{R,N} + range::R # a model-generating iterator + next::Union{Nothing,N} # to hold output of `iterate(range)` +end + +ExplicitState(r::R, ::Nothing) where R = ExplicitState{R,Nothing}(r,nothing) +ExplictState(r::R, n::N) where {R,N} = ExplicitState{R,Union{Nothing,N}}(r,n) + +function MLJTuning.setup(tuning::Explicit, model, range, verbosity) + next = iterate(range) + return ExplicitState(range, next) +end # models! returns all available models in the range at once: function MLJTuning.models!(tuning::Explicit, model, history, state, + n_remaining, verbosity) - history === nothing && return state - return state[length(history) + 1:end] + + range, next = state.range, state.next + + next === nothing && return nothing + + m, s = next + models = [m, ] + + next = iterate(range, s) + + i = 1 # current length of `models` + while i < n_remaining + next === nothing && break + m, s = next + push!(models, m) + i += 1 + next = iterate(range, s) + end + + state.next = next + + return models + end +function default_n(tuning::Explicit, range) + try + length(range) + catch MethodError + DEFAULT_N + end +end diff --git a/src/strategies/grid.jl b/src/strategies/grid.jl index bb84ac1..5c31266 100644 --- a/src/strategies/grid.jl +++ b/src/strategies/grid.jl @@ -104,11 +104,13 @@ function setup(tuning::Grid, model, user_range, verbosity) end -MLJTuning.models!(tuning::Grid, model, history::Nothing, - state, verbosity) = state.models -MLJTuning.models!(tuning::Grid, model, history, - state, verbosity) = - state.models[length(history) + 1:end] +MLJTuning.models!(tuning::Grid, + model, + history, + state, + n_remaining, + verbosity) = + state.models[_length(history) + 1:end] function tuning_report(tuning::Grid, history, state) diff --git a/src/tuned_models.jl b/src/tuned_models.jl index b737044..8372031 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -276,7 +276,12 @@ function build(history, j = _length(history) models_exhausted = false while j < n && !models_exhausted - metamodels = models!(tuning, model, history, state, verbosity) + metamodels = models!(tuning, + model, + history, + state, + n - j, + verbosity) Δj = _length(metamodels) Δj == 0 && (models_exhausted = true) shortfall = n - Δj diff --git a/src/tuning_strategy_interface.jl b/src/tuning_strategy_interface.jl index 20ef123..4c58b18 100644 --- a/src/tuning_strategy_interface.jl +++ b/src/tuning_strategy_interface.jl @@ -29,11 +29,6 @@ end tuning_report(tuning::TuningStrategy, history, state) = (history=history,) # for declaring the default number of models to evaluate: -function default_n(tuning::TuningStrategy, range) - try - length(range) - catch MethodError - DEFAULT_N - end -end +default_n(tuning::TuningStrategy, range) = DEFAULT_N + diff --git a/test/tuned_models.jl b/test/tuned_models.jl index e1148e6..18e4340 100644 --- a/test/tuned_models.jl +++ b/test/tuned_models.jl @@ -23,6 +23,10 @@ y = 2*x1 .+ 5*x2 .- 3*x3 .+ 0.4*rand(N); m(K) = KNNRegressor(K=K) r = [m(K) for K in 2:13] +# TODO: replace the above with the line below and fix post an issue on +# the failure (a bug in Distributed, I reckon): +# r = (m(K) for K in 2:13) + @testset "constructor" begin @test_throws ErrorException TunedModel(model=first(r), tuning=Explicit(), measure=rms) @@ -96,20 +100,30 @@ end) annotate(model) = (model, params(model)[1]) + _length(x) = length(x) + _length(::Nothing) = 0 function MLJTuning.models!(tuning::MockExplicit, model, history, state, + n_remaining, verbosity) - history === nothing && return annotate.(state) - return annotate.(state)[length(history) + 1:end] + return annotate.(state)[_length(history) + 1:end] end MLJTuning.result(tuning::MockExplicit, history, state, e, metadata) = (measure=e.measure, measurement=e.measurement, K=metadata) -end -@test MockExplicit == MockExplicit + function default_n(tuning::Explicit, range) + try + length(range) + catch MethodError + DEFAULT_N + end + + end + +end @testset_accelerated("passing of model metadata", accel, (exclude=[CPUThreads],), begin @@ -119,7 +133,7 @@ end fitresult, meta_state, report = fit(tm, 0, X, y); history, _, state = meta_state; for (m, r) in history - #@test m.K == r.K + @test m.K == r.K end end)