Skip to content

Commit

Permalink
[fix] reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
ElleryQu committed Oct 23, 2023
1 parent 6abb2f7 commit 7149680
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion sml/tree/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ package(default_visibility = ["//visibility:public"])
py_library(
name = "tree",
srcs = ["tree.py"],
)
)
4 changes: 2 additions & 2 deletions sml/tree/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ py_binary(
name = "tree_emul",
srcs = ["tree_emul.py"],
deps = [
"//sml/tree:tree",
"//sml/tree",
"//sml/utils:emulation",
],
)
)
2 changes: 1 addition & 1 deletion sml/tree/emulations/tree_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def proc_wrapper(max_depth=2, n_labels=3):
dt = sml_dtc(
max_depth=max_depth, criterion='gini', splitter='best', n_labels=n_labels
)

def proc(X, y):
dt_fit = dt.fit(X, y)
result = dt_fit.predict(X)
Expand Down
4 changes: 2 additions & 2 deletions sml/tree/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ py_test(
name = "tree_test",
srcs = ["tree_test.py"],
deps = [
"//sml/tree:tree",
"//sml/tree",
"//spu:init",
"//spu/utils:simulation",
],
)
)
2 changes: 1 addition & 1 deletion sml/tree/tests/tree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class UnitTests(unittest.TestCase):
def test_tree(self):
def proc_wrapper(max_depth=2, n_labels=3):
dt = sml_dtc("gini", "best", max_depth, n_labels)

def proc(X, y):
dt_fit = dt.fit(X, y)
result = dt_fit.predict(X)
Expand Down
20 changes: 10 additions & 10 deletions sml/tree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

class DecisionTreeClassifier:
"""A decision tree classifier based on [GTree](https://arxiv.org/abs/2305.00645).
Adopting a MPC-based linear scan method (i.e. oblivious_array_access), GTree
designs a new GPU-friendly oblivious decision tree training protocol, which is
Adopting a MPC-based linear scan method (i.e. oblivious_array_access), GTree
designs a new GPU-friendly oblivious decision tree training protocol, which is
more efficient than the prior works. The current implementation supports the training
of decision tree with binary features (i.e. {0, 1}) and multi-class labels (i.e. {0, 1, 2, \dots}).
We provide a simple example to show how to use GTree to train a decision tree classifier
of decision tree with binary features (i.e. {0, 1}) and multi-class labels (i.e. {0, 1, 2, \dots}).
We provide a simple example to show how to use GTree to train a decision tree classifier
in sml/tree/emulations/tree_emul.py. For training, the memory and time complexity is around
O(n_samples * n_labels * n_features * 2 ** max_height).
Expand Down Expand Up @@ -71,18 +71,18 @@ def predict(self, X):
def oblivious_array_access(array, index):
'''
Extract elements from array according to index.
If array is 1D, then output [array[i] for i in index].
e.g.: array = [1, 2, 3, 4, 5], index = [0, 2, 4], output = [1, 3, 5].
If array is 2D, then output [array[i, index[i]] for i in range(len(array))].
e.g.: array = [[1, 2, 3], [4, 5, 6]], index = [0, 2], output = [[1], [6]].
'''
# (n_array)
count_array = jnp.arange(0, array.shape[-1])
# (n_array, n_index)
E = jnp.equal(index, count_array[:, jnp.newaxis])

assert len(array.shape) <= 2, "OAA protocol only supports 1D or 2D array."

# OAA basic case
Expand All @@ -101,7 +101,7 @@ def oblivious_array_access(array, index):
def oaa_elementwise(array, index_array):
'''
Extract elements from each sub-array of array according to index_array.
e.g. array = [[1, 2, 3], [4, 5, 6]], index_array = [0, 2], output = [[1, 3], [4, 6]].
'''
assert index_array.shape[0] == array.shape[0], "n_arrays must be equal to n_index."
Expand Down

0 comments on commit 7149680

Please sign in to comment.