From 67a750ce308b1dab032196c18c88b69ad9566889 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 27 Dec 2024 06:15:22 +0000 Subject: [PATCH] increase data.py coverage from 84% to 90% - add unit tests for Files enum and download_file - move load_df_wbm_with_preds from test_preds.py to test_data.py - .gitignore .coverage* files - contributing.md remove leading slashes from model pred file paths - better error handling in Files enum by raising a ValueError when a label not found --- .gitignore | 4 +- contributing.md | 4 +- matbench_discovery/data.py | 5 +- matbench_discovery/models.py | 2 +- matbench_discovery/preds/discovery.py | 3 - .../routes/tasks/geo-opt/geo-opt-readme.md | 6 +- tests/test_data.py | 128 ++++++++++++++++++ tests/test_preds.py | 63 +-------- 8 files changed, 143 insertions(+), 72 deletions(-) diff --git a/.gitignore b/.gitignore index c3bd3341..3a82ee35 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.egg-info dist build +.coverage* # cache __pycache__ @@ -33,8 +34,5 @@ models/**/*.tgz # auto-generated docs site/src/routes/api/*.md -# temporary ignore rules -data/mp/mptrj-gga-ggapu/* - # large files data/*-models-geo-opt-analysis-symprec=*.csv.gz diff --git a/contributing.md b/contributing.md index 345e3028..d21bc4fa 100644 --- a/contributing.md +++ b/contributing.md @@ -175,10 +175,10 @@ To submit a new model to this benchmark and add it to our leaderboard, please cr metrics: discovery: - pred_file: /models//--wbm-IS2RE.csv.gz # should contain the models energy predictions for the WBM test set + pred_file: models//--wbm-IS2RE.csv.gz # should contain the models energy predictions for the WBM test set pred_col: e_form_per_atom_ geo_opt: # only applicable if the model performed structure relaxation - pred_file: /models//--wbm-IS2RE.json.gz # should contain the models relaxed structures as ASE Atoms or pymatgen Structures, and separate columns for material_id and energies/forces/stresses at each relaxation step + pred_file: models//--wbm-IS2RE.json.gz # should contain the models relaxed structures as ASE Atoms or pymatgen Structures, and separate columns for material_id and energies/forces/stresses at each relaxation step pred_col: e_form_per_atom_ ``` diff --git a/matbench_discovery/data.py b/matbench_discovery/data.py index a1477557..4a927b73 100644 --- a/matbench_discovery/data.py +++ b/matbench_discovery/data.py @@ -275,7 +275,10 @@ def label(self) -> str: @classmethod def from_label(cls, label: str) -> Self: """Get enum member from pretty label.""" - return next(attr for attr in cls if attr.label == label) + file = next((attr for attr in cls if attr.label == label), None) + if file is None: + raise ValueError(f"{label=} not found in {cls.__name__}") + return file class DataFiles(Files): diff --git a/matbench_discovery/models.py b/matbench_discovery/models.py index 8ce01e7e..b2d9845a 100644 --- a/matbench_discovery/models.py +++ b/matbench_discovery/models.py @@ -39,7 +39,7 @@ try: ModelType(model_data.get("model_type")) # check if model_type is valid except ValueError as exc: - exc.add_note(f"{metadata_file=}") + exc.add_note(f"{metadata_file=}\nPick from {', '.join(ModelType)}") raise diff --git a/matbench_discovery/preds/discovery.py b/matbench_discovery/preds/discovery.py index 72d56535..fd4e0efd 100644 --- a/matbench_discovery/preds/discovery.py +++ b/matbench_discovery/preds/discovery.py @@ -15,9 +15,6 @@ # load WBM summary dataframe with all models' formation energy predictions (eV/atom) df_preds = load_df_wbm_with_preds().round(3) -# for combo in [("CHGNet", "M3GNet")]: -# df_preds[" + ".join(combo)] = df_preds[combo].mean(axis=1) # noqa: ERA001 -# Model[" + ".join(combo)] = "combo" # noqa: ERA001 df_metrics = pd.DataFrame() diff --git a/site/src/routes/tasks/geo-opt/geo-opt-readme.md b/site/src/routes/tasks/geo-opt/geo-opt-readme.md index 843fa1ea..62a6a48a 100644 --- a/site/src/routes/tasks/geo-opt/geo-opt-readme.md +++ b/site/src/routes/tasks/geo-opt/geo-opt-readme.md @@ -1,14 +1,14 @@ # MLFF Geometry Optimization Analysis -> Disclaimer: There is a caveat to the structure similarity analysis below. The WBM test set was generated using the `MPRelaxSet` which applies [`ISYM=2`](https://vasp.at/wiki/index.php/ISIF). This fixes the structure's symmetry. The MLFFs by contrast use the FIRE or LBFGS optimizers with no symmetry constraints. They may therefore in some cases relax to lower energy states with different symmetry. This is not a mistake of the model and so higher σmatch (the percentage of structures with matching ML and DFT spacegroups) is not necessarily indicative of a better model. Thanks to Alex Ganose for pointing this out! Undiscovered lower energy structures in the relatively well-explored chemical systems covered by WBM and MP are not expected to be a common occurrence. Hence we believe this analysis still provides some signal so we left it on this secluded page. +> Disclaimer: There is a caveat to the structure similarity analysis below. The WBM test set was generated using the `MPRelaxSet` which applies [`ISYM=2`](https://vasp.at/wiki/index.php/ISIF). This fixes the structure's symmetry. The MLFFs by contrast use the FIRE or LBFGS optimizers with no symmetry constraints. They may therefore in some cases relax to lower energy states with different symmetry. This is not a mistake of the model and so higher σmatch (the percentage of structures with matching ML and DFT spacegroups) is not necessarily indicative of a better model. Thanks to Alex Ganose for pointing this out! Undiscovered lower energy structures in the relatively well-explored chemical systems covered by WBM and MP are not expected to be a common occurrence. Hence we believe this analysis still provides some useful insight. All plots/metrics below evaluate the quality of MLFF relaxations for the 257k crystal structures in the [WBM test set](https://nature.com/articles/s41524-020-00481-6). Not all models were able to relax all structures (user/cluster error may explain some failures) but every model was evaluated on at least relaxations. -Symmetry detection was performed with the excellent Rust library [`moyopy`](https://github.com/spglib/moyo), a ~4x faster successor to the already outstanding [`spglib`](https://spglib.readthedocs.io). +Symmetry detection was performed with the excellent Rust library [`moyopy`](https://github.com/spglib/moyo), a ~4x faster successor to [`spglib`](https://spglib.readthedocs.io). -> σmatch / σdec / σinc denote the fraction of structures that retain, increase, or decrease the symmetry of the DFT-relaxed structure during MLFF relaxation. The match criterion is for the ML ground state to have identical spacegroup as DFT. For σdec / σinc, ML relaxation increased / decreased the set of symmetry operations on a structure. Note that the symmetry metrics are sensitive to the `symprec` value passed to `spglib` so we show results for multiple values. See the [`spglib` docs](https://spglib.readthedocs.io/en/latest/variable.html#symprec) and [paper](https://arxiv.org/html/1808.01590v2) for details. +> σmatch / σdec / σinc denote the fraction of structures that retain, increase, or decrease the symmetry of the DFT-relaxed structure during MLFF relaxation. The match criterion is for the ML ground state to have identical spacegroup as DFT. For σdec / σinc, the number of symmetry operations for a structure increased / decreased during MLFF relaxation. Note that the symmetry metrics are sensitive to the `symprec` value passed to `spglib` so we show results for multiple values. See the [`spglib` docs](https://spglib.readthedocs.io/en/latest/variable.html#symprec) and [paper](https://arxiv.org/html/1808.01590v2) for details.
diff --git a/tests/test_data.py b/tests/test_data.py index 2e668a61..235c8c0c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -2,6 +2,7 @@ import zipfile from pathlib import Path from typing import Any +from unittest.mock import patch import numpy as np import pandas as pd @@ -11,15 +12,19 @@ from pymatgen.core import Lattice, Structure from pymatviz.enums import Key +from matbench_discovery import DATA_DIR from matbench_discovery.data import ( DataFiles, + Files, Model, as_dict_handler, ase_atoms_from_zip, ase_atoms_to_zip, df_wbm, glob_to_df, + load_df_wbm_with_preds, ) +from matbench_discovery.enums import MbdKey, TestSubset structure = Structure( lattice=Lattice.cubic(5), @@ -172,9 +177,28 @@ def test_ase_atoms_from_zip_with_limit(tmp_path: Path) -> None: assert len(read_atoms) == 2 +def test_files() -> None: + """Test error handling in Files enum.""" + + assert Files.base_dir == DATA_DIR + + # Test custom base_dir + class SubFiles(Files, base_dir="foo"): + pass + + assert SubFiles.base_dir == "foo" + + # Test invalid label lookup + label = "invalid-label" + with pytest.raises(ValueError, match=f"{label=} not found in Files"): + Files.from_label(label) + + def test_data_files() -> None: """Test DataFiles enum functionality.""" # Test that paths are constructed correctly + assert str(DataFiles.mp_energies) == f"{DATA_DIR}/mp/2023-01-10-mp-energies.csv.gz" + assert repr(DataFiles.mp_energies) == "DataFiles.mp_energies" assert DataFiles.mp_energies.name == "mp_energies" assert ( DataFiles.mp_energies.url == "https://figshare.com/ndownloader/files/49083124" @@ -183,6 +207,7 @@ def test_data_files() -> None: # Test that multiple files exist and have correct attributes assert DataFiles.wbm_summary.rel_path == "wbm/2023-12-13-wbm-summary.csv.gz" + assert DataFiles.wbm_summary.path == f"{DATA_DIR}/wbm/2023-12-13-wbm-summary.csv.gz" assert ( DataFiles.wbm_summary.url == "https://figshare.com/ndownloader/files/44225498" ) @@ -228,3 +253,106 @@ def test_data_files_urls(data_file: DataFiles) -> None: # check that the URL is valid by sending a head request response = requests.head(url, allow_redirects=True, timeout=5) assert response.status_code in {200, 403}, f"Invalid URL for {name}: {url}" + + +def test_download_file(tmp_path: Path, capsys: pytest.CaptureFixture) -> None: + """Test download_file function.""" + + from matbench_discovery.data import download_file + + url = "https://example.com/test.txt" + test_content = b"test content" + dest_path = tmp_path / "test.txt" + + # Mock successful request + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response._content = test_content # noqa: SLF001 + + with patch("requests.get", return_value=mock_response): + download_file(str(dest_path), url) + assert dest_path.read_bytes() == test_content + + # Mock failed request + mock_response = requests.Response() + mock_response.status_code = 404 + mock_response._content = b"Not found" # noqa: SLF001 + + with patch("requests.get", return_value=mock_response): + download_file(str(dest_path), url) # Should print error but not raise + + stdout, stderr = capsys.readouterr() + assert f"Error downloading {url=}" in stdout + assert stderr == "" + + +@pytest.mark.parametrize("models", [[], ["wrenformer"]]) +@pytest.mark.parametrize("max_error_threshold", [None, 5.0, 1.0]) +def test_load_df_wbm_with_preds( + models: list[str], max_error_threshold: float | None +) -> None: + df_wbm_with_preds = load_df_wbm_with_preds( + models=models, max_error_threshold=max_error_threshold + ) + assert len(df_wbm_with_preds) == len(df_wbm) + + assert list(df_wbm_with_preds) == list(df_wbm) + [ + Model[model].label for model in models + ] + assert df_wbm_with_preds.index.name == Key.mat_id + + for model_name in models: + model = Model[model_name] + assert model.label in df_wbm_with_preds + if max_error_threshold is not None: + # Check if predictions exceeding the threshold are filtered out + error = abs( + df_wbm_with_preds[model.label] - df_wbm_with_preds[MbdKey.e_form_dft] + ) + assert np.all(error[~error.isna()] <= max_error_threshold) + else: + # If no threshold is set, all predictions should be present + assert df_wbm_with_preds[model.label].isna().sum() == 0 + + +def test_load_df_wbm_max_error_threshold() -> None: + models = {Model.mace.label: 38} # num missing preds for default max_error_threshold + df_no_thresh = load_df_wbm_with_preds(models=list(models)) + df_high_thresh = load_df_wbm_with_preds(models=list(models), max_error_threshold=10) + df_low_thresh = load_df_wbm_with_preds(models=list(models), max_error_threshold=0.1) + + for model, n_missing in models.items(): + assert df_no_thresh[model].isna().sum() == n_missing + assert df_high_thresh[model].isna().sum() <= df_no_thresh[model].isna().sum() + assert df_high_thresh[model].isna().sum() <= df_low_thresh[model].isna().sum() + + +def test_load_df_wbm_with_preds_errors(df_float: pd.DataFrame) -> None: + """Test error handling in load_df_wbm_with_preds function.""" + + # Test invalid model name + with pytest.raises(ValueError, match="expected subset of"): + load_df_wbm_with_preds(models=["InvalidModel"]) + + # Test negative error threshold + with pytest.raises( + ValueError, match="max_error_threshold must be a positive number" + ): + load_df_wbm_with_preds(max_error_threshold=-1) + + # Test pred_col not in predictions file + with ( + patch("pandas.read_csv", return_value=df_float), + pytest.raises(ValueError, match="pred_col.*not found in"), + ): + load_df_wbm_with_preds(models=["alignn"]) + + +@pytest.mark.parametrize( + "subset", + ["unique_prototypes", TestSubset.uniq_protos, ["wbm-1-1", "wbm-1-2"], None], +) +def test_load_df_wbm_with_preds_subset(subset: Any) -> None: + """Test subset handling in load_df_wbm_with_preds.""" + df_wbm = load_df_wbm_with_preds(subset=subset) + assert isinstance(df_wbm, pd.DataFrame) diff --git a/tests/test_preds.py b/tests/test_preds.py index 1bf307d0..23527a22 100644 --- a/tests/test_preds.py +++ b/tests/test_preds.py @@ -1,9 +1,5 @@ import os -import numpy as np -import pytest -from pymatviz.enums import Key - from matbench_discovery.data import df_wbm from matbench_discovery.enums import MbdKey from matbench_discovery.preds.discovery import ( @@ -11,7 +7,6 @@ df_each_err, df_each_pred, df_metrics, - load_df_wbm_with_preds, ) @@ -57,61 +52,11 @@ def test_df_each_err() -> None: ) -@pytest.mark.parametrize("models", [[], ["wrenformer"]]) -@pytest.mark.parametrize("max_error_threshold", [None, 5.0, 1.0]) -def test_load_df_wbm_with_preds( - models: list[str], max_error_threshold: float | None -) -> None: - df_wbm_with_preds = load_df_wbm_with_preds( - models=models, max_error_threshold=max_error_threshold - ) - assert len(df_wbm_with_preds) == len(df_wbm) - - assert list(df_wbm_with_preds) == list(df_wbm) + [ - Model[model].label for model in models - ] - assert df_wbm_with_preds.index.name == Key.mat_id - - for model_name in models: - model = Model[model_name] - assert model.label in df_wbm_with_preds - if max_error_threshold is not None: - # Check if predictions exceeding the threshold are filtered out - error = abs( - df_wbm_with_preds[model.label] - df_wbm_with_preds[MbdKey.e_form_dft] - ) - assert np.all(error[~error.isna()] <= max_error_threshold) - else: - # If no threshold is set, all predictions should be present - assert df_wbm_with_preds[model.label].isna().sum() == 0 - - -def test_load_df_wbm_max_error_threshold() -> None: - models = {Model.mace.label: 38} # num missing preds for default max_error_threshold - df_no_thresh = load_df_wbm_with_preds(models=list(models)) - df_high_thresh = load_df_wbm_with_preds(models=list(models), max_error_threshold=10) - df_low_thresh = load_df_wbm_with_preds(models=list(models), max_error_threshold=0.1) - - for model, n_missing in models.items(): - assert df_no_thresh[model].isna().sum() == n_missing - assert df_high_thresh[model].isna().sum() <= df_no_thresh[model].isna().sum() - assert df_high_thresh[model].isna().sum() <= df_low_thresh[model].isna().sum() - - -def test_load_df_wbm_with_preds_raises() -> None: - with pytest.raises(ValueError, match="unknown_models='foo'"): - load_df_wbm_with_preds(models=["foo"]) - - with pytest.raises( - ValueError, match="max_error_threshold must be a positive number" - ): - load_df_wbm_with_preds(max_error_threshold=-1.0) - - def test_pred_files() -> None: assert len(Model) >= 6 for model in Model: - assert model.discovery_path.endswith(".csv.gz") + pred_path = model.discovery_path + assert pred_path.endswith(".csv.gz") assert os.path.isfile( - model.discovery_path - ), f"{model=} missing discovery pred file, expected at {model.discovery_path}" + pred_path + ), f"discovery pred file for {model=} not found, expected at {pred_path}"