Skip to content

Commit

Permalink
fix tests-min
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jul 10, 2024
1 parent 9be73a6 commit 19134e5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 31 deletions.
59 changes: 30 additions & 29 deletions tests/metatensor/test_base_metatensor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import pytest
import torch
from metatensor.torch import Labels, TensorBlock
from metatensor.torch.atomistic import System
from packaging import version

from meshlode.metatensor.base import CalculatorBaseMetatensor

mts_torch = pytest.importorskip("metatensor.torch")
mts_atomistic = pytest.importorskip("metatensor.torch.atomistic")
meshlode_metatensor = pytest.importorskip("meshlode.metatensor")

class CalculatorTest(CalculatorBaseMetatensor):

class CalculatorTest(meshlode_metatensor.base.CalculatorBaseMetatensor):
def _compute_single_system(
self, positions, charges, cell, neighbor_indices, neighbor_shifts
):
Expand All @@ -16,18 +17,18 @@ def _compute_single_system(

@pytest.mark.parametrize("method_name", ["compute", "forward"])
def test_compute_output_shapes_single(method_name):
system = System(
system = mts_atomistic.System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

charges = torch.tensor([1.0, -1.0]).reshape(-1, 1)
data = TensorBlock(
data = mts_torch.TensorBlock(
values=charges,
samples=Labels.range("atom", charges.shape[0]),
samples=mts_torch.Labels.range("atom", charges.shape[0]),
components=[],
properties=Labels.range("charge", charges.shape[1]),
properties=mts_torch.Labels.range("charge", charges.shape[1]),
)

system.add_data(name="charges", data=data)
Expand All @@ -50,18 +51,18 @@ def test_compute_output_shapes_single(method_name):

def test_compute_output_shapes_multiple():

system = System(
system = mts_atomistic.System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

charges = torch.tensor([1.0, -1.0]).reshape(-1, 1)
data = TensorBlock(
data = mts_torch.TensorBlock(
values=charges,
samples=Labels.range("atom", charges.shape[0]),
samples=mts_torch.Labels.range("atom", charges.shape[0]),
components=[],
properties=Labels.range("charge", charges.shape[1]),
properties=mts_torch.Labels.range("charge", charges.shape[1]),
)

system.add_data(name="charges", data=data)
Expand All @@ -82,13 +83,13 @@ def test_compute_output_shapes_multiple():


def test_wrong_system_dtype():
system1 = System(
system1 = mts_atomistic.System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

system2 = System(
system2 = mts_atomistic.System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], dtype=torch.float64),
cell=torch.zeros([3, 3], dtype=torch.float64),
Expand All @@ -102,13 +103,13 @@ def test_wrong_system_dtype():


def test_wrong_system_device():
system1 = System(
system1 = mts_atomistic.System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

system2 = System(
system2 = mts_atomistic.System(
types=torch.tensor([1, 1], device="meta"),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], device="meta"),
cell=torch.zeros([3, 3], device="meta"),
Expand All @@ -122,23 +123,23 @@ def test_wrong_system_device():


def test_wrong_system_not_all_charges():
system1 = System(
system1 = mts_atomistic.System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

charges = torch.tensor([1.0, -1.0]).reshape(-1, 1)
data = TensorBlock(
data = mts_torch.TensorBlock(
values=charges,
samples=Labels.range("atom", charges.shape[0]),
samples=mts_torch.Labels.range("atom", charges.shape[0]),
components=[],
properties=Labels.range("charge", charges.shape[1]),
properties=mts_torch.Labels.range("charge", charges.shape[1]),
)

system1.add_data(name="charges", data=data)

system2 = System(
system2 = mts_atomistic.System(
types=torch.tensor(
[1, 1],
),
Expand All @@ -154,34 +155,34 @@ def test_wrong_system_not_all_charges():


def test_different_number_charge_channles():
system1 = System(
system1 = mts_atomistic.System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

charges1 = torch.tensor([1.0, -1.0]).reshape(-1, 1)
data1 = TensorBlock(
data1 = mts_torch.TensorBlock(
values=charges1,
samples=Labels.range("atom", charges1.shape[0]),
samples=mts_torch.Labels.range("atom", charges1.shape[0]),
components=[],
properties=Labels.range("charge", charges1.shape[1]),
properties=mts_torch.Labels.range("charge", charges1.shape[1]),
)

system1.add_data(name="charges", data=data1)

system2 = System(
system2 = mts_atomistic.System(
types=torch.tensor([1, 1]),
positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]),
cell=torch.zeros([3, 3]),
)

charges2 = torch.tensor([[1.0, 2.0], [-1.0, -2.0]])
data2 = TensorBlock(
data2 = mts_torch.TensorBlock(
values=charges2,
samples=Labels.range("atom", charges2.shape[0]),
samples=mts_torch.Labels.range("atom", charges2.shape[0]),
components=[],
properties=Labels.range("charge", charges2.shape[1]),
properties=mts_torch.Labels.range("charge", charges2.shape[1]),
)
system2.add_data(name="charges", data=data2)

Expand Down
2 changes: 1 addition & 1 deletion tests/metatensor/test_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from packaging import version


meshlode_metatensor = pytest.importorskip("meshlode.metatensor")
mts_torch = pytest.importorskip("metatensor.torch")
mts_atomistic = pytest.importorskip("metatensor.torch.atomistic")
meshlode_metatensor = pytest.importorskip("meshlode.metatensor")


ATOMIC_SMEARING = 0.1
Expand Down
2 changes: 1 addition & 1 deletion tests/metatensor/test_workflow_metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from packaging import version


meshlode_metatensor = pytest.importorskip("meshlode.metatensor")
mts_torch = pytest.importorskip("metatensor.torch")
mts_atomistic = pytest.importorskip("metatensor.torch.atomistic")
meshlode_metatensor = pytest.importorskip("meshlode.metatensor")


ATOMIC_SMEARING = 0.1
Expand Down

0 comments on commit 19134e5

Please sign in to comment.