From 76cc9da8970b1f1c388da04837e7126642480531 Mon Sep 17 00:00:00 2001 From: salazar1117 <70342834+salazar1117@users.noreply.github.com> Date: Wed, 25 Oct 2023 16:25:48 +0800 Subject: [PATCH] gaussian (#327) # Pull Request ## What problem does this PR solve? Issue Number: Fixed #255 ## Possible side effects? - Performance: - Backward compatibility: --- sml/gaussian_process/BUILD.bazel | 19 ++ sml/gaussian_process/_gpc.py | 318 ++++++++++++++++++++ sml/gaussian_process/emulations/BUILD.bazel | 13 + sml/gaussian_process/emulations/gpc_emul.py | 49 +++ sml/gaussian_process/kernels.py | 15 + sml/gaussian_process/ovo_ovr.py | 109 +++++++ sml/gaussian_process/tests/BUILD.bazel | 14 + sml/gaussian_process/tests/gpc_test.py | 43 +++ 8 files changed, 580 insertions(+) create mode 100644 sml/gaussian_process/_gpc.py create mode 100644 sml/gaussian_process/emulations/gpc_emul.py create mode 100644 sml/gaussian_process/kernels.py create mode 100644 sml/gaussian_process/ovo_ovr.py create mode 100644 sml/gaussian_process/tests/gpc_test.py diff --git a/sml/gaussian_process/BUILD.bazel b/sml/gaussian_process/BUILD.bazel index 7832e732..9b700075 100644 --- a/sml/gaussian_process/BUILD.bazel +++ b/sml/gaussian_process/BUILD.bazel @@ -11,3 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_library") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "_gpc", + srcs = ["_gpc.py"], +) + +py_library( + name = "kernels", + srcs = ["kernels.py"], +) + +py_library( + name = "ovo_ovr", + srcs = ["ovo_ovr.py"], +) diff --git a/sml/gaussian_process/_gpc.py b/sml/gaussian_process/_gpc.py new file mode 100644 index 00000000..66e4ff00 --- /dev/null +++ b/sml/gaussian_process/_gpc.py @@ -0,0 +1,318 @@ +import os +import sys + +import jax +import jax.numpy as jnp +from jax import grad +from jax.lax.linalg import cholesky +from jax.scipy.linalg import cho_solve, solve +from jax.scipy.special import erf, expit + +sys.path.append(os.path.join(os.path.dirname(__file__), "./")) +from kernels import RBF +from ovo_ovr import OneVsRestClassifier + +LAMBDAS = jnp.array([0.41, 0.4, 0.37, 0.44, 0.39])[:, jnp.newaxis] +COEFS = jnp.array( + [-1854.8214151, 3516.89893646, 221.29346712, 128.12323805, -2010.49422654] +)[:, jnp.newaxis] + + +class _BinaryGaussianProcessClassifierLaplace: + def __init__( + self, + kernel=None, + *, + poss="sigmoid", + max_iter_predict=100, + ): + self.kernel = kernel + self.max_iter_predict = max_iter_predict + self.poss = poss + + def fit(self, X, y): + self._check_kernal() + + self.X_train_ = jnp.asarray(X) + + if self.poss == "sigmoid": + self.approx_func = expit + else: + raise ValueError( + f"Unsupported prior-likelihood function {self.poss}." + "Please try the default dunction sigmoid." + ) + + self.y_train = y + + K = self.kernel_(self.X_train_) + self.f_ = self._posterior_mode(K) + return self.f_ + + # def log_and_grad(self, f, y_train): + # _tmp = lambda f, y_train: jnp.sum(self.approx_func(y_train*f)) + # return grad(_tmp)(f, y_train)/self.approx_func(y_train*f) + + # def log_and_2grads_and_negtive(self, f, y_train): + # _tmp = lambda f, y: jnp.sum(self.log_and_grad(f, y)) + # return -grad(_tmp)(f, y_train) + + # def log_and_3grads(self, f, y_train): + # _tmp = lambda f, y_train: jnp.sum(-self.log_and_2grads_and_negtive(f, y_train)) + # return grad(_tmp)(f, y_train) + + def predict(self, Xll): + X = jnp.asarray(Xll) + K_star = self.kernel_(self.X_train_, X) + # f_star = K_star.T.dot(self.log_and_grad(self.f_, self.y_train)) + f_star = K_star.T.dot(self.y_train - self.approx_func(self.f_)) + + return jnp.where(f_star > 0, 1, 0) + + def predict_proba(self, Xll): + X = jnp.asarray(Xll) + K = self.kernel_(self.X_train_) + # Based on Algorithm 3.2 of GPML + + # W = self.log_and_2grads_and_negtive(self.f_, self.y_train) + pi = self.approx_func(self.f_) + W = pi * (1 - pi) + + W_sqr = jnp.sqrt(W) + W_sqr_K = W_sqr[:, jnp.newaxis] * K + B = jnp.eye(W.shape[0]) + W_sqr_K * W_sqr + L = cholesky(B) + + K_star = self.kernel_(self.X_train_, X) + # f_star = K_star.T.dot(self.log_and_grad(self.f_, self.y_train)) + f_star = K_star.T.dot(self.y_train - pi) + v = solve(L, W_sqr[:, jnp.newaxis] * K_star) + var_f_star = jnp.diag(self.kernel_(X)) - jnp.einsum("ij,ij->j", v, v) + + alpha = 1 / (2 * var_f_star) + gamma = LAMBDAS * f_star + integrals = ( + jnp.sqrt(jnp.pi / alpha) + * erf(gamma * jnp.sqrt(alpha / (alpha + LAMBDAS**2))) + / (2 * jnp.sqrt(var_f_star * 2 * jnp.pi)) + ) + pi_star = (COEFS * integrals).sum(axis=0) + 0.5 * COEFS.sum() + + return jnp.vstack((1 - pi_star, pi_star)).T + + def _posterior_mode(self, K): + # Based on Algorithm 3.1 of GPML + f = jnp.zeros_like( + self.y_train, dtype=jnp.float32 + ) # a warning is triggered if float64 is used + + for _ in range(self.max_iter_predict): + # W = self.log_and_2grads_and_negtive(f, self.y_train) + pi = self.approx_func(f) + W = pi * (1 - pi) + W_sqr = jnp.sqrt(W) + W_sqr_K = W_sqr[:, jnp.newaxis] * K + + B = jnp.eye(W.shape[0]) + W_sqr_K * W_sqr + L = cholesky(B) + # b = W * f + self.log_and_grad(f, self.y_train) + b = W * f + (self.y_train - pi) + a = b - jnp.dot( + W_sqr[:, jnp.newaxis] * cho_solve((L, True), jnp.eye(W.shape[0])), + W_sqr_K.dot(b), + ) + f = K.dot(a) + + self.f_cached = f # for warm-start + return f + + def _check_kernal(self): + if self.kernel is None: # Use an RBF kernel as default + self.kernel_ = RBF() + else: + self.kernel_ = self.kernel + + +class GaussianProcessClassifier: + """Gaussian process classification (GPC) based on Laplace approximation. + + The implementation is based on Algorithm 3.1, 3.2, and 5.1 from [RW2006]_. + + Internally, the Laplace approximation is used for approximating the + non-Gaussian posterior by a Gaussian. + + Currently, the implementation is restricted to using the logistic link + function. For multi-class classification, several binary one-versus rest + classifiers are fitted. Note that this class thus does not implement + a true multi-class Laplace approximation. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + kernel : kernel instance, default=None + The kernel specifying the covariance function of the GP. If None is + passed, the kernel "1.0 * RBF(1.0)" is used as default. Note that + the kernel's hyperparameters are optimized during fitting. Also kernel + cannot be a `CompoundKernel`. + + max_iter_predict : int, default=100 + The maximum number of iterations in Newton's method for approximating + the posterior during predict. Smaller values will reduce computation + time at the cost of worse results. + + multi_class : 'one_vs_rest', default='one_vs_rest' + Specifies how multi-class classification problems are handled. + One binary Gaussian process classifier is fitted for each class, which + is trained to separate this class from the rest. + + poss : "sigmoid", allable or None, default="sigmoid", the predefined + likelihood function which computes the possibility of the predict output + w.r.t. f value. + + Attributes + ---------- + base_estimator_ : ``Estimator`` instance + The estimator instance that defines the likelihood function + using the observed data. + + kernel_ : kernel instance + The kernel used for prediction. In case of binary classification, + the structure of the kernel is the same as the one passed as parameter + but with optimized hyperparameters. In case of multi-class + classification, a CompoundKernel is returned which consists of the + different kernels used in the one-versus-rest classifiers. + + n_classes_ : int + The number of classes in the training data + + References + ---------- + .. [RW2006] `Carl E. Rasmussen and Christopher K.I. Williams, + "Gaussian Processes for Machine Learning", + MIT Press 2006 `_ + + Examples + -------- + >>> from sklearn.datasets import load_iris + >>> from sklearn.gaussian_process import GaussianProcessClassifier + >>> from sklearn.gaussian_process.kernels import RBF + >>> X, y = load_iris(return_X_y=True) + >>> kernel = 1.0 * RBF(1.0) + >>> gpc = GaussianProcessClassifier(kernel=kernel, + ... random_state=0).fit(X, y, 3) + >>> gpc.predict_proba(X[:2,:]) + array([[0.83548752, 0.03228706, 0.13222543], + [0.79064206, 0.06525643, 0.14410151]]) + """ + + def __init__( + self, + kernel=None, + *, + poss="sigmoid", + max_iter_predict=100, + multi_class="one_vs_rest", + ): + self.kernel = kernel + self.max_iter_predict = max_iter_predict + self.multi_class = multi_class + self.poss = poss + + def fit(self, X, y, n_classes_): + """Fit Gaussian process classification model. + + Parameters + ---------- + X : jax numpy array (n_samples, n_features) of object + Feature vectors of training data. + + y : jax numpy array (n_samples,) Target values, + must be preprocessed to 0, 1, 2, ... + + n_classes_ : The number of classes in the training data + + Returns + ------- + self : object + Returns an instance of self. + """ + self.n_classes_ = n_classes_ + self.y_train = jnp.array(y) + + if self.n_classes_ == 1: + raise ValueError( + "GaussianProcessClassifier requires 2 or more " + "distinct classes; got 1 class (only class %s " + "is present)" % self.n_classes_[0] + ) + + self.base_estimator_ = _BinaryGaussianProcessClassifierLaplace( + kernel=self.kernel, + max_iter_predict=self.max_iter_predict, + poss=self.poss, + ) + + if self.n_classes_ > 2: + if self.multi_class == "one_vs_rest": + self.base_estimator_ = OneVsRestClassifier( + self.base_estimator_, self.n_classes_ + ) + elif self.multi_class == "one_vs_one": + raise ValueError("one_vs_one classifier is not supported") + else: + raise ValueError("Unknown multi-class mode %s" % self.multi_class) + + self.X = jnp.array(X) + self.base_estimator_.fit(self.X, self.y_train) + + def predict(self, X): + """Perform classification on an array of test vectors X. + + Parameters + ---------- + X : jax numpy array (n_samples, n_features) of object + Query points where the GP is evaluated for classification. + + Returns + ------- + C : jax numpy array (n_samples,) + Predicted target values for X. + """ + self.check_is_fitted() + return self.base_estimator_.predict(X) + + def predict_proba(self, X): + """Return probability estimates for the test vector X. + + Parameters + ---------- + X : jax numpy array (n_samples, n_features) of object + Query points where the GP is evaluated for classification. + + Returns + ------- + C : jax numpy array (n_samples, n_classes) + Returns the probability of the samples for each class in + the model. The columns correspond to the classes in sorted + order. + """ + self.check_is_fitted() + return self.base_estimator_.predict_proba(X) + + def check_is_fitted(self): + """Perform is_fitted validation for estimator. + + Checks if the estimator is fitted by verifying the presence of + fitted attribute self.n_classes_ and otherwise + raises a NotFittedError with the given message. + + Raises + ------ + Exception + If the attribute is not found. + """ + try: + self.n_classes_ + except: + raise Exception('Model is not fitted yet') diff --git a/sml/gaussian_process/emulations/BUILD.bazel b/sml/gaussian_process/emulations/BUILD.bazel index 7832e732..bf907a4f 100644 --- a/sml/gaussian_process/emulations/BUILD.bazel +++ b/sml/gaussian_process/emulations/BUILD.bazel @@ -11,3 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "gpc_emul", + srcs = ["gpc_emul.py"], + deps = [ + "//sml/gaussian_process:_gpc", + "//spu/utils:emulation", + ], +) diff --git a/sml/gaussian_process/emulations/gpc_emul.py b/sml/gaussian_process/emulations/gpc_emul.py new file mode 100644 index 00000000..9e727d76 --- /dev/null +++ b/sml/gaussian_process/emulations/gpc_emul.py @@ -0,0 +1,49 @@ +import os +import sys + +import jax.numpy as jnp +from sklearn.datasets import load_iris + +# Add the library directory to the path +sys.path.append(os.path.join(os.path.dirname(__file__), "../../../")) +import sml.utils.emulation as emulation +from sml.gaussian_process._gpc import GaussianProcessClassifier + + +def emul_gpc(mode: emulation.Mode.MULTIPROCESS): + def proc(x, y): + model = GaussianProcessClassifier(max_iter_predict=10) + model.fit(x, y) + + pred = model.predict(x) + return pred + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 + ) + emulator.up() + + # load data + x, y = load_iris(return_X_y=True) + x = x[45:55, :] + y = y[45:55] + + # mark these data to be protected in SPU + x, y = emulator.seal(x, y) + result = emulator.run(proc)(x, y) + print("Accuracy: ", jnp.sum(result == y) / len(y)) + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_gpc(emulation.Mode.MULTIPROCESS) + + finally: + emulator.down() + +if __name__ == "__main__": + emul_gpc(emulation.Mode.MULTIPROCESS) diff --git a/sml/gaussian_process/kernels.py b/sml/gaussian_process/kernels.py new file mode 100644 index 00000000..40e9c59c --- /dev/null +++ b/sml/gaussian_process/kernels.py @@ -0,0 +1,15 @@ +import jax.numpy as jnp + + +class RBF: + def __init__(self, length_scale=1.0): + self.length_scale = length_scale + + def __call__(self, X, Y=None): + if Y == None: + Y = X + K = jnp.zeros((X.shape[0], Y.shape[0]), dtype=jnp.float32) + for i in range(X.shape[0]): + for j in range(Y.shape[0]): + K = K.at[i, j].set((jnp.sum((X[i] - Y[j]) ** 2))) + return jnp.exp(-K / (2 * self.length_scale)) diff --git a/sml/gaussian_process/ovo_ovr.py b/sml/gaussian_process/ovo_ovr.py new file mode 100644 index 00000000..d7c99b5a --- /dev/null +++ b/sml/gaussian_process/ovo_ovr.py @@ -0,0 +1,109 @@ +import jax.numpy as jnp +from jax import vmap +from jax.scipy.special import erf, expit +from kernels import RBF +from jax.lax.linalg import cholesky +from jax.scipy.linalg import cho_solve, solve + +LAMBDAS = jnp.array([0.41, 0.4, 0.37, 0.44, 0.39])[:, jnp.newaxis] +COEFS = jnp.array( + [-1854.8214151, 3516.89893646, 221.29346712, 128.12323805, -2010.49422654] +)[:, jnp.newaxis] + +class OneVsRestClassifier: + def __init__(self, estimator, n_classes): + self.estimator = estimator + self.classes_ = n_classes + + def fit(self, X, y): + + self.estimator.approx_func = expit + self.estimator.X_train_ = jnp.array(X) + + if self.estimator.kernel is None: # Use an RBF kernel as default + self.estimator.kernel_ = RBF() + else: + self.estimator.kernel_ = self.estimator.kernel + + self.K = self.estimator.kernel_(X) + + self.y_binary = jnp.array( + [jnp.where(y == i, 0, 1) for i in range(self.classes_)] + ) + + self.fs_ = vmap(self._posterior_mode, in_axes=(None, 0, None))(X, self.y_binary, self.K) + self.f_cached = self.fs_ + + def predict(self, X_test): + X = jnp.array(X_test) + K_star = self.estimator.kernel_(self.estimator.X_train_, X) + diag_K_Xtest = jnp.diag(self.estimator.kernel_(X)) + maxima = vmap(self.predict_proba_oneclass, in_axes=(0, None, 0, None, None, None))( + self.y_binary, X_test, self.fs_, self.K, K_star, diag_K_Xtest + ) + return maxima.argmax(axis=0) + + def predict_proba(self, X_test): + X = jnp.array(X_test) + K_star = self.estimator.kernel_(self.estimator.X_train_, X) + diag_K_Xtest = jnp.diag(self.estimator.kernel_(X)) + maxima = vmap(self.predict_proba_oneclass, in_axes=(0, None, 0, None, None, None))( + self.y_binary, X_test, self.fs_, self.K, K_star, diag_K_Xtest + ) + maxima = maxima / jnp.sum(maxima, axis=0) + return maxima.T + + def predict_proba_oneclass(self, y_binary, X_test, f_, K, K_star, diag_K_Xtest): + X = jnp.asarray(X_test) + # K = self.estimator.kernel_(self.estimator.X_train_) + + pi = self.estimator.approx_func(f_) + W = pi * (1 - pi) + + W_sqr = jnp.sqrt(W) + W_sqr_K = W_sqr[:, jnp.newaxis] * K + B = jnp.eye(W.shape[0]) + W_sqr_K * W_sqr + L = cholesky(B) + + # K_star = self.estimator.kernel_(self.estimator.X_train_, X) + # f_star = K_star.T.dot(self.log_and_grad(self.f_, self.y_train)) + f_star = K_star.T.dot(y_binary - pi) + v = solve(L, W_sqr[:, jnp.newaxis] * K_star) + var_f_star = diag_K_Xtest - jnp.einsum("ij,ij->j", v, v) + + alpha = 1 / (2 * var_f_star) + gamma = LAMBDAS * f_star + integrals = ( + jnp.sqrt(jnp.pi / alpha) + * erf(gamma * jnp.sqrt(alpha / (alpha + LAMBDAS**2))) + / (2 * jnp.sqrt(var_f_star * 2 * jnp.pi)) + ) + pi_star = (COEFS * integrals).sum(axis=0) + 0.5 * COEFS.sum() + + return 1 - pi_star + + def _posterior_mode(self, X, y_binary, K): + # K = self.estimator.kernel_(X) + # Based on Algorithm 3.1 of GPML + f = jnp.zeros_like( + y_binary, dtype=jnp.float32 + ) # a warning is triggered if float64 is used + + for _ in range(self.estimator.max_iter_predict): + # W = self.log_and_2grads_and_negtive(f, self.y_train) + pi = self.estimator.approx_func(f) + W = pi * (1 - pi) + W_sqr = jnp.sqrt(W) + W_sqr_K = W_sqr[:, jnp.newaxis] * K + + B = jnp.eye(W.shape[0]) + W_sqr_K * W_sqr + L = cholesky(B) + # b = W * f + self.log_and_grad(f, self.y_train) + b = W * f + (y_binary - pi) + a = b - jnp.dot( + W_sqr[:, jnp.newaxis] * cho_solve((L, True), jnp.eye(W.shape[0])), + W_sqr_K.dot(b), + ) + f = K.dot(a) + + return f diff --git a/sml/gaussian_process/tests/BUILD.bazel b/sml/gaussian_process/tests/BUILD.bazel index 7832e732..52e9a0a3 100644 --- a/sml/gaussian_process/tests/BUILD.bazel +++ b/sml/gaussian_process/tests/BUILD.bazel @@ -11,3 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +load("@rules_python//python:defs.bzl", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "gpc_test", + srcs = ["gpc_test.py"], + deps = [ + "//sml/gaussian_process:_gpc", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/gaussian_process/tests/gpc_test.py b/sml/gaussian_process/tests/gpc_test.py new file mode 100644 index 00000000..a854252b --- /dev/null +++ b/sml/gaussian_process/tests/gpc_test.py @@ -0,0 +1,43 @@ +import os +import sys +import unittest + +import jax +import jax.numpy as jnp +from sklearn.datasets import load_iris + +import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as spsim + +# Add the library directory to the path +sys.path.append(os.path.join(os.path.dirname(__file__), "../../../")) +from sml.gaussian_process._gpc import GaussianProcessClassifier + + +class UnitTests(unittest.TestCase): + def test_gpc(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + # Test GaussianProcessClassifier + def proc(x, y): + model = GaussianProcessClassifier(max_iter_predict=10) + model.fit(x, y, 3) + + pred = model.predict(x) + return pred + + # Create dataset + x, y = load_iris(return_X_y=True) + # x = x[45:55, :] + # y = y[45:55] + + # Run + result = spsim.sim_jax(sim, proc)(x, y) + print(result) + print(y) + print("Accuracy: ", jnp.sum(result == y) / len(y)) + + +unittest.main()