diff --git a/Project.toml b/Project.toml index f5f4cf6..2a6c764 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJTuning" uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" authors = ["Anthony D. Blaom "] -version = "0.8.6" +version = "0.8.7" [deps] ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" @@ -18,7 +18,7 @@ StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc" ComputationalResources = "0.3" Distributions = "0.22,0.23,0.24, 0.25" LatinHypercubeSampling = "1.7.2" -MLJBase = "1.3" +MLJBase = "1.4" ProgressMeter = "1.7.1" RecipesBase = "0.8,0.9,1" StatisticalMeasuresBase = "0.1.1" diff --git a/src/tuned_models.jl b/src/tuned_models.jl index a77361e..70a35a9 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -905,6 +905,10 @@ function MLJBase.feature_importances(::EitherTunedModel, fitresult, report) return MLJBase.feature_importances(fitresult) end + + + + ## METADATA MLJBase.is_wrapper(::Type{<:EitherTunedModel}) = true @@ -912,10 +916,8 @@ MLJBase.supports_weights(::Type{<:EitherTunedModel{<:Any,M,L}}) where {M,L} = MLJBase.supports_weights(M) MLJBase.supports_class_weights(::Type{<:EitherTunedModel{<:Any,M,L}}) where {M,L} = MLJBase.supports_class_weights(M) -MLJBase.load_path(::Type{<:ProbabilisticTunedModel}) = - "MLJTuning.ProbabilisticTunedModel" -MLJBase.load_path(::Type{<:DeterministicTunedModel}) = - "MLJTuning.DeterministicTunedModel" +MLJBase.load_path(::Type{<:EitherTunedModel}) = + "MLJTuning.TunedModel" MLJBase.package_name(::Type{<:EitherTunedModel}) = "MLJTuning" MLJBase.package_uuid(::Type{<:EitherTunedModel}) = "03970b2e-30c4-11ea-3135-d1576263f10f" @@ -928,3 +930,4 @@ MLJBase.input_scitype(::Type{<:EitherTunedModel{T,M,L}}) where {T,M,L} = MLJBase.input_scitype(M) MLJBase.target_scitype(::Type{<:EitherTunedModel{T,M,L}}) where {T,M,L} = MLJBase.target_scitype(M) +MLJBase.constructor(::Type{<:EitherTunedModel}) = TunedModel diff --git a/test/tuned_models.jl b/test/tuned_models.jl index 9826dae..f99f799 100644 --- a/test/tuned_models.jl +++ b/test/tuned_models.jl @@ -69,6 +69,8 @@ end TunedModel(first(r), last(r), range=r, measure=l2), ) tm = @test_logs TunedModel(model=first(r), range=r, measure=l2) + @test MLJBase.constructor(tm) == TunedModel + @test MLJBase.load_path(tm) == "MLJTuning.TunedModel" @test tm.tuning isa RandomSearch @test input_scitype(tm) == Table(Continuous)