From 8219645ca499f913d82e05914a4dfbca6a7550e6 Mon Sep 17 00:00:00 2001 From: Jonathan Citrin Date: Fri, 19 Apr 2024 04:54:51 -0700 Subject: [PATCH] Standardization of function argument order PiperOrigin-RevId: 626329477 --- torax/calc_coeffs.py | 98 ++++---- torax/core_profile_setters.py | 62 ++--- torax/fvm/block_1d_coeffs.py | 4 +- torax/fvm/discrete_system.py | 4 +- torax/fvm/implicit_solve_block.py | 8 +- torax/fvm/newton_raphson_solve_block.py | 78 +++--- torax/fvm/optimizer_solve_block.py | 74 +++--- torax/fvm/residual_and_loss.py | 190 +++++++------- torax/fvm/tests/fvm.py | 126 +++++----- torax/physics.py | 12 +- torax/sim.py | 246 +++++++++++-------- torax/sources/bootstrap_current_source.py | 8 +- torax/sources/external_current_source.py | 12 +- torax/sources/formulas.py | 8 +- torax/sources/qei_source.py | 6 +- torax/sources/source.py | 4 +- torax/sources/source_models.py | 39 +-- torax/sources/tests/formulas.py | 8 +- torax/sources/tests/fusion_heat_source.py | 2 +- torax/sources/tests/qei_source.py | 8 +- torax/sources/tests/source.py | 12 +- torax/sources/tests/source_models.py | 20 +- torax/sources/tests/test_lib.py | 8 +- torax/state.py | 11 +- torax/stepper/linear_theta_method.py | 32 +-- torax/stepper/nonlinear_theta_method.py | 74 +++--- torax/stepper/predictor_corrector_method.py | 24 +- torax/stepper/stepper.py | 80 +++--- torax/tests/boundary_conditions.py | 4 +- torax/tests/physics.py | 2 +- torax/tests/sim_output_source_profiles.py | 10 +- torax/tests/sim_time_dependence.py | 20 +- torax/tests/state.py | 18 +- torax/tests/test_lib/explicit_stepper.py | 18 +- torax/transport_model/qlknn_wrapper.py | 8 +- torax/transport_model/tests/qlknn_wrapper.py | 2 +- 36 files changed, 692 insertions(+), 648 deletions(-) diff --git a/torax/calc_coeffs.py b/torax/calc_coeffs.py index f16e64c9..122a1933 100644 --- a/torax/calc_coeffs.py +++ b/torax/calc_coeffs.py @@ -34,9 +34,9 @@ def calculate_pereverzev_flux( - core_profiles: state.CoreProfiles, - geo: geometry.Geometry, dynamic_config_slice: config_slice.DynamicConfigSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: """Adds Pereverzev-Corrigan flux to diffusion terms.""" @@ -134,32 +134,30 @@ def calculate_pereverzev_flux( def calc_coeffs( - core_profiles: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, - dynamic_config_slice: config_slice.DynamicConfigSlice, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice: config_slice.DynamicConfigSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, transport_model: transport_model_lib.TransportModel, explicit_source_profiles: source_profiles_lib.SourceProfiles, source_models: source_models_lib.SourceModels, + evolving_names: tuple[str, ...], use_pereverzev: bool = False, explicit_call: bool = False, ) -> block_1d_coeffs.Block1DCoeffs: """Calculates Block1DCoeffs for the time step described by `core_profiles`. Args: + static_config_slice: General input parameters which are fixed through a + simulation run, and if changed, would trigger a recompile. + dynamic_config_slice: General input parameters that can change from time + step to time step or simulation run to run, and do so without triggering a + recompile. + geo: Geometry describing the torus. core_profiles: Core plasma profiles for this time step during this iteration of the solver. Depending on the type of stepper being used, this may or may not be equal to the original plasma profiles at the beginning of the time step. - evolving_names: The names of the evolving variables in the order that their - coefficients should be written to `coeffs`. - geo: Geometry describing the torus. - dynamic_config_slice: General input parameters that can change from time - step to time step or simulation run to run, and do so without triggering a - recompile. - static_config_slice: General input parameters which are fixed through a - simulation run, and if changed, would trigger a recompile. transport_model: A TransportModel subclass, calculates transport coeffs. explicit_source_profiles: Precomputed explicit source profiles. These profiles either do not depend on the core profiles or depend on the @@ -169,6 +167,8 @@ def calc_coeffs( source_models: All TORAX source/sink functions that generate the explicit and implicit source profiles used as terms for the core profiles equations. + evolving_names: The names of the evolving variables in the order that their + coefficients should be written to `coeffs`. use_pereverzev: Toggle whether to calculate Pereverzev terms explicit_call: If True, indicates that calc_coeffs is being called for the explicit component of the PDE. Then calculates a reduced Block1DCoeffs if @@ -183,20 +183,20 @@ def calc_coeffs( # explicit components of the PDE, only return a cheaper reduced Block1DCoeffs if explicit_call and static_config_slice.solver.theta_imp == 1.0: return _calc_coeffs_reduced( + geo, core_profiles, evolving_names, - geo, ) else: return _calc_coeffs_full( - core_profiles, - evolving_names, - geo, - dynamic_config_slice, static_config_slice, + dynamic_config_slice, + geo, + core_profiles, transport_model, explicit_source_profiles, source_models, + evolving_names, use_pereverzev, ) @@ -204,38 +204,36 @@ def calc_coeffs( @functools.partial( jax_utils.jit, static_argnames=[ - 'transport_model', 'static_config_slice', - 'evolving_names', + 'transport_model', 'source_models', + 'evolving_names', ], ) def _calc_coeffs_full( - core_profiles: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, - dynamic_config_slice: config_slice.DynamicConfigSlice, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice: config_slice.DynamicConfigSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, transport_model: transport_model_lib.TransportModel, explicit_source_profiles: source_profiles_lib.SourceProfiles, source_models: source_models_lib.SourceModels, + evolving_names: tuple[str, ...], use_pereverzev: bool = False, ) -> block_1d_coeffs.Block1DCoeffs: """Calculates Block1DCoeffs for the time step described by `core_profiles`. Args: + static_config_slice: General input parameters which are fixed through a + simulation run, and if changed, would trigger a recompile. + dynamic_config_slice: General input parameters that can change from time + step to time step or simulation run to run, and do so without triggering a + recompile. + geo: Geometry describing the torus. core_profiles: Core plasma profiles for this time step during this iteration of the solver. Depending on the type of stepper being used, this may or may not be equal to the original plasma profiles at the beginning of the time step. - evolving_names: The names of the evolving variables in the order that their - coefficients should be written to `coeffs`. - geo: Geometry describing the torus. - dynamic_config_slice: General input parameters that can change from time - step to time step or simulation run to run, and do so without triggering a - recompile. - static_config_slice: General input parameters which are fixed through a - simulation run, and if changed, would trigger a recompile. transport_model: A TransportModel subclass, calculates transport coeffs. explicit_source_profiles: Precomputed explicit source profiles. These profiles either do not depend on the core profiles or depend on the @@ -245,6 +243,8 @@ def _calc_coeffs_full( source_models: All TORAX source/sink functions that generate the explicit and implicit source profiles used as terms for the core profiles equations. + evolving_names: The names of the evolving variables in the order that their + coefficients should be written to `coeffs`. use_pereverzev: Toggle whether to calculate Pereverzev terms Returns: @@ -319,13 +319,13 @@ def _calc_coeffs_full( # fill source vector based on both original and updated core profiles source_psi = source_models_lib.sum_sources_psi( + geo, source_models, implicit_source_profiles, - geo, ) + source_models_lib.sum_sources_psi( + geo, source_models, explicit_source_profiles, - geo, ) true_ne_face = core_profiles.ne.face_value() * dynamic_config_slice.nref @@ -501,13 +501,13 @@ def _calc_coeffs_full( # density source vector based both on original and updated core profiles source_ne = source_models_lib.sum_sources_ne( - source_models, - explicit_source_profiles, geo, - ) + source_models_lib.sum_sources_ne( + explicit_source_profiles, source_models, - implicit_source_profiles, + ) + source_models_lib.sum_sources_ne( geo, + implicit_source_profiles, + source_models, ) if full_v_face_el is not None: @@ -562,7 +562,7 @@ def _calc_coeffs_full( ) = jax.lax.cond( use_pereverzev, lambda: calculate_pereverzev_flux( - core_profiles, geo, dynamic_config_slice + dynamic_config_slice, geo, core_profiles, ), lambda: tuple([jnp.zeros_like(geo.r_face)] * 6), ) @@ -598,23 +598,23 @@ def _calc_coeffs_full( source_mat_ee = jnp.zeros_like(geo.r) source_i = source_models_lib.sum_sources_temp_ion( - source_models, - explicit_source_profiles, geo, - ) + source_models_lib.sum_sources_temp_ion( + explicit_source_profiles, source_models, - implicit_source_profiles, + ) + source_models_lib.sum_sources_temp_ion( geo, + implicit_source_profiles, + source_models, ) source_e = source_models_lib.sum_sources_temp_el( - source_models, - explicit_source_profiles, geo, - ) + source_models_lib.sum_sources_temp_el( + explicit_source_profiles, source_models, - implicit_source_profiles, + ) + source_models_lib.sum_sources_temp_el( geo, + implicit_source_profiles, + source_models, ) # Add the Qei effects. @@ -727,9 +727,9 @@ def _calc_coeffs_full( ], ) def _calc_coeffs_reduced( + geo: geometry.Geometry, core_profiles: state.CoreProfiles, evolving_names: tuple[str, ...], - geo: geometry.Geometry, ) -> block_1d_coeffs.Block1DCoeffs: """Calculates only the transient_in_cell terms in Block1DCoeffs.""" diff --git a/torax/core_profile_setters.py b/torax/core_profile_setters.py index 816b38b9..ab218692 100644 --- a/torax/core_profile_setters.py +++ b/torax/core_profile_setters.py @@ -39,8 +39,8 @@ def _updated_ti( - dynamic_config_slice: config_slice.DynamicConfigSlice, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice: config_slice.DynamicConfigSlice, geo: Geometry, ) -> fvm.CellVariable: """Updated ion temp. Used upon initialization and if temp_ion=False.""" @@ -70,8 +70,8 @@ def _updated_ti( def _updated_te( - dynamic_config_slice: config_slice.DynamicConfigSlice, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice: config_slice.DynamicConfigSlice, geo: Geometry, ) -> fvm.CellVariable: """Updated electron temp. Used upon initialization and if temp_el=False.""" @@ -101,8 +101,8 @@ def _updated_te( def _updated_dens( - dynamic_config_slice: config_slice.DynamicConfigSlice, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice: config_slice.DynamicConfigSlice, geo: Geometry, ) -> tuple[fvm.CellVariable, fvm.CellVariable]: """Updated particle density. Used upon initialization and if dens_eq=False.""" @@ -250,27 +250,27 @@ def _prescribe_currents_no_bootstrap( def _prescribe_currents_with_bootstrap( dynamic_config_slice: config_slice.DynamicConfigSlice, geo: Geometry, - source_models: source_models_lib.SourceModels, temp_ion: fvm.CellVariable, temp_el: fvm.CellVariable, ne: fvm.CellVariable, ni: fvm.CellVariable, jtot_face: jax.Array, psi: fvm.CellVariable, + source_models: source_models_lib.SourceModels, ) -> state.Currents: """Creates the initial Currents. Args: dynamic_config_slice: General configuration parameters at t_initial. geo: Geometry of the tokamak. - source_models: All TORAX source/sink functions. If not provided, uses the - default sources. temp_ion: Ion temperature. temp_el: Electron temperature. ne: Electron density. ni: Main ion density. jtot_face: Total current density on face grid. psi: Poloidal flux. + source_models: All TORAX source/sink functions. If not provided, uses the + default sources. Returns: currents: Plasma currents @@ -304,9 +304,9 @@ def _prescribe_currents_with_bootstrap( # form of external current on face grid jext_source = source_models.jext jext_face, jext = jext_source.get_value( - source_type=dynamic_config_slice.sources[jext_source.name].source_type, dynamic_config_slice=dynamic_config_slice, geo=geo, + source_type=dynamic_config_slice.sources[jext_source.name].source_type, ) # construct prescribed current formula on grid. @@ -349,25 +349,25 @@ def _prescribe_currents_with_bootstrap( def _calculate_currents_from_psi( dynamic_config_slice: config_slice.DynamicConfigSlice, geo: Geometry, - source_models: source_models_lib.SourceModels, temp_ion: fvm.CellVariable, temp_el: fvm.CellVariable, ne: fvm.CellVariable, ni: fvm.CellVariable, psi: fvm.CellVariable, + source_models: source_models_lib.SourceModels, ) -> state.Currents: """Creates the initial Currents using psi to calculate jtot. Args: dynamic_config_slice: General configuration parameters at t_initial. geo: Geometry of the tokamak. - source_models: All TORAX source/sink functions. If not provided, uses the - default sources. temp_ion: Ion temperature. temp_el: Electron temperature. ne: Electron density. ni: Main ion density. psi: Poloidal flux. + source_models: All TORAX source/sink functions. If not provided, uses the + default sources. Returns: currents: Plasma currents @@ -406,9 +406,9 @@ def _calculate_currents_from_psi( # form of external current on face grid jext_source = source_models.jext jext_face, jext = jext_source.get_value( - source_type=dynamic_config_slice.sources[jext_source.name].source_type, dynamic_config_slice=dynamic_config_slice, geo=geo, + source_type=dynamic_config_slice.sources[jext_source.name].source_type, ) johm = jtot - jext - bootstrap_profile.j_bootstrap @@ -495,16 +495,16 @@ def _update_psi_from_j( def initial_core_profiles( - dynamic_config_slice: config_slice.DynamicConfigSlice, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice: config_slice.DynamicConfigSlice, geo: Geometry, source_models: source_models_lib.SourceModels | None = None, ) -> state.CoreProfiles: """Calculates the initial core profiles. Args: - dynamic_config_slice: Dynamic configuration parameters at t=t_initial. static_config_slice: Static simulation configuration parameters. + dynamic_config_slice: Dynamic configuration parameters at t=t_initial. geo: Torus geometry. source_models: All models for TORAX sources/sinks. If not provided, uses the default source_models. @@ -523,9 +523,9 @@ def initial_core_profiles( # To set initial values and compute the boundary conditions, we need to handle # potentially time-varying inputs from the users. # The default time in build_dynamic_config_slice is t_initial - temp_ion = _updated_ti(dynamic_config_slice, static_config_slice, geo) - temp_el = _updated_te(dynamic_config_slice, static_config_slice, geo) - ne, ni = _updated_dens(dynamic_config_slice, static_config_slice, geo) + temp_ion = _updated_ti(static_config_slice, dynamic_config_slice, geo) + temp_el = _updated_te(static_config_slice, dynamic_config_slice, geo) + ne, ni = _updated_dens(static_config_slice, dynamic_config_slice, geo) # set up initial psi profile based on current profile if ( @@ -550,13 +550,13 @@ def initial_core_profiles( currents = _prescribe_currents_with_bootstrap( dynamic_config_slice=dynamic_config_slice, geo=geo, - source_models=source_models, temp_ion=temp_ion, temp_el=temp_el, ne=ne, ni=ni, jtot_face=currents_no_bootstrap.jtot_face, psi=psi_no_bootstrap, + source_models=source_models, ) psi = _update_psi_from_j( @@ -567,8 +567,8 @@ def initial_core_profiles( q_face, _ = physics.calc_q_from_jtot_psi( geo=geo, - jtot_face=currents.jtot_face, psi=psi, + jtot_face=currents.jtot_face, q_correction_factor=dynamic_config_slice.numerics.q_correction_factor, ) s_face = physics.calc_s_from_psi(geo, psi) @@ -594,8 +594,8 @@ def initial_core_profiles( ) q_face, _ = physics.calc_q_from_jtot_psi( geo=geo, - jtot_face=geo.jtot_face, psi=psi, + jtot_face=geo.jtot_face, q_correction_factor=dynamic_config_slice.numerics.q_correction_factor, ) s_face = physics.calc_s_from_psi(geo, psi) @@ -636,7 +636,7 @@ def initial_core_profiles( psidot = dataclasses.replace( psidot, value=source_models_lib.calc_psidot( - source_models, dynamic_config_slice, geo, core_profiles + dynamic_config_slice, geo, core_profiles, source_models, ), ) @@ -654,20 +654,20 @@ def initial_core_profiles( def updated_prescribed_core_profiles( - core_profiles: state.CoreProfiles, - dynamic_config_slice: config_slice.DynamicConfigSlice, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice: config_slice.DynamicConfigSlice, geo: Geometry, + core_profiles: state.CoreProfiles, ) -> dict[str, jax.Array]: """Updates core profiles which are not being evolved by PDE. Uses same functions as for profile initialization. Args: - core_profiles: Core profiles dataclass to be updated - dynamic_config_slice: Dynamic configuration parameters at t=t_initial. static_config_slice: Static simulation configuration parameters. + dynamic_config_slice: Dynamic configuration parameters at t=t_initial. geo: Torus geometry. + core_profiles: Core profiles dataclass to be updated Returns: Updated core profiles. @@ -680,21 +680,21 @@ def updated_prescribed_core_profiles( not static_config_slice.ion_heat_eq and dynamic_config_slice.numerics.enable_prescribed_profile_evolution ): - temp_ion = _updated_ti(dynamic_config_slice, static_config_slice, geo).value + temp_ion = _updated_ti(static_config_slice, dynamic_config_slice, geo).value else: temp_ion = core_profiles.temp_ion.value if ( not static_config_slice.el_heat_eq and dynamic_config_slice.numerics.enable_prescribed_profile_evolution ): - temp_el = _updated_te(dynamic_config_slice, static_config_slice, geo).value + temp_el = _updated_te(static_config_slice, dynamic_config_slice, geo).value else: temp_el = core_profiles.temp_el.value if ( not static_config_slice.dens_eq and dynamic_config_slice.numerics.enable_prescribed_profile_evolution ): - ne, _ = _updated_dens(dynamic_config_slice, static_config_slice, geo) + ne, _ = _updated_dens(static_config_slice, dynamic_config_slice, geo) ne = ne.value else: ne = core_profiles.ne.value @@ -703,18 +703,18 @@ def updated_prescribed_core_profiles( def update_evolving_core_profiles( - core_profiles: state.CoreProfiles, x_new: tuple[fvm.cell_variable.CellVariable, ...], - evolving_names: tuple[str, ...], dynamic_config_slice: config_slice.DynamicConfigSlice, + core_profiles: state.CoreProfiles, + evolving_names: tuple[str, ...], ) -> state.CoreProfiles: """Returns the new core profiles after updating the evolving variables. Args: - core_profiles: The old set of core plasma profiles. x_new: The new values of the evolving variables. - evolving_names: The names of the evolving variables. dynamic_config_slice: The dynamic config slice. + core_profiles: The old set of core plasma profiles. + evolving_names: The names of the evolving variables. """ def get_update(x_new, var): diff --git a/torax/fvm/block_1d_coeffs.py b/torax/fvm/block_1d_coeffs.py index a79b07a8..fa58a956 100644 --- a/torax/fvm/block_1d_coeffs.py +++ b/torax/fvm/block_1d_coeffs.py @@ -103,8 +103,8 @@ class Block1DCoeffsCallback(Protocol): def __call__( self, - x: tuple[cell_variable.CellVariable, ...], dynamic_config_slice: config_slice.DynamicConfigSlice, + x: tuple[cell_variable.CellVariable, ...], allow_pereverzev: bool = False, explicit_call: bool = False, ) -> Block1DCoeffs: @@ -124,10 +124,10 @@ def __call__( final output x_new. Args: - x: The state. dynamic_config_slice: Runtime configuration parameters. These values are potentially time-dependent and should correspond to the time step of the state x. + x: The state. allow_pereverzev: If True, then the coeffs are being called for an initial guess based on a linear step as opposed to just passing the iniitial state. This is a special case which may lead to the pereverzev-corrigan diff --git a/torax/fvm/discrete_system.py b/torax/fvm/discrete_system.py index 77a6a1bf..05980c09 100644 --- a/torax/fvm/discrete_system.py +++ b/torax/fvm/discrete_system.py @@ -39,8 +39,8 @@ def calc_c( - coeffs: Block1DCoeffs, x: tuple[cell_variable.CellVariable, ...], + coeffs: Block1DCoeffs, convection_dirichlet_mode: str = 'ghost', convection_neumann_mode: str = 'ghost', ) -> tuple[jax.Array, jax.Array]: @@ -50,9 +50,9 @@ def calc_c( more detail. Args: - coeffs: Coefficients defining the differential equation. x: Tuple containing CellVariables for each channel. This function uses only their shape and their boundary conditions, not their values. + coeffs: Coefficients defining the differential equation. convection_dirichlet_mode: See docstring of the `convection_terms` function, `dirichlet_mode` argument. convection_neumann_mode: See docstring of the `convection_terms` function, diff --git a/torax/fvm/implicit_solve_block.py b/torax/fvm/implicit_solve_block.py index 6df43c2a..650598a7 100644 --- a/torax/fvm/implicit_solve_block.py +++ b/torax/fvm/implicit_solve_block.py @@ -36,9 +36,9 @@ ], ) def implicit_solve_block( + dt: jax.Array, x_old: tuple[cell_variable.CellVariable, ...], x_new_guess: tuple[cell_variable.CellVariable, ...], - dt: jax.Array, coeffs_old: block_1d_coeffs.Block1DCoeffs, coeffs_new: block_1d_coeffs.Block1DCoeffs, theta_imp: float = 1.0, @@ -53,9 +53,9 @@ def implicit_solve_block( to obtain the coefficients for a particular problem. Args: + dt: Discrete time step. x_old: Tuple containing CellVariables for each channel with their values at x_new_guess: Tuple containing initial guess for x_new. - dt: Discrete time step. coeffs_old: Coefficients defining the equation, computed for time t. coeffs_new: Coefficients defining the equation, computed for time t+dt. theta_imp: Coefficient in [0, 1] determining which solution method to use. @@ -86,9 +86,9 @@ def implicit_solve_block( lhs_mat, lhs_vec, rhs_mat, rhs_vec = ( residual_and_loss.theta_method_matrix_equation( - x_new_guess=x_new_guess, - x_old=x_old, dt=dt, + x_old=x_old, + x_new_guess=x_new_guess, coeffs_old=coeffs_old, coeffs_new=coeffs_new, theta_imp=theta_imp, diff --git a/torax/fvm/newton_raphson_solve_block.py b/torax/fvm/newton_raphson_solve_block.py index bccd6973..2cd9a6e2 100644 --- a/torax/fvm/newton_raphson_solve_block.py +++ b/torax/fvm/newton_raphson_solve_block.py @@ -89,18 +89,18 @@ def _log_iterations( def newton_raphson_solve_block( - x_old: tuple[cell_variable.CellVariable, ...], - core_profiles_t_plus_dt: state_module.CoreProfiles, - evolving_names: tuple[str, ...], dt: jax.Array, - coeffs_callback: Block1DCoeffsCallback, + static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice_t: config_slice.DynamicConfigSlice, dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, geo: geometry.Geometry, + x_old: tuple[cell_variable.CellVariable, ...], + core_profiles_t_plus_dt: state_module.CoreProfiles, transport_model: transport_model_lib.TransportModel, - source_models: source_models_lib.SourceModels, explicit_source_profiles: source_profiles.SourceProfiles, + source_models: source_models_lib.SourceModels, + coeffs_callback: Block1DCoeffsCallback, + evolving_names: tuple[str, ...], log_iterations: bool = False, initial_guess_mode: InitialGuessMode = INITIAL_GUESS_MODE, maxiter: int = MAXITER, @@ -133,29 +133,29 @@ def newton_raphson_solve_block( either a warning or recalculation with a lower dt. Args: - x_old: Tuple containing CellVariables for each channel with their values at - the start of the time step. - core_profiles_t_plus_dt: Core plasma profiles which contain all available - prescribed quantities at the end of the time step. This includes evolving - boundary conditions and prescribed time-dependent profiles that are not - being evolved by the PDE system. - evolving_names: The names of variables within the core profiles that should - evolve. dt: Discrete time step. - coeffs_callback: Calculates diffusion, convection etc. coefficients given a - core_profiles. Repeatedly called by the iterative optimizer. + static_config_slice: Static runtime configuration. Changes to these config + params will trigger recompilation. dynamic_config_slice_t: Runtime configuration for time t (the start time of the step). These config params can change from step to step without triggering a recompilation. dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. - static_config_slice: Static runtime configuration. Changes to these config - params will trigger recompilation. geo: Geometry object. + x_old: Tuple containing CellVariables for each channel with their values at + the start of the time step. + core_profiles_t_plus_dt: Core plasma profiles which contain all available + prescribed quantities at the end of the time step. This includes evolving + boundary conditions and prescribed time-dependent profiles that are not + being evolved by the PDE system. transport_model: Turbulent transport model callable. - source_models: Collection of source callables to generate source PDE - coefficients. explicit_source_profiles: Pre-calculated sources implemented as explicit sources in the PDE. + source_models: Collection of source callables to generate source PDE + coefficients. + coeffs_callback: Calculates diffusion, convection etc. coefficients given a + core_profiles. Repeatedly called by the iterative optimizer. + evolving_names: The names of variables within the core profiles that should + evolve. log_iterations: If true, output diagnostic information from within iteration loop. initial_guess_mode: chooses the initial_guess for the iterative method, @@ -183,7 +183,7 @@ def newton_raphson_solve_block( # pyformat: enable coeffs_old = coeffs_callback( - x_old, dynamic_config_slice_t, explicit_call=True + dynamic_config_slice_t, x_old, explicit_call=True ) match initial_guess_mode: @@ -194,8 +194,8 @@ def newton_raphson_solve_block( # if set by config, needed if stiff transport models (e.g. qlknn) # are used. coeffs_exp_linear = coeffs_callback( - x_old, dynamic_config_slice_t, + x_old, allow_pereverzev=True, explicit_call=True, ) @@ -210,19 +210,21 @@ def newton_raphson_solve_block( # this is jitted. ( source_models_lib.build_all_zero_profiles( - source_models, dynamic_config_slice_t, geo + dynamic_config_slice_t, + geo, + source_models, ), state_module.CoreTransport.zeros(geo), ), ) init_x_new, _ = predictor_corrector_method.predictor_corrector_method( - init_val=init_val, - x_old=x_old, dt=dt, + static_config_slice=static_config_slice, + dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + x_old=x_old, + init_val=init_val, coeffs_exp=coeffs_exp_linear, coeffs_callback=coeffs_callback, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, ) init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new) case InitialGuessMode.X_OLD: @@ -239,30 +241,30 @@ def newton_raphson_solve_block( residual_fun = functools.partial( residual_and_loss.theta_method_block_residual, dt=dt, + static_config_slice=static_config_slice, + dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + geo=geo, x_old=x_old, core_profiles_t_plus_dt=core_profiles_t_plus_dt, - geo=geo, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, - evolving_names=evolving_names, - coeffs_old=coeffs_old, transport_model=transport_model, - source_models=source_models, explicit_source_profiles=explicit_source_profiles, + source_models=source_models, + coeffs_old=coeffs_old, + evolving_names=evolving_names, ) jacobian_fun = functools.partial( residual_and_loss.theta_method_block_jacobian, dt=dt, + static_config_slice=static_config_slice, + dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + geo=geo, x_old=x_old, core_profiles_t_plus_dt=core_profiles_t_plus_dt, - geo=geo, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, evolving_names=evolving_names, - coeffs_old=coeffs_old, transport_model=transport_model, - source_models=source_models, explicit_source_profiles=explicit_source_profiles, + source_models=source_models, + coeffs_old=coeffs_old, ) cond_fun = functools.partial(cond, tol=tol, tau_min=tau_min, maxiter=maxiter) diff --git a/torax/fvm/optimizer_solve_block.py b/torax/fvm/optimizer_solve_block.py index dd9aea85..f5a349f8 100644 --- a/torax/fvm/optimizer_solve_block.py +++ b/torax/fvm/optimizer_solve_block.py @@ -44,18 +44,18 @@ def optimizer_solve_block( - x_old: tuple[cell_variable.CellVariable, ...], - core_profiles_t_plus_dt: state.CoreProfiles, - evolving_names: tuple[str, ...], dt: jax.Array, - coeffs_callback: Block1DCoeffsCallback, + static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice_t: config_slice.DynamicConfigSlice, dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, geo: geometry.Geometry, + x_old: tuple[cell_variable.CellVariable, ...], + core_profiles_t_plus_dt: state.CoreProfiles, transport_model: transport_model_lib.TransportModel, - source_models: source_models_lib.SourceModels, explicit_source_profiles: source_profiles.SourceProfiles, + source_models: source_models_lib.SourceModels, + coeffs_callback: Block1DCoeffsCallback, + evolving_names: tuple[str, ...], initial_guess_mode: InitialGuessMode = INITIAL_GUESS_MODE, maxiter=MAXITER, tol=TOL, @@ -71,21 +71,7 @@ def optimizer_solve_block( between two sides of the equation describing a theta method update. Args: - x_old: Tuple containing CellVariables for each channel with their values at - the start of the time step. - core_profiles_t_plus_dt: Core plasma profiles which contain all available - prescribed quantities at the end of the time step. This includes evolving - boundary conditions and prescribed time-dependent profiles that are not - being evolved by the PDE system. - evolving_names: The names of variables within the core profiles that should - evolve. dt: Discrete time step. - coeffs_callback: Calculates diffusion, convection etc. coefficients given a - core_profiles. Repeatedly called by the iterative optimizer. - dynamic_config_slice_t: Runtime configuration for time t (the start time of - the step). These config params can change from step to step without - triggering a recompilation. - dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. static_config_slice: Static runtime configuration. Changes to these config params will trigger recompilation. A key parameter in static_config slice is theta_imp, a coefficient in [0, 1] determining which solution method to @@ -94,12 +80,26 @@ def optimizer_solve_block( solution methods: theta_imp = 1: Backward Euler implicit method (default). theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler explicit method. + dynamic_config_slice_t: Runtime configuration for time t (the start time of + the step). These config params can change from step to step without + triggering a recompilation. + dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. geo: Geometry object used to initialize auxiliary outputs. + x_old: Tuple containing CellVariables for each channel with their values at + the start of the time step. + core_profiles_t_plus_dt: Core plasma profiles which contain all available + prescribed quantities at the end of the time step. This includes evolving + boundary conditions and prescribed time-dependent profiles that are not + being evolved by the PDE system. transport_model: Turbulent transport model callable. - source_models: Collection of source callables to generate source PDE - coefficients. explicit_source_profiles: Pre-calculated sources implemented as explicit sources in the PDE. + source_models: Collection of source callables to generate source PDE + coefficients. + coeffs_callback: Calculates diffusion, convection etc. coefficients given a + core_profiles. Repeatedly called by the iterative optimizer. + evolving_names: The names of variables within the core profiles that should + evolve. initial_guess_mode: Chooses the initial_guess for the iterative method, either x_old or linear step. When taking the linear step, it is also recommended to use Pereverzev-Corrigan terms if the transport use @@ -116,7 +116,7 @@ def optimizer_solve_block( # pyformat: enable coeffs_old = coeffs_callback( - x_old, dynamic_config_slice_t, explicit_call=True + dynamic_config_slice_t, x_old, explicit_call=True ) match initial_guess_mode: @@ -127,8 +127,8 @@ def optimizer_solve_block( # if set by config, needed if stiff transport models (e.g. qlknn) # are used. coeffs_exp_linear = coeffs_callback( - x_old, dynamic_config_slice_t, + x_old, allow_pereverzev=True, explicit_call=True, ) @@ -142,19 +142,21 @@ def optimizer_solve_block( # this is jitted. ( source_models_lib.build_all_zero_profiles( - source_models, dynamic_config_slice_t, geo + dynamic_config_slice_t, + geo, + source_models, ), state.CoreTransport.zeros(geo), ), ) init_x_new, _ = predictor_corrector_method.predictor_corrector_method( + dt=dt, + static_config_slice=static_config_slice, + dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, init_val=init_val, x_old=x_old, - dt=dt, coeffs_exp=coeffs_exp_linear, coeffs_callback=coeffs_callback, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, ) init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new) case InitialGuessMode.X_OLD: @@ -166,18 +168,18 @@ def optimizer_solve_block( # Advance jaxopt_solver by one timestep x_new_vec, final_loss, aux_output = residual_and_loss.jaxopt_solver( - init_x_new_vec=init_x_new_vec, + dt=dt, + static_config_slice=static_config_slice, + dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + geo=geo, x_old=x_old, + init_x_new_vec=init_x_new_vec, core_profiles_t_plus_dt=core_profiles_t_plus_dt, - geo=geo, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, - dt=dt, - evolving_names=evolving_names, - coeffs_old=coeffs_old, transport_model=transport_model, - source_models=source_models, explicit_source_profiles=explicit_source_profiles, + source_models=source_models, + coeffs_old=coeffs_old, + evolving_names=evolving_names, maxiter=maxiter, tol=tol, ) diff --git a/torax/fvm/residual_and_loss.py b/torax/fvm/residual_and_loss.py index 4a97e73f..63834b4a 100644 --- a/torax/fvm/residual_and_loss.py +++ b/torax/fvm/residual_and_loss.py @@ -52,9 +52,9 @@ ], ) def theta_method_matrix_equation( - x_new_guess: tuple[cell_variable.CellVariable, ...], - x_old: tuple[cell_variable.CellVariable, ...], dt: jax.Array, + x_old: tuple[cell_variable.CellVariable, ...], + x_new_guess: tuple[cell_variable.CellVariable, ...], coeffs_old: Block1DCoeffs, coeffs_new: Block1DCoeffs, theta_imp: float = 1.0, @@ -110,9 +110,9 @@ def theta_method_matrix_equation( ``` Args: - x_new_guess: Current guess of x_new defined as a tuple of CellVariables. - x_old: The starting x defined as a tuple of CellVariables. dt: Time step duration. + x_old: The starting x defined as a tuple of CellVariables. + x_new_guess: Current guess of x_new defined as a tuple of CellVariables. coeffs_old: The coefficients calculated at x_old. coeffs_new: The coefficients calculated at x_new. theta_imp: Coefficient on implicit term of theta method. @@ -156,8 +156,8 @@ def theta_method_matrix_equation( right_transient = jnp.diag(jnp.squeeze(tc_in_old / tc_in_new)) c_mat_new, c_new = discrete_system.calc_c( - coeffs_new, x_new_guess, + coeffs_new, convection_dirichlet_mode, convection_neumann_mode, ) @@ -175,8 +175,8 @@ def theta_method_matrix_equation( msg='tc_out_old*tc_in_new unexpectedly < eps', ) c_mat_old, c_old = discrete_system.calc_c( - coeffs_old, x_old, + coeffs_old, convection_dirichlet_mode, convection_neumann_mode, ) @@ -194,39 +194,31 @@ def theta_method_matrix_equation( jax_utils.jit, static_argnames=[ 'static_config_slice', - 'evolving_names', 'transport_model', 'source_models', + 'evolving_names', ], ) def theta_method_block_residual( x_new_guess_vec: jax.Array, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + geo: geometry.Geometry, x_old: tuple[cell_variable.CellVariable, ...], core_profiles_t_plus_dt: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, - dt: jax.Array, - coeffs_old: Block1DCoeffs, transport_model: transport_model_lib.TransportModel, - source_models: source_models_lib.SourceModels, explicit_source_profiles: source_profiles.SourceProfiles, + source_models: source_models_lib.SourceModels, + coeffs_old: Block1DCoeffs, + evolving_names: tuple[str, ...], ) -> tuple[jax.Array, AuxiliaryOutput]: """Residual of theta-method equation for core profiles at next time-step. Args: x_new_guess_vec: Flattened array of current guess of x_new for all evolving core profiles. - x_old: The starting x defined as a tuple of CellVariables. - core_profiles_t_plus_dt: Core plasma profiles which contain all available - prescribed quantities at the end of the time step. This includes evolving - boundary conditions and prescribed time-dependent profiles that are not - being evolved by the PDE system. - evolving_names: The names of variables within the core profiles that should - evolve. - geo: Geometry object. - dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. + dt: Time step duration. static_config_slice: Static runtime configuration. Changes to these config params will trigger recompilation. A key parameter in static_config slice is theta_imp, a coefficient in [0, 1] determining which solution method to @@ -235,13 +227,21 @@ def theta_method_block_residual( solution methods: theta_imp = 1: Backward Euler implicit method (default). theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler explicit method. - dt: Time step duration. - coeffs_old: The coefficients calculated at x_old. + dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. + geo: Geometry object. + x_old: The starting x defined as a tuple of CellVariables. + core_profiles_t_plus_dt: Core plasma profiles which contain all available + prescribed quantities at the end of the time step. This includes evolving + boundary conditions and prescribed time-dependent profiles that are not + being evolved by the PDE system. transport_model: Turbulent transport model callable. - source_models: Collection of source callables to generate source PDE - coefficients. explicit_source_profiles: Pre-calculated sources implemented as explicit sources in the PDE. + source_models: Collection of source callables to generate source PDE + coefficients. + coeffs_old: The coefficients calculated at x_old. + evolving_names: The names of variables within the core profiles that should + evolve. Returns: residual: Vector residual between LHS and RHS of the theta method equation. @@ -253,27 +253,27 @@ def theta_method_block_residual( x_new_guess_vec, core_profiles_t_plus_dt, evolving_names ) core_profiles_t_plus_dt = core_profile_setters.update_evolving_core_profiles( - core_profiles_t_plus_dt, x_new_guess, - evolving_names, dynamic_config_slice_t_plus_dt, + core_profiles_t_plus_dt, + evolving_names, ) coeffs_new = calc_coeffs.calc_coeffs( - core_profiles=core_profiles_t_plus_dt, - evolving_names=evolving_names, - geo=geo, - dynamic_config_slice=dynamic_config_slice_t_plus_dt, static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice_t_plus_dt, + geo=geo, + core_profiles=core_profiles_t_plus_dt, transport_model=transport_model, explicit_source_profiles=explicit_source_profiles, source_models=source_models, + evolving_names=evolving_names, use_pereverzev=False, ) lhs_mat, lhs_vec, rhs_mat, rhs_vec = theta_method_matrix_equation( - x_new_guess=x_new_guess, - x_old=x_old, dt=dt, + x_old=x_old, + x_new_guess=x_new_guess, coeffs_old=coeffs_old, coeffs_new=coeffs_new, theta_imp=static_config_slice.solver.theta_imp, @@ -295,9 +295,9 @@ def theta_method_block_residual( theta_method_block_jacobian, static_argnames=[ 'static_config_slice', - 'evolving_names', 'transport_model', 'source_models', + 'evolving_names', ], ) @@ -306,39 +306,31 @@ def theta_method_block_residual( jax_utils.jit, static_argnames=[ 'static_config_slice', - 'evolving_names', 'transport_model', 'source_models', + 'evolving_names', ], ) def theta_method_block_loss( x_new_guess_vec: jax.Array, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + geo: geometry.Geometry, x_old: tuple[cell_variable.CellVariable, ...], core_profiles_t_plus_dt: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, - dt: jax.Array, - coeffs_old: Block1DCoeffs, transport_model: transport_model_lib.TransportModel, - source_models: source_models_lib.SourceModels, explicit_source_profiles: source_profiles.SourceProfiles, + source_models: source_models_lib.SourceModels, + coeffs_old: Block1DCoeffs, + evolving_names: tuple[str, ...], ) -> tuple[jax.Array, AuxiliaryOutput]: """Loss for the optimizer method of nonlinear solution. Args: x_new_guess_vec: Flattened array of current guess of x_new for all evolving core profiles. - x_old: The starting x defined as a tuple of CellVariables. - core_profiles_t_plus_dt: Core plasma profiles which contain all available - prescribed quantities at the end of the time step. This includes evolving - boundary conditions and prescribed time-dependent profiles that are not - being evolved by the PDE system. - evolving_names: The names of variables within the core profiles that should - evolve. - geo: geometry object - dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. + dt: Time step duration. static_config_slice: Static runtime configuration. Changes to these config params will trigger recompilation. A key parameter in static_config slice is theta_imp, a coefficient in [0, 1] determining which solution method to @@ -347,31 +339,39 @@ def theta_method_block_loss( solution methods: theta_imp = 1: Backward Euler implicit method (default). theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler explicit method. - dt: Time step duration. - coeffs_old: The coefficients calculated at x_old. + dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. + geo: geometry object + x_old: The starting x defined as a tuple of CellVariables. + core_profiles_t_plus_dt: Core plasma profiles which contain all available + prescribed quantities at the end of the time step. This includes evolving + boundary conditions and prescribed time-dependent profiles that are not + being evolved by the PDE system. transport_model: turbulent transport model callable - source_models: Collection of source callables to generate source PDE - coefficients. explicit_source_profiles: pre-calculated sources implemented as explicit sources in the PDE + source_models: Collection of source callables to generate source PDE + coefficients. + coeffs_old: The coefficients calculated at x_old. + evolving_names: The names of variables within the core profiles that should + evolve. Returns: loss: mean squared loss of theta method residual. """ residual, aux_output = theta_method_block_residual( - x_new_guess_vec=x_new_guess_vec, + dt=dt, + static_config_slice=static_config_slice, + dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + geo=geo, x_old=x_old, + x_new_guess_vec=x_new_guess_vec, core_profiles_t_plus_dt=core_profiles_t_plus_dt, - evolving_names=evolving_names, - geo=geo, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, - dt=dt, - coeffs_old=coeffs_old, transport_model=transport_model, - source_models=source_models, explicit_source_profiles=explicit_source_profiles, + source_models=source_models, + coeffs_old=coeffs_old, + evolving_names=evolving_names, ) loss = jnp.mean(jnp.square(residual)) return loss, aux_output @@ -381,41 +381,31 @@ def theta_method_block_loss( jax_utils.jit, static_argnames=[ 'static_config_slice', - 'evolving_names', 'transport_model', 'source_models', + 'evolving_names', ], ) def jaxopt_solver( - init_x_new_vec: jax.Array, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + geo: geometry.Geometry, x_old: tuple[cell_variable.CellVariable, ...], + init_x_new_vec: jax.Array, core_profiles_t_plus_dt: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, - dt: jax.Array, - coeffs_old: Block1DCoeffs, transport_model: transport_model_lib.TransportModel, - source_models: source_models_lib.SourceModels, explicit_source_profiles: source_profiles.SourceProfiles, + source_models: source_models_lib.SourceModels, + coeffs_old: Block1DCoeffs, + evolving_names: tuple[str, ...], maxiter: int, tol: float, ) -> tuple[jax.Array, float, AuxiliaryOutput]: """Advances jaxopt solver by one timestep. Args: - init_x_new_vec: Flattened array of initial guess of x_new for all evolving - core profiles. - x_old: The starting x defined as a tuple of CellVariables. - core_profiles_t_plus_dt: Core plasma profiles which contain all available - prescribed quantities at the end of the time step. This includes evolving - boundary conditions and prescribed time-dependent profiles that are not - being evolved by the PDE system. - evolving_names: The names of variables within the core profiles that should - evolve. - geo: geometry object. - dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. + dt: Time step duration. static_config_slice: Static runtime configuration. Changes to these config params will trigger recompilation. A key parameter in static_config slice is theta_imp, a coefficient in [0, 1] determining which solution method to @@ -424,13 +414,23 @@ def jaxopt_solver( solution methods: theta_imp = 1: Backward Euler implicit method (default). theta_imp = 0.5: Crank-Nicolson. theta_imp = 0: Forward Euler explicit method. - dt: Time step duration. - coeffs_old: The coefficients calculated at x_old. + dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt. + geo: geometry object. + x_old: The starting x defined as a tuple of CellVariables. + init_x_new_vec: Flattened array of initial guess of x_new for all evolving + core profiles. + core_profiles_t_plus_dt: Core plasma profiles which contain all available + prescribed quantities at the end of the time step. This includes evolving + boundary conditions and prescribed time-dependent profiles that are not + being evolved by the PDE system. transport_model: turbulent transport model callable. - source_models: Collection of source callables to generate source PDE - coefficients. explicit_source_profiles: pre-calculated sources implemented as explicit sources in the PDE. + source_models: Collection of source callables to generate source PDE + coefficients. + coeffs_old: The coefficients calculated at x_old. + evolving_names: The names of variables within the core profiles that should + evolve. maxiter: maximum number of iterations of jaxopt solver. tol: tolerance for jaxopt solver convergence. @@ -442,16 +442,16 @@ def jaxopt_solver( loss = functools.partial( theta_method_block_loss, dt=dt, + static_config_slice=static_config_slice, + dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + geo=geo, x_old=x_old, core_profiles_t_plus_dt=core_profiles_t_plus_dt, - geo=geo, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, - evolving_names=evolving_names, - coeffs_old=coeffs_old, transport_model=transport_model, - source_models=source_models, explicit_source_profiles=explicit_source_profiles, + source_models=source_models, + coeffs_old=coeffs_old, + evolving_names=evolving_names, ) solver = jaxopt.LBFGS(fun=loss, maxiter=maxiter, tol=tol, has_aux=True) solver_output = solver.run(init_x_new_vec) diff --git a/torax/fvm/tests/fvm.py b/torax/fvm/tests/fvm.py index 6a5992d9..aeafd5d8 100644 --- a/torax/fvm/tests/fvm.py +++ b/torax/fvm/tests/fvm.py @@ -215,9 +215,9 @@ def test_leftward_convection(self, num_cells, theta_imp, time_steps): ) for _ in range(time_steps): x = implicit_solve_block.implicit_solve_block( + dt=dt, x_old=x, x_new_guess=x, - dt=dt, coeffs_old=coeffs, # Assume no time-dependent params. coeffs_new=coeffs, @@ -381,7 +381,7 @@ def test_nonlinear_solve_block_loss_minimum( static_config_slice = config_slice.build_static_config_slice(config) source_models = source_models_lib.SourceModels() core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice, static_config_slice, geo, source_models + static_config_slice, dynamic_config_slice, geo, source_models ) evolving_names = tuple(['temp_ion']) explicit_source_profiles = source_models_lib.build_source_profiles( @@ -393,14 +393,14 @@ def test_nonlinear_solve_block_loss_minimum( ) transport_model = transport_model_factory.construct(config) coeffs = calc_coeffs.calc_coeffs( - core_profiles=core_profiles, - evolving_names=evolving_names, - geo=geo, - dynamic_config_slice=dynamic_config_slice, static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, + geo=geo, + core_profiles=core_profiles, transport_model=transport_model, explicit_source_profiles=explicit_source_profiles, source_models=source_models, + evolving_names=evolving_names, use_pereverzev=False, ) # dt well under the explicit stability limit for dx=1 and chi=1 @@ -410,12 +410,12 @@ def test_nonlinear_solve_block_loss_minimum( for _ in range(time_steps): x_old = copy.deepcopy(x_new) x_new = implicit_solve_block.implicit_solve_block( + dt=dt, x_old=x_old, x_new_guess=x_new, coeffs_old=coeffs, # Assume no time-dependent params. coeffs_new=coeffs, - dt=dt, theta_imp=theta_imp, ) @@ -424,33 +424,33 @@ def test_nonlinear_solve_block_loss_minimum( # solution as the minimum with approximately zero residual # core_profiles_t_plus_dt is not updated since coeffs stay constant here loss, _ = residual_and_loss.theta_method_block_loss( - x_new_guess_vec=jnp.concatenate([var.value for var in x_new]), + dt=dt, + static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice_t_plus_dt=dynamic_config_slice, + geo=geo, x_old=x_old, + x_new_guess_vec=jnp.concatenate([var.value for var in x_new]), core_profiles_t_plus_dt=core_profiles, - evolving_names=evolving_names, - geo=geo, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, - static_config_slice=config_slice.build_static_config_slice(config), - dt=dt, - coeffs_old=coeffs, transport_model=transport_model, - source_models=source_models, explicit_source_profiles=explicit_source_profiles, + source_models=source_models, + coeffs_old=coeffs, + evolving_names=evolving_names, ) residual, _ = residual_and_loss.theta_method_block_residual( + dt=dt, + static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice_t_plus_dt=dynamic_config_slice, + geo=geo, x_new_guess_vec=jnp.concatenate([var.value for var in x_new]), x_old=x_old, core_profiles_t_plus_dt=core_profiles, - evolving_names=evolving_names, - geo=geo, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, - static_config_slice=config_slice.build_static_config_slice(config), - dt=dt, - coeffs_old=coeffs, transport_model=transport_model, - source_models=source_models, explicit_source_profiles=explicit_source_profiles, + source_models=source_models, + coeffs_old=coeffs, + evolving_names=evolving_names, ) np.testing.assert_allclose(loss, 0.0, atol=1e-7) @@ -489,21 +489,21 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): ), ), ) - geo = geometry.build_circular_geometry(config) - dynamic_config_slice = config_slice.build_dynamic_config_slice(config) static_config_slice = config_slice.build_static_config_slice(config) + dynamic_config_slice = config_slice.build_dynamic_config_slice(config) + geo = geometry.build_circular_geometry(config) transport_model = transport_model_factory.construct( config, ) source_models = source_models_lib.SourceModels() initial_core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice, static_config_slice, geo, source_models + static_config_slice, dynamic_config_slice, geo, source_models ) explicit_source_profiles = source_models_lib.build_source_profiles( - source_models=source_models, dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=initial_core_profiles, + source_models=source_models, explicit=True, ) @@ -511,14 +511,14 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): evolving_names = tuple(['temp_ion']) coeffs = calc_coeffs.calc_coeffs( - core_profiles=initial_core_profiles, - evolving_names=evolving_names, - geo=geo, - dynamic_config_slice=dynamic_config_slice, static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, + geo=geo, + core_profiles=initial_core_profiles, transport_model=transport_model, explicit_source_profiles=explicit_source_profiles, source_models=source_models, + evolving_names=evolving_names, use_pereverzev=False, ) initial_right_boundary = jnp.array(0.0) @@ -531,12 +531,12 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): # Run with different theta_imp values. for theta_imp in [0.0, 0.5, 1.0]: x_new = implicit_solve_block.implicit_solve_block( + dt=dt, x_old=(x_0,), x_new_guess=(x_0,), coeffs_old=coeffs, # Assume no time-dependent params. coeffs_new=coeffs, - dt=dt, theta_imp=theta_imp, ) # No matter what theta_imp is used, the x_new will be all 0s because there @@ -549,12 +549,12 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): x_1 = dataclasses.replace(x_0, right_face_constraint=final_right_boundary) # However, the explicit terms (when theta_imp = 0), should still be all 0. x_new = implicit_solve_block.implicit_solve_block( + dt=dt, x_old=(x_0,), x_new_guess=(x_1,), coeffs_old=coeffs, # Assume no time-dependent params. coeffs_new=coeffs, - dt=dt, theta_imp=0.0, ) np.testing.assert_allclose(x_new[0].value, 0.0) @@ -564,12 +564,12 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): ) # And when theta_imp is > 0, the values should be > 0. x_new = implicit_solve_block.implicit_solve_block( + dt=dt, x_old=(x_0,), x_new_guess=(x_1,), coeffs_old=coeffs, # Assume no time-dependent params. coeffs_new=coeffs, - dt=dt, theta_imp=0.5, ) self.assertGreater(x_new[0].value.min(), 0.0) @@ -621,13 +621,13 @@ def test_theta_residual_uses_updated_boundary_conditions(self): ) source_models = source_models_lib.SourceModels() initial_core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice, static_config_slice_theta0, geo, source_models + static_config_slice_theta0, dynamic_config_slice, geo, source_models ) explicit_source_profiles = source_models_lib.build_source_profiles( - source_models=source_models, dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=initial_core_profiles, + source_models=source_models, explicit=True, ) @@ -635,14 +635,14 @@ def test_theta_residual_uses_updated_boundary_conditions(self): evolving_names = tuple(['temp_ion']) coeffs_old = calc_coeffs.calc_coeffs( - core_profiles=initial_core_profiles, - evolving_names=evolving_names, - geo=geo, - dynamic_config_slice=dynamic_config_slice, static_config_slice=static_config_slice_theta05, + dynamic_config_slice=dynamic_config_slice, + geo=geo, + core_profiles=initial_core_profiles, transport_model=transport_model, explicit_source_profiles=explicit_source_profiles, source_models=source_models, + evolving_names=evolving_names, use_pereverzev=False, ) @@ -654,7 +654,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self): right_face_constraint=initial_right_boundary, ) core_profiles_t_plus_dt = core_profile_setters.initial_core_profiles( - dynamic_config_slice, static_config_slice_theta0, geo + static_config_slice_theta0, dynamic_config_slice, geo ) core_profiles_t_plus_dt = dataclasses.replace( core_profiles_t_plus_dt, @@ -666,18 +666,18 @@ def test_theta_residual_uses_updated_boundary_conditions(self): # with diffusive transport and zero transport, then the state will stay # at all 0, and the residual should be 0. residual, _ = residual_and_loss.theta_method_block_residual( - x_new_guess_vec=x_0.value, + dt=dt, + static_config_slice=static_config_slice_theta05, + dynamic_config_slice_t_plus_dt=dynamic_config_slice, + geo=geo, x_old=(x_0,), + x_new_guess_vec=x_0.value, core_profiles_t_plus_dt=core_profiles_t_plus_dt, - evolving_names=evolving_names, - geo=geo, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, - static_config_slice=static_config_slice_theta05, - dt=dt, - coeffs_old=coeffs_old, transport_model=transport_model, - source_models=source_models, explicit_source_profiles=explicit_source_profiles, + source_models=source_models, + coeffs_old=coeffs_old, + evolving_names=evolving_names, ) np.testing.assert_allclose(residual, 0.0) with self.subTest('updated_boundary_conditions'): @@ -686,8 +686,12 @@ def test_theta_residual_uses_updated_boundary_conditions(self): # residual would still be 0. final_right_boundary = jnp.array(1.0) residual, _ = residual_and_loss.theta_method_block_residual( - x_new_guess_vec=x_0.value, + dt=dt, + static_config_slice=static_config_slice_theta0, + dynamic_config_slice_t_plus_dt=dynamic_config_slice, + geo=geo, x_old=(x_0,), + x_new_guess_vec=x_0.value, core_profiles_t_plus_dt=dataclasses.replace( core_profiles_t_plus_dt, temp_ion=dataclasses.replace( @@ -695,19 +699,18 @@ def test_theta_residual_uses_updated_boundary_conditions(self): ), ), evolving_names=evolving_names, - geo=geo, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, - static_config_slice=static_config_slice_theta0, - dt=dt, - coeffs_old=coeffs_old, transport_model=transport_model, - source_models=source_models, explicit_source_profiles=explicit_source_profiles, + source_models=source_models, + coeffs_old=coeffs_old, ) np.testing.assert_allclose(residual, 0.0) # But when theta_imp > 0, the residual should be non-zero. residual, _ = residual_and_loss.theta_method_block_residual( - x_new_guess_vec=x_0.value, + dt=dt, + static_config_slice=static_config_slice_theta05, + dynamic_config_slice_t_plus_dt=dynamic_config_slice, + geo=geo, x_old=(x_0,), core_profiles_t_plus_dt=dataclasses.replace( core_profiles_t_plus_dt, @@ -715,15 +718,12 @@ def test_theta_residual_uses_updated_boundary_conditions(self): x_0, right_face_constraint=final_right_boundary ), ), - evolving_names=evolving_names, - dt=dt, - geo=geo, - dynamic_config_slice_t_plus_dt=dynamic_config_slice, - static_config_slice=static_config_slice_theta05, - coeffs_old=coeffs_old, + x_new_guess_vec=x_0.value, transport_model=transport_model, - source_models=source_models, explicit_source_profiles=explicit_source_profiles, + source_models=source_models, + coeffs_old=coeffs_old, + evolving_names=evolving_names, ) self.assertGreater(jnp.abs(jnp.sum(residual)), 0.0) diff --git a/torax/physics.py b/torax/physics.py index 5c143cf7..d4ba2da2 100644 --- a/torax/physics.py +++ b/torax/physics.py @@ -58,10 +58,10 @@ def update_jtot_q_face_s_face( core_profiles.psi, ) q_face, _ = calc_q_from_jtot_psi( - geo, - jtot_face, - core_profiles.psi, - q_correction_factor, + geo=geo, + psi=core_profiles.psi, + jtot_face=jtot_face, + q_correction_factor=q_correction_factor, ) s_face = calc_s_from_psi( geo, @@ -137,8 +137,8 @@ def internal_boundary( def calc_q_from_jtot_psi( geo: Geometry, - jtot_face: jax.Array, psi: cell_variable.CellVariable, + jtot_face: jax.Array, q_correction_factor: float, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Calculates q given jtot and psi. @@ -149,8 +149,8 @@ def calc_q_from_jtot_psi( Args: geo: Magnetic geometry. - jtot_face: Total toroidal current density on face grid. psi: Poloidal flux. + jtot_face: Total toroidal current density on face grid. q_correction_factor: ad-hoc fix for non-physical circular geometry model such that q(r=a) = 3 for standard ITER parameters; diff --git a/torax/sim.py b/torax/sim.py index 69f40105..40f4e1ed 100644 --- a/torax/sim.py +++ b/torax/sim.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Functionality for running the heat + density simulation. +"""Functionality for running simulations. This includes the `run_simulation` main loop, logging functionality, and functionality for translating between our particular physics @@ -22,6 +22,29 @@ jax compilation off and on. Compilation is on by default. Turning compilation off can sometimes help with debugging (e.g. by making it easier to print error messages in context). + +Throughout TORAX, we maintain the following canonical argument order passed to +the various functions. For each individual case only a subset of these are +passed, but the order should be maintained. Individual elements in +CANONICAL_ORDER are substrings of full argument names which may appear in +practice, e.g. "dynamic_config_slice_t_plus_dt", "coeffs_callback". + +CANONICAL_ORDER = [ + "dt" + "source_type", + "static_config_slice", + "dynamic_config_slice", + "geo", + "x_old", + "state", + "core_profiles", + "step", + "transport_model", + "source_profiles", + "source_models", + "coeffs", + "evolving_names", +] """ from __future__ import annotations @@ -75,40 +98,40 @@ class CoeffsCallback: """Implements fvm.Block1DCoeffsCallback using calc_coeffs. Attributes: + static_config_slice: See the docstring for `stepper.Stepper`. + geo: See the docstring for `stepper.Stepper`. core_profiles_t: The core plasma profiles at the start of the time step. core_profiles_t_plus_dt: Core plasma profiles at the end of the time step. - evolving_names: The names of the evolving variables. - geo: See the docstring for `stepper.Stepper`. - static_config_slice: See the docstring for `stepper.Stepper`. transport_model: See the docstring for `stepper.Stepper`. explicit_source_profiles: See the docstring for `stepper.Stepper`. source_models: See the docstring for `stepper.Stepper`. + evolving_names: The names of the evolving variables. """ def __init__( self, + static_config_slice: config_slice.StaticConfigSlice, + geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, - static_config_slice: config_slice.StaticConfigSlice, transport_model: transport_model_lib.TransportModel, explicit_source_profiles: source_profiles_lib.SourceProfiles, source_models: source_models_lib.SourceModels, + evolving_names: tuple[str, ...], ): + self.static_config_slice = static_config_slice + self.geo = geo self.core_profiles_t = core_profiles_t self.core_profiles_t_plus_dt = core_profiles_t_plus_dt - self.evolving_names = evolving_names - self.geo = geo - self.static_config_slice = static_config_slice self.transport_model = transport_model self.explicit_source_profiles = explicit_source_profiles self.source_models = source_models + self.evolving_names = evolving_names def __call__( self, - x: tuple[fvm.CellVariable, ...], dynamic_config_slice: config_slice.DynamicConfigSlice, + x: tuple[fvm.CellVariable, ...], allow_pereverzev: bool = False, # Checks if reduced calc_coeffs for explicit terms when theta_imp=1 # should be called @@ -142,14 +165,14 @@ def __call__( use_pereverzev = False return calc_coeffs.calc_coeffs( - core_profiles=core_profiles, - evolving_names=self.evolving_names, - geo=self.geo, - dynamic_config_slice=dynamic_config_slice, static_config_slice=self.static_config_slice, + dynamic_config_slice=dynamic_config_slice, + geo=self.geo, + core_profiles=core_profiles, transport_model=self.transport_model, explicit_source_profiles=self.explicit_source_profiles, source_models=self.source_models, + evolving_names=self.evolving_names, use_pereverzev=use_pereverzev, explicit_call=explicit_call, ) @@ -169,11 +192,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) x = tuple([self.core_profiles_t[name] for name in self.evolving_names]) self.frozen_coeffs = super().__call__( - x, dynamic_config_slice, allow_pereverzev=False, explicit_call=False + dynamic_config_slice, x, allow_pereverzev=False, explicit_call=False ) def __call__( - self, x, dynamic_config_slice, allow_pereverzev=False, explicit_call=False + self, dynamic_config_slice, x, allow_pereverzev=False, explicit_call=False ): return self.frozen_coeffs @@ -221,27 +244,27 @@ def stepper(self) -> stepper_lib.Stepper: def __call__( self, - input_state: state.ToraxSimState, - geo: geometry.Geometry, - dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, + geo: geometry.Geometry, + input_state: state.ToraxSimState, explicit_source_profiles: source_profiles_lib.SourceProfiles, ) -> state.ToraxSimState: """Advances the simulation state one time step. Args: - input_state: State at the start of the time step, including the core - profiles which are being evolved. - geo: The geometry of the torus during this time step of the simulation. - While the geometry may change, any changes to the grid size can trigger - recompilation of the stepper (if it is jitted) or an error (assuming it - is JAX-compiled and lowered). + static_config_slice: Static parameters that, if they change, should + trigger a recompilation of the SimulationStepFn. dynamic_config_slice_provider: Object that returns a set of runtime parameters which may change from time step to time step or simulation run to run. If these config parameters change, it does NOT trigger a JAX recompilation. - static_config_slice: Static parameters that, if they change, should - trigger a recompilation of the SimulationStepFn. + geo: The geometry of the torus during this time step of the simulation. + While the geometry may change, any changes to the grid size can trigger + recompilation of the stepper (if it is jitted) or an error (assuming it + is JAX-compiled and lowered). + input_state: State at the start of the time step, including the core + profiles which are being evolved. explicit_source_profiles: Explicit source profiles computed based on the core profiles at the start of the time step. @@ -249,7 +272,7 @@ def __call__( ToraxSimState containing: - the core profiles at the end of the time step. - time and time step calculator state info. - - extra auxiliary outputs useful for internal inspection. + - core_sources and core_transport at the end of the time step. - stepper_error_state: 0 if solver converged with fine tolerance for this step 1 if solver did not converge for this step (was above coarse tol) @@ -305,10 +328,10 @@ def __call__( # conditions and time-dependent prescribed profiles not directly solved by # PDE system. core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt( - core_profiles_t=core_profiles_t, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, static_config_slice=static_config_slice, + dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, geo=geo, + core_profiles_t=core_profiles_t, ) stepper_iterations = 0 @@ -317,13 +340,13 @@ def __call__( # step with large dt) we apply the adaptive time step routine if requested. core_profiles, core_sources, core_transport, stepper_error_state = ( self._stepper_fn( - core_profiles_t=core_profiles_t, - core_profiles_t_plus_dt=core_profiles_t_plus_dt, - geo=geo, + dt=dt, + static_config_slice=static_config_slice, dynamic_config_slice_t=dynamic_config_slice_t, dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, - dt=dt, + geo=geo, + core_profiles_t=core_profiles_t, + core_profiles_t_plus_dt=core_profiles_t_plus_dt, explicit_source_profiles=explicit_source_profiles, ) ) @@ -333,8 +356,8 @@ def __call__( t=input_state.t + dt, dt=dt, core_profiles=core_profiles, - core_sources=core_sources, core_transport=core_transport, + core_sources=core_sources, stepper_iterations=stepper_iterations, time_step_calculator_state=time_step_calculator_state, stepper_error_state=stepper_error_state, @@ -376,13 +399,13 @@ def body_fun( ) core_profiles, core_sources, core_transport, stepper_error_state = ( self._stepper_fn( - core_profiles_t=core_profiles_t, - core_profiles_t_plus_dt=core_profiles_t_plus_dt, - geo=geo, + dt=dt, + static_config_slice=static_config_slice, dynamic_config_slice_t=dynamic_config_slice_t, dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, - dt=dt, + geo=geo, + core_profiles_t=core_profiles_t, + core_profiles_t_plus_dt=core_profiles_t_plus_dt, explicit_source_profiles=explicit_source_profiles, ) ) @@ -392,8 +415,8 @@ def body_fun( dt=dt, stepper_iterations=updated_output.stepper_iterations + 1, core_profiles=core_profiles, - core_sources=core_sources, core_transport=core_transport, + core_sources=core_sources, stepper_error_state=stepper_error_state, ) @@ -430,15 +453,15 @@ def body_fun( def get_initial_state( - dynamic_config_slice: config_slice.DynamicConfigSlice, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, - time_step_calculator: ts.TimeStepCalculator, source_models: source_models_lib.SourceModels, + time_step_calculator: ts.TimeStepCalculator, ) -> state.ToraxSimState: """Returns the initial state to be used by run_simulation().""" initial_core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice, static_config_slice, geo, source_models + static_config_slice, dynamic_config_slice, geo, source_models ) return state.ToraxSimState( t=jnp.array(dynamic_config_slice.numerics.t_initial), @@ -538,20 +561,20 @@ class Sim: def __init__( self, - time_step_calculator: ts.TimeStepCalculator, - initial_state: state.ToraxSimState, - geometry_provider: GeometryProvider, - dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, static_config_slice: config_slice.StaticConfigSlice, - stepper: stepper_lib.Stepper | None = None, + dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, + geometry_provider: GeometryProvider, + initial_state: state.ToraxSimState, + time_step_calculator: ts.TimeStepCalculator, transport_model: transport_model_lib.TransportModel | None = None, + stepper: stepper_lib.Stepper | None = None, step_fn: SimulationStepFn | None = None, ): - self._time_step_calculator = time_step_calculator - self._initial_state = initial_state - self._geometry_provider = geometry_provider - self._dynamic_config_slice_provider = dynamic_config_slice_provider self._static_config_slice = static_config_slice + self._dynamic_config_slice_provider = dynamic_config_slice_provider + self._geometry_provider = geometry_provider + self._initial_state = initial_state + self._time_step_calculator = time_step_calculator if step_fn is None: if stepper is None or transport_model is None: raise ValueError( @@ -652,12 +675,12 @@ def run( if spectator is not None: spectator.reset() return run_simulation( - initial_state=self.initial_state, - step_fn=self.step_fn, - geometry_provider=self.geometry_provider, - dynamic_config_slice_provider=self.dynamic_config_slice_provider, static_config_slice=self.static_config_slice, + dynamic_config_slice_provider=self.dynamic_config_slice_provider, + geometry_provider=self.geometry_provider, + initial_state=self.initial_state, time_step_calculator=self.time_step_calculator, + step_fn=self.step_fn, log_timestep_info=log_timestep_info, spectator=spectator, ) @@ -722,31 +745,31 @@ def build_sim_from_config( config.numerics.t_initial ) initial_state = get_initial_state( - dynamic_config_slice=dynamic_config_slice, static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice, geo=geo, - time_step_calculator=time_step_calculator, source_models=stepper.source_models, + time_step_calculator=time_step_calculator, ) return Sim( - time_step_calculator=time_step_calculator, - initial_state=initial_state, - geometry_provider=ConstantGeometryProvider(geo), - dynamic_config_slice_provider=dynamic_config_slice_provider, static_config_slice=static_config_slice, - stepper=stepper, + dynamic_config_slice_provider=dynamic_config_slice_provider, + geometry_provider=ConstantGeometryProvider(geo), + initial_state=initial_state, + time_step_calculator=time_step_calculator, transport_model=transport_model, + stepper=stepper, ) def run_simulation( - initial_state: state.ToraxSimState, - step_fn: SimulationStepFn, - geometry_provider: GeometryProvider, - dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, + geometry_provider: GeometryProvider, + initial_state: state.ToraxSimState, time_step_calculator: ts.TimeStepCalculator, + step_fn: SimulationStepFn, log_timestep_info: bool = False, spectator: spectator_lib.Spectator | None = None, ) -> tuple[state.ToraxSimState, ...]: @@ -766,13 +789,16 @@ def run_simulation( history. Args: - initial_state: The starting state of the simulation. This includes both the - state variables which the stepper.Stepper will evolve (like ion temp, psi, - etc.) as well as other states that need to be be tracked, like time. - step_fn: Callable which takes in ToraxSimState and outputs the ToraxSimState - after one timestep. Note that step_fn determines dt (how long the timestep - is). The state_history that run_simulation() outputs comes from these - ToraxSimState objects. + static_config_slice: A static set of arguments to provide to the step_fn. If + step_fn is JAX-compiled, then these params are "compile-time constant" + meaning that they are considered static to the compiled function. If they + change (i.e. the same step_fn is called again with a different + static_config_slice), then the step_fn will be recompiled. JAX determines + if recompilation is necessary via the hash of the static_config_slice. + dynamic_config_slice_provider: Provides a DynamicConfigSlice to use as input + for each time step. See static_config_slice and the config_slice module + docstring for config_slice to understand why we need the dynamic and + static config slices and what they control. geometry_provider: Provides the geometry of the torus for each time step based on the ToraxSimState at the start of the time step. The geometry may change from time step to time step, so the sim needs a function to provide @@ -781,18 +807,15 @@ def run_simulation( a time step and returns the Geometry for that time step. For most use cases, only the time will be relevant from the ToraxSimState (in order to support time-dependent geometries). - dynamic_config_slice_provider: Provides a DynamicConfigSlice to use as input - for each time step. See static_config_slice and the config_slice module - docstring for config_slice to understand why we need the dynamic and - static config slices and what they control. - static_config_slice: A static set of arguments to provide to the step_fn. If - step_fn is JAX-compiled, then these params are "compile-time constant" - meaning that they are considered static to the compiled function. If they - change (i.e. the same step_fn is called again with a different - static_config_slice), then the step_fn will be recompiled. JAX determines - if recompilation is necessary via the hash of the static_config_slice. + initial_state: The starting state of the simulation. This includes both the + state variables which the stepper.Stepper will evolve (like ion temp, psi, + etc.) as well as other states that need to be be tracked, like time. time_step_calculator: TimeStepCalculator determining policy for stepping through time. + step_fn: Callable which takes in ToraxSimState and outputs the ToraxSimState + after one timestep. Note that step_fn determines dt (how long the timestep + is). The state_history that run_simulation() outputs comes from these + ToraxSimState objects. log_timestep_info: If True, logs basic timestep info, like time, dt, on every step. spectator: Object which can "spectate" values as the simulation runs. See @@ -829,11 +852,11 @@ def run_simulation( # before starting the run-loop. The explicit source profiles will be computed # inside the loop and will be merged with these implicit source profiles. initial_state.core_sources = _get_initial_source_profiles( - source_models=step_fn.stepper.source_models, static_config_slice=static_config_slice, dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=initial_state.core_profiles, + source_models=step_fn.stepper.source_models, ) if spectator is not None: # Because of the updates we apply to the core sources during the next @@ -868,10 +891,10 @@ def run_simulation( # DynamicSourceConfigSlice. All implicit sources will have their profiles # set to 0. explicit_source_profiles = source_models_lib.build_source_profiles( - source_models=step_fn.stepper.source_models, dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=sim_state.core_profiles, + source_models=step_fn.stepper.source_models, explicit=True, ) @@ -880,9 +903,9 @@ def run_simulation( # profiles at this time step's t. We can merge those "implicit" source # profiles with the explicit ones computed here. sim_state.core_sources = _merge_source_profiles( - source_models=step_fn.stepper.source_models, explicit_source_profiles=explicit_source_profiles, implicit_source_profiles=sim_state.core_sources, + source_models=step_fn.stepper.source_models, qei_core_profiles=sim_state.core_profiles, ) # Make sure to "spectate" the state after the source profiles have been @@ -894,10 +917,10 @@ def run_simulation( # Now prep the spectator for the following time step. spectator.before_step() sim_state = step_fn( - sim_state, - geo, - dynamic_config_slice_provider, static_config_slice, + dynamic_config_slice_provider, + geo, + sim_state, explicit_source_profiles, ) stepper_error_state = sim_state.stepper_error_state @@ -915,16 +938,16 @@ def run_simulation( # profiles computed based on the final state. logging.info("Updating last step's source profiles.") explicit_source_profiles = source_models_lib.build_source_profiles( - source_models=step_fn.stepper.source_models, dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=sim_state.core_profiles, + source_models=step_fn.stepper.source_models, explicit=True, ) sim_state.core_sources = _merge_source_profiles( - source_models=step_fn.stepper.source_models, explicit_source_profiles=explicit_source_profiles, implicit_source_profiles=sim_state.core_sources, + source_models=step_fn.stepper.source_models, qei_core_profiles=sim_state.core_profiles, ) if spectator is not None: @@ -1028,10 +1051,10 @@ def _update_spectator( def update_current_distribution( - source_models: source_models_lib.SourceModels, dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + source_models: source_models_lib.SourceModels, ) -> state.CoreProfiles: """Update bootstrap current based on the new core_profiles.""" @@ -1068,17 +1091,20 @@ def update_current_distribution( def update_psidot( - source_models: source_models_lib.SourceModels, dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + source_models: source_models_lib.SourceModels, ) -> state.CoreProfiles: """Update psidot based on new core_profiles.""" psidot = dataclasses.replace( core_profiles.psidot, value=source_models_lib.calc_psidot( - source_models, dynamic_config_slice, geo, core_profiles + dynamic_config_slice, + geo, + core_profiles, + source_models, ), ) @@ -1090,10 +1116,10 @@ def update_psidot( def provide_core_profiles_t_plus_dt( - core_profiles_t: state.CoreProfiles, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, geo: geometry.Geometry, + core_profiles_t: state.CoreProfiles, ) -> state.CoreProfiles: """Provides state at t_plus_dt with new boundary conditions and prescribed profiles.""" updated_boundary_conditions = ( @@ -1103,10 +1129,10 @@ def provide_core_profiles_t_plus_dt( ) ) updated_values = core_profile_setters.updated_prescribed_core_profiles( - core_profiles=core_profiles_t, - dynamic_config_slice=dynamic_config_slice_t_plus_dt, static_config_slice=static_config_slice, + dynamic_config_slice=dynamic_config_slice_t_plus_dt, geo=geo, + core_profiles=core_profiles_t, ) temp_ion = dataclasses.replace( core_profiles_t.temp_ion, @@ -1141,11 +1167,11 @@ def provide_core_profiles_t_plus_dt( def _get_initial_source_profiles( - source_models: source_models_lib.SourceModels, static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + source_models: source_models_lib.SourceModels, ) -> source_profiles_lib.SourceProfiles: """Returns the "implicit" profiles for the initial state in run_simulation(). @@ -1161,7 +1187,6 @@ def _get_initial_source_profiles( core profiles. Args: - source_models: Source models used to compute core source profiles. static_config_slice: Config parameters which, when they change, trigger recompilations. They should not change within a single run of the sim. dynamic_config_slice: Runtime parameters which may change from time step to @@ -1169,22 +1194,25 @@ def _get_initial_source_profiles( geo: The geometry of the torus during this time step of the simulation. core_profiles: Core profiles that may evolve throughout the course of a simulation. These values here are, of course, only the original states. + source_models: Source models used to compute core source profiles. Returns: SourceProfiles from implicit source models based on the core profiles from the starting state. """ implicit_profiles = source_models_lib.build_source_profiles( - source_models=source_models, dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=core_profiles, + source_models=source_models, explicit=False, ) qei = source_models.qei_source.get_qei( - dynamic_config_slice.sources[source_models.qei_source.name].source_type, - dynamic_config_slice=dynamic_config_slice, static_config_slice=static_config_slice, + source_type=dynamic_config_slice.sources[ + source_models.qei_source.name + ].source_type, + dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=core_profiles, ) @@ -1196,9 +1224,9 @@ def _get_initial_source_profiles( # in our tests, jitting this function actually slightly slows down runs, so this # is left as pure python. def _merge_source_profiles( - source_models: source_models_lib.SourceModels, explicit_source_profiles: source_profiles_lib.SourceProfiles, implicit_source_profiles: source_profiles_lib.SourceProfiles, + source_models: source_models_lib.SourceModels, qei_core_profiles: state.CoreProfiles, ) -> source_profiles_lib.SourceProfiles: """Returns a SourceProfiles that merges the input profiles. @@ -1212,7 +1240,6 @@ def _merge_source_profiles( SourceProfiles that includes both. Args: - source_models: Source models used to compute the profiles given. explicit_source_profiles: Profiles from explicit source models. This SourceProfiles dict will include keys for both the explicit and implicit sources, but only the explicit sources will have non-zero profiles. See @@ -1221,6 +1248,7 @@ def _merge_source_profiles( SourceProfiles dict will include keys for both the explicit and implicit sources, but only the implicit sources will have non-zero profiles. See source.py and source_config.py for more info on explicit vs. implicit. + source_models: Source models used to compute the profiles given. qei_core_profiles: The core profiles used to compute the Qei source. Returns: diff --git a/torax/sources/bootstrap_current_source.py b/torax/sources/bootstrap_current_source.py index fbee48e2..43f6c302 100644 --- a/torax/sources/bootstrap_current_source.py +++ b/torax/sources/bootstrap_current_source.py @@ -102,10 +102,10 @@ def calc_neoclassical( # We don't store q_cell in the evolving core profiles, so we need to # recalculate it. q_face, _ = physics.calc_q_from_jtot_psi( - geo, - jtot_face, - psi, - dynamic_config_slice.numerics.q_correction_factor, + geo=geo, + psi=psi, + jtot_face=jtot_face, + q_correction_factor=dynamic_config_slice.numerics.q_correction_factor, ) nuestar = ( 6.921e-18 diff --git a/torax/sources/external_current_source.py b/torax/sources/external_current_source.py index 54316048..d3036787 100644 --- a/torax/sources/external_current_source.py +++ b/torax/sources/external_current_source.py @@ -44,14 +44,14 @@ def calculate_Iext( # pylint: disable=invalid-name def calculate_jext_face( - geo: geometry.Geometry, dynamic_config_slice: config_slice.DynamicConfigSlice, + geo: geometry.Geometry, ) -> jnp.ndarray: """Calculates the external current density profiles. Args: + dynamic_config_slice: Parameter configuration at present timestep. geo: Tokamak geometry. - dynamic_config_slice: Parameter configuration at present timesteap. Returns: External current density profile along the face grid. @@ -72,14 +72,14 @@ def calculate_jext_face( def calculate_jext_hires( - geo: geometry.Geometry, dynamic_config_slice: config_slice.DynamicConfigSlice, + geo: geometry.Geometry, ) -> jnp.ndarray: """Calculates the external current density profile along the hires grid. Args: + dynamic_config_slice: Parameter configuration at present timestep. geo: Tokamak geometry. - dynamic_config_slice: Parameter configuration at present timesteap. Returns: External current density profile along the hires cell grid. @@ -138,8 +138,8 @@ def get_value( lambda _0, _1, _2: source.ProfileType.FACE.get_zero_profile(geo) ), formula=lambda dcs, g, _: calculate_jext_face( - g, dcs, + g, ), output_shape=source.ProfileType.FACE.get_profile_shape(geo), ) @@ -161,8 +161,8 @@ def jext_hires( # There is no model for this source. model_func=(lambda _0, _1, _2: jnp.zeros_like(geo.r_hires_norm)), formula=lambda dcs, g, _: calculate_jext_hires( - g, dcs, + g, ), output_shape=geo.r_hires_norm.shape, ) diff --git a/torax/sources/formulas.py b/torax/sources/formulas.py index 48328895..2f89ada5 100644 --- a/torax/sources/formulas.py +++ b/torax/sources/formulas.py @@ -29,10 +29,10 @@ def exponential_profile( + geo: geometry.Geometry, c1: float, c2: float, total: float, - geo: geometry.Geometry, use_normalized_r: bool = False, ) -> jnp.ndarray: """Returns an exponential profile on the cell grid. @@ -45,10 +45,10 @@ def exponential_profile( The formula can use the normalized r and r_face if specified. Args: + geo: Geometry constants of torus. c1: Constant. See description above. c2: Constant. See description above. total: Constant. See description above. - geo: Geometry constants of torus. use_normalized_r: If True, uses r_norm and r_face_norm to calculate the profile. @@ -65,10 +65,10 @@ def exponential_profile( def gaussian_profile( + geo: geometry.Geometry, c1: float, c2: float, total: float, - geo: geometry.Geometry, use_normalized_r: bool = False, ) -> jnp.ndarray: """Returns a gaussian profile on the cell grid. @@ -81,10 +81,10 @@ def gaussian_profile( The formula can use the normalized r and r_face if specified. Args: + geo: Geometry constants of torus. c1: Constant. See description above. c2: Constant. See description above. total: Constant. See description above. - geo: Geometry constants of torus. use_normalized_r: If True, uses r_norm and r_face_norm to calculate the profile. diff --git a/torax/sources/qei_source.py b/torax/sources/qei_source.py index ee946060..c0fdabd3 100644 --- a/torax/sources/qei_source.py +++ b/torax/sources/qei_source.py @@ -61,8 +61,8 @@ class QeiSource(source.Source): def get_qei( self, source_type: int, - dynamic_config_slice: config_slice.DynamicConfigSlice, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> source_profiles.QeiInfo: @@ -71,7 +71,7 @@ def get_qei( return jax.lax.cond( source_type == source_config.SourceType.MODEL_BASED.value, lambda: _model_based_qei( - dynamic_config_slice, static_config_slice, geo, core_profiles + static_config_slice, dynamic_config_slice, geo, core_profiles ), lambda: source_profiles.QeiInfo.zeros(geo), ) @@ -95,8 +95,8 @@ def get_source_profile_for_affected_core_profile( def _model_based_qei( - dynamic_config_slice: config_slice.DynamicConfigSlice, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> source_profiles.QeiInfo: diff --git a/torax/sources/source.py b/torax/sources/source.py index 6c221598..c08a4142 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -219,10 +219,10 @@ def get_value( else self.formula ) return get_source_profiles( - source_type=source_type, dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=core_profiles, + source_type=source_type, model_func=model_func, formula=formula, output_shape=output_shape, @@ -373,10 +373,10 @@ def get_value( dynamic_config_slice, geo, core_profiles ) profile = super().get_value( - source_type=source_type, dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=core_profiles, + source_type=source_type, ) assert isinstance(profile, jnp.ndarray) chex.assert_rank(profile, 1) diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index f15bee7d..51481032 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -45,22 +45,22 @@ ], ) def build_source_profiles( - source_models: SourceModels, dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + source_models: SourceModels, explicit: bool, ) -> source_profiles.SourceProfiles: """Builds explicit or implicit source profiles. Args: - source_models: Functions computing profiles for all TORAX sources/sinks. dynamic_config_slice: Input config for this time step. Can change from time step to time step. geo: Geometry of the torus. core_profiles: Core plasma profiles, either at the start of the time step (if explicit) or the live profiles being evolved during the time step (if implicit). + source_models: Functions computing profiles for all TORAX sources/sinks. explicit: If True, this function should return profiles for all explicit sources. All implicit sources should be set to 0. And same vice versa. @@ -337,9 +337,9 @@ def _build_temp_ion_el_profiles( def sum_sources_psi( + geo: geometry.Geometry, source_models: SourceModels, source_profile: source_profiles.SourceProfiles, - geo: geometry.Geometry, ) -> jnp.ndarray: """Computes psi source values for sim.calc_coeffs.""" total = ( @@ -358,9 +358,9 @@ def sum_sources_psi( def sum_sources_ne( - source_models: SourceModels, - source_profile: source_profiles.SourceProfiles, geo: geometry.Geometry, + source_profile: source_profiles.SourceProfiles, + source_models: SourceModels, ) -> jnp.ndarray: """Computes ne source values for sim.calc_coeffs.""" total = jnp.zeros_like(geo.r) @@ -374,9 +374,9 @@ def sum_sources_ne( def sum_sources_temp_ion( - source_models: SourceModels, - source_profile: source_profiles.SourceProfiles, geo: geometry.Geometry, + source_profile: source_profiles.SourceProfiles, + source_models: SourceModels, ) -> jnp.ndarray: """Computes temp_ion source values for sim.calc_coeffs.""" total = jnp.zeros_like(geo.r) @@ -390,9 +390,9 @@ def sum_sources_temp_ion( def sum_sources_temp_el( - source_models: SourceModels, - source_profile: source_profiles.SourceProfiles, geo: geometry.Geometry, + source_profile: source_profiles.SourceProfiles, + source_models: SourceModels, ) -> jnp.ndarray: """Computes temp_el source values for sim.calc_coeffs.""" total = jnp.zeros_like(geo.r) @@ -406,10 +406,10 @@ def sum_sources_temp_el( def calc_and_sum_sources_psi( - source_models: SourceModels, dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + source_models: SourceModels, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Computes sum of psi sources for psi_dot calculation.""" @@ -447,10 +447,10 @@ def calc_and_sum_sources_psi( ], ) def calc_psidot( - source_models: SourceModels, dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + source_models: SourceModels, ) -> jnp.ndarray: r"""Calculates psidot (loop voltage). Used for the Ohmic electron heat source. @@ -462,10 +462,10 @@ def calc_psidot( (but abridged) formulation as in sim.calc_coeffs and fvm._calc_c is used here Args: - source_models: All TORAX source/sinks. dynamic_config_slice: Simulation configuration at this timestep geo: Torus geometry core_profiles: Core plasma profiles. + source_models: All TORAX source/sinks. Returns: psidot: on cell grid @@ -473,10 +473,10 @@ def calc_psidot( consts = constants.CONSTANTS psi_sources, sigma = calc_and_sum_sources_psi( - source_models, dynamic_config_slice, geo, core_profiles, + source_models, ) toc_psi = ( 1.0 @@ -500,10 +500,10 @@ def calc_psidot( # OhmicHeatSource is a special case and defined here to avoid circular # dependencies, since it depends on the psi sources def _ohmic_heat_model( - source_models: SourceModels, dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + source_models: SourceModels, ) -> jnp.ndarray: """Returns the Ohmic source for electron heat equation.""" jtot, _ = physics.calc_jtot_from_psi( @@ -511,7 +511,12 @@ def _ohmic_heat_model( core_profiles.psi, ) - psidot = calc_psidot(source_models, dynamic_config_slice, geo, core_profiles) + psidot = calc_psidot( + dynamic_config_slice, + geo, + core_profiles, + source_models, + ) pohm = jtot * psidot / (2 * jnp.pi * geo.Rmaj) return pohm @@ -554,10 +559,10 @@ def _model_func( core_profiles: state.CoreProfiles, ) -> jnp.ndarray: return _ohmic_heat_model( - source_models=self.source_models, dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=core_profiles, + source_models=self.source_models, ) # Must use object.__setattr__ instead of simply doing @@ -852,9 +857,9 @@ def all_sources(self) -> dict[str, source_lib.Source]: def build_all_zero_profiles( - source_models: SourceModels, dynamic_config_slice: config_slice.DynamicConfigSlice, geo: geometry.Geometry, + source_models: SourceModels, ) -> source_profiles.SourceProfiles: """Returns a SourceProfiles object with all zero profiles.""" profiles = { diff --git a/torax/sources/tests/formulas.py b/torax/sources/tests/formulas.py index e96623bb..724e8bc7 100644 --- a/torax/sources/tests/formulas.py +++ b/torax/sources/tests/formulas.py @@ -198,14 +198,14 @@ def _run_sim_and_check( ): """Runs sim with new dynamic config and checks the profiles vs. expected.""" torax_outputs = sim_lib.run_simulation( - initial_state=sim.initial_state, - step_fn=sim.step_fn, - geometry_provider=sim.geometry_provider, + static_config_slice=sim.static_config_slice, dynamic_config_slice_provider=( config_slice.TimeDependentDynamicConfigSliceProvider(config) ), - static_config_slice=sim.static_config_slice, + geometry_provider=sim.geometry_provider, + initial_state=sim.initial_state, time_step_calculator=sim.time_step_calculator, + step_fn=sim.step_fn, ) state_history, _, _ = state_lib.build_history_from_states(torax_outputs) t = state_lib.build_time_history_from_states(torax_outputs) diff --git a/torax/sources/tests/fusion_heat_source.py b/torax/sources/tests/fusion_heat_source.py index 2046dae9..b5f09f69 100644 --- a/torax/sources/tests/fusion_heat_source.py +++ b/torax/sources/tests/fusion_heat_source.py @@ -62,8 +62,8 @@ def test_calc_fusion( nref = config.nref core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, source_models=source_models_lib.SourceModels(), ) diff --git a/torax/sources/tests/qei_source.py b/torax/sources/tests/qei_source.py index 7eb06586..6b297234 100644 --- a/torax/sources/tests/qei_source.py +++ b/torax/sources/tests/qei_source.py @@ -49,8 +49,8 @@ def test_source_value(self): config = config_lib.Config() geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, source_models=source_models_lib.SourceModels(qei_source=source), ) @@ -59,8 +59,8 @@ def test_source_value(self): static_slice = config_slice.build_static_config_slice(config) qei = source.get_qei( dynamic_slice.sources[source.name].source_type, - dynamic_slice, static_slice, + dynamic_slice, geo, core_profiles, ) @@ -71,8 +71,8 @@ def test_invalid_source_types_raise_errors(self): config = config_lib.Config() geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, source_models=source_models_lib.SourceModels(qei_source=source), ) @@ -83,8 +83,8 @@ def test_invalid_source_types_raise_errors(self): with self.assertRaises(jax.interpreters.xla.xe.XlaRuntimeError): source.get_qei( unsupported_type.value, - dynamic_slice, static_slice, + dynamic_slice, geo, core_profiles, ) diff --git a/torax/sources/tests/source.py b/torax/sources/tests/source.py index 1762a56d..8312b551 100644 --- a/torax/sources/tests/source.py +++ b/torax/sources/tests/source.py @@ -43,8 +43,8 @@ def test_zero_profile_works_by_default(self): ) geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, source_models=source_models_lib.SourceModels( additional_sources=[source] @@ -113,8 +113,8 @@ def test_defaults_output_zeros(self): ) geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, source_models=source_models_lib.SourceModels( additional_sources=[source] @@ -167,8 +167,8 @@ def test_overriding_default_formula(self): ) geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, source_models=source_models_lib.SourceModels( additional_sources=[source] @@ -202,8 +202,8 @@ def test_overriding_model(self): ) geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, source_models=source_models_lib.SourceModels( additional_sources=[source] @@ -266,8 +266,8 @@ def test_custom_formula(self): ) geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, # defaults are enough for this. source_models=source_models_lib.SourceModels(), @@ -295,8 +295,8 @@ def test_multiple_profiles_raises_error(self): ) geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, # defaults are enough for this. source_models=source_models_lib.SourceModels(), diff --git a/torax/sources/tests/source_models.py b/torax/sources/tests/source_models.py index b053f72b..4736bd08 100644 --- a/torax/sources/tests/source_models.py +++ b/torax/sources/tests/source_models.py @@ -47,16 +47,16 @@ def test_computing_source_profiles_works_with_all_defaults(self): geo = torax.build_circular_geometry(config) source_models = source_models_lib.SourceModels() core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, source_models=source_models, ) _ = source_models_lib.build_source_profiles( - source_models, dynamic_config_slice, geo, core_profiles, explicit=True + dynamic_config_slice, geo, core_profiles, source_models, explicit=True ) _ = source_models_lib.build_source_profiles( - source_models, dynamic_config_slice, geo, core_profiles, explicit=False + dynamic_config_slice, geo, core_profiles, source_models, explicit=False ) def test_summed_temp_ion_profiles_dont_change_when_jitting(self): @@ -84,11 +84,11 @@ def test_summed_temp_ion_profiles_dont_change_when_jitting(self): ) with self.subTest('without_jit'): summed_temp_ion = source_models_lib.sum_sources_temp_ion( - source_models, profiles, geo + geo, profiles, source_models ) np.testing.assert_allclose(summed_temp_ion, ones * 4 * geo.vpr) summed_temp_el = source_models_lib.sum_sources_temp_el( - source_models, profiles, geo + geo, profiles, source_models ) np.testing.assert_allclose(summed_temp_el, ones * 11 * geo.vpr) @@ -97,13 +97,13 @@ def test_summed_temp_ion_profiles_dont_change_when_jitting(self): source_models_lib.sum_sources_temp_ion, static_argnames=['source_models'], ) - jitted_temp_ion = sum_temp_ion(source_models, profiles, geo) + jitted_temp_ion = sum_temp_ion(geo, profiles, source_models) np.testing.assert_allclose(jitted_temp_ion, ones * 4 * geo.vpr) sum_temp_el = jax.jit( source_models_lib.sum_sources_temp_el, static_argnames=['source_models'], ) - jitted_temp_el = sum_temp_el(source_models, profiles, geo) + jitted_temp_el = sum_temp_el(geo, profiles, source_models) np.testing.assert_allclose(jitted_temp_el, ones * 11 * geo.vpr) def test_custom_source_profiles_dont_change_when_jitted(self): @@ -162,17 +162,17 @@ def foo_formula(unused_dcs, geo: geometry.Geometry, unused_state): def compute_and_sum_profiles(): profiles = source_models_lib.build_source_profiles( - source_models=source_models, dynamic_config_slice=dynamic_config_slice, geo=geo, core_profiles=core_profiles, + source_models=source_models, # Configs set sources to implicit by default, so set this to False to # calculate the custom source's profile. explicit=False, ) - ne = source_models_lib.sum_sources_ne(source_models, profiles, geo) + ne = source_models_lib.sum_sources_ne(geo, profiles, source_models) temp_el = source_models_lib.sum_sources_temp_el( - source_models, profiles, geo + geo, profiles, source_models ) return (ne, temp_el) diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index 20a0bc2d..e500b48f 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -104,8 +104,8 @@ def test_source_value(self): source_models = source_models_lib.SourceModels() geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, source_models=source_models, ) @@ -123,8 +123,8 @@ def test_invalid_source_types_raise_errors(self): config = config_lib.Config() geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, # only need default sources here. source_models=source_models_lib.SourceModels(), @@ -158,8 +158,8 @@ def test_source_value(self): config = config_lib.Config() geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, # only need default sources here. source_models=source_models_lib.SourceModels(), @@ -178,8 +178,8 @@ def test_invalid_source_types_raise_errors(self): config = config_lib.Config() geo = geometry.build_circular_geometry(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice=config_slice.build_dynamic_config_slice(config), static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=config_slice.build_dynamic_config_slice(config), geo=geo, # only need default sources here. source_models=source_models_lib.SourceModels(), diff --git a/torax/state.py b/torax/state.py index 6c57f81d..f4886a47 100644 --- a/torax/state.py +++ b/torax/state.py @@ -201,16 +201,18 @@ def zeros(cls, geo: geometry.Geometry) -> CoreTransport: class ToraxSimState: """Full simulator state. - The simulation stepping in sim.py evolves the "mesh state" which includes all + The simulation stepping in sim.py evolves core_profiles which includes all the attributes the simulation is advancing. But beyond those, there are - additional stateful elements which evolve on each simulation step. + additional stateful elements which may evolve on each simulation step, such + as sources and transport. - This class includes both the mesh state and these additional elements. + This class includes both core_profiles and these additional elements. Attributes: t: time coordinate dt: timestep interval core_profiles: Core plasma profiles at time t. + core_transport: Core plasma transport coefficients computed at time t. core_sources: Profiles for all sources/sinks. For any state-dependent source models, the profiles in this dataclass are computed based on the core profiles at time t, almost. When running `sim.run_simulation()`, any @@ -221,7 +223,6 @@ class ToraxSimState: at time t, but is not guaranteed to be. In case exact source profiles are required for each time step, they must be recomputed manually after running `run_simulation()`. - core_transport: Core plasma transport coefficients computed at time t. stepper_iterations: number of stepper iterations carried out in previous step, i.e. the number of times dt was reduced when using the adaptive dt method. @@ -236,8 +237,8 @@ class ToraxSimState: # Profiles evolved or calculated by the simulation. core_profiles: CoreProfiles - core_sources: source_profiles.SourceProfiles core_transport: CoreTransport + core_sources: source_profiles.SourceProfiles # Other "side" states used for logging and feeding to other components of # TORAX. diff --git a/torax/stepper/linear_theta_method.py b/torax/stepper/linear_theta_method.py index 95cba3da..680942a6 100644 --- a/torax/stepper/linear_theta_method.py +++ b/torax/stepper/linear_theta_method.py @@ -42,15 +42,15 @@ def __init__( def _x_new( self, - core_profiles_t: state.CoreProfiles, - core_profiles_t_plus_dt: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice_t: config_slice.DynamicConfigSlice, dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, - dt: jax.Array, + geo: geometry.Geometry, + core_profiles_t: state.CoreProfiles, + core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, + evolving_names: tuple[str, ...], ) -> tuple[ tuple[fvm.CellVariable, ...], source_profiles.SourceProfiles, @@ -66,20 +66,20 @@ def _x_new( # Instantiate coeffs_callback class coeffs_callback = self.callback_class( + static_config_slice=static_config_slice, + geo=geo, core_profiles_t=core_profiles_t, core_profiles_t_plus_dt=core_profiles_t_plus_dt, - evolving_names=evolving_names, - geo=geo, - static_config_slice=static_config_slice, transport_model=self.transport_model, explicit_source_profiles=explicit_source_profiles, source_models=self.source_models, + evolving_names=evolving_names, ) # Compute the explicit coeffs based on the core profiles at time t and all # runtime parameters at time t. coeffs_exp = coeffs_callback( - x_old, dynamic_config_slice_t, allow_pereverzev=True, explicit_call=True + dynamic_config_slice_t, x_old, allow_pereverzev=True, explicit_call=True ) # Calculate x_new with the predictor corrector method. Reverts to a @@ -93,7 +93,9 @@ def _x_new( x_new_init, ( source_models_lib.build_all_zero_profiles( - self.source_models, dynamic_config_slice_t, geo + dynamic_config_slice_t, + geo, + self.source_models, ), state.CoreTransport.zeros(geo), ), @@ -101,13 +103,13 @@ def _x_new( x_new, (core_sources, core_transport) = ( predictor_corrector_method.predictor_corrector_method( - init_val=init_val, - x_old=x_old, dt=dt, + static_config_slice=static_config_slice, + dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, + x_old=x_old, + init_val=init_val, coeffs_exp=coeffs_exp, coeffs_callback=coeffs_callback, - dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, ) ) diff --git a/torax/stepper/nonlinear_theta_method.py b/torax/stepper/nonlinear_theta_method.py index f5f7a100..3f3b13bd 100644 --- a/torax/stepper/nonlinear_theta_method.py +++ b/torax/stepper/nonlinear_theta_method.py @@ -57,15 +57,15 @@ def __init__( def _x_new( self, - core_profiles_t: state.CoreProfiles, - core_profiles_t_plus_dt: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice_t: config_slice.DynamicConfigSlice, dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, - dt: jax.Array, + geo: geometry.Geometry, + core_profiles_t: state.CoreProfiles, + core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, + evolving_names: tuple[str, ...], ) -> tuple[ tuple[fvm.CellVariable, ...], source_profiles.SourceProfiles, @@ -75,26 +75,26 @@ def _x_new( """See Stepper._x_new docstring.""" coeffs_callback = self.callback_class( + static_config_slice=static_config_slice, + geo=geo, core_profiles_t=core_profiles_t, core_profiles_t_plus_dt=core_profiles_t_plus_dt, - evolving_names=evolving_names, - geo=geo, - static_config_slice=static_config_slice, transport_model=self.transport_model, explicit_source_profiles=explicit_source_profiles, source_models=self.source_models, + evolving_names=evolving_names, ) x_new, core_sources, core_transport, error = self._x_new_helper( + dt=dt, + static_config_slice=static_config_slice, dynamic_config_slice_t=dynamic_config_slice_t, dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, + geo=geo, core_profiles_t=core_profiles_t, core_profiles_t_plus_dt=core_profiles_t_plus_dt, - evolving_names=evolving_names, - geo=geo, explicit_source_profiles=explicit_source_profiles, - dt=dt, coeffs_callback=coeffs_callback, + evolving_names=evolving_names, ) return x_new, core_sources, core_transport, error @@ -102,16 +102,16 @@ def _x_new( @abc.abstractmethod def _x_new_helper( self, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice_t: config_slice.DynamicConfigSlice, dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, + geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, explicit_source_profiles: source_profiles.SourceProfiles, - dt: jax.Array, coeffs_callback: sim.CoeffsCallback, + evolving_names: tuple[str, ...], ) -> tuple[ tuple[fvm.CellVariable, ...], source_profiles.SourceProfiles, @@ -155,16 +155,16 @@ def __init__( def _x_new_helper( self, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice_t: config_slice.DynamicConfigSlice, dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, + geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, explicit_source_profiles: source_profiles.SourceProfiles, - dt: jax.Array, coeffs_callback: sim.CoeffsCallback, + evolving_names: tuple[str, ...], ) -> tuple[ tuple[fvm.CellVariable, ...], source_profiles.SourceProfiles, @@ -175,18 +175,18 @@ def _x_new_helper( # Unpack the outputs of the optimizer_solve_block. x_new, error, (core_sources, core_transport) = ( optimizer_solve_block.optimizer_solve_block( - x_old=tuple([core_profiles_t[name] for name in evolving_names]), - core_profiles_t_plus_dt=core_profiles_t_plus_dt, - evolving_names=evolving_names, dt=dt, - coeffs_callback=coeffs_callback, + static_config_slice=static_config_slice, dynamic_config_slice_t=dynamic_config_slice_t, dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, geo=geo, + x_old=tuple([core_profiles_t[name] for name in evolving_names]), + core_profiles_t_plus_dt=core_profiles_t_plus_dt, transport_model=self.transport_model, - source_models=self.source_models, explicit_source_profiles=explicit_source_profiles, + source_models=self.source_models, + coeffs_callback=coeffs_callback, + evolving_names=evolving_names, initial_guess_mode=self.initial_guess_mode, maxiter=self.maxiter, tol=self.tol, @@ -237,16 +237,16 @@ def __init__( def _x_new_helper( self, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice_t: config_slice.DynamicConfigSlice, dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, + geo: geometry.Geometry, core_profiles_t: state.CoreProfiles, core_profiles_t_plus_dt: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, explicit_source_profiles: source_profiles.SourceProfiles, - dt: jax.Array, coeffs_callback: sim.CoeffsCallback, + evolving_names: tuple[str, ...], ) -> tuple[ tuple[fvm.CellVariable, ...], source_profiles.SourceProfiles, @@ -260,18 +260,18 @@ def _x_new_helper( # Unpack the outputs of the optimizer_solve_block. x_new, error, (core_sources, core_transport) = ( newton_raphson_solve_block.newton_raphson_solve_block( - x_old=tuple([core_profiles_t[name] for name in evolving_names]), - core_profiles_t_plus_dt=core_profiles_t_plus_dt, - evolving_names=evolving_names, dt=dt, - coeffs_callback=coeffs_callback, + static_config_slice=static_config_slice, dynamic_config_slice_t=dynamic_config_slice_t, dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, geo=geo, + x_old=tuple([core_profiles_t[name] for name in evolving_names]), + core_profiles_t_plus_dt=core_profiles_t_plus_dt, transport_model=self.transport_model, - source_models=self.source_models, explicit_source_profiles=explicit_source_profiles, + source_models=self.source_models, + coeffs_callback=coeffs_callback, + evolving_names=evolving_names, log_iterations=dynamic_config_slice_t.solver.log_iterations, initial_guess_mode=self.initial_guess_mode, maxiter=self.maxiter, diff --git a/torax/stepper/predictor_corrector_method.py b/torax/stepper/predictor_corrector_method.py index 67529dbb..fe745938 100644 --- a/torax/stepper/predictor_corrector_method.py +++ b/torax/stepper/predictor_corrector_method.py @@ -28,29 +28,29 @@ def predictor_corrector_method( + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, + x_old: tuple[fvm.CellVariable, ...], init_val: tuple[ tuple[fvm.CellVariable, ...], chex.ArrayTree ], - x_old: tuple[fvm.CellVariable, ...], - dt: jax.Array, coeffs_exp: fvm.block_1d_coeffs.Block1DCoeffs, coeffs_callback: fvm.block_1d_coeffs.Block1DCoeffsCallback, - dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, ) -> tuple[tuple[fvm.CellVariable, ...], Any]: """Predictor-corrector method. Args: - init_val: Initial guess for the predictor corrector output. + dt: current timestep + static_config_slice: General input parameters which are fixed through a + simulation run, and if changed, would trigger a recompile. + dynamic_config_slice_t_plus_dt: Dynamic config parameters corresponding to + the next time step, needed for the implicit PDE coefficients x_old: Tuple of CellVariables correspond to the evolving core profiles at time t. - dt: current timestep + init_val: Initial guess for the predictor corrector output. coeffs_exp: Block1DCoeffs PDE coefficients at beginning of timestep coeffs_callback: coefficient callback function - dynamic_config_slice_t_plus_dt: Dynamic config parameters corresponding to - the next time step, needed for the implicit PDE coefficients - static_config_slice: General input parameters which are fixed through a - simulation run, and if changed, would trigger a recompile. Returns: x_new: Solution of evolving core profile state variables @@ -64,15 +64,15 @@ def loop_body(i, val): # pylint: disable=unused-argument x_new_guess = val[0] coeffs_new = coeffs_callback( - x_new_guess, dynamic_config_slice_t_plus_dt, + x_new_guess, allow_pereverzev=True, ) x_new = implicit_solve_block.implicit_solve_block( + dt=dt, x_old=x_old, x_new_guess=x_new_guess, - dt=dt, coeffs_old=coeffs_exp, coeffs_new=coeffs_new, theta_imp=static_config_slice.solver.theta_imp, diff --git a/torax/stepper/stepper.py b/torax/stepper/stepper.py index 9eaa48ea..72f17cb6 100644 --- a/torax/stepper/stepper.py +++ b/torax/stepper/stepper.py @@ -54,13 +54,13 @@ def __init__( def __call__( self, - core_profiles_t: state.CoreProfiles, - core_profiles_t_plus_dt: state.CoreProfiles, - geo: geometry.Geometry, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice_t: config_slice.DynamicConfigSlice, dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, - dt: jax.Array, + geo: geometry.Geometry, + core_profiles_t: state.CoreProfiles, + core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, ) -> tuple[ state.CoreProfiles, @@ -71,21 +71,21 @@ def __call__( """Applies a time step update. Args: - core_profiles_t: Core plasma profiles at the beginning of the time step. - core_profiles_t_plus_dt: Core plasma profiles which contain all available - prescribed quantities at the end of the time step. This includes - evolving boundary conditions and prescribed time-dependent profiles that - are not being evolved by the PDE system. - geo: Geometry of the torus. + dt: Time step duration. + static_config_slice: Input params that trigger recompilation when they + change. These don't have to be JAX-friendly types and can be used in + control-flow logic. dynamic_config_slice_t: Runtime configuration for time t (the start time of the step). These config params can change from step to step without triggering a recompilation. dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt, used for implicit calculations in the solver. - static_config_slice: Input params that trigger recompilation when they - change. These don't have to be JAX-friendly types and can be used in - control-flow logic. - dt: Time step duration. + geo: Geometry of the torus. + core_profiles_t: Core plasma profiles at the beginning of the time step. + core_profiles_t_plus_dt: Core plasma profiles which contain all available + prescribed quantities at the end of the time step. This includes + evolving boundary conditions and prescribed time-dependent profiles that + are not being evolved by the PDE system. explicit_source_profiles: Source profiles of all explicit sources (as configured by the input config). All implicit source's profiles will be set to 0 in this object. These explicit source profiles were calculated @@ -125,15 +125,15 @@ def __call__( # Don't call solver functions on an empty list if evolving_names: x_new, core_sources, core_transport, error = self._x_new( - core_profiles_t=core_profiles_t, - core_profiles_t_plus_dt=core_profiles_t_plus_dt, - evolving_names=evolving_names, - geo=geo, + dt=dt, + static_config_slice=static_config_slice, dynamic_config_slice_t=dynamic_config_slice_t, dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt, - static_config_slice=static_config_slice, - dt=dt, + geo=geo, + core_profiles_t=core_profiles_t, + core_profiles_t_plus_dt=core_profiles_t_plus_dt, explicit_source_profiles=explicit_source_profiles, + evolving_names=evolving_names, ) else: x_new = tuple() @@ -147,10 +147,10 @@ def __call__( core_profiles_t_plus_dt = ( core_profile_setters.update_evolving_core_profiles( - core_profiles_t_plus_dt, x_new, - evolving_names, dynamic_config_slice_t_plus_dt, + core_profiles_t_plus_dt, + evolving_names, ) ) @@ -163,15 +163,15 @@ def __call__( def _x_new( self, - core_profiles_t: state.CoreProfiles, - core_profiles_t_plus_dt: state.CoreProfiles, - evolving_names: tuple[str, ...], - geo: geometry.Geometry, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice_t: config_slice.DynamicConfigSlice, dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, - dt: jax.Array, + geo: geometry.Geometry, + core_profiles_t: state.CoreProfiles, + core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, + evolving_names: tuple[str, ...], ) -> tuple[ tuple[fvm.CellVariable, ...], source_profiles.SourceProfiles, @@ -184,23 +184,23 @@ def _x_new( will work, or implement a different `__call__`. Args: - core_profiles_t: Core plasma profiles at the beginning of the time step. - core_profiles_t_plus_dt: Core plasma profiles which contain all available - prescribed quantities at the end of the time step. This includes - evolving boundary conditions and prescribed time-dependent profiles that - are not being evolved by the PDE system. - evolving_names: The names of core_profiles variables that should evolve. - geo: Geometry of the torus. + dt: Time step duration. + static_config_slice: Input params that trigger recompilation when they + change. These don't have to be JAX-friendly types and can be used in + control-flow logic. dynamic_config_slice_t: Runtime configuration for time t (the start time of the step). These config params can change from step to step without triggering a recompilation. dynamic_config_slice_t_plus_dt: Runtime configuration for time t + dt, used for implicit calculations in the solver. - static_config_slice: Input params that trigger recompilation when they - change. These don't have to be JAX-friendly types and can be used in - control-flow logic. - dt: Time step duration. + geo: Geometry of the torus. + core_profiles_t: Core plasma profiles at the beginning of the time step. + core_profiles_t_plus_dt: Core plasma profiles which contain all available + prescribed quantities at the end of the time step. This includes + evolving boundary conditions and prescribed time-dependent profiles that + are not being evolved by the PDE system. explicit_source_profiles: see the docstring of __call__ + evolving_names: The names of core_profiles variables that should evolve. Returns: x_new: The values of the evolving variables at time t + dt. diff --git a/torax/tests/boundary_conditions.py b/torax/tests/boundary_conditions.py index eea37a19..db166de4 100644 --- a/torax/tests/boundary_conditions.py +++ b/torax/tests/boundary_conditions.py @@ -49,7 +49,9 @@ def test_setting_boundary_conditions(self): config ) core_profiles = core_profile_setters.initial_core_profiles( - initial_dynamic_config_slice, static_config_slice, geo + static_config_slice, + initial_dynamic_config_slice, + geo, ) dynamic_config_slice = config_slice.build_dynamic_config_slice(config, 0.5) diff --git a/torax/tests/physics.py b/torax/tests/physics.py index 4acb9300..e374af44 100644 --- a/torax/tests/physics.py +++ b/torax/tests/physics.py @@ -51,8 +51,8 @@ def test_calc_q_from_psi( q_face_jax, q_cell_jax = physics.calc_q_from_jtot_psi( geo, - jtot, references.psi, + jtot, dynamic_config_slice.numerics.q_correction_factor, ) diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index 410a5f47..01f707c7 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -166,8 +166,8 @@ def custom_source_formula(dynamic_config, geo, unused_state): sim_states = sim_lib.run_simulation( initial_state=sim_lib.get_initial_state( - dynamic_config_slice=initial_dcs, static_config_slice=static_config_slice, + dynamic_config_slice=initial_dcs, geo=geo, time_step_calculator=time_stepper, source_models=source_models, @@ -275,10 +275,10 @@ def stepper(self): def __call__( self, - input_state: state_module.ToraxSimState, - geo: geometry.Geometry, - dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, static_config_slice: config_slice.StaticConfigSlice, + dynamic_config_slice_provider: config_slice.DynamicConfigSliceProvider, + geo: geometry.Geometry, + input_state: state_module.ToraxSimState, explicit_source_profiles: source_profiles_lib.SourceProfiles, ) -> state_module.ToraxSimState: dt, ts_state = self._time_step_calculator.next_dt( @@ -296,10 +296,10 @@ def __call__( time_step_calculator_state=ts_state, # The returned source profiles include only the implicit sources. core_sources=source_models_lib.build_source_profiles( - source_models=self.stepper.source_models, dynamic_config_slice=dynamic_config_slice_provider(new_t), geo=geo, core_profiles=input_state.core_profiles, # no state evolution. + source_models=self.stepper.source_models, explicit=False, ), ) diff --git a/torax/tests/sim_time_dependence.py b/torax/tests/sim_time_dependence.py index b2aa4dd1..4264ea73 100644 --- a/torax/tests/sim_time_dependence.py +++ b/torax/tests/sim_time_dependence.py @@ -82,17 +82,17 @@ def test_time_dependent_params_update_in_adaptive_dt( config.numerics.t_initial ) input_state = sim_lib.get_initial_state( - dynamic_config_slice=initial_dynamic_config_slice, static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice=initial_dynamic_config_slice, geo=geo, time_step_calculator=time_calculator, source_models=source_models, ) output_state = sim_step_fn( - input_state=input_state, - geo=geo, - dynamic_config_slice_provider=dynamic_config_slice_provider, static_config_slice=config_slice.build_static_config_slice(config), + dynamic_config_slice_provider=dynamic_config_slice_provider, + geo=geo, + input_state=input_state, explicit_source_profiles=source_models_lib.build_source_profiles( source_models=source_models, dynamic_config_slice=initial_dynamic_config_slice, @@ -140,13 +140,13 @@ def __init__( def __call__( self, - core_profiles_t: state.CoreProfiles, - core_profiles_t_plus_dt: state.CoreProfiles, - geo: geometry.Geometry, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice_t: config_slice.DynamicConfigSlice, dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, - dt: jax.Array, + geo: geometry.Geometry, + core_profiles_t: state.CoreProfiles, + core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, ) -> tuple[ state.CoreProfiles, @@ -162,9 +162,9 @@ def __call__( ) # Use Qei as a hacky way to extract what the combined value was. core_sources = source_models_lib.build_all_zero_profiles( - source_models=self.source_models, dynamic_config_slice=dynamic_config_slice_t, geo=geo, + source_models=self.source_models, ) core_sources = dataclasses.replace( core_sources, diff --git a/torax/tests/state.py b/torax/tests/state.py index 87c19143..b2d76ad4 100644 --- a/torax/tests/state.py +++ b/torax/tests/state.py @@ -45,8 +45,8 @@ def make_hist(config, geo): def scan_f(counter: jax.Array, _) -> tuple[jax.Array, state.CoreProfiles]: core_profiles = core_profile_setters.initial_core_profiles( - config_slice.build_dynamic_config_slice(config), config_slice.build_static_config_slice(config), + config_slice.build_dynamic_config_slice(config), geo, ) # Make one variable in the history track the value of the counter @@ -83,8 +83,8 @@ def test_sanity_check( """Make sure State.sanity_check can be called.""" references = references_getter() basic_core_profiles = core_profile_setters.initial_core_profiles( - config_slice.build_dynamic_config_slice(references.config), config_slice.build_static_config_slice(references.config), + config_slice.build_dynamic_config_slice(references.config), references.geo, ) basic_core_profiles.sanity_check() @@ -151,8 +151,8 @@ def test_initial_boundary_condition_from_time_dependent_params(self): ), ) core_profiles = core_profile_setters.initial_core_profiles( - config_slice.build_dynamic_config_slice(config), config_slice.build_static_config_slice(config), + config_slice.build_dynamic_config_slice(config), geometry.build_circular_geometry(config), ) np.testing.assert_allclose( @@ -208,23 +208,23 @@ def test_initial_psi_from_j( ) geo = geo_builder(config1) core_profiles1 = core_profile_setters.initial_core_profiles( - config_slice.build_dynamic_config_slice(config1), config_slice.build_static_config_slice(config1), + config_slice.build_dynamic_config_slice(config1), geo=geo, ) core_profiles2 = core_profile_setters.initial_core_profiles( - config_slice.build_dynamic_config_slice(config2), config_slice.build_static_config_slice(config2), + config_slice.build_dynamic_config_slice(config2), geo=geo, ) core_profiles3 = core_profile_setters.initial_core_profiles( - config_slice.build_dynamic_config_slice(config3), config_slice.build_static_config_slice(config3), + config_slice.build_dynamic_config_slice(config3), geo=geo, ) core_profiles3_helper = core_profile_setters.initial_core_profiles( - config_slice.build_dynamic_config_slice(config3_helper), config_slice.build_static_config_slice(config3_helper), + config_slice.build_dynamic_config_slice(config3_helper), geo=geo, ) @@ -300,13 +300,13 @@ def test_initial_psi_from_geo_noop_circular(self): initial_psi_from_j=True, ) core_profiles1 = core_profile_setters.initial_core_profiles( - config_slice.build_dynamic_config_slice(config1), config_slice.build_static_config_slice(config1), + config_slice.build_dynamic_config_slice(config1), geometry.build_circular_geometry(config1), ) core_profiles2 = core_profile_setters.initial_core_profiles( - config_slice.build_dynamic_config_slice(config2), config_slice.build_static_config_slice(config2), + config_slice.build_dynamic_config_slice(config2), geometry.build_circular_geometry(config2), ) np.testing.assert_allclose( diff --git a/torax/tests/test_lib/explicit_stepper.py b/torax/tests/test_lib/explicit_stepper.py index 79d57228..49d2acf4 100644 --- a/torax/tests/test_lib/explicit_stepper.py +++ b/torax/tests/test_lib/explicit_stepper.py @@ -49,13 +49,13 @@ class ExplicitStepper(stepper_lib.Stepper): def __call__( self, - core_profiles_t: state.CoreProfiles, - core_profiles_t_plus_dt: state.CoreProfiles, - geo: geometry.Geometry, + dt: jax.Array, + static_config_slice: config_slice.StaticConfigSlice, dynamic_config_slice_t: config_slice.DynamicConfigSlice, dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice, - static_config_slice: config_slice.StaticConfigSlice, - dt: jax.Array, + geo: geometry.Geometry, + core_profiles_t: state.CoreProfiles, + core_profiles_t_plus_dt: state.CoreProfiles, explicit_source_profiles: source_profiles.SourceProfiles, ) -> tuple[ state.CoreProfiles, @@ -101,7 +101,9 @@ def __call__( # Source term c += source_models.sum_sources_temp_ion( - self.source_models, explicit_source_profiles, geo + geo, + explicit_source_profiles, + self.source_models, ) temp_ion_new = ( @@ -123,8 +125,8 @@ def __call__( q_face, _ = physics.calc_q_from_jtot_psi( geo=geo, - jtot_face=core_profiles_t.currents.jtot, psi=core_profiles_t.psi, + jtot_face=core_profiles_t.currents.jtot, q_correction_factor=dynamic_config_slice_t.numerics.q_correction_factor, ) s_face = physics.calc_s_from_psi(geo, core_profiles_t.psi) @@ -142,9 +144,9 @@ def __call__( s_face=s_face, ), source_models.build_all_zero_profiles( - source_models=self.source_models, dynamic_config_slice=dynamic_config_slice_t, geo=geo, + source_models=self.source_models, ), state.CoreTransport.zeros(geo), error, diff --git a/torax/transport_model/qlknn_wrapper.py b/torax/transport_model/qlknn_wrapper.py index 9b77780c..38b560af 100644 --- a/torax/transport_model/qlknn_wrapper.py +++ b/torax/transport_model/qlknn_wrapper.py @@ -309,10 +309,10 @@ def _combined( # always stable at r=0 due to the zero gradient boundary conditions. q, _ = physics.calc_q_from_jtot_psi( - geo, - core_profiles.currents.jtot_face, - core_profiles.psi, - runtime_config_inputs.q_correction_factor, + geo=geo, + psi=core_profiles.psi, + jtot_face=core_profiles.currents.jtot_face, + q_correction_factor=runtime_config_inputs.q_correction_factor, ) smag = physics.calc_s_from_psi( geo, diff --git a/torax/transport_model/tests/qlknn_wrapper.py b/torax/transport_model/tests/qlknn_wrapper.py index c03439f1..4036ea89 100644 --- a/torax/transport_model/tests/qlknn_wrapper.py +++ b/torax/transport_model/tests/qlknn_wrapper.py @@ -39,7 +39,7 @@ def test_qlknn_wrapper_cache_works(self): dynamic_config_slice = config_slice.build_dynamic_config_slice(config) static_config_slice = config_slice.build_static_config_slice(config) core_profiles = core_profile_setters.initial_core_profiles( - dynamic_config_slice, static_config_slice, geo + static_config_slice, dynamic_config_slice, geo ) qlknn_jitted(dynamic_config_slice, geo, core_profiles) # The call should be cached. If there was an error, the cache size would be