diff --git a/src/meshlode/metatensor/base.py b/src/meshlode/metatensor/base.py index 57dbb0c7..92ac8b74 100644 --- a/src/meshlode/metatensor/base.py +++ b/src/meshlode/metatensor/base.py @@ -93,8 +93,9 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: """ systems = self._validate_compute_parameters(systems) potentials: List[torch.Tensor] = [] + values_samples: List[List[int]] = [] - for system in systems: + for i_system, system in enumerate(systems): charges = system.get_data("charges").values all_neighbor_lists = system.known_neighbor_lists() if all_neighbor_lists: @@ -137,11 +138,8 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: neighbor_shifts=neighbor_shifts, ) ) - system = systems[-1] - values_samples: List[List[int]] = [] - for i_system in range(len(systems)): - for i_atom in range(len(system)): - values_samples.append([i_system, i_atom]) + + values_samples += [[i_system, i_atom] for i_atom in range(len(system))] samples_values = torch.tensor(values_samples, device=self._device) properties_values = torch.arange(self._n_charges_channels, device=self._device) diff --git a/tests/metatensor/test_base_metatensor.py b/tests/metatensor/test_base_metatensor.py index 623de5d5..3943fd14 100644 --- a/tests/metatensor/test_base_metatensor.py +++ b/tests/metatensor/test_base_metatensor.py @@ -300,3 +300,42 @@ def test_neighborlist_half_error(): match = r"Found 1 neighbor list\(s\) but no full list, which is required." with pytest.raises(ValueError, match=match): calculator.compute(system) + + +def test_systems_with_different_number_of_atoms(): + system1 = mts_atomistic.System( + types=torch.tensor([1, 1, 8]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0], [0.0, 2.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges1 = torch.tensor([1.0, -1.0, 2.0]).reshape(-1, 1) + data1 = mts_torch.TensorBlock( + values=charges1, + samples=mts_torch.Labels.range("atom", charges1.shape[0]), + components=[], + properties=mts_torch.Labels.range("charge", charges1.shape[1]), + ) + + system1.add_data(name="charges", data=data1) + + system2 = mts_atomistic.System( + types=torch.tensor( + [1, 1], + ), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.zeros([3, 3]), + ) + + charges2 = torch.tensor([1.0, -1.0]).reshape(-1, 1) + data2 = mts_torch.TensorBlock( + values=charges2, + samples=mts_torch.Labels.range("atom", charges2.shape[0]), + components=[], + properties=mts_torch.Labels.range("charge", charges2.shape[1]), + ) + + system2.add_data(name="charges", data=data2) + calculator = CalculatorTest() + + calculator.compute([system1, system2])