Skip to content

Commit

Permalink
Merge pull request #79 from kc-ml2/DEV/main
Browse files Browse the repository at this point in the history
Dev/main
  • Loading branch information
yonghakim authored Sep 19, 2024
2 parents 1affa35 + cb9a618 commit 310f0b6
Show file tree
Hide file tree
Showing 60 changed files with 4,244 additions and 3,620 deletions.
58 changes: 0 additions & 58 deletions QA/1D_grating_in_2D_pattern.py

This file was deleted.

87 changes: 87 additions & 0 deletions QA/1d_pattern_in_1dc_and_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# This demo shows a case with 1D grating and TM polarization.
# If phi is set to 'None', this will use 1D TETM formulation (without azimuthal rotation, phi == 0)
# But if phi is set to '0', then the simulation will be taken for 1D conical or 2D case which is general but slower.

import numpy as np
from time import time

from meent import call_mee


def compare():
backend = 0
pol = 1 # 0: TE, 1: TM

n_top = 1 # n_incidence
n_bot = 1 # n_transmission

theta = 1E-10 # angle of incidence in radian

wavelength = 300 # wavelength
thickness = [460, 22]
period = [700, 700]
fto = [100, 0]

ucell_1d = np.array([
[
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
],
[
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
],
])
ucell_2d = np.array([
[
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
],
[
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
],
])

mee = call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, fto=fto,
wavelength=wavelength, period=period, thickness=thickness)

# 1D
mee.phi = None # which is default
mee.ucell = ucell_1d

t0_1d = time()
res = mee.conv_solve().res
t1_1d = time()
de_ri1, de_ti1 = res.de_ri, res.de_ti
print('1D (de_ri, de_ti): ', de_ri1, de_ti1)

# 1D conical
mee.phi = 0
t0_1dc = time()
res = mee.conv_solve().res
t1_1dc = time()
de_ri1c, de_ti1c = res.de_ri, res.de_ti
print('1Dc (de_ri, de_ti): ', de_ri1c, de_ti1c)

# 2D
mee.phi = 0
t0_2d = time()
mee.ucell = ucell_2d
res = mee.conv_solve().res
t1_2d = time()
de_ri2, de_ti2 = res.de_ri, res.de_ti
print('2D (de_ri, de_ti): ', de_ri2, de_ti2)

print('time for 1D formulation: ', t1_1d-t0_1d, 's')
print('time for 1Dc formulation: ', t1_1dc-t0_1dc, 's')
print('time for 2D formulation: ', t1_2d-t0_2d, 's')
print('Simulation Difference between 1D and 1Dc formulation: ',
np.linalg.norm(de_ri1 - de_ri1c), np.linalg.norm(de_ti1 - de_ti1c))
print('Simulation Difference between 1D and 2D formulation: ',
np.linalg.norm(de_ri1 - de_ri2), np.linalg.norm(de_ti1 - de_ti2))

print('Simulation Difference between 1Dc and 2D formulation: ',
np.linalg.norm(de_ri1c - de_ri2), np.linalg.norm(de_ti1c - de_ti2))


if __name__ == '__main__':
compare()
23 changes: 17 additions & 6 deletions QA/autograd_complex_ucell.py → QA/autodiff_raster1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

import meent
from meent.on_torch.optimizer.loss import LossDeflector


type_complex = 0
device = 0
Expand Down Expand Up @@ -48,7 +48,19 @@

pois = ['ucell', 'thickness'] # Parameter Of Interests
forward = jmee.conv_solve
loss_fn = LossDeflector(x_order=0, y_order=0)


class Loss:
def __call__(self, meent_result, *args, **kwargs):
res_psi, res_te, res_ti = meent_result.res, meent_result.res_te_inc, meent_result.res_tm_inc
de_ti = res_psi.de_ti
center = [a // 2 for a in de_ti.shape]
res = de_ti[center[0], center[1]+1]

return res


loss_fn = Loss()

# case 1: Gradient
grad_j = jmee.grad(pois, forward, loss_fn)
Expand All @@ -58,7 +70,7 @@
print('thickness gradient:')
print(grad_j['thickness'])

optimizer = optax.sgd(learning_rate=1e-2)
optimizer = optax.sgd(learning_rate=1E2)
t0 = time.time()
res_j = jmee.fit(pois, forward, loss_fn, optimizer, iteration=iteration)
print('Time JAX', time.time() - t0)
Expand All @@ -74,7 +86,6 @@
thickness=thickness, type_complex=type_complex, device=device)

forward = tmee.conv_solve
loss_fn = LossDeflector(x_order=0) # predefined in meent

grad_t = tmee.grad(pois, forward, loss_fn)
print('ucell gradient:')
Expand All @@ -83,7 +94,7 @@
print(grad_t['thickness'])

opt_torch = torch.optim.SGD
opt_options = {'lr': 1E-2}
opt_options = {'lr': 1E2}

t0 = time.time()
res_t = tmee.fit(pois, forward, loss_fn, opt_torch, opt_options, iteration=iteration)
Expand All @@ -102,6 +113,6 @@

print('End')

# Note that the gradient in JAX is conjugated.
# Note that the gradient in JAX is conjugation of PyTorch's.
# https://github.com/google/jax/issues/4891
# https://pytorch.org/docs/stable/notes/autograd.html#autograd-for-complex-numbers
172 changes: 172 additions & 0 deletions QA/autodiff_raster2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import jax
import torch

import jax.numpy as jnp
import numpy as np

from time import time

from meent import call_mee


def load_setting():
pol = 1 # 0: TE, 1: TM

n_top = 1 # n_incidence
n_bot = 1 # n_transmission

theta = 0 * np.pi / 180
phi = 0 * np.pi / 180

wavelength = 900

fto = [5, 5]

period = [1000, 1000]
thickness = [1120]

ucell = np.array([[[2.58941352 + 0.47745679j, 4.17771602 + 0.88991205j,
2.04255624 + 2.23670125j, 2.50478974 + 2.05242759j,
3.32747593 + 2.3854387j],
[2.80118605 + 0.53053715j, 4.46498861 + 0.10812571j,
3.99377545 + 1.0441131j, 3.10728537 + 0.6637353j,
4.74697849 + 0.62841253j],
[3.80944424 + 2.25899274j, 3.70371553 + 1.32586402j,
3.8011133 + 1.49939415j, 3.14797238 + 2.91158289j,
4.3085404 + 2.44344691j],
[2.22510179 + 2.86017146j, 2.36613053 + 2.82270351j,
4.5087168 + 0.2035904j, 3.15559949 + 2.55311298j,
4.29394604 + 0.98362617j],
[3.31324163 + 2.77590131j, 2.11744834 + 1.65894674j,
3.59347907 + 1.28895345j, 3.85713467 + 1.90714056j,
2.93805426 + 2.63385392j]]])
ucell = ucell.real

type_complex = 0
device = 0

setting = {'pol': pol, 'n_top': n_top, 'n_bot': n_bot, 'theta': theta, 'phi': phi, 'fto': fto,
'wavelength': wavelength, 'period': period, 'ucell': ucell, 'thickness': thickness, 'device': device,
'type_complex': type_complex}

return setting


def optimize_jax(setting):
ucell = setting['ucell']

mee = call_mee(backend=1, **setting)

@jax.jit
def grad_loss(ucell):
mee.ucell = ucell
res = mee.conv_solve().res
de_ri, de_ti = res.de_ri, res.de_ti

loss = de_ti[de_ti.shape[0] // 2, de_ti.shape[1] // 2]

return loss

def grad_numerical(ucell, delta):
grad_arr = jnp.zeros(ucell.shape, dtype=ucell.dtype)

@jax.jit
def compute(ucell):
mee.ucell = ucell
result = mee.conv_solve()
de_ti = result.res.de_ti
loss = de_ti[de_ti.shape[0] // 2, de_ti.shape[1] // 2]

return loss

for layer in range(ucell.shape[0]):
for r in range(ucell.shape[1]):
for c in range(ucell.shape[2]):
ucell_delta_m = ucell.copy()
ucell_delta_m[layer, r, c] -= delta
mee.ucell = ucell_delta_m
de_ti_delta_m = compute(ucell_delta_m, )

ucell_delta_p = ucell.copy()
ucell_delta_p[layer, r, c] += delta
mee.ucell = ucell_delta_p
de_ti_delta_p = compute(ucell_delta_p, )

grad_numeric = (de_ti_delta_p - de_ti_delta_m) / (2 * delta)
grad_arr = grad_arr.at[layer, r, c].set(grad_numeric)

return grad_arr

jax.grad(grad_loss)(ucell) # Dry run for jit compilation. This is to make time comparison fair.
t0 = time()
grad_ad = jax.grad(grad_loss)(ucell)
t_ad = time() - t0
print('JAX grad_ad:\n', grad_ad)
t0 = time()
grad_nume = grad_numerical(ucell, 1E-6)
t_nume = time() - t0
print('JAX grad_numeric:\n', grad_nume)
print('JAX norm of difference: ', jnp.linalg.norm(grad_nume - grad_ad) / grad_nume.size)
return t_ad, t_nume


def optimize_torch(setting):
mee = call_mee(backend=2, **setting)

mee.ucell.requires_grad = True

t0 = time()
res = mee.conv_solve().res
de_ri, de_ti = res.de_ri, res.de_ti

loss = de_ti[de_ti.shape[0] // 2, de_ti.shape[1] // 2]

loss.backward()
grad_ad = mee.ucell.grad
t_ad = time() - t0

def grad_numerical(ucell, delta):
ucell.requires_grad = False
grad_arr = torch.zeros(ucell.shape, dtype=ucell.dtype)

for layer in range(ucell.shape[0]):
for r in range(ucell.shape[1]):
for c in range(ucell.shape[2]):
ucell_delta_m = ucell.clone().detach()
ucell_delta_m[layer, r, c] -= delta
mee.ucell = ucell_delta_m
res = mee.conv_solve().res
de_ri_delta_m, de_ti_delta_m = res.de_ri, res.de_ti

ucell_delta_p = ucell.clone().detach()
ucell_delta_p[layer, r, c] += delta
mee.ucell = ucell_delta_p
res = mee.conv_solve().res
de_ri_delta_p, de_ti_delta_p = res.de_ri, res.de_ti

cy, cx = np.array(de_ti_delta_p.shape) // 2
grad_numeric = (de_ti_delta_p[cy, cx] - de_ti_delta_m[cy, cx]) / (2 * delta)
grad_arr[layer, r, c] = grad_numeric

return grad_arr

t0 = time()
grad_nume = grad_numerical(mee.ucell, 1E-6)
t_nume = time() - t0

print('Torch grad_ad:\n', grad_ad)
print('Torch grad_numeric:\n', grad_nume)
print('torch.norm: ', torch.linalg.norm(grad_nume - grad_ad) / grad_nume.numel())
return t_ad, t_nume


if __name__ == '__main__':
setting = load_setting()

print('JaxMeent')
j_t_ad, j_t_nume = optimize_jax(setting)
print('TorchMeent')
t_t_ad, t_t_nume = optimize_torch(setting)

print(f'Time for Backprop, JAX, AD: {j_t_ad} s, Numerical: {j_t_nume} s')
print(f'Time for Backprop, Torch, AD: {t_t_ad} s, Numerical: {t_t_nume} s')
Loading

0 comments on commit 310f0b6

Please sign in to comment.