Skip to content

Commit

Permalink
Merge pull request #988 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 1.7 release
  • Loading branch information
ablaom authored Jul 19, 2024
2 parents 0849be7 + d65ed1f commit 4e8c087
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "1.6"
version = "1.7.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
46 changes: 35 additions & 11 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ const ERR_INVALID_OPERATION = ArgumentError(
_ambiguous_operation(model, measure) =
"`$measure` does not support a `model` with "*
"`prediction_type(model) == :$(prediction_type(model))`. "
err_ambiguous_operation(model, measure) = ArgumentError(
_ambiguous_operation(model, measure)*
"\nUnable to infer an appropriate operation for `$measure`. "*
"Explicitly specify `operation=...` or `operations=...`. ")
err_incompatible_prediction_types(model, measure) = ArgumentError(
_ambiguous_operation(model, measure)*
"If your model is truly making probabilistic predictions, try explicitly "*
Expand Down Expand Up @@ -65,11 +61,37 @@ ERR_MEASURES_DETERMINISTIC(measure) = ArgumentError(
"and so is not supported by `$measure`. "*LOG_AVOID
)

# ==================================================================
## MODEL TYPES THAT CAN BE EVALUATED
err_ambiguous_operation(model, measure) = ArgumentError(
_ambiguous_operation(model, measure)*
"\nUnable to infer an appropriate operation for `$measure`. "*
"Explicitly specify `operation=...` or `operations=...`. "*
"Possible value(s) are: $PREDICT_OPERATIONS_STRING. "
)

const ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError(
"""
# not exported:
const Measurable = Union{Supervised, Annotator}
The `prediction_type` of your model needs to be one of: `:deterministic`,
`:probabilistic`, or `:interval`. Does your model implement one of these operations:
$PREDICT_OPERATIONS_STRING? If so, you can try explicitly specifying `operation=...`
or `operations=...` (and consider posting an issue to have the model review it's
definition of `MLJModelInterface.prediction_type`). Otherwise, performance
evaluation is not supported.
"""
)

const ERR_NEED_TARGET = ArgumentError(
"""
To evaluate a model's performance you must provide a target variable `y`, as in
`evaluate(model, X, y; options...)` or
mach = machine(model, X, y)
evaluate!(mach; options...)
"""
)

# ==================================================================
## RESAMPLING STRATEGIES
Expand Down Expand Up @@ -987,7 +1009,7 @@ function _actual_operations(operation::Nothing,
throw(err_ambiguous_operation(model, m))
end
else
throw(err_ambiguous_operation(model, m))
throw(ERR_UNSUPPORTED_PREDICTION_TYPE)
end
end
end
Expand Down Expand Up @@ -1137,7 +1159,7 @@ See also [`evaluate`](@ref), [`PerformanceEvaluation`](@ref),
"""
function evaluate!(
mach::Machine{<:Measurable};
mach::Machine;
resampling=CV(),
measures=nothing,
measure=measures,
Expand All @@ -1160,6 +1182,8 @@ function evaluate!(
# weights, measures, operations, and dispatches a
# strategy-specific `evaluate!`

length(mach.args) > 1 || throw(ERR_NEED_TARGET)

repeats > 0 || error("Need `repeats > 0`. ")

if resampling isa TrainTestPairs
Expand Down Expand Up @@ -1235,7 +1259,7 @@ Returns a [`PerformanceEvaluation`](@ref) object.
See also [`evaluate!`](@ref).
"""
evaluate(model::Measurable, args...; cache=true, kwargs...) =
evaluate(model::Model, args...; cache=true, kwargs...) =
evaluate!(machine(model, args...; cache=cache); kwargs...)

# -------------------------------------------------------------------
Expand Down
32 changes: 31 additions & 1 deletion test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ end
struct DummyInterval <: Interval end
dummy_interval=DummyInterval()

struct GoofyTransformer <: Unsupervised end

dummy_measure_det(yhat, y) = 42
API.@trait(
typeof(dummy_measure_det),
Expand Down Expand Up @@ -115,6 +117,12 @@ API.@trait(
MLJBase.err_ambiguous_operation(dummy_interval, LogLoss()),
MLJBase._actual_operations(nothing,
[LogLoss(), ], dummy_interval, 1))

# model does not have a valid `prediction_type`:
@test_throws(
MLJBase.ERR_UNSUPPORTED_PREDICTION_TYPE,
MLJBase._actual_operations(nothing, [LogLoss(),], GoofyTransformer(), 0),
)
end

@everywhere begin
Expand Down Expand Up @@ -935,7 +943,29 @@ end
end
end

# DUMMY LOGGER

# # TRANSFORMER WITH PREDICT

struct PredictingTransformer <:Unsupervised end
MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing)
MLJBase.fit(::PredictingTransformer, verbosity, X) = (nothing, nothing, nothing)
MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X))
MLJBase.predict(::PredictingTransformer, ::Nothing, X) = nothing
MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic

@testset "`Unsupervised` model with a predict" begin
X = rand(10)
y = fill(42.0, 10)
e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2)
@test e.measurement[1] 0
@test_throws(
MLJBase.ERR_NEED_TARGET,
evaluate(PredictingTransformer(), X, measure=l2),
)
end


# # DUMMY LOGGER

struct DummyLogger end

Expand Down

0 comments on commit 4e8c087

Please sign in to comment.