Skip to content

Commit

Permalink
Test that run_simulation_main cc results are correct
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653717622
  • Loading branch information
goodfeli authored and Torax team committed Jul 18, 2024
1 parent 4a86307 commit 61792df
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 0 deletions.
Binary file modified torax/tests/test_data/test_changing_config_after.nc
Binary file not shown.
Binary file modified torax/tests/test_data/test_changing_config_before.nc
Binary file not shown.
50 changes: 50 additions & 0 deletions torax/tests/test_run_simulation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from absl.testing import absltest
from absl.testing import flagsaver
from absl.testing import parameterized
import numpy as np
import torax
# run_simulation_main.py is in the repo root, which is the parent directory
# of the actual module
Expand All @@ -49,6 +50,7 @@
del sys.path[-1]
from torax import simulation_app
from torax.tests.test_lib import paths
import xarray as xr


class RunSimulationMainTest(parameterized.TestCase):
Expand Down Expand Up @@ -175,6 +177,11 @@ def mock_input(prompt):
# The second call to `input` is confirming that we should run with
# this config.
response = "y"
elif call_count == 2:
# After changing the config, we go back to the main menu.
# Now we need to send an 'r' to run the config.
self.assertEqual(prompt, run_simulation_main.CHOICE_PROMPT)
response = "r"
else:
# The second run on the simulation just completed. Fetch its output.
filepaths.append(get_latest_filepath(captured_stdout))
Expand All @@ -196,6 +203,49 @@ def mock_input(prompt):
finally:
logging.get_absl_logger().removeHandler(handler)

# We should have received 2 runs, the before change and after change runs
self.assertLen(filepaths, 2)
self.assertNotEqual(filepaths[0], filepaths[1])

ground_truth_before = before[: -len(".py")] + ".nc"
ground_truth_after = after[: -len(".py")] + ".nc"

def check(output_path, ground_truth_path):
output = xr.open_dataset(output_path)
ground_truth = xr.open_dataset(ground_truth_path)

for key in output:
self.assertIn(key, ground_truth)

for key in ground_truth:
self.assertIn(key, output)

ov = output[key].to_numpy()
gv = ground_truth[key].to_numpy()

# Same tolerances as test_iterhybrid_newton
if not np.allclose(
ov,
gv,
# GitHub CI behaves very differently from Google internal for
# the mode=zero case, needing looser tolerance for this than
# for other tests.
# rtol=0.0,
atol=1.0e-9,
):
diff = ov - gv
max_diff = np.abs(diff).max()
raise AssertionError(
f"{key} does not match. "
f"Output: {ov}. "
f"Ground truth: {gv}."
f"Diff: {diff}"
f"Max diff: {max_diff}"
)

check(filepaths[0], ground_truth_before)
check(filepaths[1], ground_truth_after)


def get_latest_filepath(stream: io.StringIO) -> str:
"""Returns the last filepath written to by the app."""
Expand Down

0 comments on commit 61792df

Please sign in to comment.