Skip to content

Commit

Permalink
Fix plotting script following new xarray output.
Browse files Browse the repository at this point in the history
Plotting script now works for both legacy h5 files and new xarray netcdf output.

1. This PR also updates the filename suffix of new xarray output files to .nc , since .h5 is no longer appropriate.
2. To allow a combination of legacy h5 and new nc files as --outfile arguments, full filenames including file suffixes are now expected.
3. h5py is removed. h5 files can be read by the xarray h5netcdf backend

PiperOrigin-RevId: 625374339
  • Loading branch information
Torax team committed Apr 17, 2024
1 parent 651eb5b commit 630539e
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 83 deletions.
20 changes: 1 addition & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ python3 run_simulation_main.py \

### Post-simulation

Once complete, the time history of a simulation state and derived quantities is written to `state_history.h5`. The output path is written to stdout
Once complete, the time history of a simulation state and derived quantities is written to `state_history.nc`. The output path is written to stdout

To take advantage of the in-memory (non-persistent) cache, the process does not end upon simulation termination. It is possible to modify the config, toggle the `log_progress` and `plot_progress` flags, and rerun the simulation. Only the following modifications will then trigger a recompilation:

Expand All @@ -219,21 +219,3 @@ deactivate
# Simulation tutorials

Under construction

# FAQ

* On MacOS, you may get the error: .. ERROR:: Could not find a local HDF5
installation.:
* Solution: You need to tell the OS where HDF5 is, try

```shell
brew install hdf5
```

```shell
export HDF5_DIR="$(brew --prefix hdf5)"
```

```shell
pip install --no-binary=h5py h5py
```
1 change: 0 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@ setuptools;python_version>="3.10"
chex>=0.1.85
fancyflags>=1.2
equinox @ git+https://github.com/patrick-kidger/equinox@1e601672d38d2c4d483535070a3572d8e8508a20
h5py>=3.10.0
PyYAML>=6.0.1
xarray>=2023.12.0
82 changes: 45 additions & 37 deletions torax/plotting/plotruns.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Basic post-run plotting tool. Plot a single run or comparison of two runs.
Includes a time slider. Reads output h5 files,
Includes a time slider. Reads output files with xarray data or legacy h5 data.
Plots:
(1) chi_i, chi_e (transport coefficients)
Expand All @@ -27,11 +27,11 @@

import argparse
import dataclasses
import h5py
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider # pylint: disable=g-importing-member
import numpy as np
import xarray as xr

matplotlib.use('TkAgg')

Expand Down Expand Up @@ -77,7 +77,7 @@ def __post_init__(self):
'--outfile',
nargs='*',
help=(
'Relative location of output h5 files (if two are provided, a'
'Relative location of output files (if two are provided, a'
' comparison is done)'
),
)
Expand All @@ -94,42 +94,50 @@ def __post_init__(self):
if not args.outfile:
raise ValueError('No output file provided')

with h5py.File(args.outfile[0] + '.h5', 'r') as hf:
plotdata1 = PlotData(
ti=hf['temp_ion'][:],
te=hf['temp_el'][:],
ne=hf['ne'][:],
j=hf['jtot'][:],
johm=hf['johm'][:],
j_bootstrap=hf['j_bootstrap'][:],
jext=hf['jext'][:],
q=hf['q_face'][:],
s=hf['s_face'][:],
chi_i=hf['chi_face_ion'][:],
chi_e=hf['chi_face_el'][:],
t=hf['t'][:],
rcell_coord=hf['r_cell_norm'][:],
rface_coord=hf['r_face_norm'][:],
)
ds1 = xr.open_dataset(args.outfile[0])
if 'time' in ds1:
t = ds1['time'].to_numpy()
else:
t = ds1['t'].to_numpy()
plotdata1 = PlotData(
ti=ds1['temp_ion'].to_numpy(),
te=ds1['temp_el'].to_numpy(),
ne=ds1['ne'].to_numpy(),
j=ds1['jtot'].to_numpy(),
johm=ds1['johm'].to_numpy(),
j_bootstrap=ds1['j_bootstrap'].to_numpy(),
jext=ds1['jext'].to_numpy(),
q=ds1['q_face'].to_numpy(),
s=ds1['s_face'].to_numpy(),
chi_i=ds1['chi_face_ion'].to_numpy(),
chi_e=ds1['chi_face_el'].to_numpy(),
rcell_coord=ds1['r_cell_norm'].to_numpy(),
rface_coord=ds1['r_face_norm'].to_numpy(),
t=t,
)

if comp_plot:
with h5py.File(args.outfile[1] + '.h5', 'r') as hf:
plotdata2 = PlotData(
ti=hf['temp_ion'][:],
te=hf['temp_el'][:],
ne=hf['ne'][:],
j=hf['jtot'][:],
johm=hf['johm'][:],
j_bootstrap=hf['j_bootstrap'][:],
jext=hf['jext'][:],
q=hf['q_face'][:],
s=hf['s_face'][:],
chi_i=hf['chi_face_ion'][:],
chi_e=hf['chi_face_el'][:],
t=hf['t'][:],
rcell_coord=hf['r_cell_norm'][:],
rface_coord=hf['r_face_norm'][:],
)
ds2 = xr.open_dataset(args.outfile[1])
if 'time' in ds2:
t = ds2['time'].to_numpy()
else:
t = ds2['t'].to_numpy()
plotdata2 = PlotData(
ti=ds2['temp_ion'].to_numpy(),
te=ds2['temp_el'].to_numpy(),
ne=ds2['ne'].to_numpy(),
j=ds2['jtot'].to_numpy(),
johm=ds2['johm'].to_numpy(),
j_bootstrap=ds2['j_bootstrap'].to_numpy(),
jext=ds2['jext'].to_numpy(),
q=ds2['q_face'].to_numpy(),
s=ds2['s_face'].to_numpy(),
chi_i=ds2['chi_face_ion'].to_numpy(),
chi_e=ds2['chi_face_el'].to_numpy(),
rcell_coord=ds2['r_cell_norm'].to_numpy(),
rface_coord=ds2['r_face_norm'].to_numpy(),
t=t,
)

fig = plt.figure(figsize=(15, 10))
ax1 = fig.add_subplot(231)
Expand Down
2 changes: 1 addition & 1 deletion torax/simulation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class AnsiColors(enum.Enum):
_ANSI_END = '\033[0m'

_DEFAULT_OUTPUT_DIR_PREFIX = '/tmp/torax_results_'
_STATE_HISTORY_FILENAME = 'state_history.h5'
_STATE_HISTORY_FILENAME = 'state_history.nc'


def log_to_stdout(output: str, color: AnsiColors | None = None) -> None:
Expand Down
35 changes: 10 additions & 25 deletions torax/tests/test_lib/sim_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from absl.testing import absltest
from absl.testing import parameterized
import chex
import h5py
import jax.numpy as jnp
import numpy as np
import torax
Expand All @@ -36,17 +35,7 @@
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import array_time_step_calculator
from torax.transport_model import transport_model as transport_model_lib


_TORAX_TO_PINT = {
'temp_ion': 'Ti',
'temp_el': 'Te',
'psi': 'psi',
's_face': 's',
'q_face': 'q',
'ne': 'ne',
}

import xarray as xr

_PYTHON_MODULE_PREFIX = '.tests.test_data.'
_PYTHON_CONFIG_PACKAGE = 'torax'
Expand Down Expand Up @@ -116,19 +105,15 @@ def _get_refs(
"""Gets reference values for the requested state profiles."""
expected_results_path = self._expected_results_path(ref_name)
self.assertTrue(os.path.exists(expected_results_path))

with open(expected_results_path, mode='rb') as f:
with h5py.File(f, 'r') as hf:
self.assertNotEmpty(profiles)
if 'Ti' in hf.keys(): # Determine if h5 file is PINT output
ref_profiles = {
profile: hf[_TORAX_TO_PINT[profile]][:] for profile in profiles
}
else:
ref_profiles = {profile: hf[profile][:] for profile in profiles}
ref_time = jnp.array(hf['t'])
self.assertEqual(ref_time.shape[0], ref_profiles[profiles[0]].shape[0])
return ref_profiles, ref_time
ds = xr.open_dataset(expected_results_path)
self.assertNotEmpty(profiles)
ref_profiles = {profile: ds[profile].to_numpy() for profile in profiles}
if 'time' in ds:
ref_time = ds['time'].to_numpy()
else:
ref_time = ds['t'].to_numpy()
self.assertEqual(ref_time.shape[0], ref_profiles[profiles[0]].shape[0])
return ref_profiles, ref_time

def _check_profiles_vs_expected(
self,
Expand Down

0 comments on commit 630539e

Please sign in to comment.