Skip to content

Commit

Permalink
Cache3 (#29)
Browse files Browse the repository at this point in the history
*  support caching train / predict
* added ppmap support
* added ppmap grain options
* added metrics-and-options-keep-fn
* move all eval handler to ns scicloj.metamorph.ml.evaluation-handler
  • Loading branch information
behrica authored Dec 27, 2024
1 parent e7e1dd7 commit a6fc51c
Show file tree
Hide file tree
Showing 15 changed files with 576 additions and 298 deletions.
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
"ghcr.io/devcontainers-contrib/features/bash-command:1": {
"command": "apt-get update && apt-get install -y rlwrap"
},
"ghcr.io/rocker-org/devcontainer-features/quarto-cli:1": {}
"ghcr.io/rocker-org/devcontainer-features/quarto-cli:1": {},
"ghcr.io/itsmechlark/features/redis-server:1": {}


},
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ bigdata.old/
zeros1G.bin

measures.csv
dump.rdb
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
unreleased
* allow parameters for :fastmath/ols (fixes #27)
* added optional redis caching for train / predict
* added metrics-and-model-keep-fn
* added option :ppmap with :ppmap-grain-size 10 to ml/eval-pipelines
* added more evaluation-handler fns suitable for model-spec search
* breaking: move all eval handler to ns scicloj.metamorph.ml.evaluation-handler
* added :probability-distributin to ml/eval-pipelines result


0.10.4
* added :target-datatypes in train result and clarified expected 'shape' of prediction
Expand Down
23 changes: 12 additions & 11 deletions deps.edn
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
scicloj/metamorph {:mvn/version "0.2.4"}
pppmap/pppmap {:mvn/version "1.0.0"}
scicloj/tablecloth {:mvn/version "7.029.2"}
metosin/malli {:mvn/version "0.16.0"}
metosin/malli {:mvn/version "0.17.0"}
generateme/fastmath {:mvn/version "3.0.0-alpha2"}
it.unimi.dsi/fastutil {:mvn/version "8.5.15"}
org.scicloj/tableplot {:mvn/version "1-alpha13"}
aerial.hanami/aerial.hanami {:mvn/version "0.20.0"
;; we only need hanami templating
:exclusions [org.clojure/clojurescript
Expand All @@ -23,7 +22,10 @@
cljsjs/vega-embed
cljsjs/vega-tooltip]}
;; needed by hanami
org.clojure/data.json {:mvn/version "0.2.6"}}
org.clojure/data.json {:mvn/version "0.2.6"}

com.taoensso/nippy {:mvn/version "3.4.2"}
}

:paths ["src" "resources"]

Expand All @@ -46,12 +48,10 @@

:dev
{:jvm-opts ["-Djava.awt.headless=true"]
:extra-deps {
io.github.nextjournal/clerk {:mvn/version "0.17.1102"}
:extra-deps {io.github.nextjournal/clerk {:mvn/version "0.17.1102"}
org.scicloj/clay {:mvn/version "2-beta23"}
scicloj/scicloj.ml.smile {:mvn/version "7.4.1"}


org.scicloj/scicloj.ml.smile {:mvn/version "7.4.4"}
org.scicloj/tableplot {:mvn/version "1-alpha13"}
datacraft-sciences/confuse {:mvn/version "0.1.1"}
ch.qos.logback/logback-classic {:mvn/version "1.5.6"}
criterium/criterium {:mvn/version "0.4.6"}
Expand All @@ -62,19 +62,20 @@
:extra-paths ["test"]
:extra-deps {com.clojure-goes-fast/clj-memory-meter {:mvn/version "0.3.0"}
lambdaisland/kaocha {:mvn/version "1.88.1376"}
scicloj/scicloj.ml.smile {:mvn/version "7.4.1"}
org.scicloj/scicloj.ml.smile {:mvn/version "7.4.4"}
datacraft-sciences/confuse {:mvn/version "0.1.1"}
ch.qos.logback/logback-classic {:mvn/version "1.5.6"}
org.mapdb/mapdb {:mvn/version "3.1.0"}
}}
:runner {:main-opts ["-m" "kaocha.runner"]}

:exp {:jvm-opts ["-Djdk.attach.allowAttachSelf" "-Xmx8G" "--add-opens=java.base/java.io=ALL-UNNAMED"]
:extra-paths ["exp"]
:extra-paths ["exp" "test"]
:extra-deps {com.clojure-goes-fast/clj-memory-meter {:mvn/version "0.3.0"}
ch.qos.logback/logback-classic {:mvn/version "1.5.6"}
criterium/criterium {:mvn/version "0.4.6"}
}}
org.scicloj/scicloj.ml.smile {:mvn/version "7.4.4"}
com.taoensso/carmine {:mvn/version "3.4.1"}}}

:smoke-test {:jvm-opts ["-Djdk.attach.allowAttachSelf" "-Xmx1G" "--add-opens=java.base/java.io=ALL-UNNAMED"]
:extra-paths ["exp"]
Expand Down
138 changes: 138 additions & 0 deletions exp/cache.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
(ns cache
(:require
[clojure.pprint :as pp]
[scicloj.metamorph.core :as mm]
[scicloj.metamorph.ml :as ml]
[scicloj.metamorph.ml.evaluation-handler :as eval-handler]
[scicloj.metamorph.ml.gridsearch :as gs]
[scicloj.metamorph.ml.loss :as loss]
[scicloj.ml.smile.classification]
[tablecloth.api :as tc]
[tablecloth.pipeline :as tc-mm]
[taoensso.carmine :as car]
[tech.v3.dataset :as ds]
[tech.v3.dataset.metamorph :as mds]
[tech.v3.dataset.modelling :as ds-mod]
[tablecloth.column.api :as tcc]))

(def titanic-train
(->
(ds/->dataset "https://github.com/scicloj/metamorph-examples/raw/main/data/titanic/train.csv"
{:key-fn keyword})))

(def titanic-test
(->
(ds/->dataset "https://github.com/scicloj/metamorph-examples/raw/main/data/titanic/test.csv"
{:key-fn keyword})
(tc/add-column :Survived 0)))


(defn pipe-fn [options]
(mm/pipeline
(mds/select-columns [:Pclass :Survived :Embarked :Sex])


(tc-mm/add-or-replace-column :Survived (fn [ds] (map #(case %
1 "yes"
0 "no")
(:Survived ds))))
(mds/categorical->number [:Survived :Sex :Embarked])
(tc-mm/replace-missing )
(mds/set-inference-target :Survived)

;; (fn [ctx]
;; (assoc ctx :options (dissoc options
;; :cache-opts))
;; )
{:metamorph/id :model}
(ml/model options)))

(defonce my-conn-pool (car/connection-pool {}))
(def my-conn-spec {:uri "redis://localhost:6379"})
(def my-wcar-opts {:pool my-conn-pool, :spec my-conn-spec})


;(ns-unmap *ns* 'cache-map)
(defonce cache-map (atom {}))


;; (reset! ml/kv-cache {:use-cache true
;; :get-fn (fn [key] (car/wcar my-wcar-opts (car/get key)))
;; :set-fn (fn [key value] (car/wcar my-wcar-opts (car/set key value)))})

(reset! ml/train-predict-cache {:use-cache true
:get-fn (fn [key] (get @cache-map key))
:set-fn (fn [key value] (swap! cache-map assoc key value))})


(defn pipe-fns [model-type hyper-params n]
(->>
(map
#(pipe-fn
(assoc %
:model-type model-type))
(gs/sobol-gridsearch hyper-params))
(take n))
)
(def n 100)
(def all-piep-fns
(concat
(pipe-fns :smile.classification/decision-tree
(ml/hyperparameters :smile.classification/decision-tree)
n)
(pipe-fns :smile.classification/logistic-regression
(ml/hyperparameters :smile.classification/logistic-regression)
n)

(pipe-fns :smile.classification/ada-boost
(ml/hyperparameters :smile.classification/ada-boost)
n)
(pipe-fns :smile.classification/random-forest
(ml/hyperparameters :smile.classification/random-forest)

n))
)

(time

(let [eval-result
(ml/evaluate-pipelines
all-piep-fns
[{:train
titanic-train
:test titanic-test}]
loss/classification-accuracy
:accuracy
{:return-best-pipeline-only false
:return-best-crossvalidation-only false
:evaluation-handler-fn (fn [result]
(def result result)

(eval-handler/metrics-and-options-keep-fn result))})]
(def eval-result eval-result)
(pp/pprint
(-> eval-result

first
first
(#(hash-map :options (get-in % [:fit-ctx :model :options])
:train-accuracy (get-in % [:train-transform :metric])
:test-accuracy (get-in % [:test-transform :metric])))))))


(def datasets
(map
(fn [result]
(tc/dataset
(merge (-> result :test-transform (select-keys [:metric]))
(-> result :fit-ctx :model :options)))
)
(-> eval-result flatten)))

(def metrices
(apply tc/concat datasets))

(-> metrices
(tc/group-by :model-type)
(tc/aggregate (fn [ds]
(tcc/mean (:metric ds)))))
Loading

0 comments on commit a6fc51c

Please sign in to comment.