diff --git a/tests/metatensor/test_base_metatensor.py b/tests/metatensor/test_base_metatensor.py index 239bec22..dbaa985b 100644 --- a/tests/metatensor/test_base_metatensor.py +++ b/tests/metatensor/test_base_metatensor.py @@ -1,13 +1,14 @@ import pytest import torch -from metatensor.torch import Labels, TensorBlock -from metatensor.torch.atomistic import System from packaging import version -from meshlode.metatensor.base import CalculatorBaseMetatensor +mts_torch = pytest.importorskip("metatensor.torch") +mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") +meshlode_metatensor = pytest.importorskip("meshlode.metatensor") -class CalculatorTest(CalculatorBaseMetatensor): + +class CalculatorTest(meshlode_metatensor.base.CalculatorBaseMetatensor): def _compute_single_system( self, positions, charges, cell, neighbor_indices, neighbor_shifts ): @@ -16,18 +17,18 @@ def _compute_single_system( @pytest.mark.parametrize("method_name", ["compute", "forward"]) def test_compute_output_shapes_single(method_name): - system = System( + system = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) - data = TensorBlock( + data = mts_torch.TensorBlock( values=charges, - samples=Labels.range("atom", charges.shape[0]), + samples=mts_torch.Labels.range("atom", charges.shape[0]), components=[], - properties=Labels.range("charge", charges.shape[1]), + properties=mts_torch.Labels.range("charge", charges.shape[1]), ) system.add_data(name="charges", data=data) @@ -50,18 +51,18 @@ def test_compute_output_shapes_single(method_name): def test_compute_output_shapes_multiple(): - system = System( + system = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) - data = TensorBlock( + data = mts_torch.TensorBlock( values=charges, - samples=Labels.range("atom", charges.shape[0]), + samples=mts_torch.Labels.range("atom", charges.shape[0]), components=[], - properties=Labels.range("charge", charges.shape[1]), + properties=mts_torch.Labels.range("charge", charges.shape[1]), ) system.add_data(name="charges", data=data) @@ -82,13 +83,13 @@ def test_compute_output_shapes_multiple(): def test_wrong_system_dtype(): - system1 = System( + system1 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) - system2 = System( + system2 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], dtype=torch.float64), cell=torch.zeros([3, 3], dtype=torch.float64), @@ -102,13 +103,13 @@ def test_wrong_system_dtype(): def test_wrong_system_device(): - system1 = System( + system1 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) - system2 = System( + system2 = mts_atomistic.System( types=torch.tensor([1, 1], device="meta"), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], device="meta"), cell=torch.zeros([3, 3], device="meta"), @@ -122,23 +123,23 @@ def test_wrong_system_device(): def test_wrong_system_not_all_charges(): - system1 = System( + system1 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) - data = TensorBlock( + data = mts_torch.TensorBlock( values=charges, - samples=Labels.range("atom", charges.shape[0]), + samples=mts_torch.Labels.range("atom", charges.shape[0]), components=[], - properties=Labels.range("charge", charges.shape[1]), + properties=mts_torch.Labels.range("charge", charges.shape[1]), ) system1.add_data(name="charges", data=data) - system2 = System( + system2 = mts_atomistic.System( types=torch.tensor( [1, 1], ), @@ -154,34 +155,34 @@ def test_wrong_system_not_all_charges(): def test_different_number_charge_channles(): - system1 = System( + system1 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) charges1 = torch.tensor([1.0, -1.0]).reshape(-1, 1) - data1 = TensorBlock( + data1 = mts_torch.TensorBlock( values=charges1, - samples=Labels.range("atom", charges1.shape[0]), + samples=mts_torch.Labels.range("atom", charges1.shape[0]), components=[], - properties=Labels.range("charge", charges1.shape[1]), + properties=mts_torch.Labels.range("charge", charges1.shape[1]), ) system1.add_data(name="charges", data=data1) - system2 = System( + system2 = mts_atomistic.System( types=torch.tensor([1, 1]), positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), cell=torch.zeros([3, 3]), ) charges2 = torch.tensor([[1.0, 2.0], [-1.0, -2.0]]) - data2 = TensorBlock( + data2 = mts_torch.TensorBlock( values=charges2, - samples=Labels.range("atom", charges2.shape[0]), + samples=mts_torch.Labels.range("atom", charges2.shape[0]), components=[], - properties=Labels.range("charge", charges2.shape[1]), + properties=mts_torch.Labels.range("charge", charges2.shape[1]), ) system2.add_data(name="charges", data=data2) diff --git a/tests/metatensor/test_calculators.py b/tests/metatensor/test_calculators.py index 09d95fa2..137d158f 100644 --- a/tests/metatensor/test_calculators.py +++ b/tests/metatensor/test_calculators.py @@ -7,9 +7,9 @@ from packaging import version -meshlode_metatensor = pytest.importorskip("meshlode.metatensor") mts_torch = pytest.importorskip("metatensor.torch") mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") +meshlode_metatensor = pytest.importorskip("meshlode.metatensor") ATOMIC_SMEARING = 0.1 diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index 09d95fa2..137d158f 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -7,9 +7,9 @@ from packaging import version -meshlode_metatensor = pytest.importorskip("meshlode.metatensor") mts_torch = pytest.importorskip("metatensor.torch") mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") +meshlode_metatensor = pytest.importorskip("meshlode.metatensor") ATOMIC_SMEARING = 0.1