Skip to content

Commit

Permalink
Add source_models to model_func interface.
Browse files Browse the repository at this point in the history
This allows cleaning up OhmicHeatSource, whose `mode_func` technically
doesn't fit the interface due to the `self` argument.
OhmicHeatSource had to use __post_init__ to install model_func,
which won't be allowed anymore once the Sources are immutable.

PiperOrigin-RevId: 653037490
  • Loading branch information
goodfeli authored and Torax team committed Jul 17, 2024
1 parent bb08d5a commit 1391c9b
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 60 deletions.
2 changes: 2 additions & 0 deletions torax/sources/bremsstrahlung_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torax.config import runtime_params_slice
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_models


@chex.dataclass(frozen=True)
Expand Down Expand Up @@ -113,6 +114,7 @@ def bremsstrahlung_model_func(
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
unused_model_func: source_models.SourceModels | None,
) -> jax.Array:
"""Model function for the Bremsstrahlung heat sink."""
assert isinstance(dynamic_source_runtime_params, DynamicRuntimeParams)
Expand Down
4 changes: 4 additions & 0 deletions torax/sources/electron_density_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torax.sources import formulas
from torax.sources import runtime_params as runtime_params_lib
from torax.sources import source
from torax.sources import source_models


# pylint: disable=invalid-name
Expand Down Expand Up @@ -64,6 +65,7 @@ def _calc_puff_source(
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
unused_state: state.CoreProfiles | None = None,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
"""Calculates external source term for n from puffs."""
assert isinstance(dynamic_source_runtime_params, DynamicGasPuffRuntimeParams)
Expand Down Expand Up @@ -122,6 +124,7 @@ def _calc_nbi_source(
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
unused_state: state.CoreProfiles | None = None,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
"""Calculates external source term for n from SBI."""
assert isinstance(
Expand Down Expand Up @@ -184,6 +187,7 @@ def _calc_pellet_source(
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
unused_state: state.CoreProfiles | None = None,
unused_source_models: source_models.SourceModels | None = None,
) -> jax.Array:
"""Calculates external source term for n from pellets."""
assert isinstance(dynamic_source_runtime_params, DynamicPelletRuntimeParams)
Expand Down
19 changes: 15 additions & 4 deletions torax/sources/external_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import Optional

import chex
import jax
Expand Down Expand Up @@ -87,11 +88,13 @@ def __post_init__(self):
_trapz = integrate.trapezoid


def _calculate_jext_face(
# pytype bug: does not treat 'source_models.SourceModels' as a forward reference
def _calculate_jext_face( # pytype: disable=name-error
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
unused_state: state.CoreProfiles | None = None,
unused_source_models: Optional['source_models.SourceModels'] = None,
) -> jax.Array:
"""Calculates the external current density profiles.
Expand Down Expand Up @@ -123,11 +126,13 @@ def _calculate_jext_face(
return jext_face


def _calculate_jext_hires(
# pytype bug: does not treat 'source_models.SourceModels' as a forward reference
def _calculate_jext_hires( # pytype: disable=name-error
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
unused_state: state.CoreProfiles | None = None,
unused_source_models: Optional['source_models.SourceModels'] = None,
) -> jax.Array:
"""Calculates the external current density profile along the hires grid.
Expand Down Expand Up @@ -212,10 +217,13 @@ def get_value(
core_profiles=core_profiles,
# There is no model implementation.
model_func=(
lambda _0, _1, _2, _3: source.ProfileType.FACE.get_zero_profile(geo)
lambda _0, _1, _2, _3, _4: source.ProfileType.FACE.get_zero_profile(
geo
)
),
formula=self.formula,
output_shape=source.ProfileType.FACE.get_profile_shape(geo),
source_models=getattr(self, 'source_models', None),
)
return profile, geometry.face_to_cell(profile)

Expand All @@ -234,9 +242,12 @@ def jext_hires(
geo=geo,
core_profiles=None,
# There is no model for this source.
model_func=(lambda _0, _1, _2, _3: jnp.zeros_like(geo.r_hires_norm)),
model_func=(
lambda _0, _1, _2, _3, _4: jnp.zeros_like(geo.r_hires_norm)
),
formula=self.hires_formula,
output_shape=geo.r_hires_norm.shape,
source_models=getattr(self, 'source_models', None),
)

def get_source_profile_for_affected_core_profile(
Expand Down
7 changes: 5 additions & 2 deletions torax/sources/formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Prescribed formulas for computing source profiles."""

import dataclasses
from typing import Optional
import jax
from jax import numpy as jnp
from torax import geometry
Expand Down Expand Up @@ -116,12 +117,13 @@ def gaussian_profile(
class Exponential:
"""Callable class providing an exponential profile."""

def __call__(
def __call__( # pytype: disable=name-error
self,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams,
geo: geometry.Geometry,
unused_state: state.CoreProfiles | None,
unused_source_models: Optional['source_models.SourceModels'] = None,
) -> jax.Array:
exp_config = dynamic_source_runtime_params.formula
assert isinstance(exp_config, formula_config.DynamicExponential)
Expand All @@ -138,12 +140,13 @@ def __call__(
class Gaussian:
"""Callable class providing a gaussian profile."""

def __call__(
def __call__( # pytype: disable=name-error
self,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params.DynamicRuntimeParams,
geo: geometry.Geometry,
unused_state: state.CoreProfiles | None,
unused_source_models: Optional['source_models.SourceModels'] = None,
) -> jax.Array:
gaussian_config = dynamic_source_runtime_params.formula
assert isinstance(gaussian_config, formula_config.DynamicGaussian)
Expand Down
6 changes: 5 additions & 1 deletion torax/sources/fusion_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import Optional

import jax
from jax import numpy as jnp
Expand Down Expand Up @@ -120,12 +121,15 @@ def calc_fusion(
return Ptot, Pfus_i, Pfus_e


def fusion_heat_model_func(
# pytype bug: does not treat 'source_models.SourceModels' as forward reference
def fusion_heat_model_func( # pytype: disable=name-error
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
unused_source_models: Optional['source_models.SourceModels'],
) -> jax.Array:
"""Model function for fusion heating."""
del dynamic_source_runtime_params # Unused.
# pylint: disable=invalid-name
_, Pfus_i, Pfus_e = calc_fusion(
Expand Down
4 changes: 3 additions & 1 deletion torax/sources/generic_ion_el_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import dataclasses
from typing import Optional

import chex
import jax
Expand Down Expand Up @@ -101,11 +102,12 @@ def calc_generic_heat_source(
return source_ion, source_el


def _default_formula(
def _default_formula( # pytype: disable=name-error
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
unused_source_models: Optional['source_models.SourceModels'],
) -> jax.Array:
"""Returns the default formula-based ion/electron heat source profile."""
del dynamic_runtime_params_slice, core_profiles # Unused.
Expand Down
33 changes: 23 additions & 10 deletions torax/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import enum
import types
import typing
from typing import Any, Callable, Protocol
from typing import Any, Callable, Optional, Protocol

# We use Optional here because | doesn't work with string name types.
# We use string name 'source_models.SourceModels' in this file to avoid
# circular imports.

import chex
import jax
Expand All @@ -40,12 +44,14 @@


# Sources implement these functions to be able to provide source profiles.
SourceProfileFunction = Callable[
# pytype bug: 'source_models.SourceModels' not treated as forward reference
SourceProfileFunction = Callable[ # pytype: disable=name-error
[ # Arguments
runtime_params_slice.DynamicRuntimeParamsSlice, # General config params
runtime_params_lib.DynamicRuntimeParams, # Source-specific params.
geometry.Geometry,
state.CoreProfiles | None,
Optional['source_models.SourceModels'],
],
# Returns a JAX array, tuple of arrays, or mapping of arrays.
chex.ArrayTree,
Expand Down Expand Up @@ -207,12 +213,12 @@ def get_value(
self.check_mode(dynamic_source_runtime_params.mode)
output_shape = self.output_shape_getter(geo)
model_func = (
(lambda _0, _1, _2, _3: jnp.zeros(output_shape))
(lambda _0, _1, _2, _3, _4: jnp.zeros(output_shape))
if self.model_func is None
else self.model_func
)
formula = (
(lambda _0, _1, _2, _3: jnp.zeros(output_shape))
(lambda _0, _1, _2, _3, _4: jnp.zeros(output_shape))
if self.formula is None
else self.formula
)
Expand All @@ -224,6 +230,7 @@ def get_value(
model_func=model_func,
formula=formula,
output_shape=output_shape,
source_models=getattr(self, 'source_models', None),
)

def get_source_profile_for_affected_core_profile(
Expand Down Expand Up @@ -287,7 +294,7 @@ class SingleProfileSource(Source):
.. code-block:: python
# Define an electron-density source with a Gaussian profile.
my_custom_source = source.SingleProfileSource(
my_custom_source_builder = source.SingleProfileSourceBuilder(
supported_modes=(
runtime_params_lib.Mode.ZERO,
runtime_params_lib.Mode.FORMULA_BASED,
Expand All @@ -297,17 +304,17 @@ class SingleProfileSource(Source):
)
# Define its runtime parameters (this could be done in the constructor as
# well).
my_custom_source.runtime_params = runtime_params_lib.RuntimeParams(
my_custom_source_builder.runtime_params = runtime_params_lib.RuntimeParams(
mode=runtime_params_lib.Mode.FORMULA_BASED,
formula=formula_config.Gaussian(
total=1.0,
c1=2.0,
c2=3.0,
),
)
all_torax_sources = source_models_lib.SourceModels(
sources={
'my_custom_source': my_custom_source,
all_torax_sources_builder = source_models_lib.SourceModelsBuilder(
sources_builder={
'my_custom_source': my_custom_source_builder,
}
)
Expand Down Expand Up @@ -345,6 +352,7 @@ def _my_foo_model(
dynamic_source_runtime_params,
geo,
core_profiles,
source_models,
) -> jax.Array:
assert isinstance(dynamic_source_runtime_params, DynamicFooRuntimeParams)
# implement your foo model.
Expand Down Expand Up @@ -435,14 +443,16 @@ def get_zero_profile(self, geo: geometry.Geometry) -> jax.Array:
return jnp.zeros(self.get_profile_shape(geo))


def get_source_profiles(
# pytype bug: 'source_models.SourceModels' not treated as a forward ref
def get_source_profiles( # pytype: disable=name-error
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
dynamic_source_runtime_params: runtime_params_lib.DynamicRuntimeParams,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles | None,
model_func: SourceProfileFunction,
formula: SourceProfileFunction,
output_shape: tuple[int, ...],
source_models: Optional['source_models.SourceModels'],
) -> jax.Array:
"""Returns source profiles requested by the runtime_params_lib.
Expand All @@ -460,6 +470,7 @@ def get_source_profiles(
model_func: Model function.
formula: Formula implementation.
output_shape: Expected shape of the outut array.
source_models: The SourceModels if the Source `links_back`
Returns:
Output array of a profile or concatenated/stacked profiles.
Expand All @@ -474,6 +485,7 @@ def get_source_profiles(
dynamic_source_runtime_params,
geo,
core_profiles,
source_models,
),
zeros,
)
Expand All @@ -484,6 +496,7 @@ def get_source_profiles(
dynamic_source_runtime_params,
geo,
core_profiles,
source_models,
),
zeros,
)
Expand Down
Loading

0 comments on commit 1391c9b

Please sign in to comment.