Skip to content

Commit

Permalink
Merge pull request #170 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.7.0 release
  • Loading branch information
ablaom authored Apr 6, 2022
2 parents 31f0255 + 32b502e commit ebf0983
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 102 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.3'
- '1.6'
- '1'
os:
- ubuntu-latest
Expand Down
26 changes: 3 additions & 23 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJTuning"
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.6.16"
version = "0.7.0"

[deps]
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Expand All @@ -17,27 +17,7 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
ComputationalResources = "0.3"
Distributions = "0.22,0.23,0.24, 0.25"
LatinHypercubeSampling = "1.7.2"
MLJBase = "0.18.19, 0.19"
MLJModelInterface = "0.4.1, 1.1.1"
MLJBase = "0.20"
ProgressMeter = "1.7.1"
RecipesBase = "0.8,0.9,1"
julia = "1.3"

[extras]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CategoricalArrays", "DecisionTree", "Distances", "Distributions", "LinearAlgebra", "MLJModelInterface", "MultivariateStats", "NearestNeighbors", "ScientificTypes", "StableRNGs", "Statistics", "StatsBase", "Tables", "Test"]
julia = "1.6"
3 changes: 2 additions & 1 deletion src/MLJTuning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ export learning_curve!, learning_curve
import MLJBase
using MLJBase
import MLJBase: Bounded, Unbounded, DoublyUnbounded,
LeftUnbounded, RightUnbounded, _process_accel_settings, chunks
LeftUnbounded, RightUnbounded, _process_accel_settings, chunks,
restore, save
using RecipesBase
using Distributed
import Distributions
Expand Down
15 changes: 6 additions & 9 deletions src/learning_curves.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,6 @@ Other key-word options are documented at [`TunedModel`](@ref).
learning_curve(mach::Machine{<:Supervised}; kwargs...) =
learning_curve(mach.model, mach.args...; kwargs...)

# for backwards compatibility
function learning_curve!(mach::Machine{<:Supervised}; kwargs...)
Base.depwarn("`learning_curve!` is deprecated, use `learning_curve` instead. ",
Core.Typeof(learning_curve!).name.mt.name)
learning_curve(mach; kwargs...)
end

function learning_curve(model::Supervised, args...;
resolution=30,
resampling=Holdout(),
Expand Down Expand Up @@ -299,8 +292,12 @@ end

n_threads = Threads.nthreads()
if n_threads == 1
return _tuning_results(rngs, CPU1(),
tuned, rng_name, verbosity)
return _tuning_results(rngs,
CPU1(),
tuned,
rows,
rng_name,
verbosity)
end

old_rng = recursive_getproperty(tuned.model.model, rng_name)
Expand Down
7 changes: 7 additions & 0 deletions src/serialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
MLJModelInterface.save(::MLJTuning.EitherTunedModel, fitresult::Machine) =
serializable(fitresult)

function MLJModelInterface.restore(::MLJTuning.EitherTunedModel, fitresult)
fitresult.fitresult = restore(fitresult.model, fitresult.fitresult)
return fitresult
end
9 changes: 6 additions & 3 deletions src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ const EitherTunedModel{T,M} =
#todo update:
"""
tuned_model = TunedModel(; model=<model to be mutated>,
tuning=Grid(),
tuning=RandomSearch(),
resampling=Holdout(),
range=nothing,
measure=nothing,
Expand Down Expand Up @@ -173,7 +173,10 @@ plus other key/value pairs specific to the `tuning` strategy.
- `models`: Alternatively, an iterator of MLJ models to be explicitly
evaluated. These may have varying types.
- `tuning=Grid()`: tuning strategy to be applied (eg, `RandomSearch()`)
- `tuning=RandomSearch()`: tuning strategy to be applied (eg, `Grid()`). See
the [Tuning
Models](https://alan-turing-institute.github.io/MLJ.jl/dev/tuning_models/#Tuning-Models)
section of the MLJ manual for a complete list of options.
- `resampling=Holdout()`: resampling strategy (eg, `Holdout()`, `CV()`),
`StratifiedCV()`) to be applied in performance evaluations
Expand Down Expand Up @@ -253,7 +256,7 @@ function TunedModel(; model=nothing,
throw(ERR_NEED_EXPLICIT)
end
else
tuning === nothing && (tuning = Grid())
tuning === nothing && (tuning = RandomSearch())
end

# either a `model` is specified or we are in the case
Expand Down
36 changes: 36 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LatinHypercubeSampling = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
CategoricalArrays = "0.10"
ComputationalResources = "0.3"
DecisionTree = "0.10"
Distances = "0.10"
Distributions = "0.25"
MLJBase = "0.20"
MLJModelInterface = "1.3"
MultivariateStats = "0.9"
NearestNeighbors = "0.4"
ScientificTypes = "3.0"
StableRNGs = "1.0"
StatsBase = "0.33"
Tables = "1.6"
11 changes: 0 additions & 11 deletions test/learning_curves.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,5 @@ end

end

@testset "deprecation of learning_curve!" begin
atom = KNNRegressor()
mach = machine(atom, X, y)
r = range(atom, :K, lower=1, upper=2)
@test_deprecated learning_curve!(mach;
range=r,
measure=LPLoss(),
verbosity=0)

end

end # module
true
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ end
@test include("learning_curves.jl")
end

@testset "Serialization" begin
@test include("serialization.jl")
end

# @testset "julia bug" begin
# @test include("julia_bug.jl")
# end
Expand Down
73 changes: 73 additions & 0 deletions test/serialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

module TestSerialization

using Test
using MLJBase
using Serialization
using MLJTuning
using ..Models

function test_args(mach)
# Check source nodes are empty if any
for arg in mach.args
if arg isa Source
@test arg == source()
end
end
end

function test_data(mach)
@test !isdefined(mach, :old_rows)
@test !isdefined(mach, :data)
@test !isdefined(mach, :resampled_data)
@test !isdefined(mach, :cache)
end

function generic_tests(mach₁, mach₂)
test_args(mach₂)
test_data(mach₂)
@test mach₂.state == -1
for field in (:frozen, :model, :old_model, :old_upstream_state, :fit_okay)
@test getfield(mach₁, field) == getfield(mach₂, field)
end
end


@testset "Test TunedModel" begin
filename = "tuned_model.jls"
X, y = make_regression(100)
base_model = DecisionTreeRegressor()
tuned_model = TunedModel(
model=base_model,
tuning=Grid(),
range=[range(base_model, :min_samples_split, values=[2,3,4])],
)
mach = machine(tuned_model, X, y)
fit!(mach, rows=1:50, verbosity=0)
smach = MLJBase.serializable(mach)
@test smach.fitresult isa Machine
@test smach.report == mach.report
generic_tests(mach, smach)

Serialization.serialize(filename, smach)
smach = Serialization.deserialize(filename)
MLJBase.restore!(smach)

@test MLJBase.predict(smach, X) == MLJBase.predict(mach, X)
@test fitted_params(smach) isa NamedTuple
@test report(smach) == report(mach)

rm(filename)

# End to end
MLJBase.save(filename, mach)
smach = machine(filename)
@test predict(smach, X) == predict(mach, X)

rm(filename)

end

end

true
Loading

0 comments on commit ebf0983

Please sign in to comment.