Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix plotting script following new xarray output. #42

Merged
1 commit merged into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 1 addition & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ python3 run_simulation_main.py \

### Post-simulation

Once complete, the time history of a simulation state and derived quantities is written to `state_history.h5`. The output path is written to stdout
Once complete, the time history of a simulation state and derived quantities is written to `state_history.nc`. The output path is written to stdout

To take advantage of the in-memory (non-persistent) cache, the process does not end upon simulation termination. It is possible to modify the config, toggle the `log_progress` and `plot_progress` flags, and rerun the simulation. Only the following modifications will then trigger a recompilation:

Expand All @@ -219,21 +219,3 @@ deactivate
# Simulation tutorials

Under construction

# FAQ

* On MacOS, you may get the error: .. ERROR:: Could not find a local HDF5
installation.:
* Solution: You need to tell the OS where HDF5 is, try

```shell
brew install hdf5
```

```shell
export HDF5_DIR="$(brew --prefix hdf5)"
```

```shell
pip install --no-binary=h5py h5py
```
1 change: 0 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@ setuptools;python_version>="3.10"
chex>=0.1.85
fancyflags>=1.2
equinox @ git+https://github.com/patrick-kidger/equinox@1e601672d38d2c4d483535070a3572d8e8508a20
h5py>=3.10.0
PyYAML>=6.0.1
xarray>=2023.12.0
82 changes: 45 additions & 37 deletions torax/plotting/plotruns.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Basic post-run plotting tool. Plot a single run or comparison of two runs.

Includes a time slider. Reads output h5 files,
Includes a time slider. Reads output files with xarray data or legacy h5 data.

Plots:
(1) chi_i, chi_e (transport coefficients)
Expand All @@ -27,11 +27,11 @@

import argparse
import dataclasses
import h5py
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider # pylint: disable=g-importing-member
import numpy as np
import xarray as xr

matplotlib.use('TkAgg')

Expand Down Expand Up @@ -77,7 +77,7 @@ def __post_init__(self):
'--outfile',
nargs='*',
help=(
'Relative location of output h5 files (if two are provided, a'
'Relative location of output files (if two are provided, a'
' comparison is done)'
),
)
Expand All @@ -94,42 +94,50 @@ def __post_init__(self):
if not args.outfile:
raise ValueError('No output file provided')

with h5py.File(args.outfile[0] + '.h5', 'r') as hf:
plotdata1 = PlotData(
ti=hf['temp_ion'][:],
te=hf['temp_el'][:],
ne=hf['ne'][:],
j=hf['jtot'][:],
johm=hf['johm'][:],
j_bootstrap=hf['j_bootstrap'][:],
jext=hf['jext'][:],
q=hf['q_face'][:],
s=hf['s_face'][:],
chi_i=hf['chi_face_ion'][:],
chi_e=hf['chi_face_el'][:],
t=hf['t'][:],
rcell_coord=hf['r_cell_norm'][:],
rface_coord=hf['r_face_norm'][:],
)
ds1 = xr.open_dataset(args.outfile[0])
if 'time' in ds1:
t = ds1['time'].to_numpy()
else:
t = ds1['t'].to_numpy()
plotdata1 = PlotData(
ti=ds1['temp_ion'].to_numpy(),
te=ds1['temp_el'].to_numpy(),
ne=ds1['ne'].to_numpy(),
j=ds1['jtot'].to_numpy(),
johm=ds1['johm'].to_numpy(),
j_bootstrap=ds1['j_bootstrap'].to_numpy(),
jext=ds1['jext'].to_numpy(),
q=ds1['q_face'].to_numpy(),
s=ds1['s_face'].to_numpy(),
chi_i=ds1['chi_face_ion'].to_numpy(),
chi_e=ds1['chi_face_el'].to_numpy(),
rcell_coord=ds1['r_cell_norm'].to_numpy(),
rface_coord=ds1['r_face_norm'].to_numpy(),
t=t,
)

if comp_plot:
with h5py.File(args.outfile[1] + '.h5', 'r') as hf:
plotdata2 = PlotData(
ti=hf['temp_ion'][:],
te=hf['temp_el'][:],
ne=hf['ne'][:],
j=hf['jtot'][:],
johm=hf['johm'][:],
j_bootstrap=hf['j_bootstrap'][:],
jext=hf['jext'][:],
q=hf['q_face'][:],
s=hf['s_face'][:],
chi_i=hf['chi_face_ion'][:],
chi_e=hf['chi_face_el'][:],
t=hf['t'][:],
rcell_coord=hf['r_cell_norm'][:],
rface_coord=hf['r_face_norm'][:],
)
ds2 = xr.open_dataset(args.outfile[1])
if 'time' in ds2:
t = ds2['time'].to_numpy()
else:
t = ds2['t'].to_numpy()
plotdata2 = PlotData(
ti=ds2['temp_ion'].to_numpy(),
te=ds2['temp_el'].to_numpy(),
ne=ds2['ne'].to_numpy(),
j=ds2['jtot'].to_numpy(),
johm=ds2['johm'].to_numpy(),
j_bootstrap=ds2['j_bootstrap'].to_numpy(),
jext=ds2['jext'].to_numpy(),
q=ds2['q_face'].to_numpy(),
s=ds2['s_face'].to_numpy(),
chi_i=ds2['chi_face_ion'].to_numpy(),
chi_e=ds2['chi_face_el'].to_numpy(),
rcell_coord=ds2['r_cell_norm'].to_numpy(),
rface_coord=ds2['r_face_norm'].to_numpy(),
t=t,
)

fig = plt.figure(figsize=(15, 10))
ax1 = fig.add_subplot(231)
Expand Down
2 changes: 1 addition & 1 deletion torax/simulation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class AnsiColors(enum.Enum):
_ANSI_END = '\033[0m'

_DEFAULT_OUTPUT_DIR_PREFIX = '/tmp/torax_results_'
_STATE_HISTORY_FILENAME = 'state_history.h5'
_STATE_HISTORY_FILENAME = 'state_history.nc'


def log_to_stdout(output: str, color: AnsiColors | None = None) -> None:
Expand Down
35 changes: 10 additions & 25 deletions torax/tests/test_lib/sim_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from absl.testing import absltest
from absl.testing import parameterized
import chex
import h5py
import jax.numpy as jnp
import numpy as np
import torax
Expand All @@ -36,17 +35,7 @@
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import array_time_step_calculator
from torax.transport_model import transport_model as transport_model_lib


_TORAX_TO_PINT = {
'temp_ion': 'Ti',
'temp_el': 'Te',
'psi': 'psi',
's_face': 's',
'q_face': 'q',
'ne': 'ne',
}

import xarray as xr

_PYTHON_MODULE_PREFIX = '.tests.test_data.'
_PYTHON_CONFIG_PACKAGE = 'torax'
Expand Down Expand Up @@ -116,19 +105,15 @@ def _get_refs(
"""Gets reference values for the requested state profiles."""
expected_results_path = self._expected_results_path(ref_name)
self.assertTrue(os.path.exists(expected_results_path))

with open(expected_results_path, mode='rb') as f:
with h5py.File(f, 'r') as hf:
self.assertNotEmpty(profiles)
if 'Ti' in hf.keys(): # Determine if h5 file is PINT output
ref_profiles = {
profile: hf[_TORAX_TO_PINT[profile]][:] for profile in profiles
}
else:
ref_profiles = {profile: hf[profile][:] for profile in profiles}
ref_time = jnp.array(hf['t'])
self.assertEqual(ref_time.shape[0], ref_profiles[profiles[0]].shape[0])
return ref_profiles, ref_time
ds = xr.open_dataset(expected_results_path)
self.assertNotEmpty(profiles)
ref_profiles = {profile: ds[profile].to_numpy() for profile in profiles}
if 'time' in ds:
ref_time = ds['time'].to_numpy()
else:
ref_time = ds['t'].to_numpy()
self.assertEqual(ref_time.shape[0], ref_profiles[profiles[0]].shape[0])
return ref_profiles, ref_time

def _check_profiles_vs_expected(
self,
Expand Down