-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[add] GTree based Decision Tree (#367)
# Pull Request ## What problem does this PR solve? Issue Number: Fixed #212 ## Possible side effects? - Performance: **test**: ``` No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) [2023-10-13 18:32:00.502] [info] [thread_pool.cc:30] Create a fixed thread pool with size 19 Accuracy in SKlearn: 0.96 Accuracy in SPU: 0.96 . ---------------------------------------------------------------------- Ran 1 test in 78.026s OK ``` **emulation**: Running time in SPU: 77.37s Accuracy in SKlearn: 0.96 Accuracy in SPU: 0.96 [2023-10-13 18:35:38,091] Shutdown multiprocess cluster... - Backward compatibility:
- Loading branch information
Showing
6 changed files
with
491 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2023 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import time | ||
|
||
import jax.numpy as jnp | ||
from sklearn.datasets import load_iris | ||
from sklearn.tree import DecisionTreeClassifier | ||
|
||
import sml.utils.emulation as emulation | ||
from sml.tree.tree import DecisionTreeClassifier as sml_dtc | ||
|
||
MAX_DEPTH = 3 | ||
CONFIG_FILE = emulation.CLUSTER_ABY3_3PC | ||
|
||
|
||
def emul_tree(mode=emulation.Mode.MULTIPROCESS): | ||
def proc_wrapper(max_depth=2, n_labels=3): | ||
dt = sml_dtc( | ||
max_depth=max_depth, criterion='gini', splitter='best', n_labels=n_labels | ||
) | ||
|
||
def proc(X, y): | ||
dt_fit = dt.fit(X, y) | ||
result = dt_fit.predict(X) | ||
return result | ||
|
||
return proc | ||
|
||
def load_data(): | ||
iris = load_iris() | ||
iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) | ||
# sorted_features: n_samples * n_features_in | ||
n_samples, n_features_in = iris_data.shape | ||
n_labels = len(jnp.unique(iris_label)) | ||
sorted_features = jnp.sort(iris_data, axis=0) | ||
new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 | ||
new_features = jnp.greater_equal( | ||
iris_data[:, :], new_threshold[:, jnp.newaxis, :] | ||
) | ||
new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) | ||
|
||
X, y = new_features[:, ::3], iris_label[:] | ||
return X, y | ||
|
||
try: | ||
# bandwidth and latency only work for docker mode | ||
emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) | ||
emulator.up() | ||
|
||
# load mock data | ||
X, y = load_data() | ||
n_samples = y.shape[0] | ||
n_labels = jnp.unique(y).shape[0] | ||
|
||
# compare with sklearn | ||
clf = DecisionTreeClassifier( | ||
max_depth=MAX_DEPTH, criterion='gini', splitter='best', random_state=None | ||
) | ||
start = time.time() | ||
clf = clf.fit(X, y) | ||
score_plain = clf.score(X, y) | ||
end = time.time() | ||
print(f"Running time in SKlearn: {end - start:.2f}s") | ||
|
||
# mark these data to be protected in SPU | ||
X_spu, y_spu = emulator.seal(X, y) | ||
|
||
# run | ||
proc = proc_wrapper(MAX_DEPTH, n_labels) | ||
start = time.time() | ||
result = emulator.run(proc)(X_spu, y_spu) | ||
end = time.time() | ||
score_encrpted = jnp.sum((result == y)) / n_samples | ||
print(f"Running time in SPU: {end - start:.2f}s") | ||
|
||
# print acc | ||
print(f"Accuracy in SKlearn: {score_plain:.2f}") | ||
print(f"Accuracy in SPU: {score_encrpted:.2f}") | ||
|
||
finally: | ||
emulator.down() | ||
|
||
|
||
if __name__ == "__main__": | ||
emul_tree(emulation.Mode.MULTIPROCESS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright 2023 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import unittest | ||
|
||
import jax.numpy as jnp | ||
from sklearn.datasets import load_iris | ||
from sklearn.tree import DecisionTreeClassifier | ||
|
||
import spu.spu_pb2 as spu_pb2 # type: ignore | ||
import spu.utils.simulation as spsim | ||
from sml.tree.tree import DecisionTreeClassifier as sml_dtc | ||
|
||
MAX_DEPTH = 3 | ||
|
||
|
||
class UnitTests(unittest.TestCase): | ||
def test_tree(self): | ||
def proc_wrapper(max_depth=2, n_labels=3): | ||
dt = sml_dtc("gini", "best", max_depth, n_labels) | ||
|
||
def proc(X, y): | ||
dt_fit = dt.fit(X, y) | ||
result = dt_fit.predict(X) | ||
return result | ||
|
||
return proc | ||
|
||
def load_data(): | ||
iris = load_iris() | ||
iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) | ||
# sorted_features: n_samples * n_features_in | ||
n_samples, n_features_in = iris_data.shape | ||
n_labels = len(jnp.unique(iris_label)) | ||
sorted_features = jnp.sort(iris_data, axis=0) | ||
new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 | ||
new_features = jnp.greater_equal( | ||
iris_data[:, :], new_threshold[:, jnp.newaxis, :] | ||
) | ||
new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) | ||
|
||
X, y = new_features[:, ::3], iris_label[:] | ||
return X, y | ||
|
||
# bandwidth and latency only work for docker mode | ||
sim = spsim.Simulator.simple( | ||
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 | ||
) | ||
|
||
# load mock data | ||
X, y = load_data() | ||
n_samples = y.shape[0] | ||
n_labels = jnp.unique(y).shape[0] | ||
|
||
# compare with sklearn | ||
clf = DecisionTreeClassifier( | ||
max_depth=MAX_DEPTH, criterion='gini', splitter='best', random_state=None | ||
) | ||
clf = clf.fit(X, y) | ||
score_plain = clf.score(X, y) | ||
|
||
# run | ||
proc = proc_wrapper(MAX_DEPTH, n_labels) | ||
result = spsim.sim_jax(sim, proc)(X, y) | ||
score_encrpted = jnp.sum((result == y)) / n_samples | ||
|
||
# print acc | ||
print(f"Accuracy in SKlearn: {score_plain:.2f}") | ||
print(f"Accuracy in SPU: {score_encrpted:.2f}") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.