Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For a 0.8.1 release #204

Merged
merged 8 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
coverage:
status:
project:
default:
threshold: 0.5%
patch:
default:
target: 80%
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJTuning"
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.8.0"
version = "0.8.1"

[deps]
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Expand Down
2 changes: 1 addition & 1 deletion src/plotrecipes.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@recipe function f(mach::MLJBase.Machine{<:EitherTunedModel})
rep = report(mach)
measurement = string(typeof(rep.best_history_entry.measure[1]))
measurement = repr(rep.best_history_entry.measure[1])

Check warning on line 3 in src/plotrecipes.jl

View check run for this annotation

Codecov / codecov/patch

src/plotrecipes.jl#L3

Added line #L3 was not covered by tests
r = rep.plotting
z = r.measurements
X = r.parameter_values
Expand Down
30 changes: 24 additions & 6 deletions src/strategies/explicit.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
const WARN_INCONSISTENT_PREDICTION_TYPE =
"Not all models to be evaluated have the same prediction type, and this may "*
"cause problems for some measures. For example, a probabilistic metric "*
"like `log_loss` cannot be applied to a model making point (deterministic) "*
"predictions. Inspect the prediction type with "*
"`prediction_type(model)`. "

mutable struct Explicit <: TuningStrategy end

struct ExplicitState{R, N}
range::R # a model-generating iterator
next::N # to hold output of `iterate(range)`
next::N # to hold output of `iterate(range)`
prediction_type::Symbol
user_warned::Bool
end

ExplictState(r::R, n::N) where {R,N} = ExplicitState{R, Union{Nothing, N}}(r, n)

function MLJTuning.setup(tuning::Explicit, model, range, n, verbosity)
next = iterate(range)
return ExplicitState(range, next)
return ExplicitState(range, next, MLJBase.prediction_type(model), false)
end

# models! returns as many models as possible but no more than `n_remaining`:
Expand All @@ -20,11 +27,21 @@ function MLJTuning.models(tuning::Explicit,
n_remaining,
verbosity)

range, next = state.range, state.next
range, next, prediction_type, user_warned =
state.range, state.next, state.prediction_type, state.user_warned

function check(m)
if !user_warned && verbosity > -1 && MLJBase.prediction_type(m) != prediction_type
@warn WARN_INCONSISTENT_PREDICTION_TYPE
user_warned = true
end
end

next === nothing && return nothing, state

m, s = next
check(m)

models = Any[m, ] # types not known until run-time

next = iterate(range, s)
Expand All @@ -33,12 +50,13 @@ function MLJTuning.models(tuning::Explicit,
while i < n_remaining
next === nothing && break
m, s = next
check(m)
push!(models, m)
i += 1
next = iterate(range, s)
end

new_state = ExplicitState(range, next)
new_state = ExplicitState(range, next, prediction_type, user_warned)

return models, new_state

Expand Down
4 changes: 3 additions & 1 deletion src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,11 @@ function event!(metamodel,
state)
model = _first(metamodel)
metadata = _last(metamodel)
force = typeof(resampling_machine.model.model) !=
typeof(model)
resampling_machine.model.model = model
verb = (verbosity >= 2 ? verbosity - 3 : verbosity - 1)
fit!(resampling_machine, verbosity=verb)
fit!(resampling_machine; verbosity=verb, force)
E = evaluate(resampling_machine)
entry0 = (model = model,
measure = E.measure,
Expand Down
39 changes: 39 additions & 0 deletions test/strategies/explicit.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
good = KNNClassifier(K=2)
bad = KNNClassifier(K=10)
ugly = ConstantClassifier()
evil = DeterministicConstantClassifier()

r = [good, bad, ugly]

Expand Down Expand Up @@ -44,4 +45,42 @@ X, y = make_blobs(rng=rng)
@test_throws ArgumentError TunedModel(; models=[dcc, dcc])
end

r = [good, bad, evil, ugly]

@testset "inconsistent prediction types" begin
# case where different predictions types is actually okay (but still
# a warning is issued):
tmodel = TunedModel(
models=r,
resampling = Holdout(),
measure=accuracy,
)
@test_logs(
(:warn, MLJTuning.WARN_INCONSISTENT_PREDICTION_TYPE),
MLJBase.fit(tmodel, 0, X, y),
);

# verbosity = -1 suppresses the warning:
@test_logs(
MLJBase.fit(tmodel, -1, X, y),
);

# case where there really is a problem with different prediction types:
tmodel = TunedModel(
models=r,
resampling = Holdout(),
measure=log_loss,
)
@test_logs(
(:warn, MLJTuning.WARN_INCONSISTENT_PREDICTION_TYPE),
(:error,),
(:info,),
(:info,),
@test_throws(
ArgumentError, # indicates the problem is with incompatible measure
MLJBase.fit(tmodel, 0, X, y),
)
)
end

true
2 changes: 1 addition & 1 deletion test/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ results = [(evaluate(model, X, y,
tm = TunedModel(
models=r,
resampling=CV(nfolds=2),
measures=cross_entropy
measures=cross_entropy,
)
@test_logs((:error, r"Problem"),
(:info, r""),
Expand Down
Loading