Skip to content

Commit

Permalink
Merge pull request #83 from kc-ml2/DEV/main
Browse files Browse the repository at this point in the history
return field cell for all 3 input polization cases
  • Loading branch information
yonghakim authored Dec 20, 2024
2 parents 310f0b6 + d6805bd commit efb9696
Show file tree
Hide file tree
Showing 12 changed files with 455 additions and 214 deletions.
19 changes: 15 additions & 4 deletions benchmarks/reti_meent_1Dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,32 @@ def run_1dc(option, plot_figure=False):
# Numpy
mee = meent.call_mee(backend=0, **option)
res_numpy = mee.conv_solve()
field_cell_numpy = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x)
field_cell_numpy = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True))

# JAX
mee = meent.call_mee(backend=1, **option) # JAX
res_jax = mee.conv_solve()
field_cell_jax = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x)
field_cell_jax = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True))

# Torch
mee = meent.call_mee(backend=2, **option) # PyTorch
res_torch = mee.conv_solve()
field_cell_torch = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x).numpy()
field_cell_torch = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True)).numpy()

bds = ['Numpy', 'JAX', 'Torch']
fields = [field_cell_numpy, field_cell_jax, field_cell_torch]

print('Field difference: given_pol to TE, given_pol to TM')
n1 = np.linalg.norm(field_cell_numpy[0] - field_cell_numpy[1])
n2 = np.linalg.norm(field_cell_numpy[0] - field_cell_numpy[2])
j1 = np.linalg.norm(field_cell_jax[0] - field_cell_jax[1])
j2 = np.linalg.norm(field_cell_jax[0] - field_cell_jax[2])
t1 = np.linalg.norm(field_cell_torch[0] - field_cell_torch[1])
t2 = np.linalg.norm(field_cell_torch[0] - field_cell_torch[2])
print(f'numpy: {n1}, {n2}')
print(f'jax: {j1}, {j2}')
print(f'torch: {t1}, {t2}')

print('Norm of (meent - reti) per backend')
for i, res_t in enumerate([res_numpy, res_jax, res_torch]):
reti_de_ri_te, reti_de_ti_te = np.array(top_refl_info_te.efficiency).T, np.array(top_tran_info_te.efficiency).T
Expand Down Expand Up @@ -78,7 +89,7 @@ def run_1dc(option, plot_figure=False):
)

for i_field in range(reti_field_cell.shape[-1]):
res_temp = np.linalg.norm(fields[i][i_field] - reti_field_cell[i_field])
res_temp = np.linalg.norm(fields[i][0,:,:,:,i_field] - reti_field_cell[:,:,:,i_field])
print(f'field, {i_field+1}th: {res_temp}')

if plot_figure:
Expand Down
25 changes: 18 additions & 7 deletions benchmarks/reti_meent_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,32 @@ def run_2d(option, case, plot_figure=False):
# Numpy
mee = meent.call_mee(backend=0, **option)
res_numpy = mee.conv_solve()
field_cell_numpy = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x)
field_cell_numpy = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True))

# JAX
mee = meent.call_mee(backend=1, **option) # JAX
res_jax = mee.conv_solve()
field_cell_jax = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x)
field_cell_jax = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True))

# Torch
mee = meent.call_mee(backend=2, **option) # PyTorch
res_torch = mee.conv_solve()
field_cell_torch = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x).numpy()
field_cell_torch = mee.calculate_field(res_z=res_z, res_y=res_y, res_x=res_x, set_field_input=(True, True, True)).numpy()

bds = ['Numpy', 'JAX', 'Torch']
fields = [field_cell_numpy, field_cell_jax, field_cell_torch]

print('Field difference: given_pol to TE, given_pol to TM')
n1 = np.linalg.norm(field_cell_numpy[0] - field_cell_numpy[1])
n2 = np.linalg.norm(field_cell_numpy[0] - field_cell_numpy[2])
j1 = np.linalg.norm(field_cell_jax[0] - field_cell_jax[1])
j2 = np.linalg.norm(field_cell_jax[0] - field_cell_jax[2])
t1 = np.linalg.norm(field_cell_torch[0] - field_cell_torch[1])
t2 = np.linalg.norm(field_cell_torch[0] - field_cell_torch[2])
print(f'numpy: {n1}, {n2}')
print(f'jax: {j1}, {j2}')
print(f'torch: {t1}, {t2}')

print('Norm of (meent - reti) per backend')
for i, res_t in enumerate([res_numpy, res_jax, res_torch]):
reti_de_ri_te, reti_de_ti_te = np.array(top_refl_info_te.efficiency).T, np.array(top_tran_info_te.efficiency).T
Expand Down Expand Up @@ -85,7 +96,7 @@ def run_2d(option, case, plot_figure=False):
)

for i_field in range(reti_field_cell.shape[-1]):
res_temp = np.linalg.norm(fields[i][i_field] - reti_field_cell[i_field])
res_temp = np.linalg.norm(fields[i][0,:,:,:,i_field] - reti_field_cell[:,:,:,i_field])
print(f'field, {i_field+1}th: {res_temp}')

if plot_figure:
Expand All @@ -103,7 +114,7 @@ def run_2d(option, case, plot_figure=False):
im = axes[ix, 4].imshow(r_data.imag, cmap='jet', aspect='auto')
fig.colorbar(im, ax=axes[ix, 4], shrink=1)

n_data = fields[i][:, res_y//2, :, ix]
n_data = fields[i][0, :, res_y//2, :, ix]

im = axes[ix, 1].imshow(abs(n_data) ** 2, cmap='jet', aspect='auto')
fig.colorbar(im, ax=axes[ix, 1], shrink=1)
Expand Down Expand Up @@ -161,7 +172,7 @@ def case_2d_1(plot_figure=False):
def case_2d_2(plot_figure=False):
factor = 1
option = {}
option['pol'] = 1 # 0: TE, 1: TM
option['pol'] = 0 # 0: TE, 1: TM
option['n_top'] = 1 # n_incidence
option['n_bot'] = 1 # n_transmission
option['theta'] = 20 * np.pi / 180
Expand Down Expand Up @@ -254,7 +265,7 @@ def case_2d_4(plot_figure=False):
def case_2d_5(plot_figure=False):
factor = 1
option = {}
option['pol'] = 0 # 0: TE, 1: TM
option['pol'] = 0.5 # 0: TE, 1: TM
option['n_top'] = 1 # n_incidence
option['n_bot'] = 1 # n_transmission
option['theta'] = 0 * np.pi / 180
Expand Down
116 changes: 68 additions & 48 deletions meent/on_jax/emsolver/field_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

def field_dist_1d(wavelength, kx, T1, layer_info_list, period, pol, res_x=20, res_y=1, res_z=20,
type_complex=jnp.complex128):

k0 = 2 * jnp.pi / wavelength
Kx = jnp.diag(kx)

Expand Down Expand Up @@ -66,8 +65,8 @@ def field_dist_1d(wavelength, kx, T1, layer_info_list, period, pol, res_x=20, re

# @partial(jax.jit, static_argnums=(5, 6, 10, 11, 12, 13))
def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period,
res_x=20, res_y=20, res_z=20, type_complex=jnp.complex128):

res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False),
type_complex=jnp.complex128):
k0 = 2 * jnp.pi / wavelength

ff_x = len(kx)
Expand All @@ -77,11 +76,13 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period,
Kx = jnp.diag(jnp.tile(kx, ff_y).flatten())
Ky = jnp.diag(jnp.tile(ky.reshape((-1, 1)), ff_x).flatten())

field_cell = jnp.zeros((res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex)
field_cell = jnp.zeros((sum(set_field_input), res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex)

T_layer = T1
# T_layer = T1
# T_layer = T1[set_field_input]
T_layer = T1[jnp.array(set_field_input)]

big_I = jnp.eye((len(T1))).astype(type_complex)
big_I = jnp.eye((len(T1[0]))).astype(type_complex)
O = jnp.zeros((ff_xy, ff_xy), dtype=type_complex)

# From the first layer
Expand All @@ -106,20 +107,27 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period,
# z_1d = np.arange(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d
z_1d = jnp.linspace(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d

c1_plus = c[0 * ff_xy:1 * ff_xy]
c2_plus = c[1 * ff_xy:2 * ff_xy]
c1_minus = c[2 * ff_xy:3 * ff_xy]
c2_minus = c[3 * ff_xy:4 * ff_xy]
c1_plus = c[:, 0 * ff_xy:1 * ff_xy]
c2_plus = c[:, 1 * ff_xy:2 * ff_xy]
c1_minus = c[:, 2 * ff_xy:3 * ff_xy]
c2_minus = c[:, 3 * ff_xy:4 * ff_xy]

big_Q1 = jnp.diag(q_1)
big_Q2 = jnp.diag(q_2)

Sx = W_2 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus)
Sy = V_11 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \
+ V_12 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus)
Ux = W_1 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus)
Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \
+ V_22 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus)
Sx = W_2 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus
+ d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus)
Sy = V_11 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus
+ d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \
+ V_12 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus
+ d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus)
Ux = W_1 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus
+ d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus)
Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus
+ d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \
+ V_22 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus
+ d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus)

Sz = -1j * epz_conv_i @ (Kx @ Uy - Ky @ Ux)
Uz = -1j * (Kx @ Sy - Ky @ Sx)

Expand All @@ -138,17 +146,19 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period,
inv_fourier = jnp.exp(-1j * x_2d) * jnp.exp(-1j * y_2d)
inv_fourier = inv_fourier.reshape((res_y, res_x, -1))

Ex = inv_fourier[:, :, None, :] @ Sx[:, None, None, :, :]
Ey = inv_fourier[:, :, None, :] @ Sy[:, None, None, :, :]
Ez = inv_fourier[:, :, None, :] @ Sz[:, None, None, :, :]
Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, None, None, :, :]
Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, None, None, :, :]
Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, None, None, :, :]
Ex = inv_fourier[:, :, None, :] @ Sx[:, :, None, None, :, :]
Ey = inv_fourier[:, :, None, :] @ Sy[:, :, None, None, :, :]
Ez = inv_fourier[:, :, None, :] @ Sz[:, :, None, None, :, :]
Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, :, None, None, :, :]
Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, :, None, None, :, :]
Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, :, None, None, :, :]

val = jnp.concatenate(
(Ex.squeeze(-1), Ey.squeeze(-1), Ez.squeeze(-1), Hx.squeeze(-1), Hy.squeeze(-1), Hz.squeeze(-1)), -1)

field_cell = field_cell.at[res_z * idx_layer:res_z * (idx_layer + 1)].set(val)
val = jnp.moveaxis(val, 1, 0)

field_cell = field_cell.at[:, res_z * idx_layer:res_z * (idx_layer + 1)].set(val)

T_layer = big_A_i @ big_X @ T_layer

Expand All @@ -157,8 +167,8 @@ def field_dist_1d_conical(wavelength, kx, ky, T1, layer_info_list, period,

# @partial(jax.jit, static_argnums=(5, 6, 10, 11, 12, 13))
def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period,
res_x=20, res_y=20, res_z=20, type_complex=jnp.complex128):

res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False),
type_complex=jnp.complex128):
k0 = 2 * jnp.pi / wavelength

ff_x = len(kx)
Expand All @@ -168,15 +178,16 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period,
Kx = jnp.diag(jnp.tile(kx, ff_y).flatten())
Ky = jnp.diag(jnp.tile(ky.reshape((-1, 1)), ff_x).flatten())

field_cell = jnp.zeros((res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex)
field_cell = jnp.zeros((sum(set_field_input), res_z * len(layer_info_list), res_y, res_x, 6), dtype=type_complex)

T_layer = T1
# T_layer = T1
# T_layer = T1[list(set_field_input)]
T_layer = T1[jnp.array(set_field_input)]

big_I = jnp.eye((len(T1))).astype(type_complex)
big_I = jnp.eye((len(T1[0]))).astype(type_complex)

# From the first layer
for idx_layer, (epz_conv_i, W, V, q, d, big_A_i, big_B) in enumerate(layer_info_list[::-1]):

W_11 = W[:ff_xy, :ff_xy]
W_12 = W[:ff_xy, ff_xy:]
W_21 = W[ff_xy:, :ff_xy]
Expand All @@ -193,24 +204,32 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period,
z_1d = jnp.linspace(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d
# z_1d = np.arange(0, res_z, res_z).reshape((-1, 1, 1)) / res_z * d

c1_plus = c[0 * ff_xy:1 * ff_xy]
c2_plus = c[1 * ff_xy:2 * ff_xy]
c1_minus = c[2 * ff_xy:3 * ff_xy]
c2_minus = c[3 * ff_xy:4 * ff_xy]
c1_plus = c[:, 0 * ff_xy:1 * ff_xy]
c2_plus = c[:, 1 * ff_xy:2 * ff_xy]
c1_minus = c[:, 2 * ff_xy:3 * ff_xy]
c2_minus = c[:, 3 * ff_xy:4 * ff_xy]

q1 = q[:len(q) // 2]
q2 = q[len(q) // 2:]
big_Q1 = jnp.diag(q1)
big_Q2 = jnp.diag(q2)

Sx = W_11 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \
+ W_12 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus)
Sy = W_21 @ (d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \
+ W_22 @ (d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus)
Ux = V_11 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \
+ V_12 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus)
Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d) @ c1_plus + d_exp(k0 * big_Q1 * (z_1d - d)) @ c1_minus) \
+ V_22 @ (-d_exp(-k0 * big_Q2 * z_1d) @ c2_plus + d_exp(k0 * big_Q2 * (z_1d - d)) @ c2_minus)
Sx = W_11 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus
+ d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \
+ W_12 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus
+ d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus)
Sy = W_21 @ (d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus
+ d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \
+ W_22 @ (d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus
+ d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus)
Ux = V_11 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus
+ d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \
+ V_12 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus
+ d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus)
Uy = V_21 @ (-d_exp(-k0 * big_Q1 * z_1d)[:, None, :, :] @ c1_plus
+ d_exp(k0 * big_Q1 * (z_1d - d))[:, None, :, :] @ c1_minus) \
+ V_22 @ (-d_exp(-k0 * big_Q2 * z_1d)[:, None, :, :] @ c2_plus
+ d_exp(k0 * big_Q2 * (z_1d - d))[:, None, :, :] @ c2_minus)

Sz = -1j * epz_conv_i @ (Kx @ Uy - Ky @ Ux)
Uz = -1j * (Kx @ Sy - Ky @ Sx)
Expand All @@ -232,17 +251,18 @@ def field_dist_2d(wavelength, kx, ky, T1, layer_info_list, period,
inv_fourier = jnp.exp(-1j * x_2d) * jnp.exp(-1j * y_2d)
inv_fourier = inv_fourier.reshape((res_y, res_x, -1))

Ex = inv_fourier[:, :, None, :] @ Sx[:, None, None, :, :]
Ey = inv_fourier[:, :, None, :] @ Sy[:, None, None, :, :]
Ez = inv_fourier[:, :, None, :] @ Sz[:, None, None, :, :]
Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, None, None, :, :]
Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, None, None, :, :]
Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, None, None, :, :]
Ex = inv_fourier[:, :, None, :] @ Sx[:, :, None, None, :, :]
Ey = inv_fourier[:, :, None, :] @ Sy[:, :, None, None, :, :]
Ez = inv_fourier[:, :, None, :] @ Sz[:, :, None, None, :, :]
Hx = 1j * inv_fourier[:, :, None, :] @ Ux[:, :, None, None, :, :]
Hy = 1j * inv_fourier[:, :, None, :] @ Uy[:, :, None, None, :, :]
Hz = 1j * inv_fourier[:, :, None, :] @ Uz[:, :, None, None, :, :]

val = jnp.concatenate(
(Ex.squeeze(-1), Ey.squeeze(-1), Ez.squeeze(-1), Hx.squeeze(-1), Hy.squeeze(-1), Hz.squeeze(-1)), -1)
val = jnp.moveaxis(val, 1, 0)

field_cell = field_cell.at[res_z * idx_layer:res_z * (idx_layer + 1)].set(val)
field_cell = field_cell.at[:, res_z * idx_layer:res_z * (idx_layer + 1)].set(val)

T_layer = big_A_i @ big_X @ T_layer

Expand Down
16 changes: 9 additions & 7 deletions meent/on_jax/emsolver/rcwa.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def conv_solve(self, **kwargs):
# return de_ri, de_ti

@jax_device_set
def calculate_field(self, res_x=20, res_y=20, res_z=20):
def calculate_field(self, res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False)):
kx, ky = self.get_kx_ky_vector(wavelength=self.wavelength)

if self._grating_type_assigned == 0:
Expand All @@ -269,10 +269,12 @@ def calculate_field(self, res_x=20, res_y=20, res_z=20):
res_x=res_x, res_y=res_y, res_z=res_z, type_complex=self.type_complex)
elif self._grating_type_assigned == 1:
field_cell = field_dist_1d_conical(self.wavelength, kx, ky, self.T1, self.layer_info_list, self.period,
res_x=res_x, res_y=res_y, res_z=res_z, type_complex=self.type_complex)
res_x=res_x, res_y=res_y, res_z=res_z, set_field_input=set_field_input,
type_complex=self.type_complex)
else:
field_cell = field_dist_2d(self.wavelength, kx, ky, self.T1, self.layer_info_list, self.period,
res_x=res_x, res_y=res_y, res_z=res_z, type_complex=self.type_complex)
res_x=res_x, res_y=res_y, res_z=res_z, set_field_input=set_field_input,
type_complex=self.type_complex)

return field_cell

Expand All @@ -281,21 +283,21 @@ def field_plot(self, field_cell):

@partial(jax.jit, static_argnums=(1, 2, 3))
@jax_device_set
def conv_solve_field(self, res_x=20, res_y=20, res_z=20, **kwargs):
def conv_solve_field(self, res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False), **kwargs):
[setattr(self, k, v) for k, v in kwargs.items()] # needed for optimization

if self.fourier_type == 1:
print('CFT (fourier_type=1) is not supported with JAX jit-compilation. Use conv_solve_field_no_jit.')
return None, None, None

de_ri, de_ti, _, _ = self._conv_solve()
field_cell = self.calculate_field(res_x, res_y, res_z)
field_cell = self.calculate_field(res_x, res_y, res_z, set_field_input)
return de_ri, de_ti, field_cell

@jax_device_set
def conv_solve_field_no_jit(self, res_x=20, res_y=20, res_z=20):
def conv_solve_field_no_jit(self, res_x=20, res_y=20, res_z=20, set_field_input=(True, False, False)):
de_ri, de_ti, _, _ = self._conv_solve()
field_cell = self.calculate_field(res_x, res_y, res_z)
field_cell = self.calculate_field(res_x, res_y, res_z, set_field_input)
return de_ri, de_ti, field_cell

def run_ucell_vmap(self, ucell_list):
Expand Down
Loading

0 comments on commit efb9696

Please sign in to comment.