Skip to content

Commit

Permalink
Standardization of function argument order
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626329477
  • Loading branch information
jcitrin authored and Torax team committed Apr 19, 2024
1 parent e1c50e3 commit 8219645
Show file tree
Hide file tree
Showing 36 changed files with 692 additions and 648 deletions.
98 changes: 49 additions & 49 deletions torax/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@


def calculate_pereverzev_flux(
core_profiles: state.CoreProfiles,
geo: geometry.Geometry,
dynamic_config_slice: config_slice.DynamicConfigSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
"""Adds Pereverzev-Corrigan flux to diffusion terms."""

Expand Down Expand Up @@ -134,32 +134,30 @@ def calculate_pereverzev_flux(


def calc_coeffs(
core_profiles: state.CoreProfiles,
evolving_names: tuple[str, ...],
geo: geometry.Geometry,
dynamic_config_slice: config_slice.DynamicConfigSlice,
static_config_slice: config_slice.StaticConfigSlice,
dynamic_config_slice: config_slice.DynamicConfigSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
transport_model: transport_model_lib.TransportModel,
explicit_source_profiles: source_profiles_lib.SourceProfiles,
source_models: source_models_lib.SourceModels,
evolving_names: tuple[str, ...],
use_pereverzev: bool = False,
explicit_call: bool = False,
) -> block_1d_coeffs.Block1DCoeffs:
"""Calculates Block1DCoeffs for the time step described by `core_profiles`.
Args:
static_config_slice: General input parameters which are fixed through a
simulation run, and if changed, would trigger a recompile.
dynamic_config_slice: General input parameters that can change from time
step to time step or simulation run to run, and do so without triggering a
recompile.
geo: Geometry describing the torus.
core_profiles: Core plasma profiles for this time step during this iteration
of the solver. Depending on the type of stepper being used, this may or
may not be equal to the original plasma profiles at the beginning of the
time step.
evolving_names: The names of the evolving variables in the order that their
coefficients should be written to `coeffs`.
geo: Geometry describing the torus.
dynamic_config_slice: General input parameters that can change from time
step to time step or simulation run to run, and do so without triggering a
recompile.
static_config_slice: General input parameters which are fixed through a
simulation run, and if changed, would trigger a recompile.
transport_model: A TransportModel subclass, calculates transport coeffs.
explicit_source_profiles: Precomputed explicit source profiles. These
profiles either do not depend on the core profiles or depend on the
Expand All @@ -169,6 +167,8 @@ def calc_coeffs(
source_models: All TORAX source/sink functions that generate the explicit
and implicit source profiles used as terms for the core profiles
equations.
evolving_names: The names of the evolving variables in the order that their
coefficients should be written to `coeffs`.
use_pereverzev: Toggle whether to calculate Pereverzev terms
explicit_call: If True, indicates that calc_coeffs is being called for the
explicit component of the PDE. Then calculates a reduced Block1DCoeffs if
Expand All @@ -183,59 +183,57 @@ def calc_coeffs(
# explicit components of the PDE, only return a cheaper reduced Block1DCoeffs
if explicit_call and static_config_slice.solver.theta_imp == 1.0:
return _calc_coeffs_reduced(
geo,
core_profiles,
evolving_names,
geo,
)
else:
return _calc_coeffs_full(
core_profiles,
evolving_names,
geo,
dynamic_config_slice,
static_config_slice,
dynamic_config_slice,
geo,
core_profiles,
transport_model,
explicit_source_profiles,
source_models,
evolving_names,
use_pereverzev,
)


@functools.partial(
jax_utils.jit,
static_argnames=[
'transport_model',
'static_config_slice',
'evolving_names',
'transport_model',
'source_models',
'evolving_names',
],
)
def _calc_coeffs_full(
core_profiles: state.CoreProfiles,
evolving_names: tuple[str, ...],
geo: geometry.Geometry,
dynamic_config_slice: config_slice.DynamicConfigSlice,
static_config_slice: config_slice.StaticConfigSlice,
dynamic_config_slice: config_slice.DynamicConfigSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
transport_model: transport_model_lib.TransportModel,
explicit_source_profiles: source_profiles_lib.SourceProfiles,
source_models: source_models_lib.SourceModels,
evolving_names: tuple[str, ...],
use_pereverzev: bool = False,
) -> block_1d_coeffs.Block1DCoeffs:
"""Calculates Block1DCoeffs for the time step described by `core_profiles`.
Args:
static_config_slice: General input parameters which are fixed through a
simulation run, and if changed, would trigger a recompile.
dynamic_config_slice: General input parameters that can change from time
step to time step or simulation run to run, and do so without triggering a
recompile.
geo: Geometry describing the torus.
core_profiles: Core plasma profiles for this time step during this iteration
of the solver. Depending on the type of stepper being used, this may or
may not be equal to the original plasma profiles at the beginning of the
time step.
evolving_names: The names of the evolving variables in the order that their
coefficients should be written to `coeffs`.
geo: Geometry describing the torus.
dynamic_config_slice: General input parameters that can change from time
step to time step or simulation run to run, and do so without triggering a
recompile.
static_config_slice: General input parameters which are fixed through a
simulation run, and if changed, would trigger a recompile.
transport_model: A TransportModel subclass, calculates transport coeffs.
explicit_source_profiles: Precomputed explicit source profiles. These
profiles either do not depend on the core profiles or depend on the
Expand All @@ -245,6 +243,8 @@ def _calc_coeffs_full(
source_models: All TORAX source/sink functions that generate the explicit
and implicit source profiles used as terms for the core profiles
equations.
evolving_names: The names of the evolving variables in the order that their
coefficients should be written to `coeffs`.
use_pereverzev: Toggle whether to calculate Pereverzev terms
Returns:
Expand Down Expand Up @@ -319,13 +319,13 @@ def _calc_coeffs_full(

# fill source vector based on both original and updated core profiles
source_psi = source_models_lib.sum_sources_psi(
geo,
source_models,
implicit_source_profiles,
geo,
) + source_models_lib.sum_sources_psi(
geo,
source_models,
explicit_source_profiles,
geo,
)

true_ne_face = core_profiles.ne.face_value() * dynamic_config_slice.nref
Expand Down Expand Up @@ -501,13 +501,13 @@ def _calc_coeffs_full(

# density source vector based both on original and updated core profiles
source_ne = source_models_lib.sum_sources_ne(
source_models,
explicit_source_profiles,
geo,
) + source_models_lib.sum_sources_ne(
explicit_source_profiles,
source_models,
implicit_source_profiles,
) + source_models_lib.sum_sources_ne(
geo,
implicit_source_profiles,
source_models,
)

if full_v_face_el is not None:
Expand Down Expand Up @@ -562,7 +562,7 @@ def _calc_coeffs_full(
) = jax.lax.cond(
use_pereverzev,
lambda: calculate_pereverzev_flux(
core_profiles, geo, dynamic_config_slice
dynamic_config_slice, geo, core_profiles,
),
lambda: tuple([jnp.zeros_like(geo.r_face)] * 6),
)
Expand Down Expand Up @@ -598,23 +598,23 @@ def _calc_coeffs_full(
source_mat_ee = jnp.zeros_like(geo.r)

source_i = source_models_lib.sum_sources_temp_ion(
source_models,
explicit_source_profiles,
geo,
) + source_models_lib.sum_sources_temp_ion(
explicit_source_profiles,
source_models,
implicit_source_profiles,
) + source_models_lib.sum_sources_temp_ion(
geo,
implicit_source_profiles,
source_models,
)

source_e = source_models_lib.sum_sources_temp_el(
source_models,
explicit_source_profiles,
geo,
) + source_models_lib.sum_sources_temp_el(
explicit_source_profiles,
source_models,
implicit_source_profiles,
) + source_models_lib.sum_sources_temp_el(
geo,
implicit_source_profiles,
source_models,
)

# Add the Qei effects.
Expand Down Expand Up @@ -727,9 +727,9 @@ def _calc_coeffs_full(
],
)
def _calc_coeffs_reduced(
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
evolving_names: tuple[str, ...],
geo: geometry.Geometry,
) -> block_1d_coeffs.Block1DCoeffs:
"""Calculates only the transient_in_cell terms in Block1DCoeffs."""

Expand Down
Loading

0 comments on commit 8219645

Please sign in to comment.