From d6805bdfabc548122bcd2e0a0615613eac2edee1 Mon Sep 17 00:00:00 2001 From: yonghakim Date: Fri, 20 Dec 2024 21:39:10 +0900 Subject: [PATCH] return field cell for all 3 input polization cases - given pol by user, TE and TM. --- benchmarks/reti_meent_1Dc.py | 19 +- benchmarks/reti_meent_2D.py | 25 +- meent/on_jax/emsolver/field_distribution.py | 116 +++++---- meent/on_jax/emsolver/rcwa.py | 16 +- meent/on_jax/emsolver/transfer_method.py | 10 +- meent/on_numpy/emsolver/field_distribution.py | 109 ++++---- meent/on_numpy/emsolver/rcwa.py | 12 +- meent/on_numpy/emsolver/transfer_method.py | 6 +- meent/on_torch/emsolver/field_distribution.py | 235 +++++++++++++----- meent/on_torch/emsolver/rcwa.py | 21 +- meent/on_torch/emsolver/transfer_method.py | 13 +- tutorials/01-modeling-and-emsolver.ipynb | 87 +++++-- 12 files changed, 455 insertions(+), 214 deletions(-) diff --git a/benchmarks/reti_meent_1Dc.py b/benchmarks/reti_meent_1Dc.py index a9d6d58..bc83a83 100644 --- a/benchmarks/reti_meent_1Dc.py +++ b/benchmarks/reti_meent_1Dc.py @@ -32,21 +32,32 @@ def run_1dc(option, plot_figure=False): # Numpy mee = meent.call_mee(backend=0, **option) res_numpy = mee.conv_solve() - field_cell_numpy = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x) + field_cell_numpy = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True)) # JAX mee = meent.call_mee(backend=1, **option) # JAX res_jax = mee.conv_solve() - field_cell_jax = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x) + field_cell_jax = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True)) # Torch mee = meent.call_mee(backend=2, **option) # PyTorch res_torch = mee.conv_solve() - field_cell_torch = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x).numpy() + field_cell_torch = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True)).numpy() bds = ['Numpy', 'JAX', 'Torch'] fields = [field_cell_numpy, field_cell_jax, field_cell_torch] + print('Field difference: given_pol to TE, given_pol to TM') + n1 = np.linalg.norm(field_cell_numpy[0] - field_cell_numpy[1]) + n2 = np.linalg.norm(field_cell_numpy[0] - field_cell_numpy[2]) + j1 = np.linalg.norm(field_cell_jax[0] - field_cell_jax[1]) + j2 = np.linalg.norm(field_cell_jax[0] - field_cell_jax[2]) + t1 = np.linalg.norm(field_cell_torch[0] - field_cell_torch[1]) + t2 = np.linalg.norm(field_cell_torch[0] - field_cell_torch[2]) + print(f'numpy: {n1}, {n2}') + print(f'jax: {j1}, {j2}') + print(f'torch: {t1}, {t2}') + print('Norm of (meent - reti) per backend') for i, res_t in enumerate([res_numpy, res_jax, res_torch]): reti_de_ri_te, reti_de_ti_te = np.array(top_refl_info_te.efficiency).T, np.array(top_tran_info_te.efficiency).T @@ -78,7 +89,7 @@ def run_1dc(option, plot_figure=False): ) for i_field in range(reti_field_cell.shape[-1]): - res_temp = np.linalg.norm(fields[i][i_field] - reti_field_cell[i_field]) + res_temp = np.linalg.norm(fields[i][0,:,:,:,i_field] - reti_field_cell[:,:,:,i_field]) print(f'field, {i_field+1}th: {res_temp}') if plot_figure: diff --git a/benchmarks/reti_meent_2D.py b/benchmarks/reti_meent_2D.py index 863c542..248d4db 100644 --- a/benchmarks/reti_meent_2D.py +++ b/benchmarks/reti_meent_2D.py @@ -34,21 +34,32 @@ def run_2d(option, case, plot_figure=False): # Numpy mee = meent.call_mee(backend=0, **option) res_numpy = mee.conv_solve() - field_cell_numpy = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x) + field_cell_numpy = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True)) # JAX mee = meent.call_mee(backend=1, **option) # JAX res_jax = mee.conv_solve() - field_cell_jax = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x) + field_cell_jax = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True)) # Torch mee = meent.call_mee(backend=2, **option) # PyTorch res_torch = mee.conv_solve() - field_cell_torch = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x).numpy() + field_cell_torch = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True)).numpy() bds = ['Numpy', 'JAX', 'Torch'] fields = [field_cell_numpy, field_cell_jax, field_cell_torch] + print('Field difference: given_pol to TE, given_pol to TM') + n1 = np.linalg.norm(field_cell_numpy[0] - field_cell_numpy[1]) + n2 = np.linalg.norm(field_cell_numpy[0] - field_cell_numpy[2]) + j1 = np.linalg.norm(field_cell_jax[0] - field_cell_jax[1]) + j2 = np.linalg.norm(field_cell_jax[0] - field_cell_jax[2]) + t1 = np.linalg.norm(field_cell_torch[0] - field_cell_torch[1]) + t2 = np.linalg.norm(field_cell_torch[0] - field_cell_torch[2]) + print(f'numpy: {n1}, {n2}') + print(f'jax: {j1}, {j2}') + print(f'torch: {t1}, {t2}') + print('Norm of (meent - reti) per backend') for i, res_t in enumerate([res_numpy, res_jax, res_torch]): reti_de_ri_te, reti_de_ti_te = np.array(top_refl_info_te.efficiency).T, np.array(top_tran_info_te.efficiency).T @@ -85,7 +96,7 @@ def run_2d(option, case, plot_figure=False): ) for i_field in range(reti_field_cell.shape[-1]): - res_temp = np.linalg.norm(fields[i][i_field] - reti_field_cell[i_field]) + res_temp = np.linalg.norm(fields[i][0,:,:,:,i_field] - reti_field_cell[:,:,:,i_field]) print(f'field, {i_field+1}th: {res_temp}') if plot_figure: @@ -103,7 +114,7 @@ def run_2d(option, case, plot_figure=False): im = axes[ix, 4].imshow(r_data.imag, cmap='jet', aspect='auto') fig.colorbar(im, ax=axes[ix, 4], shrink=1) - n_data = fields[i][:, res_y//2, :, ix] + n_data = fields[i][0, :, res_y//2, :, ix] im = axes[ix, 1].imshow(abs(n_data) ** 2, cmap='jet', aspect='auto') fig.colorbar(im, ax=axes[ix, 1], shrink=1) @@ -161,7 +172,7 @@ def case_2d_1(plot_figure=False): def case_2d_2(plot_figure=False): factor = 1 option = {} - option['pol'] = 1 # 0: TE, 1: TM + option['pol'] = 0 # 0: TE, 1: TM option['n_top'] = 1 # n_incidence option['n_bot'] = 1 # n_transmission option['theta'] = 20 * np.pi / 180 @@ -254,7 +265,7 @@ def case_2d_4(plot_figure=False): def case_2d_5(plot_figure=False): factor = 1 option = {} - option['pol'] = 0 # 0: TE, 1: TM + option['pol'] = 0.5 # 0: TE, 1: TM option['n_top'] = 1 # n_incidence option['n_bot'] = 1 # n_transmission option['theta'] = 0 * np.pi / 180 diff --git a/meent/on_jax/emsolver/field_distribution.py b/meent/on_jax/emsolver/field_distribution.py index fa27598..2fdb016 100644 --- a/meent/on_jax/emsolver/field_distribution.py +++ b/meent/on_jax/emsolver/field_distribution.py @@ -6,7 +6,6 @@ def field_dist_1d(wavelength, kx, T1, layer_info_list, period, pol, res_x=20, res_y=1, res_z=20, type_complex=jnp.complex128): - k0 = 2 * jnp.pi / wavelength Kx = jnp.diag(kx) @@ -66,8 +65,8 @@ def field_dist_1d(wavelength, kx, T1, layer_info_list, period, pol, res_x=20, re # @partial(jax.jit, static_argnums=(5, 6, 10, 11, 12, 13)) def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, - res_x=20, res_y=20, res_z=20, type_complex=jnp.complex128): - + res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False), + type_complex=jnp.complex128): k0 = 2 * jnp.pi / wavelength ff_x = len(kx) @@ -77,11 +76,13 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, Kx = jnp.diag(jnp.tile(kx, ff_y).flatten()) Ky = jnp.diag(jnp.tile(ky.reshape((-1, 1)), ff_x).flatten()) - field_cell = jnp.zeros((res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex) + field_cell = jnp.zeros((sum(set_field_input), res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex) - T_layer = T1 + # T_layer = T1 + # T_layer = T1[set_field_input] + T_layer = T1[jnp.array(set_field_input)] - big_I = jnp.eye((len(T1))).astype(type_complex) + big_I = jnp.eye((len(T1[0]))).astype(type_complex) O = jnp.zeros((ff_xy, ff_xy), dtype=type_complex) # From the first layer @@ -106,20 +107,27 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, # z_1d = np.arange(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d z_1d = jnp.linspace(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d - c1_plus = c[0 * ff_xy:1 * ff_xy] - c2_plus = c[1 * ff_xy:2 * ff_xy] - c1_minus = c[2 * ff_xy:3 * ff_xy] - c2_minus = c[3 * ff_xy:4 * ff_xy] + c1_plus = c[:, 0 * ff_xy:1 * ff_xy] + c2_plus = c[:, 1 * ff_xy:2 * ff_xy] + c1_minus = c[:, 2 * ff_xy:3 * ff_xy] + c2_minus = c[:, 3 * ff_xy:4 * ff_xy] big_Q1 = jnp.diag(q_1) big_Q2 = jnp.diag(q_2) - Sx = W_2 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Sy = V_11 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_12 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Ux = W_1 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) - Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) + Sx = W_2 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Sy = V_11 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_12 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Ux = W_1 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) + Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Sz = -1j * epz_conv_i @ (Kx @ Uy - Ky @ Ux) Uz = -1j * (Kx @ Sy - Ky @ Sx) @@ -138,17 +146,19 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, inv_fourier = jnp.exp(-1j * x_2d) * jnp.exp(-1j * y_2d) inv_fourier = inv_fourier.reshape((res_y, res_x, -1)) - Ex = inv_fourier[:, :, None, :] @ Sx[:, None, None, :, :] - Ey = inv_fourier[:, :, None, :] @ Sy[:, None, None, :, :] - Ez = inv_fourier[:, :, None, :] @ Sz[:, None, None, :, :] - Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, None, None, :, :] - Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, None, None, :, :] - Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, None, None, :, :] + Ex = inv_fourier[:, :, None, :] @ Sx[:, :, None, None, :, :] + Ey = inv_fourier[:, :, None, :] @ Sy[:, :, None, None, :, :] + Ez = inv_fourier[:, :, None, :] @ Sz[:, :, None, None, :, :] + Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, :, None, None, :, :] + Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, :, None, None, :, :] + Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, :, None, None, :, :] val = jnp.concatenate( (Ex.squeeze(-1), Ey.squeeze(-1), Ez.squeeze(-1), Hx.squeeze(-1), Hy.squeeze(-1), Hz.squeeze(-1)), -1) - field_cell = field_cell.at[res_z * idx_layer:res_z * (idx_layer + 1)].set(val) + val = jnp.moveaxis(val, 1, 0) + + field_cell = field_cell.at[:, res_z * idx_layer:res_z * (idx_layer + 1)].set(val) T_layer = big_A_i @ big_X @ T_layer @@ -157,8 +167,8 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, # @partial(jax.jit, static_argnums=(5, 6, 10, 11, 12, 13)) def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, - res_x=20, res_y=20, res_z=20, type_complex=jnp.complex128): - + res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False), + type_complex=jnp.complex128): k0 = 2 * jnp.pi / wavelength ff_x = len(kx) @@ -168,15 +178,16 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, Kx = jnp.diag(jnp.tile(kx, ff_y).flatten()) Ky = jnp.diag(jnp.tile(ky.reshape((-1, 1)), ff_x).flatten()) - field_cell = jnp.zeros((res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex) + field_cell = jnp.zeros((sum(set_field_input), res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex) - T_layer = T1 + # T_layer = T1 + # T_layer = T1[list(set_field_input)] + T_layer = T1[jnp.array(set_field_input)] - big_I = jnp.eye((len(T1))).astype(type_complex) + big_I = jnp.eye((len(T1[0]))).astype(type_complex) # From the first layer for idx_layer, (epz_conv_i, W, V, q, d, big_A_i, big_B) in enumerate(layer_info_list[::-1]): - W_11 = W[:ff_xy, :ff_xy] W_12 = W[:ff_xy, ff_xy:] W_21 = W[ff_xy:, :ff_xy] @@ -193,24 +204,32 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, z_1d = jnp.linspace(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d # z_1d = np.arange(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d - c1_plus = c[0 * ff_xy:1 * ff_xy] - c2_plus = c[1 * ff_xy:2 * ff_xy] - c1_minus = c[2 * ff_xy:3 * ff_xy] - c2_minus = c[3 * ff_xy:4 * ff_xy] + c1_plus = c[:, 0 * ff_xy:1 * ff_xy] + c2_plus = c[:, 1 * ff_xy:2 * ff_xy] + c1_minus = c[:, 2 * ff_xy:3 * ff_xy] + c2_minus = c[:, 3 * ff_xy:4 * ff_xy] q1 = q[:len(q) // 2] q2 = q[len(q) // 2:] big_Q1 = jnp.diag(q1) big_Q2 = jnp.diag(q2) - Sx = W_11 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + W_12 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Sy = W_21 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + W_22 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Ux = V_11 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_12 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) + Sx = W_11 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + W_12 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Sy = W_21 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + W_22 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Ux = V_11 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_12 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) Sz = -1j * epz_conv_i @ (Kx @ Uy - Ky @ Ux) Uz = -1j * (Kx @ Sy - Ky @ Sx) @@ -232,17 +251,18 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, inv_fourier = jnp.exp(-1j * x_2d) * jnp.exp(-1j * y_2d) inv_fourier = inv_fourier.reshape((res_y, res_x, -1)) - Ex = inv_fourier[:, :, None, :] @ Sx[:, None, None, :, :] - Ey = inv_fourier[:, :, None, :] @ Sy[:, None, None, :, :] - Ez = inv_fourier[:, :, None, :] @ Sz[:, None, None, :, :] - Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, None, None, :, :] - Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, None, None, :, :] - Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, None, None, :, :] + Ex = inv_fourier[:, :, None, :] @ Sx[:, :, None, None, :, :] + Ey = inv_fourier[:, :, None, :] @ Sy[:, :, None, None, :, :] + Ez = inv_fourier[:, :, None, :] @ Sz[:, :, None, None, :, :] + Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, :, None, None, :, :] + Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, :, None, None, :, :] + Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, :, None, None, :, :] val = jnp.concatenate( (Ex.squeeze(-1), Ey.squeeze(-1), Ez.squeeze(-1), Hx.squeeze(-1), Hy.squeeze(-1), Hz.squeeze(-1)), -1) + val = jnp.moveaxis(val, 1, 0) - field_cell = field_cell.at[res_z * idx_layer:res_z * (idx_layer + 1)].set(val) + field_cell = field_cell.at[:, res_z * idx_layer:res_z * (idx_layer + 1)].set(val) T_layer = big_A_i @ big_X @ T_layer diff --git a/meent/on_jax/emsolver/rcwa.py b/meent/on_jax/emsolver/rcwa.py index 8188d60..2839f68 100644 --- a/meent/on_jax/emsolver/rcwa.py +++ b/meent/on_jax/emsolver/rcwa.py @@ -260,7 +260,7 @@ def conv_solve(self, **kwargs): # return de_ri, de_ti @jax_device_set - def calculate_field(self, res_x=20, res_y=20, res_z=20): + def calculate_field(self, res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False)): kx, ky = self.get_kx_ky_vector(wavelength=self.wavelength) if self._grating_type_assigned == 0: @@ -269,10 +269,12 @@ def calculate_field(self, res_x=20, res_y=20, res_z=20): res_x=res_x, res_y=res_y, res_z=res_z, type_complex=self.type_complex) elif self._grating_type_assigned == 1: field_cell = field_dist_1d_conical(self.wavelength, kx, ky, self.T1, self.layer_info_list, self.period, - res_x=res_x, res_y=res_y, res_z=res_z, type_complex=self.type_complex) + res_x=res_x, res_y=res_y, res_z=res_z, set_field_input=set_field_input, + type_complex=self.type_complex) else: field_cell = field_dist_2d(self.wavelength, kx, ky, self.T1, self.layer_info_list, self.period, - res_x=res_x, res_y=res_y, res_z=res_z, type_complex=self.type_complex) + res_x=res_x, res_y=res_y, res_z=res_z, set_field_input=set_field_input, + type_complex=self.type_complex) return field_cell @@ -281,7 +283,7 @@ def field_plot(self, field_cell): @partial(jax.jit, static_argnums=(1, 2, 3)) @jax_device_set - def conv_solve_field(self, res_x=20, res_y=20, res_z=20, **kwargs): + def conv_solve_field(self, res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False), **kwargs): [setattr(self, k, v) for k, v in kwargs.items()] # needed for optimization if self.fourier_type == 1: @@ -289,13 +291,13 @@ def conv_solve_field(self, res_x=20, res_y=20, res_z=20, **kwargs): return None, None, None de_ri, de_ti, _, _ = self._conv_solve() - field_cell = self.calculate_field(res_x, res_y, res_z) + field_cell = self.calculate_field(res_x, res_y, res_z, set_field_input) return de_ri, de_ti, field_cell @jax_device_set - def conv_solve_field_no_jit(self, res_x=20, res_y=20, res_z=20): + def conv_solve_field_no_jit(self, res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False)): de_ri, de_ti, _, _ = self._conv_solve() - field_cell = self.calculate_field(res_x, res_y, res_z) + field_cell = self.calculate_field(res_x, res_y, res_z, set_field_input) return de_ri, de_ti, field_cell def run_ucell_vmap(self, ucell_list): diff --git a/meent/on_jax/emsolver/transfer_method.py b/meent/on_jax/emsolver/transfer_method.py index ca1d9f7..1399c25 100644 --- a/meent/on_jax/emsolver/transfer_method.py +++ b/meent/on_jax/emsolver/transfer_method.py @@ -355,7 +355,6 @@ def transfer_1d_conical_3(k0, W, V, q, d, varphi, big_F, big_G, big_T, type_comp def transfer_1d_conical_4(ff_x, ff_y, big_F, big_G, big_T, kz_top, kz_bot, psi, theta, n_top, n_bot, type_complex=jnp.complex128, use_pinv=False): - ff_xy = ff_x * ff_y Kz_top = jnp.diag(kz_top) @@ -365,7 +364,6 @@ def transfer_1d_conical_4(ff_x, ff_y, big_F, big_G, big_T, kz_top, kz_bot, psi, I = jnp.eye(ff_xy, dtype=type_complex) O = jnp.zeros((ff_xy, ff_xy), dtype=type_complex) - big_F_11 = big_F[:ff_xy, :ff_xy] big_F_12 = big_F[:ff_xy, ff_xy:] big_F_21 = big_F[ff_xy:, :ff_xy] @@ -475,7 +473,9 @@ def transfer_1d_conical_4(ff_x, ff_y, big_F, big_G, big_T, kz_top, kz_bot, psi, result = {'res': res, 'res_tm_inc': res_tm_inc, 'res_te_inc': res_te_inc} - return result, big_T1 + big_T1_all = jnp.stack((big_T1, big_T1_tetm[:, 0:1], big_T1_tetm[:, 1:2])) + + return result, big_T1_all def transfer_2d_1(kx, ky, n_top, n_bot, type_complex=jnp.complex128): @@ -718,6 +718,8 @@ def transfer_2d_4(ff_x, ff_y, big_F, big_G, big_T, kz_top, kz_bot, psi, theta, n result = {'res': res, 'res_tm_inc': res_tm_inc, 'res_te_inc': res_te_inc} - return result, big_T1 + big_T1_all = jnp.stack((big_T1, big_T1_tetm[:, 0:1], big_T1_tetm[:, 1:2])) + + return result, big_T1_all diff --git a/meent/on_numpy/emsolver/field_distribution.py b/meent/on_numpy/emsolver/field_distribution.py index aa37549..dc56779 100644 --- a/meent/on_numpy/emsolver/field_distribution.py +++ b/meent/on_numpy/emsolver/field_distribution.py @@ -59,7 +59,8 @@ def field_dist_1d(wavelength, kx, T1, layer_info_list, period, def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, - res_x=20, res_y=20, res_z=20, type_complex=np.complex128): + res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False), + type_complex=np.complex128): k0 = 2 * np.pi / wavelength ff_x = len(kx) @@ -69,11 +70,13 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, Kx = np.diag(np.tile(kx, ff_y).flatten()) Ky = np.diag(np.tile(ky.reshape((-1, 1)), ff_x).flatten()) - field_cell = np.zeros((res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex) + # field_cell = np.zeros((res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex) + field_cell = np.zeros((sum(set_field_input), res_z * len(layer_info_list), res_y, res_x, 6), + dtype=type_complex) + # T_layer = T1 + T_layer = T1[list(set_field_input)] - T_layer = T1 - - big_I = np.eye((len(T1))).astype(type_complex) + big_I = np.eye((len(T1[0]))).astype(type_complex) O = np.zeros((ff_xy, ff_xy), dtype=type_complex) # From the first layer @@ -98,20 +101,20 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, # z_1d = np.arange(res_z).reshape((-1, 1, 1)) / res_z * d z_1d = np.linspace(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d - c1_plus = c[0 * ff_xy:1 * ff_xy] - c2_plus = c[1 * ff_xy:2 * ff_xy] - c1_minus = c[2 * ff_xy:3 * ff_xy] - c2_minus = c[3 * ff_xy:4 * ff_xy] + c1_plus = c[:, 0 * ff_xy:1 * ff_xy] + c2_plus = c[:, 1 * ff_xy:2 * ff_xy] + c1_minus = c[:, 2 * ff_xy:3 * ff_xy] + c2_minus = c[:, 3 * ff_xy:4 * ff_xy] big_Q1 = np.diag(q_1) big_Q2 = np.diag(q_2) - Sx = W_2 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Sy = V_11 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_12 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Ux = W_1 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) - Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) + Sx = W_2 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Sy = V_11 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_12 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Ux = W_1 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) + Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) Sz = -1j * epz_conv_i @ (Kx @ Uy - Ky @ Ux) Uz = -1j * (Kx @ Sy - Ky @ Sx) @@ -141,17 +144,18 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, # Hy = -1j * exp_K[:, :, None, :] @ Uy[:, None, None, :, :] # Hz = -1j * exp_K[:, :, None, :] @ Uz[:, None, None, :, :] - Ex = inv_fourier[:, :, None, :] @ Sx[:, None, None, :, :] - Ey = inv_fourier[:, :, None, :] @ Sy[:, None, None, :, :] - Ez = inv_fourier[:, :, None, :] @ Sz[:, None, None, :, :] - Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, None, None, :, :] - Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, None, None, :, :] - Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, None, None, :, :] + Ex = inv_fourier[:, :, None, :] @ Sx[:, :, None, None, :, :] + Ey = inv_fourier[:, :, None, :] @ Sy[:, :, None, None, :, :] + Ez = inv_fourier[:, :, None, :] @ Sz[:, :, None, None, :, :] + Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, :, None, None, :, :] + Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, :, None, None, :, :] + Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, :, None, None, :, :] val = np.concatenate( (Ex.squeeze(-1), Ey.squeeze(-1), Ez.squeeze(-1), Hx.squeeze(-1), Hy.squeeze(-1), Hz.squeeze(-1)), -1) + val = np.moveaxis(val, 1, 0) - field_cell[res_z * idx_layer:res_z * (idx_layer + 1)] = val + field_cell[:, res_z * idx_layer:res_z * (idx_layer + 1)] = val T_layer = big_A_i @ big_X @ T_layer @@ -159,7 +163,8 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, - res_x=20, res_y=20, res_z=20, type_complex=np.complex128): + res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False), + type_complex=np.complex128): k0 = 2 * np.pi / wavelength ff_x = len(kx) @@ -169,11 +174,13 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, Kx = np.diag(np.tile(kx, ff_y).flatten()) Ky = np.diag(np.tile(ky.reshape((-1, 1)), ff_x).flatten()) - field_cell = np.zeros((res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex) + # field_cell = np.zeros((res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex) + field_cell = np.zeros((sum(set_field_input), res_z * len(layer_info_list), res_y, res_x, 6), + dtype=type_complex) + # T_layer = T1 + T_layer = T1[list(set_field_input)] - T_layer = T1 - - big_I = np.eye((len(T1))).astype(type_complex) + big_I = np.eye((len(T1[0]))).astype(type_complex) # From the first layer for idx_layer, (epz_conv_i, W, V, q, d, big_A_i, big_B) in enumerate(layer_info_list[::-1]): @@ -196,22 +203,30 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, z_1d = np.linspace(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d # z_1d = np.arange(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d - c1_plus = c[0 * ff_xy:1 * ff_xy] - c2_plus = c[1 * ff_xy:2 * ff_xy] - c1_minus = c[2 * ff_xy:3 * ff_xy] - c2_minus = c[3 * ff_xy:4 * ff_xy] + c1_plus = c[:, 0 * ff_xy:1 * ff_xy] + c2_plus = c[:, 1 * ff_xy:2 * ff_xy] + c1_minus = c[:, 2 * ff_xy:3 * ff_xy] + c2_minus = c[:, 3 * ff_xy:4 * ff_xy] big_Q1 = np.diag(q_1) big_Q2 = np.diag(q_2) - Sx = W_11 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + W_12 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Sy = W_21 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + W_22 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Ux = V_11 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_12 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) + Sx = W_11 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + W_12 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Sy = W_21 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + W_22 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Ux = V_11 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_12 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) Sz = -1j * epz_conv_i @ (Kx @ Uy - Ky @ Ux) Uz = -1j * (Kx @ Sy - Ky @ Sx) @@ -233,17 +248,17 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, inv_fourier = np.exp(-1j * x_2d) * np.exp(-1j * y_2d) inv_fourier = inv_fourier.reshape((res_y, res_x, -1)) - Ex = inv_fourier[:, :, None, :] @ Sx[:, None, None, :, :] - Ey = inv_fourier[:, :, None, :] @ Sy[:, None, None, :, :] - Ez = inv_fourier[:, :, None, :] @ Sz[:, None, None, :, :] - Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, None, None, :, :] - Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, None, None, :, :] - Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, None, None, :, :] + Ex = inv_fourier[:, :, None, :] @ Sx[:, :, None, None, :, :] + Ey = inv_fourier[:, :, None, :] @ Sy[:, :, None, None, :, :] + Ez = inv_fourier[:, :, None, :] @ Sz[:, :, None, None, :, :] + Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, :, None, None, :, :] + Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, :, None, None, :, :] + Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, :, None, None, :, :] val = np.concatenate( (Ex.squeeze(-1), Ey.squeeze(-1), Ez.squeeze(-1), Hx.squeeze(-1), Hy.squeeze(-1), Hz.squeeze(-1)), -1) - - field_cell[res_z * idx_layer:res_z * (idx_layer + 1)] = val + val = np.moveaxis(val, 1, 0) + field_cell[:, res_z * idx_layer:res_z * (idx_layer + 1)] = val T_layer = big_A_i @ big_X @ T_layer diff --git a/meent/on_numpy/emsolver/rcwa.py b/meent/on_numpy/emsolver/rcwa.py index 5447780..429e9b2 100644 --- a/meent/on_numpy/emsolver/rcwa.py +++ b/meent/on_numpy/emsolver/rcwa.py @@ -198,7 +198,7 @@ def conv_solve(self, **kwargs): return result - def calculate_field(self, res_x=20, res_y=20, res_z=20): + def calculate_field(self, res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False)): # TODO: change res_ to accept array of points. kx, ky = self.get_kx_ky_vector(wavelength=self.wavelength) @@ -209,16 +209,18 @@ def calculate_field(self, res_x=20, res_y=20, res_z=20): elif self._grating_type_assigned == 1: # TODO other bds field_cell = field_dist_1d_conical(self.wavelength, kx, ky, self.T1, self.layer_info_list, self.period, - res_x=res_x, res_y=res_y, res_z=res_z, type_complex=self.type_complex) + res_x=res_x, res_y=res_y, res_z=res_z, set_field_input=set_field_input, + type_complex=self.type_complex) else: field_cell = field_dist_2d(self.wavelength, kx, ky, self.T1, self.layer_info_list, self.period, - res_x=res_x, res_y=res_y, res_z=res_z, type_complex=self.type_complex) + res_x=res_x, res_y=res_y, res_z=res_z, set_field_input=set_field_input, + type_complex=self.type_complex) return field_cell - def conv_solve_field(self, res_x=20, res_y=20, res_z=20): + def conv_solve_field(self, res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False)): res = self.conv_solve() - field_cell = self.calculate_field(res_x, res_y, res_z) + field_cell = self.calculate_field(res_x, res_y, res_z, set_field_input) return res, field_cell def field_plot(self, field_cell): diff --git a/meent/on_numpy/emsolver/transfer_method.py b/meent/on_numpy/emsolver/transfer_method.py index 6b616c7..3b7027e 100644 --- a/meent/on_numpy/emsolver/transfer_method.py +++ b/meent/on_numpy/emsolver/transfer_method.py @@ -375,8 +375,9 @@ def transfer_1d_conical_4(ff_x, ff_y, big_F, big_G, big_T, kz_top, kz_bot, psi, 'de_ti_s': de_ti_s_tetm[1], 'de_ti_p': de_ti_p_tetm[1], 'de_ti': de_ti_tetm[1]} result = {'res': res, 'res_tm_inc': res_tm_inc, 'res_te_inc': res_te_inc} + big_T1_all = np.stack((big_T1, big_T1_tetm[:, 0:1], big_T1_tetm[:, 1:2])) - return result, big_T1 + return result, big_T1_all def transfer_2d_1(kx, ky, n_top, n_bot, type_complex=np.complex128): @@ -619,5 +620,6 @@ def transfer_2d_4(ff_x, ff_y, big_F, big_G, big_T, kz_top, kz_bot, psi, theta, n 'de_ti_s': de_ti_s_tetm[1], 'de_ti_p': de_ti_p_tetm[1], 'de_ti': de_ti_tetm[1]} result = {'res': res, 'res_tm_inc': res_tm_inc, 'res_te_inc': res_te_inc} + big_T1_all = np.stack((big_T1, big_T1_tetm[:, 0:1], big_T1_tetm[:, 1:2])) - return result, big_T1 + return result, big_T1_all \ No newline at end of file diff --git a/meent/on_torch/emsolver/field_distribution.py b/meent/on_torch/emsolver/field_distribution.py index 46ba0d3..f6afd5e 100644 --- a/meent/on_torch/emsolver/field_distribution.py +++ b/meent/on_torch/emsolver/field_distribution.py @@ -4,7 +4,6 @@ def field_dist_1d(wavelength, kx, T1, layer_info_list, period, pol, res_x=20, res_y=20, res_z=20, device='cpu', type_complex=torch.complex128, type_float=torch.float64): - k0 = 2 * torch.pi / wavelength Kx = torch.diag(kx) @@ -62,22 +61,25 @@ def field_dist_1d(wavelength, kx, T1, layer_info_list, period, def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, - res_x=20, res_y=20, res_z=20, device='cpu', type_complex=torch.complex128, type_float=torch.float64): - + res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False), + device='cpu', type_complex=torch.complex128, type_float=torch.float64): k0 = 2 * torch.pi / wavelength ff_x = len(kx) ff_y = len(ky) ff_xy = ff_x * ff_y - Kx = torch.diag(torch.tile(kx, (ff_y, )).flatten()) - Ky = torch.diag(torch.tile(ky.reshape((-1, 1)), (ff_x, )).flatten()) + Kx = torch.diag(torch.tile(kx, (ff_y,)).flatten()) + Ky = torch.diag(torch.tile(ky.reshape((-1, 1)), (ff_x,)).flatten()) - field_cell = torch.zeros((res_z * len(layer_info_list), res_y, res_x, 6), device=device, dtype=type_complex) + # field_cell = torch.zeros((res_z * len(layer_info_list), res_y, res_x, 6), device=device, dtype=type_complex) + field_cell = torch.zeros((sum(set_field_input), res_z * len(layer_info_list), res_y, res_x, 6), device=device, + dtype=type_complex) - T_layer = T1 + # T_layer = T1 + T_layer = T1[list(set_field_input)] - big_I = torch.eye((len(T1)), device=device, dtype=type_complex) + big_I = torch.eye((len(T1[0])), device=device, dtype=type_complex) O = torch.zeros((ff_xy, ff_xy), device=device, dtype=type_complex) # From the first layer @@ -100,25 +102,35 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, torch.cat([X_1, O], dim=1), torch.cat([O, X_2], dim=1)]) - c = torch.cat([big_I, big_B @ big_A_i @ big_X]) @ T_layer + c = torch.cat([big_I, big_B @ big_A_i @ big_X]) @ T_layer # z_1d = np.arange(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d z_1d = torch.linspace(0, res_z, res_z, device=device, dtype=type_complex).reshape((-1, 1, 1)) / res_z * d - c1_plus = c[0 * ff_xy:1 * ff_xy] - c2_plus = c[1 * ff_xy:2 * ff_xy] - c1_minus = c[2 * ff_xy:3 * ff_xy] - c2_minus = c[3 * ff_xy:4 * ff_xy] + # c1_plus = c[0 * ff_xy:1 * ff_xy] + # c2_plus = c[1 * ff_xy:2 * ff_xy] + # c1_minus = c[2 * ff_xy:3 * ff_xy] + # c2_minus = c[3 * ff_xy:4 * ff_xy] + c1_plus = c[:, 0 * ff_xy:1 * ff_xy] + c2_plus = c[:, 1 * ff_xy:2 * ff_xy] + c1_minus = c[:, 2 * ff_xy:3 * ff_xy] + c2_minus = c[:, 3 * ff_xy:4 * ff_xy] big_Q1 = torch.diag(q_1) big_Q2 = torch.diag(q_2) - Sx = W_2 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Sy = V_11 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_12 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Ux = W_1 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) - Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) + Sx = W_2 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Sy = V_11 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_12 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Ux = W_1 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) + Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) Sz = -1j * epz_conv_i @ (Kx @ Uy - Ky @ Ux) Uz = -1j * (Kx @ Sy - Ky @ Sx) @@ -130,7 +142,8 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, # y_1d = np.arange(res_y-1, -1, -1).reshape((-1, 1, 1)) * period[1] / res_y # y_1d = torch.linspace(0, period[1], res_y, device=device, dtype=type_complex)[::-1].reshape((-1, 1, 1)) - y_1d = torch.flip(torch.linspace(0, period[1], res_y, device=device, dtype=type_complex), dims=(0,)).reshape((-1, 1, 1)) + y_1d = torch.flip(torch.linspace(0, period[1], res_y, device=device, dtype=type_complex), dims=(0,)).reshape( + (-1, 1, 1)) y_2d = torch.tile(y_1d, (1, res_x, 1)) y_2d = y_2d * ky * k0 y_2d = y_2d.reshape((res_y, res_x, len(ky), 1)) @@ -138,17 +151,18 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, inv_fourier = torch.exp(-1j * x_2d) * torch.exp(-1j * y_2d) inv_fourier = inv_fourier.reshape((res_y, res_x, -1)) - Ex = inv_fourier[:, :, None, :] @ Sx[:, None, None, :, :] - Ey = inv_fourier[:, :, None, :] @ Sy[:, None, None, :, :] - Ez = inv_fourier[:, :, None, :] @ Sz[:, None, None, :, :] - Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, None, None, :, :] - Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, None, None, :, :] - Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, None, None, :, :] + Ex = inv_fourier[:, :, None, :] @ Sx[:, :, None, None, :, :] + Ey = inv_fourier[:, :, None, :] @ Sy[:, :, None, None, :, :] + Ez = inv_fourier[:, :, None, :] @ Sz[:, :, None, None, :, :] + Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, :, None, None, :, :] + Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, :, None, None, :, :] + Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, :, None, None, :, :] val = torch.cat( (Ex.squeeze(-1), Ey.squeeze(-1), Ez.squeeze(-1), Hx.squeeze(-1), Hy.squeeze(-1), Hz.squeeze(-1)), -1) + val = torch.moveaxis(val, 1, 0) - field_cell[res_z * idx_layer:res_z * (idx_layer + 1)] = val + field_cell[:, res_z * idx_layer:res_z * (idx_layer + 1)] = val T_layer = big_A_i @ big_X @ T_layer @@ -156,26 +170,115 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period, def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, - res_x=20, res_y=20, res_z=20, device='cpu', type_complex=torch.complex128): - + res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False), + device='cpu', type_complex=torch.complex128): k0 = 2 * torch.pi / wavelength ff_x = len(kx) ff_y = len(ky) ff_xy = ff_x * ff_y - Kx = torch.diag(torch.tile(kx, (ff_y, )).flatten()) - Ky = torch.diag(torch.tile(ky.reshape((-1, 1)), (ff_x, )).flatten()) - - field_cell = torch.zeros((res_z * len(layer_info_list), res_y, res_x, 6), device=device, dtype=type_complex) - - T_layer = T1 - - big_I = torch.eye((len(T1)), device=device, dtype=type_complex) + Kx = torch.diag(torch.tile(kx, (ff_y,)).flatten()) + Ky = torch.diag(torch.tile(ky.reshape((-1, 1)), (ff_x,)).flatten()) + + # import time + # t0 = time.time() + # field_cell0 = torch.zeros((res_z * len(layer_info_list), res_y, res_x, 6), device=device, dtype=type_complex) + # + # T_layer = T1[0] + # big_I = torch.eye((len(T1[0])), device=device, dtype=type_complex) + # + # # From the first layer + # for idx_layer, (epz_conv_i, W, V, q, d, big_A_i, big_B) in enumerate(layer_info_list[::-1]): + # + # W_11 = W[:ff_xy, :ff_xy] + # W_12 = W[:ff_xy, ff_xy:] + # W_21 = W[ff_xy:, :ff_xy] + # W_22 = W[ff_xy:, ff_xy:] + # + # V_11 = V[:ff_xy, :ff_xy] + # V_12 = V[:ff_xy, ff_xy:] + # V_21 = V[ff_xy:, :ff_xy] + # V_22 = V[ff_xy:, ff_xy:] + # + # q_1 = q[:ff_xy] + # q_2 = q[ff_xy:] + # + # big_X = torch.diag(torch.exp(-k0 * q * d)) + # + # c = torch.cat([big_I, big_B @ big_A_i @ big_X]) @ T_layer + # + # z_1d = torch.linspace(0, res_z, res_z, device=device, dtype=type_complex).reshape((-1, 1, 1)) / res_z * d + # # z_1d = np.arange(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d + # + # c1_plus0 = c[0 * ff_xy:1 * ff_xy] + # c2_plus0 = c[1 * ff_xy:2 * ff_xy] + # c1_minus0 = c[2 * ff_xy:3 * ff_xy] + # c2_minus0 = c[3 * ff_xy:4 * ff_xy] + # + # big_Q1 = torch.diag(q_1) + # big_Q2 = torch.diag(q_2) + # + # Sx0 = W_11 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus0 + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus0) \ + # + W_12 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus0 + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus0) + # Sy0 = W_21 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus0 + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus0) \ + # + W_22 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus0 + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus0) + # + # Ux0 = V_11 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus0 + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus0) \ + # + V_12 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus0 + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus0) + # Uy0 = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus0 + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus0) \ + # + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus0 + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus0) + # + # Sz0 = -1j * epz_conv_i @ (Kx @ Uy0 - Ky @ Ux0) + # Uz0 = -1j * (Kx @ Sy0 - Ky @ Sx0) + # + # # x_1d = np.arange(res_x).reshape((1, -1, 1)) * period[0] / res_x + # x_1d = torch.linspace(0, period[0], res_x, device=device, dtype=type_complex).reshape((1, -1, 1)) + # + # # y_1d = np.arange(res_y-1, -1, -1).reshape((-1, 1, 1)) * period[1] / res_y + # # y_1d = torch.linspace(0, period[1], res_y, device=device, dtype=type_complex)[::-1].reshape((-1, 1, 1)) + # y_1d = torch.flip(torch.linspace(0, period[1], res_y, device=device, dtype=type_complex), dims=(0,)).reshape((-1, 1, 1)) + # + # x_2d = torch.tile(x_1d, (res_y, 1, 1)) + # x_2d = x_2d * kx * k0 + # x_2d = x_2d.reshape((res_y, res_x, 1, len(kx))) + # + # y_2d = torch.tile(y_1d, (1, res_x, 1)) + # y_2d = y_2d * ky * k0 + # y_2d = y_2d.reshape((res_y, res_x, len(ky), 1)) + # + # inv_fourier = torch.exp(-1j * x_2d) * torch.exp(-1j * y_2d) + # inv_fourier = inv_fourier.reshape((res_y, res_x, -1)) + # # (20, 50, 1, 63) (50, 1, 1, 63, 1); 50 20 50 1 1 + # Ex0 = inv_fourier[:, :, None, :] @ Sx0[:, None, None, :, :] + # Ey0 = inv_fourier[:, :, None, :] @ Sy0[:, None, None, :, :] + # Ez0 = inv_fourier[:, :, None, :] @ Sz0[:, None, None, :, :] + # Hx0 = 1j * inv_fourier[:, :, None, :] @ Ux0[:, None, None, :, :] + # Hy0 = 1j * inv_fourier[:, :, None, :] @ Uy0[:, None, None, :, :] + # Hz0 = 1j * inv_fourier[:, :, None, :] @ Uz0[:, None, None, :, :] + # + # val = torch.cat( + # (Ex0.squeeze(-1), Ey0.squeeze(-1), Ez0.squeeze(-1), Hx0.squeeze(-1), Hy0.squeeze(-1), Hz0.squeeze(-1)), -1) + # + # field_cell0[res_z * idx_layer:res_z * (idx_layer + 1)] = val + # + # T_layer = big_A_i @ big_X @ T_layer + # + # print('original:', time.time() - t0) + + # t0 = time.time() + + # set_field_input = [True, False, False] + # set_field_input = [True, False, True] + # set_field_input = [True, True, True] + field_cell = torch.zeros((sum(set_field_input), res_z * len(layer_info_list), res_y, res_x, 6), device=device, + dtype=type_complex) + T_layer = T1[list(set_field_input)] + + big_I = torch.eye((len(T1[0])), device=device, dtype=type_complex) # From the first layer for idx_layer, (epz_conv_i, W, V, q, d, big_A_i, big_B) in enumerate(layer_info_list[::-1]): - W_11 = W[:ff_xy, :ff_xy] W_12 = W[:ff_xy, ff_xy:] W_21 = W[ff_xy:, :ff_xy] @@ -191,28 +294,35 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, big_X = torch.diag(torch.exp(-k0 * q * d)) - c = torch.cat([big_I, big_B @ big_A_i @ big_X]) @ T_layer + c = torch.cat([big_I, big_B @ big_A_i @ big_X]) @ T_layer z_1d = torch.linspace(0, res_z, res_z, device=device, dtype=type_complex).reshape((-1, 1, 1)) / res_z * d # z_1d = np.arange(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d - c1_plus = c[0 * ff_xy:1 * ff_xy] - c2_plus = c[1 * ff_xy:2 * ff_xy] - c1_minus = c[2 * ff_xy:3 * ff_xy] - c2_minus = c[3 * ff_xy:4 * ff_xy] + c1_plus = c[:, 0 * ff_xy:1 * ff_xy] + c2_plus = c[:, 1 * ff_xy:2 * ff_xy] + c1_minus = c[:, 2 * ff_xy:3 * ff_xy] + c2_minus = c[:, 3 * ff_xy:4 * ff_xy] big_Q1 = torch.diag(q_1) big_Q2 = torch.diag(q_2) - Sx = W_11 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + W_12 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Sy = W_21 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + W_22 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - - Ux = V_11 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_12 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) - Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \ - + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus) + Sx = W_11 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + W_12 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Sy = W_21 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + W_22 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Ux = V_11 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_12 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) + Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus + + d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \ + + V_22 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus + + d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus) Sz = -1j * epz_conv_i @ (Kx @ Uy - Ky @ Ux) Uz = -1j * (Kx @ Sy - Ky @ Sx) @@ -222,7 +332,8 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, # y_1d = np.arange(res_y-1, -1, -1).reshape((-1, 1, 1)) * period[1] / res_y # y_1d = torch.linspace(0, period[1], res_y, device=device, dtype=type_complex)[::-1].reshape((-1, 1, 1)) - y_1d = torch.flip(torch.linspace(0, period[1], res_y, device=device, dtype=type_complex), dims=(0,)).reshape((-1, 1, 1)) + y_1d = torch.flip(torch.linspace(0, period[1], res_y, device=device, dtype=type_complex), dims=(0,)).reshape( + (-1, 1, 1)) x_2d = torch.tile(x_1d, (res_y, 1, 1)) x_2d = x_2d * kx * k0 @@ -235,20 +346,22 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period, inv_fourier = torch.exp(-1j * x_2d) * torch.exp(-1j * y_2d) inv_fourier = inv_fourier.reshape((res_y, res_x, -1)) - Ex = inv_fourier[:, :, None, :] @ Sx[:, None, None, :, :] - Ey = inv_fourier[:, :, None, :] @ Sy[:, None, None, :, :] - Ez = inv_fourier[:, :, None, :] @ Sz[:, None, None, :, :] - Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, None, None, :, :] - Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, None, None, :, :] - Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, None, None, :, :] + Ex = inv_fourier[:, :, None, :] @ Sx[:, :, None, None, :, :] + Ey = inv_fourier[:, :, None, :] @ Sy[:, :, None, None, :, :] + Ez = inv_fourier[:, :, None, :] @ Sz[:, :, None, None, :, :] + Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, :, None, None, :, :] + Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, :, None, None, :, :] + Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, :, None, None, :, :] val = torch.cat( (Ex.squeeze(-1), Ey.squeeze(-1), Ez.squeeze(-1), Hx.squeeze(-1), Hy.squeeze(-1), Hz.squeeze(-1)), -1) - - field_cell[res_z * idx_layer:res_z * (idx_layer + 1)] = val + val = torch.moveaxis(val, 1, 0) + field_cell[:, res_z * idx_layer:res_z * (idx_layer + 1)] = val T_layer = big_A_i @ big_X @ T_layer + # print('new:', time.time() - t0) + return field_cell diff --git a/meent/on_torch/emsolver/rcwa.py b/meent/on_torch/emsolver/rcwa.py index ea01ed6..41f0ae8 100644 --- a/meent/on_torch/emsolver/rcwa.py +++ b/meent/on_torch/emsolver/rcwa.py @@ -54,7 +54,7 @@ def __init__(self, period=(1., 1.), wavelength=1., ucell=None, - thickness=(0., ), + thickness=(0.,), backend=2, pol=0., fto=(0, 0), @@ -196,28 +196,31 @@ def conv_solve(self, **kwargs): return result - def calculate_field(self, res_x=20, res_y=20, res_z=20): + def calculate_field(self, res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False)): kx, ky = self.get_kx_ky_vector(wavelength=self.wavelength) if self._grating_type_assigned == 0: res_y = 1 field_cell = field_dist_1d(self.wavelength, kx, self.T1, self.layer_info_list, self.period, self.pol, - res_x=res_x, res_y=res_y, res_z=res_z, device=self.device, type_complex=self.type_complex) + res_x=res_x, res_y=res_y, res_z=res_z, device=self.device, + type_complex=self.type_complex) elif self._grating_type_assigned == 1: field_cell = field_dist_1d_conical(self.wavelength, kx, ky, self.T1, self.layer_info_list, self.period, - res_x=res_x, res_y=res_y, res_z=res_z, device=self.device, type_complex=self.type_complex) + res_x=res_x, res_y=res_y, res_z=res_z, set_field_input=set_field_input, + device=self.device, type_complex=self.type_complex) else: field_cell = field_dist_2d(self.wavelength, kx, ky, self.T1, self.layer_info_list, self.period, - res_x=res_x, res_y=res_y, res_z=res_z, device=self.device, type_complex=self.type_complex) + res_x=res_x, res_y=res_y, res_z=res_z, set_field_input=set_field_input, + device=self.device, type_complex=self.type_complex) return field_cell - def conv_solve_field(self, res_x=20, res_y=20, res_z=20, **kwargs): + def conv_solve_field(self, res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False), **kwargs): [setattr(self, k, v) for k, v in kwargs.items()] # needed for optimization - de_ri, de_ti = self.conv_solve() - field_cell = self.calculate_field(res_x, res_y, res_z) - return de_ri, de_ti, field_cell + res = self.conv_solve() + field_cell = self.calculate_field(res_x, res_y, res_z, set_field_input) + return res, field_cell def field_plot(self, field_cell): field_plot(field_cell, self.pol) diff --git a/meent/on_torch/emsolver/transfer_method.py b/meent/on_torch/emsolver/transfer_method.py index f03ff38..cbca227 100644 --- a/meent/on_torch/emsolver/transfer_method.py +++ b/meent/on_torch/emsolver/transfer_method.py @@ -377,7 +377,8 @@ def transfer_1d_conical_4(ff_x, ff_y, big_F, big_G, big_T, kz_top, kz_bot, psi, R_p = final_RT[ff_xy: 2 * ff_xy, :].reshape((ff_y, ff_x)) big_T1 = final_RT[2 * ff_xy:, :] - big_T_tetm = big_T.clone().detach() + # big_T_tetm = big_T.clone().detach() + big_T_tetm = big_T.clone() big_T = big_T @ big_T1 T_s = big_T[:ff_xy, :].reshape((ff_y, ff_x)) @@ -447,8 +448,9 @@ def transfer_1d_conical_4(ff_x, ff_y, big_F, big_G, big_T, kz_top, kz_bot, psi, 'de_ti_s': de_ti_s_tetm[1], 'de_ti_p': de_ti_p_tetm[1], 'de_ti': de_ti_tetm[1]} result = {'res': res, 'res_tm_inc': res_tm_inc, 'res_te_inc': res_te_inc} + big_T1_all = torch.stack((big_T1, big_T1_tetm[:, 0:1], big_T1_tetm[:, 1:2])) - return result, big_T1 + return result, big_T1_all def transfer_2d_1(kx, ky, n_top, n_bot, device=torch.device('cpu'), type_complex=torch.complex128): @@ -642,7 +644,8 @@ def transfer_2d_4(ff_x, ff_y, big_F, big_G, big_T, kz_top, kz_bot, psi, theta, n R_p = final_RT[ff_xy: 2 * ff_xy, :].reshape((ff_y, ff_x)) big_T1 = final_RT[2 * ff_xy:, :] - big_T_tetm = big_T.clone().detach() + # big_T_tetm = big_T.clone().detach() + big_T_tetm = big_T.clone() big_T = big_T @ big_T1 T_s = big_T[:ff_xy, :].reshape((ff_y, ff_x)) @@ -712,5 +715,7 @@ def transfer_2d_4(ff_x, ff_y, big_F, big_G, big_T, kz_top, kz_bot, psi, theta, n 'de_ti_s': de_ti_s_tetm[1], 'de_ti_p': de_ti_p_tetm[1], 'de_ti': de_ti_tetm[1]} result = {'res': res, 'res_tm_inc': res_tm_inc, 'res_te_inc': res_te_inc} + # big_T1_all = [big_T1, big_T1_tetm[:, 0:1], big_T1_tetm[:, 1:2]] + big_T1_all = torch.stack((big_T1, big_T1_tetm[:, 0:1], big_T1_tetm[:, 1:2])) - return result, big_T1 + return result, big_T1_all diff --git a/tutorials/01-modeling-and-emsolver.ipynb b/tutorials/01-modeling-and-emsolver.ipynb index 250efce..06fe62f 100644 --- a/tutorials/01-modeling-and-emsolver.ipynb +++ b/tutorials/01-modeling-and-emsolver.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -24,7 +24,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -122,7 +122,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -148,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -296,7 +296,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -757,7 +757,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### 2D" + "### General\n", + "In general cases, the fields can be individually simulated with 3 different polarized incidence simultaneously (with additional calculation time) as RCWA does.\n" ] }, { @@ -791,8 +792,62 @@ "\n", "mee = meent.call_mee(backend=0, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi,\n", " fto=fto, wavelength=wavelength, period=period, ucell=ucell_2d_m, \n", - " thickness=thickness, type_complex=type_complex)\n", - "result, field_cell = mee.conv_solve_field(res_z=100, res_y=100, res_x=100)\n" + " thickness=thickness, type_complex=type_complex)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "result, field_cell = mee.conv_solve_field(res_z=100, res_y=100, res_x=100, set_field_input=(True, True, True))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`set_field_input`is the switch to select which polarization source to use to recover the field.\n", + "\n", + "* The 1st element: recover the field generated from user set polarization\n", + "* The 2nd element: recover the field generated from pure TE polarized incidence\n", + "* The 3rd element: recover the field generated from pure TM polarized incidence" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 400, 100, 100, 6)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "field_cell.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first axis is 3 where each subarrays are the set of fields generated from user set pol, TE and TM cases.\n", + "\n", + "The second axis is Z-direction.\n", + "\n", + "The third axis is Y-direction\n", + "\n", + "The fourth is X-direction.\n", + "\n", + "The last represents the type of the field (Ex, Ey, Ez, Hx, Hy, Hz)" ] }, { @@ -804,7 +859,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -833,7 +888,7 @@ "title = ['2D Ex', '2D Ey', '2D Ez', '2D Hx', '2D Hy', '2D Hz', ]\n", "\n", "for ix in range(3):\n", - " val = abs(field_cell[:, 0, :, ix]) ** 2\n", + " val = abs(field_cell[0, :, 0, :, ix]) ** 2\n", " im = axes[ix].imshow(val, cmap='jet', aspect='auto')\n", " # plt.clim(0, 2) # identical to caxis([-4,4]) in MATLAB\n", " fig.colorbar(im, ax=axes[ix], shrink=1)\n", @@ -843,7 +898,7 @@ "\n", "fig, axes = plt.subplots(1, 3, figsize=(10, 2))\n", "for ix in range(3, 6, 1):\n", - " val = abs(field_cell[:, 0, :, ix]) ** 2\n", + " val = abs(field_cell[0, :, 0, :, ix]) ** 2\n", " im = axes[ix-3].imshow(val, cmap='jet', aspect='auto')\n", " # plt.clim(0, 2) # identical to caxis([-4,4]) in MATLAB\n", " fig.colorbar(im, ax=axes[ix-3], shrink=1)\n", @@ -861,7 +916,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -890,7 +945,7 @@ "title = ['2D Ex', '2D Ey', '2D Ez', '2D Hx', '2D Hy', '2D Hz', ]\n", "\n", "for ix in range(3):\n", - " val = abs(field_cell[0, :, :, ix]) ** 2\n", + " val = abs(field_cell[0, 0, :, :, ix]) ** 2\n", " im = axes[ix].imshow(val, cmap='jet', aspect='auto')\n", " # plt.clim(0, 2) # identical to caxis([-4,4]) in MATLAB\n", " fig.colorbar(im, ax=axes[ix], shrink=1)\n", @@ -900,7 +955,7 @@ "\n", "fig, axes = plt.subplots(1, 3, figsize=(10, 2))\n", "for ix in range(3, 6, 1):\n", - " val = abs(field_cell[0, :, :, ix]) ** 2\n", + " val = abs(field_cell[0, 0, :, :, ix]) ** 2\n", " im = axes[ix-3].imshow(val, cmap='jet', aspect='auto')\n", " # plt.clim(0, 2) # identical to caxis([-4,4]) in MATLAB\n", " fig.colorbar(im, ax=axes[ix-3], shrink=1)\n",