diff --git a/src/meshlode/calculators/meshpotential.py b/src/meshlode/calculators/meshpotential.py index 6ab15748..0c43472b 100644 --- a/src/meshlode/calculators/meshpotential.py +++ b/src/meshlode/calculators/meshpotential.py @@ -108,15 +108,19 @@ def forward( types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute(types=types, positions=positions, cell=cell) + return self.compute( + types=types, positions=positions, cell=cell, charges=charges + ) def compute( self, types: Union[List[torch.Tensor], torch.Tensor], positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute potential for all provided "systems" stacked inside list. @@ -131,6 +135,7 @@ def compute( box/unit cell of the system. Each row should be one of the bounding box vector; and columns should contain the x, y, and z components of these vectors (i.e. the cell should be given in row-major order). + :param charges: Optional single or list of 2D tensor of shape (len(types), n), :return: List of torch Tensors containing the potentials for all frames and all atoms. Each tensor in the list is of shape (n_atoms, n_types), where @@ -154,6 +159,7 @@ def compute( if not isinstance(cell, list): cell = [cell] + # Check that all inputs are consistent for types_single, positions_single, cell_single in zip(types, positions, cell): if len(types_single.shape) != 1: raise ValueError( @@ -189,27 +195,59 @@ def compute( f"{cell_single.device}." ) - # We don't require and test that all dtypes and devices are consistent if a list - # of inputs. Each "frame" is processed independently. - requested_types = self._get_requested_types(types) - n_types = len(requested_types) - potentials = [] - for types_single, positions_single, cell_single in zip(types, positions, cell): - # One-hot encoding of charge information - charges = torch.zeros( - (len(types_single), n_types), - dtype=positions_single.dtype, - device=positions_single.device, - ) - for i_type, atomic_type in enumerate(requested_types): - charges[types_single == atomic_type, i_type] = 1.0 + # If charges are not provided, we assume that all types are treated separately + if charges is None: + charges = [] + for types_single, positions_single in zip(types, positions): + # One-hot encoding of charge information + charges_single = self._one_hot_charges( + types=types_single, + requested_types=requested_types, + dtype=positions_single.dtype, + device=positions_single.device, + ) + charges.append(charges_single) + + # If charges are provided, we need to make sure that they are consistent with + # the provided types + else: + if not isinstance(charges, list): + charges = [charges] + if len(charges) != len(types): + raise ValueError( + "The number of `types` and `charges` tensors must be the same, " + f"got {len(types)} and {len(charges)}." + ) + for charges_single, types_single in zip(charges, types): + if charges_single.shape[0] != len(types_single): + raise ValueError( + "The first dimension of `charges` must be the same as the " + f"length of `types`, got {charges_single.shape[0]} and " + f"{len(types_single)}." + ) + if charges[0].dtype != positions[0].dtype: + raise ValueError( + "`charges` must be have the same dtype as `positions`, got " + f"{charges[0].dtype} and {positions[0].dtype}." + ) + if charges[0].device != positions[0].device: + raise ValueError( + "`charges` must be on the same device as `positions`, got " + f"{charges[0].device} and {positions[0].device}." + ) + # We don't require and test that all dtypes and devices are consistent if a list + # of inputs. Each "frame" is processed independently. + potentials = [] + for positions_single, cell_single, charges_single in zip( + positions, cell, charges + ): # Compute the potentials potentials.append( self._compute_single_system( - positions=positions_single, charges=charges, cell=cell_single + positions=positions_single, charges=charges_single, cell=cell_single ) ) @@ -233,6 +271,21 @@ def _get_requested_types(self, types: List[torch.Tensor]) -> List[int]: else: return types_requested + def _one_hot_charges( + self, + types: torch.Tensor, + requested_types: List[int], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + n_types = len(requested_types) + one_hot_charges = torch.zeros((len(types), n_types), dtype=dtype, device=device) + + for i_type, atomic_type in enumerate(requested_types): + one_hot_charges[types == atomic_type, i_type] = 1.0 + + return one_hot_charges + def _compute_single_system( self, positions: torch.Tensor, diff --git a/src/meshlode/metatensor/meshpotential.py b/src/meshlode/metatensor/meshpotential.py index 0c11f7e3..134d725b 100644 --- a/src/meshlode/metatensor/meshpotential.py +++ b/src/meshlode/metatensor/meshpotential.py @@ -75,14 +75,22 @@ def compute( ) -> TensorMap: """Compute potential for all provided ``systems``. - All ``systems`` must have the same ``dtype`` and the same ``device``. + All ``systems`` must have the same ``dtype`` and the same ``device``. If each + system contains a custom data field `charges` the potential will be calculated + for each "charges-channel". The number of `charges-channels` must be same in all + ``systems``. If no "explicit" charges are set the potential will be calculated + for each "types-channels". + + Refer to :meth:`meshlode.MeshPotential.compute()` for additional details on how + "charges-channel" and "types-channels" are computed. :param systems: single System or list of :py:class:`metatensor.torch.atomisic.System` on which to run the - calculation. + calculations. :return: TensorMap containing the potential of all types. The keys of the - tensormap are "center_type" and "neighbor_type". + TensorMap are "center_type" and "neighbor_type" if no charges are asociated + with """ # Make sure that the compute function also works if only a single frame is # provided as input (for convenience of users testing out the code) @@ -111,38 +119,71 @@ def compute( ) n_types = len(requested_types) - # Initialize dictionary for sparse storage of the features to map atomic types - # to array indices. In general, the types are sorted according to their - # (integer) type and assigned the array indices 0, 1, 2,... Example: for H2O: - # `H` is mapped to `0` and `O` is mapped to `1`. - n_types_sq = n_types * n_types - feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_types_sq)} + has_charges = torch.tensor(["charges" in s.known_data() for s in systems]) + all_charges = torch.all(has_charges) + any_charges = torch.any(has_charges) + + if any_charges and not all_charges: + raise ValueError("`systems` do not consistently contain `charges` data") + if all_charges: + use_explicit_charges = True + n_charges_channels = systems[0].get_data("charges").values.shape[1] + spec_channels = list(range(n_charges_channels)) + key_names = ["center_type", "charges_channel"] + + for i_system, system in enumerate(systems): + n_channels = system.get_data("charges").values.shape[1] + if n_channels != n_charges_channels: + raise ValueError( + f"number of charges-channels in system index {i_system} " + f"({n_channels}) is inconsistent with first system " + f"({n_charges_channels})" + ) + else: + # Use one hot encoded type channel per species for charges channel + use_explicit_charges = False + n_charges_channels = n_types + spec_channels = requested_types + key_names = ["center_type", "neighbor_type"] + + # Initialize dictionary for TensorBlock storage. + # + # If `use_explicit_charges=False`, the blocks are sorted according to the + # (integer) center_type and neighbor_type. Blocks are assigned the array indices + # 0, 1, 2,... Example: for H2O: `H` is mapped to `0` and `O` is mapped to `1`. + # + # For `use_explicit_charges=True` the blocks are stored according to the + # center_type and charge_channel + n_blocks = n_types * n_charges_channels + feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_blocks)} for system in systems: - # One-hot encoding of charge information - types = system.types - charges = torch.zeros((len(system), n_types), dtype=dtype, device=device) - for i_specie, atomic_type in enumerate(requested_types): - charges[types == atomic_type, i_specie] = 1.0 + if use_explicit_charges: + charges = system.get_data("charges").values + else: + # One-hot encoding of charge information + charges = self._one_hot_charges( + system.types, requested_types, dtype, device + ) # Compute the potentials potential = self._compute_single_system( system.positions, charges, system.cell ) - # Reorder data into Metatensor format + # Reorder data into metatensor format for spec_center, at_num_center in enumerate(requested_types): - for spec_neighbor in range(len(requested_types)): - a_pair = spec_center * n_types + spec_neighbor + for spec_channel in range(len(spec_channels)): + a_pair = spec_center * n_charges_channels + spec_channel feat_dic[a_pair] += [ - potential[types == at_num_center, spec_neighbor] + potential[system.types == at_num_center, spec_channel] ] # Assemble all computed potential values into TensorBlocks for each combination - # of center_type and neighbor_type + # of center_type and neighbor_type/charge_channel blocks: List[TensorBlock] = [] for keys, values in feat_dic.items(): - spec_center = requested_types[keys // n_types] + spec_center = requested_types[keys // n_charges_channels] # Generate the Labels objects for the samples and properties of the # TensorBlock. @@ -176,14 +217,17 @@ def compute( blocks.append(block) + assert len(blocks) == n_blocks + # Generate TensorMap from TensorBlocks by defining suitable keys key_values: List[torch.Tensor] = [] for spec_center in requested_types: - for spec_neighbor in requested_types: + for spec_channel in spec_channels: key_values.append( - torch.tensor([spec_center, spec_neighbor], device=device) + torch.tensor([spec_center, spec_channel], device=device) ) key_values = torch.vstack(key_values) - labels_keys = Labels(["center_type", "neighbor_type"], key_values) + + labels_keys = Labels(key_names, key_values) return TensorMap(keys=labels_keys, blocks=blocks) diff --git a/tests/calculators/test_meshpotential.py b/tests/calculators/test_meshpotential.py index 53663a0a..58e73808 100644 --- a/tests/calculators/test_meshpotential.py +++ b/tests/calculators/test_meshpotential.py @@ -24,6 +24,12 @@ def cscl_system(): return types, positions, cell +def cscl_system_with_charges(): + """CsCl crystal with charges.""" + charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) + return cscl_system() + (charges,) + + # Initialize the calculators. For now, only the MeshPotential is implemented. def descriptor() -> MeshPotential: atomic_smearing = 0.1 @@ -86,7 +92,17 @@ def test_operation_as_torch_script(): def test_single_frame(): values = descriptor().compute(*cscl_system()) - print(values) + assert_close( + MADELUNG_CSCL, + CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], + atol=1e4, + rtol=1e-5, + ) + + +# Test with explicit charges +def test_single_frame_with_charges(): + values = descriptor().compute(*cscl_system_with_charges()) assert_close( MADELUNG_CSCL, CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1], @@ -138,6 +154,36 @@ def test_positions_error(): descriptor().compute(types=types, positions=positions, cell=cell) +def test_charges_error_dimension_mismatch(): + types = torch.tensor([1, 2]) + positions = torch.zeros((2, 3)) + cell = torch.eye(3) + charges = torch.zeros((1, 2)) # This should have the same first dimension as types + + match = ( + "The first dimension of `charges` must be the same as the length " + "of `types`, got 1 and 2." + ) + + with pytest.raises(ValueError, match=match): + descriptor().compute( + types=types, positions=positions, cell=cell, charges=charges + ) + + +def test_charges_error_length_mismatch(): + types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])] + positions = [torch.zeros((2, 3)), torch.zeros((3, 3))] + cell = torch.eye(3) + charges = [torch.zeros(2, 1)] # This should have the same length as types + match = "The number of `types` and `charges` tensors must be the same, got 2 and 1." + + with pytest.raises(ValueError, match=match): + descriptor().compute( + types=types, positions=positions, cell=cell, charges=charges + ) + + def test_cell_error(): types = torch.tensor([1, 2, 3]) positions = torch.zeros((3, 3)) @@ -212,6 +258,37 @@ def test_inconsistent_device(): MP.compute(types=types, positions=positions, cell=cell) +def test_inconsistent_device_charges(): + """Test if the cell and positions have inconsistent device and error is raised.""" + types = torch.tensor([1], device="cpu") + positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu") + cell = torch.eye(3, device="cpu") + charges = torch.tensor([0.0], device="meta") # different device + + MP = MeshPotential(atomic_smearing=0.2) + + match = "`charges` must be on the same device as `positions`, got meta and cpu." + with pytest.raises(ValueError, match=match): + MP.compute(types=types, positions=positions, cell=cell, charges=charges) + + +def test_inconsistent_dtype_charges(): + """Test if the cell and positions have inconsistent dtype and error is raised.""" + types = torch.tensor([1], dtype=torch.float32) + positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32) + cell = torch.eye(3, dtype=torch.float32) + charges = torch.tensor([0.0], dtype=torch.float64) # Different dtype + + MP = MeshPotential(atomic_smearing=0.2) + + match = ( + "`charges` must be have the same dtype as `positions`, got torch.float64 and " + "torch.float32" + ) + with pytest.raises(ValueError, match=match): + MP.compute(types=types, positions=positions, cell=cell, charges=charges) + + def test_1d_tolist(): in_list = [1, 2, 7, 3, 4, 42] in_tensor = torch.tensor(in_list) diff --git a/tests/metatensor/test_metatensor_meshpotential.py b/tests/metatensor/test_metatensor_meshpotential.py index 77260b6c..a373c766 100644 --- a/tests/metatensor/test_metatensor_meshpotential.py +++ b/tests/metatensor/test_metatensor_meshpotential.py @@ -28,6 +28,54 @@ def toy_system_single_frame( ) +def toy_system_single_frame_charges(): + system = toy_system_single_frame() + + # Create system with "hand" written one hot charges + charges = torch.tensor([[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]]) + + # create a metatensor.TensorBlock wich and to add it to the system + samples = metatensor_torch.Labels("atom", torch.arange(len(system)).reshape(-1, 1)) + properties = metatensor_torch.Labels( + "charge", torch.arange(charges.shape[1]).reshape(-1, 1) + ) + + charges_block = metatensor_torch.TensorBlock( + samples=samples, + components=[], + properties=properties, + values=charges, + ) + + system.add_data("charges", charges_block) + + return system + + +def toy_system_single_frame_charges_arbitrary_charges(): + system = toy_system_single_frame() + + # Create system with "hand" written random charges with 4 samples and 5 channels + charges = torch.rand(4, 5) + + # create a metatensor.TensorBlock wich and to add it to the system + samples = metatensor_torch.Labels("atom", torch.arange(len(system)).reshape(-1, 1)) + properties = metatensor_torch.Labels( + "charge", torch.arange(charges.shape[1]).reshape(-1, 1) + ) + + charges_block = metatensor_torch.TensorBlock( + samples=samples, + components=[], + properties=properties, + values=charges, + ) + + system.add_data("charges", charges_block) + + return system + + # Initialize the calculators. For now, only the meshlode_metatensor.MeshPotential is # implemented. def descriptor() -> meshlode_metatensor.MeshPotential: @@ -96,6 +144,94 @@ def test_wrong_device_between_systems(): ) +def test_explicit_charges(): + mp = descriptor() + potential = mp.compute(toy_system_single_frame()) + potential_charges = mp.compute(toy_system_single_frame_charges()) + + # Test metatdata + assert potential_charges.keys.names == ["center_type", "charges_channel"] + assert torch.all( + potential_charges.keys.values == torch.tensor([[1, 0], [1, 1], [8, 0], [8, 1]]) + ) + + # Test values + for block, block_charges in zip(potential, potential_charges): + assert block_charges.samples == block.samples + assert block_charges.components == block.components + assert block_charges.properties == block.properties + assert torch.all(block_charges.values == block.values) + + +def test_explicit_arbitrarycharges(): + mp = descriptor() + potential_charges = mp.compute(toy_system_single_frame_charges_arbitrary_charges()) + + # Test metatdata + assert potential_charges.keys.names == ["center_type", "charges_channel"] + assert torch.all( + potential_charges.keys.values + == torch.tensor( + [ + [1, 0], + [1, 1], + [1, 2], + [1, 3], + [1, 4], + [8, 0], + [8, 1], + [8, 2], + [8, 3], + [8, 4], + ] + ) + ) + + +def test_error_raise_charges_no_charges(): + systems = [toy_system_single_frame(), toy_system_single_frame_charges()] + match = "`systems` do not consistently contain `charges` data" + + with pytest.raises(ValueError, match=match): + descriptor().compute(systems) + + +def test_error_raise_charge_shape(): + system = toy_system_single_frame() + + # Create system with "hand" written one hot charges + charges = torch.tensor( + [[1.0, 0.0, 2.0], [1.0, 0.0, 2.0], [0.0, 1.0, 2.0], [0.0, 1.0, 2.0]] + ) + + # create a metatensor.TensorBlock wich and to add it to the system + samples = metatensor_torch.Labels( + "atom", torch.arange(charges.shape[0]).reshape(-1, 1) + ) + properties = metatensor_torch.Labels( + "charge", torch.arange(charges.shape[1]).reshape(-1, 1) + ) + + charges_block = metatensor_torch.TensorBlock( + samples=samples, + components=[], + properties=properties, + values=charges, + ) + + system.add_data("charges", charges_block) + + systems = [system, toy_system_single_frame_charges()] + + match = ( + r"number of charges-channels in system index 1 \(2\) is inconsistent with " + r"first system \(3\)" + ) + + with pytest.raises(ValueError, match=match): + descriptor().compute(systems) + + # Make sure that the calculators are computing the features without raising errors, # and returns the correct output format (TensorMap) def check_operation(calculator):