Skip to content

Commit

Permalink
Merge pull request #77 from kc-ml2/DEV/main
Browse files Browse the repository at this point in the history
Dev/main
  • Loading branch information
yonghakim authored Aug 8, 2024
2 parents e7a2980 + 6e54c5b commit 22366a9
Show file tree
Hide file tree
Showing 83 changed files with 6,910 additions and 8,499 deletions.
58 changes: 58 additions & 0 deletions QA/1D_grating_in_2D_pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np

from meent.main import call_mee


def test():
backend = 0
pol = 1 # 0: TE, 1: TM

n_top = 1 # n_incidence
n_bot = 1 # n_transmission

theta = 1E-10 # angle of incidence in radian
phi = 0 # azimuth angle in radian

wavelength = 300 # wavelength
thickness = [460, 22]
period = [700, 700]
fto = [10, 0]

# 1D
ucell = np.array([
[
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
],
[
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
],
])

AA = call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
fto=fto, wavelength=wavelength, period=period, ucell=ucell, thickness=thickness)
de_ri, de_ti = AA.conv_solve()
print('1D', de_ri.sum(), de_ti.sum())

# 2D case

ucell = np.array([
[
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
],
[
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
[1, 1, 1, 3.48, 3.48, 3.48, 1, 1, 1, 1],
],
])

AA = call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
fto=fto, wavelength=wavelength, period=period, ucell=ucell, thickness=thickness)
de_ri, de_ti = AA.conv_solve()
print('2D', de_ri.sum(), de_ti.sum())


if __name__ == '__main__':
test()
107 changes: 107 additions & 0 deletions QA/autograd_complex_ucell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import time

import jax
import optax
import numpy as np
import jax.numpy as jnp

import torch

import meent
from meent.on_torch.optimizer.loss import LossDeflector

type_complex = 0
device = 0
n_top = 1 # n_incidence
n_bot = 1 # n_transmission
theta = 0/180 * np.pi # angle of incidence
phi = 0/180 * np.pi # angle of rotation
wavelength = 900

pol = 0 # 0: TE, 1: TM
iteration = 20

fto = [5, 5]
period = [1000, 1000] # length of the unit cell. Here it's 1D.
thickness = [500] # thickness of each layer, from top to bottom.

ucell = np.array([[[2.58941352 + 0.47745679j, 4.17771602 + 0.88991205j,
2.04255624 + 2.23670125j, 2.50478974 + 2.05242759j,
3.32747593 + 2.3854387j],
[2.80118605 + 0.53053715j, 4.46498861 + 0.10812571j,
3.99377545 + 1.0441131j, 3.10728537 + 0.6637353j,
4.74697849 + 0.62841253j],
[3.80944424 + 2.25899274j, 3.70371553 + 1.32586402j,
3.8011133 + 1.49939415j, 3.14797238 + 2.91158289j,
4.3085404 + 2.44344691j],
[2.22510179 + 2.86017146j, 2.36613053 + 2.82270351j,
4.5087168 + 0.2035904j, 3.15559949 + 2.55311298j,
4.29394604 + 0.98362617j],
[3.31324163 + 2.77590131j, 2.11744834 + 1.65894674j,
3.59347907 + 1.28895345j, 3.85713467 + 1.90714056j,
2.93805426 + 2.63385392j]]])

# JAX Meent
jmee = meent.call_mee(backend=1, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
fto=fto, wavelength=wavelength, period=period, ucell=ucell,
thickness=thickness, type_complex=type_complex, device=device)

pois = ['ucell', 'thickness'] # Parameter Of Interests
forward = jmee.conv_solve
loss_fn = LossDeflector(x_order=0, y_order=0)

# case 1: Gradient
grad_j = jmee.grad(pois, forward, loss_fn)

print('ucell gradient:')
print(grad_j['ucell'])
print('thickness gradient:')
print(grad_j['thickness'])

optimizer = optax.sgd(learning_rate=1e-2)
t0 = time.time()
res_j = jmee.fit(pois, forward, loss_fn, optimizer, iteration=iteration)
print('Time JAX', time.time() - t0)

print('ucell final:')
print(res_j['ucell'])
print('thickness final:')
print(res_j['thickness'])

# Torch Meent
tmee = meent.call_mee(backend=2, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
fto=fto, wavelength=wavelength, period=period, ucell=ucell,
thickness=thickness, type_complex=type_complex, device=device)

forward = tmee.conv_solve
loss_fn = LossDeflector(x_order=0) # predefined in meent

grad_t = tmee.grad(pois, forward, loss_fn)
print('ucell gradient:')
print(grad_t['ucell'])
print('thickness gradient:')
print(grad_t['thickness'])

opt_torch = torch.optim.SGD
opt_options = {'lr': 1E-2}

t0 = time.time()
res_t = tmee.fit(pois, forward, loss_fn, opt_torch, opt_options, iteration=iteration)
print('Time Torch: ', time.time() - t0)
print('ucell final:')
print(res_t[0])
print('thickness final:')
print(res_t[1])

print('\n=============Difference between JaxMeent and TorchMeent==============================\n')
print('initial ucell gradient difference', np.linalg.norm(grad_j['ucell'].conj() - grad_t['ucell'].detach().numpy()))
print('initial thickness gradient difference', np.linalg.norm(grad_j['thickness'].conj() - grad_t['thickness'].detach().numpy()))

print('final ucell difference', np.linalg.norm(res_j['ucell'] - res_t[0].detach().numpy()))
print('final thickness difference', np.linalg.norm(res_j['thickness'] - res_t[1].detach().numpy()))

print('End')

# Note that the gradient in JAX is conjugated.
# https://github.com/google/jax/issues/4891
# https://pytorch.org/docs/stable/notes/autograd.html#autograd-for-complex-numbers
126 changes: 51 additions & 75 deletions QA/auto-grad_numerical-grad.py → QA/autograd_raster.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import warnings
import jax
import jax.numpy as jnp
import torch

import numpy as np

from copy import deepcopy
Expand All @@ -8,19 +13,18 @@
def load_setting():
pol = 1 # 0: TE, 1: TM

n_I = 1 # n_incidence
n_II = 1 # n_transmission
n_top = 1 # n_incidence
n_bot = 1 # n_transmission

theta = 0 * np.pi / 180
phi = 0 * np.pi / 180
psi = 0 * np.pi / 180 if pol else 90 * np.pi / 180

wavelength = 900

fourier_order = [2, 2]
fto = [2, 2]

# case 1
grating_type = 2
period = [1000, 1000]
thickness = [1120., 400, 300]

Expand All @@ -41,69 +45,46 @@ def load_setting():
]
)

# # case 2
# grating_type = 2
# period = [100, 100]
# thickness = [1120.]
# ucell = np.array([
# [
# [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, ],
# [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, ],
# [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, ],
# [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, ],
# [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, ],
# [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ],
# [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ],
# [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ],
# [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ],
# [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ],
# ],
# ]) * 8 + 1.
#
# # case 3
# grating_type = 0 # grating type: 0 for 1D grating without rotation (phi == 0)
# thickness = [500, 1000] # thickness of each layer, from top to bottom.
# ucell = np.array([
# [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ]],
# [[1, 1, 1, 1, 0, 1, 1, 1, 1, 1, ]],
# ]) * 4 + 1 # refractive index
#
# # case 4
# grating_type = 2
#
# thickness, period = [1120.], [1000, 1000]
# ucell = np.array(
# [
# [
# [3.5, 1.2, 1.5, 1.2, 3.3],
# [3.1, 1.5, 1.5, 1.4, 3.1],
# ],
# ]
# )
# Case 4
thickness = [1120]

ucell = np.array([[[2.58941352 + 0.47745679j, 4.17771602 + 0.88991205j,
2.04255624 + 2.23670125j, 2.50478974 + 2.05242759j,
3.32747593 + 2.3854387j],
[2.80118605 + 0.53053715j, 4.46498861 + 0.10812571j,
3.99377545 + 1.0441131j, 3.10728537 + 0.6637353j,
4.74697849 + 0.62841253j],
[3.80944424 + 2.25899274j, 3.70371553 + 1.32586402j,
3.8011133 + 1.49939415j, 3.14797238 + 2.91158289j,
4.3085404 + 2.44344691j],
[2.22510179 + 2.86017146j, 2.36613053 + 2.82270351j,
4.5087168 + 0.2035904j, 3.15559949 + 2.55311298j,
4.29394604 + 0.98362617j],
[3.31324163 + 2.77590131j, 2.11744834 + 1.65894674j,
3.59347907 + 1.28895345j, 3.85713467 + 1.90714056j,
2.93805426 + 2.63385392j]]])
ucell = ucell.real

type_complex = 0
device = 0
return grating_type, pol, n_I, n_II, theta, phi, psi, wavelength, thickness, period, fourier_order, \
type_complex, device, ucell

return pol, n_top, n_bot, theta, phi, psi, wavelength, thickness, period, fto, type_complex, device, ucell

def optimize_jax():
import jax
import jax.numpy as jnp

grating_type, pol, n_I, n_II, theta, phi, psi, wavelength, thickness, period, fourier_order, \
type_complex, device, ucell = load_setting()
def optimize_jax(setting):
pol, n_top, n_bot, theta, phi, psi, wavelength, thickness, period, fto, \
type_complex, device, ucell = setting

mee = call_mee(backend=1, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi,
fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell,
mee = call_mee(backend=1, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
fto=fto, wavelength=wavelength, period=period, ucell=ucell,
thickness=thickness, device=device,
type_complex=type_complex, perturbation=1E-10)
type_complex=type_complex)
ucell = mee.ucell

@jax.grad
def grad_loss(ucell):
mee.ucell = ucell
de_ri, de_ti, _, _, _ = mee._conv_solve()
# de_ri, de_ti, _, _, _ = mee._conv_solve()
de_ri, de_ti = mee.conv_solve()
try:
loss = de_ti[de_ti.shape[0] // 2, de_ti.shape[1] // 2]
except:
Expand All @@ -117,10 +98,12 @@ def grad_numerical(ucell, delta):
for c in range(ucell.shape[2]):
ucell_delta_m = ucell.at[layer, r, c].set(ucell[layer, r, c] - delta)
mee.ucell = ucell_delta_m
de_ri_delta_m, de_ti_delta_m, _, _, _ = mee._conv_solve()
# de_ri_delta_m, de_ti_delta_m, _, _, _ = mee._conv_solve()
de_ri_delta_m, de_ti_delta_m = mee.conv_solve()
ucell_delta_p = ucell.at[layer, r, c].set(ucell[layer, r, c] + delta)
mee.ucell = ucell_delta_p
de_ri_delta_p, de_ti_delta_p, _, _, _ = mee._conv_solve()
# de_ri_delta_p, de_ti_delta_p, _, _, _ = mee._conv_solve()
de_ri_delta_p, de_ti_delta_p = mee.conv_solve()
try:
grad_numeric = \
(de_ti_delta_p[de_ti_delta_p.shape[0] // 2, de_ti_delta_p.shape[1] // 2]
Expand All @@ -140,18 +123,17 @@ def grad_numerical(ucell, delta):
print('JAX norm: ', jnp.linalg.norm(grad_nume - grad_ad) / grad_nume.size)


def optimize_torch():
def optimize_torch(setting):
"""
out of date.
Will be updated.
"""
import torch

grating_type, pol, n_I, n_II, theta, phi, psi, wavelength, thickness, period, fourier_order, \
type_complex, device, ucell = load_setting()
pol, n_top, n_bot, theta, phi, psi, wavelength, thickness, period, fto, \
type_complex, device, ucell = setting

tmee = call_mee(backend=2, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi,
fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell,
tmee = call_mee(backend=2, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,
fto=fto, wavelength=wavelength, period=period, ucell=ucell,
thickness=thickness, device=device,
type_complex=type_complex, )
tmee.ucell.requires_grad = True
Expand Down Expand Up @@ -200,16 +182,10 @@ def grad_numerical(ucell, delta):


if __name__ == '__main__':
try:
print('JaxMeent')
optimize_jax()
except Exception as e:
print('JaxMeent has problem. Do you have JAX?')
print(e)
setting = load_setting()

try:
print('TorchMeent')
optimize_torch()
except Exception as e:
print('TorchMeent has problem. Do you have PyTorch?')
print(e)
print('JaxMeent')
optimize_jax(setting)

print('TorchMeent')
optimize_torch(setting)
Loading

0 comments on commit 22366a9

Please sign in to comment.