Skip to content

Commit

Permalink
Add shape tests, example and code blocks in docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Kazuki Huguenin-Dumittan committed Dec 1, 2023
1 parent 9cca2e0 commit 74b87df
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 63 deletions.
56 changes: 31 additions & 25 deletions src/meshlode/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Our calculator API follows the `rascaline <https://luthaf.fr/rascaline>`_ API and coding
guidelines to promote usability and interoperability with existing workflows.
"""
from typing import Dict, List, Union
from typing import Dict, List, Union, Optional

import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
Expand All @@ -30,10 +30,11 @@ def _my_1d_tolist(x: torch.Tensor):
class MeshPotential(torch.nn.Module):
"""A species wise long range potential.
:param atomic_gaussian_width: Width of the atom-centered gaussian used to create the
:param atomic_smearing: Width of the atom-centered gaussian used to create the
atomic density.
:param mesh_spacing: Value that determines the umber of Fourier-space grid points
that will be used along each axis.
that will be used along each axis. If set to ``None``, it will automatically
be set to half of ``atomic_smearing``
:param interpolation_order: Interpolation order for mapping onto the grid, where an
interpolation order of p corresponds to interpolation by a polynomial of degree
``p-1`` (e.g. ``p=4`` for cubic interpolation).
Expand All @@ -47,17 +48,15 @@ class MeshPotential(torch.nn.Module):
>>> import torch
>>> from meshlode import MeshPotential, System
Define simple example structure having the CsCl structure
Define simple example structure having the CsCl (Cesium Chloride) structure
>>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])
>>> atomic_numbers = torch.tensor([55, 17]) # Cs and Cl
>>> frame = System(species=atomic_numbers, positions=positions, cell=torch.eye(3))
Compute features
>>> MP = MeshPotential(
... atomic_gaussian_width=0.2, mesh_spacing=0.1, interpolation_order=4
... )
>>> MP = MeshPotential(atomic_smearing=0.2, mesh_spacing=0.1, interpolation_order=4)
>>> features = MP.compute(frame)
All species combinations
Expand All @@ -70,10 +69,10 @@ class MeshPotential(torch.nn.Module):
55 17
55 55
)
>>> block_ClCl = features.block({"species_center": 17, "species_neighbor": 17})
The Cl-potential at the position of the Cl atom
>>> block_ClCl = features.block({"species_center": 17, "species_neighbor": 17})
>>> block_ClCl.values
tensor([[1.3755]])
Expand All @@ -83,14 +82,26 @@ class MeshPotential(torch.nn.Module):

def __init__(
self,
atomic_gaussian_width: float,
mesh_spacing: float = 0.2,
interpolation_order: float = 4,
subtract_self: bool = False,
atomic_smearing: float,
mesh_spacing: Optional[float] = None,
interpolation_order: Optional[float] = 4,
subtract_self: Optional[bool] = False,
):
super().__init__()

self.atomic_gaussian_width = atomic_gaussian_width
# Check that all provided values are correct
if interpolation_order not in [1, 2, 3, 4, 5]:
raise ValueError("Only `interpolation_order` from 1 to 5 are allowed")
if atomic_smearing <= 0:
raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive")

# If no explicit mesh_spacing is given, set it such that it can resolve
# the smeared potentials.
if mesh_spacing is None:
mesh_spacing = atomic_smearing / 2

# Store provided parameters
self.atomic_smearing = atomic_smearing
self.mesh_spacing = mesh_spacing
self.interpolation_order = interpolation_order
self.subtract_self = subtract_self
Expand All @@ -105,7 +116,6 @@ def forward(
"""forward just calls :py:meth:`CalculatorModule.compute`"""
res = self.compute(frames=systems)
return res
# return 0.

def compute(
self,
Expand Down Expand Up @@ -211,12 +221,12 @@ def _compute_single_frame(
Compute the "electrostatic" potential at the position of all atoms in a
structure.
:param cell: torch.tensor of shape (3,3). Describes the unit cell of the
:param cell: torch.tensor of shape `(3,3)`. Describes the unit cell of the
structure, where cell[i] is the i-th basis vector.
:param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian
coordinates of the atoms. The implementation also works if the positions
are not contained within the unit cell.
:param charges: torch.tensor of shape (n_atoms, n_channels). In the simplest
:param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest
case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the
charge of atom i. More generally, the potential for the same atom positions
is computed for n_channels independent meshes, and one can specify the
Expand All @@ -229,23 +239,19 @@ def _compute_single_frame(
standard electrostatic potential in which Na and Cl have charges of +1 and
-1, respectively.
:returns: torch.tensor of shape (n_atoms, n_channels) containing the potential
at the position of each atom for the n_channels independent meshes separately.
:returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential
at the position of each atom for the `n_channels` independent meshes separately.
"""
smearing = self.atomic_gaussian_width
mesh_resolution = self.mesh_spacing
smearing = self.atomic_smearing
interpolation_order = self.interpolation_order

# Initializations
n_atoms = len(positions)
assert positions.shape == (n_atoms, 3)
assert charges.shape[0] == n_atoms

# Define k-vectors
if mesh_resolution is None:
k_cutoff = 2 * torch.pi / (smearing / 2)
else:
k_cutoff = 2 * torch.pi / mesh_resolution
# Define cutoff in reciprocal space
k_cutoff = 2 * torch.pi / self.mesh_spacing

# Compute number of times each basis vector of the
# reciprocal space can be scaled until the cutoff
Expand Down
41 changes: 24 additions & 17 deletions src/meshlode/fourier_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,42 @@
===================
"""
import torch
from typing import Optional


class FourierSpaceConvolution:
"""
Class for handling all the steps necessary to compute the convolution f*G between
two functions f and G, where the values of f are provided on a discrete mesh.
:param cell: torch.tensor of shape (3,3) Tensor specifying the real space unit
:param cell: torch.tensor of shape ``(3,3)`` Tensor specifying the real space unit
cell of a structure, where cell[i] is the i-th basis vector
"""

def __init__(self, cell: torch.Tensor):
if cell.shape != (3, 3):
raise ValueError(f"cell of shape {cell.shape} should be of shape (3,3)")
self.cell: torch.Tensor = cell

def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor:
"""
For a given unit cell, compute all reciprocal space vectors that are used to
perform sums in the Fourier transformed space.
:param cell: torch.tensor of shape (3,3)
Tensor specifying the real space unit cell of a structure, where cell[i] is
the i-th basis vector
:param ns: torch.tensor of shape (3,)
ns = [nx, ny, nz] contains the number of mesh points in the x-, y- and
:param ns: torch.tensor of shape ``(3,)``
``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and
z-direction, respectively. For faster performance during the Fast Fourier
Transform (FFT) it is recommended to use values of nx, ny and nz that are
powers of 2.
:return: torch.tensor of shape [N_k,3] Contains all reciprocal space vectors
that will be used during Ewald summation (or related approaches). The number
N_k of such vectors is given by N_k = nx * ny * nz. k_vectors[i] contains
the i-th vector, where the order has no special significance.
:return: torch.tensor of shape ``(N,3)`` Contains all reciprocal space vectors
that will be used during Ewald summation (or related approaches).
``k_vectors[i]`` contains the i-th vector, where the order has no special
significance.
"""
if ns.shape != (3,):
raise ValueError(f"ns of shape {ns.shape} should be of shape (3,)")

# Define basis vectors of the reciprocal cell
reciprocal_cell = 2 * torch.pi * self.cell.inverse().T
bx = reciprocal_cell[0]
Expand All @@ -55,17 +58,18 @@ def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor:
return k_vectors

def kernel_func(
self, ksq: torch.Tensor, potential_exponent: int = 1, smearing: float = 0.2
self, ksq: torch.Tensor, potential_exponent: int = 1,
smearing: float = 0.2
) -> torch.Tensor:
"""
Fourier transform of the Coulomb potential or more general effective 1/r**p
potentials with additional smearing to remove the singularity at the origin.
:param ksq: torch.tensor of shape (N_k,) Squared norm of the k-vectors
:param ksq: torch.tensor of shape ``(N,)`` Squared norm of the k-vectors
:param potential_exponent: Exponent of the effective 1/r**p decay
:param smearing: Broadening of the 1/r**p decay close to the origin
:return: torch.tensor of shape (N_k,) with the values of the kernel function
:return: torch.tensor of shape ``(N,)`` with the values of the kernel function
G(k) evaluated at the provided (squared norms of the) k-vectors
"""
if potential_exponent == 1:
Expand All @@ -76,7 +80,7 @@ def kernel_func(
raise ValueError("Only potential exponents 0 and 1 are supported")

def value_at_origin(
self, potential_exponent: int = 1, smearing: float = 0.2
self, potential_exponent: int = 1, smearing: Optional[float] = 0.2
) -> float:
"""
Since the kernel function in reciprocal space typically has a (removable)
Expand All @@ -87,7 +91,7 @@ def value_at_origin(
:return: float of G(k=0), the value of the kernel function at the origin.
"""
if potential_exponent in [1, 2, 3]:
if potential_exponent == 1:
return 0.0
elif potential_exponent == 0:
return 1.0
Expand All @@ -104,18 +108,21 @@ def compute(
Compute the "electrostatic potential" from the density defined
on a discrete mesh.
:param mesh_values: torch.tensor of shape (n_channels, nx, ny, nz)
:param mesh_values: torch.tensor of shape ``(n_channels, nx, ny, nz)``
The values of the density defined on a mesh.
:param potential_exponent: int
The exponent in the 1/r**p decay of the effective potential, where p=1
corresponds to the Coulomb potential, and p=0 is set as a delta-potential.
:param smearing: float
Width of the Gaussian smearing (for the Coulomb potential).
:returns: torch.tensor of shape (n_channels, nx, ny, nz)
:returns: torch.tensor of shape ``(n_channels, nx, ny, nz)``
The potential evaluated on the same mesh points as the provided
density.
"""
if mesh_values.dim() != 4:
raise ValueError("`mesh_values`` needs to be a 4 dimensional tensor")

# Get shape information from mesh
n_channels, nx, ny, nz = mesh_values.shape
ns = torch.tensor([nx, ny, nz])
Expand Down
49 changes: 32 additions & 17 deletions src/meshlode/mesh_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class MeshInterpolator:
of calculations is identical, this is performed in a separate function called
"compute_interpolation_weights".
:param cell: torch.tensor of shape (3,3)
cell[i] is the i-th basis vector of the unit cell
:param ns_mesh: list of tuple of size 3
:param cell: torch.tensor of shape ``(3,3)``, where ``cell[i]`` is the i-th basis
vector of the unit cell
:param ns_mesh: toch.tensor of shape ``(3,)``
Number of mesh points to use along each of the three axes
:param interpolation_order: int
The degree of the polynomials used for interpolation. A higher order leads
Expand All @@ -34,6 +34,14 @@ class MeshInterpolator:
def __init__(
self, cell: torch.Tensor, ns_mesh: torch.Tensor, interpolation_order: int
):
# Check that the provided parameters match the specifications
if cell.shape != (3, 3):
raise ValueError(f"cell of shape {cell.shape} should be of shape (3,3)")
if ns_mesh.shape != (3,):
raise ValueError(f"shape {ns_mesh.shape} of `ns_mesh` has to be (3,)")
if interpolation_order not in [1, 2, 3, 4, 5]:
raise ValueError("Only `interpolation_order` from 1 to 5 are allowed")

self.cell = cell
self.ns_mesh = ns_mesh
self.interpolation_order = interpolation_order
Expand All @@ -48,7 +56,7 @@ def __init__(
self.y_indices: torch.Tensor = torch.tensor(0)
self.z_indices: torch.Tensor = torch.tensor(0)

def compute_1d_weights(self, x: torch.Tensor) -> torch.Tensor:
def _compute_1d_weights(self, x: torch.Tensor) -> torch.Tensor:
"""
Generate the smooth interpolation weights used to smear the particles onto a
mesh.
Expand All @@ -57,10 +65,10 @@ def compute_1d_weights(self, x: torch.Tensor) -> torch.Tensor:
J. Chem. Phys. 109, 7678–7693 (1998)
https://doi.org/10.1063/1.477414
:param x: torch.tensor of shape (n,)
:param x: torch.tensor of shape ``(n,)``
Set of relative positions in the interval [-1/2, 1/2].
:return: torch.tensor of shape (interpolation_order, n)
:return: torch.tensor of shape ``(interpolation_order, n)``
Interpolation weights
"""
# Compute weights based on the given order
Expand Down Expand Up @@ -108,12 +116,16 @@ def compute_interpolation_weights(self, positions: torch.Tensor):
"""
Compute the interpolation weights of each atom for a given cell (specified
during initialization of this class). The weights are not returned, but are used
when calling the forward (points_to_mesh) and backward (mesh_to_points)
when calling the forward (``points_to_mesh``) and backward (``mesh_to_points``)
interpolation functions.
:param positions: torch.tensor of shape (N,3)
:param positions: torch.tensor of shape ``(N,3)``
Absolute positions of atoms in Cartesian coordinates
"""
n_positions = len(positions)
if positions.shape != (n_positions, 3):
raise ValueError(f"shape {positions.shape} of `positions` has to be (N,3)")

# Compute positions relative to the mesh basis vectors
positions_rel = torch.linalg.solve(self.cell.T, positions.T).T
positions_rel *= self.ns_mesh
Expand All @@ -127,7 +139,7 @@ def compute_interpolation_weights(self, positions: torch.Tensor):
offsets = positions_rel - positions_rel_idx

# Compute weights based on distances and interpolation order
self.interpolation_weights = self.compute_1d_weights(offsets)
self.interpolation_weights = self._compute_1d_weights(offsets)

# Calculate indices of mesh points on which
# the particle weights are interpolated
Expand Down Expand Up @@ -169,15 +181,18 @@ def points_to_mesh(self, particle_weights: torch.Tensor) -> torch.Tensor:
"compute_interpolation_weights" has been called before to compute all the
necessary weights and indices.
:param particle_weights: torch.tensor of shape (n_atoms, n_channels)
particle_weights[i,a] is the "weight" or "charge" that atom i has to
:param particle_weights: torch.tensor of shape ``(n_points, n_channels)``
``particle_weights[i,a]`` is the weight (charge) that point (atom) i has to
generate the "a-th" potential. In practice, this can be used to compute e.g.
the Na and Cl contributions to the potential separately by using a one-hot
encoding of the species.
:return: torch.tensor of shape (n_channels, n_mesh, n_mesh, n_mesh)
:return: torch.tensor of shape ``(n_channels, n_mesh, n_mesh, n_mesh)``
Discrete density
"""
if particle_weights.dim() != 2:
raise ValueError("`particle_weights` needs to be a tensor of dimension 2")

# Update mesh values by combining particle weights and interpolation weights
n_channels = particle_weights.shape[1]
nx = int(self.ns_mesh[0])
Expand All @@ -203,18 +218,18 @@ def mesh_to_points(self, mesh_vals: torch.Tensor) -> torch.Tensor:
Take a function defined on a mesh and interpolate
its values on arbitrary positions.
:param mesh_vals: torch.tensor of shape (n_channels, nx, ny, nz)
:param mesh_vals: torch.tensor of shape ``(n_channels, nx, ny, nz)``
The tensor contains the values of a function evaluated on a
three-dimensional mesh. (nx, ny, nz) are the number of points along each of
the three directions, while n_channels provides the number of such functions
that are treated simulateously for the present system.
:param positions: torch.tensor of shape (n_points,3)
Absolute positions of particles in Cartesian coordinates, onto whose
locations we wish to interpolate the mesh values.
:return: interpolated_values: torch.tensor of shape (n_points, n_channels)
:return: interpolated_values: torch.tensor of shape ``(n_points, n_channels)``
Values of the interpolated function.
"""
if mesh_vals.dim() != 4:
raise ValueError("`mesh_vals` need to be a tensor of dimension 4")

interpolated_values = (
(
mesh_vals[:, self.x_indices, self.y_indices, self.z_indices]
Expand Down
Loading

0 comments on commit 74b87df

Please sign in to comment.