Skip to content

Commit

Permalink
Add caching for kvector generation
Browse files Browse the repository at this point in the history
* Add caching for reciprocal space vectors

* Rename smearing -> atomic_smearing

* Hide class inheritance in documentation

* Improve comments

---------

Co-authored-by: Kevin Kazuki Huguenin-Dumittan <kvhuguenin@gmail.com>
  • Loading branch information
PicoCentauri and Kevin Kazuki Huguenin-Dumittan authored Dec 22, 2023
1 parent 34766a2 commit d915d0f
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 80 deletions.
1 change: 0 additions & 1 deletion docs/src/references/calculators/meshpotential.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ MeshPotential
.. autoclass:: meshlode.MeshPotential
:members:
:undoc-members:
:show-inheritance:
1 change: 0 additions & 1 deletion docs/src/references/lib/fourier_convolution.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ Fourier Convolution
.. autoclass:: meshlode.lib.FourierSpaceConvolution
:members:
:undoc-members:
:show-inheritance:
1 change: 0 additions & 1 deletion docs/src/references/lib/mesh_interpolator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ Mesh Interpolator
.. autoclass:: meshlode.lib.MeshInterpolator
:members:
:undoc-members:
:show-inheritance:
1 change: 0 additions & 1 deletion docs/src/references/lib/system.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ System
.. autoclass:: meshlode.System
:members:
:undoc-members:
:show-inheritance:
1 change: 0 additions & 1 deletion docs/src/references/metatensor/meshpotential.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@ MeshPotential
.. autoclass:: meshlode.metatensor.MeshPotential
:members:
:undoc-members:
:show-inheritance:
14 changes: 9 additions & 5 deletions examples/library-tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,21 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None):
# be easily extended to compute an arbitrary filter
#

fsc = meshlode.lib.fourier_convolution.FourierSpaceConvolution(frame.cell)
fsc = meshlode.lib.fourier_convolution.FourierSpaceConvolution()

# %%
# plain smearing
rho_mesh = fsc.compute(mesh, potential_exponent=0, smearing=1)
# plain atomic_smearing
rho_mesh = fsc.compute(
mesh_values=mesh, cell=frame.cell, potential_exponent=0, atomic_smearing=1
)

sliceplot(rho_mesh[0, :, :, :5])

# %%
# coulomb-like potential, no smearing
coulomb_mesh = fsc.compute(mesh, potential_exponent=1, smearing=0)
# coulomb-like potential, no atomic_smearing
coulomb_mesh = fsc.compute(
mesh_values=mesh, cell=frame.cell, potential_exponent=1, atomic_smearing=0
)

sliceplot(coulomb_mesh[1, :, :, :5], cmap="seismic")

Expand Down
22 changes: 14 additions & 8 deletions src/meshlode/calculators/meshpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def __init__(
self.interpolation_order = interpolation_order
self.subtract_self = subtract_self

# Initilize auxiliary objects
self.fourier_space_convolution = FourierSpaceConvolution()

# This function is kept to keep MeshLODE compatible with the broader pytorch
# infrastructure, which require a "forward" function. We name this function
# "compute" instead, for compatibility with other COSMO software.
Expand Down Expand Up @@ -156,7 +159,7 @@ def _compute_single_frame(
Compute the "electrostatic" potential at the position of all atoms in a
structure.
:param cell: torch.tensor of shape `(3,3)`. Describes the unit cell of the
:param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the
structure, where cell[i] is the i-th basis vector.
:param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian
coordinates of the atoms. The implementation also works if the positions
Expand All @@ -177,9 +180,6 @@ def _compute_single_frame(
:returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential
at the position of each atom for the `n_channels` independent meshes separately.
"""
smearing = self.atomic_smearing
interpolation_order = self.interpolation_order

# Initializations
n_atoms = len(positions)
assert positions.shape == (n_atoms, 3)
Expand All @@ -197,20 +197,26 @@ def _compute_single_frame(
ns = 2 ** torch.ceil(torch.log2(ns_actual_approx)).long() # [nx, ny, nz]

# Step 1: Smear particles onto mesh
MI = MeshInterpolator(cell, ns, interpolation_order=interpolation_order)
MI = MeshInterpolator(cell, ns, interpolation_order=self.interpolation_order)
MI.compute_interpolation_weights(positions)
rho_mesh = MI.points_to_mesh(particle_weights=charges)

# Step 2: Perform Fourier space convolution (FSC)
FSC = FourierSpaceConvolution(cell)
potential_mesh = FSC.compute(rho_mesh, potential_exponent=1, smearing=smearing)
potential_mesh = self.fourier_space_convolution.compute(
mesh_values=rho_mesh,
cell=cell,
potential_exponent=1,
atomic_smearing=self.atomic_smearing,
)

# Step 3: Back interpolation
interpolated_potential = MI.mesh_to_points(potential_mesh)

# Remove self contribution
if self.subtract_self:
self_contrib = torch.sqrt(torch.tensor(2.0 / torch.pi)) / smearing
self_contrib = (
torch.sqrt(torch.tensor(2.0 / torch.pi)) / self.atomic_smearing
)
interpolated_potential -= charges * self_contrib

return interpolated_potential
120 changes: 87 additions & 33 deletions src/meshlode/lib/fourier_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,45 @@
import torch


@torch.jit.script
class FourierSpaceConvolution:
"""
Class for handling all the steps necessary to compute the convolution f*G between
two functions f and G, where the values of f are provided on a discrete mesh.
:param cell: torch.tensor of shape ``(3,3)`` Tensor specifying the real space unit
cell of a structure, where cell[i] is the i-th basis vector
Class for handling all the steps necessary to compute the convolution :math:`f*G`
between two functions :math:`f` and :math:`G`, where the values of :math:`f` are
provided on a discrete mesh. In practice, the convolution is performed in
reciprocal space using the fast Fourier transform algorithm.
Since the reciprocal space vectors used for the calculations only depend on the
cell for a given set of hypers, the vectors are cached to reduce the computational
cost in case multiple structures use identical cells.
Example
-------
To compute the "electrostatic potential" we first have to define the cell as
well as the grid points where we want to evaluate the potential:
>>> import torch
>>> L = torch.rand((1,)) * 20 + 1.0
>>> cell = L * torch.randn((3, 3))
>>> ns = torch.randint(1, 20, size=(4,))
>>> n_channels, nx, ny, nz = ns
>>> nz *= 2 # last dimension needs to be even
>>> mesh_values = torch.randn(size=(n_channels, nx, ny, nz))
With this definitions we just have to call the :meth:`compute` method and save the
results
>>> fsc = FourierSpaceConvolution()
>>> potential = fsc.compute(mesh_values=mesh_values, cell=cell)
"""

def __init__(self, cell: torch.Tensor):
if cell.shape != (3, 3):
raise ValueError(f"cell of shape {cell.shape} should be of shape (3,3)")
self.cell: torch.Tensor = cell
def __init__(self):
self._cell_cache = torch.zeros(3, 3)
self._ns_cache = torch.zeros(3)
self._knorm_sq_cache = torch.empty(1)

def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor:
def generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor:
"""
For a given unit cell, compute all reciprocal space vectors that are used to
perform sums in the Fourier transformed space.
Expand All @@ -27,17 +51,22 @@ def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor:
z-direction, respectively. For faster performance during the Fast Fourier
Transform (FFT) it is recommended to use values of nx, ny and nz that are
powers of 2.
:param cell: torch.tensor of shape ``(3, 3)`` Tensor specifying the real space
unit cell of a structure, where cell[i] is the i-th basis vector
:return: torch.tensor of shape ``(N,3)`` Contains all reciprocal space vectors
:return: torch.tensor of shape ``(N, 3)`` Contains all reciprocal space vectors
that will be used during Ewald summation (or related approaches).
``k_vectors[i]`` contains the i-th vector, where the order has no special
significance.
"""
if ns.shape != (3,):
raise ValueError(f"ns of shape {ns.shape} should be of shape (3,)")
raise ValueError(f"ns of shape {ns.shape} should be of shape (3, )")

if cell.shape != (3, 3):
raise ValueError(f"cell of shape {cell.shape} should be of shape (3, 3)")

# Define basis vectors of the reciprocal cell
reciprocal_cell = 2 * torch.pi * self.cell.inverse().T
reciprocal_cell = 2 * torch.pi * cell.inverse().T
bx = reciprocal_cell[0]
by = reciprocal_cell[1]
bz = reciprocal_cell[2]
Expand All @@ -55,36 +84,41 @@ def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor:
return k_vectors

def kernel_func(
self, ksq: torch.Tensor, potential_exponent: int = 1, smearing: float = 0.2
self,
ksq: torch.Tensor,
potential_exponent: int = 1,
atomic_smearing: float = 0.2,
) -> torch.Tensor:
"""
Fourier transform of the Coulomb potential or more general
effective :math:`1/r^p` potentials with additional smearing to remove the
Fourier transform of the Coulomb potential or more general effective
:math:`1/r^p` potentials with additional ``atomic_smearing`` to remove the
singularity at the origin.
:param ksq: torch.tensor of shape ``(N,)`` Squared norm of the k-vectors
:param potential_exponent: Exponent of the effective :math:`1/r^p` decay
:param smearing: Broadening of the :math:`1/r^p` decay close to the origin
:param atomic_smearing: Broadening of the :math:`1/r^p` decay close to the
origin
:return: torch.tensor of shape ``(N,)`` with the values of the kernel function
G(k) evaluated at the provided (squared norms of the) k-vectors
"""
if potential_exponent == 1:
return 4 * torch.pi / ksq * torch.exp(-0.5 * smearing**2 * ksq)
return 4 * torch.pi / ksq * torch.exp(-0.5 * atomic_smearing**2 * ksq)
elif potential_exponent == 0:
return torch.exp(-0.5 * smearing**2 * ksq)
return torch.exp(-0.5 * atomic_smearing**2 * ksq)
else:
raise ValueError("Only potential exponents 0 and 1 are supported")

def value_at_origin(
self, potential_exponent: int = 1, smearing: Optional[float] = 0.2
self, potential_exponent: int = 1, atomic_smearing: Optional[float] = 0.2
) -> float:
"""
Since the kernel function in reciprocal space typically has a (removable)
singularity at k=0, the value at that point needs to be specified explicitly.
:param potential_exponent: Exponent of the effective :math:`1/r^p` decay
:param smearing: Broadening of the :math:`1/r^p` decay close to the origin
:param atomic_smearing: Broadening of the :math:`1/r^p` decay close to the
origin
:return: float of G(k=0), the value of the kernel function at the origin.
"""
Expand All @@ -98,21 +132,24 @@ def value_at_origin(
def compute(
self,
mesh_values: torch.Tensor,
cell: torch.Tensor,
potential_exponent: int = 1,
smearing: float = 0.2,
atomic_smearing: float = 0.2,
) -> torch.Tensor:
"""
Compute the "electrostatic potential" from the density defined
on a discrete mesh.
:param mesh_values: torch.tensor of shape ``(n_channels, nx, ny, nz)``
The values of the density defined on a mesh.
:param cell: torch.tensor of shape ``(3, 3)`` Tensor specifying the real space
unit cell of a structure, where cell[i] is the i-th basis vector
:param potential_exponent: int
The exponent in the :math:`1/r^p` decay of the effective potential,
where :math:`p=1` corresponds to the Coulomb potential,
and :math:`p=0` is set as Gaussian smearing.
:param smearing: float
Width of the Gaussian smearing (for the Coulomb potential).
and :math:`p=0` is set as Gaussian atomic_smearing.
:param atomic_smearing: float
Width of the Gaussian atomic_smearing (for the Coulomb potential).
:returns: torch.tensor of shape ``(n_channels, nx, ny, nz)``
The potential evaluated on the same mesh points as the provided
Expand All @@ -121,14 +158,29 @@ def compute(
if mesh_values.dim() != 4:
raise ValueError("`mesh_values`` needs to be a 4 dimensional tensor")

if cell.shape != (3, 3):
raise ValueError(f"cell of shape {cell.shape} should be of shape (3, 3)")

# Get shape information from mesh
n_channels, nx, ny, nz = mesh_values.shape
_, nx, ny, nz = mesh_values.shape
ns = torch.tensor([nx, ny, nz])

# Get the relevant reciprocal space vectors (k-vectors)
# and compute their norm.
kvectors = self.generate_kvectors(ns)
knorm_sq = torch.sum(kvectors**2, dim=3)
# Use chached values if cell and number of mesh points have not changed since
# last call.
same_cell = torch.allclose(cell, self._cell_cache, atol=1e-15, rtol=1e-15)
if torch.all(ns == self._ns_cache) and same_cell:
knorm_sq = self._knorm_sq_cache
else:
# Get the relevant reciprocal space vectors (k-vectors)
# and compute their norm.
kvectors = self.generate_kvectors(ns=ns, cell=cell)
knorm_sq = torch.sum(kvectors**2, dim=3)

# Store values for the cache. We do not clone the arrays because we only
# read the value and do not perform any inplace operations
self._cell_cache = cell
self._ns_cache = ns
self._knorm_sq_cache = knorm_sq

# G(k) is the Fourier transform of the Coulomb potential
# generated by a Gaussian charge density
Expand All @@ -137,10 +189,12 @@ def compute(
# to the requirement that the net charge of the cell is zero.
# G = kernel_func(knorm_sq)
G = self.kernel_func(
knorm_sq, potential_exponent=potential_exponent, smearing=smearing
knorm_sq,
potential_exponent=potential_exponent,
atomic_smearing=atomic_smearing,
)
G[0, 0, 0] = self.value_at_origin(
potential_exponent=potential_exponent, smearing=smearing
potential_exponent=potential_exponent, atomic_smearing=atomic_smearing
)

# Fourier transforms consisting of the following substeps:
Expand All @@ -153,7 +207,7 @@ def compute(
# normalization option 'backward' (the convention in which 1/n_mesh
# is in the backward transformation) and vice versa for the
# inverse transform (irfft).
volume = self.cell.det()
volume = cell.det()
dims = (1, 2, 3) # dimensions along which to Fourier transform
mesh_hat = torch.fft.rfftn(mesh_values, norm="backward", dim=dims)
potential_hat = mesh_hat * G
Expand Down
8 changes: 4 additions & 4 deletions src/meshlode/lib/mesh_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MeshInterpolator:
of calculations is identical, this is performed in a separate function called
:func:`compute_interpolation_weights`.
:param cell: torch.tensor of shape ``(3,3)``, where ``cell[i]`` is the i-th basis
:param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis
vector of the unit cell
:param ns_mesh: toch.tensor of shape ``(3,)``
Number of mesh points to use along each of the three axes
Expand All @@ -33,7 +33,7 @@ def __init__(
):
# Check that the provided parameters match the specifications
if cell.shape != (3, 3):
raise ValueError(f"cell of shape {cell.shape} should be of shape (3,3)")
raise ValueError(f"cell of shape {cell.shape} should be of shape (3, 3)")
if ns_mesh.shape != (3,):
raise ValueError(f"shape {ns_mesh.shape} of `ns_mesh` has to be (3,)")
if interpolation_order not in [1, 2, 3, 4, 5]:
Expand Down Expand Up @@ -115,12 +115,12 @@ def compute_interpolation_weights(self, positions: torch.Tensor):
when calling the forward (:func:`points_to_mesh`) and backward
(:func:`mesh_to_points`) interpolation functions.
:param positions: torch.tensor of shape ``(N,3)``
:param positions: torch.tensor of shape ``(N, 3)``
Absolute positions of atoms in Cartesian coordinates
"""
n_positions = len(positions)
if positions.shape != (n_positions, 3):
raise ValueError(f"shape {positions.shape} of `positions` has to be (N,3)")
raise ValueError(f"shape {positions.shape} of `positions` has to be (N, 3)")

# Compute positions relative to the mesh basis vectors
positions_rel = torch.matmul(positions, torch.inverse(self.cell))
Expand Down
Loading

0 comments on commit d915d0f

Please sign in to comment.