From b8cf10ea3c8f05bd1f64315047001186308c9494 Mon Sep 17 00:00:00 2001 From: Bing Liu Date: Mon, 30 Dec 2024 10:49:47 +0800 Subject: [PATCH] MPC Quantized Machine Learning- Jacobi SVD MPC Quantized Machine Learning- Jacobi SVD --- sml/decomposition/BUILD.bazel | 5 + sml/decomposition/emulations/BUILD.bazel | 9 ++ .../emulations/jacobi_svd_emul.py | 91 +++++++++++++++++++ sml/decomposition/jacobi_svd.py | 82 +++++++++++++++++ 4 files changed, 187 insertions(+) create mode 100644 sml/decomposition/emulations/jacobi_svd_emul.py create mode 100644 sml/decomposition/jacobi_svd.py diff --git a/sml/decomposition/BUILD.bazel b/sml/decomposition/BUILD.bazel index c80a6751..9318d181 100644 --- a/sml/decomposition/BUILD.bazel +++ b/sml/decomposition/BUILD.bazel @@ -26,3 +26,8 @@ py_library( name = "nmf", srcs = ["nmf.py"], ) + +py_library( + name = "jacobi_svd", + srcs = ["jacobi_svd.py"], +) diff --git a/sml/decomposition/emulations/BUILD.bazel b/sml/decomposition/emulations/BUILD.bazel index 100d6e40..9f4c4ab9 100644 --- a/sml/decomposition/emulations/BUILD.bazel +++ b/sml/decomposition/emulations/BUILD.bazel @@ -45,6 +45,15 @@ py_binary( ], ) +py_binary( + name = "jacobi_svd_emul", + srcs = ["jacobi_svd_emul.py"], + deps = [ + "//sml/decomposition:jacobi_svd", + "//sml/utils:emulation", + ], +) + filegroup( name = "conf", srcs = [ diff --git a/sml/decomposition/emulations/jacobi_svd_emul.py b/sml/decomposition/emulations/jacobi_svd_emul.py new file mode 100644 index 00000000..7d974f98 --- /dev/null +++ b/sml/decomposition/emulations/jacobi_svd_emul.py @@ -0,0 +1,91 @@ +import os +import sys +import jax.numpy as jnp +import jax.random as random +import jax.lax as lax +import numpy as np +from scipy.linalg import svd as scipy_svd + +sys.path.append(os.path.join(os.path.dirname(__file__), '../../../')) + +import sml.utils.emulation as emulation + +def generate_symmetric_matrix(n, seed=0): + A = random.normal(random.PRNGKey(seed), (n, n)) + S = (A + A.T) / 2 + return S + +def jacobi_rotation(A, p, q): + tau = (A[q, q] - A[p, p]) / (2 * A[p, q]) + t = jnp.sign(tau) / (jnp.abs(tau) + jnp.sqrt(1 + tau**2)) + c = 1 / jnp.sqrt(1 + t**2) + s = t * c + return c, s + +def apply_jacobi_rotation_A(A, c, s, p, q): + A_new = A.copy() + A = A.at[p, :].set(c * A_new[p, :] - s * A_new[q, :]) + A = A.at[q, :].set(s * A_new[p, :] + c * A_new[q, :]) + A_new = A.copy() + A = A.at[:, p].set(c * A_new[:, p] - s * A_new[:, q]) + A = A.at[:, q].set(s * A_new[:, p] + c * A_new[:, q]) + return A + +def jacobi_svd(A, tol=1e-10, max_iter=5): + n = A.shape[0] + A = jnp.array(A) + + def body_fun(i, val): + A, max_off_diag = val + mask = jnp.abs(A - jnp.diagonal(A)) > tol + for p in range(n): + for q in range(p + 1, n): + A = lax.cond( + mask[p, q], + lambda A: apply_jacobi_rotation_A(A, *jacobi_rotation(A, p, q), p, q), + lambda A: A, + A + ) + max_off_diag = lax.cond( + mask[p, q], + lambda x: jnp.maximum(x, jnp.abs(A[p, q])), + lambda x: x, + max_off_diag + ) + return A, max_off_diag + + max_off_diag = jnp.inf + A, _, = lax.fori_loop(0, max_iter, body_fun, (A, max_off_diag)) + + singular_values = jnp.abs(jnp.diag(A)) + idx = jnp.argsort(-singular_values) + singular_values = singular_values[idx] + + return singular_values + +def emul_jacobi_svd(mode=emulation.Mode.MULTIPROCESS): + print("Start Jacobi SVD emulation.") + + def proc_transform(A): + singular_values = jacobi_svd(A) + return singular_values + + try: + emulator = emulation.Emulator( + emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 + ) + emulator.up() + + A = generate_symmetric_matrix(10) + A_spu = emulator.seal(A) + singular_values = emulator.run(proc_transform)(A_spu) + + _, singular_values_scipy, _ = scipy_svd(np.array(A), full_matrices=False) + + np.testing.assert_allclose(np.sort(singular_values), np.sort(singular_values_scipy), atol=1e-3) + + finally: + emulator.down() + +if __name__ == "__main__": + emul_jacobi_svd(emulation.Mode.MULTIPROCESS) diff --git a/sml/decomposition/jacobi_svd.py b/sml/decomposition/jacobi_svd.py new file mode 100644 index 00000000..bec148cb --- /dev/null +++ b/sml/decomposition/jacobi_svd.py @@ -0,0 +1,82 @@ +import jax.numpy as jnp +import jax.random as random +import jax.lax as lax +from jax import jit, vmap +import numpy as np +import time + +@jit +def jacobi_rotation(A, p, q): + tau = (A[q, q] - A[p, p]) / (2 * A[p, q]) + t = jnp.sign(tau) / (jnp.abs(tau) + jnp.sqrt(1 + tau**2)) + c = 1 / jnp.sqrt(1 + t**2) + s = t * c + return c, s + +@jit +def apply_jacobi_rotation_A(A, c, s, p, q): + A_new = A.copy() + A = A.at[p, :].set(c * A_new[p, :] - s * A_new[q, :]) + A = A.at[q, :].set(s * A_new[p, :] + c * A_new[q, :]) + A_new = A.copy() + A = A.at[:, p].set(c * A_new[:, p] - s * A_new[:, q]) + A = A.at[:, q].set(s * A_new[:, p] + c * A_new[:, q]) + return A + +@jit +def jacobi_svd(A, tol=1e-10, max_iter=5): + n = A.shape[0] + A = jnp.array(A) + + def body_fun(i, val): + A, max_off_diag, iterations = val + mask = jnp.abs(A - jnp.diagonal(A)) > tol + for p in range(n): + for q in range(p + 1, n): + A = lax.cond( + mask[p, q], + lambda A: apply_jacobi_rotation_A(A, *jacobi_rotation(A, p, q), p, q), + lambda A: A, + A + ) + max_off_diag = lax.cond( + mask[p, q], + lambda x: jnp.maximum(x, jnp.abs(A[p, q])), + lambda x: x, + max_off_diag + ) + return A, max_off_diag, iterations + + max_off_diag = jnp.inf + iterations = 0 + A, _, final_iterations = lax.fori_loop(0, max_iter, body_fun, (A, max_off_diag, iterations)) + + singular_values = jnp.abs(jnp.diag(A)) + idx = jnp.argsort(-singular_values) + singular_values = singular_values[idx] + + return singular_values + +def generate_symmetric_matrix(n, seed=0): + A = random.normal(random.PRNGKey(seed), (n, n)) + S = (A + A.T) / 2 + return S + +n = 10 + +A_jax = generate_symmetric_matrix(n) + +start_time = time.time() +singular_values = jacobi_svd(A_jax) +end_time = time.time() + +elapsed_time = end_time - start_time +print(f"Run Time: {elapsed_time:.6f} s") + +print("Singular Values Jacobi_svd:") +print(singular_values) + +A_np = np.array(A_jax) +_, Sigma, _ = np.linalg.svd(A_np) +print("Sigma:") +print(Sigma) \ No newline at end of file