Skip to content

Commit

Permalink
Merge pull request #9 from kc-ml2/DEV
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
yonghakim authored Jan 9, 2023
2 parents ea00b63 + 66d387a commit 0882c3d
Show file tree
Hide file tree
Showing 47 changed files with 6,240 additions and 1,439 deletions.
2 changes: 1 addition & 1 deletion JLAB/solver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import numpy as np

from meent.on_numpy.rcwa import RCWALight as RCWA
from meent.on_numpy.rcwa import RCWANumpy as RCWA
from meent.on_numpy.convolution_matrix import to_conv_mat, find_nk_index


Expand Down
2 changes: 1 addition & 1 deletion benchmarks/interface/Reticolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run_acs_loop_wavelength(self, pattern, deflected_angle, wls=None, n_si='SILI
if wls is None:
wls = self.wavelength
else:
self.wavelength = wls # TODO: handle better.
self.wavelength = wls

if type(n_si) == str and n_si.upper() == 'SILICON':
n_si = find_nk_index(n_si, self.mat_table, self.wavelength)
Expand Down
129 changes: 129 additions & 0 deletions examples/JAX/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import time
import jax
from jax import jit
import numpy as np
import jax.numpy as jnp

import torch

def jit_vs_nonjit():
ucell = jnp.zeros((1, 100, 100))

res = jnp.zeros(ucell.shape, dtype='complex')

@jit
def assign(arr, index, value):
arr = arr.at[index].set(value)
return arr

assign_index = (0, 0, 0)
assign_value = 3 ** 2

t0 = time.time()
arr = res.at[assign_index].set(assign_value)
print('at set 1: ', time.time() - t0)

t0 = time.time()
arr = res.at[assign_index].set(assign_value)
print('at set 2: ', time.time() - t0)

t0 = time.time()
arr = res.at[assign_index].set(assign_value)
print('at set 3: ', time.time() - t0)

t0 = time.time()
arr = assign(res, assign_index, assign_value).block_until_ready()
print('assign 1: ', time.time() - t0)

t0 = time.time()
arr = assign(res, assign_index, assign_value).block_until_ready()
print('assign 2: ', time.time() - t0)


t0 = time.time()
arr = assign(res, assign_index, assign_value)
print('assign 3: ', time.time() - t0)


for i in range(1):
# res = assign(res, assign_index, assign_value)
arr = res.at[assign_index].set(assign_value)
print(time.time() - t0)
t0 = time.time()
for i in range(100):
# res = assign(res, assign_index, assign_value)
arr = res.at[assign_index].set(assign_value)
print(time.time() - t0)

t0 = time.time()
for i in range(1):
arr = assign(res, assign_index, assign_value)
# arr = res.at[tuple(assign_index)].set(assign_value)
print(time.time() - t0)

t0 = time.time()
for i in range(100):
arr = assign(res, assign_index, assign_value).block_until_ready()
# arr = res.at[tuple(assign_index)].set(assign_value)
print(time.time() - t0)

# Result

# at set 1: 0.03652310371398926
# at set 2: 0.0010008811950683594
# at set 3: 0.0007517337799072266
# assign 1: 0.016371965408325195
# assign 2: 4.601478576660156e-05
# assign 3: 3.0994415283203125e-05

# at set 1: 0.0009369850158691406
# at set 2 to 102: 0.06914997100830078
# assign 1: 5.412101745605469e-05
# assign 2 to 102: 0.0008990764617919922


def test():
ss = 4000
aa = np.arange(ss*ss).reshape((ss, ss))
bb = torch.Tensor(aa)
itera = 1000

for _ in range(itera):
t0 = time.time()
np.linalg.eig(aa)
print(time.time() - t0)

print('jax')
for _ in range(itera):
t0 = time.time()
jnp.linalg.eig(aa)
print(time.time() - t0)

print('jit')
t0 = time.time()
eig = jax.jit(jnp.linalg.eig)
eig(aa)
print(time.time() - t0)

for _ in range(itera-1):
t0 = time.time()
eig(aa)
print(time.time() - t0)

print('torch')
for _ in range(itera):
t0 = time.time()
torch.linalg.eig(bb)
print(time.time()-t0)



if __name__ == '__main__':
# Global flag to set a specific platform, must be used at startup.
jax.config.update('jax_platform_name', 'cpu')

x = jnp.square(2)
print(repr(x.device_buffer.device())) # CpuDevice(id=0)

# jit_vs_nonjit()
test()
80 changes: 41 additions & 39 deletions examples/ex2_field_distribution.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
import time
import numpy as np

from meent.rcwa import call_solver

grating_type = 0 # 0: 1D, 1: 1D conical, 2:2D.
pol = 1 # 0: TE, 1: TM

n_I = 1.45 # n_incidence
n_II = 1 # n_transmission

theta = 0 # in degree, notation from Moharam paper
phi = 0 # in degree, notation from Moharam paper
psi = 0 if pol else 90 # in degree, notation from Moharam paper

wls = np.linspace(900, 900, 1) # wavelength

if grating_type in (0, 1):
def_angle = 60
period = abs(wls / np.sin(def_angle / 180 * np.pi))
# period = [2000]
fourier_order = 5
patterns = [[3.48, 1, 1]] # n_ridge, n_groove, fill_factor
#
# else:
# period = [700, 700]
# fourier_order = 2
# patterns = [[3.48, 1, [0.3, 1]], [3.48, 1, [0.3, 1]]] # n_ridge, n_groove, fill_factor[x, y]

thickness = [325]

t0 = time.perf_counter()
solver = call_solver(mode=0, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi,
fourier_order=fourier_order, wls=wls, period=period, patterns=patterns, thickness=thickness)

a, b = solver.loop_wavelength_fill_factor()
# solver.plot()

print('wall time: ', time.perf_counter() - t0)
# import time
# import numpy as np
#
# from meent.rcwa import call_solver
#
# grating_type = 0 # 0: 1D, 1: 1D conical, 2:2D.
# pol = 1 # 0: TE, 1: TM
#
# n_I = 1.45 # n_incidence
# n_II = 1 # n_transmission
#
# theta = 0 # in degree, notation from Moharam paper
# phi = 0 # in degree, notation from Moharam paper
# psi = 0 if pol else 90 # in degree, notation from Moharam paper
#
# wls = np.linspace(900, 900, 1) # wavelength
#
# if grating_type in (0, 1):
# def_angle = 60
# period = abs(wls / np.sin(def_angle / 180 * np.pi))
# # period = [2000]
# fourier_order = 5
# patterns = [[3.48, 1, 1]] # n_ridge, n_groove, fill_factor
# #
# # else:
# # period = [700, 700]
# # fourier_order = 2
# # patterns = [[3.48, 1, [0.3, 1]], [3.48, 1, [0.3, 1]]] # n_ridge, n_groove, fill_factor[x, y]
#
# thickness = [325]
#
# t0 = time.perf_counter()
# solver = call_solver(mode=0, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi,
# fourier_order=fourier_order, wls=wls, period=period, patterns=patterns, thickness=thickness)
#
# a, b = solver.run_ucell()
# solver.calculate_field()
#
# # solver.plot()
#
# print('wall time: ', time.perf_counter() - t0)
Loading

0 comments on commit 0882c3d

Please sign in to comment.