Skip to content

Commit

Permalink
Merge pull request #209 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.8.2 release
  • Loading branch information
ablaom authored Mar 7, 2024
2 parents 60ad344 + c081bd0 commit 39d6cb4
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJTuning"
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.8.1"
version = "0.8.2"

[deps]
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Expand Down
49 changes: 40 additions & 9 deletions src/tuned_models.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## TYPES AND CONSTRUCTOR
# TYPES AND CONSTRUCTOR

const ERR_SPECIFY_MODEL = ArgumentError(
"You need to specify `model=...`, unless `tuning=Explicit()`. ")
Expand Down Expand Up @@ -687,7 +687,7 @@ function finalize(tuned_model,
history,
state,
verbosity,
rm,
resampling_machine,
data...)
model = tuned_model.model
tuning = tuned_model.tuning
Expand All @@ -713,7 +713,7 @@ function finalize(tuned_model,
end

report = merge(report1, tuning_report(tuning, history, state))
meta_state = (history, deepcopy(tuned_model), model_buffer, state, rm)
meta_state = (history, deepcopy(tuned_model), model_buffer, state, resampling_machine)

return fitresult, meta_state, report
end
Expand Down Expand Up @@ -749,9 +749,16 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
history, state = build!(nothing, n, tuning, model, model_buffer, state,
verbosity, acceleration, resampling_machine)

rm = resampling_machine
return finalize(tuned_model, model_buffer,
history, state, verbosity, rm, data...)

return finalize(
tuned_model,
model_buffer,
history,
state,
verbosity,
resampling_machine,
data...,
)

end

Expand Down Expand Up @@ -784,9 +791,15 @@ function MLJBase.update(tuned_model::EitherTunedModel,
history, state = build!(history, n!, tuning, model, model_buffer, state,
verbosity, acceleration, resampling_machine)

rm = resampling_machine
return finalize(tuned_model, model_buffer,
history, state, verbosity, rm, data...)
return finalize(
tuned_model,
model_buffer,
history,
state,
verbosity,
resampling_machine,
data...,
)
else
return fit(tuned_model, verbosity, data...)
end
Expand All @@ -806,6 +819,24 @@ function MLJBase.fitted_params(tuned_model::EitherTunedModel, fitresult)
end


## FORWARD SERIALIZATION METHODS FROM ATOMIC MODEL

const ERR_SERIALIZATION = ErrorException(
"Attempting to serialize a `TunedModel` instance whose best model has not "*
"been trained. It appears as if it was trained with `train_best=false`. "*
"Try re-training using `train_best=true`. "
)

# `fitresult` is `machine(best_model, data...)`, trained iff `train_best` hyperparameter
# is `true`.
function MLJBase.save(tmodel::EitherTunedModel, fitresult)
MLJBase.age(fitresult) > 0 || throw(ERR_SERIALIZATION)
return MLJBase.serializable(fitresult)
end
MLJBase.restore(tmodel::EitherTunedModel, serializable_fitresult) =
MLJBase.restore!(serializable_fitresult)


## SUPPORT FOR MLJ ITERATION API

MLJBase.iteration_parameter(::Type{<:EitherTunedModel}) = :n
Expand Down
49 changes: 49 additions & 0 deletions test/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,4 +406,53 @@ end
MLJBase.Tables.istable(mach.cache[end].fitresult.machine.data[1])
end

# define a supervised model with ephemeral `fitresult`, but which overcomes this by
# overloading `save`/`restore`:
thing = []
struct EphemeralRegressor <: Deterministic end
function MLJBase.fit(::EphemeralRegressor, verbosity, X, y)
# if I serialize/deserialized `thing` then `id` below changes:
id = objectid(thing)
fitresult = (thing, id, mean(y))
return fitresult, nothing, NamedTuple()
end
function MLJBase.predict(::EphemeralRegressor, fitresult, X)
thing, id, μ = fitresult
return id == objectid(thing) ? fill(μ, nrows(X)) :
throw(ErrorException("dead fitresult"))
end
function MLJBase.save(::EphemeralRegressor, fitresult)
thing, _, μ = fitresult
return (thing, μ)
end
function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
thing, μ = serialized_fitresult
id = objectid(thing)
return (thing, id, μ)
end

@testset "save and restore" begin
# https://github.com/JuliaAI/MLJTuning.jl/issues/207
X, y = (; x = rand(10)), fill(42.0, 3)
tmodel = TunedModel(
models=fill(EphemeralRegressor(), 2),
measure=l2,
resampling=Holdout(),
train_best=false,
)
mach = machine(tmodel, X, y)
fit!(mach, verbosity=0)
io = IOBuffer()
@test_throws MLJTuning.ERR_SERIALIZATION MLJBase.save(io, mach)
close(io)
tmodel.train_best = true
fit!(mach, verbosity=0)
io = IOBuffer()
@test_logs MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
close(io)
@test MLJBase.predict(mach2, (; x = rand(2))) fill(42.0, 2)
end

true

0 comments on commit 39d6cb4

Please sign in to comment.