Skip to content

Commit

Permalink
Implement doc changes suggested by Philip
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Kazuki Huguenin-Dumittan committed Dec 1, 2023
1 parent 227588e commit 961c082
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 29 deletions.
3 changes: 2 additions & 1 deletion src/meshlode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Particle-mesh based calculation of Long Distance Equivariants.
"""
from .calculators import MeshPotential
from .system import System

__all__ = ["MeshPotential"]
__all__ = ["MeshPotential", "System"]
__version__ = "0.0.0-dev"
45 changes: 25 additions & 20 deletions src/meshlode/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from meshlode.system import System


def my_1d_tolist(x: torch.Tensor):
def _my_1d_tolist(x: torch.Tensor):
"""Auxilary function to convert torch tensor to list of integers"""
result: List[int] = []
for i in x:
Expand All @@ -32,19 +32,34 @@ class MeshPotential(torch.nn.Module):
:param atomic_gaussian_width: Width of the atom-centered gaussian used to create the
atomic density.
:type atomic_gaussian_width: float
:param mesh_spacing: Value that determines the umber of Fourier-space grid points
that will be used along each axis.
:type mesh_spacing: float
:param interpolation_order: Interpolation order for mapping onto the grid, where an
interpolation order of p corresponds to interpolation by a polynomial of degree
p-1 (e.g. p=4 for cubic interpolation).
:type interpolation_order: int
:param subtract_self: bool. If set to true, subtract from the features of an atom
the contributions to the potential arising from that atom itself (but not the
periodic images).
Example
-------
>>> calculator = MeshPotential(atomic_gaussian_width=1.0)
>>> import torch
>>> from meshlode import MeshPotential, System
>>> # Define simple example structure having the CsCl structure
>>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])
>>> atomic_numbers = torch.tensor([55, 17]) # Cs and Cl
>>> frame = System(species=atomic_numbers, positions=positions, cell=torch.eye(3))
>>> # Compute features
>>> MP = MeshPotential(
... atomic_gaussian_width=0.2, mesh_spacing=0.1, interpolation_order=4
... )
>>> features = MP.compute(frame)
>>> keys = features.keys # print to see all species combinations
>>> block_ClCl = features.block({"species_center": 17, "species_neighbor": 17})
>>> values = block_ClCl.values # the Cl-potential at the position of the Cl atom
"""

Expand All @@ -55,12 +70,14 @@ def __init__(
atomic_gaussian_width: float,
mesh_spacing: float = 0.2,
interpolation_order: float = 4,
subtract_self: bool = False,
):
super().__init__()

self.atomic_gaussian_width = atomic_gaussian_width
self.mesh_spacing = mesh_spacing
self.interpolation_order = interpolation_order
self.subtract_self = subtract_self

# This function is kept to keep MeshLODE compatible with the broader pytorch
# infrastructure, which require a "forward" function. We name this function
Expand All @@ -78,7 +95,6 @@ def forward(
def compute(
self,
frames: Union[List[System], System],
subtract_self: bool = False,
) -> TensorMap:
"""Compute the potential at the position of each atom for all Systems provided
in "frames".
Expand All @@ -88,9 +104,6 @@ def compute(
``requires_grad`` set to :py:obj:`True`, then the corresponding gradients
are computed and registered as a custom node in the computational graph, to
allow backward propagation of the gradients later.
:param subtract_self: bool. If set to true, subtract from the features of an
atom i the contributions to the potential arising from the "center" atom
itself (but not the periodic images).
:return: TensorMap containing the potential of all atoms. The keys of the
tensormap are "species_center" and "species_neighbor".
Expand All @@ -110,7 +123,7 @@ def compute(
n_atoms_tot += len(frame)
all_species.append(frame.species)
all_species = torch.hstack(all_species)
atomic_numbers = my_1d_tolist(torch.unique(all_species))
atomic_numbers = _my_1d_tolist(torch.unique(all_species))
n_species = len(atomic_numbers)

# Initialize dictionary for sparse storage of the features
Expand All @@ -126,9 +139,7 @@ def compute(
charges[species == atomic_number, i_specie] = 1.0

# Compute the potentials
potential = self._compute_single_frame(
frame.cell, frame.positions, charges, subtract_self
)
potential = self._compute_single_frame(frame.cell, frame.positions, charges)

# Reorder data into Metatensor format
for spec_center, at_num_center in enumerate(atomic_numbers):
Expand Down Expand Up @@ -180,7 +191,6 @@ def _compute_single_frame(
cell: torch.Tensor,
positions: torch.Tensor,
charges: torch.Tensor,
subtract_self: bool = False,
) -> torch.Tensor:
"""
Compute the "electrostatic" potential at the position of all atoms in a
Expand All @@ -204,11 +214,6 @@ def _compute_single_frame(
standard electrostatic potential in which Na and Cl have charges of +1 and
-1, respectively.
:param subtract_self: bool. If set to true, the contribution to the potential of
the center atom itself is subtracted, meaning that only the potential
generated by the remaining atoms + periodic images of the center atom is
taken into account.
:returns: torch.tensor of shape (n_atoms, n_channels) containing the potential
at the position of each atom for the n_channels independent meshes separately.
"""
Expand Down Expand Up @@ -248,7 +253,7 @@ def _compute_single_frame(
interpolated_potential = MI.mesh_to_points(potential_mesh)

# Remove self contribution
if subtract_self:
if self.subtract_self:
self_contrib = torch.sqrt(torch.tensor(2.0 / torch.pi)) / smearing
interpolated_potential -= charges * self_contrib

Expand Down
3 changes: 2 additions & 1 deletion tests/test_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ class TestMultiFrameToySystem:
atomic_gaussian_width=atomic_gaussian_width,
mesh_spacing=mesh_spacing,
interpolation_order=interpolation_order,
subtract_self=False,
)
tensormaps_list.append(MP.compute(frames, subtract_self=False))
tensormaps_list.append(MP.compute(frames))

@pytest.mark.parametrize("features", tensormaps_list)
def test_tensormap_labels(self, features):
Expand Down
15 changes: 8 additions & 7 deletions tests/test_madelung.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def test_madelung_low_order(
madelung = dic["madelung"] / scaling_factor
mesh_spacing = smearing / 2 * scaling_factor
smearing_eff = smearing * scaling_factor
MP = MeshPotential(smearing_eff, mesh_spacing, interpolation_order)
potentials_mesh = MP._compute_single_frame(
cell, positions, charges, subtract_self=True
MP = MeshPotential(
smearing_eff, mesh_spacing, interpolation_order, subtract_self=True
)
potentials_mesh = MP._compute_single_frame(cell, positions, charges)
energies = potentials_mesh * charges
energies_target = -torch.ones_like(energies) * madelung
assert_close(energies, energies_target, rtol=1e-4, atol=1e-6)
Expand Down Expand Up @@ -153,10 +153,10 @@ def test_madelung_high_order(
madelung = dic["madelung"] / scaling_factor
mesh_spacing = smearing / 10 * scaling_factor
smearing_eff = smearing * scaling_factor
MP = MeshPotential(smearing_eff, mesh_spacing, interpolation_order)
potentials_mesh = MP._compute_single_frame(
cell, positions, charges, subtract_self=True
MP = MeshPotential(
smearing_eff, mesh_spacing, interpolation_order, subtract_self=True
)
potentials_mesh = MP._compute_single_frame(cell, positions, charges)
energies = potentials_mesh * charges
energies_target = -torch.ones_like(energies) * madelung
assert_close(energies, energies_target, rtol=1e-2, atol=1e-3)
Expand Down Expand Up @@ -191,8 +191,9 @@ def test_madelung_low_order_metatensor(
atomic_gaussian_width=smearing_eff,
mesh_spacing=mesh_spacing,
interpolation_order=interpolation_order,
subtract_self=True,
)
potentials_mesh = MP.compute(frame, subtract_self=True)
potentials_mesh = MP.compute(frame)

# Compute the actual potential from the features
energies = torch.zeros((n_atoms, 1))
Expand Down

0 comments on commit 961c082

Please sign in to comment.