Skip to content

Commit

Permalink
increase data.py coverage from 84% to 90%
Browse files Browse the repository at this point in the history
- add unit tests for Files enum and download_file
- move load_df_wbm_with_preds from test_preds.py to test_data.py
- .gitignore .coverage* files
- contributing.md remove leading slashes from model pred file paths
- better error handling in Files enum by raising a ValueError when a label not found
  • Loading branch information
janosh committed Dec 27, 2024
1 parent a937a1a commit 67a750c
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 72 deletions.
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*.egg-info
dist
build
.coverage*

# cache
__pycache__
Expand Down Expand Up @@ -33,8 +34,5 @@ models/**/*.tgz
# auto-generated docs
site/src/routes/api/*.md

# temporary ignore rules
data/mp/mptrj-gga-ggapu/*

# large files
data/*-models-geo-opt-analysis-symprec=*.csv.gz
4 changes: 2 additions & 2 deletions contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ To submit a new model to this benchmark and add it to our leaderboard, please cr
metrics:
discovery:
pred_file: /models/<model_dir>/<yyyy-mm-dd>-<model_name>-wbm-IS2RE.csv.gz # should contain the models energy predictions for the WBM test set
pred_file: models/<model_dir>/<yyyy-mm-dd>-<model_name>-wbm-IS2RE.csv.gz # should contain the models energy predictions for the WBM test set
pred_col: e_form_per_atom_<model_name>
geo_opt: # only applicable if the model performed structure relaxation
pred_file: /models/<model_dir>/<yyyy-mm-dd>-<model_name>-wbm-IS2RE.json.gz # should contain the models relaxed structures as ASE Atoms or pymatgen Structures, and separate columns for material_id and energies/forces/stresses at each relaxation step
pred_file: models/<model_dir>/<yyyy-mm-dd>-<model_name>-wbm-IS2RE.json.gz # should contain the models relaxed structures as ASE Atoms or pymatgen Structures, and separate columns for material_id and energies/forces/stresses at each relaxation step
pred_col: e_form_per_atom_<model_name>
```
Expand Down
5 changes: 4 additions & 1 deletion matbench_discovery/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,10 @@ def label(self) -> str:
@classmethod
def from_label(cls, label: str) -> Self:
"""Get enum member from pretty label."""
return next(attr for attr in cls if attr.label == label)
file = next((attr for attr in cls if attr.label == label), None)
if file is None:
raise ValueError(f"{label=} not found in {cls.__name__}")
return file


class DataFiles(Files):
Expand Down
2 changes: 1 addition & 1 deletion matbench_discovery/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
try:
ModelType(model_data.get("model_type")) # check if model_type is valid
except ValueError as exc:
exc.add_note(f"{metadata_file=}")
exc.add_note(f"{metadata_file=}\nPick from {', '.join(ModelType)}")
raise


Expand Down
3 changes: 0 additions & 3 deletions matbench_discovery/preds/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@

# load WBM summary dataframe with all models' formation energy predictions (eV/atom)
df_preds = load_df_wbm_with_preds().round(3)
# for combo in [("CHGNet", "M3GNet")]:
# df_preds[" + ".join(combo)] = df_preds[combo].mean(axis=1) # noqa: ERA001
# Model[" + ".join(combo)] = "combo" # noqa: ERA001


df_metrics = pd.DataFrame()
Expand Down
6 changes: 3 additions & 3 deletions site/src/routes/tasks/geo-opt/geo-opt-readme.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# MLFF Geometry Optimization Analysis

> Disclaimer: There is a caveat to the structure similarity analysis below. The WBM test set was generated using the `MPRelaxSet` which applies [`ISYM=2`](https://vasp.at/wiki/index.php/ISIF). This fixes the structure's symmetry. The MLFFs by contrast use the FIRE or LBFGS optimizers with no symmetry constraints. They may therefore in some cases relax to lower energy states with different symmetry. This is not a mistake of the model and so higher σ<sub>match</sub> (the percentage of structures with matching ML and DFT spacegroups) is not necessarily indicative of a better model. Thanks to Alex Ganose for pointing this out! Undiscovered lower energy structures in the relatively well-explored chemical systems covered by WBM and MP are not expected to be a common occurrence. Hence we believe this analysis still provides some signal so we left it on this secluded page.
> Disclaimer: There is a caveat to the structure similarity analysis below. The WBM test set was generated using the `MPRelaxSet` which applies [`ISYM=2`](https://vasp.at/wiki/index.php/ISIF). This fixes the structure's symmetry. The MLFFs by contrast use the FIRE or LBFGS optimizers with no symmetry constraints. They may therefore in some cases relax to lower energy states with different symmetry. This is not a mistake of the model and so higher σ<sub>match</sub> (the percentage of structures with matching ML and DFT spacegroups) is not necessarily indicative of a better model. Thanks to Alex Ganose for pointing this out! Undiscovered lower energy structures in the relatively well-explored chemical systems covered by WBM and MP are not expected to be a common occurrence. Hence we believe this analysis still provides some useful insight.
All plots/metrics below evaluate the quality of MLFF relaxations for the 257k crystal structures in the [WBM test set](https://nature.com/articles/s41524-020-00481-6). Not all models were able to relax all structures (user/cluster error may explain some failures) but every model was evaluated on at least <slot name="min-relaxed-structures"/> relaxations.

Symmetry detection was performed with the excellent Rust library [`moyopy`](https://github.com/spglib/moyo), a ~4x faster successor to the already outstanding [`spglib`](https://spglib.readthedocs.io).
Symmetry detection was performed with the excellent Rust library [`moyopy`](https://github.com/spglib/moyo), a ~4x faster successor to [`spglib`](https://spglib.readthedocs.io).

<slot name="geo-opt-metrics-table"/>

> σ<sub>match</sub> / σ<sub>dec</sub> / σ<sub>inc</sub> denote the fraction of structures that retain, increase, or decrease the symmetry of the DFT-relaxed structure during MLFF relaxation. The match criterion is for the ML ground state to have identical spacegroup as DFT. For σ<sub>dec</sub> / σ<sub>inc</sub>, ML relaxation increased / decreased the set of symmetry operations on a structure. Note that the symmetry metrics are sensitive to the `symprec` value passed to `spglib` so we show results for multiple values. See the [`spglib` docs](https://spglib.readthedocs.io/en/latest/variable.html#symprec) and [paper](https://arxiv.org/html/1808.01590v2) for details.
> σ<sub>match</sub> / σ<sub>dec</sub> / σ<sub>inc</sub> denote the fraction of structures that retain, increase, or decrease the symmetry of the DFT-relaxed structure during MLFF relaxation. The match criterion is for the ML ground state to have identical spacegroup as DFT. For σ<sub>dec</sub> / σ<sub>inc</sub>, the number of symmetry operations for a structure increased / decreased during MLFF relaxation. Note that the symmetry metrics are sensitive to the `symprec` value passed to `spglib` so we show results for multiple values. See the [`spglib` docs](https://spglib.readthedocs.io/en/latest/variable.html#symprec) and [paper](https://arxiv.org/html/1808.01590v2) for details.
<hr />

Expand Down
128 changes: 128 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import zipfile
from pathlib import Path
from typing import Any
from unittest.mock import patch

import numpy as np
import pandas as pd
Expand All @@ -11,15 +12,19 @@
from pymatgen.core import Lattice, Structure
from pymatviz.enums import Key

from matbench_discovery import DATA_DIR
from matbench_discovery.data import (
DataFiles,
Files,
Model,
as_dict_handler,
ase_atoms_from_zip,
ase_atoms_to_zip,
df_wbm,
glob_to_df,
load_df_wbm_with_preds,
)
from matbench_discovery.enums import MbdKey, TestSubset

structure = Structure(
lattice=Lattice.cubic(5),
Expand Down Expand Up @@ -172,9 +177,28 @@ def test_ase_atoms_from_zip_with_limit(tmp_path: Path) -> None:
assert len(read_atoms) == 2


def test_files() -> None:
"""Test error handling in Files enum."""

assert Files.base_dir == DATA_DIR

# Test custom base_dir
class SubFiles(Files, base_dir="foo"):
pass

assert SubFiles.base_dir == "foo"

# Test invalid label lookup
label = "invalid-label"
with pytest.raises(ValueError, match=f"{label=} not found in Files"):
Files.from_label(label)


def test_data_files() -> None:
"""Test DataFiles enum functionality."""
# Test that paths are constructed correctly
assert str(DataFiles.mp_energies) == f"{DATA_DIR}/mp/2023-01-10-mp-energies.csv.gz"
assert repr(DataFiles.mp_energies) == "DataFiles.mp_energies"
assert DataFiles.mp_energies.name == "mp_energies"
assert (
DataFiles.mp_energies.url == "https://figshare.com/ndownloader/files/49083124"
Expand All @@ -183,6 +207,7 @@ def test_data_files() -> None:

# Test that multiple files exist and have correct attributes
assert DataFiles.wbm_summary.rel_path == "wbm/2023-12-13-wbm-summary.csv.gz"
assert DataFiles.wbm_summary.path == f"{DATA_DIR}/wbm/2023-12-13-wbm-summary.csv.gz"
assert (
DataFiles.wbm_summary.url == "https://figshare.com/ndownloader/files/44225498"
)
Expand Down Expand Up @@ -228,3 +253,106 @@ def test_data_files_urls(data_file: DataFiles) -> None:
# check that the URL is valid by sending a head request
response = requests.head(url, allow_redirects=True, timeout=5)
assert response.status_code in {200, 403}, f"Invalid URL for {name}: {url}"


def test_download_file(tmp_path: Path, capsys: pytest.CaptureFixture) -> None:
"""Test download_file function."""

from matbench_discovery.data import download_file

url = "https://example.com/test.txt"
test_content = b"test content"
dest_path = tmp_path / "test.txt"

# Mock successful request
mock_response = requests.Response()
mock_response.status_code = 200
mock_response._content = test_content # noqa: SLF001

with patch("requests.get", return_value=mock_response):
download_file(str(dest_path), url)
assert dest_path.read_bytes() == test_content

# Mock failed request
mock_response = requests.Response()
mock_response.status_code = 404
mock_response._content = b"Not found" # noqa: SLF001

with patch("requests.get", return_value=mock_response):
download_file(str(dest_path), url) # Should print error but not raise

stdout, stderr = capsys.readouterr()
assert f"Error downloading {url=}" in stdout
assert stderr == ""


@pytest.mark.parametrize("models", [[], ["wrenformer"]])
@pytest.mark.parametrize("max_error_threshold", [None, 5.0, 1.0])
def test_load_df_wbm_with_preds(
models: list[str], max_error_threshold: float | None
) -> None:
df_wbm_with_preds = load_df_wbm_with_preds(
models=models, max_error_threshold=max_error_threshold
)
assert len(df_wbm_with_preds) == len(df_wbm)

assert list(df_wbm_with_preds) == list(df_wbm) + [
Model[model].label for model in models
]
assert df_wbm_with_preds.index.name == Key.mat_id

for model_name in models:
model = Model[model_name]
assert model.label in df_wbm_with_preds
if max_error_threshold is not None:
# Check if predictions exceeding the threshold are filtered out
error = abs(
df_wbm_with_preds[model.label] - df_wbm_with_preds[MbdKey.e_form_dft]
)
assert np.all(error[~error.isna()] <= max_error_threshold)
else:
# If no threshold is set, all predictions should be present
assert df_wbm_with_preds[model.label].isna().sum() == 0


def test_load_df_wbm_max_error_threshold() -> None:
models = {Model.mace.label: 38} # num missing preds for default max_error_threshold
df_no_thresh = load_df_wbm_with_preds(models=list(models))
df_high_thresh = load_df_wbm_with_preds(models=list(models), max_error_threshold=10)
df_low_thresh = load_df_wbm_with_preds(models=list(models), max_error_threshold=0.1)

for model, n_missing in models.items():
assert df_no_thresh[model].isna().sum() == n_missing
assert df_high_thresh[model].isna().sum() <= df_no_thresh[model].isna().sum()
assert df_high_thresh[model].isna().sum() <= df_low_thresh[model].isna().sum()


def test_load_df_wbm_with_preds_errors(df_float: pd.DataFrame) -> None:
"""Test error handling in load_df_wbm_with_preds function."""

# Test invalid model name
with pytest.raises(ValueError, match="expected subset of"):
load_df_wbm_with_preds(models=["InvalidModel"])

# Test negative error threshold
with pytest.raises(
ValueError, match="max_error_threshold must be a positive number"
):
load_df_wbm_with_preds(max_error_threshold=-1)

# Test pred_col not in predictions file
with (
patch("pandas.read_csv", return_value=df_float),
pytest.raises(ValueError, match="pred_col.*not found in"),
):
load_df_wbm_with_preds(models=["alignn"])


@pytest.mark.parametrize(
"subset",
["unique_prototypes", TestSubset.uniq_protos, ["wbm-1-1", "wbm-1-2"], None],
)
def test_load_df_wbm_with_preds_subset(subset: Any) -> None:
"""Test subset handling in load_df_wbm_with_preds."""
df_wbm = load_df_wbm_with_preds(subset=subset)
assert isinstance(df_wbm, pd.DataFrame)
63 changes: 4 additions & 59 deletions tests/test_preds.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import os

import numpy as np
import pytest
from pymatviz.enums import Key

from matbench_discovery.data import df_wbm
from matbench_discovery.enums import MbdKey
from matbench_discovery.preds.discovery import (
Model,
df_each_err,
df_each_pred,
df_metrics,
load_df_wbm_with_preds,
)


Expand Down Expand Up @@ -57,61 +52,11 @@ def test_df_each_err() -> None:
)


@pytest.mark.parametrize("models", [[], ["wrenformer"]])
@pytest.mark.parametrize("max_error_threshold", [None, 5.0, 1.0])
def test_load_df_wbm_with_preds(
models: list[str], max_error_threshold: float | None
) -> None:
df_wbm_with_preds = load_df_wbm_with_preds(
models=models, max_error_threshold=max_error_threshold
)
assert len(df_wbm_with_preds) == len(df_wbm)

assert list(df_wbm_with_preds) == list(df_wbm) + [
Model[model].label for model in models
]
assert df_wbm_with_preds.index.name == Key.mat_id

for model_name in models:
model = Model[model_name]
assert model.label in df_wbm_with_preds
if max_error_threshold is not None:
# Check if predictions exceeding the threshold are filtered out
error = abs(
df_wbm_with_preds[model.label] - df_wbm_with_preds[MbdKey.e_form_dft]
)
assert np.all(error[~error.isna()] <= max_error_threshold)
else:
# If no threshold is set, all predictions should be present
assert df_wbm_with_preds[model.label].isna().sum() == 0


def test_load_df_wbm_max_error_threshold() -> None:
models = {Model.mace.label: 38} # num missing preds for default max_error_threshold
df_no_thresh = load_df_wbm_with_preds(models=list(models))
df_high_thresh = load_df_wbm_with_preds(models=list(models), max_error_threshold=10)
df_low_thresh = load_df_wbm_with_preds(models=list(models), max_error_threshold=0.1)

for model, n_missing in models.items():
assert df_no_thresh[model].isna().sum() == n_missing
assert df_high_thresh[model].isna().sum() <= df_no_thresh[model].isna().sum()
assert df_high_thresh[model].isna().sum() <= df_low_thresh[model].isna().sum()


def test_load_df_wbm_with_preds_raises() -> None:
with pytest.raises(ValueError, match="unknown_models='foo'"):
load_df_wbm_with_preds(models=["foo"])

with pytest.raises(
ValueError, match="max_error_threshold must be a positive number"
):
load_df_wbm_with_preds(max_error_threshold=-1.0)


def test_pred_files() -> None:
assert len(Model) >= 6
for model in Model:
assert model.discovery_path.endswith(".csv.gz")
pred_path = model.discovery_path
assert pred_path.endswith(".csv.gz")
assert os.path.isfile(
model.discovery_path
), f"{model=} missing discovery pred file, expected at {model.discovery_path}"
pred_path
), f"discovery pred file for {model=} not found, expected at {pred_path}"

0 comments on commit 67a750c

Please sign in to comment.