Skip to content

Commit

Permalink
Change interpolated var 2D to do rho interpolation on init
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653247334
  • Loading branch information
tamaranorman authored and Torax team committed Jul 18, 2024
1 parent 4a86307 commit d3cd672
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 125 deletions.
30 changes: 28 additions & 2 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,33 @@ def build_chease_geometry(
return geo


def build_chease_geometry_provider(
Ip_from_parameters: bool = True,
geometry_dir: str | None = None,
**kwargs,
) -> geometry_provider.GeometryProvider:
"""Constructs a geometry provider from CHEASE file or series of files."""
if 'geometry_configs' in kwargs:
if not isinstance(kwargs['geometry_configs'], dict):
raise ValueError('geometry_configs must be a dict.')
geometries = {}
for time, config in kwargs['geometry_configs'].items():
geometries[time] = build_chease_geometry(
Ip_from_parameters=Ip_from_parameters,
geometry_dir=geometry_dir,
**config,
)
return geometry_provider.TimeDependentGeometryProvider(
geometry.StandardGeometryProvider.create_provider(geometries))
return geometry_provider.ConstantGeometryProvider(
build_chease_geometry(
Ip_from_parameters=Ip_from_parameters,
geometry_dir=geometry_dir,
**kwargs,
)
)


def build_sim_from_config(
config: dict[str, Any],
) -> sim_lib.Sim:
Expand Down Expand Up @@ -236,8 +263,7 @@ def build_geometry_provider_from_config(
return geometry_provider.ConstantGeometryProvider(
geometry.build_circular_geometry(**kwargs))
elif geometry_type == 'chease':
return geometry_provider.ConstantGeometryProvider(
build_chease_geometry(**kwargs))
return build_chease_geometry_provider(**kwargs)
raise ValueError(f'Unknown geometry type: {geometry_type}')


Expand Down
4 changes: 2 additions & 2 deletions torax/config/config_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ def _interpolate_var_2d(
):
# Dealing with a param input so convert it first.
param_or_param_input = interpolated_param.InterpolatedVarTimeRho(
values=param_or_param_input,
values=param_or_param_input, rho=geo.torax_mesh.face_centers
)
return param_or_param_input.get_value(t, geo.torax_mesh.face_centers)
return param_or_param_input.get_value(t)


def get_init_kwargs(
Expand Down
24 changes: 15 additions & 9 deletions torax/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Mapping
import dataclasses
import enum
import functools
from typing import Type

import chex
Expand Down Expand Up @@ -46,7 +47,7 @@ class Grid1D:
"""

nx: int
dx: chex.Numeric
dx: float
face_centers: chex.Array
cell_centers: chex.Array

Expand All @@ -59,13 +60,16 @@ def __post_init__(self):
def __eq__(self, other: Grid1D) -> bool:
return (
self.nx == other.nx
and np.array_equal(self.dx, other.dx)
and self.dx == other.dx
and np.array_equal(self.face_centers, other.face_centers)
and np.array_equal(self.cell_centers, other.cell_centers)
)

def __hash__(self) -> int:
return hash((self.nx, self.dx))

@classmethod
def construct(cls, nx: int, dx: chex.Array) -> Grid1D:
def construct(cls, nx: int, dx: float) -> Grid1D:
"""Constructs a Grid1D.
Args:
Expand Down Expand Up @@ -298,7 +302,7 @@ def create_provider(
):
continue
kwargs[attr.name] = interpolated_param.InterpolatedVarSingleAxis(
(times, np.stack([getattr(g, attr.name) for g in geos], axis=-1))
(times, np.stack([getattr(g, attr.name) for g in geos], axis=0))
)
return cls(**kwargs)

Expand All @@ -320,6 +324,7 @@ def _get_geometry_base(self, t: chex.Numeric, geometry_class: Type[Geometry]):
kwargs[attr.name] = getattr(self, attr.name).get_value(t)
return geometry_class(**kwargs) # pytype: disable=wrong-keyword-args

@functools.partial(jax_utils.jit, static_argnums=0)
def get_geometry(self, t: chex.Numeric) -> Geometry:
"""Returns a Geometry instance at the given time."""
return self._get_geometry_base(t, Geometry)
Expand Down Expand Up @@ -365,6 +370,7 @@ class StandardGeometryProvider(GeometryProvider):
delta_upper_face: interpolated_param.InterpolatedVarSingleAxis
delta_lower_face: interpolated_param.InterpolatedVarSingleAxis

@functools.partial(jax_utils.jit, static_argnums=0)
def get_geometry(self, t: chex.Numeric) -> Geometry:
"""Returns a Geometry instance at the given time."""
return self._get_geometry_base(t, StandardGeometry)
Expand Down Expand Up @@ -402,7 +408,7 @@ def build_circular_geometry(
# r_norm coordinate is r/Rmin in circular, and rho_norm in standard
# geometry (CHEASE/EQDSK)
# Define mesh (Slab Uniform 1D with Jacobian = 1)
dr_norm = np.array(1) / nr
dr_norm = 1. / nr
mesh = Grid1D.construct(nx=nr, dx=dr_norm)
rmax = np.asarray(Rmin)
# helper variables for mesh cells and faces
Expand Down Expand Up @@ -537,7 +543,7 @@ def build_circular_geometry(
return CircularAnalyticalGeometry(
# Set the standard geometry params.
geometry_type=GeometryType.CIRCULAR.value,
dr_norm=dr_norm,
dr_norm=np.asarray(dr_norm),
torax_mesh=mesh,
rmax=rmax,
Rmaj=Rmaj,
Expand Down Expand Up @@ -635,7 +641,7 @@ class StandardGeometryIntermediates:
psi: chex.Array
Ip_profile: chex.Array
rho: chex.Array
rhon: chex.Array
rhon: np.ndarray
Rin: chex.Array
Rout: chex.Array
RBphi: chex.Array
Expand Down Expand Up @@ -816,7 +822,7 @@ def build_standard_geometry(

# fill geometry structure
# r_norm coordinate is rho_tor_norm
dr_norm = intermediate.rhon[-1] / intermediate.nr
dr_norm = float(intermediate.rhon[-1]) / intermediate.nr
# normalized grid
mesh = Grid1D.construct(nx=intermediate.nr, dx=dr_norm)
rmax = intermediate.rho[-1] # radius denormalization constant
Expand Down Expand Up @@ -897,7 +903,7 @@ def build_standard_geometry(

return StandardGeometry(
geometry_type=GeometryType.CHEASE.value,
dr_norm=dr_norm,
dr_norm=np.asarray(dr_norm),
torax_mesh=mesh,
rmax=rmax,
Rmaj=intermediate.Rmaj,
Expand Down
9 changes: 4 additions & 5 deletions torax/geometry_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
"""File I/O for loading geometry files."""
import os

import jax
import jax.numpy as jnp
import numpy as np


def initialize_CHEASE_dict( # pylint: disable=invalid-name
file_path: str,
) -> dict[str, jax.Array]:
) -> dict[str, np.ndarray]:
"""Loads the data from a CHEASE file into a dictionary."""
# pyformat: disable
with open(file_path, 'r') as file:
Expand All @@ -40,14 +39,14 @@ def initialize_CHEASE_dict( # pylint: disable=invalid-name

# Convert lists to jax arrays.
return {
var_label: jnp.array(chease_data[var_label]) for var_label in chease_data
var_label: np.asarray(chease_data[var_label]) for var_label in chease_data
}


def load_chease_data(
geometry_dir: str | None,
geometry_file: str,
) -> dict[str, jax.Array]:
) -> dict[str, np.ndarray]:
"""Loads the data from a CHEASE file into a dictionary."""
# The code below does not use os.environ.get() in order to support an internal
# version of the code.
Expand Down
19 changes: 15 additions & 4 deletions torax/geometry_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ class ConstantGeometryProvider(GeometryProvider):
def __init__(self, geo: geometry.Geometry):
self._geo = geo

def __call__(
self,
t: chex.Numeric,
) -> geometry.Geometry:
def __call__(self, t: chex.Numeric) -> geometry.Geometry:
# The API includes time as an arg even though it is unused in order
# to match the API of a GeometryProvider.
del t # Ignored.
Expand All @@ -93,3 +90,17 @@ def __call__(
@property
def torax_mesh(self) -> geometry.Grid1D:
return self._geo.torax_mesh


class TimeDependentGeometryProvider(GeometryProvider):
"""Returns a Geometry that changes over time."""

def __init__(self, geometry_provider: geometry.GeometryProvider):
self._geometry_provider = geometry_provider

def __call__(self, t: chex.Numeric) -> geometry.Geometry:
return self._geometry_provider.get_geometry(t)

@property
def torax_mesh(self) -> geometry.Grid1D:
return self._geometry_provider.torax_mesh
60 changes: 19 additions & 41 deletions torax/interpolated_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,20 +274,21 @@ def is_bool_param(self) -> bool:
return self._is_bool_param


class InterpolatedVarTimeRho:
class InterpolatedVarTimeRho(InterpolatedParamBase):
"""Interpolates on a grid (time, rho).
- Given `values` that map from time-values to `InterpolatedVarSingleAxis`s
that tell you how to interpolate along rho for different time values this
class linearly interpolates along time to provide a value at any (time, rho)
pair.
- For time values that are outside the range of `values` the closest defined
`InterpolatedVarSingleAxis` is used.
- NOTE: We assume that rho interpolation is fixed per simulation so take this
at init and take just time at get_value.
"""

def __init__(
self,
values: InterpolatedVarTimeRhoInput,
rho: chex.Numeric,
rho_interpolation_mode: InterpolationMode = (
InterpolationMode.PIECEWISE_LINEAR
),
Expand All @@ -307,46 +308,23 @@ def __init__(
raise ValueError('Indicies in values mapping must be unique.')
if not values:
raise ValueError('Values mapping must not be empty.')
self.times_values = {
v: InterpolatedVarSingleAxis(values[v], rho_interpolation_mode)
for v in values.keys()
}
self.sorted_indices = jnp.array(sorted(values.keys()))

def get_value(
self,
time: chex.Numeric,
rho: chex.Numeric,
) -> chex.Array:
"""Returns the value of this parameter interpolated at the given (time,rho).
This method is not jittable as it is.
Args:
time: The time-coordinate to interpolate at.
rho: The rho-coordinate to interpolate at.
Returns:
The value of the interpolated at the given (time,rho).
"""
# Find the index that is left of value which time is closest to.
left = jnp.searchsorted(self.sorted_indices, time, side='left')

# If time is either smaller or larger, than smallest and largest values
# we know how to interpolate for, use the boundary interpolater.
if left == 0:
return self.times_values[float(self.sorted_indices[0])].get_value(rho)
if left == len(self.sorted_indices):
return self.times_values[float(self.sorted_indices[-1])].get_value(rho)

# Interpolate between the two closest defined interpolaters.
left_time = float(self.sorted_indices[left - 1])
right_time = float(self.sorted_indices[left])
return self.times_values[left_time].get_value(rho) * (right_time - time) / (
right_time - left_time
) + self.times_values[right_time].get_value(rho) * (time - left_time) / (
right_time - left_time
rho_interpolated = np.stack(
[
InterpolatedVarSingleAxis(
values[v], rho_interpolation_mode
).get_value(rho)
for v in values
],
axis=0,
)
times = np.asarray(list(values.keys()))
self._time_interpolated_var = InterpolatedVarSingleAxis(
(times, rho_interpolated)
)

def get_value(self, x: chex.Numeric) -> chex.Array:
"""Returns the value of this parameter interpolated at x=time."""
return self._time_interpolated_var.get_value(x)

# In runtime_params, users should be able to either specify the
# InterpolatedVarSingleAxis/InterpolatedVarTimeRho object directly or the values
Expand Down
Loading

0 comments on commit d3cd672

Please sign in to comment.