Skip to content

Commit

Permalink
Add tests for potentials class
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Kazuki Huguenin-Dumittan committed Jun 20, 2024
1 parent 6f1c866 commit 40b8808
Show file tree
Hide file tree
Showing 13 changed files with 440 additions and 196 deletions.
20 changes: 13 additions & 7 deletions src/meshlode/calculators/calculator_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from meshlode.lib import InversePowerLawPotential
from typing import List, Optional, Union

import torch

from meshlode.lib import InversePowerLawPotential


def get_default_exponent():
return torch.tensor(1.0)


default_exponent = get_default_exponent()


@torch.jit.script
def _1d_tolist(x: torch.Tensor) -> List[int]:
Expand Down Expand Up @@ -38,17 +46,17 @@ class CalculatorBase(torch.nn.Module):
def __init__(
self,
all_types: Optional[List[int]] = None,
exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64),
exponent: Optional[torch.Tensor] = default_exponent,
):
super().__init__()

if all_types is None:
self.all_types = None
else:
self.all_types = _1d_tolist(torch.unique(torch.tensor(all_types)))

self.exponent = exponent
self.potential = InversePowerLawPotential(exponent = exponent)
self.potential = InversePowerLawPotential(exponent=exponent)

# This function is kept to keep this library compatible with the broader pytorch
# infrastructure, which require a "forward" function. We name this function
Expand All @@ -60,9 +68,7 @@ def forward(
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""forward just calls :py:meth:`CalculatorModule.compute`"""
return self.compute(
types=types, positions=positions, charges=charges
)
return self.compute(types=types, positions=positions, charges=charges)

def compute(
self,
Expand Down
6 changes: 5 additions & 1 deletion src/meshlode/calculators/calculator_base_periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def compute(
positions = [positions]
if not isinstance(cell, list):
cell = [cell]
if (neighbor_indices is not None) and not isinstance(neighbor_indices, list):
neighbor_indices = [neighbor_indices]
if (neighbor_shifts is not None) and not isinstance(neighbor_shifts, list):
neighbor_shifts = [neighbor_shifts]

# Check that all inputs are consistent
for types_single, positions_single, cell_single in zip(types, positions, cell):
Expand Down Expand Up @@ -164,7 +168,7 @@ def compute(
# of inputs. Each "frame" is processed independently.
potentials = []

if neighbor_indices is None:
if neighbor_indices is None or neighbor_shifts is None:
for positions_single, cell_single, charges_single in zip(
positions, cell, charges
):
Expand Down
12 changes: 4 additions & 8 deletions src/meshlode/calculators/direct.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .calculator_base import CalculatorBase

import torch

from .calculator_base import CalculatorBase


class DirectPotential(CalculatorBase):
"""A specie-wise long-range potential computed using a direct summation over all
Expand Down Expand Up @@ -46,10 +46,6 @@ def _compute_single_system(
potential. Subtracting these from each other, one could recover the more
standard electrostatic potential in which Na and Cl have charges of +1 and
-1, respectively.
:param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the
structure, where cell[i] is the i-th basis vector. While redundant in this
particular implementation, the parameter is kept to keep the same inputs as
the other calculators.
:returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential
at the position of each atom for the `n_channels` independent meshes separately.
Expand All @@ -73,9 +69,9 @@ def _compute_single_system(
# obvious alternative of setting the same components to zero after the division
# had issues with autograd. I would appreciate any better alternatives.
distances_sq[diagonal_indices, diagonal_indices] += 1e50

# Compute potential
potentials_by_pair = distances_sq.pow(-self.exponent / 2.)
potentials_by_pair = distances_sq.pow(-self.exponent / 2.0)
potentials = torch.matmul(potentials_by_pair, charges)

return potentials
41 changes: 21 additions & 20 deletions src/meshlode/calculators/ewald.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import torch
from typing import List, Optional

from .calculator_base_periodic import CalculatorBasePeriodic
import torch

# extra imports for neighbor list
from ase import Atoms
from ase.neighborlist import neighbor_list

from .calculator_base import default_exponent
from .calculator_base_periodic import CalculatorBasePeriodic


class EwaldPotential(CalculatorBasePeriodic):
"""A specie-wise long-range potential computed using the Ewald sum, scaling as
O(N^2) with respect to the number of particles N used as a reference to test faster
Expand Down Expand Up @@ -62,12 +65,12 @@ class EwaldPotential(CalculatorBasePeriodic):
def __init__(
self,
all_types: Optional[List[int]] = None,
exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64),
sr_cutoff: Optional[float] = None,
exponent: Optional[torch.Tensor] = default_exponent,
sr_cutoff: Optional[torch.Tensor] = None,
atomic_smearing: Optional[float] = None,
lr_wavelength: Optional[float] = None,
subtract_self: Optional[bool] = True,
subtract_interior: Optional[bool] = False
subtract_interior: Optional[bool] = False,
):
super().__init__(all_types=all_types, exponent=exponent)

Expand Down Expand Up @@ -120,7 +123,7 @@ def _compute_single_system(
cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6
if self.sr_cutoff is not None:
if self.sr_cutoff > torch.min(cell_dimensions) / 2:
raise ValueError(f"sr_cutoff {sr_cutoff} needs to be > {cutoff_max}")
raise ValueError(f"sr_cutoff {self.sr_cutoff} has to be > {cutoff_max}")

# Set the defaut values of convergence parameters
# The total computational cost = cost of SR part + cost of LR part
Expand Down Expand Up @@ -154,8 +157,6 @@ def _compute_single_system(
sr_cutoff=sr_cutoff,
)

##return charges * torch.sum(positions, dim=1) * self.exponent + potential_sr

potential_lr = self._compute_lr(
positions=positions,
charges=charges,
Expand All @@ -164,11 +165,9 @@ def _compute_single_system(
lr_wavelength=lr_wavelength,
)

#return potential_lr

potential_ewald = potential_sr + potential_lr
return potential_ewald

def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor:
"""
For a given unit cell, compute all reciprocal space vectors that are used to
Expand All @@ -189,9 +188,9 @@ def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tens
that will be used during Ewald summation (or related approaches).
``k_vectors[i]`` contains the i-th vector, where the order has no special
significance.
The total number N of k-vectors is NOT simply nx*ny*nz, and roughly corresponds
to nx*ny*nz/2 due since the vectors +k and -k can be grouped together during
summation.
The total number N of k-vectors is NOT simply nx*ny*nz, and roughly
corresponds to nx*ny*nz/2 due since the vectors +k and -k can be grouped
together during summation.
"""
# Check that the shapes of all inputs are correct
if ns.shape != (3,):
Expand Down Expand Up @@ -239,18 +238,18 @@ def _compute_lr(
structure, where cell[i] is the i-th basis vector.
:param smearing: torch.Tensor smearing paramter determining the splitting
between the SR and LR parts.
:param lr_wavelength: Spatial resolution used for the long-range (reciprocal space)
part of the Ewald sum. More conretely, all Fourier space vectors with a
wavelength >= this value will be kept.
:param lr_wavelength: Spatial resolution used for the long-range (reciprocal
space) part of the Ewald sum. More conretely, all Fourier space vectors with
a wavelength >= this value will be kept.
:returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential
at the position of each atom for the `n_channels` independent meshes separately.
"""
# Define k-space cutoff from required real-space resolution
k_cutoff = 2 * torch.pi / lr_wavelength

# Compute number of times each basis vector of the reciprocal space can be scaled
# until the cutoff is reached
# Compute number of times each basis vector of the reciprocal space can be
# scaled until the cutoff is reached
basis_norms = torch.linalg.norm(cell, dim=1)
ns_float = k_cutoff * basis_norms / 2 / torch.pi
ns = torch.ceil(ns_float).long()
Expand Down Expand Up @@ -346,7 +345,9 @@ def _compute_sr(
# Compute energy
potential = torch.zeros_like(charges)
for i, j, shift in zip(atom_is, atom_js, shifts):
dist = torch.linalg.norm(positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell)))
dist = torch.linalg.norm(
positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell))
)

# If the contribution from all atoms within the cutoff is to be subtracted
# this short-range part will simply use -V_LR as the potential
Expand Down
5 changes: 3 additions & 2 deletions src/meshlode/calculators/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from meshlode.lib.fourier_convolution import FourierSpaceConvolution
from meshlode.lib.mesh_interpolator import MeshInterpolator

from .calculator_base import default_exponent
from .calculator_base_periodic import CalculatorBasePeriodic


class MeshPotential(CalculatorBasePeriodic):
"""A specie-wise long-range potential, computed using the particle-mesh Ewald (PME)
method scaling as O(NlogN) with respect to the number of particles N.
Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(
interpolation_order: Optional[int] = 4,
subtract_self: Optional[bool] = False,
all_types: Optional[List[int]] = None,
exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64),
exponent: Optional[torch.Tensor] = default_exponent,
):
super().__init__(all_types=all_types, exponent=exponent)

Expand Down Expand Up @@ -120,7 +122,6 @@ def _compute_single_system(
assert positions.dtype == cell.dtype and charges.dtype == cell.dtype
assert positions.device == cell.device and charges.device == cell.device


# Define cutoff in reciprocal space
if mesh_spacing is None:
mesh_spacing = self.mesh_spacing
Expand Down
28 changes: 10 additions & 18 deletions src/meshlode/calculators/meshewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,6 @@ def _compute_single_system(
:returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential
at the position of each atom for the `n_channels` independent meshes separately.
"""
# Check that the realspace cutoff (if provided) is not too large
# This is because the current implementation is not able to return multiple
# periodic images of the same atom as a neighbor
cell_dimensions = torch.linalg.norm(cell, dim=1)
cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6
if self.sr_cutoff is not None:
if self.sr_cutoff > torch.min(cell_dimensions) / 2:
raise ValueError(f"sr_cutoff {self.sr_cutoff} has to be > {cutoff_max}")

# Set the defaut values of convergence parameters
# The total computational cost = cost of SR part + cost of LR part
# Bigger smearing increases the cost of the SR part while decreasing the cost
Expand All @@ -181,6 +172,8 @@ def _compute_single_system(
# chosen to reach a convergence on the order of 1e-4 to 1e-5 for the test
# structures.
if self.sr_cutoff is None:
cell_dimensions = torch.linalg.norm(cell, dim=1)
cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6
sr_cutoff = cutoff_max
else:
sr_cutoff = self.sr_cutoff
Expand All @@ -203,7 +196,7 @@ def _compute_single_system(
smearing=smearing,
sr_cutoff=sr_cutoff,
neighbor_indices=neighbor_indices,
neighbor_shifts=neighbor_shifts
neighbor_shifts=neighbor_shifts,
)

# Compute long-range (LR) part using a Fourier / reciprocal space sum
Expand Down Expand Up @@ -325,23 +318,22 @@ def _compute_sr(
:returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential
at the position of each atom for the `n_channels` independent meshes separately.
"""
if neighbor_indices is None:
if neighbor_indices is None or neighbor_shifts is None:
# Get list of neighbors
struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True)
atom_is, atom_js, shifts = neighbor_list(
atom_is, atom_js, neighbor_shifts = neighbor_list(
"ijS", struc, sr_cutoff.item(), self_interaction=False
)
else:
atom_is = neighbor_indices[:,0]
atom_js = neighbor_indices[:,1]
shifts = neighbor_shifts.T

atom_is = neighbor_indices[0]
atom_js = neighbor_indices[1]

# Compute energy
potential = torch.zeros_like(charges)
for i, j, shift in zip(atom_is, atom_js, shifts):
for i, j, shift in zip(atom_is, atom_js, neighbor_shifts):
shift = shift.type(cell.dtype)
dist = torch.linalg.norm(
positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell))
positions[j] - positions[i] + torch.tensor(shift @ cell)
)

# If the contribution from all atoms within the cutoff is to be subtracted
Expand Down
Loading

0 comments on commit 40b8808

Please sign in to comment.