Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPC Quantized Machine Learning- Jacobi SVD #952

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sml/decomposition/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ py_library(
name = "nmf",
srcs = ["nmf.py"],
)

py_library(
name = "jacobi_svd",
srcs = ["jacobi_svd.py"],
)
9 changes: 9 additions & 0 deletions sml/decomposition/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ py_binary(
],
)

py_binary(
name = "jacobi_svd_emul",
srcs = ["jacobi_svd_emul.py"],
deps = [
"//sml/decomposition:jacobi_svd",
"//sml/utils:emulation",
],
)

filegroup(
name = "conf",
srcs = [
Expand Down
91 changes: 91 additions & 0 deletions sml/decomposition/emulations/jacobi_svd_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
import sys
import jax.numpy as jnp
import jax.random as random
import jax.lax as lax
import numpy as np
from scipy.linalg import svd as scipy_svd

sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))

import sml.utils.emulation as emulation

def generate_symmetric_matrix(n, seed=0):
A = random.normal(random.PRNGKey(seed), (n, n))
S = (A + A.T) / 2
return S

def jacobi_rotation(A, p, q):
tau = (A[q, q] - A[p, p]) / (2 * A[p, q])
t = jnp.sign(tau) / (jnp.abs(tau) + jnp.sqrt(1 + tau**2))
c = 1 / jnp.sqrt(1 + t**2)
s = t * c
return c, s

def apply_jacobi_rotation_A(A, c, s, p, q):
A_new = A.copy()
A = A.at[p, :].set(c * A_new[p, :] - s * A_new[q, :])
A = A.at[q, :].set(s * A_new[p, :] + c * A_new[q, :])
A_new = A.copy()
A = A.at[:, p].set(c * A_new[:, p] - s * A_new[:, q])
A = A.at[:, q].set(s * A_new[:, p] + c * A_new[:, q])
return A

def jacobi_svd(A, tol=1e-10, max_iter=5):
n = A.shape[0]
A = jnp.array(A)

def body_fun(i, val):
A, max_off_diag = val
mask = jnp.abs(A - jnp.diagonal(A)) > tol
for p in range(n):
for q in range(p + 1, n):
A = lax.cond(
mask[p, q],
lambda A: apply_jacobi_rotation_A(A, *jacobi_rotation(A, p, q), p, q),
lambda A: A,
A
)
max_off_diag = lax.cond(
mask[p, q],
lambda x: jnp.maximum(x, jnp.abs(A[p, q])),
lambda x: x,
max_off_diag
)
return A, max_off_diag

max_off_diag = jnp.inf
A, _, = lax.fori_loop(0, max_iter, body_fun, (A, max_off_diag))

singular_values = jnp.abs(jnp.diag(A))
idx = jnp.argsort(-singular_values)
singular_values = singular_values[idx]

return singular_values

def emul_jacobi_svd(mode=emulation.Mode.MULTIPROCESS):
print("Start Jacobi SVD emulation.")

def proc_transform(A):
singular_values = jacobi_svd(A)
return singular_values

try:
emulator = emulation.Emulator(
emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20
)
emulator.up()

A = generate_symmetric_matrix(10)
A_spu = emulator.seal(A)
singular_values = emulator.run(proc_transform)(A_spu)

_, singular_values_scipy, _ = scipy_svd(np.array(A), full_matrices=False)

np.testing.assert_allclose(np.sort(singular_values), np.sort(singular_values_scipy), atol=1e-3)

finally:
emulator.down()

if __name__ == "__main__":
emul_jacobi_svd(emulation.Mode.MULTIPROCESS)
82 changes: 82 additions & 0 deletions sml/decomposition/jacobi_svd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import jax.numpy as jnp
import jax.random as random
import jax.lax as lax
from jax import jit, vmap
import numpy as np
import time

@jit
def jacobi_rotation(A, p, q):
tau = (A[q, q] - A[p, p]) / (2 * A[p, q])
t = jnp.sign(tau) / (jnp.abs(tau) + jnp.sqrt(1 + tau**2))
c = 1 / jnp.sqrt(1 + t**2)
s = t * c
return c, s

@jit
def apply_jacobi_rotation_A(A, c, s, p, q):
A_new = A.copy()
A = A.at[p, :].set(c * A_new[p, :] - s * A_new[q, :])
A = A.at[q, :].set(s * A_new[p, :] + c * A_new[q, :])
A_new = A.copy()
A = A.at[:, p].set(c * A_new[:, p] - s * A_new[:, q])
A = A.at[:, q].set(s * A_new[:, p] + c * A_new[:, q])
return A

@jit
def jacobi_svd(A, tol=1e-10, max_iter=5):
n = A.shape[0]
A = jnp.array(A)

def body_fun(i, val):
A, max_off_diag, iterations = val
mask = jnp.abs(A - jnp.diagonal(A)) > tol
for p in range(n):
for q in range(p + 1, n):
A = lax.cond(
mask[p, q],
lambda A: apply_jacobi_rotation_A(A, *jacobi_rotation(A, p, q), p, q),
lambda A: A,
A
)
max_off_diag = lax.cond(
mask[p, q],
lambda x: jnp.maximum(x, jnp.abs(A[p, q])),
lambda x: x,
max_off_diag
)
return A, max_off_diag, iterations

max_off_diag = jnp.inf
iterations = 0
A, _, final_iterations = lax.fori_loop(0, max_iter, body_fun, (A, max_off_diag, iterations))

singular_values = jnp.abs(jnp.diag(A))
idx = jnp.argsort(-singular_values)
singular_values = singular_values[idx]

return singular_values

def generate_symmetric_matrix(n, seed=0):
A = random.normal(random.PRNGKey(seed), (n, n))
S = (A + A.T) / 2
return S

n = 10

A_jax = generate_symmetric_matrix(n)

start_time = time.time()
singular_values = jacobi_svd(A_jax)
end_time = time.time()

elapsed_time = end_time - start_time
print(f"Run Time: {elapsed_time:.6f} s")

print("Singular Values Jacobi_svd:")
print(singular_values)

A_np = np.array(A_jax)
_, Sigma, _ = np.linalg.svd(A_np)
print("Sigma:")
print(Sigma)
Loading