diff --git a/sml/tree/BUILD.bazel b/sml/tree/BUILD.bazel index 7832e732..439ac5ea 100644 --- a/sml/tree/BUILD.bazel +++ b/sml/tree/BUILD.bazel @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "tree", + srcs = ["tree.py"], +) diff --git a/sml/tree/emulations/BUILD.bazel b/sml/tree/emulations/BUILD.bazel index 7832e732..480fd3a9 100644 --- a/sml/tree/emulations/BUILD.bazel +++ b/sml/tree/emulations/BUILD.bazel @@ -11,3 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_binary( + name = "tree_emul", + srcs = ["tree_emul.py"], + deps = [ + "//sml/tree", + "//sml/utils:emulation", + ], +) diff --git a/sml/tree/emulations/tree_emul.py b/sml/tree/emulations/tree_emul.py new file mode 100644 index 00000000..d2290bc4 --- /dev/null +++ b/sml/tree/emulations/tree_emul.py @@ -0,0 +1,97 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.tree import DecisionTreeClassifier + +import sml.utils.emulation as emulation +from sml.tree.tree import DecisionTreeClassifier as sml_dtc + +MAX_DEPTH = 3 +CONFIG_FILE = emulation.CLUSTER_ABY3_3PC + + +def emul_tree(mode=emulation.Mode.MULTIPROCESS): + def proc_wrapper(max_depth=2, n_labels=3): + dt = sml_dtc( + max_depth=max_depth, criterion='gini', splitter='best', n_labels=n_labels + ) + + def proc(X, y): + dt_fit = dt.fit(X, y) + result = dt_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + # sorted_features: n_samples * n_features_in + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) + emulator.up() + + # load mock data + X, y = load_data() + n_samples = y.shape[0] + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + clf = DecisionTreeClassifier( + max_depth=MAX_DEPTH, criterion='gini', splitter='best', random_state=None + ) + start = time.time() + clf = clf.fit(X, y) + score_plain = clf.score(X, y) + end = time.time() + print(f"Running time in SKlearn: {end - start:.2f}s") + + # mark these data to be protected in SPU + X_spu, y_spu = emulator.seal(X, y) + + # run + proc = proc_wrapper(MAX_DEPTH, n_labels) + start = time.time() + result = emulator.run(proc)(X_spu, y_spu) + end = time.time() + score_encrpted = jnp.sum((result == y)) / n_samples + print(f"Running time in SPU: {end - start:.2f}s") + + # print acc + print(f"Accuracy in SKlearn: {score_plain:.2f}") + print(f"Accuracy in SPU: {score_encrpted:.2f}") + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_tree(emulation.Mode.MULTIPROCESS) diff --git a/sml/tree/tests/BUILD.bazel b/sml/tree/tests/BUILD.bazel index 7832e732..2b60bdb1 100644 --- a/sml/tree/tests/BUILD.bazel +++ b/sml/tree/tests/BUILD.bazel @@ -11,3 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "tree_test", + srcs = ["tree_test.py"], + deps = [ + "//sml/tree", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/tree/tests/tree_test.py b/sml/tree/tests/tree_test.py new file mode 100644 index 00000000..470065f9 --- /dev/null +++ b/sml/tree/tests/tree_test.py @@ -0,0 +1,85 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import jax.numpy as jnp +from sklearn.datasets import load_iris +from sklearn.tree import DecisionTreeClassifier + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.tree.tree import DecisionTreeClassifier as sml_dtc + +MAX_DEPTH = 3 + + +class UnitTests(unittest.TestCase): + def test_tree(self): + def proc_wrapper(max_depth=2, n_labels=3): + dt = sml_dtc("gini", "best", max_depth, n_labels) + + def proc(X, y): + dt_fit = dt.fit(X, y) + result = dt_fit.predict(X) + return result + + return proc + + def load_data(): + iris = load_iris() + iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) + # sorted_features: n_samples * n_features_in + n_samples, n_features_in = iris_data.shape + n_labels = len(jnp.unique(iris_label)) + sorted_features = jnp.sort(iris_data, axis=0) + new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 + new_features = jnp.greater_equal( + iris_data[:, :], new_threshold[:, jnp.newaxis, :] + ) + new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) + + X, y = new_features[:, ::3], iris_label[:] + return X, y + + # bandwidth and latency only work for docker mode + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + # load mock data + X, y = load_data() + n_samples = y.shape[0] + n_labels = jnp.unique(y).shape[0] + + # compare with sklearn + clf = DecisionTreeClassifier( + max_depth=MAX_DEPTH, criterion='gini', splitter='best', random_state=None + ) + clf = clf.fit(X, y) + score_plain = clf.score(X, y) + + # run + proc = proc_wrapper(MAX_DEPTH, n_labels) + result = spsim.sim_jax(sim, proc)(X, y) + score_encrpted = jnp.sum((result == y)) / n_samples + + # print acc + print(f"Accuracy in SKlearn: {score_plain:.2f}") + print(f"Accuracy in SPU: {score_encrpted:.2f}") + + +if __name__ == "__main__": + unittest.main() diff --git a/sml/tree/tree.py b/sml/tree/tree.py new file mode 100644 index 00000000..a7806ddf --- /dev/null +++ b/sml/tree/tree.py @@ -0,0 +1,273 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import jax.numpy as jnp + + +class DecisionTreeClassifier: + """A decision tree classifier based on [GTree](https://arxiv.org/abs/2305.00645). + + Adopting a MPC-based linear scan method (i.e. oblivious_array_access), GTree + designs a new GPU-friendly oblivious decision tree training protocol, which is + more efficient than the prior works. The current implementation supports the training + of decision tree with binary features (i.e. {0, 1}) and multi-class labels (i.e. {0, 1, 2, \dots}). + + We provide a simple example to show how to use GTree to train a decision tree classifier + in sml/tree/emulations/tree_emul.py. For training, the memory and time complexity is around + O(n_samples * n_labels * n_features * 2 ** max_height). + + Parameters + ---------- + criterion : {"gini", "entropy", "log_loss"}, default="gini" + The function to measure the quality of a split. Supported criteria are + "gini" for the Gini impurity and "log_loss" and "entropy" both for the + Shannon information gain, see :ref:`tree_mathematical_formulation`. + + splitter : {"best", "random"}, default="best" + The strategy used to choose the split at each node. Supported + strategies are "best" to choose the best split and "random" to choose + the best random split. + + max_depth : int, default=None + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + n_labels: int, the max number of labels. + """ + + def __init__(self, criterion, splitter, max_depth, n_labels): + assert criterion == "gini", "criteria other than gini is not supported." + assert splitter == "best", "splitter other than best is not supported." + self.max_depth = max_depth + self.n_labels = n_labels + + def fit(self, X, y): + self.T, self.F = odtt(X, y, self.max_depth, self.n_labels) + return self + + def predict(self, X): + assert self.T != None, "the model has not been trained yet." + return odti(X, self.T, self.max_depth) + + +''' +The protocols of GTree. +''' + + +def oblivious_array_access(array, index): + ''' + Extract elements from array according to index. + + If array is 1D, then output [array[i] for i in index]. + e.g.: array = [1, 2, 3, 4, 5], index = [0, 2, 4], output = [1, 3, 5]. + + If array is 2D, then output [[array[j, i] for i in index] for j in range(array.shape[0])]. + e.g. array = [[1, 2, 3], [4, 5, 6]], index_array = [0, 2], output = [[1, 3], [4, 6]]. + ''' + # (n_array) + count_array = jnp.arange(0, array.shape[-1]) + # (n_array, n_index) + E = jnp.equal(index, count_array[:, jnp.newaxis]) + + assert len(array.shape) <= 2, "OAA protocol only supports 1D or 2D array." + + # OAA basic case + if len(array.shape) == 1: + # (n_array, n_index) + O = array[:, jnp.newaxis] * E # select shares + zu = jnp.sum(O, axis=0) + # OAA vectorization variant + elif len(array.shape) == 2: + # (n_arrays, n_array, n_index) + O = array[:, :, jnp.newaxis] * E[jnp.newaxis, :, :] # select shares + zu = jnp.sum(O, axis=1) + return zu + + +def oaa_elementwise(array, index_array): + ''' + Given index_array, output [array[i, index[i]] for i in range(len(array))]. + + e.g.: array = [[1, 2, 3], [4, 5, 6]], index = [0, 2], output = [1, 6]. + ''' + assert index_array.shape[0] == array.shape[0], "n_arrays must be equal to n_index." + assert len(array.shape) == 2, "OAAE protocol only supports 2D array." + count_array = jnp.arange(0, array.shape[-1]) + # (n_array, n_index) + E = jnp.equal(index_array[:, jnp.newaxis], count_array) + if len(array.shape) == 2: + O = array * E + zu = jnp.sum(O, axis=1) + return zu + + +# def oblivious_learning(X, y, T, F, M, Cn, h): +def oblivious_learning(X, y, T, F, M, h, Cn, n_labels): + '''partition the data and count the number of data samples. + + params: + D: data samples, which is splitted into X, y. X: (n_samples, n_features), y: (n_samples, 1). + T: tree structure reprensenting split features. (total_nodes) + F: tree structure reprensenting node types. (total_nodes) + 0 for internal, 1 for leaf, 2 for dummy. + M: which leave node does D[i] belongs to (for level h-1). (n_samples) + Cn: statical information of the data samples. (n_leaves, n_labels+1, 2*n_features) + h: int, current depth of the tree. + ''' + # line 1-5, partition the datas into new leaves. + n_d, n_f = X.shape + n_h = 2**h + if h != 0: + Tval = oaa(T, M) + Dval = oaae(X, Tval) + M = 2 * M + Dval + 1 + + # (n_leaves) + LCidx = jnp.arange(0, n_h) + isLeaf = jnp.equal(F[n_h - 1 : 2 * n_h - 1], jnp.ones(n_h)) + # (n_samples, n_leaves) + LCF = jnp.equal(M[:, jnp.newaxis] - n_h + 1, LCidx) + LCF = LCF * isLeaf + # (n_samples, n_leaves, n_labels, 2 * n_features) + Cd = jnp.zeros((n_d, n_h, n_labels + 1, 2 * n_f)) + Cd = Cd.at[:, :, 0, 0::2].set(jnp.tile((1 - X)[:, jnp.newaxis, :], (1, n_h, 1))) + Cd = Cd.at[:, :, 0, 1::2].set(jnp.tile((X)[:, jnp.newaxis, :], (1, n_h, 1))) + for i in range(n_labels): + Cd = Cd.at[:, :, i + 1, 0::2].set( + jnp.tile( + ((1 - X) * (i == y)[:, jnp.newaxis])[:, jnp.newaxis, :], (1, n_h, 1) + ) + ) + Cd = Cd.at[:, :, i + 1, 1::2].set( + jnp.tile(((X) * (i == y)[:, jnp.newaxis])[:, jnp.newaxis, :], (1, n_h, 1)) + ) + Cd = Cd * LCF[:, :, jnp.newaxis, jnp.newaxis] + # (n_leaves, n_labels+1, 2*n_features) + new_Cn = jnp.sum(Cd, axis=0) + + if h != 0: + Cn = Cn.repeat(2, axis=0) + new_Cn = new_Cn[:, :, :] + Cn[:, :, :] * (1 - isLeaf[:, jnp.newaxis, jnp.newaxis]) + + return new_Cn, M + + +def oblivious_heuristic_computation(Cn, gamma, F, h, n_labels): + '''Compute gini index, find the best feature, and update F. + + params: + Cn: statical information of the data samples. (n_leaves, n_labels+1, 2*n_features) + gamma: gamma[n][i] indicates if feature si has been assigned at node n. (n_leaves, n_features) + F: tree structure reprensenting node types. (total_nodes) + 0 for internal, 1 for leaf, 2 for dummy. + h: int, current depth of the tree. + n_labels: int, number of labels. + ''' + n_leaves = Cn.shape[0] + n_features = gamma.shape[1] + Ds0 = Cn[:, 0, 0::2] + Ds1 = Cn[:, 0, 1::2] + D = Ds0 + Ds1 + Q = D * Ds0 * Ds1 + P = jnp.zeros(gamma.shape) + for i in range(n_labels): + P = P - Ds1 * (Cn[:, i + 1, 0::2] ** 2) - Ds0 * (Cn[:, i + 1, 1::2] ** 2) + gini = Q / (Q + P + 1) + gini = gini * gamma + # (n_leaves) + SD = jnp.argmax(gini, axis=1) + index = jnp.arange(0, n_features) + gamma = gamma * jnp.not_equal(index[jnp.newaxis, :], SD[:, jnp.newaxis]) + new_gamma = jnp.zeros((n_leaves * 2, n_features)) + new_gamma = new_gamma.at[0::2, :].set(gamma) + new_gamma = new_gamma.at[1::2, :].set(gamma) + + # # modification. + psi = jnp.zeros((n_leaves, n_labels)) + for i in range(n_labels): + psi = psi.at[:, i].set(Cn[:, i + 1, 0] + Cn[:, i + 1, 1]) + total = jnp.sum(psi, axis=1) + psi = total[:, jnp.newaxis] - psi + psi = jnp.prod(psi, axis=1) + F = F.at[2**h - 1 : 2 ** (h + 1) - 1].set( + jnp.equal(psi * F[2**h - 1 : 2 ** (h + 1) - 1], 0) + ) + F = F.at[2 ** (h + 1) - 1 : 2 ** (h + 2) - 1 : 2].set( + 2 - jnp.equal(F[2**h - 1 : 2 ** (h + 1) - 1], 0) + ) + F = F.at[2 ** (h + 1) : 2 ** (h + 2) - 1 : 2].set( + F[2 ** (h + 1) - 1 : 2 ** (h + 2) - 1 : 2] + ) + return SD, new_gamma, F + + +def oblivious_node_split(SD, T, F, Cn, h, max_depth): + '''Convert each node into its internal node and generates new leaves at the next level.''' + + T = T.at[2**h - 1 : 2 ** (h + 1) - 1].set(SD) + return T, Cn + + +def oblivious_DT_training(X, y, max_depth, n_labels): + n_samples, n_features = X.shape + T = jnp.zeros((2 ** (max_depth + 1) - 1)) + F = jnp.ones((2**max_depth - 1)) + M = jnp.zeros(n_samples) + gamma = jnp.ones((1, n_features)) + Cn = jnp.zeros((1, n_labels + 1, 2 * n_features)) + + h = 0 + while h < max_depth: + Cn, M = ol(X, y, T, F, M, h, Cn, n_labels) + + SD, gamma, F = ohc(Cn, gamma, F, h, n_labels) + + T, Cn = ons(SD, T, F, Cn, h, max_depth) + + h += 1 + + n_leaves = 2**h + psi = jnp.zeros((n_leaves, n_labels)) + for i in range(2 ** (h - 1)): + t1 = oaa(Cn[i, 1:], 2 * SD[i : i + 1]).squeeze() + t2 = oaa(Cn[i, 1:], 2 * SD[i : i + 1] + 1).squeeze() + psi = psi.at[2 * i, :].set(t1) + psi = psi.at[2 * i + 1, :].set(t2) + T = T.at[n_leaves - 1 :].set(jnp.argmax(psi, axis=1)) + return T, F + + +def oblivious_DT_inference(X, T, max_height): + n_samples, n_features = X.shape + Tidx = jnp.zeros((n_samples)) + i = 0 + while i < max_height: + Tval = oaa(T, Tidx) + Dval = oaae(X, Tval) + Tidx = Tidx * 2 + Dval + 1 + i += 1 + Tval = oaa(T, Tidx) + return Tval + + +oaa = oblivious_array_access +oaae = oaa_elementwise +ol = oblivious_learning +ohc = oblivious_heuristic_computation +ons = oblivious_node_split +odtt = oblivious_DT_training +odti = oblivious_DT_inference