Skip to content

Commit

Permalink
[add] GTree based Decision Tree (#367)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #212 

## Possible side effects?

- Performance:

**test**:

```
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[2023-10-13 18:32:00.502] [info] [thread_pool.cc:30] Create a fixed thread pool with size 19
Accuracy in SKlearn: 0.96
Accuracy in SPU: 0.96
.
----------------------------------------------------------------------
Ran 1 test in 78.026s

OK
```

**emulation**:

Running time in SPU: 77.37s
Accuracy in SKlearn: 0.96
Accuracy in SPU: 0.96
[2023-10-13 18:35:38,091] Shutdown multiprocess cluster...

- Backward compatibility:
  • Loading branch information
ElleryQu authored Oct 24, 2023
1 parent 85b9387 commit 7e51ddb
Show file tree
Hide file tree
Showing 6 changed files with 491 additions and 0 deletions.
9 changes: 9 additions & 0 deletions sml/tree/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_library")

package(default_visibility = ["//visibility:public"])

py_library(
name = "tree",
srcs = ["tree.py"],
)
13 changes: 13 additions & 0 deletions sml/tree/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_binary")

package(default_visibility = ["//visibility:public"])

py_binary(
name = "tree_emul",
srcs = ["tree_emul.py"],
deps = [
"//sml/tree",
"//sml/utils:emulation",
],
)
97 changes: 97 additions & 0 deletions sml/tree/emulations/tree_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import jax.numpy as jnp
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

import sml.utils.emulation as emulation
from sml.tree.tree import DecisionTreeClassifier as sml_dtc

MAX_DEPTH = 3
CONFIG_FILE = emulation.CLUSTER_ABY3_3PC


def emul_tree(mode=emulation.Mode.MULTIPROCESS):
def proc_wrapper(max_depth=2, n_labels=3):
dt = sml_dtc(
max_depth=max_depth, criterion='gini', splitter='best', n_labels=n_labels
)

def proc(X, y):
dt_fit = dt.fit(X, y)
result = dt_fit.predict(X)
return result

return proc

def load_data():
iris = load_iris()
iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target)
# sorted_features: n_samples * n_features_in
n_samples, n_features_in = iris_data.shape
n_labels = len(jnp.unique(iris_label))
sorted_features = jnp.sort(iris_data, axis=0)
new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2
new_features = jnp.greater_equal(
iris_data[:, :], new_threshold[:, jnp.newaxis, :]
)
new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1)

X, y = new_features[:, ::3], iris_label[:]
return X, y

try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20)
emulator.up()

# load mock data
X, y = load_data()
n_samples = y.shape[0]
n_labels = jnp.unique(y).shape[0]

# compare with sklearn
clf = DecisionTreeClassifier(
max_depth=MAX_DEPTH, criterion='gini', splitter='best', random_state=None
)
start = time.time()
clf = clf.fit(X, y)
score_plain = clf.score(X, y)
end = time.time()
print(f"Running time in SKlearn: {end - start:.2f}s")

# mark these data to be protected in SPU
X_spu, y_spu = emulator.seal(X, y)

# run
proc = proc_wrapper(MAX_DEPTH, n_labels)
start = time.time()
result = emulator.run(proc)(X_spu, y_spu)
end = time.time()
score_encrpted = jnp.sum((result == y)) / n_samples
print(f"Running time in SPU: {end - start:.2f}s")

# print acc
print(f"Accuracy in SKlearn: {score_plain:.2f}")
print(f"Accuracy in SPU: {score_encrpted:.2f}")

finally:
emulator.down()


if __name__ == "__main__":
emul_tree(emulation.Mode.MULTIPROCESS)
14 changes: 14 additions & 0 deletions sml/tree/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_python//python:defs.bzl", "py_test")

package(default_visibility = ["//visibility:public"])

py_test(
name = "tree_test",
srcs = ["tree_test.py"],
deps = [
"//sml/tree",
"//spu:init",
"//spu/utils:simulation",
],
)
85 changes: 85 additions & 0 deletions sml/tree/tests/tree_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

import jax.numpy as jnp
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

import spu.spu_pb2 as spu_pb2 # type: ignore
import spu.utils.simulation as spsim
from sml.tree.tree import DecisionTreeClassifier as sml_dtc

MAX_DEPTH = 3


class UnitTests(unittest.TestCase):
def test_tree(self):
def proc_wrapper(max_depth=2, n_labels=3):
dt = sml_dtc("gini", "best", max_depth, n_labels)

def proc(X, y):
dt_fit = dt.fit(X, y)
result = dt_fit.predict(X)
return result

return proc

def load_data():
iris = load_iris()
iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target)
# sorted_features: n_samples * n_features_in
n_samples, n_features_in = iris_data.shape
n_labels = len(jnp.unique(iris_label))
sorted_features = jnp.sort(iris_data, axis=0)
new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2
new_features = jnp.greater_equal(
iris_data[:, :], new_threshold[:, jnp.newaxis, :]
)
new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1)

X, y = new_features[:, ::3], iris_label[:]
return X, y

# bandwidth and latency only work for docker mode
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

# load mock data
X, y = load_data()
n_samples = y.shape[0]
n_labels = jnp.unique(y).shape[0]

# compare with sklearn
clf = DecisionTreeClassifier(
max_depth=MAX_DEPTH, criterion='gini', splitter='best', random_state=None
)
clf = clf.fit(X, y)
score_plain = clf.score(X, y)

# run
proc = proc_wrapper(MAX_DEPTH, n_labels)
result = spsim.sim_jax(sim, proc)(X, y)
score_encrpted = jnp.sum((result == y)) / n_samples

# print acc
print(f"Accuracy in SKlearn: {score_plain:.2f}")
print(f"Accuracy in SPU: {score_encrpted:.2f}")


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 7e51ddb

Please sign in to comment.