diff --git a/CHANGELOG.md b/CHANGELOG.md index b16fc50..0f92ed7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +unreleased +- added dummy classifier + + 0.7.2 - fixed verify ns diff --git a/src/scicloj/metamorph/ml/classification.clj b/src/scicloj/metamorph/ml/classification.clj index ddb273c..6d84de7 100644 --- a/src/scicloj/metamorph/ml/classification.clj +++ b/src/scicloj/metamorph/ml/classification.clj @@ -1,6 +1,9 @@ (ns scicloj.metamorph.ml.classification (:require [tech.v3.dataset :as ds] - [tech.v3.datatype.pprint :as dtype-pp])) + [tech.v3.dataset.modelling :as ds-mod] + [tech.v3.datatype.pprint :as dtype-pp] + [scicloj.metamorph.ml :as ml])) + (defn- safe-inc [item] @@ -63,6 +66,9 @@ (confusion-map->ds conf-matrix-map :all))) + + + #_(defn confusion-ds [model test-ds] (let [predictions (ml/predict model test-ds) @@ -73,3 +79,32 @@ (comment (confusion-map [:a :b :c :a] [:a :c :c :a] :all)) + +(defn- get-majority-class [target-ds] + (let [target-column-name (first + (ds-mod/inference-target-column-names target-ds))] + (->> + (-> target-ds (get target-column-name) frequencies) + (sort-by :second) + reverse + first + first))) + + +(ml/define-model! :metamorph.ml/dummy-classifier + (fn [feature-ds target-ds options] + (let [target-column-name (first + (ds-mod/inference-target-column-names target-ds))] + {:majority-class (get-majority-class target-ds) + :distinct-labels (-> target-ds (get target-column-name) distinct)})) + + (fn [feature-ds thawed-model {:keys [options model-data] :as model}] + (let [ target-column-name (-> model :target-columns first) + dummy-labels (case (:dummy-strategy options) + :majority-class (repeat (:majority-class model-data)) + :fixed-class (repeat (:fixed-class options)) + :random-class (repeatedly (fn [] (rand-nth (:distinct-labels model-data)))))] + + + (ds/add-or-update-column feature-ds target-column-name dummy-labels))) + {}) diff --git a/test/scicloj/metamorph/classification_test.clj b/test/scicloj/metamorph/classification_test.clj index 94a4d41..5aa3b9f 100644 --- a/test/scicloj/metamorph/classification_test.clj +++ b/test/scicloj/metamorph/classification_test.clj @@ -1,6 +1,9 @@ (ns scicloj.metamorph.classification-test - (:require [scicloj.metamorph.ml.classification :refer :all] - [clojure.test :refer :all])) + (:require [scicloj.metamorph.ml.classification :refer [confusion-map]] + [clojure.test :refer :all] + [scicloj.metamorph.ml :as ml] + [tech.v3.dataset :as ds] + [scicloj.metamorph.ml.toydata :as toydata])) (deftest test-normalized @@ -13,3 +16,38 @@ (confusion-map [:a :b :c :a] [:a :c :c :a]) {:a {:a 1.0} :c {:b 0.5 :c 0.5}}))) + + +(deftest dummy-classification-fixed-label [] + (let [ds (toydata/iris-ds) + model (ml/train ds {:model-type :metamorph.ml/dummy-classifier + :dummy-strategy :fixed-class + :fixed-class 0}) + + prediction (ml/predict ds model)] + + (is (= (:species prediction) (repeat 150 0))))) + + +(deftest dummy-classification-majority [] + (let [ds (toydata/breast-cancer-ds) + model (ml/train ds {:model-type :metamorph.ml/dummy-classifier + :dummy-strategy :majority-class}) + + + prediction (ml/predict ds model)] + + (is (= (:class prediction) (repeat 569 0))))) + + + +(deftest dummy-classification-random [] + (let [ds (toydata/breast-cancer-ds) + model (ml/train ds {:model-type :metamorph.ml/dummy-classifier + :dummy-strategy :random-class}) + + + prediction (ml/predict ds model)] + (def prediction prediction) + (is (= [0 1] (-> prediction :class distinct))))) +