diff --git a/Project.toml b/Project.toml index 80203bc0..550d6897 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "1.6" +version = "1.7.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/resampling.jl b/src/resampling.jl index 7f1eb970..0f85bf60 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -31,10 +31,6 @@ const ERR_INVALID_OPERATION = ArgumentError( _ambiguous_operation(model, measure) = "`$measure` does not support a `model` with "* "`prediction_type(model) == :$(prediction_type(model))`. " -err_ambiguous_operation(model, measure) = ArgumentError( - _ambiguous_operation(model, measure)* - "\nUnable to infer an appropriate operation for `$measure`. "* - "Explicitly specify `operation=...` or `operations=...`. ") err_incompatible_prediction_types(model, measure) = ArgumentError( _ambiguous_operation(model, measure)* "If your model is truly making probabilistic predictions, try explicitly "* @@ -65,11 +61,37 @@ ERR_MEASURES_DETERMINISTIC(measure) = ArgumentError( "and so is not supported by `$measure`. "*LOG_AVOID ) -# ================================================================== -## MODEL TYPES THAT CAN BE EVALUATED +err_ambiguous_operation(model, measure) = ArgumentError( + _ambiguous_operation(model, measure)* + "\nUnable to infer an appropriate operation for `$measure`. "* + "Explicitly specify `operation=...` or `operations=...`. "* + "Possible value(s) are: $PREDICT_OPERATIONS_STRING. " +) + +const ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError( + """ -# not exported: -const Measurable = Union{Supervised, Annotator} + The `prediction_type` of your model needs to be one of: `:deterministic`, + `:probabilistic`, or `:interval`. Does your model implement one of these operations: + $PREDICT_OPERATIONS_STRING? If so, you can try explicitly specifying `operation=...` + or `operations=...` (and consider posting an issue to have the model review it's + definition of `MLJModelInterface.prediction_type`). Otherwise, performance + evaluation is not supported. + + """ +) + +const ERR_NEED_TARGET = ArgumentError( + """ + + To evaluate a model's performance you must provide a target variable `y`, as in + `evaluate(model, X, y; options...)` or + + mach = machine(model, X, y) + evaluate!(mach; options...) + + """ +) # ================================================================== ## RESAMPLING STRATEGIES @@ -987,7 +1009,7 @@ function _actual_operations(operation::Nothing, throw(err_ambiguous_operation(model, m)) end else - throw(err_ambiguous_operation(model, m)) + throw(ERR_UNSUPPORTED_PREDICTION_TYPE) end end end @@ -1137,7 +1159,7 @@ See also [`evaluate`](@ref), [`PerformanceEvaluation`](@ref), """ function evaluate!( - mach::Machine{<:Measurable}; + mach::Machine; resampling=CV(), measures=nothing, measure=measures, @@ -1160,6 +1182,8 @@ function evaluate!( # weights, measures, operations, and dispatches a # strategy-specific `evaluate!` + length(mach.args) > 1 || throw(ERR_NEED_TARGET) + repeats > 0 || error("Need `repeats > 0`. ") if resampling isa TrainTestPairs @@ -1235,7 +1259,7 @@ Returns a [`PerformanceEvaluation`](@ref) object. See also [`evaluate!`](@ref). """ -evaluate(model::Measurable, args...; cache=true, kwargs...) = +evaluate(model::Model, args...; cache=true, kwargs...) = evaluate!(machine(model, args...; cache=cache); kwargs...) # ------------------------------------------------------------------- diff --git a/test/resampling.jl b/test/resampling.jl index fbf26777..62812eec 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -25,6 +25,8 @@ end struct DummyInterval <: Interval end dummy_interval=DummyInterval() +struct GoofyTransformer <: Unsupervised end + dummy_measure_det(yhat, y) = 42 API.@trait( typeof(dummy_measure_det), @@ -115,6 +117,12 @@ API.@trait( MLJBase.err_ambiguous_operation(dummy_interval, LogLoss()), MLJBase._actual_operations(nothing, [LogLoss(), ], dummy_interval, 1)) + + # model does not have a valid `prediction_type`: + @test_throws( + MLJBase.ERR_UNSUPPORTED_PREDICTION_TYPE, + MLJBase._actual_operations(nothing, [LogLoss(),], GoofyTransformer(), 0), + ) end @everywhere begin @@ -935,7 +943,29 @@ end end end -# DUMMY LOGGER + +# # TRANSFORMER WITH PREDICT + +struct PredictingTransformer <:Unsupervised end +MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing) +MLJBase.fit(::PredictingTransformer, verbosity, X) = (nothing, nothing, nothing) +MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X)) +MLJBase.predict(::PredictingTransformer, ::Nothing, X) = nothing +MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic + +@testset "`Unsupervised` model with a predict" begin + X = rand(10) + y = fill(42.0, 10) + e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2) + @test e.measurement[1] ≈ 0 + @test_throws( + MLJBase.ERR_NEED_TARGET, + evaluate(PredictingTransformer(), X, measure=l2), + ) +end + + +# # DUMMY LOGGER struct DummyLogger end