From 961c082f9fcc012f4ea367c34d955f3d667941e8 Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Fri, 1 Dec 2023 14:20:52 +0100 Subject: [PATCH] Implement doc changes suggested by Philip --- src/meshlode/__init__.py | 3 ++- src/meshlode/calculators.py | 45 ++++++++++++++++++++----------------- tests/test_calculators.py | 3 ++- tests/test_madelung.py | 15 +++++++------ 4 files changed, 37 insertions(+), 29 deletions(-) diff --git a/src/meshlode/__init__.py b/src/meshlode/__init__.py index d4ccc1ea..83097d9e 100644 --- a/src/meshlode/__init__.py +++ b/src/meshlode/__init__.py @@ -5,6 +5,7 @@ Particle-mesh based calculation of Long Distance Equivariants. """ from .calculators import MeshPotential +from .system import System -__all__ = ["MeshPotential"] +__all__ = ["MeshPotential", "System"] __version__ = "0.0.0-dev" diff --git a/src/meshlode/calculators.py b/src/meshlode/calculators.py index af64d200..fbfdba60 100644 --- a/src/meshlode/calculators.py +++ b/src/meshlode/calculators.py @@ -19,7 +19,7 @@ from meshlode.system import System -def my_1d_tolist(x: torch.Tensor): +def _my_1d_tolist(x: torch.Tensor): """Auxilary function to convert torch tensor to list of integers""" result: List[int] = [] for i in x: @@ -32,19 +32,34 @@ class MeshPotential(torch.nn.Module): :param atomic_gaussian_width: Width of the atom-centered gaussian used to create the atomic density. - :type atomic_gaussian_width: float :param mesh_spacing: Value that determines the umber of Fourier-space grid points that will be used along each axis. - :type mesh_spacing: float :param interpolation_order: Interpolation order for mapping onto the grid, where an interpolation order of p corresponds to interpolation by a polynomial of degree p-1 (e.g. p=4 for cubic interpolation). - :type interpolation_order: int + :param subtract_self: bool. If set to true, subtract from the features of an atom + the contributions to the potential arising from that atom itself (but not the + periodic images). Example ------- - >>> calculator = MeshPotential(atomic_gaussian_width=1.0) + >>> import torch + >>> from meshlode import MeshPotential, System + + >>> # Define simple example structure having the CsCl structure + >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) + >>> atomic_numbers = torch.tensor([55, 17]) # Cs and Cl + >>> frame = System(species=atomic_numbers, positions=positions, cell=torch.eye(3)) + + >>> # Compute features + >>> MP = MeshPotential( + ... atomic_gaussian_width=0.2, mesh_spacing=0.1, interpolation_order=4 + ... ) + >>> features = MP.compute(frame) + >>> keys = features.keys # print to see all species combinations + >>> block_ClCl = features.block({"species_center": 17, "species_neighbor": 17}) + >>> values = block_ClCl.values # the Cl-potential at the position of the Cl atom """ @@ -55,12 +70,14 @@ def __init__( atomic_gaussian_width: float, mesh_spacing: float = 0.2, interpolation_order: float = 4, + subtract_self: bool = False, ): super().__init__() self.atomic_gaussian_width = atomic_gaussian_width self.mesh_spacing = mesh_spacing self.interpolation_order = interpolation_order + self.subtract_self = subtract_self # This function is kept to keep MeshLODE compatible with the broader pytorch # infrastructure, which require a "forward" function. We name this function @@ -78,7 +95,6 @@ def forward( def compute( self, frames: Union[List[System], System], - subtract_self: bool = False, ) -> TensorMap: """Compute the potential at the position of each atom for all Systems provided in "frames". @@ -88,9 +104,6 @@ def compute( ``requires_grad`` set to :py:obj:`True`, then the corresponding gradients are computed and registered as a custom node in the computational graph, to allow backward propagation of the gradients later. - :param subtract_self: bool. If set to true, subtract from the features of an - atom i the contributions to the potential arising from the "center" atom - itself (but not the periodic images). :return: TensorMap containing the potential of all atoms. The keys of the tensormap are "species_center" and "species_neighbor". @@ -110,7 +123,7 @@ def compute( n_atoms_tot += len(frame) all_species.append(frame.species) all_species = torch.hstack(all_species) - atomic_numbers = my_1d_tolist(torch.unique(all_species)) + atomic_numbers = _my_1d_tolist(torch.unique(all_species)) n_species = len(atomic_numbers) # Initialize dictionary for sparse storage of the features @@ -126,9 +139,7 @@ def compute( charges[species == atomic_number, i_specie] = 1.0 # Compute the potentials - potential = self._compute_single_frame( - frame.cell, frame.positions, charges, subtract_self - ) + potential = self._compute_single_frame(frame.cell, frame.positions, charges) # Reorder data into Metatensor format for spec_center, at_num_center in enumerate(atomic_numbers): @@ -180,7 +191,6 @@ def _compute_single_frame( cell: torch.Tensor, positions: torch.Tensor, charges: torch.Tensor, - subtract_self: bool = False, ) -> torch.Tensor: """ Compute the "electrostatic" potential at the position of all atoms in a @@ -204,11 +214,6 @@ def _compute_single_frame( standard electrostatic potential in which Na and Cl have charges of +1 and -1, respectively. - :param subtract_self: bool. If set to true, the contribution to the potential of - the center atom itself is subtracted, meaning that only the potential - generated by the remaining atoms + periodic images of the center atom is - taken into account. - :returns: torch.tensor of shape (n_atoms, n_channels) containing the potential at the position of each atom for the n_channels independent meshes separately. """ @@ -248,7 +253,7 @@ def _compute_single_frame( interpolated_potential = MI.mesh_to_points(potential_mesh) # Remove self contribution - if subtract_self: + if self.subtract_self: self_contrib = torch.sqrt(torch.tensor(2.0 / torch.pi)) / smearing interpolated_potential -= charges * self_contrib diff --git a/tests/test_calculators.py b/tests/test_calculators.py index fb239fc2..870040e7 100644 --- a/tests/test_calculators.py +++ b/tests/test_calculators.py @@ -94,8 +94,9 @@ class TestMultiFrameToySystem: atomic_gaussian_width=atomic_gaussian_width, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, + subtract_self=False, ) - tensormaps_list.append(MP.compute(frames, subtract_self=False)) + tensormaps_list.append(MP.compute(frames)) @pytest.mark.parametrize("features", tensormaps_list) def test_tensormap_labels(self, features): diff --git a/tests/test_madelung.py b/tests/test_madelung.py index 461bdd38..95b12a52 100644 --- a/tests/test_madelung.py +++ b/tests/test_madelung.py @@ -119,10 +119,10 @@ def test_madelung_low_order( madelung = dic["madelung"] / scaling_factor mesh_spacing = smearing / 2 * scaling_factor smearing_eff = smearing * scaling_factor - MP = MeshPotential(smearing_eff, mesh_spacing, interpolation_order) - potentials_mesh = MP._compute_single_frame( - cell, positions, charges, subtract_self=True + MP = MeshPotential( + smearing_eff, mesh_spacing, interpolation_order, subtract_self=True ) + potentials_mesh = MP._compute_single_frame(cell, positions, charges) energies = potentials_mesh * charges energies_target = -torch.ones_like(energies) * madelung assert_close(energies, energies_target, rtol=1e-4, atol=1e-6) @@ -153,10 +153,10 @@ def test_madelung_high_order( madelung = dic["madelung"] / scaling_factor mesh_spacing = smearing / 10 * scaling_factor smearing_eff = smearing * scaling_factor - MP = MeshPotential(smearing_eff, mesh_spacing, interpolation_order) - potentials_mesh = MP._compute_single_frame( - cell, positions, charges, subtract_self=True + MP = MeshPotential( + smearing_eff, mesh_spacing, interpolation_order, subtract_self=True ) + potentials_mesh = MP._compute_single_frame(cell, positions, charges) energies = potentials_mesh * charges energies_target = -torch.ones_like(energies) * madelung assert_close(energies, energies_target, rtol=1e-2, atol=1e-3) @@ -191,8 +191,9 @@ def test_madelung_low_order_metatensor( atomic_gaussian_width=smearing_eff, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, + subtract_self=True, ) - potentials_mesh = MP.compute(frame, subtract_self=True) + potentials_mesh = MP.compute(frame) # Compute the actual potential from the features energies = torch.zeros((n_atoms, 1))