Skip to content

Commit

Permalink
added dummy-classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Jan 29, 2024
1 parent 168eacd commit aaf6dca
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
unreleased
- added dummy classifier


0.7.2
- fixed verify ns

Expand Down
37 changes: 36 additions & 1 deletion src/scicloj/metamorph/ml/classification.clj
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
(ns scicloj.metamorph.ml.classification
(:require [tech.v3.dataset :as ds]
[tech.v3.datatype.pprint :as dtype-pp]))
[tech.v3.dataset.modelling :as ds-mod]
[tech.v3.datatype.pprint :as dtype-pp]
[scicloj.metamorph.ml :as ml]))


(defn- safe-inc
[item]
Expand Down Expand Up @@ -63,6 +66,9 @@
(confusion-map->ds conf-matrix-map :all)))





#_(defn confusion-ds
[model test-ds]
(let [predictions (ml/predict model test-ds)
Expand All @@ -73,3 +79,32 @@
(comment
(confusion-map [:a :b :c :a] [:a :c :c :a] :all))


(defn- get-majority-class [target-ds]
(let [target-column-name (first
(ds-mod/inference-target-column-names target-ds))]
(->>
(-> target-ds (get target-column-name) frequencies)
(sort-by :second)
reverse
first
first)))


(ml/define-model! :metamorph.ml/dummy-classifier
(fn [feature-ds target-ds options]
(let [target-column-name (first
(ds-mod/inference-target-column-names target-ds))]
{:majority-class (get-majority-class target-ds)
:distinct-labels (-> target-ds (get target-column-name) distinct)}))

(fn [feature-ds thawed-model {:keys [options model-data] :as model}]
(let [ target-column-name (-> model :target-columns first)
dummy-labels (case (:dummy-strategy options)
:majority-class (repeat (:majority-class model-data))
:fixed-class (repeat (:fixed-class options))
:random-class (repeatedly (fn [] (rand-nth (:distinct-labels model-data)))))]


(ds/add-or-update-column feature-ds target-column-name dummy-labels)))
{})
42 changes: 40 additions & 2 deletions test/scicloj/metamorph/classification_test.clj
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
(ns scicloj.metamorph.classification-test
(:require [scicloj.metamorph.ml.classification :refer :all]
[clojure.test :refer :all]))
(:require [scicloj.metamorph.ml.classification :refer [confusion-map]]
[clojure.test :refer :all]
[scicloj.metamorph.ml :as ml]
[tech.v3.dataset :as ds]
[scicloj.metamorph.ml.toydata :as toydata]))


(deftest test-normalized
Expand All @@ -13,3 +16,38 @@
(confusion-map [:a :b :c :a] [:a :c :c :a])
{:a {:a 1.0}
:c {:b 0.5 :c 0.5}})))


(deftest dummy-classification-fixed-label []
(let [ds (toydata/iris-ds)
model (ml/train ds {:model-type :metamorph.ml/dummy-classifier
:dummy-strategy :fixed-class
:fixed-class 0})

prediction (ml/predict ds model)]

(is (= (:species prediction) (repeat 150 0)))))


(deftest dummy-classification-majority []
(let [ds (toydata/breast-cancer-ds)
model (ml/train ds {:model-type :metamorph.ml/dummy-classifier
:dummy-strategy :majority-class})


prediction (ml/predict ds model)]

(is (= (:class prediction) (repeat 569 0)))))



(deftest dummy-classification-random []
(let [ds (toydata/breast-cancer-ds)
model (ml/train ds {:model-type :metamorph.ml/dummy-classifier
:dummy-strategy :random-class})


prediction (ml/predict ds model)]
(def prediction prediction)
(is (= [0 1] (-> prediction :class distinct)))))

0 comments on commit aaf6dca

Please sign in to comment.