diff --git a/JLAB/solver.py b/JLAB/solver.py index 25a1b7c..6f2f463 100644 --- a/JLAB/solver.py +++ b/JLAB/solver.py @@ -1,7 +1,7 @@ import time import numpy as np -from meent.on_numpy.rcwa import RCWALight as RCWA +from meent.on_numpy.rcwa import RCWANumpy as RCWA from meent.on_numpy.convolution_matrix import to_conv_mat, find_nk_index diff --git a/benchmarks/interface/Reticolo.py b/benchmarks/interface/Reticolo.py index 855c5ae..0be7608 100644 --- a/benchmarks/interface/Reticolo.py +++ b/benchmarks/interface/Reticolo.py @@ -80,7 +80,7 @@ def run_acs_loop_wavelength(self, pattern, deflected_angle, wls=None, n_si='SILI if wls is None: wls = self.wavelength else: - self.wavelength = wls # TODO: handle better. + self.wavelength = wls if type(n_si) == str and n_si.upper() == 'SILICON': n_si = find_nk_index(n_si, self.mat_table, self.wavelength) diff --git a/examples/JAX/benchmark.py b/examples/JAX/benchmark.py new file mode 100644 index 0000000..0ce4c30 --- /dev/null +++ b/examples/JAX/benchmark.py @@ -0,0 +1,129 @@ +import time +import jax +from jax import jit +import numpy as np +import jax.numpy as jnp + +import torch + +def jit_vs_nonjit(): + ucell = jnp.zeros((1, 100, 100)) + + res = jnp.zeros(ucell.shape, dtype='complex') + + @jit + def assign(arr, index, value): + arr = arr.at[index].set(value) + return arr + + assign_index = (0, 0, 0) + assign_value = 3 ** 2 + + t0 = time.time() + arr = res.at[assign_index].set(assign_value) + print('at set 1: ', time.time() - t0) + + t0 = time.time() + arr = res.at[assign_index].set(assign_value) + print('at set 2: ', time.time() - t0) + + t0 = time.time() + arr = res.at[assign_index].set(assign_value) + print('at set 3: ', time.time() - t0) + + t0 = time.time() + arr = assign(res, assign_index, assign_value).block_until_ready() + print('assign 1: ', time.time() - t0) + + t0 = time.time() + arr = assign(res, assign_index, assign_value).block_until_ready() + print('assign 2: ', time.time() - t0) + + + t0 = time.time() + arr = assign(res, assign_index, assign_value) + print('assign 3: ', time.time() - t0) + + + for i in range(1): + # res = assign(res, assign_index, assign_value) + arr = res.at[assign_index].set(assign_value) + print(time.time() - t0) + t0 = time.time() + for i in range(100): + # res = assign(res, assign_index, assign_value) + arr = res.at[assign_index].set(assign_value) + print(time.time() - t0) + + t0 = time.time() + for i in range(1): + arr = assign(res, assign_index, assign_value) + # arr = res.at[tuple(assign_index)].set(assign_value) + print(time.time() - t0) + + t0 = time.time() + for i in range(100): + arr = assign(res, assign_index, assign_value).block_until_ready() + # arr = res.at[tuple(assign_index)].set(assign_value) + print(time.time() - t0) + + # Result + + # at set 1: 0.03652310371398926 + # at set 2: 0.0010008811950683594 + # at set 3: 0.0007517337799072266 + # assign 1: 0.016371965408325195 + # assign 2: 4.601478576660156e-05 + # assign 3: 3.0994415283203125e-05 + + # at set 1: 0.0009369850158691406 + # at set 2 to 102: 0.06914997100830078 + # assign 1: 5.412101745605469e-05 + # assign 2 to 102: 0.0008990764617919922 + + +def test(): + ss = 4000 + aa = np.arange(ss*ss).reshape((ss, ss)) + bb = torch.Tensor(aa) + itera = 1000 + + for _ in range(itera): + t0 = time.time() + np.linalg.eig(aa) + print(time.time() - t0) + + print('jax') + for _ in range(itera): + t0 = time.time() + jnp.linalg.eig(aa) + print(time.time() - t0) + + print('jit') + t0 = time.time() + eig = jax.jit(jnp.linalg.eig) + eig(aa) + print(time.time() - t0) + + for _ in range(itera-1): + t0 = time.time() + eig(aa) + print(time.time() - t0) + + print('torch') + for _ in range(itera): + t0 = time.time() + torch.linalg.eig(bb) + print(time.time()-t0) + + + +if __name__ == '__main__': + # Global flag to set a specific platform, must be used at startup. + jax.config.update('jax_platform_name', 'cpu') + + x = jnp.square(2) + print(repr(x.device_buffer.device())) # CpuDevice(id=0) + + # jit_vs_nonjit() + test() \ No newline at end of file diff --git a/examples/ex2_field_distribution.py b/examples/ex2_field_distribution.py index 59d3655..a648e35 100644 --- a/examples/ex2_field_distribution.py +++ b/examples/ex2_field_distribution.py @@ -1,39 +1,41 @@ -import time -import numpy as np - -from meent.rcwa import call_solver - -grating_type = 0 # 0: 1D, 1: 1D conical, 2:2D. -pol = 1 # 0: TE, 1: TM - -n_I = 1.45 # n_incidence -n_II = 1 # n_transmission - -theta = 0 # in degree, notation from Moharam paper -phi = 0 # in degree, notation from Moharam paper -psi = 0 if pol else 90 # in degree, notation from Moharam paper - -wls = np.linspace(900, 900, 1) # wavelength - -if grating_type in (0, 1): - def_angle = 60 - period = abs(wls / np.sin(def_angle / 180 * np.pi)) - # period = [2000] - fourier_order = 5 - patterns = [[3.48, 1, 1]] # n_ridge, n_groove, fill_factor -# -# else: -# period = [700, 700] -# fourier_order = 2 -# patterns = [[3.48, 1, [0.3, 1]], [3.48, 1, [0.3, 1]]] # n_ridge, n_groove, fill_factor[x, y] - -thickness = [325] - -t0 = time.perf_counter() -solver = call_solver(mode=0, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, - fourier_order=fourier_order, wls=wls, period=period, patterns=patterns, thickness=thickness) - -a, b = solver.loop_wavelength_fill_factor() -# solver.plot() - -print('wall time: ', time.perf_counter() - t0) +# import time +# import numpy as np +# +# from meent.rcwa import call_solver +# +# grating_type = 0 # 0: 1D, 1: 1D conical, 2:2D. +# pol = 1 # 0: TE, 1: TM +# +# n_I = 1.45 # n_incidence +# n_II = 1 # n_transmission +# +# theta = 0 # in degree, notation from Moharam paper +# phi = 0 # in degree, notation from Moharam paper +# psi = 0 if pol else 90 # in degree, notation from Moharam paper +# +# wls = np.linspace(900, 900, 1) # wavelength +# +# if grating_type in (0, 1): +# def_angle = 60 +# period = abs(wls / np.sin(def_angle / 180 * np.pi)) +# # period = [2000] +# fourier_order = 5 +# patterns = [[3.48, 1, 1]] # n_ridge, n_groove, fill_factor +# # +# # else: +# # period = [700, 700] +# # fourier_order = 2 +# # patterns = [[3.48, 1, [0.3, 1]], [3.48, 1, [0.3, 1]]] # n_ridge, n_groove, fill_factor[x, y] +# +# thickness = [325] +# +# t0 = time.perf_counter() +# solver = call_solver(mode=0, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, +# fourier_order=fourier_order, wls=wls, period=period, patterns=patterns, thickness=thickness) +# +# a, b = solver.run_ucell() +# solver.calculate_field() +# +# # solver.plot() +# +# print('wall time: ', time.perf_counter() - t0) diff --git a/examples/ex2_ucell.py b/examples/ex2_ucell.py index 8717242..f1645dc 100644 --- a/examples/ex2_ucell.py +++ b/examples/ex2_ucell.py @@ -1,98 +1,110 @@ import time import matplotlib.pyplot as plt +import jax.numpy as jnp import numpy as np - from meent.rcwa import call_solver, sweep_wavelength +import jax +import torch + +from ex2_ucell_functions import get_cond_numpy, get_cond_jax, get_cond_torch -grating_type = 2 # 0: 1D, 1: 1D conical, 2:2D. +# common +# grating_type = 1 # 0: 1D, 1: 1D conical, 2:2D. pol = 1 # 0: TE, 1: TM n_I = 1 # n_incidence n_II = 1 # n_transmission -theta = 0.1 # in degree, notation from Moharam paper -phi = 0 # in degree, notation from Moharam paper -psi = 0 if pol else 90 # in degree, notation from Moharam paper - -wavelength = np.array([900]) # wavelength - -if grating_type in (0, 1): - period = [1000] - fourier_order = 5 - - ucell = np.array([ - - [ - [ - 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, - ], - ], - [ - [ - 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, - ], - ], - ]) - -else: - period = [1000, 1000] - fourier_order = 15 - - ucell = np.array([ - - [ - [1, 0, 0, 1, 1, 1, 1, 1, 1, 1,], - [0, 0, 0, 1, 1, 0, 1, 1, 1, 1,], - [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], - [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], - [1, 1, 0, 1, 1, 1, 1, 1, 1, 1,], - [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], - [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], - [0, 1, 0, 1, 1, 1, 1, 1, 1, 1,], - [1, 0, 0, 1, 1, 1, 1, 1, 1, 1,], - [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], - ], - # [ - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], - # ], - ]) - - # ucell = np.array([ - # - # [ - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], - # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], - # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], - # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], - # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], - # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], - # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], - # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], - # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], - # ], - # ]) +theta = 0 +phi = 0 +psi = 0 if pol else 90 + +wavelength = 900 thickness = [500] +ucell_materials = [1, 3.48] +mode_options = {0: 'numpy', 1: 'JAX', 2: 'Torch', 3: 'numpy_integ', 4: 'JAX_integ',} +n_iter = 1 -ucell_materials = [1, 3.48] -# ucell_materials = [3.48, 1] -t0 = time.time() -AA = call_solver(mode=0, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, - fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell, ucell_materials=ucell_materials, - thickness=thickness) -de_ri, de_ti = AA.run_ucell() -print(de_ri, de_ti) -print(time.time()-t0) -res = AA.calculate_field() + +def run_test(grating_type, mode_key, dtype, device): + + if mode_key == 0: + device = None + + if dtype == 0: + type_complex = np.complex128 + else: + type_complex = np.complex64 + period, fourier_order, ucell = get_cond_numpy(grating_type) + + elif mode_key == 1: + # JAX + if device == 0: + jax.config.update('jax_platform_name', 'cpu') + else: + jax.config.update('jax_platform_name', 'gpu') + + if dtype == 0: + type_complex = jnp.complex128 + else: + type_complex = jnp.complex64 + period, fourier_order, ucell = get_cond_jax(grating_type) + + else: + # Torch + if device == 0: + device = torch.device('cpu') + else: + device = torch.device('cuda') + + if dtype == 0: + type_complex = torch.complex128 + else: + type_complex = torch.complex64 + period, fourier_order, ucell = get_cond_torch(grating_type) + + AA = call_solver(mode=mode_key, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, + fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell, + ucell_materials=ucell_materials, + thickness=thickness, device=device, type_complex=type_complex, ) + + for i in range(n_iter): + t0 = time.time() + de_ri, de_ti = AA.run_ucell() + print(f'run_cell: {i}: ', time.time()-t0) + + resolution = (20, 20, 20) + for i in range(1): + t0 = time.time() + AA.calculate_field(resolution=resolution, plot=False) + print(f'cal_field: {i}', time.time() - t0) + + return de_ri, de_ti + + +def run_loop(): + for grating_type in [0,1,2]: + for bd in [0,1,2]: + for dtype in [0,1]: + for device in [0]: + run_test(grating_type, bd, dtype, device) + + try: + print(f'grating:{grating_type}, backend:{bd}, dtype:{dtype}, dev:{device}') + run_test(grating_type, bd, dtype, device) + except Exception as e: + print(e) + + +def run_assert(): + + for grating_type in [0,1,2]: + for bd in [0,1,2]: + print(run_test(grating_type, bd, 0, 0)) + + +if __name__ == '__main__': + run_assert() diff --git a/examples/ex2_ucell_functions.py b/examples/ex2_ucell_functions.py new file mode 100644 index 0000000..3439c2e --- /dev/null +++ b/examples/ex2_ucell_functions.py @@ -0,0 +1,252 @@ +import numpy as np +import jax.numpy as jnp + +import torch + +fourier_order = 1 +def get_cond_numpy(grating_type): + + if grating_type in [0, 1]: + + period = [1000] + + ucell = np.array([ + + [ + [ + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + ], + ], + # [ + # [ + # 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + # ], + # ], + ]) + else: + period = [1000, 1000] + + # ucell = torch.tensor([ + ucell = np.array([ + # [ + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0,], + # [0, 0, 0, 1, 1, 0, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [1, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # ], + + # [ + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # ], + + [ + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + ], + ]) + + # ucell = np.array([ + # + # [ + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # ], + # ]) + return period, fourier_order, ucell + + +def get_cond_jax(grating_type): + + if grating_type in [0, 1]: + + period = [1000] + + ucell = jnp.array([ + + [ + [ + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + ], + ], + # [ + # [ + # 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + # ], + # ], + ]) + else: + # period = torch.tensor([1000, 1000]) + period = [1000, 1000] + + # ucell = torch.tensor([ + ucell = jnp.array([ + # [ + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0,], + # [0, 0, 0, 1, 1, 0, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [1, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # ], + + # [ + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # ], + + [ + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + ], + ]) + + # ucell = np.array([ + # + # [ + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # ], + # ]) + return period, fourier_order, ucell + + +def get_cond_torch(grating_type): + + if grating_type in [0, 1]: + + period = [1000] + + ucell = torch.tensor([ + + [ + [ + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + ], + ], + # [ + # [ + # 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + # ], + # ], + ]) + else: + # period = torch.tensor([1000, 1000]) + period = [1000, 1000] + + ucell = torch.tensor([ + # [ + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0,], + # [0, 0, 0, 1, 1, 0, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [1, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # ], + + # [ + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # ], + + [ + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + ], + ]) + + # ucell = np.array([ + # + # [ + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # ], + # ]) + return period, fourier_order, ucell + diff --git a/examples/ex2_ucell_torch.py b/examples/ex2_ucell_torch.py new file mode 100644 index 0000000..d83bb10 --- /dev/null +++ b/examples/ex2_ucell_torch.py @@ -0,0 +1,151 @@ +import time + +import matplotlib.pyplot as plt +import jax.numpy as jnp +import numpy as np +from meent.rcwa import call_solver, sweep_wavelength +import jax +import torch + +# JAX +jax.config.update('jax_platform_name', 'cpu') +# jax.config.update('jax_platform_name', 'gpu') + +# Torch +device = torch.device('cuda') +# device = torch.device('cpu') + +type_complex = torch.complex128 +type_complex = torch.complex64 +# type_complex = np.complex64 +# type_complex = jnp.complex64 + +# common +grating_type = 2 # 0: 1D, 1: 1D conical, 2:2D. +pol = 1 # 0: TE, 1: TM + +n_I = 1 # n_incidence +n_II = 1 # n_transmission + +theta = 0.1 +phi = 0 +psi = 0 if pol else 90 +# wavelength = np.array([900]) +wavelength = 900 # TODO: in numpy mode, np.array([900]) and 900 shows different result. final result of array shows 1E-14 order but 900 only shows 0 + + +if grating_type in (0, 1): + period = [1000] + fourier_order = 20 + + ucell = np.array([ + + [ + [ + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + ], + ], + # [ + # [ + # 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + # ], + # ], + ]) + +else: + # period = torch.tensor([1000, 1000]) + period = [1000, 1000] + fourier_order = 15 + fourier_order = 20 + + # ucell = torch.tensor([ + ucell = np.array([ + # [ + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0,], + # [0, 0, 0, 1, 1, 0, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [1, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # ], + + # [ + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ], + # ], + + [ + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, ], + ], + ]) + + # ucell = np.array([ + # + # [ + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [0, 0, 0, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,], + # ], + # ]) + + +thickness = [500] + + +ucell_materials = [1, 3.48] + +mode_options = {0: 'numpy', 1: 'JAX', 2: 'Torch', 3: 'numpy_integ', 4: 'JAX_integ',} + +mode_key = 2 + +n_iter = 1 + +print(mode_options[mode_key]) + +AA = call_solver(mode=mode_key, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, + fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell, + ucell_materials=ucell_materials, + thickness=thickness, device=device, type_complex=type_complex, ) + +for i in range(n_iter): + t0 = time.time() + de_ri, de_ti = AA.run_ucell() + print(f'run {i}: ', time.time()-t0) + + +resolution = (20, 20, 20) +for i in range(1): + t0 = time.time() + AA.calculate_field(resolution=resolution, plot=True) + print(time.time() - t0) + + +# print(de_ri, de_ti) diff --git a/examples/ex3_ucell_materials.py b/examples/ex3_ucell_materials.py new file mode 100644 index 0000000..97b730d --- /dev/null +++ b/examples/ex3_ucell_materials.py @@ -0,0 +1,57 @@ +import time +import numpy as np + +from meent.on_numpy.rcwa import RCWANumpy as RCWA +from meent.rcwa import call_solver, sweep_wavelength + + +grating_type = 2 # 0: 1D, 1: 1D conical, 2:2D. +pol = 1 # 0: TE, 1: TM + +n_I = 1 # n_incidence +n_II = 1 # n_transmission + +theta = 0 # in degree, notation from Moharam paper +phi = 0 # in degree, notation from Moharam paper +psi = 0 if pol else 90 # in degree, notation from Moharam paper + +wls = np.linspace(900, 900, 1) # wavelength + +if grating_type in (0, 1): + period = [1400] + fourier_order = 20 + +else: + period = [700, 700] + fourier_order = 3 + +thickness = [460, 660] + +ucell = np.array([ + + [ + [1, 1, 1, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 1], + ], + [ + [1, 1, 1, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 1], + ], +]) + +ucell_materials = ['SILICON', 1] + +AA = call_solver(mode=0, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, + fourier_order=fourier_order, wls=wls, period=period, ucell=ucell, ucell_materials=ucell_materials, + thickness=thickness) +de_ri, de_ti = AA.run_ucell() +print(de_ri, de_ti) + +wls = np.linspace(500, 1000, 100) + +a, b = sweep_wavelength(wls, mode=0, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, + fourier_order=fourier_order, period=period, ucell=ucell, ucell_materials=ucell_materials, thickness=thickness) + +print(a) diff --git a/examples/ex5_ucell-DEV_FFT.py b/examples/ex5_ucell-DEV_FFT.py deleted file mode 100644 index f1c1b1c..0000000 --- a/examples/ex5_ucell-DEV_FFT.py +++ /dev/null @@ -1,59 +0,0 @@ -import time - -import matplotlib.pyplot as plt -import numpy as np - -from meent.rcwa import call_solver, sweep_wavelength - - -grating_type = 2 # 0: 1D, 1: 1D conical, 2:2D. -pol = 1 # 0: TE, 1: TM - -n_I = 1 # n_incidence -n_II = 1 # n_transmission - -theta = 0 # in degree, notation from Moharam paper -phi = 0 # in degree, notation from Moharam paper -psi = 0 if pol else 90 # in degree, notation from Moharam paper - -wavelength = np.linspace(900, 900, 1) # wavelength - -if grating_type in (0, 1): - period = [1400] - fourier_order = 3 - -else: - period = [700, 700] - fourier_order = 3 - -thickness = [460, 660] - -ucell = np.array([ - - [ - [1, 1, 1, 0, 0, 0, 0, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 1, 0], - [1, 1, 1, 0, 0, 0, 1, 1, 1, 1], - ], - # [ - # [1, 1, 1, 0, 0, 0, 1, 1, 1, 1], - # [1, 1, 1, 0, 0, 0, 1, 1, 1, 1], - # [1, 1, 1, 0, 0, 0, 1, 1, 1, 1], - # ], -]) - -ucell_materials = [1, 1.5] - -AA = call_solver(mode=0, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, - fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell, ucell_materials=ucell_materials, - thickness=thickness) -de_ri, de_ti = AA.run_ucell() -print(de_ri, de_ti) - -wavelength_array = np.linspace(500, 1000, 100) - -a, b = sweep_wavelength(wavelength_array, mode=0, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, - fourier_order=fourier_order, period=period, ucell=ucell, ucell_materials=ucell_materials, thickness=thickness) - -plt.plot(wavelength_array, a.sum((1, 2)), wavelength_array, b.sum((1, 2))) -plt.show() diff --git a/examples/ex6_ucell_dev.py b/examples/ex6_ucell_dev.py new file mode 100644 index 0000000..c4a61c6 --- /dev/null +++ b/examples/ex6_ucell_dev.py @@ -0,0 +1,77 @@ +import numpy as np +# import jax.numpy as jnp +import pandas as pd +import matplotlib.pyplot as plt +import time + +from meent.rcwa import call_solver + + +n_I = 1 +n_si = 3.48 +n_air = 1 +n_II = 1 +theta = 1E-10 +phi = 1E-10 + +fourier_order = 2 + +period = [700, 700] +wls = np.array([900.]) +pol = 1 +psi = 0 if pol else 90 # in degree, notation from Moharam paper + +thickness = [1120] + +pattern = np.array([n_si, n_si, n_si, n_air, n_air, n_air, n_air, n_air, n_air, n_air, ]) + +N = len(pattern) +dx = period[0]/N +grid = np.arange(1, N+1)*dx + +textures = [n_I, [grid, pattern], n_II] + +profile = np.array([[0, *thickness, 0], [1, 2, 3]]) +grating_type = 2 + +pattern = np.array([1, 1, 1, 0, 0, 0, 0, 0, 0, 0, ]) +ucell_materials = [n_air, n_si] + +ucell = np.array([[pattern]]) + +from meent.on_numpy.convolution_matrix import cell_compression + +cell_comp, x, y = cell_compression(ucell[0]) + +ucell = np.zeros((1, 10, 10)) + +base = np.meshgrid(0, np.arange(0, 10, 1), np.arange(0, 10, 1)) + +obj1 = np.meshgrid(0, np.arange(0, 10, 1), np.arange(3, 5, 1)) + +obj_list = [base, obj1] +mat_list = [1, 'si'] + +from meent.on_numpy.convolution_matrix import put_permittivity_in_ucell_object, \ + read_material_table, put_permittivity_in_ucell, to_conv_mat + +# a = put_permittivity_in_ucell_new(ucell, mat_list, obj_list, None, wls) + +# --- Run --- + +solver = call_solver(mode=0, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, + fourier_order=fourier_order, wls=wls, period=period, ucell=ucell, ucell_materials=ucell_materials, + thickness=thickness) +t0 = time.time() +# de_ri, de_ti = meent_t.run_ucell() + +mat_table = read_material_table() + +solver.ucell = put_permittivity_in_ucell_object(ucell, mat_list, obj_list, mat_table, wls) + +e_conv_all = to_conv_mat(solver.ucell, solver.fourier_order) +o_e_conv_all = to_conv_mat(1 / solver.ucell, solver.fourier_order) + +de_ri, de_ti = solver.solve(solver.wavelength, e_conv_all, o_e_conv_all) + +print(de_ri) diff --git a/examples/optimization/optimize_test.py b/examples/optimization/optimize_test.py index 407eb62..9c496d5 100644 --- a/examples/optimization/optimize_test.py +++ b/examples/optimization/optimize_test.py @@ -27,7 +27,7 @@ def get_difference(self): if __name__ == '__main__': - + t0 = time.time() aa = jnp.array(1100, dtype='float32') cc = jnp.array(1E-4, dtype='float32') # OK @@ -60,7 +60,7 @@ def loss(thick): phi = 20 psi = 0 if pol else 90 - wls = jnp.linspace(500, 2300, 1) + wls = jnp.linspace(900, 2300, 1) fourier_order = 10 # Ground Truth @@ -68,18 +68,19 @@ def loss(thick): thickness = jnp.array([1120]) cell = jnp.array([[[3.48 ** 2, 3.48 ** 2, 3.48 ** 2, 1, 1, 1, 1, 1, 1, 1]]]) ground_truth = call_solver(mode=1, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, - fourier_order=fourier_order, wls=wls, period=period, patterns=cell, thickness=thickness) + fourier_order=fourier_order, wavelength=wls, period=period, patterns=cell, thickness=thickness) # Test thickness = jnp.array([thick]) test = call_solver(mode=1, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, - fourier_order=fourier_order, wls=wls, period=period, patterns=cell, thickness=thickness) + fourier_order=fourier_order, wavelength=wls, period=period, patterns=cell, thickness=thickness) a, b = ground_truth.jax_test() c, d = test.jax_test() - gap = jnp.linalg.norm(test.spectrum_r - ground_truth.spectrum_r) + + gap = jnp.linalg.norm(a - c) # print('gap:', gap.primal) return gap @@ -128,5 +129,5 @@ def mingd(x): minimum = minimums[arglist] print("The minimum is {} the argmin is {}".format(minimum, argmin)) - + print(time.time() - t0) print('end') diff --git a/meent/integ/__init__.py b/meent/integ/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/meent/integ/_base.py b/meent/integ/_base.py new file mode 100644 index 0000000..879280c --- /dev/null +++ b/meent/integ/_base.py @@ -0,0 +1,220 @@ +from .scattering_method import scattering_1d_1, scattering_1d_2, scattering_1d_3, scattering_2d_1, scattering_2d_wv,\ + scattering_2d_2, scattering_2d_3 +from .transfer_method import transfer_1d_1, transfer_1d_2, transfer_1d_3, transfer_1d_conical_1, transfer_1d_conical_2,\ + transfer_1d_conical_3, transfer_2d_1, transfer_2d_wv, transfer_2d_2, transfer_2d_3 + +import meent.integ.backend.meentpy as ee + + +class _BaseRCWA: + + def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., fourier_order=10, + period=0.7, wavelength=900, pol=0, + patterns=None, ucell=None, ucell_materials=None, thickness=None, algo='TMM'): + + self.grating_type = grating_type # 1D=0, 1D_conical=1, 2D=2 + self.n_I = n_I + self.n_II = n_II + + self.theta = theta * ee.pi / 180 + self.phi = phi * ee.pi / 180 + self.psi = psi * ee.pi / 180 # TODO: integrate psi and pol + + self.pol = pol # TE 0, TM 1 + if self.pol == 0: # TE + self.psi = 90 * ee.pi / 180 + elif self.pol == 1: # TM + self.psi = 0 * ee.pi / 180 + else: + print('not implemented yet') + raise ValueError + + self.fourier_order = fourier_order + self.ff = 2 * self.fourier_order + 1 + + self.period = period + + self.wavelength = wavelength + + self.patterns = patterns + self.ucell = ucell + self.ucell_materials = ucell_materials + self.thickness = thickness + + self.algo = algo + + self.layer_info_list = [] + self.T1 = None + + def solve_1d(self, wl, E_conv_all, o_E_conv_all): + + fourier_indices = ee.arange(-self.fourier_order, self.fourier_order + 1) + + delta_i0 = ee.zeros(self.ff) + + # delta_i0[self.fourier_order] = 1 + delta_i0 = ee.assign(delta_i0, self.fourier_order, 1) + + k0 = 2 * ee.pi / wl + + if self.algo == 'TMM': + kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T \ + = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, + self.theta, delta_i0, self.fourier_order, fourier_indices, wl, self.period) + elif self.algo == 'SMM': + Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ + = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, fourier_indices, self.period, + self.pol, wl=wl) + else: + raise ValueError + + # From the last layer + for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + + if self.pol == 0: + E_conv_i = None + A = Kx ** 2 - E_conv + eigenvalues, W = ee.linalg.eig(A) + q = eigenvalues ** 0.5 + + Q = ee.diag(q) + V = W @ Q + + elif self.pol == 1: + E_conv_i = ee.linalg.inv(E_conv) + B = Kx @ E_conv_i @ Kx - ee.eye(E_conv.shape[0]) + o_E_conv_i = ee.linalg.inv(o_E_conv) + + eigenvalues, W = ee.linalg.eig(o_E_conv_i @ B) + q = eigenvalues ** 0.5 + + Q = ee.diag(q) + V = o_E_conv @ W @ Q + + else: + raise ValueError + + if self.algo == 'TMM': + X, f, g, T, a_i, b = transfer_1d_2(k0, q, d, W, V, f, g, self.fourier_order, T) + + layer_info = [E_conv_i, q, W, X, a_i, b, d] + self.layer_info_list.append(layer_info) + + elif self.algo == 'SMM': + A, B, S_dict, Sg = scattering_1d_2(W, Wg, V, Vg, d, k0, Q, Sg) + else: + raise ValueError + + if self.algo == 'TMM': + de_ri, de_ti, T1 = transfer_1d_3(g, YZ_I, f, delta_i0, inc_term, T, k_I_z, k0, self.n_I, self.n_II, + self.theta, self.pol, k_II_z) + self.T1 = T1 + + elif self.algo == 'SMM': + de_ri, de_ti = scattering_1d_3(Wt, Wg, Vt, Vg, Sg, self.ff, Wr, self.fourier_order, Kzr, Kzt, + self.n_I, self.n_II, self.theta, self.pol) + else: + raise ValueError + + return de_ri, de_ti + + # TODO: scattering method + def solve_1d_conical(self, wl, e_conv_all, o_e_conv_all): + + fourier_indices = ee.arange(-self.fourier_order, self.fourier_order + 1) + + delta_i0 = ee.zeros(self.ff) + delta_i0[self.fourier_order] = 1 + + k0 = 2 * ee.pi / wl + + if self.algo == 'TMM': + Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ + = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.period, fourier_indices, self.theta, self.phi, wl) + elif self.algo == 'SMM': + print('SMM for 1D conical is not implemented') + return ee.nan, ee.nan + else: + raise ValueError + + for e_conv, o_e_conv, d in zip(e_conv_all[::-1], o_e_conv_all[::-1], self.thickness[::-1]): + e_conv_i = ee.linalg.inv(e_conv) + o_e_conv_i = ee.linalg.inv(o_e_conv) + + if self.algo == 'TMM': + big_F, big_G, big_T = transfer_1d_conical_2(k0, Kx, ky, e_conv, e_conv_i, o_e_conv_i, self.ff, d, + varphi, big_F, big_G, big_T) + elif self.algo == 'SMM': + raise ValueError + else: + raise ValueError + + if self.algo == 'TMM': + de_ri, de_ti = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z) + elif self.algo == 'SMM': + raise ValueError + else: + raise ValueError + + return de_ri, de_ti + + def solve_2d(self, wl, E_conv_all, o_E_conv_all): + + fourier_indices = ee.arange(-self.fourier_order, self.fourier_order + 1) + + delta_i0 = ee.zeros((self.ff ** 2, 1)) + # delta_i0[self.ff ** 2 // 2, 0] = 1 + + assign_index = [self.ff ** 2 // 2, 0] + delta_i0 = ee.assign(delta_i0, assign_index, 1) + + I = ee.eye(self.ff ** 2) + O = ee.zeros((self.ff ** 2, self.ff ** 2)) + + center = self.ff ** 2 + + k0 = 2 * ee.pi / wl + + if self.algo == 'TMM': + kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ + = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.period, fourier_indices, self.theta, self.phi, wl) + elif self.algo == 'SMM': + Kx, Ky, kz_inc, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ + = scattering_2d_1(self.n_I, self.n_II, self.theta, self.phi, k0, self.period, self.fourier_order) + else: + raise ValueError + + # From the last layer + for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + E_conv_i = ee.linalg.inv(E_conv) + o_E_conv_i = ee.linalg.inv(o_E_conv) + + if self.algo == 'TMM': # TODO: MERGE W V part + W, V, q = transfer_2d_wv(self.ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, center) + + big_X, big_F, big_G, big_T, big_A_i, big_B, \ + W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 \ + = transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T) + + layer_info = [E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] + self.layer_info_list.append(layer_info) + + elif self.algo == 'SMM': + W, V, LAMBDA = scattering_2d_wv(self.ff, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) + A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA) + else: + raise ValueError + + if self.algo == 'TMM': + de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z) + self.T1 = big_T1 + + elif self.algo == 'SMM': + de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, + self.pol, self.theta, self.phi, self.fourier_order, self.ff) + else: + raise ValueError + + return de_ri.reshape((self.ff, self.ff)).real, de_ti.reshape((self.ff, self.ff)).real diff --git a/meent/integ/backend/__init__.py b/meent/integ/backend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/meent/integ/backend/be_jax.py b/meent/integ/backend/be_jax.py new file mode 100644 index 0000000..3a76e94 --- /dev/null +++ b/meent/integ/backend/be_jax.py @@ -0,0 +1,148 @@ + +import numpy + +import jax +import jax.numpy as np +from jax import jit + +backend = 'jax' + + +linspace = np.linspace +pi = np.pi +arange = np.arange +zeros = np.zeros +diag = np.diag + +linalg = np.linalg +eye = np.eye +nan = np.nan +interp = np.interp +loadtxt = numpy.loadtxt +array = np.array +roll = np.roll +exp = np.exp +vstack = np.vstack +ones = np.ones +repeat = np.repeat +tile = np.tile + +sin = np.sin +cos = np.cos +block = np.block + +real = np.real +imag = np.imag +conj = np.conj + +nonzero = np.nonzero + +arctan = np.arctan + +hstack = np.hstack +eig = np.linalg.eig + +from jax import jit + + +# @jit +def _assign_row_all(arr, index, value): + row, col = arr.shape + return arr.at[:, index].set(value) + + +@jit +def _assign_col_all(arr, index, value): + row, col = arr.shape + return arr.at[index, np.arange(col)].set(value) + + +@jit +def _assign(arr, index, value): + return arr.at[index].set(value) + + +def assign(arr, index, value, row_all=False, col_all=False): + if type(index) == list: + index = tuple(index) + + if row_all: + arr = _assign_row_all(arr, index, value) + elif col_all: + arr = _assign_col_all(arr, index, value) + else: + arr = _assign(arr, index, value) + return arr + + +def assign1(arr, index, value, row_all=False, col_all=False): + if type(index) == list: + index = tuple(index) + + if row_all: + arr = arr.at[:, index].set(value) + elif col_all: + arr = arr.at[index, :].set(value) + else: + arr = arr.at[index].set(value) + return arr + + + +# import numpy +# +# import jax +# import jax.numpy as np +# from jax import jit +# +# backend = 'jax' + +# +# linspace = jit(np.linspace) +# pi = np.pi +# arange = jit(np.arange) +# zeros = np.zeros +# diag = jit(np.diag) +# +# linalg = np.linalg +# eye = jit(np.eye) +# nan = np.nan +# interp = jit(np.interp) +# loadtxt = numpy.loadtxt +# array = jit(np.array) +# roll = jit(np.roll) +# exp = jit(np.exp) +# vstack = jit(np.vstack) +# ones = jit(np.ones) +# repeat = jit(np.repeat) +# tile = jit(np.tile) +# +# sin = jit(np.sin) +# cos = jit(np.cos) +# block = jit(np.block) +# +# real = jit(np.real) +# imag = jit(np.imag) +# conj = jit(np.conj) +# +# nonzero = jit(np.nonzero) +# +# arctan = jit(np.arctan) +# +# hstack = jit(np.hstack) +# eig = np.linalg.eig +# +# +# def assign(arr, index, value, row_all=False, col_all=False): +# if type(index) == list: +# index = tuple(index) +# +# if row_all: +# arr = arr.at[:, index].set(value) +# elif col_all: +# arr = arr.at[index, :].set(value) +# else: +# arr = arr.at[index].set(value) +# return arr + + diff --git a/meent/integ/backend/be_numpy.py b/meent/integ/backend/be_numpy.py new file mode 100644 index 0000000..2982649 --- /dev/null +++ b/meent/integ/backend/be_numpy.py @@ -0,0 +1,51 @@ +import numpy as np + + +backend = 'numpy' +linspace = np.linspace +pi = np.pi +arange = np.arange +zeros = np.zeros +diag = np.diag + +linalg = np.linalg +eye = np.eye +nan = np.nan +interp = np.interp +loadtxt = np.loadtxt +array = np.array +roll = np.roll +exp = np.exp +vstack = np.vstack +ones = np.ones +repeat = np.repeat +tile = np.tile + +sin = np.sin +cos = np.cos +block = np.block + +real = np.real +imag = np.imag +conj = np.conj + +nonzero = np.nonzero + +arctan = np.arctan + +hstack = np.hstack + +eig = np.linalg.eig + + +def assign(arr, index, value, row_all=False, col_all=False): + if type(index) == list: + index = tuple(index) + + if row_all: + arr[:, index] = value + elif col_all: + arr[index, :] = value + else: + arr[index] = value + return arr diff --git a/meent/integ/backend/be_tensorflow.py b/meent/integ/backend/be_tensorflow.py new file mode 100644 index 0000000..2735f47 --- /dev/null +++ b/meent/integ/backend/be_tensorflow.py @@ -0,0 +1,76 @@ +import tensorflow as tf +import numpy as np + + +backend = 'numpy' +linspace = tf.linspace +pi = tf.constant(np.pi) + +arange = tf.experimental.numpy.arange + +# zeros = tf.zeros +zeros = tf.experimental.numpy.zeros + +# diag = tf.linalg.diag +diag = tf.experimental.numpy.diag + + +linalg = tf.linalg +eye = tf.eye + +# nan = tf.nan +nan = tf.constant(np.nan) + +interp = None + +loadtxt = np.loadtxt + +array = tf.experimental.numpy.array + +# roll = tf.roll +roll = tf.experimental.numpy.roll + + +exp = tf.exp + +# vstack = tf.vstack +vstack = tf.experimental.numpy.vstack + + +ones = tf.ones +repeat = tf.repeat +tile = tf.tile + +sin = tf.sin +cos = tf.cos +block = None + +real = tf.math.real +imag = tf.math.imag +conj = tf.math.conj + +nonzero = tf.experimental.numpy.nonzero + +# arctan = tf.math.atan +arctan = tf.experimental.numpy.arctan + +hstack = tf.experimental.numpy.hstack + +eig = tf.linalg.eig + + +def assign(arr, index, value, row_all=False, col_all=False): + if type(index) == list: + index = tuple(index) + + if row_all: + arr[:, index] = value + elif col_all: + arr[index, :] = value + else: + arr[index] = value + return arr + +# https://stackoverflow.com/questions/38420288/how-to-implement-element-wise-1d-interpolation-in-tensorflow +def interp(): + pass \ No newline at end of file diff --git a/meent/integ/backend/be_torch.py b/meent/integ/backend/be_torch.py new file mode 100644 index 0000000..42a0b0f --- /dev/null +++ b/meent/integ/backend/be_torch.py @@ -0,0 +1,7 @@ +import torch + + +class BackendTorch: + backend = 'torch' + torch.device('cuda') + eig = torch.linalg.eig diff --git a/meent/integ/backend/meentpy.py b/meent/integ/backend/meentpy.py new file mode 100644 index 0000000..4f7d990 --- /dev/null +++ b/meent/integ/backend/meentpy.py @@ -0,0 +1,49 @@ +import meent.integ.backend + +if meent.integ.backend.mode == 2: + from meent.integ.backend.be_numpy import * + print(33) + # import numpy as np + # + # backend = 'numpy' + # linspace = np.linspace + # pi = np.pi + # arange = np.arange + # zeros = np.zeros + # diag = np.diag + # + # linalg = np.linalg + # eye = np.eye + # nan = np.nan + # interp = np.interp + # loadtxt = np.loadtxt + # array = np.array + # roll = np.roll + # exp = np.exp + # vstack = np.vstack + # ones = np.ones + # repeat = np.repeat + # tile = np.tile + # + # sin = np.sin + # cos = np.cos + # block = np.block + # + # real = np.real + # imag = np.imag + # conj = np.conj + # + # nonzero = np.nonzero + # + # arctan = np.arctan + # + # hstack = np.hstack + # + # eig = np.linalg.eig + + + + +elif meent.integ.backend.mode == 3: + from meent.integ.backend.be_jax import * + print(23) diff --git a/meent/integ/convolution_matrix.py b/meent/integ/convolution_matrix.py new file mode 100644 index 0000000..948a43b --- /dev/null +++ b/meent/integ/convolution_matrix.py @@ -0,0 +1,292 @@ +import meent.integ.backend.meentpy as ee + +from os import walk +from scipy.io import loadmat +from pathlib import Path + + +def put_permittivity_in_ucell(ucell, mat_list, mat_table, wl): + + res = ee.zeros(ucell.shape, dtype='complex') + + for z in range(ucell.shape[0]): + for y in range(ucell.shape[1]): + for x in range(ucell.shape[2]): + material = mat_list[ucell[z, y, x]] + if type(material) == str: + # res[z, y, x] = find_nk_index(material, mat_table, wavelength) ** 2 + assign_index = [z, y, x] + assign_value = find_nk_index(material, mat_table, wl) ** 2 + res = ee.assign(res, assign_index, assign_value) + + else: + # res[z, y, x] = material ** 2 + assign_index = [z, y, x] + assign_value = material ** 2 + + res = ee.assign(res, assign_index, assign_value) + + return res + + +def put_permittivity_in_ucell_object(ucell_size, mat_list, obj_list, mat_table, wl): + # TODO: under development + res = ee.zeros(ucell_size, dtype='complex') + + for material, obj_index in zip(mat_list, obj_list): + if type(material) == str: + res[obj_index] = find_nk_index(material, mat_table, wl) ** 2 + else: + res[obj_index] = material ** 2 + + return res + + +def find_nk_index(material, mat_table, wl): + if material[-6:] == '__real': + material = material[:-6] + n_only = True + else: + n_only = False + + mat_data = mat_table[material.upper()] + + n_index = ee.interp(wl, mat_data[:, 0], mat_data[:, 1]) + + if n_only: + return n_index + + k_index = ee.interp(wl, mat_data[:, 0], mat_data[:, 2]) + nk = n_index + 1j * k_index + + return nk + + +def read_material_table(nk_path=None): + mat_table = {} + + if nk_path is None: + nk_path = str(Path(__file__).resolve().parent.parent) + '/nk_data' + + full_path_list, name_list, _ = [], [], [] + for (dirpath, dirnames, filenames) in walk(nk_path): + full_path_list.extend([f'{dirpath}/{filename}' for filename in filenames]) + name_list.extend(filenames) + for path, name in zip(full_path_list, name_list): + if name[-3:] == 'txt': + data = ee.loadtxt(path, skiprows=1) + mat_table[name[:-4].upper()] = data + + elif name[-3:] == 'mat': + data = loadmat(path) + data = ee.array([data['WL'], data['n'], data['k']])[:, :, 0].T + mat_table[name[:-4].upper()] = data + return mat_table + + +def cell_compression(cell): + # find discontinuities in x + step_y, step_x = 1. / ee.array(cell.shape) + x = [] + y = [] + cell_x = [] + cell_xy = [] + + cell_next = ee.roll(cell, -1, axis=1) + + for col in range(cell.shape[1]): + if not (cell[:, col] == cell_next[:, col]).all() or (col == cell.shape[1] - 1): + x.append(step_x * (col + 1)) + cell_x.append(cell[:, col]) + + cell_x = ee.array(cell_x).T + cell_x_next = ee.roll(cell_x, -1, axis=0) + + for row in range(cell_x.shape[0]): + if not (cell_x[row, :] == cell_x_next[row, :]).all() or (row == cell_x.shape[0] - 1): + y.append(step_y * (row + 1)) + cell_xy.append(cell_x[row, :]) + + x = ee.array(x).reshape((-1, 1)) + y = ee.array(y).reshape((-1, 1)) + cell_comp = ee.array(cell_xy) + + return cell_comp, x, y + + +def fft_piecewise_constant(cell, fourier_order): + if cell.shape[0] == 1: + fourier_order = [0, fourier_order] + else: + fourier_order = [fourier_order, fourier_order] + cell, x, y = cell_compression(cell) + + # X axis + cell_next_x = ee.roll(cell, -1, axis=1) + cell_diff_x = cell_next_x - cell + + modes = ee.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) + + f_coeffs_x = cell_diff_x @ ee.exp(-1j * 2 * ee.pi * x @ modes[None, :]) + c = f_coeffs_x.shape[1] // 2 + + x_next = ee.vstack((ee.roll(x, -1, axis=0)[:-1], 1)) - x + + # f_coeffs_x[:, c] = (cell @ ee.vstack((x[0], x_next[:-1]))).flatten() + + assign_index = [ee.arange(len(f_coeffs_x)), ee.array([c])] + assign_value = (cell @ ee.vstack((x[0], x_next[:-1]))).flatten() + + f_coeffs_x = ee.assign(f_coeffs_x, assign_index, assign_value) + + mask = ee.ones(f_coeffs_x.shape[1], dtype=bool) + # mask[c] = False + + mask = ee.assign(mask, c, False) + + + # f_coeffs_x[:, mask] /= (1j * 2 * ee.pi * modes[mask]) + + assign_index = mask + assign_value = f_coeffs_x[:, mask] / (1j * 2 * ee.pi * modes[mask]) + + f_coeffs_x = ee.assign(f_coeffs_x, assign_index, assign_value, row_all=True) + + # Y axis + f_coeffs_x_next_y = ee.roll(f_coeffs_x, -1, axis=0) + f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x + + modes = ee.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) + + f_coeffs_xy = f_coeffs_x_diff_y.T @ ee.exp(-1j * 2 * ee.pi * y @ modes[None, :]) + c = f_coeffs_xy.shape[1] // 2 + + y_next = ee.vstack((ee.roll(y, -1, axis=0)[:-1], 1)) - y + + # f_coeffs_xy[:, c] = f_coeffs_x.T @ ee.vstack((y[0], y_next[:-1])).flatten() + + assign_value = f_coeffs_x.T @ ee.vstack((y[0], y_next[:-1])).flatten() + f_coeffs_xy = ee.assign(f_coeffs_xy, c, assign_value, row_all=True) + + if c: + mask = ee.ones(f_coeffs_xy.shape[1], dtype=bool) + # mask[c] = False + mask = ee.assign(mask, c, False) + + # f_coeffs_xy[:, mask] /= (1j * 2 * ee.pi * modes[mask]) + + assign_value = f_coeffs_xy[:, mask] / (1j * 2 * ee.pi * modes[mask]) + f_coeffs_xy = ee.assign(f_coeffs_xy, mask, assign_value, row_all=True) + + return f_coeffs_xy.T + + +def to_conv_mat(pmt, fourier_order): + + if len(pmt.shape) == 2: + print('shape is 2') + raise ValueError + ff = 2 * fourier_order + 1 + + if pmt.shape[1] == 1: # 1D + + res = ee.zeros((pmt.shape[0], ff, ff)).astype('complex') + + for i, layer in enumerate(pmt): + # f_coeffs = fft_piecewise_constant(layer, fourier_order) + # A = ee.roll(circulant(f_coeffs.flatten()), (f_coeffs.size + 1) // 2, 0) + # res[i] = A[:2 * fourier_order + 1, :2 * fourier_order + 1] + f_coeffs = fft_piecewise_constant(layer, fourier_order) + + center = f_coeffs.shape[1] // 2 + + conv_idx = ee.arange(-ff + 1, ff, 1) + conv_idx = circulant(conv_idx) + + e_conv = f_coeffs[0, center + conv_idx] + # res = res.at[i].set(e_conv) + res = ee.assign(res, i, e_conv) + + else: # 2D + # attention on the order of axis (Z Y X) + + # TODO: separate fourier order + res = ee.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype('complex') + + for i, layer in enumerate(pmt): + # pmtvy_fft = fft_piecewise_constant(layer, fourier_order) + # + # center = ee.array(pmtvy_fft.shape) // 2 + # + # conv_idx = ee.arange(-ff + 1, ff, 1) + # conv_idx = circulant(conv_idx)[ff - 1:, :ff] + # + # conv_i = ee.repeat(conv_idx, ff, axis=1) + # conv_i = ee.repeat(conv_i, [ff] * ff, axis=0) + # conv_j = ee.tile(conv_idx, (ff, ff)) + # res[i] = pmtvy_fft[center[0] + conv_i, center[1] + conv_j] + f_coeffs = fft_piecewise_constant(layer, fourier_order) + + center = ee.array(f_coeffs.shape) // 2 + + conv_idx = ee.arange(-ff + 1, ff, 1) + + conv_idx = circulant(conv_idx) + + conv_i = ee.repeat(conv_idx, ff, axis=1) + conv_i = ee.repeat(conv_i, ff, axis=0) + conv_j = ee.tile(conv_idx, (ff, ff)) + + # res = res.at[i].set(f_coeffs[center[0] + conv_i, center[1] + conv_j]) + assign_value = f_coeffs[center[0] + conv_i, center[1] + conv_j] + res = ee.assign(res, i, assign_value) + # import matplotlib.pyplot as plt + # + # plt.figure() + # plt.imshow(abs(res[0]), cmap='jet') + # plt.colorbar() + # plt.show() + # + return res + + +# def draw_fill_factor(patterns_fill_factor, grating_type, resolution=1000, mode=0): +# +# # res in Z X Y +# if grating_type == 2: +# res = ee.zeros((len(patterns_fill_factor), resolution, resolution), dtype='complex') +# else: +# res = ee.zeros((len(patterns_fill_factor), 1, resolution), dtype='complex') +# +# if grating_type in (0, 1): # TODO: handle this by len(fill_factor) +# # fill_factor is not exactly implemented. +# for i, (n_ridge, n_groove, fill_factor) in enumerate(patterns_fill_factor): +# permittivity = ee.ones((1, resolution), dtype='complex') +# cut = int(resolution * fill_factor) +# permittivity[0, :cut] *= n_ridge ** 2 +# permittivity[0, cut:] *= n_groove ** 2 +# res[i, 0] = permittivity +# else: # 2D +# for i, (n_ridge, n_groove, fill_factor) in enumerate(patterns_fill_factor): +# fill_factor = ee.array(fill_factor) +# permittivity = ee.ones((resolution, resolution), dtype='complex') +# cut = (resolution * fill_factor) # TODO: need parenthesis? +# permittivity *= n_groove ** 2 +# permittivity[:int(cut[1]), :int(cut[0])] *= n_ridge ** 2 +# res[i] = permittivity +# +# return res + +def circulant(c): + + center = ee.array(c.shape) // 2 + circ = ee.zeros((center[0] + 1, center[0] + 1), dtype='int32') + + for r in range(center[0]+1): + idx = ee.arange(r, r - center - 1, -1) + + # circ = circ.at[r].set(c[center + idx]) + assign_value = c[center + idx] + circ = ee.assign(circ, r, assign_value) + + return circ diff --git a/meent/integ/field_distribution.py b/meent/integ/field_distribution.py new file mode 100644 index 0000000..7d983a1 --- /dev/null +++ b/meent/integ/field_distribution.py @@ -0,0 +1,203 @@ +import meent as ee +import matplotlib.pyplot as plt + +from scipy.linalg import expm + + +def field_distribution(grating_type, *args, **kwargs): + if grating_type == 0: + res = field_dist_1d(*args, **kwargs) + else: + res = field_dist_2d(*args, **kwargs) + return res + + +def field_dist_1d(wavelength, n_I, theta, fourier_order, T1, layer_info_list, period, pol, resolution=(100, 1, 100)): + + k0 = 2 * ee.pi / wavelength + fourier_indices = ee.arange(-fourier_order, fourier_order + 1) + + kx_vector = k0 * (n_I * ee.sin(theta) - fourier_indices * (wavelength / period[0])).astype('complex') + Kx = ee.diag(kx_vector / k0) + + resolution_z, resolution_y, resolution_x = resolution + + field_cell = ee.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype='complex') + + T_layer = T1 + + # From the first layer + for idx_layer, (E_conv_i, q, W, X, a_i, b, d) in enumerate(layer_info_list[::-1]): + + c1 = T_layer[:, None] + c2 = b @ a_i @ X @ T_layer[:, None] + + Q = ee.diag(q) + + if pol == 0: + V = W @ Q + + else: + V = E_conv_i @ W @ Q + EKx = E_conv_i @ Kx + + for k in range(resolution_z): + z = k / resolution_z * d + + if pol == 0: # TE + Sy = W @ (expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + Ux = V @ (-expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + f_here = (-1j) * Kx @ Sy + + for j in range(resolution_y): + for i in range(resolution_x): + x = i * period[0] / resolution_x + + Ey = Sy.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Hx = -1j * Ux.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Hz = f_here.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + + field_cell[resolution_z * idx_layer + k, j, i] = Ey, Hx, Hz + else: # TM + Uy = W @ (expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + Sx = V @ (-expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + + f_here = (-1j) * EKx @ Uy + + for j in range(resolution_y): + for i in range(resolution_x): + x = i * period[0] / resolution_x + + Hy = Uy.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Ex = 1j * Sx.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Ez = f_here.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + + field_cell[resolution_z * idx_layer + k, j, i] = Hy, Ex, Ez + + T_layer = a_i @ X @ T_layer + + return field_cell + + +def field_dist_2d(wavelength, n_I, theta, phi, fourier_order, T1, layer_info_list, period, resolution=(100, 100, 100)): + k0 = 2 * ee.pi / wavelength + fourier_indices = ee.arange(-fourier_order, fourier_order + 1) + ff = 2 * fourier_order + 1 + + kx_vector = k0 * (n_I * ee.sin(theta) * ee.cos(phi) - fourier_indices * ( + wavelength / period[0])).astype('complex') + ky_vector = k0 * (n_I * ee.sin(theta) * ee.sin(phi) - fourier_indices * ( + wavelength / period[1])).astype('complex') + + Kx = ee.diag(ee.tile(kx_vector, ff).flatten()) / k0 + Ky = ee.diag(ee.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 + + resolution_z, resolution_y, resolution_x = resolution + field_cell = ee.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 6), dtype='complex') + + T_layer = T1 + + big_I = ee.eye((len(T1))) + + # From the first layer + for idx_layer, (E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d)\ + in enumerate(layer_info_list[::-1]): + + c = ee.block([[big_I], [big_B @ big_A_i @ big_X]]) @ T_layer + + ff = len(c) // 4 + + c1_plus = c[0*ff:1*ff] + c2_plus = c[1*ff:2*ff] + c1_minus = c[2*ff:3*ff] + c2_minus = c[3*ff:4*ff] + + q1 = q[:len(q)//2] + q2 = q[len(q)//2:] + big_Q1 = ee.diag(q1) + big_Q2 = ee.diag(q2) + + for k in range(resolution_z): + z = k / resolution_z * d + + Sx = W_11 @ (expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + W_12 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Sy = W_21 @ (expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + W_22 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Ux = V_11 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + V_12 @ (-expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Uy = V_21 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + V_22 @ (-expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Sz = -1j * E_conv_i @ (Kx @ Uy - Ky @ Ux) + + Uz = -1j * (Kx @ Sy - Ky @ Sx) + + for j in range(resolution_y): + y = j * period[1] / resolution_y + + for i in range(resolution_x): + x = i * period[0] / resolution_x + + exp_K = ee.exp(-1j * kx_vector.reshape((1, -1)) * x) * ee.exp(-1j * ky_vector.reshape((-1, 1)) * y) + exp_K = exp_K.flatten() + + Ex = Sx.T @ exp_K + Ey = Sy.T @ exp_K + Ez = Sz.T @ exp_K + + Hx = -1j * Ux.T @ exp_K + Hy = -1j * Uy.T @ exp_K + Hz = -1j * Uz.T @ exp_K + + field_cell[resolution_z * idx_layer + k, j, i] = [Ex, Ey, Ez, Hx, Hy, Hz] + + T_layer = big_A_i @ big_X @ T_layer + + return field_cell + + +def field_plot_zx(field_cell, pol=0, plot_indices=(1, 1, 1, 1, 1, 1), y_slice=0, z_slice=-1, zx=True, yx=True): + + if field_cell.shape[-1] == 6: # 2D grating + title = ['2D Ex', '2D Ey', '2D Ez', '2D Hx', '2D Hy', '2D Hz', ] + else: # 1D grating + if pol == 0: # TE + title = ['1D Ey', '1D Hx', '1D Hz', ] + else: # TM + title = ['1D Hy', '1D Ex', '1D Ez', ] + + if zx: + for idx in range(len(title)): + if plot_indices[idx]: + plt.imshow((abs(field_cell[:, y_slice, :, idx]) ** 2), cmap='jet', aspect='auto') + # plt.clim(0, 2) # identical to caxis([-4,4]) in MATLAB + plt.colorbar() + plt.title(title[idx]) + plt.show() + if yx: + for idx in range(len(title)): + if plot_indices[idx]: + plt.imshow((abs(field_cell[z_slice, :, :, idx]) ** 2), cmap='jet', aspect='auto') + plt.clim(0, 3.5) # identical to caxis([-4,4]) in MATLAB + plt.colorbar() + plt.title(title[idx]) + plt.show() + + # for idx in range(len(title)): + # if plot_indices[idx]: + # plt.imshow((abs(field_cell[0, :, :, idx]) ** 2), cmap='jet', aspect='auto') + # # plt.clim(0, 1.3) # identical to caxis([-4,4]) in MATLAB + # plt.colorbar() + # plt.title(title[idx]) + # plt.show() + # for idx in range(len(title)): + # if plot_indices[idx]: + # plt.imshow((abs(field_cell[-1, :, :, idx]) ** 2), cmap='jet', aspect='auto') + # # plt.clim(0, 3.2) # identical to caxis([-4,4]) in MATLAB + # plt.colorbar() + # plt.title(title[idx]) + # plt.show() diff --git a/meent/integ/rcwa.py b/meent/integ/rcwa.py new file mode 100644 index 0000000..24ed55e --- /dev/null +++ b/meent/integ/rcwa.py @@ -0,0 +1,63 @@ +import time + +from ._base import _BaseRCWA +from .convolution_matrix import to_conv_mat, put_permittivity_in_ucell, read_material_table +from .field_distribution import field_dist_1d, field_dist_2d, field_plot_zx + + +class RCWAInteg(_BaseRCWA): + + def __init__(self, mode=0, grating_type=0, n_I=1., n_II=1., theta=0, phi=0, psi=0, fourier_order=40, period=(100,), + wavelength=900, pol=0, patterns=None, ucell=None, ucell_materials=None, thickness=None, algo='TMM', + *args, **kwargs): + + super().__init__(grating_type, n_I, n_II, theta, phi, psi, fourier_order, period, wavelength, pol, patterns, ucell, ucell_materials, + thickness, algo) + + self.mode = mode + self.spectrum_r, self.spectrum_t = None, None + # self.init_spectrum_array() + self.mat_table = read_material_table() + + def solve(self, wavelength, e_conv_all, o_e_conv_all): + + # TODO: !handle uniform layer + + if self.grating_type == 0: + de_ri, de_ti = self.solve_1d(wavelength, e_conv_all, o_e_conv_all) + elif self.grating_type == 1: + de_ri, de_ti = self.solve_1d_conical(wavelength, e_conv_all, o_e_conv_all) + elif self.grating_type == 2: + de_ri, de_ti = self.solve_2d(wavelength, e_conv_all, o_e_conv_all) + else: + raise ValueError + + return de_ri.real, de_ti.real + + def run_ucell(self): + + ucell = put_permittivity_in_ucell(self.ucell, self.ucell_materials, self.mat_table, self.wavelength) + + e_conv_all = to_conv_mat(ucell, self.fourier_order) + o_e_conv_all = to_conv_mat(1 / ucell, self.fourier_order) + + de_ri, de_ti = self.solve(self.wavelength, e_conv_all, o_e_conv_all) + + return de_ri, de_ti + + def calculate_field(self, resolution=None, plot=True): + + if self.grating_type == 0: + resolution = [100, 1, 100] if not resolution else resolution + field_cell = field_dist_1d(self.wavelength, self.n_I, self.theta, self.fourier_order, self.T1, + self.layer_info_list, self.period, self.pol, resolution=resolution) + else: + resolution = [100, 100, 100] if not resolution else resolution + field_cell = field_dist_2d(self.wavelength, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, + self.layer_info_list, self.period, resolution=resolution) + + if plot: + field_plot_zx(field_cell, self.pol) + + return field_cell + diff --git a/meent/integ/scattering_method.py b/meent/integ/scattering_method.py new file mode 100644 index 0000000..aa7fc33 --- /dev/null +++ b/meent/integ/scattering_method.py @@ -0,0 +1,183 @@ +""" +currently SMM is not supported +""" + +# many codes for scattering matrix method are from here: +# https://github.com/zhaonat/Rigorous-Coupled-Wave-Analysis +# also refer our fork https://github.com/yonghakim/zhaonat-rcwa + +from .smm_util import * + + +def scattering_1d_1(k0, n_I, n_II, theta, phi, fourier_indices, period, pol, wl=None): + + kx_vector = (n_I * np.sin(theta) * np.cos(phi) - fourier_indices * ( + 2 * np.pi / k0 / period[0])).astype('complex') + Kx = np.diag(kx_vector) + + # scattering matrix needed for 'gap medium' + Wg, Vg, Kzg = homogeneous_1D(Kx, 1, wl=wl, comment='Gap') + + # reflection medium + Wr, Vr, Kzr = homogeneous_1D(Kx, n_I, pol=pol, wl=wl, comment='Refl') + + # transmission medium; + Wt, Vt, Kzt = homogeneous_1D(Kx, n_II, pol=pol, wl=wl, comment='Tran') + + # S matrices for the reflection region + Ar, Br = A_B_matrices_half_space(Vr, Vg) # make sure this order is right + _, Sg = S_RT(Ar, Br, ref_mode=True) # scatter matrix for the reflection region + + return Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg + + +def scattering_1d_2(W, Wg, V, Vg, d, k0, LAMBDA, Sg): + # calculating A and B matrices for scattering matrix + # define S matrix for the GRATING REGION + A, B = A_B_matrices(W, Wg, V, Vg) + _, S_dict = S_layer(A, B, d, k0, LAMBDA) + _, Sg = RedhefferStar(Sg, S_dict) + + return A, B, S_dict, Sg + + +def scattering_1d_3(Wt, Wg, Vt, Vg, Sg, ff, Wr, fourier_order, Kzr, Kzt, n_I, n_II, theta, pol): + # define S matrices for the Transmission region + At, Bt = A_B_matrices_half_space(Vt, Vg) # make sure this order is right + _, St_dict = S_RT(At, Bt, ref_mode=False) # scatter matrix for the reflection region + _, Sg = RedhefferStar(Sg, St_dict) + + k_inc = n_I * np.array([np.sin(theta), 0, np.cos(theta)]) + + c_inc = np.zeros((ff, 1)) # only need one set... + c_inc[fourier_order] = 1 + c_inc = np.linalg.inv(Wr) @ c_inc + # COMPUTE FIELDS: similar idea but more complex for RCWA since you have individual modes each contributing + reflected = Wr @ Sg['S11'] @ c_inc + transmitted = Wt @ Sg['S21'] @ c_inc + + # reflected is already ry or Ey + rsq = np.square(np.abs(reflected)) + tsq = np.square(np.abs(transmitted)) + + # compute final reflectivity + if pol == 0: + de_ri = np.real(Kzr) @ rsq / np.real(k_inc[2]) + de_ti = np.real(Kzt) @ tsq / np.real(k_inc[2]) + elif pol == 1: + de_ri = np.real(Kzr)@rsq/np.real(k_inc[2]) / n_I**2 + de_ti = np.real(Kzt)@tsq/np.real(k_inc[2]) * n_I**2 / n_II**4 + else: + raise ValueError + + return de_ri.flatten(), de_ti.flatten() + + +def scattering_2d_1(n_I, n_II, theta, phi, k0, period, fourier_order): + kx_inc = n_I * np.sin(theta) * np.cos(phi) + ky_inc = n_I * np.sin(theta) * np.sin(phi) + kz_inc = np.sqrt(n_I ** 2 * 1 - kx_inc ** 2 - ky_inc ** 2) + + Kx, Ky = K_matrix_cubic_2D(kx_inc, ky_inc, k0, period[0], period[1], fourier_order, fourier_order) + + # specify gap media (this is an LHI so no eigenvalue problem should be solved + e_h = 1 + Wg, Vg, Kzg = homogeneous_module(Kx, Ky, e_h) + + # ================= Working on the Reflection Side =========== ## + e_r = n_I ** 2 + Wr, Vr, Kzr = homogeneous_module(Kx, Ky, e_r) + + # ========= Working on the Transmission Side==============## + e_t = n_II ** 2 + Wt, Vt, Kzt = homogeneous_module(Kx, Ky, e_t) + + # calculating A and B matrices for scattering matrix + Ar, Br = A_B_matrices_half_space(Vr, Vg) + + # s_ref is a matrix, Sr_dict is a dictionary + _, Sr_dict = S_RT(Ar, Br, ref_mode=True) # scatter matrix for the reflection region + Sg = Sr_dict + + return Kx, Ky, kz_inc, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg + + +def scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA): + + A, B = A_B_matrices(W, Wg, V, Vg) + _, Sl_dict = S_layer(A, B, d, k0, LAMBDA) + Sg_matrix, Sg = RedhefferStar(Sg, Sl_dict) + + return A, B, Sl_dict, Sg_matrix, Sg + + +def scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, n_I, pol, theta, + phi, fourier_order, ff): + normal_vector = np.array([0, 0, 1]) # positive z points down; + # amplitude of the te vs tm modes (which are decoupled) + + if pol == 0: + pte = 1 + ptm = 0 + elif pol == 1: + pte = 0 + ptm = 1 + else: + raise ValueError + + M = N = fourier_order + NM = ff ** 2 + + # get At, Bt + # since transmission is the same as gap, order does not matter + At, Bt = A_B_matrices_half_space(Vt, Vg) + _, ST_dict = S_RT(At, Bt, ref_mode=False) + + # update global scattering matrix + Sg_matrix, Sg = RedhefferStar(Sg, ST_dict) + + # finally CONVERT THE GLOBAL SCATTERING MATRIX BACK TO A MATRIX + + K_inc_vector = n_I * np.array([np.sin(theta) * np.cos(phi), np.sin(theta) * np.sin(phi), np.cos(theta)]) + + _, e_src, _ = initial_conditions(K_inc_vector, theta, normal_vector, pte, ptm, N, M) + + c_inc = np.linalg.inv(Wr) @ e_src + # COMPUTE FIELDS: similar idea but more complex for RCWA since you have individual modes each contributing + reflected = Wr @ Sg['S11'] @ c_inc + transmitted = Wt @ Sg['S21'] @ c_inc + + rx = reflected[0:NM, :] # rx is the Ex component. + ry = reflected[NM:, :] + tx = transmitted[0:NM, :] + ty = transmitted[NM:, :] + + rz = np.linalg.inv(Kzr) @ (Kx @ rx + Ky @ ry) + tz = np.linalg.inv(Kzt) @ (Kx @ tx + Ky @ ty) + + rsq = np.square(np.abs(rx)) + np.square(np.abs(ry)) + np.square(np.abs(rz)) + tsq = np.square(np.abs(tx)) + np.square(np.abs(ty)) + np.square(np.abs(tz)) + + de_ri = np.real(Kzr)@rsq/np.real(K_inc_vector[2]) # real because we only want propagating components + de_ti = np.real(Kzt)@tsq/np.real(K_inc_vector[2]) + + return de_ri, de_ti + + +def scattering_2d_wv(ff, Kx, Ky, E_conv, oneover_E_conv, oneover_E_conv_i, E_i, mu_conv=None): + # ------------------------- + # W and V from SMM method. + NM = ff ** 2 + if mu_conv is None: + mu_conv = np.identity(NM) + + P, Q, _ = P_Q_kz(Kx, Ky, E_conv, mu_conv, oneover_E_conv, oneover_E_conv_i, E_i) + GAMMA = P @ Q + + Lambda, W = np.linalg.eig(GAMMA) # LAMBDa is effectively refractive index + LAMBDA = np.diag(Lambda) + LAMBDA = np.sqrt(LAMBDA.astype('complex')) + + V = Q @ W @ np.linalg.inv(LAMBDA) + + return W, V, LAMBDA diff --git a/meent/integ/smm_util.py b/meent/integ/smm_util.py new file mode 100644 index 0000000..9fab4f7 --- /dev/null +++ b/meent/integ/smm_util.py @@ -0,0 +1,335 @@ +""" +currently SMM is not supported +""" +# many codes for scattering matrix method are from here: +# https://github.com/zhaonat/Rigorous-Coupled-Wave-Analysis +# also refer our fork https://github.com/yonghakim/zhaonat-rcwa + +import numpy as np +from numpy.linalg import inv, pinv +# TODO: try pseudo-inverse? +from scipy.linalg import block_diag +# TODO: ok by jax? + + +def A_B_matrices_half_space(V_layer, Vg): + + I = np.eye(len(Vg)) + a = I + inv(Vg) @ V_layer + b = I - inv(Vg) @ V_layer + + return a, b + + +def A_B_matrices(W_layer, Wg, V_layer, Vg): + """ + single function to output the a and b matrices needed for the scatter matrices + :param W_layer: gap + :param Wg: + :param V_layer: gap + :param Vg: + :return: + """ + W_i = inv(W_layer) + V_i = inv(V_layer) + + a = W_i @ Wg + V_i @ Vg + b = W_i @ Wg - V_i @ Vg + + return a, b + + +def S_layer(A, B, d, k0, modes): + """ + function to create scatter matrix in the ith layer of the uniform layer structure + we assume that gap layers are used so we need only one A and one B + :param A: function A = + :param B: function B + :param k0 #free -space wavevector magnitude (normalization constant) in Si Units + :param Li #length of ith layer (in Si units) + :param modes, eigenvalue matrix + :return: S (4x4 scatter matrix) and Sdict, which contains the 2x2 block matrix as a dictionary + """ + + # sign convention (EMLAB is exp(-1i*k\dot r)) + X = np.diag(np.exp(-np.diag(modes)*d*k0)) + # TODO: Check + # TODO: expm + + A_i = inv(A) + term_i = inv(A - X @ B @ A_i @ X @ B) + + S11 = term_i @ (X @ B @ A_i @ X @ A - B) + S12 = term_i @ X @ (A - B @ A_i @ B) + S22 = S11 + S21 = S12 + + S_dict = {'S11': S11, 'S22': S22, 'S12': S12, 'S21': S21} + S = np.block([[S11, S12], [S21, S22]]) + return S, S_dict + + +def S_RT(A, B, ref_mode): + + A_i = inv(A) + + S11 = -A_i @ B + S12 = 2 * A_i + S21 = 0.5*(A - B @ A_i @ B) + S22 = B @ A_i + + if ref_mode: + S_dict = {'S11': S11, 'S22': S22, 'S12': S12, 'S21': S21} + S = np.block([[S11, S12], [S21, S22]]) + else: + S_dict = {'S11': S22, 'S22': S11, 'S12': S21, 'S21': S12} + S = np.block([[S22, S21], [S12, S11]]) + return S, S_dict + + +def homogeneous_module(Kx, Ky, e_r, m_r=1, perturbation=1E-16, wl=None, comment=None): + """ + homogeneous layer is much simpler to do, so we will create an isolated module to deal with it + :return: + """ + assert type(Kx) == np.ndarray, 'not np.array' + assert type(Ky) == np.ndarray, 'not np.array' + + N = len(Kx) + I = np.identity(N) + + P = (e_r**-1)*np.block([[Kx*Ky, e_r*m_r*I-Kx**2], [Ky**2-m_r*e_r*I, -Ky*Kx]]) + Q = (e_r/m_r)*P + + diag = np.diag(Q) + idx = np.nonzero(diag == 0)[0] + if len(idx): + # Adding pertub* to Q and pertub to Kz. + # TODO: check why this works. + # TODO: make imaginary part sign consistent + Q[idx, idx] = np.conj(perturbation) + print(wl, comment, 'non-invertible Q: adding perturbation') + # print(Q.diagonal()) + + W = np.eye(N*2) + Kz2 = (m_r*e_r*I-Kx**2-Ky**2).astype('complex') # arg is +kz^2 + # arg = -(m_r*e_r*I-Kx**2-Ky**2) # arg is +kz^2 + # Kz = np.conj(np.sqrt(arg)) # conjugate enforces the negative sign convention (we also have to conjugate er and mur if they are complex) + + Kz = np.sqrt(Kz2) # conjugate enforces the negative sign convention (we also have to conjugate er and mur if they are complex) + Kz = np.conj(Kz) # TODO: conjugate? + + diag = np.diag(Kz) + idx = np.nonzero(diag == 0)[0] + if len(idx): + Kz[idx, idx] = perturbation + print(wl, comment, 'non-invertible Kz: adding perturbation') + # print(Kz.diagonal()) + + eigenvalues = block_diag(1j*Kz, 1j*Kz) # determining the modes of ex, ey... so it appears eigenvalue order MATTERS... + V = Q @ np.linalg.inv(eigenvalues) # eigenvalue order is arbitrary (hard to compare with matlab + + + # V = -1j*Q + + return W, V, Kz + + +def homogeneous_1D(Kx, n_index, m_r=1, pol=None, perturbation=1E-20*(1+1j), wl=None, comment=None): + """ + efficient homogeneous 1D module + :param Kx: + :param e_r: + :param m_r: + :return: + """ + + e_r = n_index ** 2 + + I = np.identity(len(Kx)) + + W = I + Q = (1 / m_r) * (e_r * m_r * I - Kx ** 2) + # Q = Kx**2 - e_r * I + + diag = np.diag(Q) + idx = np.nonzero(diag == 0)[0] + if len(idx): + # Adding pertub* to Q and pertub to Kz. + # TODO: check why this works. + # TODO: make imaginary part sign consistent + Q[idx, idx] = np.conj(perturbation) + print(wl, comment, 'non-invertible Q: adding perturbation') + # print(Q.diagonal()) + + Kz = np.sqrt(m_r*e_r*I-Kx**2) + Kz = np.conj(Kz) # TODO: conjugate? + + # TODO: check Singular or ill-conditioned; spread this to whole code + # invertible check + diag = np.diag(Kz) + idx = np.nonzero(diag == 0)[0] + if len(idx): + Kz[idx, idx] = perturbation + print(wl, comment, 'non-invertible Kz: adding perturbation') + # print(Kz.diagonal()) + + # TODO: why this works... + if pol: # 0: TE, 1: TM + Kz = Kz * (n_index ** 2) + + eigenvalues = -1j*Kz # determining the modes of ex, ey... so it appears eigenvalue order MATTERS... + V = Q @ np.linalg.inv(eigenvalues) # eigenvalue order is arbitrary (hard to compare with matlab + + return W, V, Kz + + +def K_matrix_cubic_2D(beta_x, beta_y, k0, a_x, a_y, N_p, N_q): + # K_i = beta_i - pT1i - q T2i - r*T3i + # but here we apply it only for cubic and tegragonal geometries in 2D + """ + :param beta_x: input k_x,inc/k0 + :param beta_y: k_y,inc/k0; #already normalized...k0 is needed to normalize the 2*pi*lambda/a + however such normalization can cause singular matrices in the homogeneous module (specifically with eigenvalues) + :param T1:reciprocal lattice vector 1 + :param T2: + :param T3: + :return: + """ + # (indexing follows (1,1), (1,2), ..., (1,N), (2,1),(2,2),(2,3)...(M,N) ROW MAJOR + # but in the cubic case, k_x only depends on p and k_y only depends on q + k_x = beta_x - 2*np.pi*np.arange(-N_p, N_p+1)/(k0*a_x) + k_y = beta_y - 2*np.pi*np.arange(-N_q, N_q+1)/(k0*a_y) + + kx, ky = np.meshgrid(k_x, k_y) + Kx = np.diag(kx.flatten()) + Ky = np.diag(ky.flatten()) + + return Kx, Ky + + +def P_Q_kz(Kx, Ky, e_conv, mu_conv, oneover_E_conv, oneover_E_conv_i, E_i): + ''' + r is for relative so do not put epsilon_0 or mu_0 here + :param Kx: NM x NM matrix + :param Ky: + :param e_conv: (NM x NM) conv matrix + :param mu_r: + :return: + ''' + argument = e_conv - Kx ** 2 - Ky ** 2 + Kz = np.conj(np.sqrt(argument.astype('complex'))) + # Kz = np.sqrt(argument.astype('complex')) # TODO: conjugate? + + # TODO: confirm whether oneonver_E_conv is indeed not used + # TODO: Check sign of P and Q + P = np.block([ + [Kx @ E_i @ Ky, -Kx @ E_i @ Kx + mu_conv], + [Ky @ E_i @ Ky - mu_conv, -Ky @ E_i @ Kx] + ]) + + Q = np.block([ + [Kx @ inv(mu_conv) @ Ky, -Kx @ inv(mu_conv) @ Kx + e_conv], + [-oneover_E_conv_i + Ky @ inv(mu_conv) @ Ky, -Ky @ inv(mu_conv) @ Kx] + ]) + + return P, Q, Kz + + +def delta_vector(P, Q): + ''' + create a vector with a 1 corresponding to the 0th order + #input P = 2*(num_ord_specified)+1 + ''' + fourier_grid = np.zeros((P,Q)) + fourier_grid[int(P/2), int(Q/2)] = 1 + # vector = np.zeros((P*Q,)); + # + # #the index of the (0,0) element requires a conversion using sub2ind + # index = int(P/2)*P + int(Q/2); + vector = fourier_grid.flatten() + return np.matrix(np.reshape(vector, (1,len(vector)))) + + +def initial_conditions(K_inc_vector, theta, normal_vector, pte, ptm, P, Q): + """ + :param K_inc_vector: whether it's normalized or not is not important... + :param theta: angle of incience + :param normal_vector: pointing into z direction + :param pte: te polarization amplitude + :param ptm: tm polarization amplitude + :return: + calculates the incident E field, cinc, and the polarization fro the initial condition vectors + """ + # ate -> unit vector holding the out of plane direction of TE + # atm -> unit vector holding the out of plane direction of TM + # what are the out of plane components...(Ey and Hy) + # normal_vector = [0,0,-1]; i.e. waves propagate down into the -z direction + # cinc = Wr^-1@[Ex_inc, Ey_inc]; + + if theta != 0: + ate_vector = np.cross(K_inc_vector, normal_vector) + ate_vector = ate_vector / (np.linalg.norm(ate_vector)) + else: + ate_vector = np.array([0, 1, 0]) + + atm_vector = np.cross(ate_vector, K_inc_vector) + atm_vector = atm_vector / (np.linalg.norm(atm_vector)) + + polarization = pte * ate_vector + ptm * atm_vector # total E_field incident which is a 3 component vector (ex, ey, ez) + E_inc = polarization + # go from mode coefficients to FIELDS + delta = delta_vector(2*P+1, 2*Q+1) + + # c_inc; #remember we ultimately solve for [Ex, Ey, Hx, Hy]. + e_src = np.hstack((polarization[0]*delta, polarization[1]*delta)) + e_src = np.matrix(e_src).T # mode amplitudes of Ex, and Ey + + return E_inc, e_src, polarization + + +def RedhefferStar(SA, SB): # SA and SB are both 2x2 block matrices; + """ + RedhefferStar for arbitrarily sized 2x2 block matrices for RCWA + :param SA: dictionary containing the four sub-blocks + :param SB: dictionary containing the four sub-blocks, + keys are 'S11', 'S12', 'S21', 'S22' + :return: + """ + + assert type(SA) == dict, 'not dict' + assert type(SB) == dict, 'not dict' + + # once we break every thing like this, we should still have matrices + SA_11, SA_12, SA_21, SA_22 = SA['S11'], SA['S12'], SA['S21'], SA['S22'] + SB_11, SB_12, SB_21, SB_22 = SB['S11'], SB['S12'], SB['S21'], SB['S22'] + N = len(SA_11) # SA_11 should be square so length is fine + + I = np.eye(N) + D_i = inv(I - SB_11 @ SA_22) + F_i = inv(I - SA_22 @ SB_11) + + SAB_11 = SA_11 + SA_12 @ D_i @ SB_11 @ SA_21 + SAB_12 = SA_12 @ D_i @ SB_12 + SAB_21 = SB_21 @ F_i @ SA_21 + SAB_22 = SB_22 + SB_21 @ F_i @ SA_22 @ SB_12 + + SAB = np.block([[SAB_11, SAB_12], [SAB_21, SAB_22]]) + SAB_dict = {'S11': SAB_11, 'S22': SAB_22, 'S12': SAB_12, 'S21': SAB_21} + + return SAB, SAB_dict + + +def construct_global_scatter(scatter_list): + """ + this function assumes an RCWA implementation where all the scatter matrices are stored in a list + and the global scatter matrix is constructed at the end + :param scatter_list: list of scatter matrices of the form [Sr, S1, S2, ... , SN, ST] + :return: + """ + Sr = scatter_list[0] + Sg = Sr + for i in range(1, len(scatter_list)): + Sg = RedhefferStar(Sg, scatter_list[i]) + return Sg + diff --git a/meent/integ/transfer_method.py b/meent/integ/transfer_method.py new file mode 100644 index 0000000..a72c720 --- /dev/null +++ b/meent/integ/transfer_method.py @@ -0,0 +1,402 @@ +import meent.integ.backend.meentpy as ee + + +def transfer_1d_1(ff, polarization, k0, n_I, n_II, theta, delta_i0, fourier_order, fourier_indices, wavelength, period): + + kx_vector = k0 * (n_I * ee.sin(theta) - fourier_indices * (wavelength / period[0])).astype('complex') + + k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2) ** 0.5 + k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2) ** 0.5 + + k_I_z = k_I_z.conjugate() + k_II_z = k_II_z.conjugate() + + Kx = ee.diag(kx_vector / k0) + + f = ee.eye(ff) + + if polarization == 0: # TE + Y_I = ee.diag(k_I_z / k0) + Y_II = ee.diag(k_II_z / k0) + + YZ_I = Y_I + g = 1j * Y_II + inc_term = 1j * n_I * ee.cos(theta) * delta_i0 + + elif polarization == 1: # TM + Z_I = ee.diag(k_I_z / (k0 * n_I ** 2)) + Z_II = ee.diag(k_II_z / (k0 * n_II ** 2)) + + YZ_I = Z_I + g = 1j * Z_II + inc_term = 1j * delta_i0 * ee.cos(theta) / n_I + + else: + raise ValueError + + T = ee.eye(2 * fourier_order + 1) + + return kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T + + +def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T): + X = ee.diag(ee.exp(-k0 * q * d)) + + W_i = ee.linalg.inv(W) + V_i = ee.linalg.inv(V) + + a = 0.5 * (W_i @ f + V_i @ g) + b = 0.5 * (W_i @ f - V_i @ g) + + a_i = ee.linalg.inv(a) + + f = W @ (ee.eye(2 * fourier_order + 1) + X @ b @ a_i @ X) + g = V @ (ee.eye(2 * fourier_order + 1) - X @ b @ a_i @ X) + T = T @ a_i @ X + + return X, f, g, T, a_i, b + + +def transfer_1d_3(g1, YZ_I, f1, delta_i0, inc_term, T, k_I_z, k0, n_I, n_II, theta, polarization, k_II_z): + T1 = ee.linalg.inv(g1 + 1j * YZ_I @ f1) @ (1j * YZ_I @ delta_i0 + inc_term) + R = f1 @ T1 - delta_i0 + T = T @ T1 + + de_ri = ee.real(R * ee.conj(R) * k_I_z / (k0 * n_I * ee.cos(theta))) + if polarization == 0: + # de_ti = T * ee.conj(T) * ee.real(k_II_z / (k0 * n_I * ee.cos(theta))) + de_ti = ee.real(T * ee.conj(T) * k_II_z / (k0 * n_I * ee.cos(theta))) + elif polarization == 1: + # de_ti = T * ee.conj(T) * ee.real(k_II_z / n_II ** 2) / (k0 * ee.cos(theta) / n_I) + de_ti = ee.real(T * ee.conj(T) * k_II_z / n_II ** 2) / (k0 * ee.cos(theta) / n_I) + else: + raise ValueError + + return de_ri, de_ti, T1 + + +def transfer_2d_1(ff, k0, n_I, n_II, period, fourier_indices, theta, phi, wavelength, perturbation=1E-20 * (1 + 1j)): + I = ee.eye(ff ** 2) + O = ee.zeros((ff ** 2, ff ** 2)) + + kx_vector = k0 * (n_I * ee.sin(theta) * ee.cos(phi) - fourier_indices * ( + wavelength / period[0])).astype('complex') + ky_vector = k0 * (n_I * ee.sin(theta) * ee.sin(phi) - fourier_indices * ( + wavelength / period[1])).astype('complex') + + Kx = ee.diag(ee.tile(kx_vector, ff).flatten()) / k0 + Ky = ee.diag(ee.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 + + k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 + k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 + + k_I_z = k_I_z.flatten().conjugate() + k_II_z = k_II_z.flatten().conjugate() + + idx = ee.nonzero(kx_vector == 0)[0] + if len(idx): + # TODO: need imaginary part? + # TODO: make imaginary part sign consistent + kx_vector[idx] = perturbation + print(wavelength, 'varphi divide by 0: adding perturbation') + + varphi = ee.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() + + Y_I = ee.diag(k_I_z / k0) + Y_II = ee.diag(k_II_z / k0) + + Z_I = ee.diag(k_I_z / (k0 * n_I ** 2)) + Z_II = ee.diag(k_II_z / (k0 * n_II ** 2)) + + big_F = ee.block([[I, O], [O, 1j * Z_II]]) + big_G = ee.block([[1j * Y_II, O], [O, I]]) + + big_T = ee.eye(ff ** 2 * 2) + + return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T + + +def transfer_2d_wv(ff, Kx, E_i, Ky, o_E_conv_i, E_conv, center): + + I = ee.eye(ff ** 2) + + B = Kx @ E_i @ Kx - I + D = Ky @ E_i @ Ky - I + + S2_from_S = ee.block( + [ + [Ky ** 2 + B @ o_E_conv_i, Kx @ (E_i @ Ky @ E_conv - Ky)], + [Ky @ (E_i @ Kx @ o_E_conv_i - Kx), Kx ** 2 + D @ E_conv] + ]) + + eigenvalues, W = ee.linalg.eig(S2_from_S) + + q = eigenvalues ** 0.5 + + Q = ee.diag(q) + Q_i = ee.linalg.inv(Q) + U1_from_S = ee.block( + [ + [-Kx @ Ky, Kx ** 2 - E_conv], + [o_E_conv_i - Ky ** 2, Ky @ Kx] + ] + ) + V = U1_from_S @ W @ Q_i + + return W, V, q + + +def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T): + + q1 = q[:center] + q2 = q[center:] + + W_11 = W[:center, :center] + W_12 = W[:center, center:] + W_21 = W[center:, :center] + W_22 = W[center:, center:] + + V_11 = V[:center, :center] + V_12 = V[:center, center:] + V_21 = V[center:, :center] + V_22 = V[center:, center:] + + X_1 = ee.diag(ee.exp(-k0 * q1 * d)) + X_2 = ee.diag(ee.exp(-k0 * q2 * d)) + + F_c = ee.diag(ee.cos(varphi)) + F_s = ee.diag(ee.sin(varphi)) + + W_ss = F_c @ W_21 - F_s @ W_11 + W_sp = F_c @ W_22 - F_s @ W_12 + W_ps = F_c @ W_11 + F_s @ W_21 + W_pp = F_c @ W_12 + F_s @ W_22 + + V_ss = F_c @ V_11 + F_s @ V_21 + V_sp = F_c @ V_12 + F_s @ V_22 + V_ps = F_c @ V_21 - F_s @ V_11 + V_pp = F_c @ V_22 - F_s @ V_12 + + big_I = ee.eye(2 * (len(I))) + big_X = ee.block([[X_1, O], [O, X_2]]) + big_W = ee.block([[W_ss, W_sp], [W_ps, W_pp]]) + big_V = ee.block([[V_ss, V_sp], [V_ps, V_pp]]) + + big_W_i = ee.linalg.inv(big_W) + big_V_i = ee.linalg.inv(big_V) + + big_A = 0.5 * (big_W_i @ big_F + big_V_i @ big_G) + big_B = 0.5 * (big_W_i @ big_F - big_V_i @ big_G) + + big_A_i = ee.linalg.inv(big_A) + + big_F = big_W @ (big_I + big_X @ big_B @ big_A_i @ big_X) + big_G = big_V @ (big_I - big_X @ big_B @ big_A_i @ big_X) + + big_T = big_T @ big_A_i @ big_X + + return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 + + +def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z): + I = ee.eye(ff ** 2) + O = ee.zeros((ff ** 2, ff ** 2)) + + big_F_11 = big_F[:center, :center] + big_F_12 = big_F[:center, center:] + big_F_21 = big_F[center:, :center] + big_F_22 = big_F[center:, center:] + + big_G_11 = big_G[:center, :center] + big_G_12 = big_G[:center, center:] + big_G_21 = big_G[center:, :center] + big_G_22 = big_G[center:, center:] + + # Final Equation in form of AX=B + final_A = ee.block( + [ + [I, O, -big_F_11, -big_F_12], + [O, -1j * Z_I, -big_F_21, -big_F_22], + [-1j * Y_I, O, -big_G_11, -big_G_12], + [O, I, -big_G_21, -big_G_22], + ] + ) + + final_B = ee.block( + [ + [-ee.sin(psi) * delta_i0], + [-ee.cos(psi) * ee.cos(theta) * delta_i0], + [-1j * ee.sin(psi) * n_I * ee.cos(theta) * delta_i0], + [1j * n_I * ee.cos(psi) * delta_i0] + ] + ) + + final_RT = ee.linalg.inv(final_A) @ final_B + + R_s = final_RT[:ff ** 2, :].flatten() + R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() + + big_T1 = final_RT[2 * ff ** 2:, :] + big_T = big_T @ big_T1 + + T_s = big_T[:ff ** 2, :].flatten() + T_p = big_T[ff ** 2:, :].flatten() + + de_ri = R_s * ee.conj(R_s) * ee.real(k_I_z / (k0 * n_I * ee.cos(theta))) \ + + R_p * ee.conj(R_p) * ee.real((k_I_z / n_I ** 2) / (k0 * n_I * ee.cos(theta))) + + de_ti = T_s * ee.conj(T_s) * ee.real(k_II_z / (k0 * n_I * ee.cos(theta))) \ + + T_p * ee.conj(T_p) * ee.real((k_II_z / n_II ** 2) / (k0 * n_I * ee.cos(theta))) + + return de_ri.real, de_ti.real, big_T1 + + +def transfer_1d_conical_1(ff, k0, n_I, n_II, period, fourier_indices, theta, phi, wavelength, perturbation=1E-20 * (1 + 1j)): + I = ee.eye(ff) + O = ee.zeros((ff, ff)) + + kx_vector = k0 * (n_I * ee.sin(theta) * ee.cos(phi) - fourier_indices * (wavelength / period[0])).astype('complex') + ky = k0 * n_I * ee.sin(theta) * ee.sin(phi) + + k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 + k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 + + k_I_z = k_I_z.conjugate() + k_II_z = k_II_z.conjugate() + + idx = ee.nonzero(kx_vector == 0)[0] + if len(idx): + # TODO: need imaginary part? + # TODO: make imaginary part sign consistent + kx_vector[idx] = perturbation # TODO: test + print(wavelength, 'varphi divide by 0: adding perturbation') + + varphi = ee.arctan(ky / kx_vector) + + Y_I = ee.diag(k_I_z / k0) + Y_II = ee.diag(k_II_z / k0) + + Z_I = ee.diag(k_I_z / (k0 * n_I ** 2)) + Z_II = ee.diag(k_II_z / (k0 * n_II ** 2)) + + Kx = ee.diag(kx_vector / k0) + + big_F = ee.block([[I, O], [O, 1j * Z_II]]) + big_G = ee.block([[1j * Y_II, O], [O, I]]) + + big_T = ee.eye(2 * ff) + + return Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T + + +def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_i, oneover_E_conv_i, ff, d, varphi, big_F, big_G, big_T): + + I = ee.eye(ff) + O = ee.zeros((ff, ff)) + + A = Kx ** 2 - E_conv + B = Kx @ E_i @ Kx - I + A_i = ee.linalg.inv(A) + B_i = ee.linalg.inv(B) + + to_decompose_W_1 = ky ** 2 * I + A + to_decompose_W_2 = ky ** 2 * I + B @ oneover_E_conv_i + + eigenvalues_1, W_1 = ee.linalg.eig(to_decompose_W_1) + eigenvalues_2, W_2 = ee.linalg.eig(to_decompose_W_2) + + q_1 = eigenvalues_1 ** 0.5 + q_2 = eigenvalues_2 ** 0.5 + + Q_1 = ee.diag(q_1) + Q_2 = ee.diag(q_2) + + V_11 = A_i @ W_1 @ Q_1 + V_12 = (ky / k0) * A_i @ Kx @ W_2 + V_21 = (ky / k0) * B_i @ Kx @ E_i @ W_1 + V_22 = B_i @ W_2 @ Q_2 + + X_1 = ee.diag(ee.exp(-k0 * q_1 * d)) + X_2 = ee.diag(ee.exp(-k0 * q_2 * d)) + + F_c = ee.diag(ee.cos(varphi)) + F_s = ee.diag(ee.sin(varphi)) + + V_ss = F_c @ V_11 + V_sp = F_c @ V_12 - F_s @ W_2 + W_ss = F_c @ W_1 + F_s @ V_21 + W_sp = F_s @ V_22 + W_ps = F_s @ V_11 + W_pp = F_c @ W_2 + F_s @ V_12 + V_ps = F_c @ V_21 - F_s @ W_1 + V_pp = F_c @ V_22 + + big_I = ee.eye(2 * (len(I))) + big_X = ee.block([[X_1, O], [O, X_2]]) + big_W = ee.block([[V_ss, V_sp], [W_ps, W_pp]]) + big_V = ee.block([[W_ss, W_sp], [V_ps, V_pp]]) + + big_W_i = ee.linalg.inv(big_W) + big_V_i = ee.linalg.inv(big_V) + + big_A = 0.5 * (big_W_i @ big_F + big_V_i @ big_G) + big_B = 0.5 * (big_W_i @ big_F - big_V_i @ big_G) + + big_A_i = ee.linalg.inv(big_A) + + big_F = big_W @ (big_I + big_X @ big_B @ big_A_i @ big_X) + big_G = big_V @ (big_I - big_X @ big_B @ big_A_i @ big_X) + + big_T = big_T @ big_A_i @ big_X + + return big_F, big_G, big_T + + +def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z): + I = ee.eye(ff) + O = ee.zeros((ff, ff)) + + big_F_11 = big_F[:ff, :ff] + big_F_12 = big_F[:ff, ff:] + big_F_21 = big_F[ff:, :ff] + big_F_22 = big_F[ff:, ff:] + + big_G_11 = big_G[:ff, :ff] + big_G_12 = big_G[:ff, ff:] + big_G_21 = big_G[ff:, :ff] + big_G_22 = big_G[ff:, ff:] + + # Final Equation in form of AX=B + final_A = ee.block( + [ + [I, O, -big_F_11, -big_F_12], + [O, -1j * Z_I, -big_F_21, -big_F_22], + [-1j * Y_I, O, -big_G_11, -big_G_12], + [O, I, -big_G_21, -big_G_22], + ] + ) + + final_B = ee.hstack([ + [-ee.sin(psi) * delta_i0], + [-ee.cos(psi) * ee.cos(theta) * delta_i0], + [-1j * ee.sin(psi) * n_I * ee.cos(theta) * delta_i0], + [1j * n_I * ee.cos(psi) * delta_i0] + ]).T + + final_X = ee.linalg.inv(final_A) @ final_B + + R_s = final_X[:ff, :].flatten() + R_p = final_X[ff:2 * ff, :].flatten() + + big_T = big_T @ final_X[2 * ff:, :] + T_s = big_T[:ff, :].flatten() + T_p = big_T[ff:, :].flatten() + + de_ri = R_s * ee.conj(R_s) * ee.real(k_I_z / (k0 * n_I * ee.cos(theta))) \ + + R_p * ee.conj(R_p) * ee.real((k_I_z / n_I ** 2) / (k0 * n_I * ee.cos(theta))) + + de_ti = T_s * ee.conj(T_s) * ee.real(k_II_z / (k0 * n_I * ee.cos(theta))) \ + + T_p * ee.conj(T_p) * ee.real((k_II_z / n_II ** 2) / (k0 * n_I * ee.cos(theta))) + + return de_ri.real, de_ti.real + diff --git a/meent/on_jax/_base.py b/meent/on_jax/_base.py index 6e3f5ae..2ba2a2f 100644 --- a/meent/on_jax/_base.py +++ b/meent/on_jax/_base.py @@ -1,110 +1,39 @@ -import matplotlib.pyplot as plt - -# from .scattering_method import * -# from .transfer_method import * - -from .scattering_method import scattering_1d_1, scattering_1d_2, scattering_1d_3, scattering_2d_1, scattering_2d_wv, scattering_2d_2, scattering_2d_3 -from .transfer_method import transfer_1d_1, transfer_1d_2, transfer_1d_3, transfer_1d_conical_1, transfer_1d_conical_2, transfer_1d_conical_3, transfer_2d_1, transfer_2d_wv, transfer_2d_2, transfer_2d_3 +from copy import copy, deepcopy +from functools import partial +import jax import jax.numpy as jnp +from .scattering_method import scattering_1d_1, scattering_1d_2, scattering_1d_3, scattering_2d_1, scattering_2d_wv, \ + scattering_2d_2, scattering_2d_3 +from .transfer_method import transfer_1d_1, transfer_1d_2, transfer_1d_3, transfer_1d_conical_1, transfer_1d_conical_2, \ + transfer_1d_conical_3, transfer_2d_1, transfer_2d_wv, transfer_2d_2, transfer_2d_3 -class Base: - def __init__(self, grating_type, mode=0): - self.grating_type = grating_type - self.wls = None - self.fourier_order = None - self.spectrum_r = None - self.spectrum_t = None - self.mode = mode - - def init_spectrum_array(self): - if self.grating_type in (0, 1): - self.spectrum_r = jnp.zeros((len(self.wls), 2 * self.fourier_order + 1)) - self.spectrum_t = jnp.zeros((len(self.wls), 2 * self.fourier_order + 1)) - elif self.grating_type == 2: - self.spectrum_r = jnp.zeros((len(self.wls), 2 * self.fourier_order + 1, 2 * self.fourier_order + 1)) - self.spectrum_t = jnp.zeros((len(self.wls), 2 * self.fourier_order + 1, 2 * self.fourier_order + 1)) - else: - raise ValueError +import meent.on_jax.jitted as ee - def save_spectrum_array(self, de_ri, de_ti, i): - de_ri = jnp.array(de_ri) - de_ti = jnp.array(de_ti) - - if not de_ri.shape: - # 1D or may be not; there is a case that reticolo returns single value - c = self.spectrum_r.shape[1] // 2 - self.spectrum_r = self.spectrum_r.at[i, c].set(de_ri) - - elif len(de_ri.shape) == 1 or de_ri.shape[1] == 1: # 1D - de_ri = de_ri.flatten() - c = self.spectrum_r.shape[1] // 2 - l = de_ri.shape[0] // 2 - if len(de_ri) % 2: - idx = jnp.arange(c - l, c + l + 1) - self.spectrum_r = self.spectrum_r.at[i, idx].set(de_ri) - else: - idx = jnp.arange(c - l, c + l) - self.spectrum_r = self.spectrum_r.at[i, idx].set(de_ri) - else: - print('no code') - raise ValueError - - if not de_ti.shape: # 1D - c = self.spectrum_t.shape[1] // 2 - self.spectrum_t = self.spectrum_t.at[i, c].set(de_ti) - - elif len(de_ti.shape) == 1 or de_ti.shape[1] == 1: # 1D - de_ti = de_ti.flatten() - c = self.spectrum_t.shape[1] // 2 - l = de_ti.shape[0] // 2 - if len(de_ti) % 2: - idx = jnp.arange(c - l, c + l + 1) - self.spectrum_t = self.spectrum_t.at[i, idx].set(de_ti) - else: - idx = jnp.arange(c - l, c + l) - self.spectrum_t = self.spectrum_t.at[i, idx].set(de_ti) - else: - print('no code') - raise ValueError - - def plot(self, title=None, marker=None): - if self.grating_type == 0: - plt.plot(self.wls, self.spectrum_r.sum(axis=1), marker=marker) - plt.plot(self.wls, self.spectrum_t.sum(axis=1), marker=marker) - elif self.grating_type == 1: - plt.plot(self.wls, self.spectrum_r.sum(axis=1), marker=marker) - plt.plot(self.wls, self.spectrum_t.sum(axis=1), marker=marker) - elif self.grating_type == 2: - plt.plot(self.wls, self.spectrum_r.sum(axis=(1, 2)), marker=marker) - plt.plot(self.wls, self.spectrum_t.sum(axis=(1, 2)), marker=marker) - else: - raise ValueError - plt.title(title) - plt.show() - - -class _BaseRCWA(Base): +class _BaseRCWA: def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., fourier_order=10, - period=0.7, wls=jnp.linspace(0.5, 2.3, 400), pol=0, - patterns=None, thickness=None, algo='TMM', mode=0): - super().__init__(grating_type) + period=0.7, wavelength=ee.linspace(0.5, 2.3, 400), pol=0, + patterns=None, ucell=None, ucell_materials=None, thickness=None, algo='TMM', perturbation=1E-10, + device='cpu', type_complex=jnp.complex128): + + self.device = device + self.type_complex = type_complex self.grating_type = grating_type # 1D=0, 1D_conical=1, 2D=2 self.n_I = n_I self.n_II = n_II - self.theta = theta * jnp.pi / 180 - self.phi = phi * jnp.pi / 180 - self.psi = psi * jnp.pi / 180 # TODO: integrate psi and pol + self.theta = theta * ee.pi / 180 + self.phi = phi * ee.pi / 180 + self.psi = psi * ee.pi / 180 # TODO: integrate psi and pol self.pol = pol # TE 0, TM 1 if self.pol == 0: # TE - self.psi = 90 * jnp.pi / 180 + self.psi = 90 * ee.pi / 180 elif self.pol == 1: # TM - self.psi = 0 * jnp.pi / 180 + self.psi = 0 * ee.pi / 180 else: print('not implemented yet') raise ValueError @@ -112,31 +41,59 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., four self.fourier_order = fourier_order self.ff = 2 * self.fourier_order + 1 - self.period = period + self.period = deepcopy(period) - self.wls = wls + self.wavelength = wavelength - self.patterns = [[3.8, 1, 0.3]] if patterns is None else patterns - self.thickness = [1120] if thickness is None else thickness + self.patterns = patterns + self.ucell = deepcopy(ucell) + self.ucell_materials = ucell_materials + self.thickness = deepcopy(thickness) self.algo = algo + self.perturbation = perturbation + + self.layer_info_list = [] + self.T1 = None - self.init_spectrum_array() + self.kx_vector = None - def solve_1d(self, wl, E_conv_all, oneover_E_conv_all): + def get_kx_vector(self): + k0 = 2 * jnp.pi / self.wavelength fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + if self.grating_type == 0: + kx_vector = k0 * (self.n_I * jnp.sin(self.theta) - fourier_indices * (self.wavelength / self.period[0]) + ).astype(self.type_complex) + else: + kx_vector = k0 * (self.n_I * jnp.sin(self.theta) * jnp.cos(self.phi) - fourier_indices * ( + self.wavelength / self.period[0])).astype(self.type_complex) - delta_i0 = jnp.zeros(self.ff) + idx = jnp.nonzero(kx_vector == 0)[0] + if len(idx): + # TODO: need imaginary part? make imaginary part sign consistent + kx_vector = kx_vector.at[idx].set(self.perturbation) + print('varphi divide by 0: adding perturbation') + + self.kx_vector = kx_vector + + @partial(jax.jit, static_argnums=(0,)) + def solve_1d(self, wl, E_conv_all, o_E_conv_all): + + self.layer_info_list = [] + self.T1 = None + + fourier_indices = ee.arange(-self.fourier_order, self.fourier_order + 1) + + delta_i0 = ee.zeros(self.ff, dtype=self.type_complex) delta_i0 = delta_i0.at[self.fourier_order].set(1) - k0 = 2 * jnp.pi / wl + k0 = 2 * ee.pi / wl - # -------------------------------------------------------------------- if self.algo == 'TMM': - Kx, k_I_z, k_II_z, Kx, f, YZ_I, g, inc_term, T \ - = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, - self.theta, delta_i0, self.fourier_order, fourier_indices, wl, self.period) + kx_vector, Kx, k_I_z, k_II_z, Kx, f, YZ_I, g, inc_term, T \ + = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, + self.theta, delta_i0, self.fourier_order, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, fourier_indices, self.period, @@ -144,133 +101,174 @@ def solve_1d(self, wl, E_conv_all, oneover_E_conv_all): else: raise ValueError - # -------------------------------------------------------------------- - for E_conv, oneover_E_conv, d in zip(E_conv_all[::-1], oneover_E_conv_all[::-1], self.thickness[::-1]): + # From the last layer + for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + if self.pol == 0: + E_conv_i = None A = Kx ** 2 - E_conv - eigenvalues, W = jnp.linalg.eig(A) + eigenvalues, W = ee.eig(A, type_complex=self.type_complex) q = eigenvalues ** 0.5 - Q = jnp.diag(q) + Q = ee.diag(q) V = W @ Q elif self.pol == 1: - E_i = jnp.linalg.inv(E_conv) - B = Kx @ E_i @ Kx - jnp.eye(E_conv.shape[0]) - oneover_E_conv_i = jnp.linalg.inv(oneover_E_conv) + E_conv_i = ee.inv(E_conv) + B = Kx @ E_conv_i @ Kx - ee.eye(E_conv.shape[0]).astype(self.type_complex) + o_E_conv_i = ee.inv(o_E_conv) - eigenvalues, W = jnp.linalg.eig(oneover_E_conv_i @ B) + eigenvalues, W = ee.eig(o_E_conv_i @ B, type_complex=self.type_complex) q = eigenvalues ** 0.5 - Q = jnp.diag(q) - V = oneover_E_conv @ W @ Q + Q = ee.diag(q) + V = o_E_conv @ W @ Q else: raise ValueError - # -------------------------------------------------------------------- + if self.algo == 'TMM': - f, g, T = transfer_1d_2(k0, q, d, W, V, f, g, self.fourier_order, T) + X, f, g, T, a_i, b = transfer_1d_2(k0, q, d, W, V, f, g, self.fourier_order, T, + type_complex=self.type_complex) + + layer_info = [E_conv_i, q, W, X, a_i, b, d] + self.layer_info_list.append(layer_info) + elif self.algo == 'SMM': A, B, S_dict, Sg = scattering_1d_2(W, Wg, V, Vg, d, k0, Q, Sg) else: raise ValueError if self.algo == 'TMM': - de_ri, de_ti = transfer_1d_3(g, YZ_I, f, delta_i0, inc_term, T, k_I_z, k0, self.n_I, self.n_II, - self.theta, self.pol, k_II_z) + de_ri, de_ti, T1 = transfer_1d_3(g, YZ_I, f, delta_i0, inc_term, T, k_I_z, k0, self.n_I, self.n_II, + self.theta, self.pol, k_II_z) + self.T1 = T1 + elif self.algo == 'SMM': de_ri, de_ti = scattering_1d_3(Wt, Wg, Vt, Vg, Sg, self.ff, Wr, self.fourier_order, Kzr, Kzt, self.n_I, self.n_II, self.theta, self.pol) else: raise ValueError - return de_ri, de_ti + return de_ri, de_ti, self.layer_info_list, self.T1 # TODO: scattering method - def solve_1d_conical(self, wl, e_conv_all, o_e_conv_all): + @partial(jax.jit, static_argnums=(0,)) + def solve_1d_conical(self, wl, E_conv_all, o_E_conv_all): - fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + self.layer_info_list = [] + self.T1 = None + + # fourier_indices = ee.arange(-self.fourier_order, self.fourier_order + 1) - delta_i0 = jnp.zeros(self.ff) + delta_i0 = ee.zeros(self.ff, dtype=self.type_complex) delta_i0 = delta_i0.at[self.fourier_order].set(1) - k0 = 2 * jnp.pi / wl + k0 = 2 * ee.pi / wl if self.algo == 'TMM': Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.period, fourier_indices, self.theta, self.phi, wl) + = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, + type_complex=self.type_complex) elif self.algo == 'SMM': print('SMM for 1D conical is not implemented') - return jnp.nan, jnp.nan + return ee.nan, ee.nan else: raise ValueError - for e_conv, o_e_conv, d in zip(e_conv_all[::-1], o_e_conv_all[::-1], self.thickness[::-1]): - e_conv_i = jnp.linalg.inv(e_conv) - o_e_conv_i = jnp.linalg.inv(o_e_conv) + for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + E_conv_i = ee.inv(E_conv) + o_E_conv_i = ee.inv(o_E_conv) if self.algo == 'TMM': - big_F, big_G, big_T = transfer_1d_conical_2(k0, Kx, ky, e_conv, e_conv_i, o_e_conv_i, self.ff, d, - varphi, big_F, big_G, big_T) + # big_F, big_G, big_T\ + big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \ + = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, self.ff, d, + varphi, big_F, big_G, big_T, + type_complex=self.type_complex) + + layer_info = [E_conv_i, q_1, q_2, W_1, W_2, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] + self.layer_info_list.append(layer_info) + elif self.algo == 'SMM': raise ValueError else: raise ValueError if self.algo == 'TMM': - de_ri, de_ti = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, - delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z) + de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, + type_complex=self.type_complex) + self.T1 = big_T1 + elif self.algo == 'SMM': raise ValueError else: raise ValueError - return de_ri, de_ti + return de_ri, de_ti, self.layer_info_list, self.T1 - def solve_2d(self, wl, E_conv_all, oneover_E_conv_all): + @partial(jax.jit, static_argnums=(0,)) + def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): - fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + self.layer_info_list = [] + self.T1 = None - delta_i0 = jnp.zeros((self.ff ** 2, 1)) + fourier_indices = ee.arange(-self.fourier_order, self.fourier_order + 1) + delta_i0 = ee.zeros((self.ff ** 2, 1), dtype=self.type_complex) delta_i0 = delta_i0.at[self.ff ** 2 // 2, 0].set(1) - I = jnp.eye(self.ff ** 2) - O = jnp.zeros((self.ff ** 2, self.ff ** 2)) + I = ee.eye(self.ff ** 2).astype(self.type_complex) + O = ee.zeros((self.ff ** 2, self.ff ** 2), dtype=self.type_complex) center = self.ff ** 2 - k0 = 2 * jnp.pi / wl + k0 = 2 * ee.pi / wavelength if self.algo == 'TMM': - Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.period, fourier_indices, self.theta, self.phi, wl) + kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ + = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices, + self.theta, self.phi, wavelength, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Ky, kz_inc, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ = scattering_2d_1(self.n_I, self.n_II, self.theta, self.phi, k0, self.period, self.fourier_order) else: raise ValueError - for E_conv, oneover_E_conv, d in zip(E_conv_all[::-1], oneover_E_conv_all[::-1], self.thickness[::-1]): - E_i = jnp.linalg.inv(E_conv) - oneover_E_conv_i = jnp.linalg.inv(oneover_E_conv) + # From the last layer + for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + E_conv_i = ee.inv(E_conv) + o_E_conv_i = ee.inv(o_E_conv) if self.algo == 'TMM': # TODO: MERGE W V part - W, V, LAMBDA, Lambda = transfer_2d_wv(self.ff, Kx, E_i, Ky, oneover_E_conv_i, E_conv, center) - big_F, big_G, big_T = transfer_2d_2(k0, d, W, V, center, Lambda, varphi, I, O, big_F, big_G, big_T) + W, V, q = transfer_2d_wv(self.ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex) + + big_X, big_F, big_G, big_T, big_A_i, big_B, \ + W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 \ + = transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, + type_complex=self.type_complex) + + layer_info = [E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] + self.layer_info_list.append(layer_info) + elif self.algo == 'SMM': - W, V, LAMBDA = scattering_2d_wv(self.ff, Kx, Ky, E_conv, oneover_E_conv, oneover_E_conv_i, E_i) + W, V, LAMBDA = scattering_2d_wv(self.ff, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA) else: raise ValueError if self.algo == 'TMM': - de_ri, de_ti = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, - delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z) + de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, + type_complex=self.type_complex) + self.T1 = big_T1 + elif self.algo == 'SMM': de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, self.pol, self.theta, self.phi, self.fourier_order, self.ff) else: raise ValueError - return de_ri.reshape((self.ff, self.ff)).real, de_ti.reshape((self.ff, self.ff)).real + return de_ri.reshape((self.ff, self.ff)).real, de_ti.reshape( + (self.ff, self.ff)).real, self.layer_info_list, self.T1 diff --git a/meent/on_jax/convolution_matrix.py b/meent/on_jax/convolution_matrix.py index 7bad0d2..e457184 100644 --- a/meent/on_jax/convolution_matrix.py +++ b/meent/on_jax/convolution_matrix.py @@ -1,106 +1,241 @@ -import copy +import time +from functools import partial + +import numpy as np +import jax import jax.numpy as jnp -from scipy.io import loadmat +import meent.on_jax.jitted as ee + +from os import walk +from scipy.io import loadmat from pathlib import Path -# from jax.scipy.linalg import circulant # hope this is supported -def put_n_ridge_in_pattern(pattern_all, wl): +# @jax.jit +def put_permittivity_in_ucell(ucell, mat_list, mat_table, wl, type_complex=jnp.complex128): + + res = ee.zeros(ucell.shape, dtype=type_complex) + + for z in range(ucell.shape[0]): + for y in range(ucell.shape[1]): + for x in range(ucell.shape[2]): + material = mat_list[ucell[z, y, x]] + assign_index = (z, y, x) + + if type(material) == str: + assign_value = find_nk_index(material, mat_table, wl) ** 2 + else: + assign_value = material ** 2 + res = ee.assign(res, assign_index, assign_value) + + return res + + +def put_permittivity_in_ucell_object(ucell_size, mat_list, obj_list, mat_table, wl, + type_complex=jnp.complex128): + # TODO: under development + res = ee.zeros(ucell_size, dtype=type_complex) + + for material, obj_index in zip(mat_list, obj_list): + if type(material) == str: + res[obj_index] = find_nk_index(material, mat_table, wl) ** 2 + else: + res[obj_index] = material ** 2 + + return res + + +def find_nk_index(material, mat_table, wl): + if material[-6:] == '__real': + material = material[:-6] + n_only = True + else: + n_only = False + + mat_data = mat_table[material.upper()] - pattern_all = copy.deepcopy(pattern_all) + n_index = ee.interp(wl, mat_data[:, 0], mat_data[:, 1]) - for i, (n_ridge, n_groove, pattern) in enumerate(pattern_all): + if n_only: + return n_index - if type(n_ridge) == str: - material = n_ridge - n_ridge = find_n_index(material, wl) - pattern_all[i][0] = n_ridge - return pattern_all + k_index = ee.interp(wl, mat_data[:, 0], mat_data[:, 2]) + nk = n_index + 1j * k_index + return nk -def find_n_index(material, wl): - # TODO: where put this to? - nk_path = str(Path(__file__).resolve().parent.parent.parent) + '/nk_data/p_Si.mat' # TODO: organize - mat_si = loadmat(nk_path) +def read_material_table(nk_path=None): mat_table = {} - mat_table['SILICON'] = mat_si - mat_property = mat_table[material.upper()] - n_index = jnp.interp(wl, mat_property['WL'].flatten(), mat_property['n'].flatten()) + if nk_path is None: + nk_path = str(Path(__file__).resolve().parent.parent) + '/nk_data' + + full_path_list, name_list, _ = [], [], [] + for (dirpath, dirnames, filenames) in walk(nk_path): + full_path_list.extend([f'{dirpath}/{filename}' for filename in filenames]) + name_list.extend(filenames) + for path, name in zip(full_path_list, name_list): + if name[-3:] == 'txt': + data = ee.loadtxt(path, skiprows=1) + mat_table[name[:-4].upper()] = data + + elif name[-3:] == 'mat': + data = loadmat(path) + data = ee.array([data['WL'], data['n'], data['k']])[:, :, 0].T + mat_table[name[:-4].upper()] = data + return mat_table + + +# can't jit +def cell_compression(cell, type_complex=jnp.complex128): + + if type_complex == jnp.complex128: + type_float = jnp.float64 + else: + type_float = jnp.float32 + + # find discontinuities in x + step_y, step_x = 1. / ee.array(cell.shape, dtype=type_float) + x = [] + y = [] + cell_x = [] + cell_xy = [] + + cell_next = ee.roll(cell, -1, axis=1) + + for col in range(cell.shape[1]): + if not (cell[:, col] == cell_next[:, col]).all() or (col == cell.shape[1] - 1): + + x.append(step_x * (col + 1)) + cell_x.append(cell[:, col]) + + cell_x = ee.array(cell_x).T + cell_x_next = ee.roll(cell_x, -1, axis=0) + + for row in range(cell_x.shape[0]): + if not (cell_x[row, :] == cell_x_next[row, :]).all() or (row == cell_x.shape[0] - 1): + y.append(step_y * (row + 1)) + cell_xy.append(cell_x[row, :]) + + x = ee.array(x).reshape((-1, 1)) + y = ee.array(y).reshape((-1, 1)) + cell_comp = ee.array(cell_xy) + + return cell_comp, x, y + + +# @partial(jax.jit, static_argnums=(1,2 )) +def fft_piecewise_constant(cell, fourier_order, type_complex=jnp.complex128): + + if cell.shape[0] == 1: + fourier_order = [0, fourier_order] + else: + fourier_order = [fourier_order, fourier_order] + cell, x, y = cell_compression(cell, type_complex=type_complex) + + # X axis + cell_next_x = ee.roll(cell, -1, axis=1) + cell_diff_x = cell_next_x - cell + + modes = ee.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) + + f_coeffs_x = cell_diff_x @ ee.exp(-1j * 2 * ee.pi * x @ modes[None, :]).astype(type_complex) + c = f_coeffs_x.shape[1] // 2 + + x_next = ee.vstack((ee.roll(x, -1, axis=0)[:-1], 1)) - x + + assign_index = (ee.arange(len(f_coeffs_x)), ee.array([c])) + assign_value = (cell @ ee.vstack((x[0], x_next[:-1]))).flatten().astype(type_complex) + + f_coeffs_x = ee.assign(f_coeffs_x, assign_index, assign_value) + # f_coeffs_x = f_coeffs_x.at[assign_index].set(assign_value) + + mask_int = ee.hstack([ee.arange(c), ee.arange(c+1, f_coeffs_x.shape[1])]) + + assign_index = mask_int + + assign_value = f_coeffs_x[:, mask_int] / (1j * 2 * ee.pi * modes[mask_int]) - return n_index + f_coeffs_x = ee.assign(f_coeffs_x, assign_index, assign_value, row_all=True) + # f_coeffs_x = f_coeffs_x.at[:, assign_index].set(assign_value) + # Y axis + f_coeffs_x_next_y = ee.roll(f_coeffs_x, -1, axis=0) + f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x -def fill_factor_to_ucell(patterns_fill_factor, wl, grating_type): + modes = ee.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) - pattern_fill_factor = put_n_ridge_in_pattern(patterns_fill_factor, wl) - ucell = draw_fill_factor(pattern_fill_factor, grating_type) + f_coeffs_xy = f_coeffs_x_diff_y.T @ ee.exp(-1j * 2 * ee.pi * y @ modes[None, :]).astype(type_complex) + c = f_coeffs_xy.shape[1] // 2 - return ucell + y_next = ee.vstack((ee.roll(y, -1, axis=0)[:-1], 1)) - y + assign_index = [c] + assign_value = f_coeffs_x.T @ ee.vstack((y[0], y_next[:-1])).astype(type_complex) + f_coeffs_xy = ee.assign(f_coeffs_xy, assign_index, assign_value, row_all=True) + # f_coeffs_xy = f_coeffs_xy.at[:, assign_index].set(assign_value) + + + if c: + mask_int = ee.hstack([ee.arange(c), ee.arange(c + 1, f_coeffs_x.shape[1])]) + + assign_index = mask_int + assign_value = f_coeffs_xy[:, mask_int] / (1j * 2 * ee.pi * modes[mask_int]) + + f_coeffs_xy = ee.assign(f_coeffs_xy, assign_index, assign_value, row_all=True) + # f_coeffs_xy = f_coeffs_xy.at[:, assign_index].set(assign_value) + + return f_coeffs_xy.T + + +# @partial(jax.jit, static_argnums=(1, )) +def to_conv_mat(pmt, fourier_order, type_complex=jnp.complex128): -def to_conv_mat(pmt, fourier_order): - # FFT scaling: https://kr.mathworks.com/matlabcentral/answers/15770-scaling-the-fft-and-the-ifft?s_tid=srchtitle if len(pmt.shape) == 2: print('shape is 2') raise ValueError ff = 2 * fourier_order + 1 - # if len(pmt.shape)==2 or pmt.shape[1] == 1: # 1D - if pmt.shape[1] == 1: # 1D # TODO: confirm this handles all cases - res = jnp.zeros((pmt.shape[0], 2 * fourier_order + 1, 2 * fourier_order + 1)).astype('complex') + if pmt.shape[1] == 1: # 1D - # extend array for FFT - minimum_pattern_size = (4 * fourier_order + 1) * pmt.shape[2] - # TODO: what is theoretical minimum? - # TODO: can be a scalability issue - if pmt.shape[2] < minimum_pattern_size: - n = minimum_pattern_size // pmt.shape[2] - pmt = jnp.repeat(pmt, n + 1, axis=2) + res = ee.zeros((pmt.shape[0], ff, ff)).astype(type_complex) + + for i, layer in enumerate(pmt): + f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) - for i, pmtvy in enumerate(pmt): - pmtvy_fft = jnp.fft.fftshift(jnp.fft.fftn(pmtvy / pmtvy.size)) - center = pmtvy_fft.shape[1] // 2 + center = f_coeffs.shape[1] // 2 - conv_idx = jnp.arange(ff - 1, -ff, -1) + conv_idx = ee.arange(-ff + 1, ff, 1) conv_idx = circulant(conv_idx) - res = res.at[i].set(pmtvy_fft[1, center + conv_idx]) + + e_conv = f_coeffs[0, center + conv_idx] + # res = res.at[i].set(e_conv) + res = ee.assign(res, i, e_conv) else: # 2D - # attention on the order of axis. - # Here X Y Z. Cell Drawing in CAD is Y X Z. Here is Z Y X + # attention on the order of axis (Z Y X) - # TODO: separate fourier order - res = jnp.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype('complex') + res = ee.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype(type_complex) - # extend array - # TODO: run test - minimum_pattern_size = ff ** 2 - # TODO: what is theoretical minimum? - # TODO: can be a scalability issue + for i, layer in enumerate(pmt): - if pmt.shape[1] < minimum_pattern_size: - n = minimum_pattern_size // pmt.shape[1] - pmt = jnp.repeat(pmt, n + 1, axis=1) - if pmt.shape[2] < minimum_pattern_size: - n = minimum_pattern_size // pmt.shape[2] - pmt = jnp.repeat(pmt, n + 1, axis=2) + f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) - for i, layer in enumerate(pmt): - pmtvy_fft = jnp.fft.fftshift(jnp.fft.fft2(layer / layer.size)) + center = ee.array(f_coeffs.shape) // 2 - center = jnp.array(pmtvy_fft.shape) // 2 + conv_idx = ee.arange(-ff + 1, ff, 1) - conv_idx = jnp.arange(ff - 1, -ff, -1) conv_idx = circulant(conv_idx) - conv_i = jnp.repeat(conv_idx, ff, axis=1) - conv_i = jnp.repeat(conv_i, ff, axis=0) - conv_j = jnp.tile(conv_idx, (ff, ff)) + conv_i = ee.repeat(conv_idx, ff, 1) + conv_i = ee.repeat(conv_i, ff, axis=0) + conv_j = ee.tile(conv_idx, (ff, ff)) - res = res.at[i].set(pmtvy_fft[center[0] + conv_i, center[1] + conv_j]) + # res = res.at[i].set(f_coeffs[center[0] + conv_i, center[1] + conv_j]) + assign_value = f_coeffs[center[0] + conv_i, center[1] + conv_j] + res = ee.assign(res, i, assign_value) # import matplotlib.pyplot as plt # @@ -108,56 +243,21 @@ def to_conv_mat(pmt, fourier_order): # plt.imshow(abs(res[0]), cmap='jet') # plt.colorbar() # plt.show() - # - return res - - -def draw_fill_factor(patterns_fill_factor, grating_type, resolution=1000): - - # res in Z X Y - if grating_type == 2: - res = jnp.zeros((len(patterns_fill_factor), resolution, resolution), dtype='complex') - else: - res = jnp.zeros((len(patterns_fill_factor), 1, resolution), dtype='complex') - - if grating_type in (0, 1): # TODO: handle this by len(fill_factor) - # fill_factor is not exactly implemented. - for i, (n_ridge, n_groove, fill_factor) in enumerate(patterns_fill_factor): - permittivity = jnp.ones((1, resolution), dtype='complex') - cut = int(resolution * fill_factor) - - cut_idx = jnp.arange(cut) - permittivity *= n_groove ** 2 - - permittivity = permittivity.at[0, cut_idx].set(n_ridge ** 2) - res = res.at[i].set(permittivity) - - else: # 2D - for i, (n_ridge, n_groove, fill_factor) in enumerate(patterns_fill_factor): - fill_factor = jnp.array(fill_factor) - permittivity = jnp.ones((resolution, resolution), dtype='complex') - cut = (resolution * fill_factor) # TODO: need parenthesis? - cut_idx_row = jnp.arange(int(cut[1])) - cut_idx_column = jnp.arange(int(cut[0])) - - permittivity *= n_groove ** 2 - - rows, cols = jnp.meshgrid(cut_idx_row, cut_idx_column, indexing='ij') - - permittivity = permittivity.at[rows, cols].set(n_ridge ** 2) - res = res.at[i].set(permittivity) - + # print('conv time: ', time.time() - t0) return res def circulant(c): - center = jnp.array(c.shape) // 2 - circ = jnp.zeros((center[0] + 1, center[0] + 1), dtype='int32') + center = c.shape[0] // 2 + # circ = ee.zeros((center[0] + 1, center[0] + 1), dtype='int32') + circ = ee.zeros((center + 1, center + 1), int) - for r in range(center[0]+1): - idx = jnp.arange(r, r - center - 1, -1) + for r in range(center+1): + idx = ee.arange(r, r - center - 1, -1) - circ = circ.at[r].set(c[center + idx]) + # circ = circ.at[r].set(c[center + idx]) + assign_value = c[center + idx] + circ = ee.assign(circ, r, assign_value) return circ diff --git a/meent/on_jax/field_distribution.py b/meent/on_jax/field_distribution.py new file mode 100644 index 0000000..9a326d5 --- /dev/null +++ b/meent/on_jax/field_distribution.py @@ -0,0 +1,332 @@ +import time +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import matplotlib.pyplot as plt +import meent.on_jax.jitted as ee + + +def field_distribution(grating_type, *args, **kwargs): + if grating_type == 0: + res = field_dist_1d(*args, **kwargs) + elif grating_type == 1: + res = field_dist_1d_conical(*args, **kwargs) + else: + res = field_dist_2d(*args, **kwargs) + return res + + +def field_dist_1d(wavelength, kx_vector, n_I, theta, fourier_order, T1, layer_info_list, period, pol, resolution=(100, 1, 100), + type_complex=jnp.complex128): + + k0 = 2 * ee.pi / wavelength + # fourier_indices = ee.arange(-fourier_order, fourier_order + 1) + # kx_vector = k0 * (n_I * ee.sin(theta) - fourier_indices * (wavelength / period[0])).astype(type_complex) + + Kx = ee.diag(kx_vector / k0) + + resolution_z, resolution_y, resolution_x = resolution + + # Here use numpy array due to slow assignment speed in JAX + field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) + + T_layer = T1 + + # From the first layer + for idx_layer, (E_conv_i, q, W, X, a_i, b, d) in enumerate(layer_info_list[::-1]): + + c1 = T_layer[:, None] + c2 = b @ a_i @ X @ T_layer[:, None] + + Q = ee.diag(q) + + if pol == 0: + V = W @ Q + EKx = None + + else: + V = E_conv_i @ W @ Q + EKx = E_conv_i @ Kx + + for k in range(resolution_z): + z = k / resolution_z * d + + A, B, C = z_loop_1d(pol, k0, Kx, W, V, Q, c1, c2, d, z, EKx) + for j in range(resolution_y): + for i in range(resolution_x): + res = x_loop_1d(pol, resolution_x, period, i, A, B, C, kx_vector) + field_cell[resolution_z * idx_layer + k, j, i] = res + + T_layer = a_i @ X @ T_layer + + return field_cell + + +def field_dist_1d_conical(wavelength, kx_vector, n_I, theta, phi, fourier_order, T1, layer_info_list, period, + resolution=(100, 100, 100), type_complex=jnp.complex128): + + k0 = 2 * ee.pi / wavelength + # fourier_indices = ee.arange(-fourier_order, fourier_order + 1) + + # kx_vector = k0 * (n_I * ee.sin(theta) * ee.cos(phi) - fourier_indices * ( + # wavelength / period[0])).astype(type_complex) + ky = k0 * n_I * ee.sin(theta) * ee.sin(phi) + + Kx = ee.diag(kx_vector / k0) + + resolution_z, resolution_y, resolution_x = resolution + + # Here use numpy array due to slow assignment speed in JAX + field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 6), dtype=type_complex) + + T_layer = T1 + + big_I = ee.eye((len(T1))).astype(type_complex) + + # From the first layer + for idx_layer, [E_conv_i, q_1, q_2, W_1, W_2, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] \ + in enumerate(layer_info_list[::-1]): + + c = ee.block([[big_I], [big_B @ big_A_i @ big_X]]) @ T_layer + + for k in range(resolution_z): + Sx, Sy, Ux, Uy, Sz, Uz = z_loop_1d_conical(k, c, k0, Kx, ky, resolution_z, E_conv_i, q_1, q_2, W_1, W_2, V_11, V_12, V_21, V_22, d) + + for j in range(resolution_y): + for i in range(resolution_x): + val = x_loop_1d_conical(period, resolution_x, kx_vector, Sx, Sy, Sz, Ux, Uy, Uz, i) + field_cell[resolution_z * idx_layer + k, j, i] = val + T_layer = big_A_i @ big_X @ T_layer + + return field_cell + + +def field_dist_2d(wavelength, kx_vector, n_I, theta, phi, fourier_order, T1, layer_info_list, period, resolution=(10, 10, 10), + type_complex=jnp.complex128): + + k0 = 2 * ee.pi / wavelength + fourier_indices = ee.arange(-fourier_order, fourier_order + 1) + ff = 2 * fourier_order + 1 + + # kx_vector = k0 * (n_I * ee.sin(theta) * ee.cos(phi) - fourier_indices * ( + # wavelength / period[0])).astype(type_complex) + ky_vector = k0 * (n_I * ee.sin(theta) * ee.sin(phi) - fourier_indices * ( + wavelength / period[1])).astype(type_complex) + + Kx = ee.diag(ee.tile(kx_vector, ff).flatten()) / k0 + Ky = ee.diag(ee.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 + + resolution_z, resolution_y, resolution_x = resolution + + # Here use numpy array due to slow assignment speed in JAX + field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 6), dtype=type_complex) + + T_layer = T1 + + big_I = ee.eye((len(T1))).astype(type_complex) + + # From the first layer + for idx_layer, (E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d)\ + in enumerate(layer_info_list[::-1]): + + c = ee.block([[big_I], [big_B @ big_A_i @ big_X]]) @ T_layer + + for k in range(resolution_z): + Sx, Sy, Ux, Uy, Sz, Uz = z_loop_2d(k, c, k0, Kx, Ky, resolution_z, E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, d) + for j in range(resolution_y): + y = j * period[1] / resolution_y + for i in range(resolution_x): + val = x_loop_2d(period, resolution_x, kx_vector, ky_vector, Sx, Sy, Sz, Ux, Uy, Uz, y, i) + field_cell[resolution_z * idx_layer + k, j, i] = val + T_layer = big_A_i @ big_X @ T_layer + + return field_cell + + +@partial(jax.jit, static_argnums=(0,)) +def z_loop_1d(pol, k0, Kx, W, V, Q, c1, c2, d, z, EKx): + + if pol == 0: # TE + Sy = W @ (expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + Ux = V @ (-expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + C = (-1j) * Kx @ Sy + + return Sy, Ux, C + + else: # TM + Uy = W @ (expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + Sx = V @ (-expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + + C = (-1j) * EKx @ Uy # there is a better option for convergence + + return Uy, Sx, C + + +@partial(jax.jit, static_argnums=(0,)) +def x_loop_1d(pol, resolution_x, period, i, A, B, C, kx_vector): + + if pol == 0: # TE + Sy, Ux = A, B + x = i * period[0] / resolution_x + + Ey = Sy.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Hx = -1j * Ux.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Hz = C.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + + # field_cell = field_cell.at[resolution_z * idx_layer + k, j, i].set([Ey[0, 0], Hx[0, 0], Hz[0, 0]]) + res = [Ey[0, 0], Hx[0, 0], Hz[0, 0]] + + else: # TM + Uy, Sx = A, B + x = i * period[0] / resolution_x + + Hy = Uy.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Ex = 1j * Sx.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Ez = C.T @ ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + + res = [Hy[0, 0], Ex[0, 0], Ez[0, 0]] + + return res + + +@jax.jit +def z_loop_1d_conical(k, c, k0, Kx, ky, resolution_z, E_conv_i, q_1, q_2, W_1, W_2, V_11, V_12, V_21, V_22, d): + + z = k / resolution_z * d + + ff = len(c) // 4 + + c1_plus = c[0 * ff:1 * ff] + c2_plus = c[1 * ff:2 * ff] + c1_minus = c[2 * ff:3 * ff] + c2_minus = c[3 * ff:4 * ff] + + big_Q1 = ee.diag(q_1) + big_Q2 = ee.diag(q_2) + + Sx = W_2 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z - d)) @ c2_minus) + + Sy = V_11 @ (expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z - d)) @ c1_minus) \ + + V_12 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z - d)) @ c2_minus) + + Ux = W_1 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z - d)) @ c1_minus) + + Uy = V_21 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z - d)) @ c1_minus) \ + + V_22 @ (-expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z - d)) @ c2_minus) + + Sz = -1j * E_conv_i @ (Kx @ Uy - ky * Ux) + + Uz = -1j * (Kx @ Sy - ky * Sx) + + return Sx, Sy, Ux, Uy, Sz, Uz + + +@jax.jit +def x_loop_1d_conical(period, resolution_x, kx_vector, Sx, Sy, Sz, Ux, Uy, Uz, i): + + x = i * period[0] / resolution_x + + exp_K = ee.exp(-1j * kx_vector.reshape((-1, 1)) * x) + # exp_K = exp_K.flatten() + + Ex = Sx.T @ exp_K + Ey = Sy.T @ exp_K + Ez = Sz.T @ exp_K + + Hx = -1j * Ux.T @ exp_K + Hy = -1j * Uy.T @ exp_K + Hz = -1j * Uz.T @ exp_K + + res = [Ex[0, 0], Ey[0, 0], Ez[0, 0], Hx[0, 0], Hy[0, 0], Hz[0, 0]] + return res + + +@jax.jit +def z_loop_2d(k, c, k0, Kx, Ky, resolution_z, E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, d): + + z = k / resolution_z * d + + ff = len(c) // 4 + + c1_plus = c[0 * ff:1 * ff] + c2_plus = c[1 * ff:2 * ff] + c1_minus = c[2 * ff:3 * ff] + c2_minus = c[3 * ff:4 * ff] + + q1 = q[:len(q) // 2] + q2 = q[len(q) // 2:] + big_Q1 = ee.diag(q1) + big_Q2 = ee.diag(q2) + + Sx = W_11 @ (expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z - d)) @ c1_minus) \ + + W_12 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z - d)) @ c2_minus) + + Sy = W_21 @ (expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z - d)) @ c1_minus) \ + + W_22 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z - d)) @ c2_minus) + + Ux = V_11 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z - d)) @ c1_minus) \ + + V_12 @ (-expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z - d)) @ c2_minus) + + Uy = V_21 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z - d)) @ c1_minus) \ + + V_22 @ (-expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z - d)) @ c2_minus) + + Sz = -1j * E_conv_i @ (Kx @ Uy - Ky @ Ux) + + Uz = -1j * (Kx @ Sy - Ky @ Sx) + + return Sx, Sy, Ux, Uy, Sz, Uz + + +@jax.jit +def x_loop_2d(period, resolution_x, kx_vector, ky_vector, Sx, Sy, Sz, Ux, Uy, Uz, y, i): + + x = i * period[0] / resolution_x + + exp_K = ee.exp(-1j * kx_vector.reshape((1, -1)) * x) * ee.exp(-1j * ky_vector.reshape((-1, 1)) * y) + exp_K = exp_K.flatten() + + Ex = Sx.T @ exp_K + Ey = Sy.T @ exp_K + Ez = Sz.T @ exp_K + + Hx = -1j * Ux.T @ exp_K + Hy = -1j * Uy.T @ exp_K + Hz = -1j * Uz.T @ exp_K + + res = [Ex[0], Ey[0], Ez[0], Hx[0], Hy[0], Hz[0]] + + return res + + +def field_plot(field_cell, pol=0, plot_indices=(1, 1, 1, 1, 1, 1), y_slice=0, z_slice=-1, zx=True, yx=True): + + if field_cell.shape[-1] == 6: # 2D grating + title = ['2D Ex', '2D Ey', '2D Ez', '2D Hx', '2D Hy', '2D Hz', ] + else: # 1D grating + if pol == 0: # TE + title = ['1D Ey', '1D Hx', '1D Hz', ] + else: # TM + title = ['1D Hy', '1D Ex', '1D Ez', ] + + if zx: + for idx in range(len(title)): + if plot_indices[idx]: + plt.imshow((abs(field_cell[:, y_slice, :, idx]) ** 2), cmap='jet', aspect='auto') + # plt.clim(0, 2) # identical to caxis([-4,4]) in MATLAB + plt.colorbar() + plt.title(title[idx]) + plt.show() + if yx: + for idx in range(len(title)): + if plot_indices[idx]: + plt.imshow((abs(field_cell[z_slice, :, :, idx]) ** 2), cmap='jet', aspect='auto') + # plt.clim(0, 3.5) # identical to caxis([-4,4]) in MATLAB + plt.colorbar() + plt.title(title[idx]) + plt.show() + + +def expm(x): + return ee.diag(ee.exp(ee.diag(x))) diff --git a/meent/on_jax/jitted.py b/meent/on_jax/jitted.py new file mode 100644 index 0000000..cff0c03 --- /dev/null +++ b/meent/on_jax/jitted.py @@ -0,0 +1,90 @@ +import numpy as np + +import jax +import jax.numpy as jnp +from functools import partial + +from jax.experimental import host_callback + +loadtxt = np.loadtxt + + +backend = 'jax' + +pi = jnp.pi + +diag = jax.jit(jnp.diag) + + +inv = jax.jit(jnp.linalg.inv) + +nan = jnp.nan +interp = jax.jit(jnp.interp) + +exp = jax.jit(jnp.exp) +vstack = jax.jit(jnp.vstack) + +sin = jax.jit(jnp.sin) +cos = jax.jit(jnp.cos) +block = jax.jit(jnp.block) + +real = jax.jit(jnp.real) +imag = jax.jit(jnp.imag) +conj = jax.jit(jnp.conj) + + +arctan = jax.jit(jnp.arctan) + +hstack = jax.jit(jnp.hstack) + + +array = partial(jax.jit, static_argnums=(1, ))(jnp.array) + +roll = partial(jax.jit, static_argnums=(2,))(jnp.roll) +arange = partial(jax.jit, static_argnums=(0, 1, 2))(jnp.arange) + +ones = partial(jax.jit, static_argnums=(0, 1))(jnp.ones) +zeros = partial(jax.jit, static_argnums=(0, 1))(jnp.zeros) + +repeat = partial(jax.jit, static_argnums=(1, 2, ))(jnp.repeat) +tile = partial(jax.jit, static_argnums=(1, ))(jnp.tile) + +linspace = partial(jax.jit, static_argnums=(0, 1, 2))(jnp.linspace) + +eye = partial(jax.jit, static_argnums=(0, ))(jnp.eye) +nonzero = partial(jax.jit, static_argnums=(0, ))(jnp.nonzero) + + +@partial(jax.jit, static_argnums=(3, 4)) +def assign(arr, index, value, row_all=False, col_all=False): + if type(index) == list: + index = tuple(index) + + if row_all: + + print('assign_new') + # # coord = jnp.array([[r,c] for c in index for r in range(arr.shape[0])]).T + # coord1 = jnp.array([[[r, c] for c in index] for r in range(arr.shape[0])]) # TODO: remove loop + # coord = tuple(jnp.moveaxis(coord1, -1, 0)) + # + # arr = arr.at[coord].set(value) + arr = arr.at[:, index].set(value) + elif col_all: + arr = arr.at[index, :].set(value) + else: + arr = arr.at[index].set(value) + return arr + + +@partial(jax.jit, static_argnums=(1, )) +def eig(matrix: jnp.ndarray, type_complex=jnp.complex128) -> tuple[jnp.ndarray, jnp.ndarray]: + """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs.""" + eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], type_complex) + eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, type_complex) + return host_callback.call( + # We force this computation to be performed on the cpu by jit-ing and + # explicitly specifying the device. + jax.jit(jnp.linalg.eig, device=jax.devices('cpu')[0]), + matrix.astype(type_complex), + result_shape=(eigenvalues_shape, eigenvectors_shape), + ) diff --git a/meent/on_jax/rcwa.py b/meent/on_jax/rcwa.py index af288c9..51f5029 100644 --- a/meent/on_jax/rcwa.py +++ b/meent/on_jax/rcwa.py @@ -1,162 +1,84 @@ import time +from functools import partial + +import jax import jax.numpy as jnp +import numpy as np from ._base import _BaseRCWA -from .convolution_matrix import to_conv_mat, find_n_index, fill_factor_to_ucell +from .convolution_matrix import to_conv_mat, put_permittivity_in_ucell, read_material_table +from .field_distribution import field_dist_1d, field_dist_1d_conical, field_dist_2d, field_plot -class RCWAOpt(_BaseRCWA): +class RCWAJax(_BaseRCWA): def __init__(self, mode=0, grating_type=0, n_I=1., n_II=1., theta=0, phi=0, psi=0, fourier_order=40, period=(100,), - wls=jnp.linspace(900, 900, 1), pol=0, patterns=None, thickness=None, algo='TMM'): + wavelength=900, pol=0, patterns=None, ucell=None, ucell_materials=None, + thickness=None, algo='TMM', perturbation=1E-10, + device='cpu', type_complex=np.complex128): + + super().__init__(grating_type, n_I, n_II, theta, phi, psi, fourier_order, period, wavelength, pol, patterns, + ucell, ucell_materials, + thickness, algo, perturbation, device, type_complex) - super().__init__(grating_type, n_I, n_II, theta, phi, psi, fourier_order, period, wls, pol, patterns, - thickness, algo) + self.device = device self.mode = mode - self.spectrum_r, self.spectrum_t = None, None - self.init_spectrum_array() + self.type_complex = type_complex + + self.mat_table = read_material_table() + self.layer_info_list = [] - def solve(self, wl, e_conv_all, o_e_conv_all): + def solve(self, wavelength, e_conv_all, o_e_conv_all): - # TODO: !handle uniform layer + self.get_kx_vector() if self.grating_type == 0: - de_ri, de_ti = self.solve_1d(wl, e_conv_all, o_e_conv_all) + de_ri, de_ti, layer_info_list, T1 = self.solve_1d(wavelength, e_conv_all, o_e_conv_all) elif self.grating_type == 1: - de_ri, de_ti = self.solve_1d_conical(wl, e_conv_all, o_e_conv_all) + de_ri, de_ti, layer_info_list, T1 = self.solve_1d_conical(wavelength, e_conv_all, o_e_conv_all) elif self.grating_type == 2: - de_ri, de_ti = self.solve_2d(wl, e_conv_all, o_e_conv_all) + de_ri, de_ti, layer_info_list, T1 = self.solve_2d(wavelength, e_conv_all, o_e_conv_all) else: raise ValueError - return de_ri.real, de_ti.real + self.layer_info_list = layer_info_list + self.T1 = T1 - def loop_wavelength_fill_factor(self, wavelength_array=None): + return de_ri.real, de_ti.real - if wavelength_array is not None: - self.wls = wavelength_array - self.init_spectrum_array() + def run_ucell(self): - for i, wl in enumerate(self.wls): + ucell = put_permittivity_in_ucell(self.ucell, self.ucell_materials, self.mat_table, self.wavelength, + type_complex=self.type_complex) - ucell = fill_factor_to_ucell(self.patterns, wl, self.grating_type) - e_conv_all = to_conv_mat(ucell, self.fourier_order) - o_e_conv_all = to_conv_mat(1 / ucell, self.fourier_order) + E_conv_all = to_conv_mat(ucell, self.fourier_order, type_complex=self.type_complex) + o_E_conv_all = to_conv_mat(1 / ucell, self.fourier_order, type_complex=self.type_complex) - de_ri, de_ti = self.solve(wl, e_conv_all, o_e_conv_all) + de_ri, de_ti = self.solve(self.wavelength, E_conv_all, o_E_conv_all) - self.spectrum_r = self.spectrum_r.at[i].set(de_ri) - self.spectrum_t = self.spectrum_t.at[i].set(de_ti) + return de_ri, de_ti - return self.spectrum_r, self.spectrum_t + def calculate_field(self, resolution=None, plot=True): - def loop_wavelength_ucell(self): - # si = [[z_begin, z_end], [y_begin, y_end], [x_begin, x_end]] if self.grating_type == 0: - cell = jnp.ones((2, 1, 10)) - si = [3.48, 0, 1, 0, 1, 0, 3] - ox = [3.48, 1, 2, 0, 1, 0, 3] + resolution = [100, 1, 100] if not resolution else resolution + field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, self.T1, + self.layer_info_list, self.period, self.pol, resolution=resolution, + type_complex=self.type_complex) elif self.grating_type == 1: - cell = jnp.ones((2, 1, 10)) - si = [3.48, 0, 1, 0, 1, 0, 3] - ox = [3.48, 1, 2, 0, 1, 0, 3] - elif self.grating_type == 2: - cell = jnp.ones((2, 10, 10)) - si = [3.48, 0, 1, 0, 10, 0, 3] - ox = [3.48, 1, 2, 0, 10, 0, 3] - else: - raise ValueError + resolution = [100, 1, 100] if not resolution else resolution + field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, + self.layer_info_list, self.period, resolution=resolution, + type_complex=self.type_complex) - for i, wl in enumerate(self.wls): - for material, z_begin, z_end, y_begin, y_end, x_begin, x_end in [si, ox]: - n_index = find_n_index(material, wl) if type(material) == str else material - cell = cell.at[z_begin:z_end, y_begin:y_end, x_begin:x_end].set(n_index**2) - - e_conv_all = to_conv_mat(cell, self.fourier_order) - o_e_conv_all = to_conv_mat(1 / cell, self.fourier_order) - - de_ri, de_ti = self.solve(wl, e_conv_all, o_e_conv_all) - - self.spectrum_r = self.spectrum_r.at[i].set(de_ri) - self.spectrum_t = self.spectrum_t.at[i].set(de_ti) - - return self.spectrum_r, self.spectrum_t - - def jax_test(self): - # TODO - # # Z Y X - # # si = [[z_begin, z_end], [y_begin, y_end], [x_begin, x_end]] - # if self.grating_type == 0: - # cell = np.ones((2, 1, 10)) - # si = [3.48, 0, 1, 0, 1, 0, 3] - # ox = [3.48, 1, 2, 0, 1, 0, 3] - # elif self.grating_type == 1: - # pass - # elif self.grating_type == 2: - # cell = np.ones((2, 10, 10)) - # si = [3.48, 0, 1, 0, 10, 0, 3] - # ox = [3.48, 1, 2, 0, 10, 0, 3] - # else: - # raise ValueError - # - # for i, wl in enumerate(self.wls): - # for material, z_begin, z_end, y_begin, y_end, x_begin, x_end in [si, ox]: - # if material is str: - # n_index = find_n_index(material, wl) - # else: - # n_index = material - # cell[z_begin:z_end, y_begin:y_end, x_begin:x_end] = n_index ** 2 - - for i, wl in enumerate(self.wls): - e_conv_all = to_conv_mat(self.patterns, self.fourier_order) - oneover_e_conv_all = to_conv_mat(1 / self.patterns, self.fourier_order) - - de_ri, de_ti = self.solve(wl, e_conv_all, oneover_e_conv_all) - - self.spectrum_r = de_ri - self.spectrum_t = de_ti - - return self.spectrum_r, self.spectrum_t + else: + resolution = [10, 10, 10] if not resolution else resolution + field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, + self.layer_info_list, self.period, resolution=resolution, + type_complex=self.type_complex) + if plot: + field_plot(field_cell, self.pol) + return field_cell if __name__ == '__main__': - grating_type = 0 - pol = 0 - - n_I = 1 - n_II = 1 - - theta = 0 - phi = 0 - psi = 0 if pol else 90 - - wls = jnp.linspace(500, 1300, 100) - # wavelength = np.linspace(600, 800, 3) - - if grating_type in (0, 1): - period = [700] - patterns = [[3.48, 1, 0], [3.48, 1, 0]] # n_ridge, n_groove, fill_factor - fourier_order = 40 - - elif grating_type == 2: - period = [700, 700] - patterns = [[3.48, 1, [0.3, 1]], [3.48, 1, [0.3, 1]]] # n_ridge, n_groove, fill_factor[x, y] - fourier_order = 2 - else: - raise ValueError - - thickness = [460, 660] - - mode = 0 # 0: speed mode; 1: backprop mode; - - AA = RCWAOpt(grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, - fourier_order=fourier_order, wls=wls, period=period, patterns=patterns, thickness=thickness, mode=mode) - t0 = time.perf_counter() - - a, b = AA.loop_wavelength_fill_factor() - AA.plot() - - print(time.perf_counter() - t0) - - # AA.loop_wavelength_ucell() - # AA.plot() - # print('end') + pass diff --git a/meent/on_jax/transfer_method.py b/meent/on_jax/transfer_method.py index fe3cde0..18c299c 100644 --- a/meent/on_jax/transfer_method.py +++ b/meent/on_jax/transfer_method.py @@ -1,11 +1,16 @@ +# import jax.numpy as ee +from functools import partial + +import jax import jax.numpy as jnp -# from .convolution_matrix import * +import meent.on_jax.jitted as ee -def transfer_1d_1(ff, polarization, k0, n_I, n_II, theta, delta_i0, fourier_order,fourier_indices, wl, period): +def transfer_1d_1(ff, polarization, k0, n_I, n_II, kx_vector, theta, delta_i0, fourier_order, + type_complex=jnp.complex128): - kx_vector = k0 * (n_I * jnp.sin(theta) - fourier_indices * (wl / period[0])).astype('complex') + # kx_vector = k0 * (n_I * ee.sin(theta) - fourier_indices * (wavelength / period[0])).astype(type_complex) k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2) ** 0.5 k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2) ** 0.5 @@ -13,215 +18,188 @@ def transfer_1d_1(ff, polarization, k0, n_I, n_II, theta, delta_i0, fourier_orde k_I_z = k_I_z.conjugate() k_II_z = k_II_z.conjugate() - Kx = jnp.diag(kx_vector / k0) + Kx = ee.diag(kx_vector / k0) - f = jnp.eye(ff) + f = ee.eye(ff).astype(type_complex) if polarization == 0: # TE - Y_I = jnp.diag(k_I_z / k0) - Y_II = jnp.diag(k_II_z / k0) + Y_I = ee.diag(k_I_z / k0) + Y_II = ee.diag(k_II_z / k0) YZ_I = Y_I g = 1j * Y_II - inc_term = 1j * n_I * jnp.cos(theta) * delta_i0 + inc_term = 1j * n_I * ee.cos(theta) * delta_i0 elif polarization == 1: # TM - Z_I = jnp.diag(k_I_z / (k0 * n_I ** 2)) - Z_II = jnp.diag(k_II_z / (k0 * n_II ** 2)) + Z_I = ee.diag(k_I_z / (k0 * n_I ** 2)) + Z_II = ee.diag(k_II_z / (k0 * n_II ** 2)) YZ_I = Z_I g = 1j * Z_II - inc_term = 1j * delta_i0 * jnp.cos(theta) / n_I + inc_term = 1j * delta_i0 * ee.cos(theta) / n_I else: raise ValueError - T = jnp.eye(2 * fourier_order + 1) + T = ee.eye(2 * fourier_order + 1).astype(type_complex) + + return kx_vector, Kx, k_I_z, k_II_z, Kx, f, YZ_I, g, inc_term, T - return Kx, k_I_z, k_II_z, Kx, f, YZ_I, g, inc_term, T +def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T, type_complex=jnp.complex128): -def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T): - X = jnp.diag(jnp.exp(-k0 * q * d)) - # TODO: expm + X = ee.diag(ee.exp(-k0 * q * d)) - W_i = jnp.linalg.inv(W) - V_i = jnp.linalg.inv(V) + W_i = ee.inv(W) + V_i = ee.inv(V) a = 0.5 * (W_i @ f + V_i @ g) b = 0.5 * (W_i @ f - V_i @ g) - a_i = jnp.linalg.inv(a) + a_i = ee.inv(a) - f = W @ (jnp.eye(2 * fourier_order + 1) + X @ b @ a_i @ X) - g = V @ (jnp.eye(2 * fourier_order + 1) - X @ b @ a_i @ X) + f = W @ (ee.eye(2 * fourier_order + 1).astype(type_complex) + X @ b @ a_i @ X) + g = V @ (ee.eye(2 * fourier_order + 1).astype(type_complex) - X @ b @ a_i @ X) T = T @ a_i @ X - return f, g, T + return X, f, g, T, a_i, b def transfer_1d_3(g, YZ_I, f, delta_i0, inc_term, T, k_I_z, k0, n_I, n_II, theta, polarization, k_II_z): - Tl = jnp.linalg.inv(g + 1j * YZ_I @ f) @ (1j * YZ_I @ delta_i0 + inc_term) - R = f @ Tl - delta_i0 - T = T @ Tl - de_ri = jnp.real(R * jnp.conj(R) * k_I_z / (k0 * n_I * jnp.cos(theta))) + T1 = ee.inv(g + 1j * YZ_I @ f) @ (1j * YZ_I @ delta_i0 + inc_term) + R = f @ T1 - delta_i0 + T = T @ T1 + + de_ri = ee.real(R * ee.conj(R) * k_I_z / (k0 * n_I * ee.cos(theta))) if polarization == 0: - # de_ti = T * jnp.conj(T) * jnp.real(k_II_z / (k0 * n_I * jnp.cos(theta))) - de_ti = jnp.real(T * jnp.conj(T) * k_II_z / (k0 * n_I * jnp.cos(theta))) + # de_ti = T * ee.conj(T) * ee.real(k_II_z / (k0 * n_I * ee.cos(theta))) + de_ti = ee.real(T * ee.conj(T) * k_II_z / (k0 * n_I * ee.cos(theta))) elif polarization == 1: - # de_ti = T * jnp.conj(T) * jnp.real(k_II_z / n_II ** 2) / (k0 * jnp.cos(theta) / n_I) - de_ti = jnp.real(T * jnp.conj(T) * k_II_z / n_II ** 2) / (k0 * jnp.cos(theta) / n_I) + # de_ti = T * ee.conj(T) * ee.real(k_II_z / n_II ** 2) / (k0 * ee.cos(theta) / n_I) + de_ti = ee.real(T * ee.conj(T) * k_II_z / n_II ** 2) / (k0 * ee.cos(theta) / n_I) else: raise ValueError - return de_ri, de_ti - - -def transfer_2d_1(ff, k0, n_I, n_II, period, fourier_indices, theta, phi, wl, perturbation=1E-20*(1+1j)): - I = jnp.eye(ff ** 2) - O = jnp.zeros((ff ** 2, ff ** 2)) - - kx_vector = k0 * (n_I * jnp.sin(theta) * jnp.cos(phi) - fourier_indices * ( - wl / period[0])).astype('complex') - ky_vector = k0 * (n_I * jnp.sin(theta) * jnp.sin(phi) - fourier_indices * ( - wl / period[1])).astype('complex') - - Kx = jnp.diag(jnp.tile(kx_vector, ff).flatten()) / k0 - Ky = jnp.diag(jnp.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 - - k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 - k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 + return de_ri, de_ti, T1 - k_I_z = k_I_z.flatten().conjugate() - k_II_z = k_II_z.flatten().conjugate() - idx = jnp.nonzero(kx_vector == 0)[0] - if len(idx): - # TODO: need imaginary part? - # TODO: make imaginary part sign consistent - kx_vector = kx_vector.at[idx].set(perturbation) - print(wl, 'varphi divide by 0: adding perturbation') +def transfer_1d_conical_1(ff, k0, n_I, n_II, kx_vector, theta, phi, type_complex=jnp.complex128): - varphi = jnp.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() + I = ee.eye(ff).astype(type_complex) + O = ee.zeros((ff, ff)).astype(type_complex) - Y_I = jnp.diag(k_I_z / k0) - Y_II = jnp.diag(k_II_z / k0) + # kx_vector = k0 * (n_I * ee.sin(theta) * ee.cos(phi) - fourier_indices * (wavelength / period[0]) + # ).astype(type_complex) - Z_I = jnp.diag(k_I_z / (k0 * n_I ** 2)) - Z_II = jnp.diag(k_II_z / (k0 * n_II ** 2)) + ky = k0 * n_I * ee.sin(theta) * ee.sin(phi) - big_F = jnp.block([[I, O], [O, 1j * Z_II]]) - big_G = jnp.block([[1j * Y_II, O], [O, I]]) + k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 + k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 - big_T = jnp.eye(ff ** 2 * 2) + k_I_z = k_I_z.conjugate() + k_II_z = k_II_z.conjugate() - return Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T + Kx = ee.diag(kx_vector / k0) + varphi = ee.arctan(ky / kx_vector) -def transfer_2d_wv(ff, Kx, E_i, Ky, oneover_E_conv_i, E_conv, center): + Y_I = ee.diag(k_I_z / k0) + Y_II = ee.diag(k_II_z / k0) - I = jnp.eye(ff ** 2) - O = jnp.zeros((ff ** 2, ff ** 2)) + Z_I = ee.diag(k_I_z / (k0 * n_I ** 2)) + Z_II = ee.diag(k_II_z / (k0 * n_II ** 2)) - B = Kx @ E_i @ Kx - I - D = Ky @ E_i @ Ky - I + big_F = ee.block([[I, O], [O, 1j * Z_II]]) + big_G = ee.block([[1j * Y_II, O], [O, I]]) - S2_from_S = jnp.block( - [ - [Ky ** 2 + B @ oneover_E_conv_i, Kx @ (E_i @ Ky @ E_conv - Ky)], - [Ky @ (E_i @ Kx @ oneover_E_conv_i - Kx), Kx ** 2 + D @ E_conv] - ]) + big_T = ee.eye(2 * ff).astype(type_complex) - eigenvalues, W = jnp.linalg.eig(S2_from_S) + return Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T - Lambda = eigenvalues ** 0.5 - # Lambda_1 = Lambda[:center] - # Lambda_2 = Lambda[center:] +def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varphi, big_F, big_G, big_T, + type_complex=jnp.complex128): - LAMBDA = jnp.diag(Lambda) - LAMBDA_i = jnp.linalg.inv(LAMBDA) - U1_from_S = jnp.block( - [ - [-Kx @ Ky, Kx ** 2 - E_conv], - [oneover_E_conv_i - Ky ** 2, Ky @ Kx] # TODO Check x y order - ] - ) - V = U1_from_S @ W @ LAMBDA_i + I = ee.eye(ff).astype(type_complex) + O = ee.zeros((ff, ff)).astype(type_complex) - return W, V, LAMBDA, Lambda + A = Kx ** 2 - E_conv + B = Kx @ E_conv_i @ Kx - I + A_i = ee.inv(A) + B_i = ee.inv(B) + to_decompose_W_1 = ky ** 2 * I + A + to_decompose_W_2 = ky ** 2 * I + B @ o_E_conv_i -def transfer_2d_2(k0, d, W, V, center, Lambda, varphi, I, O, big_F, big_G, big_T): + eigenvalues_1, W_1 = ee.eig(to_decompose_W_1, type_complex=type_complex) + eigenvalues_2, W_2 = ee.eig(to_decompose_W_2, type_complex=type_complex) - Lambda_1 = Lambda[:center] - Lambda_2 = Lambda[center:] + q_1 = eigenvalues_1 ** 0.5 + q_2 = eigenvalues_2 ** 0.5 - W_11 = W[:center, :center] - W_12 = W[:center, center:] - W_21 = W[center:, :center] - W_22 = W[center:, center:] + Q_1 = ee.diag(q_1) + Q_2 = ee.diag(q_2) - V_11 = V[:center, :center] - V_12 = V[:center, center:] - V_21 = V[center:, :center] - V_22 = V[center:, center:] + V_11 = A_i @ W_1 @ Q_1 + V_12 = (ky / k0) * A_i @ Kx @ W_2 + V_21 = (ky / k0) * B_i @ Kx @ E_conv_i @ W_1 + V_22 = B_i @ W_2 @ Q_2 - X_1 = jnp.diag(jnp.exp(-k0 * Lambda_1 * d)) - X_2 = jnp.diag(jnp.exp(-k0 * Lambda_2 * d)) - # TODO: expm + X_1 = ee.diag(ee.exp(-k0 * q_1 * d)) + X_2 = ee.diag(ee.exp(-k0 * q_2 * d)) - F_c = jnp.diag(jnp.cos(varphi)) - F_s = jnp.diag(jnp.sin(varphi)) + F_c = ee.diag(ee.cos(varphi)) + F_s = ee.diag(ee.sin(varphi)) - W_ss = F_c @ W_21 - F_s @ W_11 - W_sp = F_c @ W_22 - F_s @ W_12 - W_ps = F_c @ W_11 + F_s @ W_21 - W_pp = F_c @ W_12 + F_s @ W_22 - - V_ss = F_c @ V_11 + F_s @ V_21 - V_sp = F_c @ V_12 + F_s @ V_22 - V_ps = F_c @ V_21 - F_s @ V_11 - V_pp = F_c @ V_22 - F_s @ V_12 + V_ss = F_c @ V_11 + V_sp = F_c @ V_12 - F_s @ W_2 + W_ss = F_c @ W_1 + F_s @ V_21 + W_sp = F_s @ V_22 + W_ps = F_s @ V_11 + W_pp = F_c @ W_2 + F_s @ V_12 + V_ps = F_c @ V_21 - F_s @ W_1 + V_pp = F_c @ V_22 - big_I = jnp.eye(2 * (len(I))) - big_X = jnp.block([[X_1, O], [O, X_2]]) - big_W = jnp.block([[W_ss, W_sp], [W_ps, W_pp]]) - big_V = jnp.block([[V_ss, V_sp], [V_ps, V_pp]]) + big_I = ee.eye(2 * (len(I))).astype(type_complex) + big_X = ee.block([[X_1, O], [O, X_2]]) + big_W = ee.block([[V_ss, V_sp], [W_ps, W_pp]]) + big_V = ee.block([[W_ss, W_sp], [V_ps, V_pp]]) - big_W_i = jnp.linalg.inv(big_W) - big_V_i = jnp.linalg.inv(big_V) + big_W_i = ee.inv(big_W) + big_V_i = ee.inv(big_V) big_A = 0.5 * (big_W_i @ big_F + big_V_i @ big_G) big_B = 0.5 * (big_W_i @ big_F - big_V_i @ big_G) - big_A_i = jnp.linalg.inv(big_A) + big_A_i = ee.inv(big_A) big_F = big_W @ (big_I + big_X @ big_B @ big_A_i @ big_X) big_G = big_V @ (big_I - big_X @ big_B @ big_A_i @ big_X) big_T = big_T @ big_A_i @ big_X - return big_F, big_G, big_T + return big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 -def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z): - I = jnp.eye(ff ** 2) - O = jnp.zeros((ff ** 2, ff ** 2)) +def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, + type_complex=jnp.complex128): - big_F_11 = big_F[:center, :center] - big_F_12 = big_F[:center, center:] - big_F_21 = big_F[center:, :center] - big_F_22 = big_F[center:, center:] + I = ee.eye(ff).astype(type_complex) + O = ee.zeros((ff, ff), dtype=type_complex) - big_G_11 = big_G[:center, :center] - big_G_12 = big_G[:center, center:] - big_G_21 = big_G[center:, :center] - big_G_22 = big_G[center:, center:] + big_F_11 = big_F[:ff, :ff] + big_F_12 = big_F[:ff, ff:] + big_F_21 = big_F[ff:, :ff] + big_F_22 = big_F[ff:, ff:] + + big_G_11 = big_G[:ff, :ff] + big_G_12 = big_G[:ff, ff:] + big_G_21 = big_G[ff:, :ff] + big_G_22 = big_G[ff:, ff:] # Final Equation in form of AX=B - final_A = jnp.block( + final_A = ee.block( [ [I, O, -big_F_11, -big_F_12], [O, -1j * Z_I, -big_F_21, -big_F_22], @@ -230,164 +208,171 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i ] ) - final_B = jnp.block( - [ - [-jnp.sin(psi) * delta_i0], - [-jnp.cos(psi) * jnp.cos(theta) * delta_i0], - [-1j * jnp.sin(psi) * n_I * jnp.cos(theta) * delta_i0], - [1j * n_I * jnp.cos(psi) * delta_i0] - ] - ) + # tODO: correct? + final_B = ee.hstack([ + [-ee.sin(psi) * delta_i0], + [-ee.cos(psi) * ee.cos(theta) * delta_i0], + [-1j * ee.sin(psi) * n_I * ee.cos(theta) * delta_i0], + [1j * n_I * ee.cos(psi) * delta_i0] + ]).T - final_X = jnp.linalg.inv(final_A) @ final_B + final_RT = ee.inv(final_A) @ final_B - R_s = final_X[:ff ** 2, :].flatten() - R_p = final_X[ff ** 2:2 * ff ** 2, :].flatten() + R_s = final_RT[:ff, :].flatten() + R_p = final_RT[ff:2 * ff, :].flatten() - big_T = big_T @ final_X[2 * ff ** 2:, :] - T_s = big_T[:ff ** 2, :].flatten() - T_p = big_T[ff ** 2:, :].flatten() + big_T1 = final_RT[2 * ff:, :] + big_T = big_T @ big_T1 + + T_s = big_T[:ff, :].flatten() + T_p = big_T[ff:, :].flatten() - de_ri = R_s * jnp.conj(R_s) * jnp.real(k_I_z / (k0 * n_I * jnp.cos(theta))) \ - + R_p * jnp.conj(R_p) * jnp.real((k_I_z / n_I ** 2) / (k0 * n_I * jnp.cos(theta))) + de_ri = R_s * ee.conj(R_s) * ee.real(k_I_z / (k0 * n_I * ee.cos(theta))) \ + + R_p * ee.conj(R_p) * ee.real((k_I_z / n_I ** 2) / (k0 * n_I * ee.cos(theta))) - de_ti = T_s * jnp.conj(T_s) * jnp.real(k_II_z / (k0 * n_I * jnp.cos(theta))) \ - + T_p * jnp.conj(T_p) * jnp.real((k_II_z / n_II ** 2) / (k0 * n_I * jnp.cos(theta))) + de_ti = T_s * ee.conj(T_s) * ee.real(k_II_z / (k0 * n_I * ee.cos(theta))) \ + + T_p * ee.conj(T_p) * ee.real((k_II_z / n_II ** 2) / (k0 * n_I * ee.cos(theta))) - # Aa = de_ri.sum() - # Aaa = de_ti.sum() - # - # if Aa + Aaa != 1: - # # TODO: no problem? or should be handled? - # print(1) - # wavelength = 1463.6363636363637 - # deri = 350 - # - # wavelength = 1978.9715332727274 - # deri = 558 + return de_ri.real, de_ti.real, big_T1 - return de_ri.real, de_ti.real +def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, phi, wavelength, + type_complex=jnp.complex128): -def transfer_1d_conical_1(ff, k0, n_I, n_II, period, fourier_indices, theta, phi, wl, perturbation=1E-20*(1+1j)): - I = jnp.eye(ff) - O = jnp.zeros((ff, ff)) + I = ee.eye(ff ** 2).astype(type_complex) + O = ee.zeros((ff ** 2, ff ** 2), dtype=type_complex) - kx_vector = k0 * (n_I * jnp.sin(theta) * jnp.cos(phi) - fourier_indices * (wl / period[0])).astype( - 'complex') - ky = k0 * n_I * jnp.sin(theta) * jnp.sin(phi) + # kx_vector = k0 * (n_I * ee.sin(theta) * ee.cos(phi) - fourier_indices * ( + # wavelength / period[0])).astype(type_complex) - k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 - k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 + ky_vector = k0 * (n_I * ee.sin(theta) * ee.sin(phi) - fourier_indices * ( + wavelength / period[1])).astype(type_complex) - k_I_z = k_I_z.conjugate() - k_II_z = k_II_z.conjugate() + k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 + k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 - idx = jnp.nonzero(kx_vector == 0)[0] - if len(idx): - # TODO: need imaginary part? - # TODO: make imaginary part sign consistent - kx_vector = kx_vector.at[idx].set(perturbation) - print(wl, 'varphi divide by 0: adding perturbation') + k_I_z = k_I_z.flatten().conjugate() + k_II_z = k_II_z.flatten().conjugate() - varphi = jnp.arctan(ky / kx_vector) + Kx = ee.diag(ee.tile(kx_vector, ff).flatten()) / k0 + Ky = ee.diag(ee.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 - Y_I = jnp.diag(k_I_z / k0) - Y_II = jnp.diag(k_II_z / k0) + varphi = ee.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() - Z_I = jnp.diag(k_I_z / (k0 * n_I ** 2)) - Z_II = jnp.diag(k_II_z / (k0 * n_II ** 2)) + Y_I = ee.diag(k_I_z / k0) + Y_II = ee.diag(k_II_z / k0) - Kx = jnp.diag(kx_vector / k0) + Z_I = ee.diag(k_I_z / (k0 * n_I ** 2)) + Z_II = ee.diag(k_II_z / (k0 * n_II ** 2)) - big_F = jnp.block([[I, O], [O, 1j * Z_II]]) - big_G = jnp.block([[1j * Y_II, O], [O, I]]) + big_F = ee.block([[I, O], [O, 1j * Z_II]]) + big_G = ee.block([[1j * Y_II, O], [O, I]]) - big_T = jnp.eye(2 * ff) + big_T = ee.eye(ff ** 2 * 2).astype(type_complex) - return Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T + return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T -def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_i, oneover_E_conv_i, ff, d, varphi, big_F, big_G, big_T): +def transfer_2d_wv(ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=jnp.complex128): - I = jnp.eye(ff) - O = jnp.zeros((ff, ff)) + I = ee.eye(ff ** 2).astype(type_complex) - A = Kx ** 2 - E_conv - B = Kx @ E_i @ Kx - I - A_i = jnp.linalg.inv(A) - B_i = jnp.linalg.inv(B) + B = Kx @ E_conv_i @ Kx - I + D = Ky @ E_conv_i @ Ky - I - to_decompose_W_1 = ky ** 2 * I + A - to_decompose_W_2 = ky ** 2 * I + B @ oneover_E_conv_i + S2_from_S = ee.block( + [ + [Ky ** 2 + B @ o_E_conv_i, Kx @ (E_conv_i @ Ky @ E_conv - Ky)], + [Ky @ (E_conv_i @ Kx @ o_E_conv_i - Kx), Kx ** 2 + D @ E_conv] + ]) - # TODO: using eigh? - eigenvalues_1, W_1 = jnp.linalg.eig(to_decompose_W_1) - eigenvalues_2, W_2 = jnp.linalg.eig(to_decompose_W_2) + eigenvalues, W = ee.eig(S2_from_S, type_complex=type_complex) - q_1 = eigenvalues_1 ** 0.5 - q_2 = eigenvalues_2 ** 0.5 + q = eigenvalues ** 0.5 - Q_1 = jnp.diag(q_1) - Q_2 = jnp.diag(q_2) + Q = ee.diag(q) + Q_i = ee.inv(Q) + U1_from_S = ee.block( + [ + [-Kx @ Ky, Kx ** 2 - E_conv], + [o_E_conv_i - Ky ** 2, Ky @ Kx] + ] + ) + V = U1_from_S @ W @ Q_i - V_11 = A_i @ W_1 @ Q_1 - V_12 = (ky / k0) * A_i @ Kx @ W_2 - V_21 = (ky / k0) * B_i @ Kx @ E_i @ W_1 - V_22 = B_i @ W_2 @ Q_2 + return W, V, q - X_1 = jnp.diag(jnp.exp(-k0 * q_1 * d)) - X_2 = jnp.diag(jnp.exp(-k0 * q_2 * d)) - F_c = jnp.diag(jnp.cos(varphi)) - F_s = jnp.diag(jnp.sin(varphi)) +def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, type_complex=jnp.complex128): - V_ss = F_c @ V_11 - V_sp = F_c @ V_12 - F_s @ W_2 - W_ss = F_c @ W_1 + F_s @ V_21 - W_sp = F_s @ V_22 - W_ps = F_s @ V_11 - W_pp = F_c @ W_2 + F_s @ V_12 - V_ps = F_c @ V_21 - F_s @ W_1 - V_pp = F_c @ V_22 + q1 = q[:center] + q2 = q[center:] - big_I = jnp.eye(2 * (len(I))) - big_X = jnp.block([[X_1, O], [O, X_2]]) - big_W = jnp.block([[V_ss, V_sp], [W_ps, W_pp]]) - big_V = jnp.block([[W_ss, W_sp], [V_ps, V_pp]]) + W_11 = W[:center, :center] + W_12 = W[:center, center:] + W_21 = W[center:, :center] + W_22 = W[center:, center:] - big_W_i = jnp.linalg.inv(big_W) - big_V_i = jnp.linalg.inv(big_V) + V_11 = V[:center, :center] + V_12 = V[:center, center:] + V_21 = V[center:, :center] + V_22 = V[center:, center:] + + X_1 = ee.diag(ee.exp(-k0 * q1 * d)) + X_2 = ee.diag(ee.exp(-k0 * q2 * d)) + + F_c = ee.diag(ee.cos(varphi)) + F_s = ee.diag(ee.sin(varphi)) + + W_ss = F_c @ W_21 - F_s @ W_11 + W_sp = F_c @ W_22 - F_s @ W_12 + W_ps = F_c @ W_11 + F_s @ W_21 + W_pp = F_c @ W_12 + F_s @ W_22 + + V_ss = F_c @ V_11 + F_s @ V_21 + V_sp = F_c @ V_12 + F_s @ V_22 + V_ps = F_c @ V_21 - F_s @ V_11 + V_pp = F_c @ V_22 - F_s @ V_12 + + big_I = ee.eye(2 * (len(I))).astype(type_complex) + big_X = ee.block([[X_1, O], [O, X_2]]) + big_W = ee.block([[W_ss, W_sp], [W_ps, W_pp]]) + big_V = ee.block([[V_ss, V_sp], [V_ps, V_pp]]) + + big_W_i = ee.inv(big_W) + big_V_i = ee.inv(big_V) big_A = 0.5 * (big_W_i @ big_F + big_V_i @ big_G) big_B = 0.5 * (big_W_i @ big_F - big_V_i @ big_G) - big_A_i = jnp.linalg.inv(big_A) + big_A_i = ee.inv(big_A) big_F = big_W @ (big_I + big_X @ big_B @ big_A_i @ big_X) big_G = big_V @ (big_I - big_X @ big_B @ big_A_i @ big_X) big_T = big_T @ big_A_i @ big_X - return big_F, big_G, big_T + return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 -def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z): - I = jnp.eye(ff) - O = jnp.zeros((ff, ff)) +def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, + type_complex=jnp.complex128): - big_F_11 = big_F[:ff, :ff] - big_F_12 = big_F[:ff, ff:] - big_F_21 = big_F[ff:, :ff] - big_F_22 = big_F[ff:, ff:] + I = ee.eye(ff ** 2).astype(type_complex) + O = ee.zeros((ff ** 2, ff ** 2), dtype=type_complex) - big_G_11 = big_G[:ff, :ff] - big_G_12 = big_G[:ff, ff:] - big_G_21 = big_G[ff:, :ff] - big_G_22 = big_G[ff:, ff:] + big_F_11 = big_F[:center, :center] + big_F_12 = big_F[:center, center:] + big_F_21 = big_F[center:, :center] + big_F_22 = big_F[center:, center:] + + big_G_11 = big_G[:center, :center] + big_G_12 = big_G[:center, center:] + big_G_21 = big_G[center:, :center] + big_G_22 = big_G[center:, center:] # Final Equation in form of AX=B - final_A = jnp.block( + final_A = ee.block( [ [I, O, -big_F_11, -big_F_12], [O, -1j * Z_I, -big_F_21, -big_F_22], @@ -396,27 +381,30 @@ def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i ] ) - final_B = jnp.hstack([ - [-jnp.sin(psi) * delta_i0], - [-jnp.cos(psi) * jnp.cos(theta) * delta_i0], - [-1j * jnp.sin(psi) * n_I * jnp.cos(theta) * delta_i0], - [1j * n_I * jnp.cos(psi) * delta_i0] - ]).T + final_B = ee.block( + [ + [-ee.sin(psi) * delta_i0], + [-ee.cos(psi) * ee.cos(theta) * delta_i0], + [-1j * ee.sin(psi) * n_I * ee.cos(theta) * delta_i0], + [1j * n_I * ee.cos(psi) * delta_i0] + ] + ) - final_X = jnp.linalg.inv(final_A) @ final_B + final_RT = ee.inv(final_A) @ final_B - R_s = final_X[:ff, :].flatten() - R_p = final_X[ff:2 * ff, :].flatten() + R_s = final_RT[:ff ** 2, :].flatten() + R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() - big_T = big_T @ final_X[2 * ff:, :] - T_s = big_T[:ff, :].flatten() - T_p = big_T[ff:, :].flatten() + big_T1 = final_RT[2 * ff ** 2:, :] + big_T = big_T @ big_T1 - de_ri = R_s * jnp.conj(R_s) * jnp.real(k_I_z / (k0 * n_I * jnp.cos(theta))) \ - + R_p * jnp.conj(R_p) * jnp.real((k_I_z / n_I ** 2) / (k0 * n_I * jnp.cos(theta))) + T_s = big_T[:ff ** 2, :].flatten() + T_p = big_T[ff ** 2:, :].flatten() - de_ti = T_s * jnp.conj(T_s) * jnp.real(k_II_z / (k0 * n_I * jnp.cos(theta))) \ - + T_p * jnp.conj(T_p) * jnp.real((k_II_z / n_II ** 2) / (k0 * n_I * jnp.cos(theta))) + de_ri = R_s * ee.conj(R_s) * ee.real(k_I_z / (k0 * n_I * ee.cos(theta))) \ + + R_p * ee.conj(R_p) * ee.real((k_I_z / n_I ** 2) / (k0 * n_I * ee.cos(theta))) - return de_ri.real, de_ti.real + de_ti = T_s * ee.conj(T_s) * ee.real(k_II_z / (k0 * n_I * ee.cos(theta))) \ + + T_p * ee.conj(T_p) * ee.real((k_II_z / n_II ** 2) / (k0 * n_I * ee.cos(theta))) + return de_ri.real, de_ti.real, big_T1 diff --git a/meent/on_numpy/_base.py b/meent/on_numpy/_base.py index d35bd92..1b0c027 100644 --- a/meent/on_numpy/_base.py +++ b/meent/on_numpy/_base.py @@ -1,90 +1,21 @@ -import scipy - import numpy as np -import matplotlib.pyplot as plt -from .scattering_method import scattering_1d_1, scattering_1d_2, scattering_1d_3, scattering_2d_1, scattering_2d_wv,\ +from copy import deepcopy + +from .scattering_method import scattering_1d_1, scattering_1d_2, scattering_1d_3, scattering_2d_1, scattering_2d_wv, \ scattering_2d_2, scattering_2d_3 -from .transfer_method import transfer_1d_1, transfer_1d_2, transfer_1d_3, transfer_1d_conical_1, transfer_1d_conical_2,\ +from .transfer_method import transfer_1d_1, transfer_1d_2, transfer_1d_3, transfer_1d_conical_1, transfer_1d_conical_2, \ transfer_1d_conical_3, transfer_2d_1, transfer_2d_wv, transfer_2d_2, transfer_2d_3 -# class Base: -# def __init__(self, grating_type): -# self.grating_type = grating_type -# self.wavelength = None -# self.fourier_order = None -# self.spectrum_r = None -# self.spectrum_t = None -# -# def init_spectrum_array(self): -# if self.grating_type in (0, 1): -# self.spectrum_r = np.zeros((len(self.wavelength), 2 * self.fourier_order + 1)) -# self.spectrum_t = np.zeros((len(self.wavelength), 2 * self.fourier_order + 1)) -# elif self.grating_type == 2: -# self.spectrum_r = np.zeros((len(self.wavelength), 2 * self.fourier_order + 1, 2 * self.fourier_order + 1)) -# self.spectrum_t = np.zeros((len(self.wavelength), 2 * self.fourier_order + 1, 2 * self.fourier_order + 1)) -# else: -# raise ValueError -# -# def save_spectrum_array(self, de_ri, de_ti, i): -# de_ri = np.array(de_ri) -# de_ti = np.array(de_ti) -# -# if not de_ri.shape: -# # 1D or may be not; there is a case that reticolo returns single value -# c = self.spectrum_r.shape[1] // 2 -# self.spectrum_r[i][c] = de_ri -# -# elif len(de_ri.shape) == 1 or de_ri.shape[1] == 1: -# de_ri = de_ri.flatten() -# c = self.spectrum_r.shape[1] // 2 -# l = de_ri.shape[0] // 2 -# if len(de_ri) % 2: -# self.spectrum_r[i][c - l:c + l + 1] = de_ri -# else: -# self.spectrum_r[i][c - l:c + l] = de_ri -# -# else: -# print('no code') -# raise ValueError -# -# if not de_ti.shape: # 1D -# c = self.spectrum_t.shape[1] // 2 -# self.spectrum_t[i][c] = de_ti -# -# elif len(de_ti.shape) == 1 or de_ti.shape[1] == 1: # 1D -# de_ti = de_ti.flatten() -# c = self.spectrum_t.shape[1] // 2 -# l = de_ti.shape[0] // 2 -# if len(de_ti) % 2: -# self.spectrum_t[i][c - l:c + l + 1] = de_ti -# else: -# self.spectrum_t[i][c - l:c + l] = de_ti -# -# else: -# print('no code') -# raise ValueError -# -# def plot(self, title=None, marker=None): -# if self.grating_type in (0, 1): -# plt.plot(self.wavelength, self.spectrum_r.sum(axis=1), marker=marker) -# plt.plot(self.wavelength, self.spectrum_t.sum(axis=1), marker=marker) -# elif self.grating_type == 2: -# plt.plot(self.wavelength, self.spectrum_r.sum(axis=(1, 2)), marker=marker) -# plt.plot(self.wavelength, self.spectrum_t.sum(axis=(1, 2)), marker=marker) -# else: -# raise ValueError -# plt.title(title) -# plt.show() - -# class _BaseRCWA(Base): - class _BaseRCWA: def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., fourier_order=10, period=0.7, wavelength=np.linspace(0.5, 2.3, 400), pol=0, - patterns=None, ucell=None, ucell_materials=None, thickness=None, algo='TMM'): - # super().__init__(grating_type) + patterns=None, ucell=None, ucell_materials=None, thickness=None, algo='TMM', perturbation=1E-10, + device='cpu', type_complex=np.complex128): + + self.device = device + self.type_complex = type_complex self.grating_type = grating_type # 1D=0, 1D_conical=1, 2D=2 self.n_I = n_I @@ -106,35 +37,60 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., four self.fourier_order = fourier_order self.ff = 2 * self.fourier_order + 1 - self.period = period + self.period = deepcopy(period) self.wavelength = wavelength self.patterns = patterns - self.ucell = ucell + self.ucell = deepcopy(ucell) self.ucell_materials = ucell_materials - self.thickness = thickness + self.thickness = deepcopy(thickness) self.algo = algo - - # self.init_spectrum_array() + self.perturbation = perturbation self.layer_info_list = [] self.T1 = None + self.kx_vector = None + + def get_kx_vector(self): + + k0 = 2 * np.pi / self.wavelength + fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) + if self.grating_type == 0: + kx_vector = k0 * (self.n_I * np.sin(self.theta) - fourier_indices * (self.wavelength / self.period[0]) + ).astype(self.type_complex) + else: + kx_vector = k0 * (self.n_I * np.sin(self.theta) * np.cos(self.phi) - fourier_indices * ( + self.wavelength / self.period[0]) + ).astype(self.type_complex) + + idx = np.nonzero(kx_vector == 0)[0] + if len(idx): + # TODO: need imaginary part? + # TODO: make imaginary part sign consistent + kx_vector[idx] = self.perturbation + print('varphi divide by 0: adding perturbation') + + self.kx_vector = kx_vector + def solve_1d(self, wl, E_conv_all, o_E_conv_all): + self.layer_info_list = [] + self.T1 = None + fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) - delta_i0 = np.zeros(self.ff) + delta_i0 = np.zeros(self.ff, dtype=self.type_complex) delta_i0[self.fourier_order] = 1 k0 = 2 * np.pi / wl if self.algo == 'TMM': kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T \ - = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, - self.theta, delta_i0, self.fourier_order, fourier_indices, wl, self.period) + = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, + self.theta, delta_i0, self.fourier_order, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, fourier_indices, self.period, @@ -142,8 +98,14 @@ def solve_1d(self, wl, E_conv_all, o_E_conv_all): else: raise ValueError + count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) + # From the last layer - for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + for layer_index in range(count)[::-1]: + + E_conv = E_conv_all[layer_index] + o_E_conv = o_E_conv_all[layer_index] + d = self.thickness[layer_index] if self.pol == 0: E_conv_i = None @@ -156,7 +118,7 @@ def solve_1d(self, wl, E_conv_all, o_E_conv_all): elif self.pol == 1: E_conv_i = np.linalg.inv(E_conv) - B = Kx @ E_conv_i @ Kx - np.eye(E_conv.shape[0]) + B = Kx @ E_conv_i @ Kx - np.eye(E_conv.shape[0], dtype=self.type_complex) o_E_conv_i = np.linalg.inv(o_E_conv) eigenvalues, W = np.linalg.eig(o_E_conv_i @ B) @@ -169,7 +131,8 @@ def solve_1d(self, wl, E_conv_all, o_E_conv_all): raise ValueError if self.algo == 'TMM': - X, f, g, T, a_i, b = transfer_1d_2(k0, q, d, W, V, f, g, self.fourier_order, T) + X, f, g, T, a_i, b = transfer_1d_2(k0, q, d, W, V, f, g, self.fourier_order, T, + type_complex=self.type_complex) layer_info = [E_conv_i, q, W, X, a_i, b, d] self.layer_info_list.append(layer_info) @@ -181,7 +144,7 @@ def solve_1d(self, wl, E_conv_all, o_E_conv_all): if self.algo == 'TMM': de_ri, de_ti, T1 = transfer_1d_3(g, YZ_I, f, delta_i0, inc_term, T, k_I_z, k0, self.n_I, self.n_II, - self.theta, self.pol, k_II_z) + self.theta, self.pol, k_II_z) self.T1 = T1 elif self.algo == 'SMM': @@ -193,39 +156,59 @@ def solve_1d(self, wl, E_conv_all, o_E_conv_all): return de_ri, de_ti # TODO: scattering method - def solve_1d_conical(self, wl, e_conv_all, o_e_conv_all): + def solve_1d_conical(self, wl, E_conv_all, o_E_conv_all): - fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) + self.layer_info_list = [] + self.T1 = None + + # fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) - delta_i0 = np.zeros(self.ff) + delta_i0 = np.zeros(self.ff, dtype=self.type_complex) delta_i0[self.fourier_order] = 1 k0 = 2 * np.pi / wl if self.algo == 'TMM': Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.period, fourier_indices, self.theta, self.phi, wl) + = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, + type_complex=self.type_complex) elif self.algo == 'SMM': print('SMM for 1D conical is not implemented') return np.nan, np.nan else: raise ValueError - for e_conv, o_e_conv, d in zip(e_conv_all[::-1], o_e_conv_all[::-1], self.thickness[::-1]): - e_conv_i = np.linalg.inv(e_conv) - o_e_conv_i = np.linalg.inv(o_e_conv) + count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) + + # From the last layer + for layer_index in range(count)[::-1]: + + E_conv = E_conv_all[layer_index] + o_E_conv = o_E_conv_all[layer_index] + d = self.thickness[layer_index] + + E_conv_i = np.linalg.inv(E_conv) + o_E_conv_i = np.linalg.inv(o_E_conv) if self.algo == 'TMM': - big_F, big_G, big_T = transfer_1d_conical_2(k0, Kx, ky, e_conv, e_conv_i, o_e_conv_i, self.ff, d, - varphi, big_F, big_G, big_T) + big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \ + = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, self.ff, d, + varphi, big_F, big_G, big_T, + type_complex=self.type_complex) + layer_info = [E_conv_i, q_1, q_2, W_1, W_2, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] + self.layer_info_list.append(layer_info) + elif self.algo == 'SMM': raise ValueError else: raise ValueError if self.algo == 'TMM': - de_ri, de_ti = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, - delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z) + de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, + type_complex=self.type_complex) + self.T1 = big_T1 + elif self.algo == 'SMM': raise ValueError else: @@ -235,13 +218,16 @@ def solve_1d_conical(self, wl, e_conv_all, o_e_conv_all): def solve_2d(self, wl, E_conv_all, o_E_conv_all): + self.layer_info_list = [] + self.T1 = None + fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) - delta_i0 = np.zeros((self.ff ** 2, 1)) + delta_i0 = np.zeros((self.ff ** 2, 1), dtype=self.type_complex) delta_i0[self.ff ** 2 // 2, 0] = 1 - I = np.eye(self.ff ** 2) - O = np.zeros((self.ff ** 2, self.ff ** 2)) + I = np.eye(self.ff ** 2, dtype=self.type_complex) + O = np.zeros((self.ff ** 2, self.ff ** 2), dtype=self.type_complex) center = self.ff ** 2 @@ -249,24 +235,33 @@ def solve_2d(self, wl, E_conv_all, o_E_conv_all): if self.algo == 'TMM': kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.period, fourier_indices, self.theta, self.phi, wl) + = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices, + self.theta, self.phi, wl, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Ky, kz_inc, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ = scattering_2d_1(self.n_I, self.n_II, self.theta, self.phi, k0, self.period, self.fourier_order) else: raise ValueError + count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) + # From the last layer - for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + for layer_index in range(count)[::-1]: + + E_conv = E_conv_all[layer_index] + o_E_conv = o_E_conv_all[layer_index] + d = self.thickness[layer_index] + E_conv_i = np.linalg.inv(E_conv) o_E_conv_i = np.linalg.inv(o_E_conv) if self.algo == 'TMM': # TODO: MERGE W V part - W, V, q = transfer_2d_wv(self.ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, center) + W, V, q = transfer_2d_wv(self.ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex) big_X, big_F, big_G, big_T, big_A_i, big_B, \ W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 \ - = transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T) + = transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, + type_complex=self.type_complex) layer_info = [E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] self.layer_info_list.append(layer_info) @@ -279,7 +274,8 @@ def solve_2d(self, wl, E_conv_all, o_E_conv_all): if self.algo == 'TMM': de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, - delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z) + delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, + type_complex=self.type_complex) self.T1 = big_T1 elif self.algo == 'SMM': diff --git a/meent/on_numpy/convolution_matrix.py b/meent/on_numpy/convolution_matrix.py index ba5e382..4ba40da 100644 --- a/meent/on_numpy/convolution_matrix.py +++ b/meent/on_numpy/convolution_matrix.py @@ -1,54 +1,17 @@ import copy +import time + import numpy as np from os import walk from scipy.io import loadmat -from scipy.linalg import circulant +from scipy.linalg import circulant as circulant_scipy from pathlib import Path -# def put_n_ridge_in_pattern_fill_factor(pattern_all, mat_table, wl): -# -# pattern_all = copy.deepcopy(pattern_all) -# -# for i, (n_ridge, n_groove, pattern) in enumerate(pattern_all): -# -# if type(n_ridge) == str: -# material = n_ridge -# n_ridge = find_nk_index(material, mat_table, wl) -# pattern_all[i][0] = n_ridge -# return pattern_all - - -# def get_material_index_in_ucell(ucell_comp, mat_list): -# -# res = [[[] for _ in mat_list] for _ in ucell_comp] -# -# for z, ucell_xy in enumerate(ucell_comp): -# for y in range(ucell_xy.shape[0]): -# for x in range(ucell_xy.shape[1]): -# res[z][ucell_xy[y, x]].append([y, x]) -# return res - - -# def put_permittivity_in_ucell_object_comps(ucell, mat_list, obj_list, mat_table, wl): -# -# res = np.zeros(ucell.shape, dtype='complex') -# -# for obj_xy in obj_list: -# for material, obj_index in zip(mat_list, obj_xy): -# obj_index = np.array(obj_index).T -# if type(material) == str: -# res[obj_index[0], obj_index[1]] = find_nk_index(material, mat_table, wl) ** 2 -# else: -# res[obj_index[0], obj_index[1]] = material ** 2 -# -# return res - - -def put_permittivity_in_ucell(ucell, mat_list, mat_table, wl): - - res = np.zeros(ucell.shape, dtype='complex') +def put_permittivity_in_ucell(ucell, mat_list, mat_table, wl, type_complex=np.complex128): + + res = np.zeros(ucell.shape, dtype=type_complex) for z in range(ucell.shape[0]): for y in range(ucell.shape[1]): @@ -62,9 +25,10 @@ def put_permittivity_in_ucell(ucell, mat_list, mat_table, wl): return res -def put_permittivity_in_ucell_object(ucell_size, mat_list, obj_list, mat_table, wl): +def put_permittivity_in_ucell_object(ucell_size, mat_list, obj_list, mat_table, wl, + type_complex=np.complex128): # TODO: under development - res = np.zeros(ucell_size, dtype='complex') + res = np.zeros(ucell_size, dtype=type_complex) for material, obj_index in zip(mat_list, obj_list): if type(material) == str: @@ -117,16 +81,15 @@ def read_material_table(nk_path=None): return mat_table -# def fill_factor_to_ucell(patterns_fill_factor, wl, grating_type, mat_table): -# pattern_fill_factor = put_n_ridge_in_pattern_fill_factor(patterns_fill_factor, mat_table, wl) -# ucell = draw_fill_factor(pattern_fill_factor, grating_type) -# -# return ucell +def cell_compression(cell, type_complex=np.complex128): + if type_complex == np.complex128: + type_float = np.float64 + else: + type_float = np.float32 -def cell_compression(cell): # find discontinuities in x - step_y, step_x = 1. / np.array(cell.shape) + step_y, step_x = 1. / np.array(cell.shape, dtype=type_float) x = [] y = [] cell_x = [] @@ -154,12 +117,13 @@ def cell_compression(cell): return cell_comp, x, y -def fft_piecewise_constant(cell, fourier_order): +def fft_piecewise_constant(cell, fourier_order, type_complex=np.complex128): + if cell.shape[0] == 1: fourier_order = [0, fourier_order] else: fourier_order = [fourier_order, fourier_order] - cell, x, y = cell_compression(cell) + cell, x, y = cell_compression(cell, type_complex=type_complex) # X axis cell_next_x = np.roll(cell, -1, axis=1) @@ -167,9 +131,10 @@ def fft_piecewise_constant(cell, fourier_order): modes = np.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) - f_coeffs_x = cell_diff_x @ np.exp(-1j * 2 * np.pi * x @ modes[None, :]) + f_coeffs_x = cell_diff_x @ np.exp(-1j * 2 * np.pi * x @ modes[None, :], dtype=type_complex) c = f_coeffs_x.shape[1] // 2 + # x_next = np.vstack(np.roll(x, -1, axis=0)[:-1]) - x x_next = np.vstack((np.roll(x, -1, axis=0)[:-1], 1)) - x f_coeffs_x[:, c] = (cell @ np.vstack((x[0], x_next[:-1]))).flatten() @@ -183,20 +148,22 @@ def fft_piecewise_constant(cell, fourier_order): modes = np.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) - f_coeffs_xy = f_coeffs_x_diff_y.T @ np.exp(-1j * 2 * np.pi * y @ modes[None, :]) + f_coeffs_xy = f_coeffs_x_diff_y.T @ np.exp(-1j * 2 * np.pi * y @ modes[None, :], dtype=type_complex) c = f_coeffs_xy.shape[1] // 2 y_next = np.vstack((np.roll(y, -1, axis=0)[:-1], 1)) - y f_coeffs_xy[:, c] = f_coeffs_x.T @ np.vstack((y[0], y_next[:-1])).flatten() - mask = np.ones(f_coeffs_xy.shape[1], dtype=bool) - mask[c] = False - f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes[mask]) + + if c: + mask = np.ones(f_coeffs_xy.shape[1], dtype=bool) + mask[c] = False + f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes[mask]) return f_coeffs_xy.T -def to_conv_mat(pmt, fourier_order): +def to_conv_mat(pmt, fourier_order, type_complex=np.complex128): if len(pmt.shape) == 2: print('shape is 2') @@ -205,26 +172,30 @@ def to_conv_mat(pmt, fourier_order): if pmt.shape[1] == 1: # 1D - res = np.zeros((pmt.shape[0], ff, ff)).astype('complex') + res = np.zeros((pmt.shape[0], ff, ff)).astype(type_complex) for i, layer in enumerate(pmt): - f_coeffs = fft_piecewise_constant(layer, fourier_order) - A = np.roll(circulant(f_coeffs.flatten()), (f_coeffs.size + 1) // 2, 0) - res[i] = A[:2 * fourier_order + 1, :2 * fourier_order + 1] + f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) + + center = f_coeffs.shape[1] // 2 + conv_idx = np.arange(-ff + 1, ff, 1, dtype=int) + conv_idx = circulant(conv_idx) + e_conv = f_coeffs[0, center + conv_idx] + res[i] = e_conv else: # 2D # attention on the order of axis (Z Y X) - # TODO: separate fourier order - res = np.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype('complex') + res = np.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype(type_complex) for i, layer in enumerate(pmt): - pmtvy_fft = fft_piecewise_constant(layer, fourier_order) + pmtvy_fft = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) center = np.array(pmtvy_fft.shape) // 2 conv_idx = np.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx)[ff - 1:, :ff] + # conv_idx = circulant(conv_idx)[ff - 1:, :ff] + conv_idx = circulant(conv_idx) conv_i = np.repeat(conv_idx, ff, axis=1) conv_i = np.repeat(conv_i, [ff] * ff, axis=0) @@ -238,32 +209,18 @@ def to_conv_mat(pmt, fourier_order): # plt.colorbar() # plt.show() # - # return res - - -# def draw_fill_factor(patterns_fill_factor, grating_type, resolution=1000, mode=0): -# -# # res in Z X Y -# if grating_type == 2: -# res = np.zeros((len(patterns_fill_factor), resolution, resolution), dtype='complex') -# else: -# res = np.zeros((len(patterns_fill_factor), 1, resolution), dtype='complex') -# -# if grating_type in (0, 1): # TODO: handle this by len(fill_factor) -# # fill_factor is not exactly implemented. -# for i, (n_ridge, n_groove, fill_factor) in enumerate(patterns_fill_factor): -# permittivity = np.ones((1, resolution), dtype='complex') -# cut = int(resolution * fill_factor) -# permittivity[0, :cut] *= n_ridge ** 2 -# permittivity[0, cut:] *= n_groove ** 2 -# res[i, 0] = permittivity -# else: # 2D -# for i, (n_ridge, n_groove, fill_factor) in enumerate(patterns_fill_factor): -# fill_factor = np.array(fill_factor) -# permittivity = np.ones((resolution, resolution), dtype='complex') -# cut = (resolution * fill_factor) # TODO: need parenthesis? -# permittivity *= n_groove ** 2 -# permittivity[:int(cut[1]), :int(cut[0])] *= n_ridge ** 2 -# res[i] = permittivity -# -# return res + return res + + +def circulant(c): + + center = c.shape[0] // 2 + circ = np.zeros((center + 1, center + 1), dtype=int) + + for r in range(center+1): + idx = np.arange(r, r - center - 1, -1, dtype=int) + + assign_value = c[center + idx] + circ[r] = assign_value + + return circ diff --git a/meent/on_numpy/field_distribution.py b/meent/on_numpy/field_distribution.py index 4467342..eed903f 100644 --- a/meent/on_numpy/field_distribution.py +++ b/meent/on_numpy/field_distribution.py @@ -1,28 +1,31 @@ +import time + import numpy as np import matplotlib.pyplot as plt -from scipy.linalg import expm - def field_distribution(grating_type, *args, **kwargs): if grating_type == 0: res = field_dist_1d(*args, **kwargs) + elif grating_type == 1: + res = field_dist_1d_conical(*args, **kwargs) else: res = field_dist_2d(*args, **kwargs) return res -def field_dist_1d(wavelength, n_I, theta, fourier_order, T1, layer_info_list, period, pol, resolution=(100, 1, 100)): +def field_dist_1d(wavelength, n_I, theta, fourier_order, T1, layer_info_list, period, pol, resolution=(100, 1, 100), + type_complex=np.complex128): k0 = 2 * np.pi / wavelength fourier_indices = np.arange(-fourier_order, fourier_order + 1) - kx_vector = k0 * (n_I * np.sin(theta) - fourier_indices * (wavelength / period[0])).astype('complex') + kx_vector = k0 * (n_I * np.sin(theta) - fourier_indices * (wavelength / period[0])).astype(type_complex) Kx = np.diag(kx_vector / k0) resolution_z, resolution_y, resolution_x = resolution - field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype='complex') + field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) T_layer = T1 @@ -57,12 +60,12 @@ def field_dist_1d(wavelength, n_I, theta, fourier_order, T1, layer_info_list, pe Hx = -1j * Ux.T @ np.exp(-1j * kx_vector.reshape((-1, 1)) * x) Hz = f_here.T @ np.exp(-1j * kx_vector.reshape((-1, 1)) * x) - field_cell[resolution_z * idx_layer + k, j, i] = Ey, Hx, Hz + field_cell[resolution_z * idx_layer + k, j, i] = [Ey[0, 0], Hx[0, 0], Hz[0, 0]] else: # TM Uy = W @ (expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) Sx = V @ (-expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) - f_here = (-1j) * EKx @ Uy + f_here = (-1j) * EKx @ Uy # there is a better option for convergence for j in range(resolution_y): for i in range(resolution_x): @@ -72,32 +75,108 @@ def field_dist_1d(wavelength, n_I, theta, fourier_order, T1, layer_info_list, pe Ex = 1j * Sx.T @ np.exp(-1j * kx_vector.reshape((-1, 1)) * x) Ez = f_here.T @ np.exp(-1j * kx_vector.reshape((-1, 1)) * x) - field_cell[resolution_z * idx_layer + k, j, i] = Hy, Ex, Ez + field_cell[resolution_z * idx_layer + k, j, i] = [Hy[0, 0], Ex[0, 0], Ez[0, 0]] T_layer = a_i @ X @ T_layer return field_cell -def field_dist_2d(wavelength, n_I, theta, phi, fourier_order, T1, layer_info_list, period, resolution=(100, 100, 100)): +def field_dist_1d_conical(wavelength, n_I, theta, phi, fourier_order, T1, layer_info_list, period, + resolution=(100, 100, 100), type_complex=np.complex128): + + k0 = 2 * np.pi / wavelength + fourier_indices = np.arange(-fourier_order, fourier_order + 1) + + kx_vector = k0 * (n_I * np.sin(theta) * np.cos(phi) - fourier_indices * ( + wavelength / period[0])).astype(type_complex) + ky = k0 * n_I * np.sin(theta) * np.sin(phi) + + Kx = np.diag(kx_vector / k0) + + resolution_z, resolution_y, resolution_x = resolution + field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 6), dtype=type_complex) + + T_layer = T1 + + big_I = np.eye((len(T1)), dtype=type_complex) + + # From the first layer + for idx_layer, [E_conv_i, q_1, q_2, W_1, W_2, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] \ + in enumerate(layer_info_list[::-1]): + + c = np.block([[big_I], [big_B @ big_A_i @ big_X]]) @ T_layer + + cut = len(c) // 4 + + c1_plus = c[0*cut:1*cut] + c2_plus = c[1*cut:2*cut] + c1_minus = c[2*cut:3*cut] + c2_minus = c[3*cut:4*cut] + + big_Q1 = np.diag(q_1) + big_Q2 = np.diag(q_2) + + for k in range(resolution_z): + z = k / resolution_z * d + + Sx = W_2 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Sy = V_11 @ (expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + V_12 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Ux = W_1 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) + + Uy = V_21 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + V_22 @ (-expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Sz = -1j * E_conv_i @ (Kx @ Uy - ky * Ux) + + Uz = -1j * (Kx @ Sy - ky * Sx) + + for j in range(resolution_y): + for i in range(resolution_x): + x = i * period[0] / resolution_x + + exp_K = np.exp(-1j*kx_vector.reshape((-1, 1)) * x) + # exp_K = exp_K.flatten() + + Ex = Sx.T @ exp_K + Ey = Sy.T @ exp_K + Ez = Sz.T @ exp_K + + Hx = -1j * Ux.T @ exp_K + Hy = -1j * Uy.T @ exp_K + Hz = -1j * Uz.T @ exp_K + + field_cell[resolution_z * idx_layer + k, j, i] = [Ex[0, 0], Ey[0, 0], Ez[0, 0], Hx[0, 0], Hy[0, 0], Hz[0, 0]] + + T_layer = big_A_i @ big_X @ T_layer + + return field_cell + + +def field_dist_2d(wavelength, n_I, theta, phi, fourier_order, T1, layer_info_list, period, resolution=(100, 100, 100), + type_complex=np.complex128): + k0 = 2 * np.pi / wavelength fourier_indices = np.arange(-fourier_order, fourier_order + 1) ff = 2 * fourier_order + 1 kx_vector = k0 * (n_I * np.sin(theta) * np.cos(phi) - fourier_indices * ( - wavelength / period[0])).astype('complex') + wavelength / period[0])).astype(type_complex) ky_vector = k0 * (n_I * np.sin(theta) * np.sin(phi) - fourier_indices * ( - wavelength / period[1])).astype('complex') + wavelength / period[1])).astype(type_complex) Kx = np.diag(np.tile(kx_vector, ff).flatten()) / k0 Ky = np.diag(np.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 resolution_z, resolution_y, resolution_x = resolution - field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 6), dtype='complex') + field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 6), dtype=type_complex) T_layer = T1 - big_I = np.eye((len(T1))) + big_I = np.eye((len(T1)), dtype=type_complex) # From the first layer for idx_layer, (E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d)\ @@ -153,14 +232,13 @@ def field_dist_2d(wavelength, n_I, theta, phi, fourier_order, T1, layer_info_lis Hy = -1j * Uy.T @ exp_K Hz = -1j * Uz.T @ exp_K - field_cell[resolution_z * idx_layer + k, j, i] = [Ex, Ey, Ez, Hx, Hy, Hz] - + field_cell[resolution_z * idx_layer + k, j, i] = [Ex[0], Ey[0], Ez[0], Hx[0], Hy[0], Hz[0]] T_layer = big_A_i @ big_X @ T_layer return field_cell -def field_plot_zx(field_cell, pol=0, plot_indices=(1, 1, 1, 1, 1, 1), y_slice=0, z_slice=-1, zx=True, yx=True): +def field_plot_zx(field_cell, pol=0, plot_indices=(1, 1, 1, 1, 1, 1), y_slice=0, z_slice=-1, zx=True, yx=False): if field_cell.shape[-1] == 6: # 2D grating title = ['2D Ex', '2D Ey', '2D Ez', '2D Hx', '2D Hy', '2D Hz', ] @@ -182,22 +260,12 @@ def field_plot_zx(field_cell, pol=0, plot_indices=(1, 1, 1, 1, 1, 1), y_slice=0, for idx in range(len(title)): if plot_indices[idx]: plt.imshow((abs(field_cell[z_slice, :, :, idx]) ** 2), cmap='jet', aspect='auto') - plt.clim(0, 3.5) # identical to caxis([-4,4]) in MATLAB + # plt.clim(0, 3.5) # identical to caxis([-4,4]) in MATLAB plt.colorbar() plt.title(title[idx]) plt.show() - # for idx in range(len(title)): - # if plot_indices[idx]: - # plt.imshow((abs(field_cell[0, :, :, idx]) ** 2), cmap='jet', aspect='auto') - # # plt.clim(0, 1.3) # identical to caxis([-4,4]) in MATLAB - # plt.colorbar() - # plt.title(title[idx]) - # plt.show() - # for idx in range(len(title)): - # if plot_indices[idx]: - # plt.imshow((abs(field_cell[-1, :, :, idx]) ** 2), cmap='jet', aspect='auto') - # # plt.clim(0, 3.2) # identical to caxis([-4,4]) in MATLAB - # plt.colorbar() - # plt.title(title[idx]) - # plt.show() + +def expm(x): + return np.diag(np.exp(np.diag(x))) + diff --git a/meent/on_numpy/rcwa.py b/meent/on_numpy/rcwa.py index 226e799..3b70c50 100644 --- a/meent/on_numpy/rcwa.py +++ b/meent/on_numpy/rcwa.py @@ -3,24 +3,29 @@ from ._base import _BaseRCWA from .convolution_matrix import to_conv_mat, put_permittivity_in_ucell, read_material_table -from .field_distribution import field_dist_1d, field_dist_2d, field_plot_zx +from .field_distribution import field_dist_1d, field_dist_2d, field_plot_zx, field_dist_1d_conical -class RCWALight(_BaseRCWA): +class RCWANumpy(_BaseRCWA): def __init__(self, mode=0, grating_type=0, n_I=1., n_II=1., theta=0, phi=0, psi=0, fourier_order=40, period=(100,), - wavelength=np.linspace(900, 900, 1), pol=0, patterns=None, ucell=None, ucell_materials=None, thickness=None, algo='TMM'): + wavelength=900, pol=0, patterns=None, ucell=None, ucell_materials=None, + thickness=None, algo='TMM', perturbation=1E-10, + device='cpu', type_complex=np.complex128): - super().__init__(grating_type, n_I, n_II, theta, phi, psi, fourier_order, period, wavelength, pol, patterns, ucell, ucell_materials, - thickness, algo) + super().__init__(grating_type, n_I, n_II, theta, phi, psi, fourier_order, period, wavelength, pol, patterns, + ucell, ucell_materials, + thickness, algo, perturbation, device, type_complex) + self.device = 'cpu' self.mode = mode - self.spectrum_r, self.spectrum_t = None, None - # self.init_spectrum_array() + self.type_complex = type_complex + self.mat_table = read_material_table() + self.layer_info_list = [] def solve(self, wavelength, e_conv_all, o_e_conv_all): - # TODO: !handle uniform layer + self.get_kx_vector() if self.grating_type == 0: de_ri, de_ti = self.solve_1d(wavelength, e_conv_all, o_e_conv_all) @@ -33,51 +38,15 @@ def solve(self, wavelength, e_conv_all, o_e_conv_all): return de_ri.real, de_ti.real - # def loop_wavelength_fill_factor_(self, wavelength_array=None): - # - # if wavelength_array is not None: - # self.wavelength = wavelength_array - # self.init_spectrum_array() - # - # for i, wl in enumerate(self.wavelength): - # - # ucell = fill_factor_to_ucell(self.patterns, wl, self.grating_type, self.mat_table) - # e_conv_all = to_conv_mat(ucell, self.fourier_order) - # o_e_conv_all = to_conv_mat(1 / ucell, self.fourier_order) - # - # de_ri, de_ti = self.solve(wl, e_conv_all, o_e_conv_all) - # self.spectrum_r[i] = de_ri - # self.spectrum_t[i] = de_ti - # - # return self.spectrum_r, self.spectrum_t - # - # def loop_wavelength_ucell_(self, wavelength_array=None): - # - # if wavelength_array is not None: - # self.wavelength = wavelength_array - # self.init_spectrum_array() - # - # for i, wl in enumerate(self.wavelength): - # - # ucell = put_permittivity_in_ucell(self.ucell, wl, self.grating_type, self.mat_table) - # e_conv_all = to_conv_mat(ucell, self.fourier_order) - # o_e_conv_all = to_conv_mat(1 / ucell, self.fourier_order) - # - # de_ri, de_ti = self.solve(wl, e_conv_all, o_e_conv_all) - # - # self.spectrum_r[i] = de_ri - # self.spectrum_t[i] = de_ti - # - # return self.spectrum_r, self.spectrum_t - def run_ucell(self): - ucell = put_permittivity_in_ucell(self.ucell, self.ucell_materials, self.mat_table, self.wavelength) + ucell = put_permittivity_in_ucell(self.ucell, self.ucell_materials, self.mat_table, self.wavelength, + type_complex=self.type_complex) - e_conv_all = to_conv_mat(ucell, self.fourier_order) - o_e_conv_all = to_conv_mat(1 / ucell, self.fourier_order) + E_conv_all = to_conv_mat(ucell, self.fourier_order, type_complex=self.type_complex) + o_E_conv_all = to_conv_mat(1 / ucell, self.fourier_order, type_complex=self.type_complex) - de_ri, de_ti = self.solve(self.wavelength, e_conv_all, o_e_conv_all) + de_ri, de_ti = self.solve(self.wavelength, E_conv_all, o_E_conv_all) return de_ri, de_ti @@ -86,14 +55,23 @@ def calculate_field(self, resolution=None, plot=True): if self.grating_type == 0: resolution = [100, 1, 100] if not resolution else resolution field_cell = field_dist_1d(self.wavelength, self.n_I, self.theta, self.fourier_order, self.T1, - self.layer_info_list, self.period, self.pol, resolution=resolution) + self.layer_info_list, self.period, self.pol, resolution=resolution, + type_complex=self.type_complex) + elif self.grating_type == 1: + resolution = [100, 1, 100] if not resolution else resolution + field_cell = field_dist_1d_conical(self.wavelength, self.n_I, self.theta, self.phi, self.fourier_order, + self.T1, + self.layer_info_list, self.period, resolution=resolution, + type_complex=self.type_complex) + else: resolution = [100, 100, 100] if not resolution else resolution + t0 = time.time() field_cell = field_dist_2d(self.wavelength, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, - self.layer_info_list, self.period, resolution=resolution) - + self.layer_info_list, self.period, resolution=resolution, + type_complex=self.type_complex) + print(time.time() - t0) if plot: field_plot_zx(field_cell, self.pol) return field_cell - diff --git a/meent/on_numpy/transfer_method.py b/meent/on_numpy/transfer_method.py index 15bbcc4..ccbe091 100644 --- a/meent/on_numpy/transfer_method.py +++ b/meent/on_numpy/transfer_method.py @@ -1,9 +1,10 @@ import numpy as np -def transfer_1d_1(ff, polarization, k0, n_I, n_II, theta, delta_i0, fourier_order, fourier_indices, wavelength, period): +def transfer_1d_1(ff, polarization, k0, n_I, n_II, kx_vector, theta, delta_i0, fourier_order, + type_complex=np.complex128): - kx_vector = k0 * (n_I * np.sin(theta) - fourier_indices * (wavelength / period[0])).astype('complex') + # kx_vector = k0 * (n_I * np.sin(theta) - fourier_indices * (wavelength / period[0])).astype(type_complex) k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2) ** 0.5 k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2) ** 0.5 @@ -13,7 +14,7 @@ def transfer_1d_1(ff, polarization, k0, n_I, n_II, theta, delta_i0, fourier_orde Kx = np.diag(kx_vector / k0) - f = np.eye(ff) + f = np.eye(ff, dtype=type_complex) if polarization == 0: # TE Y_I = np.diag(k_I_z / k0) @@ -34,12 +35,13 @@ def transfer_1d_1(ff, polarization, k0, n_I, n_II, theta, delta_i0, fourier_orde else: raise ValueError - T = np.eye(2 * fourier_order + 1) + T = np.eye(2 * fourier_order + 1, dtype=type_complex) return kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T -def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T): +def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T, type_complex=np.complex128): + X = np.diag(np.exp(-k0 * q * d)) W_i = np.linalg.inv(W) @@ -50,14 +52,15 @@ def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T): a_i = np.linalg.inv(a) - f = W @ (np.eye(2 * fourier_order + 1) + X @ b @ a_i @ X) - g = V @ (np.eye(2 * fourier_order + 1) - X @ b @ a_i @ X) + f = W @ (np.eye(2 * fourier_order + 1, dtype=type_complex) + X @ b @ a_i @ X) + g = V @ (np.eye(2 * fourier_order + 1, dtype=type_complex) - X @ b @ a_i @ X) T = T @ a_i @ X return X, f, g, T, a_i, b def transfer_1d_3(g1, YZ_I, f1, delta_i0, inc_term, T, k_I_z, k0, n_I, n_II, theta, polarization, k_II_z): + T1 = np.linalg.inv(g1 + 1j * YZ_I @ f1) @ (1j * YZ_I @ delta_i0 + inc_term) R = f1 @ T1 - delta_i0 T = T @ T1 @@ -75,32 +78,25 @@ def transfer_1d_3(g1, YZ_I, f1, delta_i0, inc_term, T, k_I_z, k0, n_I, n_II, the return de_ri, de_ti, T1 -def transfer_2d_1(ff, k0, n_I, n_II, period, fourier_indices, theta, phi, wavelength, perturbation=1E-20 * (1 + 1j)): - I = np.eye(ff ** 2) - O = np.zeros((ff ** 2, ff ** 2)) +def transfer_1d_conical_1(ff, k0, n_I, n_II, kx_vector, theta, phi, type_complex=np.complex128): - kx_vector = k0 * (n_I * np.sin(theta) * np.cos(phi) - fourier_indices * ( - wavelength / period[0])).astype('complex') - ky_vector = k0 * (n_I * np.sin(theta) * np.sin(phi) - fourier_indices * ( - wavelength / period[1])).astype('complex') + I = np.eye(ff, dtype=type_complex) + O = np.zeros((ff, ff), dtype=type_complex) - Kx = np.diag(np.tile(kx_vector, ff).flatten()) / k0 - Ky = np.diag(np.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 + # kx_vector = k0 * (n_I * np.sin(theta) * np.cos(phi) - fourier_indices * (wavelength / period[0]) + # ).astype(type_complex) - k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 - k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 + ky = k0 * n_I * np.sin(theta) * np.sin(phi) - k_I_z = k_I_z.flatten().conjugate() - k_II_z = k_II_z.flatten().conjugate() + k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 + k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 - idx = np.nonzero(kx_vector == 0)[0] - if len(idx): - # TODO: need imaginary part? - # TODO: make imaginary part sign consistent - kx_vector[idx] = perturbation - print(wavelength, 'varphi divide by 0: adding perturbation') + k_I_z = k_I_z.conjugate() + k_II_z = k_II_z.conjugate() - varphi = np.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() + Kx = np.diag(kx_vector / k0) + + varphi = np.arctan(ky / kx_vector) Y_I = np.diag(k_I_z / k0) Y_II = np.diag(k_II_z / k0) @@ -111,76 +107,58 @@ def transfer_2d_1(ff, k0, n_I, n_II, period, fourier_indices, theta, phi, wavele big_F = np.block([[I, O], [O, 1j * Z_II]]) big_G = np.block([[1j * Y_II, O], [O, I]]) - big_T = np.eye(ff ** 2 * 2) - - return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T - + big_T = np.eye(2 * ff, dtype=type_complex) -def transfer_2d_wv(ff, Kx, E_i, Ky, o_E_conv_i, E_conv, center): - - I = np.eye(ff ** 2) - - B = Kx @ E_i @ Kx - I - D = Ky @ E_i @ Ky - I + return Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T - S2_from_S = np.block( - [ - [Ky ** 2 + B @ o_E_conv_i, Kx @ (E_i @ Ky @ E_conv - Ky)], - [Ky @ (E_i @ Kx @ o_E_conv_i - Kx), Kx ** 2 + D @ E_conv] - ]) - eigenvalues, W = np.linalg.eig(S2_from_S) +def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varphi, big_F, big_G, big_T, + type_complex=np.complex128): - q = eigenvalues ** 0.5 + I = np.eye(ff, dtype=type_complex) + O = np.zeros((ff, ff), dtype=type_complex) - Q = np.diag(q) - Q_i = np.linalg.inv(Q) - U1_from_S = np.block( - [ - [-Kx @ Ky, Kx ** 2 - E_conv], - [o_E_conv_i - Ky ** 2, Ky @ Kx] - ] - ) - V = U1_from_S @ W @ Q_i - - return W, V, q + A = Kx ** 2 - E_conv + B = Kx @ E_conv_i @ Kx - I + A_i = np.linalg.inv(A) + B_i = np.linalg.inv(B) + to_decompose_W_1 = ky ** 2 * I + A + to_decompose_W_2 = ky ** 2 * I + B @ o_E_conv_i -def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T): + eigenvalues_1, W_1 = np.linalg.eig(to_decompose_W_1) + eigenvalues_2, W_2 = np.linalg.eig(to_decompose_W_2) - q1 = q[:center] - q2 = q[center:] + q_1 = eigenvalues_1 ** 0.5 + q_2 = eigenvalues_2 ** 0.5 - W_11 = W[:center, :center] - W_12 = W[:center, center:] - W_21 = W[center:, :center] - W_22 = W[center:, center:] + Q_1 = np.diag(q_1) + Q_2 = np.diag(q_2) - V_11 = V[:center, :center] - V_12 = V[:center, center:] - V_21 = V[center:, :center] - V_22 = V[center:, center:] + V_11 = A_i @ W_1 @ Q_1 + V_12 = (ky / k0) * A_i @ Kx @ W_2 + V_21 = (ky / k0) * B_i @ Kx @ E_conv_i @ W_1 + V_22 = B_i @ W_2 @ Q_2 - X_1 = np.diag(np.exp(-k0 * q1 * d)) - X_2 = np.diag(np.exp(-k0 * q2 * d)) + X_1 = np.diag(np.exp(-k0 * q_1 * d)) + X_2 = np.diag(np.exp(-k0 * q_2 * d)) F_c = np.diag(np.cos(varphi)) F_s = np.diag(np.sin(varphi)) - W_ss = F_c @ W_21 - F_s @ W_11 - W_sp = F_c @ W_22 - F_s @ W_12 - W_ps = F_c @ W_11 + F_s @ W_21 - W_pp = F_c @ W_12 + F_s @ W_22 - - V_ss = F_c @ V_11 + F_s @ V_21 - V_sp = F_c @ V_12 + F_s @ V_22 - V_ps = F_c @ V_21 - F_s @ V_11 - V_pp = F_c @ V_22 - F_s @ V_12 + V_ss = F_c @ V_11 + V_sp = F_c @ V_12 - F_s @ W_2 + W_ss = F_c @ W_1 + F_s @ V_21 + W_sp = F_s @ V_22 + W_ps = F_s @ V_11 + W_pp = F_c @ W_2 + F_s @ V_12 + V_ps = F_c @ V_21 - F_s @ W_1 + V_pp = F_c @ V_22 - big_I = np.eye(2 * (len(I))) + big_I = np.eye(2 * (len(I)), dtype=type_complex) big_X = np.block([[X_1, O], [O, X_2]]) - big_W = np.block([[W_ss, W_sp], [W_ps, W_pp]]) - big_V = np.block([[V_ss, V_sp], [V_ps, V_pp]]) + big_W = np.block([[V_ss, V_sp], [W_ps, W_pp]]) + big_V = np.block([[W_ss, W_sp], [V_ps, V_pp]]) big_W_i = np.linalg.inv(big_W) big_V_i = np.linalg.inv(big_V) @@ -195,22 +173,24 @@ def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T): big_T = big_T @ big_A_i @ big_X - return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 + return big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 -def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z): - I = np.eye(ff ** 2) - O = np.zeros((ff ** 2, ff ** 2)) +def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, + type_complex=np.complex128): - big_F_11 = big_F[:center, :center] - big_F_12 = big_F[:center, center:] - big_F_21 = big_F[center:, :center] - big_F_22 = big_F[center:, center:] + I = np.eye(ff, dtype=type_complex) + O = np.zeros((ff, ff), dtype=type_complex) - big_G_11 = big_G[:center, :center] - big_G_12 = big_G[:center, center:] - big_G_21 = big_G[center:, :center] - big_G_22 = big_G[center:, center:] + big_F_11 = big_F[:ff, :ff] + big_F_12 = big_F[:ff, ff:] + big_F_21 = big_F[ff:, :ff] + big_F_22 = big_F[ff:, ff:] + + big_G_11 = big_G[:ff, :ff] + big_G_12 = big_G[:ff, ff:] + big_G_21 = big_G[ff:, :ff] + big_G_22 = big_G[ff:, ff:] # Final Equation in form of AX=B final_A = np.block( @@ -222,25 +202,23 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i ] ) - final_B = np.block( - [ - [-np.sin(psi) * delta_i0], - [-np.cos(psi) * np.cos(theta) * delta_i0], - [-1j * np.sin(psi) * n_I * np.cos(theta) * delta_i0], - [1j * n_I * np.cos(psi) * delta_i0] - ] - ) + final_B = np.hstack([ + [-np.sin(psi) * delta_i0], + [-np.cos(psi) * np.cos(theta) * delta_i0], + [-1j * np.sin(psi) * n_I * np.cos(theta) * delta_i0], + [1j * n_I * np.cos(psi) * delta_i0] + ]).T final_RT = np.linalg.inv(final_A) @ final_B - R_s = final_RT[:ff ** 2, :].flatten() - R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() + R_s = final_RT[:ff, :].flatten() + R_p = final_RT[ff:2 * ff, :].flatten() - big_T1 = final_RT[2 * ff ** 2:, :] + big_T1 = final_RT[2 * ff:, :] big_T = big_T @ big_T1 - T_s = big_T[:ff ** 2, :].flatten() - T_p = big_T[ff ** 2:, :].flatten() + T_s = big_T[:ff, :].flatten() + T_p = big_T[ff:, :].flatten() de_ri = R_s * np.conj(R_s) * np.real(k_I_z / (k0 * n_I * np.cos(theta))) \ + R_p * np.conj(R_p) * np.real((k_I_z / n_I ** 2) / (k0 * n_I * np.cos(theta))) @@ -251,27 +229,28 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i return de_ri.real, de_ti.real, big_T1 -def transfer_1d_conical_1(ff, k0, n_I, n_II, period, fourier_indices, theta, phi, wavelength, perturbation=1E-20 * (1 + 1j)): - I = np.eye(ff) - O = np.zeros((ff, ff)) +def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, phi, wavelength, + type_complex=np.complex128): - kx_vector = k0 * (n_I * np.sin(theta) * np.cos(phi) - fourier_indices * (wavelength / period[0])).astype('complex') - ky = k0 * n_I * np.sin(theta) * np.sin(phi) + I = np.eye(ff ** 2, dtype=type_complex) + O = np.zeros((ff ** 2, ff ** 2), dtype=type_complex) - k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 - k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 + # kx_vector = k0 * (n_I * np.sin(theta) * np.cos(phi) - fourier_indices * ( + # wavelength / period[0])).astype(type_complex) - k_I_z = k_I_z.conjugate() - k_II_z = k_II_z.conjugate() + ky_vector = k0 * (n_I * np.sin(theta) * np.sin(phi) - fourier_indices * ( + wavelength / period[1])).astype(type_complex) - idx = np.nonzero(kx_vector == 0)[0] - if len(idx): - # TODO: need imaginary part? - # TODO: make imaginary part sign consistent - kx_vector[idx] = perturbation # TODO: test - print(wavelength, 'varphi divide by 0: adding perturbation') + k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 + k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 - varphi = np.arctan(ky / kx_vector) + k_I_z = k_I_z.flatten().conjugate() + k_II_z = k_II_z.flatten().conjugate() + + Kx = np.diag(np.tile(kx_vector, ff).flatten()) / k0 + Ky = np.diag(np.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 + + varphi = np.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() Y_I = np.diag(k_I_z / k0) Y_II = np.diag(k_II_z / k0) @@ -279,62 +258,79 @@ def transfer_1d_conical_1(ff, k0, n_I, n_II, period, fourier_indices, theta, phi Z_I = np.diag(k_I_z / (k0 * n_I ** 2)) Z_II = np.diag(k_II_z / (k0 * n_II ** 2)) - Kx = np.diag(kx_vector / k0) - big_F = np.block([[I, O], [O, 1j * Z_II]]) big_G = np.block([[1j * Y_II, O], [O, I]]) - big_T = np.eye(2 * ff) + big_T = np.eye(ff ** 2 * 2, dtype=type_complex) - return Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T + return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T -def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_i, oneover_E_conv_i, ff, d, varphi, big_F, big_G, big_T): +def transfer_2d_wv(ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=np.complex128): - I = np.eye(ff) - O = np.zeros((ff, ff)) + I = np.eye(ff ** 2, dtype=type_complex) - A = Kx ** 2 - E_conv - B = Kx @ E_i @ Kx - I - A_i = np.linalg.inv(A) - B_i = np.linalg.inv(B) + B = Kx @ E_conv_i @ Kx - I + D = Ky @ E_conv_i @ Ky - I - to_decompose_W_1 = ky ** 2 * I + A - to_decompose_W_2 = ky ** 2 * I + B @ oneover_E_conv_i + S2_from_S = np.block( + [ + [Ky ** 2 + B @ o_E_conv_i, Kx @ (E_conv_i @ Ky @ E_conv - Ky)], + [Ky @ (E_conv_i @ Kx @ o_E_conv_i - Kx), Kx ** 2 + D @ E_conv] + ]) - eigenvalues_1, W_1 = np.linalg.eig(to_decompose_W_1) - eigenvalues_2, W_2 = np.linalg.eig(to_decompose_W_2) + eigenvalues, W = np.linalg.eig(S2_from_S) - q_1 = eigenvalues_1 ** 0.5 - q_2 = eigenvalues_2 ** 0.5 + q = eigenvalues ** 0.5 - Q_1 = np.diag(q_1) - Q_2 = np.diag(q_2) + Q = np.diag(q) + Q_i = np.linalg.inv(Q) + U1_from_S = np.block( + [ + [-Kx @ Ky, Kx ** 2 - E_conv], + [o_E_conv_i - Ky ** 2, Ky @ Kx] + ] + ) + V = U1_from_S @ W @ Q_i - V_11 = A_i @ W_1 @ Q_1 - V_12 = (ky / k0) * A_i @ Kx @ W_2 - V_21 = (ky / k0) * B_i @ Kx @ E_i @ W_1 - V_22 = B_i @ W_2 @ Q_2 + return W, V, q - X_1 = np.diag(np.exp(-k0 * q_1 * d)) - X_2 = np.diag(np.exp(-k0 * q_2 * d)) + +def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, type_complex=np.complex128): + + q1 = q[:center] + q2 = q[center:] + + W_11 = W[:center, :center] + W_12 = W[:center, center:] + W_21 = W[center:, :center] + W_22 = W[center:, center:] + + V_11 = V[:center, :center] + V_12 = V[:center, center:] + V_21 = V[center:, :center] + V_22 = V[center:, center:] + + X_1 = np.diag(np.exp(-k0 * q1 * d)) + X_2 = np.diag(np.exp(-k0 * q2 * d)) F_c = np.diag(np.cos(varphi)) F_s = np.diag(np.sin(varphi)) - V_ss = F_c @ V_11 - V_sp = F_c @ V_12 - F_s @ W_2 - W_ss = F_c @ W_1 + F_s @ V_21 - W_sp = F_s @ V_22 - W_ps = F_s @ V_11 - W_pp = F_c @ W_2 + F_s @ V_12 - V_ps = F_c @ V_21 - F_s @ W_1 - V_pp = F_c @ V_22 + W_ss = F_c @ W_21 - F_s @ W_11 + W_sp = F_c @ W_22 - F_s @ W_12 + W_ps = F_c @ W_11 + F_s @ W_21 + W_pp = F_c @ W_12 + F_s @ W_22 + + V_ss = F_c @ V_11 + F_s @ V_21 + V_sp = F_c @ V_12 + F_s @ V_22 + V_ps = F_c @ V_21 - F_s @ V_11 + V_pp = F_c @ V_22 - F_s @ V_12 - big_I = np.eye(2 * (len(I))) + big_I = np.eye(2 * (len(I)), dtype=type_complex) big_X = np.block([[X_1, O], [O, X_2]]) - big_W = np.block([[V_ss, V_sp], [W_ps, W_pp]]) - big_V = np.block([[W_ss, W_sp], [V_ps, V_pp]]) + big_W = np.block([[W_ss, W_sp], [W_ps, W_pp]]) + big_V = np.block([[V_ss, V_sp], [V_ps, V_pp]]) big_W_i = np.linalg.inv(big_W) big_V_i = np.linalg.inv(big_V) @@ -349,22 +345,24 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_i, oneover_E_conv_i, ff, d, varp big_T = big_T @ big_A_i @ big_X - return big_F, big_G, big_T + return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 -def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z): - I = np.eye(ff) - O = np.zeros((ff, ff)) +def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, + type_complex=np.complex128): - big_F_11 = big_F[:ff, :ff] - big_F_12 = big_F[:ff, ff:] - big_F_21 = big_F[ff:, :ff] - big_F_22 = big_F[ff:, ff:] + I = np.eye(ff ** 2, dtype=type_complex) + O = np.zeros((ff ** 2, ff ** 2), dtype=type_complex) - big_G_11 = big_G[:ff, :ff] - big_G_12 = big_G[:ff, ff:] - big_G_21 = big_G[ff:, :ff] - big_G_22 = big_G[ff:, ff:] + big_F_11 = big_F[:center, :center] + big_F_12 = big_F[:center, center:] + big_F_21 = big_F[center:, :center] + big_F_22 = big_F[center:, center:] + + big_G_11 = big_G[:center, :center] + big_G_12 = big_G[:center, center:] + big_G_21 = big_G[center:, :center] + big_G_22 = big_G[center:, center:] # Final Equation in form of AX=B final_A = np.block( @@ -376,27 +374,31 @@ def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i ] ) - final_B = np.hstack([ - [-np.sin(psi) * delta_i0], - [-np.cos(psi) * np.cos(theta) * delta_i0], - [-1j * np.sin(psi) * n_I * np.cos(theta) * delta_i0], - [1j * n_I * np.cos(psi) * delta_i0] - ]).T + final_B = np.block( + [ + [-np.sin(psi) * delta_i0], + [-np.cos(psi) * np.cos(theta) * delta_i0], + [-1j * np.sin(psi) * n_I * np.cos(theta) * delta_i0], + [1j * n_I * np.cos(psi) * delta_i0] + ] + ) - final_X = np.linalg.inv(final_A) @ final_B + final_RT = np.linalg.inv(final_A) @ final_B - R_s = final_X[:ff, :].flatten() - R_p = final_X[ff:2 * ff, :].flatten() + R_s = final_RT[:ff ** 2, :].flatten() + R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() - big_T = big_T @ final_X[2 * ff:, :] - T_s = big_T[:ff, :].flatten() - T_p = big_T[ff:, :].flatten() + big_T1 = final_RT[2 * ff ** 2:, :] + big_T = big_T @ big_T1 + + T_s = big_T[:ff ** 2, :].flatten() + T_p = big_T[ff ** 2:, :].flatten() de_ri = R_s * np.conj(R_s) * np.real(k_I_z / (k0 * n_I * np.cos(theta))) \ - + R_p * np.conj(R_p) * np.real((k_I_z / n_I ** 2) / (k0 * n_I * np.cos(theta))) + + R_p * np.conj(R_p) * np.real((k_I_z / n_I ** 2) / (k0 * n_I * np.cos(theta))) de_ti = T_s * np.conj(T_s) * np.real(k_II_z / (k0 * n_I * np.cos(theta))) \ - + T_p * np.conj(T_p) * np.real((k_II_z / n_II ** 2) / (k0 * n_I * np.cos(theta))) + + T_p * np.conj(T_p) * np.real((k_II_z / n_II ** 2) / (k0 * n_I * np.cos(theta))) - return de_ri.real, de_ti.real + return de_ri.real, de_ti.real, big_T1 diff --git a/meent/on_torch/__init__.py b/meent/on_torch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/meent/on_torch/_base.py b/meent/on_torch/_base.py new file mode 100644 index 0000000..fb48a3b --- /dev/null +++ b/meent/on_torch/_base.py @@ -0,0 +1,292 @@ +from copy import deepcopy + +import numpy as np +import torch + +from .scattering_method import scattering_1d_1, scattering_1d_2, scattering_1d_3, scattering_2d_1, scattering_2d_wv, \ + scattering_2d_2, scattering_2d_3 +from .transfer_method import transfer_1d_1, transfer_1d_2, transfer_1d_3, transfer_1d_conical_1, transfer_1d_conical_2, \ + transfer_1d_conical_3, transfer_2d_1, transfer_2d_wv, transfer_2d_2, transfer_2d_3 + + +class _BaseRCWA: + def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., fourier_order=10, + period=0.7, wavelength=np.linspace(0.5, 2.3, 400), pol=0, + patterns=None, ucell=None, ucell_materials=None, thickness=None, algo='TMM', perturbation=1E-10, + device='cpu', type_complex=torch.complex128): + + self.device = device + self.type_complex = type_complex + + # common + self.grating_type = grating_type # 1D=0, 1D_conical=1, 2D=2 + self.n_I = n_I + self.n_II = n_II + + self.theta = torch.tensor(theta * np.pi / 180) + self.phi = torch.tensor(phi * np.pi / 180) + self.psi = torch.tensor(psi * np.pi / 180) # TODO: integrate psi and pol + + self.pol = pol # TE 0, TM 1 + if self.pol == 0: # TE + self.psi = torch.tensor(90 * np.pi / 180, device=self.device) + elif self.pol == 1: # TM + self.psi = torch.tensor(0 * np.pi / 180, device=self.device) + else: + print('not implemented yet') + raise ValueError + + self.fourier_order = fourier_order + self.ff = 2 * self.fourier_order + 1 + + self.period = deepcopy(period) + + self.wavelength = wavelength + + self.patterns = patterns + self.ucell = deepcopy(ucell) + self.ucell_materials = ucell_materials + self.thickness = deepcopy(thickness) + + self.algo = algo + self.perturbation = perturbation + + self.layer_info_list = [] + self.T1 = None + + self.kx_vector = None + + def get_kx_vector(self): + + k0 = 2 * np.pi / self.wavelength + fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + if self.grating_type == 0: + kx_vector = k0 * (self.n_I * torch.sin(self.theta) - fourier_indices * (self.wavelength / self.period[0]) + ).type(self.type_complex) + else: + kx_vector = k0 * (self.n_I * torch.sin(self.theta) * torch.cos(self.phi) - fourier_indices * ( + self.wavelength / self.period[0])).type(self.type_complex) + + idx = torch.nonzero(kx_vector == 0) + if len(idx): + # TODO: need imaginary part? + # TODO: make imaginary part sign consistent + kx_vector[idx] = self.perturbation + print('varphi divide by 0: adding perturbation') + + self.kx_vector = kx_vector + + def solve_1d(self, wl, E_conv_all, o_E_conv_all): + + self.layer_info_list = [] + self.T1 = None + + fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + + delta_i0 = torch.zeros(self.ff, device=self.device, dtype=self.type_complex) + delta_i0[self.fourier_order] = 1 + + k0 = 2 * np.pi / wl + + if self.algo == 'TMM': + kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T \ + = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, + self.theta, delta_i0, self.fourier_order, + device=self.device, type_complex=self.type_complex) + elif self.algo == 'SMM': + Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ + = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, fourier_indices, self.period, + self.pol, wl=wl) + else: + raise ValueError + + count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) + + # From the last layer + for layer_index in range(count)[::-1]: + + E_conv = E_conv_all[layer_index] + o_E_conv = o_E_conv_all[layer_index] + d = self.thickness[layer_index] + + if self.pol == 0: + E_conv_i = None + A = Kx ** 2 - E_conv + eigenvalues, W = torch.linalg.eig(A) + q = eigenvalues ** 0.5 + + Q = torch.diag(q) + V = W @ Q + + elif self.pol == 1: + E_conv_i = torch.linalg.inv(E_conv) + B = Kx @ E_conv_i @ Kx - torch.eye(E_conv.shape[0], device=self.device, dtype=self.type_complex) + o_E_conv_i = torch.linalg.inv(o_E_conv) + + eigenvalues, W = torch.linalg.eig(o_E_conv_i @ B) + q = eigenvalues ** 0.5 + + Q = torch.diag(q) + V = o_E_conv @ W @ Q + + else: + raise ValueError + + if self.algo == 'TMM': + X, f, g, T, a_i, b = transfer_1d_2(k0, q, d, W, V, f, g, self.fourier_order, T, + device=self.device, type_complex=self.type_complex) + + layer_info = [E_conv_i, q, W, X, a_i, b, d] + self.layer_info_list.append(layer_info) + + elif self.algo == 'SMM': + A, B, S_dict, Sg = scattering_1d_2(W, Wg, V, Vg, d, k0, Q, Sg) + else: + raise ValueError + + if self.algo == 'TMM': + de_ri, de_ti, T1 = transfer_1d_3(g, YZ_I, f, delta_i0, inc_term, T, k_I_z, k0, self.n_I, self.n_II, + self.theta, self.pol, k_II_z) + self.T1 = T1 + + elif self.algo == 'SMM': + de_ri, de_ti = scattering_1d_3(Wt, Wg, Vt, Vg, Sg, self.ff, Wr, self.fourier_order, Kzr, Kzt, + self.n_I, self.n_II, self.theta, self.pol) + else: + raise ValueError + + return de_ri, de_ti + + # TODO: scattering method + def solve_1d_conical(self, wl, E_conv_all, o_E_conv_all): + + self.layer_info_list = [] + self.T1 = None + + fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + + delta_i0 = torch.zeros(self.ff, device=self.device, dtype=self.type_complex) + delta_i0[self.fourier_order] = 1 + + k0 = 2 * np.pi / wl + + if self.algo == 'TMM': + Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ + = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, + device=self.device, type_complex=self.type_complex) + elif self.algo == 'SMM': + print('SMM for 1D conical is not implemented') + return np.nan, np.nan + else: + raise ValueError + + count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) + + # From the last layer + for layer_index in range(count)[::-1]: + + E_conv = E_conv_all[layer_index] + o_E_conv = o_E_conv_all[layer_index] + d = self.thickness[layer_index] + + # for e_conv, o_e_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + E_conv_i = torch.linalg.inv(E_conv) + o_E_conv_i = torch.linalg.inv(o_E_conv) + + if self.algo == 'TMM': + big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2\ + = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, self.ff, d, + varphi, big_F, big_G, big_T, + device=self.device, type_complex=self.type_complex) + + layer_info = [E_conv_i, q_1, q_2, W_1, W_2, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] + self.layer_info_list.append(layer_info) + + elif self.algo == 'SMM': + raise ValueError + else: + raise ValueError + + if self.algo == 'TMM': + de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, + device=self.device, type_complex=self.type_complex) + self.T1 = big_T1 + + elif self.algo == 'SMM': + raise ValueError + else: + raise ValueError + + return de_ri, de_ti + + def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): + + self.layer_info_list = [] + self.T1 = None + + fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + + delta_i0 = torch.zeros((self.ff ** 2, 1), device=self.device, dtype=self.type_complex) + delta_i0[self.ff ** 2 // 2, 0] = 1 + + I = torch.eye(self.ff ** 2, device=self.device, dtype=self.type_complex) + O = torch.zeros((self.ff ** 2, self.ff ** 2), device=self.device, dtype=self.type_complex) + + center = self.ff ** 2 + + k0 = 2 * np.pi / wavelength + + if self.algo == 'TMM': + kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ + = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices, + self.theta, self.phi, wavelength, device=self.device, type_complex=self.type_complex) + elif self.algo == 'SMM': + Kx, Ky, kz_inc, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ + = scattering_2d_1(self.n_I, self.n_II, self.theta, self.phi, k0, self.period, self.fourier_order) + else: + raise ValueError + + count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) + + # From the last layer + for layer_index in range(count)[::-1]: + + E_conv = E_conv_all[layer_index] + o_E_conv = o_E_conv_all[layer_index] + d = self.thickness[layer_index] + + E_conv_i = torch.linalg.inv(E_conv) + o_E_conv_i = torch.linalg.inv(o_E_conv) + + if self.algo == 'TMM': # TODO: MERGE W V part + W, V, q = transfer_2d_wv(self.ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, + device=self.device, type_complex=self.type_complex) + + big_X, big_F, big_G, big_T, big_A_i, big_B, \ + W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 \ + = transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, device=self.device, + type_complex=self.type_complex) + + layer_info = [E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] + self.layer_info_list.append(layer_info) + + elif self.algo == 'SMM': + W, V, LAMBDA = scattering_2d_wv(self.ff, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) + A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA) + else: + raise ValueError + + if self.algo == 'TMM': + de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, device=self.device, + type_complex=self.type_complex) + self.T1 = big_T1 + + elif self.algo == 'SMM': + de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, + self.pol, self.theta, self.phi, self.fourier_order, self.ff) + else: + raise ValueError + + return de_ri.reshape((self.ff, self.ff)).real, de_ti.reshape((self.ff, self.ff)).real diff --git a/meent/on_torch/convolution_matrix.py b/meent/on_torch/convolution_matrix.py new file mode 100644 index 0000000..1981a76 --- /dev/null +++ b/meent/on_torch/convolution_matrix.py @@ -0,0 +1,221 @@ +import torch +import numpy as np + +from os import walk +from scipy.io import loadmat +from pathlib import Path + + +def put_permittivity_in_ucell(ucell, mat_list, mat_table, wl, device=torch.device('cpu'), type_complex=torch.complex128): + + res = torch.zeros(ucell.shape, device=device).type(type_complex) + + for z in range(ucell.shape[0]): + for y in range(ucell.shape[1]): + for x in range(ucell.shape[2]): + material = mat_list[ucell[z, y, x]] + if type(material) == str: + res[z, y, x] = find_nk_index(material, mat_table, wl) ** 2 + else: + res[z, y, x] = material ** 2 + + return res + + +def put_permittivity_in_ucell_object(ucell_size, mat_list, obj_list, mat_table, wl, device=torch.device('cpu'), + type_complex=torch.complex128): + # TODO: under development + res = torch.zeros(ucell_size, device=device).type(type_complex) + + for material, obj_index in zip(mat_list, obj_list): + if type(material) == str: + res[obj_index] = find_nk_index(material, mat_table, wl) ** 2 + else: + res[obj_index] = material ** 2 + + return res + + +def find_nk_index(material, mat_table, wl): + if material[-6:] == '__real': + material = material[:-6] + n_only = True + else: + n_only = False + + mat_data = mat_table[material.upper()] + + n_index = np.interp(wl, mat_data[:, 0], mat_data[:, 1]) + + if n_only: + return n_index + + k_index = np.interp(wl, mat_data[:, 0], mat_data[:, 2]) + nk = n_index + 1j * k_index + + return nk + + +def read_material_table(nk_path=None): + mat_table = {} + + if nk_path is None: + nk_path = str(Path(__file__).resolve().parent.parent) + '/nk_data' + + full_path_list, name_list, _ = [], [], [] + for (dirpath, dirnames, filenames) in walk(nk_path): + full_path_list.extend([f'{dirpath}/{filename}' for filename in filenames]) + name_list.extend(filenames) + for path, name in zip(full_path_list, name_list): + if name[-3:] == 'txt': + data = np.loadtxt(path, skiprows=1) + mat_table[name[:-4].upper()] = data + + elif name[-3:] == 'mat': + data = loadmat(path) + data = np.array([data['WL'], data['n'], data['k']])[:, :, 0].T + mat_table[name[:-4].upper()] = data + return mat_table + + +def cell_compression(cell, device=torch.device('cpu'), type_complex=torch.complex128): + + if type_complex == torch.complex128: + type_float = torch.float64 + else: + type_float = torch.float32 + + # find discontinuities in x + step_y, step_x = 1. / torch.tensor(cell.shape, device=device, dtype=type_float) + x = [] + y = [] + cell_x = [] + cell_xy = [] + + cell_next = torch.roll(cell, -1, dims=1) + + for col in range(cell.shape[1]): + if not (cell[:, col] == cell_next[:, col]).all() or (col == cell.shape[1] - 1): + x.append(step_x * (col + 1)) + cell_x.append(cell[:, col].reshape((1, -1))) + # cell_xa = torch.cat(cell_x, dim=0) + # cell_xaa = torch.cat(cell_x, dim=1) + cell_x = torch.cat(cell_x, dim=0).T + cell_x_next = torch.roll(cell_x, -1, dims=0) + + for row in range(cell_x.shape[0]): + if not (cell_x[row, :] == cell_x_next[row, :]).all() or (row == cell_x.shape[0] - 1): + y.append(step_y * (row + 1)) + cell_xy.append(cell_x[row, :].reshape((1, -1))) + + x = torch.tensor(x, device=device).reshape((-1, 1)).type(type_complex) + y = torch.tensor(y, device=device).reshape((-1, 1)).type(type_complex) + cell_comp = torch.cat(cell_xy, dim=0) + + return cell_comp, x, y + + +def fft_piecewise_constant(cell, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): + if cell.shape[0] == 1: + fourier_order = [0, fourier_order] + else: + fourier_order = [fourier_order, fourier_order] + cell, x, y = cell_compression(cell, device=device, type_complex=type_complex) + + # X axis + cell_next_x = torch.roll(cell, -1, dims=1) + cell_diff_x = cell_next_x - cell + + modes = torch.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1, device=device).type(type_complex) + + f_coeffs_x = cell_diff_x @ torch.exp(-1j * 2 * np.pi * x @ modes[None, :]).type(type_complex) + c = f_coeffs_x.shape[1] // 2 + + x_next = torch.vstack((torch.roll(x, -1, dims=0)[:-1], torch.tensor([1], device=device))) - x + + f_coeffs_x[:, c] = (cell @ torch.vstack((x[0], x_next[:-1]))).flatten() + mask = torch.ones(f_coeffs_x.shape[1], device=device).type(torch.bool) + mask[c] = False + f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes[mask]) + + # Y axis + f_coeffs_x_next_y = torch.roll(f_coeffs_x, -1, dims=0) + f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x + + modes = torch.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1, device=device).type(type_complex) + + f_coeffs_xy = f_coeffs_x_diff_y.T @ torch.exp(-1j * 2 * np.pi * y @ modes[None, :]) + c = f_coeffs_xy.shape[1] // 2 + + y_next = torch.vstack((torch.roll(y, -1, dims=0)[:-1], torch.tensor([1], device=device))) - y + + f_coeffs_xy[:, c] = f_coeffs_x.T @ torch.vstack((y[0], y_next[:-1])).flatten() + + if c: + mask = torch.ones(f_coeffs_xy.shape[1], device=device).type(torch.bool) + mask[c] = False + f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes[mask]) + + return f_coeffs_xy.T + + +def to_conv_mat(pmt, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): + + if len(pmt.shape) == 2: + print('shape is 2') + raise ValueError + ff = 2 * fourier_order + 1 + + if pmt.shape[1] == 1: # 1D + + res = torch.zeros((pmt.shape[0], ff, ff), device=device).type(type_complex) + + for i, layer in enumerate(pmt): + f_coeffs = fft_piecewise_constant(layer, fourier_order, device=device, type_complex=type_complex) + + center = f_coeffs.shape[1] // 2 + conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) + conv_idx = circulant(conv_idx, device) + e_conv = f_coeffs[0, center + conv_idx] + res[i] = e_conv + + else: # 2D + # attention on the order of axis (Z Y X) + + res = torch.zeros((pmt.shape[0], ff ** 2, ff ** 2), device=device).type(type_complex) + + for i, layer in enumerate(pmt): + pmtvy_fft = fft_piecewise_constant(layer, fourier_order, device=device, type_complex=type_complex) + + center = torch.div(torch.tensor(pmtvy_fft.shape, device=device), 2, rounding_mode='trunc') + + conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) + conv_idx = circulant(conv_idx, device) + + conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) + conv_i = conv_i.repeat_interleave(ff, dim=0) + conv_j = conv_idx.repeat(ff, ff).type(torch.long) + res[i] = pmtvy_fft[center[0] + conv_i, center[1] + conv_j] + + # import matplotlib.pyplot as plt + # + # plt.figure() + # plt.imshow(abs(res[0]), cmap='jet') + # plt.colorbar() + # plt.show() + # + return res + + +def circulant(c, device=torch.device('cpu')): + + center = c.shape[0] // 2 + circ = torch.zeros((center + 1, center + 1), device=device).type(torch.long) + + for r in range(center+1): + idx = torch.arange(r, r - center - 1, -1, device=device) + + assign_value = c[center + idx] + circ[r] = assign_value + + return circ diff --git a/meent/on_torch/field_distribution.py b/meent/on_torch/field_distribution.py new file mode 100644 index 0000000..05d0235 --- /dev/null +++ b/meent/on_torch/field_distribution.py @@ -0,0 +1,274 @@ +import time + +import torch +import numpy as np +import matplotlib.pyplot as plt + + +def field_distribution(grating_type, *args, **kwargs): + if grating_type == 0: + res = field_dist_1d(*args, **kwargs) + elif grating_type == 1: + res = field_dist_1d_conical(*args, **kwargs) + else: + res = field_dist_2d(*args, **kwargs) + return res + + +def field_dist_1d(wavelength, n_I, theta, fourier_order, T1, layer_info_list, period, pol, resolution=(100, 1, 100), + device='cpu', type_complex=torch.complex128): + + k0 = 2 * np.pi / wavelength + fourier_indices = torch.arange(-fourier_order, fourier_order + 1, device=device) + + kx_vector = k0 * (n_I * np.sin(theta) - fourier_indices * (wavelength / period[0])).type(type_complex) + Kx = torch.diag(kx_vector / k0) + + resolution_z, resolution_y, resolution_x = resolution + + field_cell = torch.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3)).type(type_complex) + + T_layer = T1 + + # From the first layer + for idx_layer, (E_conv_i, q, W, X, a_i, b, d) in enumerate(layer_info_list[::-1]): + + c1 = T_layer[:, None] + c2 = b @ a_i @ X @ T_layer[:, None] + + Q = torch.diag(q) + + if pol == 0: + V = W @ Q + + else: + V = E_conv_i @ W @ Q + EKx = E_conv_i @ Kx + + for k in range(resolution_z): + z = k / resolution_z * d + + if pol == 0: # TE + Sy = W @ (expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + Ux = V @ (-expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + + f_here = (-1j) * Kx @ Sy + + for j in range(resolution_y): + for i in range(resolution_x): + x = i * period[0] / resolution_x + + Ey = Sy.T @ torch.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Hx = -1j * Ux.T @ torch.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Hz = f_here.T @ torch.exp(-1j * kx_vector.reshape((-1, 1)) * x) + + field_cell[resolution_z * idx_layer + k, j, i] = torch.tensor([Ey, Hx, Hz]) + else: # TM + Uy = W @ (expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + Sx = V @ (-expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + + f_here = (-1j) * EKx @ Uy # there is a better option for convergence + + for j in range(resolution_y): + for i in range(resolution_x): + x = i * period[0] / resolution_x + + Hy = Uy.T @ torch.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Ex = 1j * Sx.T @ torch.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Ez = f_here.T @ torch.exp(-1j * kx_vector.reshape((-1, 1)) * x) + + field_cell[resolution_z * idx_layer + k, j, i] = torch.tensor([Hy, Ex, Ez]) + + T_layer = a_i @ X @ T_layer + + return field_cell + + +def field_dist_1d_conical(wavelength, n_I, theta, phi, fourier_order, T1, layer_info_list, period, resolution=(100, 1, 100), + device='cpu', type_complex=torch.complex128): + + k0 = 2 * np.pi / wavelength + fourier_indices = torch.arange(-fourier_order, fourier_order + 1, device=device) + + kx_vector = k0 * (n_I * torch.sin(theta) * torch.cos(phi) - fourier_indices * ( + wavelength / period[0])).type(type_complex) + ky = k0 * n_I * torch.sin(theta) * torch.sin(phi) + + Kx = torch.diag(kx_vector / k0) + + resolution_z, resolution_y, resolution_x = resolution + field_cell = torch.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 6)).type(type_complex) + + T_layer = T1 + + big_I = torch.eye((len(T1)), device=device, dtype=type_complex) + + # From the first layer + for idx_layer, [E_conv_i, q_1, q_2, W_1, W_2, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] \ + in enumerate(layer_info_list[::-1]): + + c = torch.cat([big_I, big_B @ big_A_i @ big_X]) @ T_layer + + cut = len(c) // 4 + + c1_plus = c[0*cut:1*cut] + c2_plus = c[1*cut:2*cut] + c1_minus = c[2*cut:3*cut] + c2_minus = c[3*cut:4*cut] + + big_Q1 = torch.diag(q_1) + big_Q2 = torch.diag(q_2) + + for k in range(resolution_z): + z = k / resolution_z * d + + Sx = W_2 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Sy = V_11 @ (expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + V_12 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Ux = W_1 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) + + Uy = V_21 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + V_22 @ (-expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Sz = -1j * E_conv_i @ (Kx @ Uy - ky * Ux) + + Uz = -1j * (Kx @ Sy - ky * Sx) + + for j in range(resolution_y): + for i in range(resolution_x): + x = i * period[0] / resolution_x + + exp_K = torch.exp(-1j*kx_vector.reshape((-1, 1)) * x) + + Ex = Sx @ exp_K + Ey = Sy @ exp_K + Ez = Sz @ exp_K + + Hx = -1j * Ux @ exp_K + Hy = -1j * Uy @ exp_K + Hz = -1j * Uz @ exp_K + + field_cell[resolution_z * idx_layer + k, j, i] = torch.tensor([Ex, Ey, Ez, Hx, Hy, Hz]) + + T_layer = big_A_i @ big_X @ T_layer + + return field_cell + + +def field_dist_2d(wavelength, n_I, theta, phi, fourier_order, T1, layer_info_list, period, resolution=(100, 100, 100), + device='cpu', type_complex=torch.complex128): + + k0 = 2 * np.pi / wavelength + fourier_indices = torch.arange(-fourier_order, fourier_order + 1, device=device) + ff = 2 * fourier_order + 1 + + kx_vector = k0 * (n_I * np.sin(theta) * np.cos(phi) - fourier_indices * ( + wavelength / period[0])).type(type_complex) + ky_vector = k0 * (n_I * np.sin(theta) * np.sin(phi) - fourier_indices * ( + wavelength / period[1])).type(type_complex) + + Kx = torch.diag(kx_vector.tile(ff).flatten() / k0) + Ky = torch.diag(ky_vector.reshape((-1, 1)).tile(ff).flatten() / k0) + + resolution_z, resolution_y, resolution_x = resolution + field_cell = torch.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 6)).type(type_complex) + + T_layer = T1 + + big_I = torch.eye((len(T1)), device=device, dtype=type_complex) + + # From the first layer + for idx_layer, (E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d)\ + in enumerate(layer_info_list[::-1]): + + c = torch.cat([big_I, big_B @ big_A_i @ big_X]) @ T_layer + + cut = len(c) // 4 + + c1_plus = c[0*cut:1*cut] + c2_plus = c[1*cut:2*cut] + c1_minus = c[2*cut:3*cut] + c2_minus = c[3*cut:4*cut] + + q_1 = q[:len(q)//2] + q_2 = q[len(q)//2:] + big_Q1 = torch.diag(q_1) + big_Q2 = torch.diag(q_2) + + for k in range(resolution_z): + z = k / resolution_z * d + + Sx = W_11 @ (expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + W_12 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Sy = W_21 @ (expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + W_22 @ (expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Ux = V_11 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + V_12 @ (-expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Uy = V_21 @ (-expm(-k0 * big_Q1 * z) @ c1_plus + expm(k0 * big_Q1 * (z-d)) @ c1_minus) \ + + V_22 @ (-expm(-k0 * big_Q2 * z) @ c2_plus + expm(k0 * big_Q2 * (z-d)) @ c2_minus) + + Sz = -1j * E_conv_i @ (Kx @ Uy - Ky @ Ux) + + Uz = -1j * (Kx @ Sy - Ky @ Sx) + + for j in range(resolution_y): + y = j * period[1] / resolution_y + + for i in range(resolution_x): + + x = i * period[0] / resolution_x + + exp_K = torch.exp(-1j*kx_vector.reshape((1, -1)) * x) * torch.exp(-1j*ky_vector.reshape((-1, 1)) * y) + exp_K = exp_K.flatten() + + Ex = Sx.T @ exp_K + Ey = Sy.T @ exp_K + Ez = Sz.T @ exp_K + + Hx = -1j * Ux.T @ exp_K + Hy = -1j * Uy.T @ exp_K + Hz = -1j * Uz.T @ exp_K + + field_cell[resolution_z * idx_layer + k, j, i] = torch.tensor([Ex, Ey, Ez, Hx, Hy, Hz]) + + T_layer = big_A_i @ big_X @ T_layer + + return field_cell + + +def field_plot_zx(field_cell, pol=0, plot_indices=(1, 1, 1, 1, 1, 1), y_slice=0, z_slice=-1, zx=True, yx=True): + + if field_cell.shape[-1] == 6: # 2D grating + title = ['2D Ex', '2D Ey', '2D Ez', '2D Hx', '2D Hy', '2D Hz', ] + else: # 1D grating + if pol == 0: # TE + title = ['1D Ey', '1D Hx', '1D Hz', ] + else: # TM + title = ['1D Hy', '1D Ex', '1D Ez', ] + + if zx: + for idx in range(len(title)): + if plot_indices[idx]: + plt.imshow((abs(field_cell[:, y_slice, :, idx]) ** 2), cmap='jet', aspect='auto') + # plt.clim(0, 2) # identical to caxis([-4,4]) in MATLAB + plt.colorbar() + plt.title(title[idx]) + plt.show() + if yx: + for idx in range(len(title)): + if plot_indices[idx]: + plt.imshow((abs(field_cell[z_slice, :, :, idx]) ** 2), cmap='jet', aspect='auto') + plt.clim(0, 3.5) # identical to caxis([-4,4]) in MATLAB + plt.colorbar() + plt.title(title[idx]) + plt.show() + + +def expm(x): + return torch.diag(torch.exp(torch.diag(x))) + diff --git a/meent/on_torch/rcwa.py b/meent/on_torch/rcwa.py new file mode 100644 index 0000000..36e6313 --- /dev/null +++ b/meent/on_torch/rcwa.py @@ -0,0 +1,78 @@ +import time +import numpy as np +import torch + +from ._base import _BaseRCWA +from .convolution_matrix import to_conv_mat, put_permittivity_in_ucell, read_material_table +from .field_distribution import field_dist_1d, field_dist_2d, field_plot_zx, field_dist_1d_conical + + +class RCWATorch(_BaseRCWA): + def __init__(self, mode=0, grating_type=0, n_I=1., n_II=1., theta=0, phi=0, psi=0, fourier_order=40, period=(100,), + wavelength=900, pol=0, patterns=None, ucell=None, ucell_materials=None, + thickness=None, algo='TMM', perturbation=1E-10, + device='cpu', type_complex=torch.complex128): + + super().__init__(grating_type, n_I, n_II, theta, phi, psi, fourier_order, period, wavelength, pol, patterns, + ucell, ucell_materials, + thickness, algo, perturbation, device, type_complex) + + self.device = device + self.mode = mode + self.type_complex = type_complex + + self.mat_table = read_material_table() + self.layer_info_list = [] + + def solve(self, wavelength, e_conv_all, o_e_conv_all): + + # TODO: !handle uniform layer + + self.get_kx_vector() + + if self.grating_type == 0: + de_ri, de_ti = self.solve_1d(wavelength, e_conv_all, o_e_conv_all) + elif self.grating_type == 1: + de_ri, de_ti = self.solve_1d_conical(wavelength, e_conv_all, o_e_conv_all) + elif self.grating_type == 2: + de_ri, de_ti = self.solve_2d(wavelength, e_conv_all, o_e_conv_all) + else: + raise ValueError + + return de_ri.real, de_ti.real + + def run_ucell(self): + + ucell = put_permittivity_in_ucell(self.ucell, self.ucell_materials, self.mat_table, self.wavelength, + self.device, type_complex=self.type_complex) + + E_conv_all = to_conv_mat(ucell, self.fourier_order, self.device, type_complex=self.type_complex) + o_E_conv_all = to_conv_mat(1 / ucell, self.fourier_order, self.device, type_complex=self.type_complex) + + de_ri, de_ti = self.solve(self.wavelength, E_conv_all, o_E_conv_all) + + return de_ri, de_ti + + def calculate_field(self, resolution=None, plot=True): + + if self.grating_type == 0: + resolution = [100, 1, 100] if not resolution else resolution + field_cell = field_dist_1d(self.wavelength, self.n_I, self.theta, self.fourier_order, self.T1, + self.layer_info_list, self.period, self.pol, resolution=resolution, + device=self.device, type_complex=self.type_complex) + elif self.grating_type == 1: + resolution = [100, 1, 100] if not resolution else resolution + field_cell = field_dist_1d_conical(self.wavelength, self.n_I, self.theta, self.phi, self.fourier_order, + self.T1, self.layer_info_list, self.period, resolution=resolution, + device=self.device, type_complex=self.type_complex) + + else: + resolution = [100, 100, 100] if not resolution else resolution + field_cell = field_dist_2d(self.wavelength, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, + self.layer_info_list, self.period, resolution=resolution, + device=self.device, type_complex=self.type_complex) + + if plot: + field_plot_zx(field_cell, self.pol) + + return field_cell diff --git a/meent/on_torch/scattering_method.py b/meent/on_torch/scattering_method.py new file mode 100644 index 0000000..aa7fc33 --- /dev/null +++ b/meent/on_torch/scattering_method.py @@ -0,0 +1,183 @@ +""" +currently SMM is not supported +""" + +# many codes for scattering matrix method are from here: +# https://github.com/zhaonat/Rigorous-Coupled-Wave-Analysis +# also refer our fork https://github.com/yonghakim/zhaonat-rcwa + +from .smm_util import * + + +def scattering_1d_1(k0, n_I, n_II, theta, phi, fourier_indices, period, pol, wl=None): + + kx_vector = (n_I * np.sin(theta) * np.cos(phi) - fourier_indices * ( + 2 * np.pi / k0 / period[0])).astype('complex') + Kx = np.diag(kx_vector) + + # scattering matrix needed for 'gap medium' + Wg, Vg, Kzg = homogeneous_1D(Kx, 1, wl=wl, comment='Gap') + + # reflection medium + Wr, Vr, Kzr = homogeneous_1D(Kx, n_I, pol=pol, wl=wl, comment='Refl') + + # transmission medium; + Wt, Vt, Kzt = homogeneous_1D(Kx, n_II, pol=pol, wl=wl, comment='Tran') + + # S matrices for the reflection region + Ar, Br = A_B_matrices_half_space(Vr, Vg) # make sure this order is right + _, Sg = S_RT(Ar, Br, ref_mode=True) # scatter matrix for the reflection region + + return Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg + + +def scattering_1d_2(W, Wg, V, Vg, d, k0, LAMBDA, Sg): + # calculating A and B matrices for scattering matrix + # define S matrix for the GRATING REGION + A, B = A_B_matrices(W, Wg, V, Vg) + _, S_dict = S_layer(A, B, d, k0, LAMBDA) + _, Sg = RedhefferStar(Sg, S_dict) + + return A, B, S_dict, Sg + + +def scattering_1d_3(Wt, Wg, Vt, Vg, Sg, ff, Wr, fourier_order, Kzr, Kzt, n_I, n_II, theta, pol): + # define S matrices for the Transmission region + At, Bt = A_B_matrices_half_space(Vt, Vg) # make sure this order is right + _, St_dict = S_RT(At, Bt, ref_mode=False) # scatter matrix for the reflection region + _, Sg = RedhefferStar(Sg, St_dict) + + k_inc = n_I * np.array([np.sin(theta), 0, np.cos(theta)]) + + c_inc = np.zeros((ff, 1)) # only need one set... + c_inc[fourier_order] = 1 + c_inc = np.linalg.inv(Wr) @ c_inc + # COMPUTE FIELDS: similar idea but more complex for RCWA since you have individual modes each contributing + reflected = Wr @ Sg['S11'] @ c_inc + transmitted = Wt @ Sg['S21'] @ c_inc + + # reflected is already ry or Ey + rsq = np.square(np.abs(reflected)) + tsq = np.square(np.abs(transmitted)) + + # compute final reflectivity + if pol == 0: + de_ri = np.real(Kzr) @ rsq / np.real(k_inc[2]) + de_ti = np.real(Kzt) @ tsq / np.real(k_inc[2]) + elif pol == 1: + de_ri = np.real(Kzr)@rsq/np.real(k_inc[2]) / n_I**2 + de_ti = np.real(Kzt)@tsq/np.real(k_inc[2]) * n_I**2 / n_II**4 + else: + raise ValueError + + return de_ri.flatten(), de_ti.flatten() + + +def scattering_2d_1(n_I, n_II, theta, phi, k0, period, fourier_order): + kx_inc = n_I * np.sin(theta) * np.cos(phi) + ky_inc = n_I * np.sin(theta) * np.sin(phi) + kz_inc = np.sqrt(n_I ** 2 * 1 - kx_inc ** 2 - ky_inc ** 2) + + Kx, Ky = K_matrix_cubic_2D(kx_inc, ky_inc, k0, period[0], period[1], fourier_order, fourier_order) + + # specify gap media (this is an LHI so no eigenvalue problem should be solved + e_h = 1 + Wg, Vg, Kzg = homogeneous_module(Kx, Ky, e_h) + + # ================= Working on the Reflection Side =========== ## + e_r = n_I ** 2 + Wr, Vr, Kzr = homogeneous_module(Kx, Ky, e_r) + + # ========= Working on the Transmission Side==============## + e_t = n_II ** 2 + Wt, Vt, Kzt = homogeneous_module(Kx, Ky, e_t) + + # calculating A and B matrices for scattering matrix + Ar, Br = A_B_matrices_half_space(Vr, Vg) + + # s_ref is a matrix, Sr_dict is a dictionary + _, Sr_dict = S_RT(Ar, Br, ref_mode=True) # scatter matrix for the reflection region + Sg = Sr_dict + + return Kx, Ky, kz_inc, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg + + +def scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA): + + A, B = A_B_matrices(W, Wg, V, Vg) + _, Sl_dict = S_layer(A, B, d, k0, LAMBDA) + Sg_matrix, Sg = RedhefferStar(Sg, Sl_dict) + + return A, B, Sl_dict, Sg_matrix, Sg + + +def scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, n_I, pol, theta, + phi, fourier_order, ff): + normal_vector = np.array([0, 0, 1]) # positive z points down; + # amplitude of the te vs tm modes (which are decoupled) + + if pol == 0: + pte = 1 + ptm = 0 + elif pol == 1: + pte = 0 + ptm = 1 + else: + raise ValueError + + M = N = fourier_order + NM = ff ** 2 + + # get At, Bt + # since transmission is the same as gap, order does not matter + At, Bt = A_B_matrices_half_space(Vt, Vg) + _, ST_dict = S_RT(At, Bt, ref_mode=False) + + # update global scattering matrix + Sg_matrix, Sg = RedhefferStar(Sg, ST_dict) + + # finally CONVERT THE GLOBAL SCATTERING MATRIX BACK TO A MATRIX + + K_inc_vector = n_I * np.array([np.sin(theta) * np.cos(phi), np.sin(theta) * np.sin(phi), np.cos(theta)]) + + _, e_src, _ = initial_conditions(K_inc_vector, theta, normal_vector, pte, ptm, N, M) + + c_inc = np.linalg.inv(Wr) @ e_src + # COMPUTE FIELDS: similar idea but more complex for RCWA since you have individual modes each contributing + reflected = Wr @ Sg['S11'] @ c_inc + transmitted = Wt @ Sg['S21'] @ c_inc + + rx = reflected[0:NM, :] # rx is the Ex component. + ry = reflected[NM:, :] + tx = transmitted[0:NM, :] + ty = transmitted[NM:, :] + + rz = np.linalg.inv(Kzr) @ (Kx @ rx + Ky @ ry) + tz = np.linalg.inv(Kzt) @ (Kx @ tx + Ky @ ty) + + rsq = np.square(np.abs(rx)) + np.square(np.abs(ry)) + np.square(np.abs(rz)) + tsq = np.square(np.abs(tx)) + np.square(np.abs(ty)) + np.square(np.abs(tz)) + + de_ri = np.real(Kzr)@rsq/np.real(K_inc_vector[2]) # real because we only want propagating components + de_ti = np.real(Kzt)@tsq/np.real(K_inc_vector[2]) + + return de_ri, de_ti + + +def scattering_2d_wv(ff, Kx, Ky, E_conv, oneover_E_conv, oneover_E_conv_i, E_i, mu_conv=None): + # ------------------------- + # W and V from SMM method. + NM = ff ** 2 + if mu_conv is None: + mu_conv = np.identity(NM) + + P, Q, _ = P_Q_kz(Kx, Ky, E_conv, mu_conv, oneover_E_conv, oneover_E_conv_i, E_i) + GAMMA = P @ Q + + Lambda, W = np.linalg.eig(GAMMA) # LAMBDa is effectively refractive index + LAMBDA = np.diag(Lambda) + LAMBDA = np.sqrt(LAMBDA.astype('complex')) + + V = Q @ W @ np.linalg.inv(LAMBDA) + + return W, V, LAMBDA diff --git a/meent/on_torch/smm_util.py b/meent/on_torch/smm_util.py new file mode 100644 index 0000000..9fab4f7 --- /dev/null +++ b/meent/on_torch/smm_util.py @@ -0,0 +1,335 @@ +""" +currently SMM is not supported +""" +# many codes for scattering matrix method are from here: +# https://github.com/zhaonat/Rigorous-Coupled-Wave-Analysis +# also refer our fork https://github.com/yonghakim/zhaonat-rcwa + +import numpy as np +from numpy.linalg import inv, pinv +# TODO: try pseudo-inverse? +from scipy.linalg import block_diag +# TODO: ok by jax? + + +def A_B_matrices_half_space(V_layer, Vg): + + I = np.eye(len(Vg)) + a = I + inv(Vg) @ V_layer + b = I - inv(Vg) @ V_layer + + return a, b + + +def A_B_matrices(W_layer, Wg, V_layer, Vg): + """ + single function to output the a and b matrices needed for the scatter matrices + :param W_layer: gap + :param Wg: + :param V_layer: gap + :param Vg: + :return: + """ + W_i = inv(W_layer) + V_i = inv(V_layer) + + a = W_i @ Wg + V_i @ Vg + b = W_i @ Wg - V_i @ Vg + + return a, b + + +def S_layer(A, B, d, k0, modes): + """ + function to create scatter matrix in the ith layer of the uniform layer structure + we assume that gap layers are used so we need only one A and one B + :param A: function A = + :param B: function B + :param k0 #free -space wavevector magnitude (normalization constant) in Si Units + :param Li #length of ith layer (in Si units) + :param modes, eigenvalue matrix + :return: S (4x4 scatter matrix) and Sdict, which contains the 2x2 block matrix as a dictionary + """ + + # sign convention (EMLAB is exp(-1i*k\dot r)) + X = np.diag(np.exp(-np.diag(modes)*d*k0)) + # TODO: Check + # TODO: expm + + A_i = inv(A) + term_i = inv(A - X @ B @ A_i @ X @ B) + + S11 = term_i @ (X @ B @ A_i @ X @ A - B) + S12 = term_i @ X @ (A - B @ A_i @ B) + S22 = S11 + S21 = S12 + + S_dict = {'S11': S11, 'S22': S22, 'S12': S12, 'S21': S21} + S = np.block([[S11, S12], [S21, S22]]) + return S, S_dict + + +def S_RT(A, B, ref_mode): + + A_i = inv(A) + + S11 = -A_i @ B + S12 = 2 * A_i + S21 = 0.5*(A - B @ A_i @ B) + S22 = B @ A_i + + if ref_mode: + S_dict = {'S11': S11, 'S22': S22, 'S12': S12, 'S21': S21} + S = np.block([[S11, S12], [S21, S22]]) + else: + S_dict = {'S11': S22, 'S22': S11, 'S12': S21, 'S21': S12} + S = np.block([[S22, S21], [S12, S11]]) + return S, S_dict + + +def homogeneous_module(Kx, Ky, e_r, m_r=1, perturbation=1E-16, wl=None, comment=None): + """ + homogeneous layer is much simpler to do, so we will create an isolated module to deal with it + :return: + """ + assert type(Kx) == np.ndarray, 'not np.array' + assert type(Ky) == np.ndarray, 'not np.array' + + N = len(Kx) + I = np.identity(N) + + P = (e_r**-1)*np.block([[Kx*Ky, e_r*m_r*I-Kx**2], [Ky**2-m_r*e_r*I, -Ky*Kx]]) + Q = (e_r/m_r)*P + + diag = np.diag(Q) + idx = np.nonzero(diag == 0)[0] + if len(idx): + # Adding pertub* to Q and pertub to Kz. + # TODO: check why this works. + # TODO: make imaginary part sign consistent + Q[idx, idx] = np.conj(perturbation) + print(wl, comment, 'non-invertible Q: adding perturbation') + # print(Q.diagonal()) + + W = np.eye(N*2) + Kz2 = (m_r*e_r*I-Kx**2-Ky**2).astype('complex') # arg is +kz^2 + # arg = -(m_r*e_r*I-Kx**2-Ky**2) # arg is +kz^2 + # Kz = np.conj(np.sqrt(arg)) # conjugate enforces the negative sign convention (we also have to conjugate er and mur if they are complex) + + Kz = np.sqrt(Kz2) # conjugate enforces the negative sign convention (we also have to conjugate er and mur if they are complex) + Kz = np.conj(Kz) # TODO: conjugate? + + diag = np.diag(Kz) + idx = np.nonzero(diag == 0)[0] + if len(idx): + Kz[idx, idx] = perturbation + print(wl, comment, 'non-invertible Kz: adding perturbation') + # print(Kz.diagonal()) + + eigenvalues = block_diag(1j*Kz, 1j*Kz) # determining the modes of ex, ey... so it appears eigenvalue order MATTERS... + V = Q @ np.linalg.inv(eigenvalues) # eigenvalue order is arbitrary (hard to compare with matlab + + + # V = -1j*Q + + return W, V, Kz + + +def homogeneous_1D(Kx, n_index, m_r=1, pol=None, perturbation=1E-20*(1+1j), wl=None, comment=None): + """ + efficient homogeneous 1D module + :param Kx: + :param e_r: + :param m_r: + :return: + """ + + e_r = n_index ** 2 + + I = np.identity(len(Kx)) + + W = I + Q = (1 / m_r) * (e_r * m_r * I - Kx ** 2) + # Q = Kx**2 - e_r * I + + diag = np.diag(Q) + idx = np.nonzero(diag == 0)[0] + if len(idx): + # Adding pertub* to Q and pertub to Kz. + # TODO: check why this works. + # TODO: make imaginary part sign consistent + Q[idx, idx] = np.conj(perturbation) + print(wl, comment, 'non-invertible Q: adding perturbation') + # print(Q.diagonal()) + + Kz = np.sqrt(m_r*e_r*I-Kx**2) + Kz = np.conj(Kz) # TODO: conjugate? + + # TODO: check Singular or ill-conditioned; spread this to whole code + # invertible check + diag = np.diag(Kz) + idx = np.nonzero(diag == 0)[0] + if len(idx): + Kz[idx, idx] = perturbation + print(wl, comment, 'non-invertible Kz: adding perturbation') + # print(Kz.diagonal()) + + # TODO: why this works... + if pol: # 0: TE, 1: TM + Kz = Kz * (n_index ** 2) + + eigenvalues = -1j*Kz # determining the modes of ex, ey... so it appears eigenvalue order MATTERS... + V = Q @ np.linalg.inv(eigenvalues) # eigenvalue order is arbitrary (hard to compare with matlab + + return W, V, Kz + + +def K_matrix_cubic_2D(beta_x, beta_y, k0, a_x, a_y, N_p, N_q): + # K_i = beta_i - pT1i - q T2i - r*T3i + # but here we apply it only for cubic and tegragonal geometries in 2D + """ + :param beta_x: input k_x,inc/k0 + :param beta_y: k_y,inc/k0; #already normalized...k0 is needed to normalize the 2*pi*lambda/a + however such normalization can cause singular matrices in the homogeneous module (specifically with eigenvalues) + :param T1:reciprocal lattice vector 1 + :param T2: + :param T3: + :return: + """ + # (indexing follows (1,1), (1,2), ..., (1,N), (2,1),(2,2),(2,3)...(M,N) ROW MAJOR + # but in the cubic case, k_x only depends on p and k_y only depends on q + k_x = beta_x - 2*np.pi*np.arange(-N_p, N_p+1)/(k0*a_x) + k_y = beta_y - 2*np.pi*np.arange(-N_q, N_q+1)/(k0*a_y) + + kx, ky = np.meshgrid(k_x, k_y) + Kx = np.diag(kx.flatten()) + Ky = np.diag(ky.flatten()) + + return Kx, Ky + + +def P_Q_kz(Kx, Ky, e_conv, mu_conv, oneover_E_conv, oneover_E_conv_i, E_i): + ''' + r is for relative so do not put epsilon_0 or mu_0 here + :param Kx: NM x NM matrix + :param Ky: + :param e_conv: (NM x NM) conv matrix + :param mu_r: + :return: + ''' + argument = e_conv - Kx ** 2 - Ky ** 2 + Kz = np.conj(np.sqrt(argument.astype('complex'))) + # Kz = np.sqrt(argument.astype('complex')) # TODO: conjugate? + + # TODO: confirm whether oneonver_E_conv is indeed not used + # TODO: Check sign of P and Q + P = np.block([ + [Kx @ E_i @ Ky, -Kx @ E_i @ Kx + mu_conv], + [Ky @ E_i @ Ky - mu_conv, -Ky @ E_i @ Kx] + ]) + + Q = np.block([ + [Kx @ inv(mu_conv) @ Ky, -Kx @ inv(mu_conv) @ Kx + e_conv], + [-oneover_E_conv_i + Ky @ inv(mu_conv) @ Ky, -Ky @ inv(mu_conv) @ Kx] + ]) + + return P, Q, Kz + + +def delta_vector(P, Q): + ''' + create a vector with a 1 corresponding to the 0th order + #input P = 2*(num_ord_specified)+1 + ''' + fourier_grid = np.zeros((P,Q)) + fourier_grid[int(P/2), int(Q/2)] = 1 + # vector = np.zeros((P*Q,)); + # + # #the index of the (0,0) element requires a conversion using sub2ind + # index = int(P/2)*P + int(Q/2); + vector = fourier_grid.flatten() + return np.matrix(np.reshape(vector, (1,len(vector)))) + + +def initial_conditions(K_inc_vector, theta, normal_vector, pte, ptm, P, Q): + """ + :param K_inc_vector: whether it's normalized or not is not important... + :param theta: angle of incience + :param normal_vector: pointing into z direction + :param pte: te polarization amplitude + :param ptm: tm polarization amplitude + :return: + calculates the incident E field, cinc, and the polarization fro the initial condition vectors + """ + # ate -> unit vector holding the out of plane direction of TE + # atm -> unit vector holding the out of plane direction of TM + # what are the out of plane components...(Ey and Hy) + # normal_vector = [0,0,-1]; i.e. waves propagate down into the -z direction + # cinc = Wr^-1@[Ex_inc, Ey_inc]; + + if theta != 0: + ate_vector = np.cross(K_inc_vector, normal_vector) + ate_vector = ate_vector / (np.linalg.norm(ate_vector)) + else: + ate_vector = np.array([0, 1, 0]) + + atm_vector = np.cross(ate_vector, K_inc_vector) + atm_vector = atm_vector / (np.linalg.norm(atm_vector)) + + polarization = pte * ate_vector + ptm * atm_vector # total E_field incident which is a 3 component vector (ex, ey, ez) + E_inc = polarization + # go from mode coefficients to FIELDS + delta = delta_vector(2*P+1, 2*Q+1) + + # c_inc; #remember we ultimately solve for [Ex, Ey, Hx, Hy]. + e_src = np.hstack((polarization[0]*delta, polarization[1]*delta)) + e_src = np.matrix(e_src).T # mode amplitudes of Ex, and Ey + + return E_inc, e_src, polarization + + +def RedhefferStar(SA, SB): # SA and SB are both 2x2 block matrices; + """ + RedhefferStar for arbitrarily sized 2x2 block matrices for RCWA + :param SA: dictionary containing the four sub-blocks + :param SB: dictionary containing the four sub-blocks, + keys are 'S11', 'S12', 'S21', 'S22' + :return: + """ + + assert type(SA) == dict, 'not dict' + assert type(SB) == dict, 'not dict' + + # once we break every thing like this, we should still have matrices + SA_11, SA_12, SA_21, SA_22 = SA['S11'], SA['S12'], SA['S21'], SA['S22'] + SB_11, SB_12, SB_21, SB_22 = SB['S11'], SB['S12'], SB['S21'], SB['S22'] + N = len(SA_11) # SA_11 should be square so length is fine + + I = np.eye(N) + D_i = inv(I - SB_11 @ SA_22) + F_i = inv(I - SA_22 @ SB_11) + + SAB_11 = SA_11 + SA_12 @ D_i @ SB_11 @ SA_21 + SAB_12 = SA_12 @ D_i @ SB_12 + SAB_21 = SB_21 @ F_i @ SA_21 + SAB_22 = SB_22 + SB_21 @ F_i @ SA_22 @ SB_12 + + SAB = np.block([[SAB_11, SAB_12], [SAB_21, SAB_22]]) + SAB_dict = {'S11': SAB_11, 'S22': SAB_22, 'S12': SAB_12, 'S21': SAB_21} + + return SAB, SAB_dict + + +def construct_global_scatter(scatter_list): + """ + this function assumes an RCWA implementation where all the scatter matrices are stored in a list + and the global scatter matrix is constructed at the end + :param scatter_list: list of scatter matrices of the form [Sr, S1, S2, ... , SN, ST] + :return: + """ + Sr = scatter_list[0] + Sg = Sr + for i in range(1, len(scatter_list)): + Sg = RedhefferStar(Sg, scatter_list[i]) + return Sg + diff --git a/meent/on_torch/transfer_method.py b/meent/on_torch/transfer_method.py new file mode 100644 index 0000000..eec8894 --- /dev/null +++ b/meent/on_torch/transfer_method.py @@ -0,0 +1,445 @@ +import numpy as np +import torch + + +def transfer_1d_1(ff, polarization, k0, n_I, n_II, kx_vector, theta, delta_i0, fourier_order, + device='cpu', type_complex=torch.complex128): + + # kx_vector = k0 * (n_I * torch.sin(theta) - fourier_indices * (wavelength / period[0])).type(type_complex) + + k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2) ** 0.5 + k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2) ** 0.5 + + k_I_z = torch.conj(k_I_z) + k_II_z = torch.conj(k_II_z) + + Kx = torch.diag(kx_vector / k0) + + f = torch.eye(ff, device=device, dtype=type_complex) + + if polarization == 0: # TE + Y_I = torch.diag(k_I_z / k0) + Y_II = torch.diag(k_II_z / k0) + + YZ_I = Y_I + g = 1j * Y_II + inc_term = 1j * n_I * torch.cos(theta) * delta_i0 + + elif polarization == 1: # TM + Z_I = torch.diag(k_I_z / (k0 * n_I ** 2)) + Z_II = torch.diag(k_II_z / (k0 * n_II ** 2)) + + YZ_I = Z_I + g = 1j * Z_II + inc_term = 1j * delta_i0 * torch.cos(theta) / n_I + + else: + raise ValueError + + T = torch.eye(2 * fourier_order + 1, device=device, dtype=type_complex) + + return kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T + + +def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T, device='cpu', type_complex=torch.complex128): + + X = torch.diag(torch.exp(-k0 * q * d)) + + W_i = torch.linalg.inv(W) + V_i = torch.linalg.inv(V) + + a = 0.5 * (W_i @ f + V_i @ g) + b = 0.5 * (W_i @ f - V_i @ g) + + a_i = torch.linalg.inv(a) + + f = W @ (torch.eye(2 * fourier_order + 1, device=device, dtype=type_complex) + X @ b @ a_i @ X) + g = V @ (torch.eye(2 * fourier_order + 1, device=device, dtype=type_complex) - X @ b @ a_i @ X) + T = T @ a_i @ X + + return X, f, g, T, a_i, b + + +def transfer_1d_3(g1, YZ_I, f1, delta_i0, inc_term, T, k_I_z, k0, n_I, n_II, theta, polarization, k_II_z): + + T1 = torch.linalg.inv(g1 + 1j * YZ_I @ f1) @ (1j * YZ_I @ delta_i0 + inc_term) + R = f1 @ T1 - delta_i0 + T = T @ T1 + + de_ri = torch.real(R * torch.conj(R) * k_I_z / (k0 * n_I * torch.cos(theta))) + if polarization == 0: + # de_ti = T * np.conj(T) * np.real(k_II_z / (k0 * n_I * np.cos(theta))) + de_ti = torch.real(T * torch.conj(T) * k_II_z / (k0 * n_I * torch.cos(theta))) + elif polarization == 1: + # de_ti = T * np.conj(T) * np.real(k_II_z / n_II ** 2) / (k0 * np.cos(theta) / n_I) + de_ti = torch.real(T * torch.conj(T) * k_II_z / n_II ** 2) / (k0 * torch.cos(theta) / n_I) + else: + raise ValueError + + return de_ri, de_ti, T1 + + +def transfer_1d_conical_1(ff, k0, n_I, n_II, kx_vector, theta, phi, device='cpu', type_complex=torch.complex128): + + I = torch.eye(ff, device=device, dtype=type_complex) + O = torch.zeros((ff, ff), device=device, dtype=type_complex) + + # kx_vector = k0 * (n_I * torch.sin(theta) * torch.cos(phi) - fourier_indices * ( + # wavelength / period[0])).type(type_complex) + + ky = k0 * n_I * torch.sin(theta) * torch.sin(phi) + + k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 + k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky ** 2) ** 0.5 + + k_I_z = torch.conj(k_I_z.flatten()) + k_II_z = torch.conj(k_II_z.flatten()) + + Kx = torch.diag(kx_vector / k0) + + varphi = torch.arctan(ky / kx_vector) + + Y_I = torch.diag(k_I_z / k0) + Y_II = torch.diag(k_II_z / k0) + + Z_I = torch.diag(k_I_z / (k0 * n_I ** 2)) + Z_II = torch.diag(k_II_z / (k0 * n_II ** 2)) + + big_F = torch.cat( + [ + torch.cat([I, O], dim=1), + torch.cat([O, 1j * Z_II], dim=1), + ] + ) + + big_G = torch.cat( + [ + torch.cat([1j * Y_II, O], dim=1), + torch.cat([O, I], dim=1), + ] + ) + + big_T = torch.eye(ff * 2, device=device, dtype=type_complex) + + return Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T + + +def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_i, o_E_conv_i, ff, d, varphi, big_F, big_G, big_T, + device='cpu', type_complex=torch.complex128): + + I = torch.eye(ff, device=device, dtype=type_complex) + O = torch.zeros((ff, ff), device=device, dtype=type_complex) + + A = Kx ** 2 - E_conv + B = Kx @ E_i @ Kx - I + A_i = torch.linalg.inv(A) + B_i = torch.linalg.inv(B) + + to_decompose_W_1 = ky ** 2 * I + A + to_decompose_W_2 = ky ** 2 * I + B @ o_E_conv_i + + eigenvalues_1, W_1 = torch.linalg.eig(to_decompose_W_1) + eigenvalues_2, W_2 = torch.linalg.eig(to_decompose_W_2) + + q_1 = eigenvalues_1 ** 0.5 + q_2 = eigenvalues_2 ** 0.5 + + Q_1 = torch.diag(q_1) + Q_2 = torch.diag(q_2) + + V_11 = A_i @ W_1 @ Q_1 + V_12 = (ky / k0) * A_i @ Kx @ W_2 + V_21 = (ky / k0) * B_i @ Kx @ E_i @ W_1 + V_22 = B_i @ W_2 @ Q_2 + + X_1 = torch.diag(torch.exp(-k0 * q_1 * d)) + X_2 = torch.diag(torch.exp(-k0 * q_2 * d)) + + F_c = torch.diag(torch.cos(varphi)) + F_s = torch.diag(torch.sin(varphi)) + + V_ss = F_c @ V_11 + V_sp = F_c @ V_12 - F_s @ W_2 + W_ss = F_c @ W_1 + F_s @ V_21 + W_sp = F_s @ V_22 + W_ps = F_s @ V_11 + W_pp = F_c @ W_2 + F_s @ V_12 + V_ps = F_c @ V_21 - F_s @ W_1 + V_pp = F_c @ V_22 + + big_I = torch.eye(2 * (len(I)), device=device, dtype=type_complex) + + big_X = torch.cat([ + torch.cat([X_1, O], dim=1), + torch.cat([O, X_2], dim=1)]) + + big_W = torch.cat([ + torch.cat([V_ss, V_sp], dim=1), + torch.cat([W_ps, W_pp], dim=1)]) + + big_V = torch.cat([ + torch.cat([W_ss, W_sp], dim=1), + torch.cat([V_ps, V_pp], dim=1)]) + + big_W_i = torch.linalg.inv(big_W) + big_V_i = torch.linalg.inv(big_V) + + + big_A = 0.5 * (big_W_i @ big_F + big_V_i @ big_G) + big_B = 0.5 * (big_W_i @ big_F - big_V_i @ big_G) + + big_A_i = torch.linalg.inv(big_A) + + big_F = big_W @ (big_I + big_X @ big_B @ big_A_i @ big_X) + big_G = big_V @ (big_I - big_X @ big_B @ big_A_i @ big_X) + + big_T = big_T @ big_A_i @ big_X + + return big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 + + +def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, + device='cpu', type_complex=torch.complex128): + + I = torch.eye(ff, device=device, dtype=type_complex) + O = torch.zeros((ff, ff), device=device, dtype=type_complex) + + big_F_11 = big_F[:ff, :ff] + big_F_12 = big_F[:ff, ff:] + big_F_21 = big_F[ff:, :ff] + big_F_22 = big_F[ff:, ff:] + + big_G_11 = big_G[:ff, :ff] + big_G_12 = big_G[:ff, ff:] + big_G_21 = big_G[ff:, :ff] + big_G_22 = big_G[ff:, ff:] + + final_A = torch.cat( + [ + torch.cat([I, O, -big_F_11, -big_F_12], dim=1), + torch.cat([O, -1j * Z_I, -big_F_21, -big_F_22], dim=1), + torch.cat([-1j * Y_I, O, -big_G_11, -big_G_12], dim=1), + torch.cat([O, I, -big_G_21, -big_G_22], dim=1), + ] + ) + + final_B = torch.cat([ + -torch.sin(psi) * delta_i0, + -torch.cos(psi) * torch.cos(theta) * delta_i0, + -1j * torch.sin(psi) * n_I * torch.cos(theta) * delta_i0, + 1j * n_I * torch.cos(psi) * delta_i0 + ]) + + final_RT = torch.linalg.inv(final_A) @ final_B + + R_s = final_RT[:ff].flatten() + R_p = final_RT[ff:2 * ff].flatten() + + big_T1 = final_RT[2 * ff:] + big_T = big_T @ big_T1 + + T_s = big_T[:ff].flatten() + T_p = big_T[ff:].flatten() + + de_ri = R_s * torch.conj(R_s) * torch.real(k_I_z / (k0 * n_I * torch.cos(theta))) \ + + R_p * torch.conj(R_p) * torch.real((k_I_z / n_I ** 2) / (k0 * n_I * torch.cos(theta))) + + de_ti = T_s * torch.conj(T_s) * torch.real(k_II_z / (k0 * n_I * torch.cos(theta))) \ + + T_p * torch.conj(T_p) * torch.real((k_II_z / n_II ** 2) / (k0 * n_I * torch.cos(theta))) + + return de_ri.real, de_ti.real, big_T1 + + + +def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, phi, wavelength, + device='cpu', type_complex=torch.complex128): + + I = torch.eye(ff ** 2, device=device, dtype=type_complex) + O = torch.zeros((ff ** 2, ff ** 2), device=device, dtype=type_complex) + + # kx_vector = k0 * (n_I * torch.sin(theta) * torch.cos(phi) - fourier_indices * ( + # wavelength / period[0])).type(type_complex) + ky_vector = k0 * (n_I * torch.sin(theta) * torch.sin(phi) - fourier_indices * ( + wavelength / period[1])).type(type_complex) + + k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 + k_II_z = (k0 ** 2 * n_II ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 + + k_I_z = torch.conj(k_I_z.flatten()) + k_II_z = torch.conj(k_II_z.flatten()) + + Kx = torch.diag(kx_vector.tile(ff).flatten() / k0) + Ky = torch.diag(ky_vector.reshape((-1, 1)).tile(ff).flatten() / k0) + + varphi = torch.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() + + Y_I = torch.diag(k_I_z / k0) + Y_II = torch.diag(k_II_z / k0) + + Z_I = torch.diag(k_I_z / (k0 * n_I ** 2)) + Z_II = torch.diag(k_II_z / (k0 * n_II ** 2)) + + big_F = torch.cat( + [ + torch.cat([I, O], dim=1), + torch.cat([O, 1j * Z_II], dim=1), + ] + ) + + big_G = torch.cat( + [ + torch.cat([1j * Y_II, O], dim=1), + torch.cat([O, I], dim=1), + ] + ) + + big_T = torch.eye(ff ** 2 * 2, device=device, dtype=type_complex) + + return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T + + +def transfer_2d_wv(ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, device='cpu', type_complex=torch.complex128): + + I = torch.eye(ff ** 2, device=device, dtype=type_complex) + + B = Kx @ E_conv_i @ Kx - I + D = Ky @ E_conv_i @ Ky - I + + S2_from_S = torch.cat( + [ + torch.cat([Ky ** 2 + B @ o_E_conv_i, Kx @ (E_conv_i @ Ky @ E_conv - Ky)], dim=1), + torch.cat([Ky @ (E_conv_i @ Kx @ o_E_conv_i - Kx), Kx ** 2 + D @ E_conv], dim=1) + ]) + + eigenvalues, W = torch.linalg.eig(S2_from_S) + + q = eigenvalues ** 0.5 + + Q = torch.diag(q) + Q_i = torch.linalg.inv(Q) + U1_from_S = torch.cat( + [ + torch.cat([-Kx @ Ky, Kx ** 2 - E_conv], dim=1), + torch.cat([o_E_conv_i - Ky ** 2, Ky @ Kx], dim=1) + ] + ) + V = U1_from_S @ W @ Q_i + + return W, V, q + + +def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, device='cpu', + type_complex=torch.complex128): + + q1 = q[:center] + q2 = q[center:] + + W_11 = W[:center, :center] + W_12 = W[:center, center:] + W_21 = W[center:, :center] + W_22 = W[center:, center:] + + V_11 = V[:center, :center] + V_12 = V[:center, center:] + V_21 = V[center:, :center] + V_22 = V[center:, center:] + + X_1 = torch.diag(torch.exp(-k0 * q1 * d)) + X_2 = torch.diag(torch.exp(-k0 * q2 * d)) + + F_c = torch.diag(torch.cos(varphi)) + F_s = torch.diag(torch.sin(varphi)) + + W_ss = F_c @ W_21 - F_s @ W_11 + W_sp = F_c @ W_22 - F_s @ W_12 + W_ps = F_c @ W_11 + F_s @ W_21 + W_pp = F_c @ W_12 + F_s @ W_22 + + V_ss = F_c @ V_11 + F_s @ V_21 + V_sp = F_c @ V_12 + F_s @ V_22 + V_ps = F_c @ V_21 - F_s @ V_11 + V_pp = F_c @ V_22 - F_s @ V_12 + + big_I = torch.eye(2 * (len(I)), device=device, dtype=type_complex) + + big_X = torch.cat([ + torch.cat([X_1, O], dim=1), + torch.cat([O, X_2], dim=1)]) + + big_W = torch.cat([ + torch.cat([W_ss, W_sp], dim=1), + torch.cat([W_ps, W_pp], dim=1)]) + + big_V = torch.cat([ + torch.cat([V_ss, V_sp], dim=1), + torch.cat([V_ps, V_pp], dim=1)]) + + big_W_i = torch.linalg.inv(big_W) + big_V_i = torch.linalg.inv(big_V) + + big_A = 0.5 * (big_W_i @ big_F + big_V_i @ big_G) + big_B = 0.5 * (big_W_i @ big_F - big_V_i @ big_G) + + big_A_i = torch.linalg.inv(big_A) + + big_F = big_W @ (big_I + big_X @ big_B @ big_A_i @ big_X) + big_G = big_V @ (big_I - big_X @ big_B @ big_A_i @ big_X) + + big_T = big_T @ big_A_i @ big_X + + return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 + + +def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, + device='cpu', type_complex=torch.complex128): + + I = torch.eye(ff ** 2, device=device, dtype=type_complex) + O = torch.zeros((ff ** 2, ff ** 2), device=device, dtype=type_complex) + + big_F_11 = big_F[:center, :center] + big_F_12 = big_F[:center, center:] + big_F_21 = big_F[center:, :center] + big_F_22 = big_F[center:, center:] + + big_G_11 = big_G[:center, :center] + big_G_12 = big_G[:center, center:] + big_G_21 = big_G[center:, :center] + big_G_22 = big_G[center:, center:] + + # Final Equation in form of AX=B + final_A = torch.cat( + [ + torch.cat([I, O, -big_F_11, -big_F_12], dim=1), + torch.cat([O, -1j * Z_I, -big_F_21, -big_F_22], dim=1), + torch.cat([-1j * Y_I, O, -big_G_11, -big_G_12], dim=1), + torch.cat([O, I, -big_G_21, -big_G_22], dim=1), + ] + ) + + final_B = torch.cat( + [ + torch.cat([-torch.sin(psi) * delta_i0], dim=1), + torch.cat([-torch.cos(psi) * torch.cos(theta) * delta_i0], dim=1), + torch.cat([-1j * torch.sin(psi) * n_I * torch.cos(theta) * delta_i0], dim=1), + torch.cat([1j * n_I * torch.cos(psi) * delta_i0], dim=1), + ] + ) + + final_RT = torch.linalg.inv(final_A) @ final_B + + R_s = final_RT[:ff ** 2, :].flatten() + R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() + + big_T1 = final_RT[2 * ff ** 2:, :] + big_T = big_T @ big_T1 + + T_s = big_T[:ff ** 2, :].flatten() + T_p = big_T[ff ** 2:, :].flatten() + + de_ri = R_s * torch.conj(R_s) * torch.real(k_I_z / (k0 * n_I * torch.cos(theta))) \ + + R_p * torch.conj(R_p) * torch.real((k_I_z / n_I ** 2) / (k0 * n_I * torch.cos(theta))) + + de_ti = T_s * torch.conj(T_s) * torch.real(k_II_z / (k0 * n_I * torch.cos(theta))) \ + + T_p * torch.conj(T_p) * torch.real((k_II_z / n_II ** 2) / (k0 * n_I * torch.cos(theta))) + + return de_ri.real, de_ti.real, big_T1 diff --git a/meent/rcwa.py b/meent/rcwa.py index 3ea11da..f36cd0b 100644 --- a/meent/rcwa.py +++ b/meent/rcwa.py @@ -1,6 +1,10 @@ import numpy as np +import meent.integ.backend +import jax +from functools import partial +# @partial(jax.jit, static_argnums=(0, 1)) def call_solver(mode=0, *args, **kwargs): """ decide backend and return RCWA solver instance @@ -15,11 +19,25 @@ def call_solver(mode=0, *args, **kwargs): """ if mode == 0: - from meent.on_numpy.rcwa import RCWALight - RCWA = RCWALight(mode, *args, **kwargs) + from meent.on_numpy.rcwa import RCWANumpy + RCWA = RCWANumpy(mode, *args, **kwargs) elif mode == 1: - from meent.on_jax.rcwa import RCWAOpt - RCWA = RCWAOpt(mode, *args, **kwargs) + from meent.on_jax.rcwa import RCWAJax + RCWA = RCWAJax(mode, *args, **kwargs) + elif mode == 2: + from meent.on_torch.rcwa import RCWATorch + RCWA = RCWATorch(mode, *args, **kwargs) + + elif mode == 3: + meent.integ.backend.mode = 2 + from meent.integ.rcwa import RCWAInteg + RCWA = RCWAInteg(mode, *args, **kwargs) + elif mode == 4: + meent.integ.backend.mode = 3 + + from meent.integ.rcwa import RCWAInteg + RCWA = RCWAInteg(mode, *args, **kwargs) + else: raise ValueError diff --git a/setup.py b/setup.py index 1dc829f..e384428 100644 --- a/setup.py +++ b/setup.py @@ -2,14 +2,13 @@ setup( name='meent', - version='0.5.0', + version='0.6.0', url='https://github.com/kc-ml2/meent', author='KC ML2', author_email='yongha@kc-ml2.com', packages=['meent'] + find_packages(include=['meent.*']), install_requires=[ 'numpy==1.23.3', - 'scipy==1.9.1', 'jax==0.3.21', 'matplotlib==3.5.3', ], diff --git a/temp_field_dist.txt b/temp_field_dist.txt deleted file mode 100644 index 75c3d9e..0000000 --- a/temp_field_dist.txt +++ /dev/null @@ -1,126 +0,0 @@ -order 40 -q - -array([3.35866826e+01-3.27595522e-15j, 3.35618294e+01-3.20833711e-15j, - 3.22939696e+01+2.83140241e-16j, 3.22699870e+01+4.62951674e-16j, - 3.12640578e+01-2.23703358e-16j, 3.11452435e+01+5.95205307e-16j, - 3.10298359e+01+2.12145950e-15j, 3.10453122e+01+2.08520308e-15j, - 2.98340366e+01+8.43029606e-16j, 2.98134544e+01+6.76119196e-16j, - 2.86557306e+01+2.34237107e-15j, 2.85928550e+01+1.74650649e-15j, - 2.83771775e+01+2.93517588e-15j, 2.84985638e+01+1.71502901e-15j, - 2.74802888e+01+2.09731195e-15j, 2.73742124e+01+1.04684814e-15j, - 2.63110374e+01-5.90070891e-16j, 2.61958871e+01+9.28920644e-16j, - 2.58601637e+01+2.77396667e-15j, 2.57621181e+01+4.58473373e-15j, - 2.49582118e+01-5.04410417e-15j, 2.51335596e+01-2.68942318e-15j, - 2.39511297e+01-3.59585328e-16j, 2.38562244e+01+1.33552819e-15j, - 2.34371788e+01+1.83101795e-17j, 2.30442357e+01+5.24875651e-18j, - 2.26237203e+01+8.59663554e-17j, 2.27518378e+01-3.44715778e-17j, - 2.16119621e+01+9.31231455e-16j, 2.15638873e+01+1.18218640e-15j, - 2.09508424e+01-1.42362381e-15j, 2.07897253e+01+1.33434731e-15j, - 2.06141254e+01+2.99040019e-16j, 2.03541292e+01-2.14529863e-15j, - 1.89939163e+01-7.30857645e-16j, 1.94229577e+01+5.08844243e-16j, - 1.95570191e+01-1.71637706e-15j, 1.84885560e+01-1.81761123e-16j, - 1.78841770e+01+2.95323399e-16j, 1.81993910e+01-5.91437968e-16j, - 1.72054436e+01+1.26638751e-16j, 1.69214946e+01-1.66963428e-16j, - 1.65293449e+01-6.17985257e-16j, 1.60390106e+01-2.62487201e-15j, - 1.47447916e+01+9.63804000e-16j, 1.42713423e+01-6.82928035e-19j, - 1.38685600e+01-2.69827710e-16j, 1.54143071e+01+6.37178518e-16j, - 1.54080828e+01+1.28406775e-16j, 1.35641204e+01+2.03274658e-15j, - 1.29215904e+01+1.92923984e-15j, 1.25807543e+01-1.55038882e-15j, - 1.17361383e+01-4.24126027e-16j, 1.22088279e+01+1.29393030e-15j, - 1.04118323e+01+5.76707512e-16j, 1.10047498e+01+3.42165178e-15j, - 1.10764038e+01-5.74102490e-16j, 9.90976887e+00-3.29404508e-15j, - 9.22957800e+00-2.45584562e-15j, 9.42418699e+00+3.40855707e-15j, - 8.57153949e+00-1.59322612e-15j, 4.62907166e-15+3.18234903e+00j, - 8.13360785e+00+1.06923528e-15j, 6.71435841e+00-2.03588569e-15j, - 7.37449444e+00-2.14663810e-15j, 6.41529531e+00+4.32345651e-15j, - 8.91443270e-15-2.10162471e+00j, 5.51221781e+00-2.76072402e-15j, - 7.76814765e+00-9.32531495e-16j, 4.66209909e+00-7.35179706e-15j, - 5.98992969e+00+2.36920928e-16j, 4.11776889e+00+2.04064649e-15j, - 2.91490688e+00+1.08001393e-15j, 1.60967557e+00-6.41374469e-17j, - 3.62170886e+00-5.84792686e-15j, 4.83763255e+00+4.06046540e-15j, - 2.63858194e-14+5.88962851e-01j, 2.69505995e+00+1.45337289e-14j, - 2.07465320e+00+5.60751887e-15j, 9.70786423e-15-1.03101218e+00j, - 5.78261535e-01+3.37551899e-14j]) - -order 1 -array([3.01304238e-17+1.63591178j, 3.26978326e-17-0.87726555j, - 9.64129996e-18+0.92524245j]) -np.linalg.norm(q) -Out[2]: 2.0740963740142733 -q.sum() -Out[3]: (7.246955634614228e-17+1.6838886719521762j) - -order 2 -[8.65420170e-17-2.45595885e+00j 8.43925441e-01+6.12866345e-17j - 7.43393357e-01+5.98495157e-17j 1.40563854e-16-9.66315913e-01j - 5.26799094e-17+9.86035555e-01j] -3.0335804352073232 -(1.587318797883209-2.4362392068003387j) - -order 3 -[6.50167859e-17+3.00273901e+00j 1.90396653e+00-6.68538359e-20j - 1.70501940e+00-4.39357506e-17j 6.36924729e-01-1.14308215e-16j - 4.36941188e-01-4.31237120e-16j 1.90516737e-16+9.89167818e-01j - 3.73115648e-17+1.21900387e+00j] -4.313888665063014 -(4.68285185060247+5.210910700247273j) - -order 4 -[2.06532285e-16+3.14629384e+00j 2.75261845e+00+1.56321608e-16j - 2.59305433e+00+1.84015116e-16j 8.13194726e-17+1.67659734e+00j - 1.80647695e+00-2.04768593e-16j 1.57338276e+00-2.05251081e-16j - 9.02491518e-17+1.00778400e+00j 6.30240157e-17-4.15283378e-01j - 6.00310596e-01-1.27748617e-16j] -5.85647817360063 -(9.325843082614357+5.415391805547799j) - -5 -array([[ 4.33012702+0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j], - [ 0. +0.j, 3.46410162+0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j], - [ 0. +0.j, 0. +0.j, 2.59807621+0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j], - [ 0. +0.j, 0. +0.j, 0. +0.j, - 1.73205081+0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j], - [ 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0.8660254 +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j], - [ 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j], - [ 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - -0.8660254 +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j], - [ 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, -1.73205081+0.j, 0. +0.j, - 0. +0.j, 0. +0.j], - [ 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, -2.59807621+0.j, - 0. +0.j, 0. +0.j], - [ 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - -3.46410162+0.j, 0. +0.j], - [ 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, 0. +0.j, 0. +0.j, - 0. +0.j, -4.33012702+0.j]]) - - - -