Skip to content

Commit

Permalink
Remove internal neighbor calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jul 11, 2024
1 parent 37dafc7 commit a46b2de
Show file tree
Hide file tree
Showing 17 changed files with 367 additions and 234 deletions.
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ sphinx > 7.0
sphinx-gallery
sphinx-toggleprompt
tomli
chemiscope
1 change: 0 additions & 1 deletion examples/neighborlist_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@
mesh_spacing=mesh_spacing,
interpolation_order=interpolation_order,
subtract_self=True,
sr_cutoff=sr_cutoff,
)
potential = pme.compute(system)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ keywords = [
]
dependencies = [
"torch >=1.11",
"ase >= 3.23.0",
]
dynamic = ["version"]

[project.optional-dependencies]
examples = [
"ase >= 3.23.0",
"chemiscope",
"matplotlib",
]
metatensor = [
Expand Down
43 changes: 12 additions & 31 deletions src/meshlode/calculators/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import List, Optional, Tuple, Union
from typing import List, Tuple, Union

import torch
from ase import Atoms
from ase.neighborlist import neighbor_list

from meshlode.lib import InversePowerLawPotential

Expand All @@ -23,23 +21,12 @@ def _compute_sr(
charges: torch.Tensor,
cell: torch.Tensor,
smearing: float,
sr_cutoff: torch.Tensor,
neighbor_indices: Optional[torch.Tensor] = None,
neighbor_shifts: Optional[torch.Tensor] = None,
neighbor_indices: torch.Tensor,
neighbor_shifts: torch.Tensor,
) -> torch.Tensor:
if neighbor_indices is None or neighbor_shifts is None:
# Get list of neighbors
struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True)
atom_is, atom_js, neighbor_shifts = neighbor_list(
"ijS", struc, sr_cutoff.item(), self_interaction=False
)
atom_is = torch.tensor(atom_is)
atom_js = torch.tensor(atom_js)
shifts = torch.tensor(neighbor_shifts, dtype=cell.dtype) # N x 3
else:
atom_is = neighbor_indices[0]
atom_js = neighbor_indices[1]
shifts = neighbor_shifts.type(cell.dtype).T
atom_is = neighbor_indices[0]
atom_js = neighbor_indices[1]
shifts = neighbor_shifts.type(cell.dtype)

# Compute energy
potential = torch.zeros_like(charges)
Expand All @@ -65,15 +52,9 @@ def _compute_sr(


class CalculatorBaseTorch(torch.nn.Module):
"""
Base calculator for the torch interface to MeshLODE.
:param exponent: the exponent :math:`p` in :math:`1/r^p` potentials
"""
"""Base calculator for the torch interface to MeshLODE."""

def __init__(
self,
):
def __init__(self):
super().__init__()

def _validate_compute_parameters(
Expand Down Expand Up @@ -270,7 +251,7 @@ def _validate_compute_parameters(
f"device {neighbor_shifts_single.device}"
)

return positions, cell, charges, neighbor_indices, neighbor_shifts
return positions, charges, cell, neighbor_indices, neighbor_shifts

def _compute_impl(
self,
Expand All @@ -286,8 +267,8 @@ def _compute_impl(
# more general case
(
positions,
cell,
charges,
cell,
neighbor_indices,
neighbor_shifts,
) = self._validate_compute_parameters(
Expand All @@ -302,11 +283,11 @@ def _compute_impl(
potentials = []
for (
positions_single,
cell_single,
charges_single,
cell_single,
neighbor_indices_single,
neighbor_shifts_single,
) in zip(positions, cell, charges, neighbor_indices, neighbor_shifts):
) in zip(positions, charges, cell, neighbor_indices, neighbor_shifts):
# Compute the potentials
potentials.append(
self._compute_single_system(
Expand Down
6 changes: 3 additions & 3 deletions src/meshlode/calculators/directpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def __init__(self, exponent):
def _compute_single_system(
self,
positions: torch.Tensor,
cell: Union[None, torch.Tensor],
cell: None,
charges: torch.Tensor,
neighbor_indices: Union[None, torch.Tensor],
neighbor_shifts: Union[None, torch.Tensor],
neighbor_indices: None,
neighbor_shifts: None,
) -> torch.Tensor:
# Compute matrix containing the squared distances from the Gram matrix
# The squared distance and the inner product between two vectors r_i and r_j are
Expand Down
73 changes: 29 additions & 44 deletions src/meshlode/calculators/ewaldpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class _EwaldPotentialImpl(_ShortRange):
def __init__(
self,
exponent: float,
sr_cutoff: Union[None, torch.Tensor],
atomic_smearing: Union[None, float],
lr_wavelength: Union[None, float],
subtract_self: bool,
Expand All @@ -25,7 +24,6 @@ def __init__(
self, exponent=exponent, subtract_interior=subtract_interior
)
self.atomic_smearing = atomic_smearing
self.sr_cutoff = sr_cutoff
self.lr_wavelength = lr_wavelength

# If interior contributions are to be subtracted, also do so for self term
Expand All @@ -38,26 +36,22 @@ def _compute_single_system(
positions: torch.Tensor,
charges: torch.Tensor,
cell: torch.Tensor,
neighbor_indices: Optional[torch.Tensor] = None,
neighbor_shifts: Optional[torch.Tensor] = None,
neighbor_indices: torch.Tensor,
neighbor_shifts: torch.Tensor,
) -> torch.Tensor:
# Set the defaut values of convergence parameters
# The total computational cost = cost of SR part + cost of LR part
# Bigger smearing increases the cost of the SR part while decreasing the cost
# of the LR part. Since the latter usually is more expensive, we maximize the
# value of the smearing by default to minimize the cost of the LR part.
# The two auxilary parameters (sr_cutoff, lr_wavelength) then control the
# The auxilary parameter lr_wavelength then control the
# convergence of the SR and LR sums, respectively. The default values are
# chosen to reach a convergence on the order of 1e-4 to 1e-5 for the test
# structures.
if self.sr_cutoff is None:
cell_dimensions = torch.linalg.norm(cell, dim=1)
sr_cutoff = torch.min(cell_dimensions) / 2 - 1e-6
else:
sr_cutoff = self.sr_cutoff

if self.atomic_smearing is None:
smearing = sr_cutoff / 5.0
cell_dimensions = torch.linalg.norm(cell, dim=1)
max_cutoff = torch.min(cell_dimensions) / 2 - 1e-6
smearing = max_cutoff / 5.0
else:
smearing = self.atomic_smearing

Expand All @@ -72,7 +66,6 @@ def _compute_single_system(
charges=charges,
cell=cell,
smearing=smearing,
sr_cutoff=sr_cutoff,
neighbor_indices=neighbor_indices,
neighbor_shifts=neighbor_shifts,
)
Expand Down Expand Up @@ -170,15 +163,21 @@ class EwaldPotential(CalculatorBaseTorch, _EwaldPotentialImpl):
Scaling as :math:`\mathcal{O}(N^2)` with respect to the number of particles
:math:`N`.
For computing a **neighborlist** a reasonable ``cutoff`` is half the length of
the shortest cell vector, which can be for example computed according as
.. code-block:: python
cell_dimensions = torch.linalg.norm(cell, dim=1)
cutoff = torch.min(cell_dimensions) / 2 - 1e-6
This ensures a accuracy of the short range part of ``1e-5``.
:param exponent: the exponent :math:`p` in :math:`1/r^p` potentials
:param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If
not set to a global value, it will be set to be half of the shortest lattice
vector defining the cell (separately for each structure).
:param atomic_smearing: Width of the atom-centered Gaussian used to split the
Coulomb potential into the short- and long-range parts. If not set to a global
value, it will be set to 1/5 times the sr_cutoff value (separately for each
structure) to ensure convergence of the short-range part to a relative precision
of 1e-5.
value, it will be set to 1/5 times of half the larges box vector (separately for
each structure).
:param lr_wavelength: Spatial resolution used for the long-range (reciprocal space)
part of the Ewald sum. More conretely, all Fourier space vectors with a
wavelength >= this value will be kept. If not set to a global value, it will be
Expand All @@ -192,33 +191,12 @@ class EwaldPotential(CalculatorBaseTorch, _EwaldPotentialImpl):
Note that if set to true, the self contribution (see previous) is also
subtracted by default.
Example
-------
We calculate the Madelung constant of a CsCl (Cesium-Chloride) crystal. The
reference value is :math:`2 \cdot 1.7626 / \sqrt{3} \approx 2.0354`.
>>> import torch
Define crystal structure
>>> positions = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]])
>>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1)
>>> cell = torch.eye(3)
Compute the potential
>>> ewald = EwaldPotential()
>>> ewald.compute(positions=positions, charges=charges, cell=cell)
tensor([[-2.0354],
[ 2.0354]])
Which is the same as the reference value given above.
For an **example** on the usage refer to :py:class:`PMEPotential`.
"""

def __init__(
self,
exponent: float = 1.0,
sr_cutoff: Optional[torch.Tensor] = None,
atomic_smearing: Optional[float] = None,
lr_wavelength: Optional[float] = None,
subtract_self: bool = True,
Expand All @@ -227,7 +205,6 @@ def __init__(
_EwaldPotentialImpl.__init__(
self,
exponent=exponent,
sr_cutoff=sr_cutoff,
atomic_smearing=atomic_smearing,
lr_wavelength=lr_wavelength,
subtract_self=subtract_self,
Expand All @@ -240,14 +217,22 @@ def compute(
positions: Union[List[torch.Tensor], torch.Tensor],
charges: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor],
neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None,
neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None,
neighbor_indices: Union[List[torch.Tensor], torch.Tensor],
neighbor_shifts: Union[List[torch.Tensor], torch.Tensor],
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute potential for all provided "systems" stacked inside list.
The computation is performed on the same ``device`` as ``dtype`` is the input is
stored on. The ``dtype`` of the output tensors will be the same as the input.
For computing a **neighborlist** a reasonable ``cutoff`` is half the length of
the shortest cell vector, which can be for example computed according as
.. code-block:: python
cell_dimensions = torch.linalg.norm(cell, dim=1)
cutoff = torch.min(cell_dimensions) / 2 - 1e-6
:param positions: Single or 2D tensor of shape (``len(charges), 3``) containing
the Cartesian positions of all point charges in the system.
:param charges: Single 2D tensor or list of 2D tensor of shape (``n_channels,
Expand Down
Loading

0 comments on commit a46b2de

Please sign in to comment.