Skip to content

Commit

Permalink
Merge pull request #41 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.4.3 release
  • Loading branch information
ablaom authored Jun 3, 2024
2 parents 97122c3 + 615b95a commit 439e368
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJEnsembles"
uuid = "50ed68f4-41fd-4504-931a-ed422449fee0"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.4.2"
version = "0.4.3"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -21,7 +21,7 @@ CategoricalArrays = "0.8, 0.9, 0.10"
CategoricalDistributions = "0.1.2"
ComputationalResources = "0.3"
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
MLJModelInterface = "0.4.1, 1.1"
MLJModelInterface = "1.10"
ProgressMeter = "1.1"
ScientificTypesBase = "2,3"
StatisticalMeasuresBase = "0.1"
Expand Down
8 changes: 2 additions & 6 deletions src/ensembles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ _reducer(p, q) = vcat(p, q)
_reducer(p::Tuple, q::Tuple) = (vcat(p[1], q[1]), vcat(p[2], q[2]))



# # ENSEMBLE MODEL TYPES

mutable struct DeterministicEnsembleModel{Atom<:Deterministic} <: Deterministic
Expand Down Expand Up @@ -638,11 +637,8 @@ end

# Note: input and target traits are inherited from atom

MMI.load_path(::Type{<:ProbabilisticEnsembleModel}) =
"MLJ.ProbabilisticEnsembleModel"
MMI.load_path(::Type{<:DeterministicEnsembleModel}) =
"MLJ.DeterministicEnsembleModel"

MMI.load_path(::Type{<:EitherEnsembleModel}) = "MLJEnsembles.EnsembleModel"
MMI.constructor(::Type{<:EitherEnsembleModel}) = EnsembleModel
MMI.is_wrapper(::Type{<:EitherEnsembleModel}) = true
MMI.supports_weights(::Type{<:EitherEnsembleModel{Atom}}) where Atom =
MMI.supports_weights(Atom)
Expand Down
3 changes: 3 additions & 0 deletions test/ensembles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ X = MLJEnsembles.table(ones(5,3))
y = categorical(collect("asdfa"))
train, test = partition(1:length(y), 0.8);
ensemble_model = EnsembleModel(model=atom)
@test constructor(ensemble_model) == EnsembleModel
@test load_path(ensemble_model) == "MLJEnsembles.EnsembleModel"
@test package_name(ensemble_model) == "MLJEnsembles"
ensemble_model.n = 10
fitresult, cache, report = MLJEnsembles.fit(ensemble_model, 0, X, y)
predict(ensemble_model, fitresult, MLJEnsembles.selectrows(X, test))
Expand Down

0 comments on commit 439e368

Please sign in to comment.