From 1391c9b6c28ea5dac28bc678361b31a6ed5d2e4b Mon Sep 17 00:00:00 2001 From: Ian Goodfellow Date: Tue, 16 Jul 2024 17:31:29 -0700 Subject: [PATCH] Add `source_models` to model_func interface. This allows cleaning up OhmicHeatSource, whose `mode_func` technically doesn't fit the interface due to the `self` argument. OhmicHeatSource had to use __post_init__ to install model_func, which won't be allowed anymore once the Sources are immutable. PiperOrigin-RevId: 653037490 --- torax/sources/bremsstrahlung_heat_sink.py | 2 + torax/sources/electron_density_sources.py | 4 ++ torax/sources/external_current_source.py | 19 ++++-- torax/sources/formulas.py | 7 ++- torax/sources/fusion_heat_source.py | 6 +- torax/sources/generic_ion_el_heat_source.py | 4 +- torax/sources/source.py | 33 +++++++--- torax/sources/source_models.py | 69 ++++++++++----------- torax/sources/tests/source.py | 12 ++-- torax/sources/tests/source_models.py | 1 + torax/tests/sim_custom_sources.py | 1 + torax/tests/sim_output_source_profiles.py | 1 + 12 files changed, 99 insertions(+), 60 deletions(-) diff --git a/torax/sources/bremsstrahlung_heat_sink.py b/torax/sources/bremsstrahlung_heat_sink.py index 8ec47bdb..0d01aa51 100644 --- a/torax/sources/bremsstrahlung_heat_sink.py +++ b/torax/sources/bremsstrahlung_heat_sink.py @@ -29,6 +29,7 @@ from torax.config import runtime_params_slice from torax.sources import runtime_params as runtime_params_lib from torax.sources import source +from torax.sources import source_models @chex.dataclass(frozen=True) @@ -113,6 +114,7 @@ def bremsstrahlung_model_func( dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + unused_model_func: source_models.SourceModels | None, ) -> jax.Array: """Model function for the Bremsstrahlung heat sink.""" assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams) diff --git a/torax/sources/electron_density_sources.py b/torax/sources/electron_density_sources.py index 12c44e6d..ae594beb 100644 --- a/torax/sources/electron_density_sources.py +++ b/torax/sources/electron_density_sources.py @@ -27,6 +27,7 @@ from torax.sources import formulas from torax.sources import runtime_params as runtime_params_lib from torax.sources import source +from torax.sources import source_models # pylint: disable=invalid-name @@ -64,6 +65,7 @@ def _calc_puff_source( dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, + unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from puffs.""" assert isinstance(dynamic_source_runtime_params, DynamicGasPuffRuntimeParams) @@ -122,6 +124,7 @@ def _calc_nbi_source( dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, + unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from SBI.""" assert isinstance( @@ -184,6 +187,7 @@ def _calc_pellet_source( dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, + unused_source_models: source_models.SourceModels | None = None, ) -> jax.Array: """Calculates external source term for n from pellets.""" assert isinstance(dynamic_source_runtime_params, DynamicPelletRuntimeParams) diff --git a/torax/sources/external_current_source.py b/torax/sources/external_current_source.py index 26504db5..4b175c1b 100644 --- a/torax/sources/external_current_source.py +++ b/torax/sources/external_current_source.py @@ -17,6 +17,7 @@ from __future__ import annotations import dataclasses +from typing import Optional import chex import jax @@ -87,11 +88,13 @@ def __post_init__(self): _trapz = integrate.trapezoid -def _calculate_jext_face( +# pytype bug: does not treat 'source_models.SourceModels' as a forward reference +def _calculate_jext_face( # pytype: disable=name-error dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, + unused_source_models: Optional['source_models.SourceModels'] = None, ) -> jax.Array: """Calculates the external current density profiles. @@ -123,11 +126,13 @@ def _calculate_jext_face( return jext_face -def _calculate_jext_hires( +# pytype bug: does not treat 'source_models.SourceModels' as a forward reference +def _calculate_jext_hires( # pytype: disable=name-error dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None = None, + unused_source_models: Optional['source_models.SourceModels'] = None, ) -> jax.Array: """Calculates the external current density profile along the hires grid. @@ -212,10 +217,13 @@ def get_value( core_profiles=core_profiles, # There is no model implementation. model_func=( - lambda _0, _1, _2, _3: source.ProfileType.FACE.get_zero_profile(geo) + lambda _0, _1, _2, _3, _4: source.ProfileType.FACE.get_zero_profile( + geo + ) ), formula=self.formula, output_shape=source.ProfileType.FACE.get_profile_shape(geo), + source_models=getattr(self, 'source_models', None), ) return profile, geometry.face_to_cell(profile) @@ -234,9 +242,12 @@ def jext_hires( geo=geo, core_profiles=None, # There is no model for this source. - model_func=(lambda _0, _1, _2, _3: jnp.zeros_like(geo.r_hires_norm)), + model_func=( + lambda _0, _1, _2, _3, _4: jnp.zeros_like(geo.r_hires_norm) + ), formula=self.hires_formula, output_shape=geo.r_hires_norm.shape, + source_models=getattr(self, 'source_models', None), ) def get_source_profile_for_affected_core_profile( diff --git a/torax/sources/formulas.py b/torax/sources/formulas.py index ddddcfb8..69a7cedd 100644 --- a/torax/sources/formulas.py +++ b/torax/sources/formulas.py @@ -15,6 +15,7 @@ """Prescribed formulas for computing source profiles.""" import dataclasses +from typing import Optional import jax from jax import numpy as jnp from torax import geometry @@ -116,12 +117,13 @@ def gaussian_profile( class Exponential: """Callable class providing an exponential profile.""" - def __call__( + def __call__( # pytype: disable=name-error self, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None, + unused_source_models: Optional['source_models.SourceModels'] = None, ) -> jax.Array: exp_config = dynamic_source_runtime_params.formula assert isinstance(exp_config, formula_config.DynamicExponential) @@ -138,12 +140,13 @@ def __call__( class Gaussian: """Callable class providing a gaussian profile.""" - def __call__( + def __call__( # pytype: disable=name-error self, dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state.CoreProfiles | None, + unused_source_models: Optional['source_models.SourceModels'] = None, ) -> jax.Array: gaussian_config = dynamic_source_runtime_params.formula assert isinstance(gaussian_config, formula_config.DynamicGaussian) diff --git a/torax/sources/fusion_heat_source.py b/torax/sources/fusion_heat_source.py index 40ef550f..fcfa0df9 100644 --- a/torax/sources/fusion_heat_source.py +++ b/torax/sources/fusion_heat_source.py @@ -17,6 +17,7 @@ from __future__ import annotations import dataclasses +from typing import Optional import jax from jax import numpy as jnp @@ -120,12 +121,15 @@ def calc_fusion( return Ptot, Pfus_i, Pfus_e -def fusion_heat_model_func( +# pytype bug: does not treat 'source_models.SourceModels' as forward reference +def fusion_heat_model_func( # pytype: disable=name-error dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + unused_source_models: Optional['source_models.SourceModels'], ) -> jax.Array: + """Model function for fusion heating.""" del dynamic_source_runtime_params # Unused. # pylint: disable=invalid-name _, Pfus_i, Pfus_e = calc_fusion( diff --git a/torax/sources/generic_ion_el_heat_source.py b/torax/sources/generic_ion_el_heat_source.py index ed5d456a..2c114655 100644 --- a/torax/sources/generic_ion_el_heat_source.py +++ b/torax/sources/generic_ion_el_heat_source.py @@ -17,6 +17,7 @@ from __future__ import annotations import dataclasses +from typing import Optional import chex import jax @@ -101,11 +102,12 @@ def calc_generic_heat_source( return source_ion, source_el -def _default_formula( +def _default_formula( # pytype: disable=name-error dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, core_profiles: state.CoreProfiles, + unused_source_models: Optional['source_models.SourceModels'], ) -> jax.Array: """Returns the default formula-based ion/electron heat source profile.""" del dynamic_runtime_params_slice, core_profiles # Unused. diff --git a/torax/sources/source.py b/torax/sources/source.py index 3d4d70a8..70f64e85 100644 --- a/torax/sources/source.py +++ b/torax/sources/source.py @@ -27,7 +27,11 @@ import enum import types import typing -from typing import Any, Callable, Protocol +from typing import Any, Callable, Optional, Protocol + +# We use Optional here because | doesn't work with string name types. +# We use string name 'source_models.SourceModels' in this file to avoid +# circular imports. import chex import jax @@ -40,12 +44,14 @@ # Sources implement these functions to be able to provide source profiles. -SourceProfileFunction = Callable[ +# pytype bug: 'source_models.SourceModels' not treated as forward reference +SourceProfileFunction = Callable[ # pytype: disable=name-error [ # Arguments runtime_params_slice.DynamicRuntimeParamsSlice, # General config params runtime_params_lib.DynamicRuntimeParams, # Source-specific params. geometry.Geometry, state.CoreProfiles | None, + Optional['source_models.SourceModels'], ], # Returns a JAX array, tuple of arrays, or mapping of arrays. chex.ArrayTree, @@ -207,12 +213,12 @@ def get_value( self.check_mode(dynamic_source_runtime_params.mode) output_shape = self.output_shape_getter(geo) model_func = ( - (lambda _0, _1, _2, _3: jnp.zeros(output_shape)) + (lambda _0, _1, _2, _3, _4: jnp.zeros(output_shape)) if self.model_func is None else self.model_func ) formula = ( - (lambda _0, _1, _2, _3: jnp.zeros(output_shape)) + (lambda _0, _1, _2, _3, _4: jnp.zeros(output_shape)) if self.formula is None else self.formula ) @@ -224,6 +230,7 @@ def get_value( model_func=model_func, formula=formula, output_shape=output_shape, + source_models=getattr(self, 'source_models', None), ) def get_source_profile_for_affected_core_profile( @@ -287,7 +294,7 @@ class SingleProfileSource(Source): .. code-block:: python # Define an electron-density source with a Gaussian profile. - my_custom_source = source.SingleProfileSource( + my_custom_source_builder = source.SingleProfileSourceBuilder( supported_modes=( runtime_params_lib.Mode.ZERO, runtime_params_lib.Mode.FORMULA_BASED, @@ -297,7 +304,7 @@ class SingleProfileSource(Source): ) # Define its runtime parameters (this could be done in the constructor as # well). - my_custom_source.runtime_params = runtime_params_lib.RuntimeParams( + my_custom_source_builder.runtime_params = runtime_params_lib.RuntimeParams( mode=runtime_params_lib.Mode.FORMULA_BASED, formula=formula_config.Gaussian( total=1.0, @@ -305,9 +312,9 @@ class SingleProfileSource(Source): c2=3.0, ), ) - all_torax_sources = source_models_lib.SourceModels( - sources={ - 'my_custom_source': my_custom_source, + all_torax_sources_builder = source_models_lib.SourceModelsBuilder( + sources_builder={ + 'my_custom_source': my_custom_source_builder, } ) @@ -345,6 +352,7 @@ def _my_foo_model( dynamic_source_runtime_params, geo, core_profiles, + source_models, ) -> jax.Array: assert isinstance(dynamic_source_runtime_params, DynamicFooRuntimeParams) # implement your foo model. @@ -435,7 +443,8 @@ def get_zero_profile(self, geo: geometry.Geometry) -> jax.Array: return jnp.zeros(self.get_profile_shape(geo)) -def get_source_profiles( +# pytype bug: 'source_models.SourceModels' not treated as a forward ref +def get_source_profiles( # pytype: disable=name-error dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, @@ -443,6 +452,7 @@ def get_source_profiles( model_func: SourceProfileFunction, formula: SourceProfileFunction, output_shape: tuple[int, ...], + source_models: Optional['source_models.SourceModels'], ) -> jax.Array: """Returns source profiles requested by the runtime_params_lib. @@ -460,6 +470,7 @@ def get_source_profiles( model_func: Model function. formula: Formula implementation. output_shape: Expected shape of the outut array. + source_models: The SourceModels if the Source `links_back` Returns: Output array of a profile or concatenated/stacked profiles. @@ -474,6 +485,7 @@ def get_source_profiles( dynamic_source_runtime_params, geo, core_profiles, + source_models, ), zeros, ) @@ -484,6 +496,7 @@ def get_source_profiles( dynamic_source_runtime_params, geo, core_profiles, + source_models, ), zeros, ) diff --git a/torax/sources/source_models.py b/torax/sources/source_models.py index b9eb044a..b0ab33f8 100644 --- a/torax/sources/source_models.py +++ b/torax/sources/source_models.py @@ -521,6 +521,35 @@ def calc_psidot( return psidot +def ohmic_model_func( + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + source_models: SourceModels | None = None, +) -> jax.Array: + """Returns the Ohmic source for electron heat equation.""" + del dynamic_source_runtime_params + + if source_models is None: + raise TypeError('source_models is a required argument for ohmic_model_func') + + jtot, _ = physics.calc_jtot_from_psi( + geo, + core_profiles.psi, + ) + + psidot = calc_psidot( + dynamic_runtime_params_slice, + geo, + core_profiles, + source_models, + ) + + pohm = jtot * psidot / (2 * jnp.pi * geo.Rmaj) + return pohm + + # OhmicHeatSource is a special case and defined here to avoid circular # dependencies, since it depends on the psi sources @dataclasses.dataclass(kw_only=True) @@ -563,48 +592,16 @@ class OhmicHeatSource(source_lib.SingleProfileSource): ) ) - # The model function is fixed to self._model_func because that is the only + # The model function is fixed to ohmic_model_func because that is the only # supported implementation of this source. # However, since this is a param in the parent dataclass, we need to (a) - # remove the parameter from the init args and (b) set it to the correct - # function in __post_init__(). - # - # We cannot simply define a function `def model_func()` and use that because - # that definition would be overridden in this classes dataclass constructor. - # Also, the `__post_init__()` is required because it allows access to `self`, - # which is required for this model function implementation. + # remove the parameter from the init args and (b) set the default to the + # desired value. model_func: source_lib.SourceProfileFunction | None = dataclasses.field( init=False, - default_factory=lambda: None, + default_factory=lambda: ohmic_model_func, ) - def __post_init__(self): - self.model_func = self._model_func - - def _model_func( - self, - dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, - dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, - geo: geometry.Geometry, - core_profiles: state.CoreProfiles, - ) -> jax.Array: - """Returns the Ohmic source for electron heat equation.""" - del dynamic_source_runtime_params - jtot, _ = physics.calc_jtot_from_psi( - geo, - core_profiles.psi, - ) - - psidot = calc_psidot( - dynamic_runtime_params_slice, - geo, - core_profiles, - self.source_models, - ) - - pohm = jtot * psidot / (2 * jnp.pi * geo.Rmaj) - return pohm - OhmicHeatSourceBuilder = source_lib.make_source_builder( OhmicHeatSource, links_back=True diff --git a/torax/sources/tests/source.py b/torax/sources/tests/source.py index ee94ddd1..37a3a027 100644 --- a/torax/sources/tests/source.py +++ b/torax/sources/tests/source.py @@ -259,7 +259,7 @@ def test_overriding_default_formula(self): expected_output = jnp.ones(output_shape) source_builder = source_lib.SourceBuilder( output_shape_getter=lambda _0: output_shape, - formula=lambda _0, _1, _2, _3: expected_output, + formula=lambda _0, _1, _2, _3, _4: expected_output, affected_core_profiles=( source_lib.AffectedCoreProfile.TEMP_ION, source_lib.AffectedCoreProfile.TEMP_EL, @@ -302,7 +302,7 @@ def test_overriding_model(self): source_builder = source_lib.SourceBuilder( supported_modes=(runtime_params_lib.Mode.MODEL_BASED,), output_shape_getter=lambda _0: output_shape, - model_func=lambda _0, _1, _2, _3: expected_output, + model_func=lambda _0, _1, _2, _3, _4: expected_output, affected_core_profiles=( source_lib.AffectedCoreProfile.TEMP_ION, source_lib.AffectedCoreProfile.TEMP_EL, @@ -345,7 +345,7 @@ def test_retrieving_profile_for_affected_state(self): source = source_lib.Source( supported_modes=(runtime_params_lib.Mode.MODEL_BASED,), output_shape_getter=lambda _0: output_shape, - model_func=lambda _0, _1, _2, _3: profile, + model_func=lambda _0, _1, _2, _3, _4: profile, affected_core_profiles=( source_lib.AffectedCoreProfile.PSI, source_lib.AffectedCoreProfile.NE, @@ -379,7 +379,7 @@ def test_custom_formula(self): geo = geometry.build_circular_geometry(nr=5) expected_output = jnp.ones(5) # 5 matches the geo. source_builder = source_lib.SingleProfileSourceBuilder( - formula=lambda _0, _1, _2, _3: expected_output, + formula=lambda _0, _1, _2, _3, _4: expected_output, affected_core_profiles=(source_lib.AffectedCoreProfile.PSI,), ) source_builder.runtime_params.mode = runtime_params_lib.Mode.FORMULA_BASED @@ -413,7 +413,7 @@ def test_custom_formula(self): def test_multiple_profiles_raises_error(self): """A formula which outputs the wrong shape will raise an error.""" source_builder = source_lib.SingleProfileSourceBuilder( - formula=lambda _0, _1, _2, _3: jnp.ones((2, 5)), + formula=lambda _0, _1, _2, _3, _4: jnp.ones((2, 5)), affected_core_profiles=( source_lib.AffectedCoreProfile.TEMP_ION, source_lib.AffectedCoreProfile.NE, @@ -455,7 +455,7 @@ def test_retrieving_profile_for_affected_state(self): profile = jnp.asarray([1, 2, 3, 4]) # from get_value() source = source_lib.SingleProfileSource( supported_modes=(runtime_params_lib.Mode.MODEL_BASED,), - model_func=lambda _0, _1, _2, _3: profile, + model_func=lambda _0, _1, _2, _3, _4: profile, affected_core_profiles=(source_lib.AffectedCoreProfile.NE,), ) geo = geometry.build_circular_geometry(nr=4) diff --git a/torax/sources/tests/source_models.py b/torax/sources/tests/source_models.py index 954dc371..de946f02 100644 --- a/torax/sources/tests/source_models.py +++ b/torax/sources/tests/source_models.py @@ -134,6 +134,7 @@ def foo_formula( unused_sc, geo: geometry.Geometry, unused_state, + unused_source_models, ): return jnp.stack([ jnp.zeros(source_lib.ProfileType.CELL.get_profile_shape(geo)), diff --git a/torax/tests/sim_custom_sources.py b/torax/tests/sim_custom_sources.py index 40c25e3a..b710d6be 100644 --- a/torax/tests/sim_custom_sources.py +++ b/torax/tests/sim_custom_sources.py @@ -54,6 +54,7 @@ def custom_source_formula( dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams, geo: geometry.Geometry, unused_state: state_lib.CoreProfiles | None, + unused_source_models: ..., ): # Combine the outputs. assert isinstance( diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index 638f243d..207338c4 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -131,6 +131,7 @@ def custom_source_formula( source_conf, geo, unused_state, + unused_source_models, ): return jnp.ones_like(geo.r) * source_conf.foo