Skip to content

Commit

Permalink
Allow direct potential with a cell + neighbor list
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Aug 15, 2024
1 parent 37a8790 commit 7c687c1
Show file tree
Hide file tree
Showing 10 changed files with 415 additions and 86 deletions.
10 changes: 10 additions & 0 deletions docs/src/references/lib/neighbors.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Neighbors
=========

Functions for simple neighbor and distance calculations. For more advanced methods we
refer to external libraries like :py:func:`ase.neighborlist.neighbor_list` or
:py:class:`vesin.NeighborList`.

.. automodule:: torchpme.lib.neighbors
:members:
:undoc-members:
68 changes: 37 additions & 31 deletions src/torchpme/calculators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from torchpme.lib import InversePowerLawPotential
from ..lib import InversePowerLawPotential, distances


class CalculatorBaseTorch(torch.nn.Module):
Expand Down Expand Up @@ -138,6 +138,9 @@ def _validate_compute_parameters(
f"{cell_single.device}"
)

if neighbor_shifts_single is None:
raise ValueError("Provided `cell` but no `neighbor_shifts`.")

# check shape, dtype and device of charges
if charges_single.dim() != 2:
raise ValueError(
Expand Down Expand Up @@ -169,26 +172,42 @@ def _validate_compute_parameters(

# check shape, dtype and device of neighbor_indices and neighbor_shifts
if neighbor_indices_single is not None:
if neighbor_shifts_single is None:
raise ValueError(
"Need to provide both `neighbor_indices` and `neighbor_shifts` "
"together."
)

if neighbor_indices_single.shape[0] != 2:
raise ValueError(
"neighbor_indices is expected to have shape [2, num_neighbors]"
f", but got {list(neighbor_indices_single.shape)} for one "
"structure"
)

if neighbor_indices_single.device != self._device:
raise ValueError(
f"each `neighbor_indices` must be on the same device "
f"{self._device} as positions, got at least one tensor with "
f"device {neighbor_indices_single.device}"
)

if neighbor_shifts_single is not None:
if cell_single is None:
raise ValueError("Provided `neighbor_shifts` but no `cell`.")

if neighbor_shifts_single.shape[1] != 3:
raise ValueError(
"neighbor_shifts is expected to have shape [num_neighbors, 3]"
f", but got {list(neighbor_shifts_single.shape)} for one "
"structure"
)

if neighbor_shifts_single.device != self._device:
raise ValueError(
f"each `neighbor_shifts` must be on the same device "
f"{self._device} as positions, got at least one tensor with "
f"device {neighbor_shifts_single.device}"
)

if (
neighbor_indices_single is not None
and neighbor_shifts_single is not None
):
if neighbor_shifts_single.shape[0] != neighbor_indices_single.shape[1]:
raise ValueError(
"`neighbor_indices` and `neighbor_shifts` need to have shapes "
Expand All @@ -198,20 +217,6 @@ def _validate_compute_parameters(
"which is inconsistent"
)

if neighbor_indices_single.device != self._device:
raise ValueError(
f"each `neighbor_indices` must be on the same device "
f"{self._device} as positions, got at least one tensor with "
f"device {neighbor_indices_single.device}"
)

if neighbor_shifts_single.device != self._device:
raise ValueError(
f"each `neighbor_shifts` must be on the same device "
f"{self._device} as positions, got at least one tensor with "
f"device {neighbor_shifts_single.device}"
)

return positions, charges, cell, neighbor_indices, neighbor_shifts

def _compute_impl(
Expand Down Expand Up @@ -310,16 +315,12 @@ def _compute_sr(
neighbor_indices: torch.Tensor,
neighbor_shifts: torch.Tensor,
) -> torch.Tensor:
atom_is = neighbor_indices[0]
atom_js = neighbor_indices[1]
shifts = neighbor_shifts.type(cell.dtype)

# Compute energy
potential = torch.zeros_like(charges)

pos_is = positions[atom_is]
pos_js = positions[atom_js]
dists = torch.linalg.norm(pos_js - pos_is + shifts @ cell, dim=1)
dists = distances(
positions=positions,
cell=cell,
neighbor_indices=neighbor_indices,
neighbor_shifts=neighbor_shifts,
)
# If the contribution from all atoms within the cutoff is to be subtracted
# this short-range part will simply use -V_LR as the potential
if self.subtract_interior:
Expand All @@ -330,7 +331,12 @@ def _compute_sr(
else:
potentials_bare = self.potential.potential_sr_from_dist(dists, smearing)

atom_is = neighbor_indices[0]
atom_js = neighbor_indices[1]

contributions = charges[atom_js] * potentials_bare.unsqueeze(-1)

potential = torch.zeros_like(charges)
potential.index_add_(0, atom_is, contributions)

return potential
Expand Down
102 changes: 72 additions & 30 deletions src/torchpme/calculators/directpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import torch

from ..lib import InversePowerLawPotential, all_neighbor_indices, distances
from .base import CalculatorBaseTorch


class _DirectPotentialImpl:
def __init__(self, exponent):
self.exponent = exponent
self.potential = InversePowerLawPotential(exponent=exponent)

def _compute_single_system(
self,
Expand All @@ -17,33 +19,32 @@ def _compute_single_system(
neighbor_indices: Optional[torch.Tensor],
neighbor_shifts: Optional[torch.Tensor],
) -> torch.Tensor:
# Compute matrix containing the squared distances from the Gram matrix
# The squared distance and the inner product between two vectors r_i and r_j are
# related by: d_ij^2 = |r_i - r_j|^2 = r_i^2 + r_j^2 - 2*r_i*r_j
num_atoms = len(positions)
dtype = positions.dtype
device = positions.device

diagonal_indices = torch.arange(num_atoms, device=device)
gram_matrix = positions @ positions.T
squared_norms = gram_matrix[diagonal_indices, diagonal_indices].reshape(-1, 1)
ones = torch.ones((1, len(positions)), dtype=dtype, device=device)
squared_norms_matrix = torch.matmul(squared_norms, ones)
distances_sq = squared_norms_matrix + squared_norms_matrix.T - 2 * gram_matrix

# Add terms to diagonal in order to avoid division by zero
# Since these components in the target tensor need to be set to zero, we add
# a huge number such that after taking the inverse (since we evaluate 1/r^p),
# the components will effectively be set to zero.
# This is not the most elegant solution, but I am doing this since the more
# obvious alternative of setting the same components to zero after the division
# had issues with autograd. I would appreciate any better alternatives.
distances_sq[diagonal_indices, diagonal_indices] += 1e50

# Compute potential
potentials_by_pair = distances_sq.pow(-self.exponent / 2.0)

return torch.matmul(potentials_by_pair, charges)

if neighbor_indices is None:
neighbor_indices_tensor = all_neighbor_indices(
len(charges), device=charges.device
)
else:
neighbor_indices_tensor = neighbor_indices

dists = distances(
positions=positions,
cell=cell,
neighbor_indices=neighbor_indices_tensor,
neighbor_shifts=neighbor_shifts,
)

potentials_bare = self.potential.potential_from_dist(dists)

atom_is = neighbor_indices_tensor[0]
atom_js = neighbor_indices_tensor[1]

contributions = charges[atom_js] * potentials_bare.unsqueeze(-1)

potential = torch.zeros_like(charges)
potential.index_add_(0, atom_is, contributions)

return potential


class DirectPotential(CalculatorBaseTorch, _DirectPotentialImpl):
Expand Down Expand Up @@ -87,9 +88,29 @@ def compute(
self,
positions: Union[List[torch.Tensor], torch.Tensor],
charges: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[Optional[torch.Tensor]], Optional[torch.Tensor]] = None,
neighbor_indices: Union[
List[Optional[torch.Tensor]], Optional[torch.Tensor]
] = None,
neighbor_shifts: Union[
List[Optional[torch.Tensor]], Optional[torch.Tensor]
] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute potential for all provided "systems" stacked inside list.
If the optional parameter ``neighbor_indices`` is provided only those indices
are taken into account for the compuation. Otherwise all particles are
considered for computing the potential. If ``cell`` and `neighbor_shifts` are
given, compuation is performed taking the periodicity of the system into
account.
.. warning ::
When passing the ``neighbor_shifts`` parameter withput explicit
``neighbor_indices``, the shape of the ``neighbor_shifts`` must have a shape
of ``(num_atoms * (num_atoms - 1), 3)``. Also the order of all pairs must
match the of :py:func:`torchpme.lib.neighbors.all_neighbor_indices`!
The computation is performed on the same ``device`` as ``dtype`` is the input is
stored on. The ``dtype`` of the output tensors will be the same as the input.
Expand All @@ -100,6 +121,17 @@ def compute(
potential should be calculated for a standard potential ``n_channels=1``. If
more than one "channel" is provided multiple potentials for the same
position but different are computed.
:param cell: single or 2D tensor of shape (3, 3), describing the bounding
box/unit cell of the system. Each row should be one of the bounding box
vector; and columns should contain the x, y, and z components of these
vectors (i.e. the cell should be given in row-major order).
:param neighbor_indices: Optional single or list of 2D tensors of shape ``(2,
n)``, where ``n`` is the number of atoms. The two rows correspond to the
indices of a **full neighbor list** for the two atoms which are considered
neighbors (e.g. within a cutoff distance).
:param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n),
where n is the number of atoms. The 3 rows correspond to the shift indices
for periodic images of a **full neighbor list**.
:return: Single or List of torch Tensors containing the potential(s) for all
positions. Each tensor in the list is of shape ``(len(positions),
len(charges))``, where If the inputs are only single tensors only a single
Expand All @@ -108,10 +140,10 @@ def compute(

return self._compute_impl(
positions=positions,
cell=None,
charges=charges,
neighbor_indices=None,
neighbor_shifts=None,
cell=cell,
neighbor_indices=neighbor_indices,
neighbor_shifts=neighbor_shifts,
)

# This function is kept to keep torch-pme compatible with the broader pytorch
Expand All @@ -121,9 +153,19 @@ def forward(
self,
positions: Union[List[torch.Tensor], torch.Tensor],
charges: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[Optional[torch.Tensor]], Optional[torch.Tensor]] = None,
neighbor_indices: Union[
List[Optional[torch.Tensor]], Optional[torch.Tensor]
] = None,
neighbor_shifts: Union[
List[Optional[torch.Tensor]], Optional[torch.Tensor]
] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Forward just calls :py:meth:`compute`."""
return self.compute(
positions=positions,
charges=charges,
cell=cell,
neighbor_indices=neighbor_indices,
neighbor_shifts=neighbor_shifts,
)
3 changes: 3 additions & 0 deletions src/torchpme/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from .mesh_interpolator import MeshInterpolator
from .potentials import InversePowerLawPotential
from .kvectors import generate_kvectors_for_mesh, generate_kvectors_squeezed
from .neighbors import distances, all_neighbor_indices

__all__ = [
"all_neighbor_indices",
"distances",
"FourierSpaceConvolution",
"MeshInterpolator",
"InversePowerLawPotential",
Expand Down
94 changes: 94 additions & 0 deletions src/torchpme/lib/neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Optional

import torch


def all_neighbor_indices(
num_atoms: int,
dtype: torch.dtype = torch.int64,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Computes all neighbor indices between a given number of atoms, excluding self pairs.
:param num_atoms: number of atoms for which to compute the neighbor indices.
:param dtype: data type of the returned tensor.
:param device: The device on which the tensor will be allocated.
:returns: tensor of shape ``(2, num_atoms * (num_atoms - 1))`` containing all pairs
excluding self pairs.
Example
-------
>>> neighbor_indices = all_neighbor_indices(num_atoms=3)
>>> print(neighbor_indices)
tensor([[1, 2, 0, 2, 0, 1],
[0, 0, 1, 1, 2, 2]])
"""
indices = torch.arange(num_atoms, dtype=dtype, device=device).repeat(num_atoms, 1)

atom_is = indices.flatten()
atom_js = indices.T.flatten()

# Filter out the self pairs
mask = atom_is != atom_js

return torch.vstack((atom_is[mask], atom_js[mask]))


def distances(
positions: torch.Tensor,
neighbor_indices: torch.Tensor,
cell: Optional[torch.Tensor] = None,
neighbor_shifts: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Computes the pairwise distances based on positions and neighbor indices.
:param positions: Tensor of shape ``(num_atoms, 3)`` containing the positions of
each atom.
:param neighbor_indices: Tensor of shape ``(2, num_pairs)`` containing pairs of atom
indices.
:param cell: Optional tensor of shape ``(3, 3)`` representing the periodic boundary
conditions (PBC) cell vectors.
:param neighbor_shifts: Optional tensor of shape ``(num_pairs, 3)`` representing the
shift vectors for each neighbor pair under PBC.
:returns: Tensor of shape ``(num_pairs,)`` containing the distances between each
pair of neighbors.
:raises ValueError: If `cell` is provided without `neighbor_shifts` or vice versa.
Example
-------
>>> import torch
>>> positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
>>> neighbor_indices = torch.tensor([[0, 0, 1], [1, 2, 2]])
>>> dists = distances(positions, neighbor_indices)
>>> print(dists)
tensor([1.0000, 1.0000, 1.4142])
If periodic boundary conditions are applied:
>>> cell = torch.eye(3) # Identity matrix for cell vectors
>>> neighbor_shifts = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 0, 0]])
>>> dists = distances(positions, neighbor_indices, cell, neighbor_shifts)
>>> print(dists)
tensor([1.0000, 1.4142, 1.4142])
"""
atom_is = neighbor_indices[0]
atom_js = neighbor_indices[1]

pos_is = positions[atom_is]
pos_js = positions[atom_js]

distance_vectors = pos_js - pos_is

if cell is not None and neighbor_shifts is not None:
shifts = neighbor_shifts.type(cell.dtype)
distance_vectors += shifts @ cell
elif cell is not None and neighbor_shifts is None:
raise ValueError("Provided `cell` but no `neighbor_shifts`.")
elif cell is None and neighbor_shifts is not None:
raise ValueError("Provided `neighbor_shifts` but no `cell`.")

return torch.linalg.norm(distance_vectors, dim=1)
Loading

0 comments on commit 7c687c1

Please sign in to comment.