From d3cd67249f2e4df2e411aeaeaa02bab9c53348ed Mon Sep 17 00:00:00 2001 From: Tamara Norman Date: Wed, 17 Jul 2024 08:46:05 -0700 Subject: [PATCH] Change interpolated var 2D to do rho interpolation on init PiperOrigin-RevId: 653247334 --- torax/config/build_sim.py | 30 +++++++++- torax/config/config_args.py | 4 +- torax/geometry.py | 24 +++++--- torax/geometry_loader.py | 9 ++- torax/geometry_provider.py | 19 ++++-- torax/interpolated_param.py | 60 ++++++------------- torax/tests/geometry.py | 96 +++++++++++++++---------------- torax/tests/interpolated_param.py | 28 ++++----- 8 files changed, 145 insertions(+), 125 deletions(-) diff --git a/torax/config/build_sim.py b/torax/config/build_sim.py index 0051cc28..4e5b1bad 100644 --- a/torax/config/build_sim.py +++ b/torax/config/build_sim.py @@ -90,6 +90,33 @@ def build_chease_geometry( return geo +def build_chease_geometry_provider( + Ip_from_parameters: bool = True, + geometry_dir: str | None = None, + **kwargs, +) -> geometry_provider.GeometryProvider: + """Constructs a geometry provider from CHEASE file or series of files.""" + if 'geometry_configs' in kwargs: + if not isinstance(kwargs['geometry_configs'], dict): + raise ValueError('geometry_configs must be a dict.') + geometries = {} + for time, config in kwargs['geometry_configs'].items(): + geometries[time] = build_chease_geometry( + Ip_from_parameters=Ip_from_parameters, + geometry_dir=geometry_dir, + **config, + ) + return geometry_provider.TimeDependentGeometryProvider( + geometry.StandardGeometryProvider.create_provider(geometries)) + return geometry_provider.ConstantGeometryProvider( + build_chease_geometry( + Ip_from_parameters=Ip_from_parameters, + geometry_dir=geometry_dir, + **kwargs, + ) + ) + + def build_sim_from_config( config: dict[str, Any], ) -> sim_lib.Sim: @@ -236,8 +263,7 @@ def build_geometry_provider_from_config( return geometry_provider.ConstantGeometryProvider( geometry.build_circular_geometry(**kwargs)) elif geometry_type == 'chease': - return geometry_provider.ConstantGeometryProvider( - build_chease_geometry(**kwargs)) + return build_chease_geometry_provider(**kwargs) raise ValueError(f'Unknown geometry type: {geometry_type}') diff --git a/torax/config/config_args.py b/torax/config/config_args.py index 55e31089..8cd68a78 100644 --- a/torax/config/config_args.py +++ b/torax/config/config_args.py @@ -168,9 +168,9 @@ def _interpolate_var_2d( ): # Dealing with a param input so convert it first. param_or_param_input = interpolated_param.InterpolatedVarTimeRho( - values=param_or_param_input, + values=param_or_param_input, rho=geo.torax_mesh.face_centers ) - return param_or_param_input.get_value(t, geo.torax_mesh.face_centers) + return param_or_param_input.get_value(t) def get_init_kwargs( diff --git a/torax/geometry.py b/torax/geometry.py index 5483a932..29beded7 100644 --- a/torax/geometry.py +++ b/torax/geometry.py @@ -19,6 +19,7 @@ from collections.abc import Mapping import dataclasses import enum +import functools from typing import Type import chex @@ -46,7 +47,7 @@ class Grid1D: """ nx: int - dx: chex.Numeric + dx: float face_centers: chex.Array cell_centers: chex.Array @@ -59,13 +60,16 @@ def __post_init__(self): def __eq__(self, other: Grid1D) -> bool: return ( self.nx == other.nx - and np.array_equal(self.dx, other.dx) + and self.dx == other.dx and np.array_equal(self.face_centers, other.face_centers) and np.array_equal(self.cell_centers, other.cell_centers) ) + def __hash__(self) -> int: + return hash((self.nx, self.dx)) + @classmethod - def construct(cls, nx: int, dx: chex.Array) -> Grid1D: + def construct(cls, nx: int, dx: float) -> Grid1D: """Constructs a Grid1D. Args: @@ -298,7 +302,7 @@ def create_provider( ): continue kwargs[attr.name] = interpolated_param.InterpolatedVarSingleAxis( - (times, np.stack([getattr(g, attr.name) for g in geos], axis=-1)) + (times, np.stack([getattr(g, attr.name) for g in geos], axis=0)) ) return cls(**kwargs) @@ -320,6 +324,7 @@ def _get_geometry_base(self, t: chex.Numeric, geometry_class: Type[Geometry]): kwargs[attr.name] = getattr(self, attr.name).get_value(t) return geometry_class(**kwargs) # pytype: disable=wrong-keyword-args + @functools.partial(jax_utils.jit, static_argnums=0) def get_geometry(self, t: chex.Numeric) -> Geometry: """Returns a Geometry instance at the given time.""" return self._get_geometry_base(t, Geometry) @@ -365,6 +370,7 @@ class StandardGeometryProvider(GeometryProvider): delta_upper_face: interpolated_param.InterpolatedVarSingleAxis delta_lower_face: interpolated_param.InterpolatedVarSingleAxis + @functools.partial(jax_utils.jit, static_argnums=0) def get_geometry(self, t: chex.Numeric) -> Geometry: """Returns a Geometry instance at the given time.""" return self._get_geometry_base(t, StandardGeometry) @@ -402,7 +408,7 @@ def build_circular_geometry( # r_norm coordinate is r/Rmin in circular, and rho_norm in standard # geometry (CHEASE/EQDSK) # Define mesh (Slab Uniform 1D with Jacobian = 1) - dr_norm = np.array(1) / nr + dr_norm = 1. / nr mesh = Grid1D.construct(nx=nr, dx=dr_norm) rmax = np.asarray(Rmin) # helper variables for mesh cells and faces @@ -537,7 +543,7 @@ def build_circular_geometry( return CircularAnalyticalGeometry( # Set the standard geometry params. geometry_type=GeometryType.CIRCULAR.value, - dr_norm=dr_norm, + dr_norm=np.asarray(dr_norm), torax_mesh=mesh, rmax=rmax, Rmaj=Rmaj, @@ -635,7 +641,7 @@ class StandardGeometryIntermediates: psi: chex.Array Ip_profile: chex.Array rho: chex.Array - rhon: chex.Array + rhon: np.ndarray Rin: chex.Array Rout: chex.Array RBphi: chex.Array @@ -816,7 +822,7 @@ def build_standard_geometry( # fill geometry structure # r_norm coordinate is rho_tor_norm - dr_norm = intermediate.rhon[-1] / intermediate.nr + dr_norm = float(intermediate.rhon[-1]) / intermediate.nr # normalized grid mesh = Grid1D.construct(nx=intermediate.nr, dx=dr_norm) rmax = intermediate.rho[-1] # radius denormalization constant @@ -897,7 +903,7 @@ def build_standard_geometry( return StandardGeometry( geometry_type=GeometryType.CHEASE.value, - dr_norm=dr_norm, + dr_norm=np.asarray(dr_norm), torax_mesh=mesh, rmax=rmax, Rmaj=intermediate.Rmaj, diff --git a/torax/geometry_loader.py b/torax/geometry_loader.py index 3f0a0094..e392290d 100644 --- a/torax/geometry_loader.py +++ b/torax/geometry_loader.py @@ -15,13 +15,12 @@ """File I/O for loading geometry files.""" import os -import jax -import jax.numpy as jnp +import numpy as np def initialize_CHEASE_dict( # pylint: disable=invalid-name file_path: str, -) -> dict[str, jax.Array]: +) -> dict[str, np.ndarray]: """Loads the data from a CHEASE file into a dictionary.""" # pyformat: disable with open(file_path, 'r') as file: @@ -40,14 +39,14 @@ def initialize_CHEASE_dict( # pylint: disable=invalid-name # Convert lists to jax arrays. return { - var_label: jnp.array(chease_data[var_label]) for var_label in chease_data + var_label: np.asarray(chease_data[var_label]) for var_label in chease_data } def load_chease_data( geometry_dir: str | None, geometry_file: str, -) -> dict[str, jax.Array]: +) -> dict[str, np.ndarray]: """Loads the data from a CHEASE file into a dictionary.""" # The code below does not use os.environ.get() in order to support an internal # version of the code. diff --git a/torax/geometry_provider.py b/torax/geometry_provider.py index 9f653e31..e51be5c4 100644 --- a/torax/geometry_provider.py +++ b/torax/geometry_provider.py @@ -81,10 +81,7 @@ class ConstantGeometryProvider(GeometryProvider): def __init__(self, geo: geometry.Geometry): self._geo = geo - def __call__( - self, - t: chex.Numeric, - ) -> geometry.Geometry: + def __call__(self, t: chex.Numeric) -> geometry.Geometry: # The API includes time as an arg even though it is unused in order # to match the API of a GeometryProvider. del t # Ignored. @@ -93,3 +90,17 @@ def __call__( @property def torax_mesh(self) -> geometry.Grid1D: return self._geo.torax_mesh + + +class TimeDependentGeometryProvider(GeometryProvider): + """Returns a Geometry that changes over time.""" + + def __init__(self, geometry_provider: geometry.GeometryProvider): + self._geometry_provider = geometry_provider + + def __call__(self, t: chex.Numeric) -> geometry.Geometry: + return self._geometry_provider.get_geometry(t) + + @property + def torax_mesh(self) -> geometry.Grid1D: + return self._geometry_provider.torax_mesh diff --git a/torax/interpolated_param.py b/torax/interpolated_param.py index 4a3eddd5..e635951e 100644 --- a/torax/interpolated_param.py +++ b/torax/interpolated_param.py @@ -274,20 +274,21 @@ def is_bool_param(self) -> bool: return self._is_bool_param -class InterpolatedVarTimeRho: +class InterpolatedVarTimeRho(InterpolatedParamBase): """Interpolates on a grid (time, rho). - Given `values` that map from time-values to `InterpolatedVarSingleAxis`s that tell you how to interpolate along rho for different time values this class linearly interpolates along time to provide a value at any (time, rho) pair. - - For time values that are outside the range of `values` the closest defined - `InterpolatedVarSingleAxis` is used. + - NOTE: We assume that rho interpolation is fixed per simulation so take this + at init and take just time at get_value. """ def __init__( self, values: InterpolatedVarTimeRhoInput, + rho: chex.Numeric, rho_interpolation_mode: InterpolationMode = ( InterpolationMode.PIECEWISE_LINEAR ), @@ -307,46 +308,23 @@ def __init__( raise ValueError('Indicies in values mapping must be unique.') if not values: raise ValueError('Values mapping must not be empty.') - self.times_values = { - v: InterpolatedVarSingleAxis(values[v], rho_interpolation_mode) - for v in values.keys() - } - self.sorted_indices = jnp.array(sorted(values.keys())) - - def get_value( - self, - time: chex.Numeric, - rho: chex.Numeric, - ) -> chex.Array: - """Returns the value of this parameter interpolated at the given (time,rho). - - This method is not jittable as it is. - - Args: - time: The time-coordinate to interpolate at. - rho: The rho-coordinate to interpolate at. - Returns: - The value of the interpolated at the given (time,rho). - """ - # Find the index that is left of value which time is closest to. - left = jnp.searchsorted(self.sorted_indices, time, side='left') - - # If time is either smaller or larger, than smallest and largest values - # we know how to interpolate for, use the boundary interpolater. - if left == 0: - return self.times_values[float(self.sorted_indices[0])].get_value(rho) - if left == len(self.sorted_indices): - return self.times_values[float(self.sorted_indices[-1])].get_value(rho) - - # Interpolate between the two closest defined interpolaters. - left_time = float(self.sorted_indices[left - 1]) - right_time = float(self.sorted_indices[left]) - return self.times_values[left_time].get_value(rho) * (right_time - time) / ( - right_time - left_time - ) + self.times_values[right_time].get_value(rho) * (time - left_time) / ( - right_time - left_time + rho_interpolated = np.stack( + [ + InterpolatedVarSingleAxis( + values[v], rho_interpolation_mode + ).get_value(rho) + for v in values + ], + axis=0, + ) + times = np.asarray(list(values.keys())) + self._time_interpolated_var = InterpolatedVarSingleAxis( + (times, rho_interpolated) ) + def get_value(self, x: chex.Numeric) -> chex.Array: + """Returns the value of this parameter interpolated at x=time.""" + return self._time_interpolated_var.get_value(x) # In runtime_params, users should be able to either specify the # InterpolatedVarSingleAxis/InterpolatedVarTimeRho object directly or the values diff --git a/torax/tests/geometry.py b/torax/tests/geometry.py index f66eeec7..b28c7c18 100644 --- a/torax/tests/geometry.py +++ b/torax/tests/geometry.py @@ -80,22 +80,22 @@ def foo(geo: geometry.Geometry): Rmin=2.0, B=5.3, # Use the same dummy value for the rest. - psi=jnp.arange(0, 1.0, 0.01), - Ip_profile=jnp.arange(0, 1.0, 0.01), - rho=jnp.arange(0, 1.0, 0.01), - rhon=jnp.arange(0, 1.0, 0.01), - Rin=jnp.arange(0, 1.0, 0.01), - Rout=jnp.arange(0, 1.0, 0.01), - RBphi=jnp.arange(0, 1.0, 0.01), - int_Jdchi=jnp.arange(0, 1.0, 0.01), - flux_norm_1_over_R2=jnp.arange(0, 1.0, 0.01), - flux_norm_Bp2=jnp.arange(0, 1.0, 0.01), - flux_norm_dpsi=jnp.arange(0, 1.0, 0.01), - flux_norm_dpsi2=jnp.arange(0, 1.0, 0.01), - delta_upper_face=jnp.arange(0, 1.0, 0.01), - delta_lower_face=jnp.arange(0, 1.0, 0.01), - volume=jnp.arange(0, 1.0, 0.01), - area=jnp.arange(0, 1.0, 0.01), + psi=np.arange(0, 1.0, 0.01), + Ip_profile=np.arange(0, 1.0, 0.01), + rho=np.arange(0, 1.0, 0.01), + rhon=np.arange(0, 1.0, 0.01), + Rin=np.arange(0, 1.0, 0.01), + Rout=np.arange(0, 1.0, 0.01), + RBphi=np.arange(0, 1.0, 0.01), + int_Jdchi=np.arange(0, 1.0, 0.01), + flux_norm_1_over_R2=np.arange(0, 1.0, 0.01), + flux_norm_Bp2=np.arange(0, 1.0, 0.01), + flux_norm_dpsi=np.arange(0, 1.0, 0.01), + flux_norm_dpsi2=np.arange(0, 1.0, 0.01), + delta_upper_face=np.arange(0, 1.0, 0.01), + delta_lower_face=np.arange(0, 1.0, 0.01), + volume=np.arange(0, 1.0, 0.01), + area=np.arange(0, 1.0, 0.01), hires_fac=4, ) geo = geometry.build_standard_geometry(intermediate) @@ -115,22 +115,22 @@ def test_build_geometry_provider(self): Rmin=2.0, B=5.3, # Use the same dummy value for the rest. - psi=jnp.arange(0, 1.0, 0.01), - Ip_profile=jnp.arange(0, 1.0, 0.01), - rho=jnp.arange(0, 1.0, 0.01), - rhon=jnp.arange(0, 1.0, 0.01), - Rin=jnp.arange(0, 1.0, 0.01), - Rout=jnp.arange(0, 1.0, 0.01), - RBphi=jnp.arange(0, 1.0, 0.01), - int_Jdchi=jnp.arange(0, 1.0, 0.01), - flux_norm_1_over_R2=jnp.arange(0, 1.0, 0.01), - flux_norm_Bp2=jnp.arange(0, 1.0, 0.01), - flux_norm_dpsi=jnp.arange(0, 1.0, 0.01), - flux_norm_dpsi2=jnp.arange(0, 1.0, 0.01), - delta_upper_face=jnp.arange(0, 1.0, 0.01), - delta_lower_face=jnp.arange(0, 1.0, 0.01), - volume=jnp.arange(0, 1.0, 0.01), - area=jnp.arange(0, 1.0, 0.01), + psi=np.arange(0, 1.0, 0.01), + Ip_profile=np.arange(0, 1.0, 0.01), + rho=np.arange(0, 1.0, 0.01), + rhon=np.arange(0, 1.0, 0.01), + Rin=np.arange(0, 1.0, 0.01), + Rout=np.arange(0, 1.0, 0.01), + RBphi=np.arange(0, 1.0, 0.01), + int_Jdchi=np.arange(0, 1.0, 0.01), + flux_norm_1_over_R2=np.arange(0, 1.0, 0.01), + flux_norm_Bp2=np.arange(0, 1.0, 0.01), + flux_norm_dpsi=np.arange(0, 1.0, 0.01), + flux_norm_dpsi2=np.arange(0, 1.0, 0.01), + delta_upper_face=np.arange(0, 1.0, 0.01), + delta_lower_face=np.arange(0, 1.0, 0.01), + volume=np.arange(0, 1.0, 0.01), + area=np.arange(0, 1.0, 0.01), hires_fac=4, ) geo_0 = geometry.build_standard_geometry(intermediate_0) @@ -142,22 +142,22 @@ def test_build_geometry_provider(self): Rmin=1.0, B=6.5, # Use the same dummy value for the rest. - psi=jnp.arange(0, 1.0, 0.01), - Ip_profile=jnp.arange(0, 2.0, 0.02), - rho=jnp.arange(0, 1.0, 0.01), - rhon=jnp.arange(0, 1.0, 0.01), - Rin=jnp.arange(0, 1.0, 0.01), - Rout=jnp.arange(0, 1.0, 0.01), - RBphi=jnp.arange(0, 1.0, 0.01), - int_Jdchi=jnp.arange(0, 1.0, 0.01), - flux_norm_1_over_R2=jnp.arange(0, 1.0, 0.01), - flux_norm_Bp2=jnp.arange(0, 1.0, 0.01), - flux_norm_dpsi=jnp.arange(0, 1.0, 0.01), - flux_norm_dpsi2=jnp.arange(0, 1.0, 0.01), - delta_upper_face=jnp.arange(0, 1.0, 0.01), - delta_lower_face=jnp.arange(0, 1.0, 0.01), - volume=jnp.arange(0, 2.0, 0.02), - area=jnp.arange(0, 2.0, 0.02), + psi=np.arange(0, 1.0, 0.01), + Ip_profile=np.arange(0, 2.0, 0.02), + rho=np.arange(0, 1.0, 0.01), + rhon=np.arange(0, 1.0, 0.01), + Rin=np.arange(0, 1.0, 0.01), + Rout=np.arange(0, 1.0, 0.01), + RBphi=np.arange(0, 1.0, 0.01), + int_Jdchi=np.arange(0, 1.0, 0.01), + flux_norm_1_over_R2=np.arange(0, 1.0, 0.01), + flux_norm_Bp2=np.arange(0, 1.0, 0.01), + flux_norm_dpsi=np.arange(0, 1.0, 0.01), + flux_norm_dpsi2=np.arange(0, 1.0, 0.01), + delta_upper_face=np.arange(0, 1.0, 0.01), + delta_lower_face=np.arange(0, 1.0, 0.01), + volume=np.arange(0, 2.0, 0.02), + area=np.arange(0, 2.0, 0.02), hires_fac=4, ) geo_1 = geometry.build_standard_geometry(intermediate_1) diff --git a/torax/tests/interpolated_param.py b/torax/tests/interpolated_param.py index af6f2f9a..38bd4265 100644 --- a/torax/tests/interpolated_param.py +++ b/torax/tests/interpolated_param.py @@ -360,43 +360,43 @@ def test_interpolated_param_need_xs_to_be_sorted(self, range_class): def test_interpolated_var_2d(self, values, x, y, expected_output): """Tests the doubly interpolated param gives correct outputs on 2D mesh.""" interpolated_var_2d = interpolated_param.InterpolatedVarTimeRho( - values + values, rho=y ) - output = interpolated_var_2d.get_value(time=x, rho=y) - np.testing.assert_allclose(output, expected_output) + output = interpolated_var_2d.get_value(x=x) + np.testing.assert_allclose(output, expected_output, atol=1e-6, rtol=1e-6) def test_interpolated_var_2d_parses_float_input(self): """Tests that InterpolatedVarTimeRho parses float inputs correctly.""" interpolated_var_2d = interpolated_param.InterpolatedVarTimeRho( - values=1.0, + values=1.0, rho=0.0 ) np.testing.assert_allclose( - interpolated_var_2d.get_value(time=0.0, rho=0.0), 1.0 - ) - np.testing.assert_allclose( - interpolated_var_2d.get_value(time=0.0, rho=1.0), 1.0 + interpolated_var_2d.get_value(x=0.0), 1.0 ) self.assertLen(interpolated_var_2d.values, 1) self.assertIn(0.0, interpolated_var_2d.values) def test_interpolated_var_2d_parses_single_dict_input(self): - """Tests that InterpolatedVarTimeRho parses float inputs correctly.""" + """Tests that InterpolatedVarTimeRho parses dict inputs correctly.""" interpolated_var_2d = interpolated_param.InterpolatedVarTimeRho( - values={0: 18.0, 0.95: 5.0,}, + values={0: 18.0, 0.95: 5.0,}, rho=0.0, ) np.testing.assert_allclose( - interpolated_var_2d.get_value(time=0.0, rho=0.0), 18.0 + interpolated_var_2d.get_value(x=0.0), 18.0 ) np.testing.assert_allclose( - interpolated_var_2d.get_value(time=0.5, rho=0.0), 18.0 + interpolated_var_2d.get_value(x=0.5), 18.0 ) + interpolated_var_2d = interpolated_param.InterpolatedVarTimeRho( + values={0: 18.0, 0.95: 5.0,}, rho=0.95, + ) np.testing.assert_allclose( - interpolated_var_2d.get_value(time=0.0, rho=0.95), 5.0 + interpolated_var_2d.get_value(x=0.0), 5.0 ) np.testing.assert_allclose( - interpolated_var_2d.get_value(time=0.5, rho=0.95), 5.0 + interpolated_var_2d.get_value(x=0.5), 5.0 )