Skip to content

Commit

Permalink
Enable time-dependent non-evolved (prescribed) core_profiles.
Browse files Browse the repository at this point in the history
1. Also took the opportunity to combine various utility functions for core_profile boundary conditions, dataclass updates, initial and prescribed evolving condition updates, all into a single renamed update_core_profiles module.
2. Updated the sim tests to work for both .nc and legacy .h5 reference files
3. Included a new config variable which can disable the prescribed evolution. Can be useful e.g. if a user wants to initialize density scaled to a Greenwald fraction, and keep density fixed even if Ip is evolving.

PiperOrigin-RevId: 625751237
  • Loading branch information
jcitrin authored and Torax team committed Apr 17, 2024
1 parent db36c2c commit 41ceb3b
Show file tree
Hide file tree
Showing 30 changed files with 427 additions and 290 deletions.
105 changes: 0 additions & 105 deletions torax/boundary_conditions.py

This file was deleted.

5 changes: 5 additions & 0 deletions torax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ class Numerics:
current_eq: bool = False
# Solve the density equation (n evolves over time)
dens_eq: bool = False
# Enable time-dependent prescribed profiles.
# This option is provided to allow initialization of density profiles scaled
# to a Greenwald fraction, and freeze this density even if the current is time
# evolving. Otherwise the density will evolve to always maintain that GW frac.
enable_prescribed_profile_evolution: bool = True

# q-profile correction factor. Used only in ad-hoc circular geometry model
q_correction_factor: float = 1.38
Expand Down
6 changes: 6 additions & 0 deletions torax/config_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,12 @@ class DynamicNumerics:
# location if n != neped
largeValue_n: float

# Enable time-dependent prescribed profiles.
# This option is provided to allow initialization of density profiles scaled
# to a Greenwald fraction, and freeze this density even if the current is time
# evolving. Otherwise the density will evolve to always maintain that GW frac.
enable_prescribed_profile_evolution: bool


@chex.dataclass(frozen=True)
class DynamicExponentialFormulaConfigSlice:
Expand Down
4 changes: 2 additions & 2 deletions torax/fvm/residual_and_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torax import geometry
from torax import jax_utils
from torax import state
from torax import update_state
from torax import update_core_profiles
from torax.fvm import block_1d_coeffs
from torax.fvm import cell_variable
from torax.fvm import discrete_system
Expand Down Expand Up @@ -252,7 +252,7 @@ def theta_method_block_residual(
x_new_guess = fvm_conversions.vec_to_cell_variable_tuple(
x_new_guess_vec, core_profiles_t_plus_dt, evolving_names
)
core_profiles_t_plus_dt = update_state.update_core_profiles(
core_profiles_t_plus_dt = update_core_profiles.update_evolving_core_profiles(
core_profiles_t_plus_dt,
x_new_guess,
evolving_names,
Expand Down
10 changes: 5 additions & 5 deletions torax/fvm/tests/fvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torax import config_slice
from torax import fvm
from torax import geometry
from torax import initial_states
from torax import update_core_profiles
from torax.fvm import implicit_solve_block
from torax.fvm import residual_and_loss
from torax.sources import source_config
Expand Down Expand Up @@ -380,7 +380,7 @@ def test_nonlinear_solve_block_loss_minimum(
dynamic_config_slice = config_slice.build_dynamic_config_slice(config)
static_config_slice = config_slice.build_static_config_slice(config)
source_models = source_models_lib.SourceModels()
core_profiles = initial_states.initial_core_profiles(
core_profiles = update_core_profiles.initial_core_profiles(
dynamic_config_slice, static_config_slice, geo, source_models)
evolving_names = tuple(['temp_ion'])
explicit_source_profiles = source_models_lib.build_source_profiles(
Expand Down Expand Up @@ -495,7 +495,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
config,
)
source_models = source_models_lib.SourceModels()
initial_core_profiles = initial_states.initial_core_profiles(
initial_core_profiles = update_core_profiles.initial_core_profiles(
dynamic_config_slice, static_config_slice, geo, source_models
)
explicit_source_profiles = source_models_lib.build_source_profiles(
Expand Down Expand Up @@ -619,7 +619,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
config,
)
source_models = source_models_lib.SourceModels()
initial_core_profiles = initial_states.initial_core_profiles(
initial_core_profiles = update_core_profiles.initial_core_profiles(
dynamic_config_slice, static_config_slice_theta0, geo, source_models
)
explicit_source_profiles = source_models_lib.build_source_profiles(
Expand Down Expand Up @@ -652,7 +652,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
right_face_grad_constraint=None,
right_face_constraint=initial_right_boundary,
)
core_profiles_t_plus_dt = initial_states.initial_core_profiles(
core_profiles_t_plus_dt = update_core_profiles.initial_core_profiles(
dynamic_config_slice, static_config_slice_theta0, geo
)
core_profiles_t_plus_dt = dataclasses.replace(
Expand Down
51 changes: 38 additions & 13 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@
from absl import logging
import jax
import jax.numpy as jnp
from torax import boundary_conditions
from torax import calc_coeffs
from torax import config as config_lib
from torax import config_slice
from torax import fvm
from torax import geometry
from torax import initial_states
from torax import jax_utils
from torax import physics
from torax import state
from torax import update_core_profiles
from torax.sources import source_models as source_models_lib
from torax.sources import source_profiles as source_profiles_lib
from torax.spectators import spectator as spectator_lib
Expand Down Expand Up @@ -77,6 +76,7 @@ class CoeffsCallback:
Attributes:
core_profiles_t: The core plasma profiles at the start of the time step.
core_profiles_t_plus_dt: Core plasma profiles at the end of the time step.
evolving_names: The names of the evolving variables.
geo: See the docstring for `stepper.Stepper`.
static_config_slice: See the docstring for `stepper.Stepper`.
Expand All @@ -88,6 +88,7 @@ class CoeffsCallback:
def __init__(
self,
core_profiles_t: state.CoreProfiles,
core_profiles_t_plus_dt: state.CoreProfiles,
evolving_names: tuple[str, ...],
geo: geometry.Geometry,
static_config_slice: config_slice.StaticConfigSlice,
Expand All @@ -96,6 +97,7 @@ def __init__(
source_models: source_models_lib.SourceModels,
):
self.core_profiles_t = core_profiles_t
self.core_profiles_t_plus_dt = core_profiles_t_plus_dt
self.evolving_names = evolving_names
self.geo = geo
self.static_config_slice = static_config_slice
Expand All @@ -113,10 +115,14 @@ def __call__(
explicit_call: bool = False,
):
replace = {k: v for k, v in zip(self.evolving_names, x)}
# TODO( b/326579003) revisit due to prescribed profiles
core_profiles = config_lib.recursive_replace(
self.core_profiles_t, **replace
)
if explicit_call:
core_profiles = config_lib.recursive_replace(
self.core_profiles_t, **replace
)
else:
core_profiles = config_lib.recursive_replace(
self.core_profiles_t_plus_dt, **replace
)
# update ion density in core_profiles if ne is being evolved.
# Necessary for consistency in iterative nonlinear solutions
if 'ne' in self.evolving_names:
Expand Down Expand Up @@ -300,7 +306,10 @@ def __call__(
# PDE system.
# TODO( b/326579003)
core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt(
core_profiles_t, dynamic_config_slice_t_plus_dt, geo
core_profiles_t=core_profiles_t,
dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt,
static_config_slice=static_config_slice,
geo=geo,
)

stepper_iterations = 0
Expand Down Expand Up @@ -361,7 +370,10 @@ def body_fun(
input_state.t + dt,
)
core_profiles_t_plus_dt = provide_core_profiles_t_plus_dt(
core_profiles_t, dynamic_config_slice_t_plus_dt, geo
core_profiles_t=core_profiles_t,
dynamic_config_slice_t_plus_dt=dynamic_config_slice_t_plus_dt,
static_config_slice=static_config_slice,
geo=geo,
)
core_profiles, core_sources, core_transport, stepper_error_state = (
self._stepper_fn(
Expand Down Expand Up @@ -426,7 +438,7 @@ def get_initial_state(
source_models: source_models_lib.SourceModels,
) -> state.ToraxSimState:
"""Returns the initial state to be used by run_simulation()."""
initial_core_profiles = initial_states.initial_core_profiles(
initial_core_profiles = update_core_profiles.initial_core_profiles(
dynamic_config_slice, static_config_slice, geo, source_models
)
return state.ToraxSimState(
Expand Down Expand Up @@ -1083,26 +1095,39 @@ def update_psidot(
def provide_core_profiles_t_plus_dt(
core_profiles_t: state.CoreProfiles,
dynamic_config_slice_t_plus_dt: config_slice.DynamicConfigSlice,
static_config_slice: config_slice.StaticConfigSlice,
geo: geometry.Geometry,
) -> state.CoreProfiles:
"""Provides state at t_plus_dt with new boundary conditions and prescribed profiles."""
updated_boundary_conditions = boundary_conditions.compute_boundary_conditions(
dynamic_config_slice_t_plus_dt,
geo,
updated_boundary_conditions = (
update_core_profiles.compute_boundary_conditions(
dynamic_config_slice_t_plus_dt,
geo,
)
)
updated_values = update_core_profiles.update_prescribed_core_profiles(
core_profiles=core_profiles_t,
dynamic_config_slice=dynamic_config_slice_t_plus_dt,
static_config_slice=static_config_slice,
geo=geo,
)
temp_ion = dataclasses.replace(
core_profiles_t.temp_ion,
value=updated_values['temp_ion'],
**updated_boundary_conditions['temp_ion'],
)
temp_el = dataclasses.replace(
core_profiles_t.temp_el,
value=updated_values['temp_el'],
**updated_boundary_conditions['temp_el'],
)
psi = dataclasses.replace(
core_profiles_t.psi, **updated_boundary_conditions['psi']
)
ne = dataclasses.replace(
core_profiles_t.ne, **updated_boundary_conditions['ne']
core_profiles_t.ne,
value=updated_values['ne'],
**updated_boundary_conditions['ne'],
)
ni = dataclasses.replace(
core_profiles_t.ni,
Expand Down
4 changes: 2 additions & 2 deletions torax/sources/tests/bootstrap_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torax import config as config_lib
from torax import config_slice
from torax import geometry
from torax import initial_states
from torax import update_core_profiles
from torax.sources import bootstrap_current_source
from torax.sources import source as source_lib
from torax.sources import source_config
Expand All @@ -47,7 +47,7 @@ def test_source_value(self):
source = bootstrap_current_source.BootstrapCurrentSource()
config = config_lib.Config()
geo = geometry.build_circular_geometry(config)
core_profiles = initial_states.initial_core_profiles(
core_profiles = update_core_profiles.initial_core_profiles(
dynamic_config_slice=config_slice.build_dynamic_config_slice(config),
static_config_slice=config_slice.build_static_config_slice(config),
geo=geo,
Expand Down
4 changes: 2 additions & 2 deletions torax/sources/tests/fusion_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
from torax import config_slice
from torax import constants
from torax import initial_states
from torax import update_core_profiles
from torax.sources import fusion_heat_source
from torax.sources import source
from torax.sources import source_config
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_calc_fusion(
geo = references.geo
nref = config.nref

core_profiles = initial_states.initial_core_profiles(
core_profiles = update_core_profiles.initial_core_profiles(
dynamic_config_slice=config_slice.build_dynamic_config_slice(config),
static_config_slice=config_slice.build_static_config_slice(config),
geo=geo,
Expand Down
6 changes: 3 additions & 3 deletions torax/sources/tests/qei_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torax import config as config_lib
from torax import config_slice
from torax import geometry
from torax import initial_states
from torax import update_core_profiles
from torax.sources import qei_source
from torax.sources import source as source_lib
from torax.sources import source_config
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_source_value(self):
source = qei_source.QeiSource()
config = config_lib.Config()
geo = geometry.build_circular_geometry(config)
core_profiles = initial_states.initial_core_profiles(
core_profiles = update_core_profiles.initial_core_profiles(
dynamic_config_slice=config_slice.build_dynamic_config_slice(config),
static_config_slice=config_slice.build_static_config_slice(config),
geo=geo,
Expand All @@ -70,7 +70,7 @@ def test_invalid_source_types_raise_errors(self):
source = qei_source.QeiSource()
config = config_lib.Config()
geo = geometry.build_circular_geometry(config)
core_profiles = initial_states.initial_core_profiles(
core_profiles = update_core_profiles.initial_core_profiles(
dynamic_config_slice=config_slice.build_dynamic_config_slice(config),
static_config_slice=config_slice.build_static_config_slice(config),
geo=geo,
Expand Down
Loading

0 comments on commit 41ceb3b

Please sign in to comment.