Skip to content

Commit

Permalink
Add a jit call to the get_geometry method
Browse files Browse the repository at this point in the history
This allows compolation across the repeated block of code and reduces setup speed from 5 to 3s for the time-dependent test

PiperOrigin-RevId: 653247333
  • Loading branch information
tamaranorman authored and Torax team committed Jul 18, 2024
1 parent 4a86307 commit c32be8a
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 12 deletions.
30 changes: 28 additions & 2 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,33 @@ def build_chease_geometry(
return geo


def build_chease_geometry_provider(
Ip_from_parameters: bool = True,
geometry_dir: str | None = None,
**kwargs,
) -> geometry_provider.GeometryProvider:
"""Constructs a geometry provider from CHEASE file or series of files."""
if 'geometry_configs' in kwargs:
if not isinstance(kwargs['geometry_configs'], dict):
raise ValueError('geometry_configs must be a dict.')
geometries = {}
for time, config in kwargs['geometry_configs'].items():
geometries[time] = build_chease_geometry(
Ip_from_parameters=Ip_from_parameters,
geometry_dir=geometry_dir,
**config,
)
return geometry_provider.TimeDependentGeometryProvider(
geometry.StandardGeometryProvider.create_provider(geometries))
return geometry_provider.ConstantGeometryProvider(
build_chease_geometry(
Ip_from_parameters=Ip_from_parameters,
geometry_dir=geometry_dir,
**kwargs,
)
)


def build_sim_from_config(
config: dict[str, Any],
) -> sim_lib.Sim:
Expand Down Expand Up @@ -236,8 +263,7 @@ def build_geometry_provider_from_config(
return geometry_provider.ConstantGeometryProvider(
geometry.build_circular_geometry(**kwargs))
elif geometry_type == 'chease':
return geometry_provider.ConstantGeometryProvider(
build_chease_geometry(**kwargs))
return build_chease_geometry_provider(**kwargs)
raise ValueError(f'Unknown geometry type: {geometry_type}')


Expand Down
8 changes: 7 additions & 1 deletion torax/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Mapping
import dataclasses
import enum
import functools
from typing import Type

import chex
Expand Down Expand Up @@ -64,6 +65,9 @@ def __eq__(self, other: Grid1D) -> bool:
and np.array_equal(self.cell_centers, other.cell_centers)
)

def __hash__(self) -> int:
return hash((self.nx, self.dx))

@classmethod
def construct(cls, nx: int, dx: chex.Array) -> Grid1D:
"""Constructs a Grid1D.
Expand Down Expand Up @@ -298,7 +302,7 @@ def create_provider(
):
continue
kwargs[attr.name] = interpolated_param.InterpolatedVarSingleAxis(
(times, np.stack([getattr(g, attr.name) for g in geos], axis=-1))
(times, np.stack([getattr(g, attr.name) for g in geos], axis=0))
)
return cls(**kwargs)

Expand All @@ -320,6 +324,7 @@ def _get_geometry_base(self, t: chex.Numeric, geometry_class: Type[Geometry]):
kwargs[attr.name] = getattr(self, attr.name).get_value(t)
return geometry_class(**kwargs) # pytype: disable=wrong-keyword-args

@functools.partial(jax_utils.jit, static_argnums=0)
def get_geometry(self, t: chex.Numeric) -> Geometry:
"""Returns a Geometry instance at the given time."""
return self._get_geometry_base(t, Geometry)
Expand Down Expand Up @@ -365,6 +370,7 @@ class StandardGeometryProvider(GeometryProvider):
delta_upper_face: interpolated_param.InterpolatedVarSingleAxis
delta_lower_face: interpolated_param.InterpolatedVarSingleAxis

@functools.partial(jax_utils.jit, static_argnums=0)
def get_geometry(self, t: chex.Numeric) -> Geometry:
"""Returns a Geometry instance at the given time."""
return self._get_geometry_base(t, StandardGeometry)
Expand Down
9 changes: 4 additions & 5 deletions torax/geometry_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
"""File I/O for loading geometry files."""
import os

import jax
import jax.numpy as jnp
import numpy as np


def initialize_CHEASE_dict( # pylint: disable=invalid-name
file_path: str,
) -> dict[str, jax.Array]:
) -> dict[str, np.ndarray]:
"""Loads the data from a CHEASE file into a dictionary."""
# pyformat: disable
with open(file_path, 'r') as file:
Expand All @@ -40,14 +39,14 @@ def initialize_CHEASE_dict( # pylint: disable=invalid-name

# Convert lists to jax arrays.
return {
var_label: jnp.array(chease_data[var_label]) for var_label in chease_data
var_label: np.asarray(chease_data[var_label]) for var_label in chease_data
}


def load_chease_data(
geometry_dir: str | None,
geometry_file: str,
) -> dict[str, jax.Array]:
) -> dict[str, np.ndarray]:
"""Loads the data from a CHEASE file into a dictionary."""
# The code below does not use os.environ.get() in order to support an internal
# version of the code.
Expand Down
19 changes: 15 additions & 4 deletions torax/geometry_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ class ConstantGeometryProvider(GeometryProvider):
def __init__(self, geo: geometry.Geometry):
self._geo = geo

def __call__(
self,
t: chex.Numeric,
) -> geometry.Geometry:
def __call__(self, t: chex.Numeric) -> geometry.Geometry:
# The API includes time as an arg even though it is unused in order
# to match the API of a GeometryProvider.
del t # Ignored.
Expand All @@ -93,3 +90,17 @@ def __call__(
@property
def torax_mesh(self) -> geometry.Grid1D:
return self._geo.torax_mesh


class TimeDependentGeometryProvider(GeometryProvider):
"""Returns a Geometry that changes over time."""

def __init__(self, geometry_provider: geometry.GeometryProvider):
self._geometry_provider = geometry_provider

def __call__(self, t: chex.Numeric) -> geometry.Geometry:
return self._geometry_provider.get_geometry(t)

@property
def torax_mesh(self) -> geometry.Grid1D:
return self._geometry_provider.torax_mesh

0 comments on commit c32be8a

Please sign in to comment.