Skip to content

Commit

Permalink
cleanup base classes
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jul 4, 2024
1 parent 809d85e commit b1556ac
Show file tree
Hide file tree
Showing 14 changed files with 325 additions and 352 deletions.
360 changes: 246 additions & 114 deletions src/meshlode/calculators/calculator_base.py

Large diffs are not rendered by default.

159 changes: 21 additions & 138 deletions src/meshlode/calculators/calculator_base_periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,30 @@ class CalculatorBasePeriodic(CalculatorBase):

name = "CalculatorBasePeriodic"

# Note that the base class also has this function, but with the parameter "cell"
# only as an option. For periodic implementations, "cell" is a strictly required
# parameter, which is why this function is implemented again.
# This function is kept to keep MeshLODE compatible with the broader pytorch
# infrastructure, which require a "forward" function. We name this function
# "compute" instead, for compatibility with other COSMO software.
def forward(
self,
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor] = None,
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None,
neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""forward just calls :py:meth:`CalculatorModule.compute`"""
return self.compute(
types=types, positions=positions, cell=cell, charges=charges
types=types,
positions=positions,
cell=cell,
charges=charges,
neighbor_indices=neighbor_indices,
neighbor_shifts=neighbor_shifts,
)

def compute(
self,
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor] = None,
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None,
neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None,
Expand Down Expand Up @@ -74,132 +75,14 @@ def compute(
while ``features[0,1]`` is the potential at the position of the Oxygen atom
generated by the Oxygen atom(s).
"""
# make sure compute function works if only a single tensor are provided as input
if not isinstance(types, list):
types = [types]
if not isinstance(positions, list):
positions = [positions]
if not isinstance(cell, list):
cell = [cell]
if (neighbor_indices is not None) and not isinstance(neighbor_indices, list):
neighbor_indices = [neighbor_indices]
if (neighbor_shifts is not None) and not isinstance(neighbor_shifts, list):
neighbor_shifts = [neighbor_shifts]

# Check that all inputs are consistent
for types_single, positions_single, cell_single in zip(types, positions, cell):
if len(types_single.shape) != 1:
raise ValueError(
"each `types` must be a 1 dimensional tensor, got at least "
f"one tensor with {len(types_single.shape)} dimensions"
)

if positions_single.shape != (len(types_single), 3):
raise ValueError(
"each `positions` must be a (n_types x 3) tensor, got at least "
f"one tensor with shape {list(positions_single.shape)}"
)

if cell_single.shape != (3, 3):
raise ValueError(
"each `cell` must be a (3 x 3) tensor, got at least "
f"one tensor with shape {list(cell_single.shape)}"
)

if cell_single.dtype != positions_single.dtype:
raise ValueError(
"`cell` must be have the same dtype as `positions`, got "
f"{cell_single.dtype} and {positions_single.dtype}"
)

if (
positions_single.device != types_single.device
or cell_single.device != types_single.device
):
raise ValueError(
"`types`, `positions`, and `cell` must be on the same device, got "
f"{types_single.device}, {positions_single.device} and "
f"{cell_single.device}."
)

requested_types = self._get_requested_types(types)

# If charges are not provided, we assume that all types are treated separately
if charges is None:
charges = []
for types_single, positions_single in zip(types, positions):
# One-hot encoding of charge information
charges_single = self._one_hot_charges(
types=types_single,
requested_types=requested_types,
dtype=positions_single.dtype,
device=positions_single.device,
)
charges.append(charges_single)

# If charges are provided, we need to make sure that they are consistent with
# the provided types
else:
if not isinstance(charges, list):
charges = [charges]
if len(charges) != len(types):
raise ValueError(
"The number of `types` and `charges` tensors must be the same, "
f"got {len(types)} and {len(charges)}."
)
for charges_single, types_single in zip(charges, types):
if charges_single.shape[0] != len(types_single):
raise ValueError(
"The first dimension of `charges` must be the same as the "
f"length of `types`, got {charges_single.shape[0]} and "
f"{len(types_single)}."
)
if charges[0].dtype != positions[0].dtype:
raise ValueError(
"`charges` must be have the same dtype as `positions`, got "
f"{charges[0].dtype} and {positions[0].dtype}."
)
if charges[0].device != positions[0].device:
raise ValueError(
"`charges` must be on the same device as `positions`, got "
f"{charges[0].device} and {positions[0].device}."
)
# We don't require and test that all dtypes and devices are consistent if a list
# of inputs. Each "frame" is processed independently.
potentials = []

if neighbor_indices is None or neighbor_shifts is None:
for positions_single, cell_single, charges_single in zip(
positions, cell, charges
):
# Compute the potentials
potentials.append(
self._compute_single_system(
positions=positions_single,
charges=charges_single,
cell=cell_single,
)
)
else:
for (
positions_single,
cell_single,
charges_single,
neighbor_indices_single,
neighbor_shifts_single,
) in zip(positions, cell, charges, neighbor_indices, neighbor_shifts):
# Compute the potentials
potentials.append(
self._compute_single_system(
positions=positions_single,
charges=charges_single,
cell=cell_single,
neighbor_indices=neighbor_indices_single,
neighbor_shifts=neighbor_shifts_single,
)
)

if len(types) == 1:
return potentials[0]
else:
return potentials
if cell is None:
raise ValueError("cell must be provided")

return self._compute_impl(
types=types,
positions=positions,
cell=cell,
charges=charges,
neighbor_indices=neighbor_indices,
neighbor_shifts=neighbor_shifts,
)
33 changes: 6 additions & 27 deletions src/meshlode/calculators/direct.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

import torch

from .calculator_base import CalculatorBase
Expand All @@ -23,33 +25,11 @@ class DirectPotential(CalculatorBase):
def _compute_single_system(
self,
positions: torch.Tensor,
cell: Union[None, torch.Tensor],
charges: torch.Tensor,
neighbor_indices: Union[None, torch.Tensor],
neighbor_shifts: Union[None, torch.Tensor],
) -> torch.Tensor:
"""
Compute the "electrostatic" potential at the position of all atoms in a
structure.
This solver does not use periodic boundaries, and thus also does not take into
account potential periodic images.
:param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian
coordinates of the atoms. The implementation also works if the positions
are not contained within the unit cell.
:param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest
case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the
charge of atom i. More generally, the potential for the same atom positions
is computed for n_channels independent meshes, and one can specify the
"charge" of each atom on each of the meshes independently. For standard LODE
that treats all (atomic) types separately, one example could be: If n_atoms
= 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use
the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for
the charges. This would then separately compute the "Na" potential and "Cl"
potential. Subtracting these from each other, one could recover the more
standard electrostatic potential in which Na and Cl have charges of +1 and
-1, respectively.
:returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential
at the position of each atom for the `n_channels` independent meshes separately.
"""
# Compute matrix containing the squared distances from the Gram matrix
# The squared distance and the inner product between two vectors r_i and r_j are
# related by: d_ij^2 = |r_i - r_j|^2 = r_i^2 + r_j^2 - 2*r_i*r_j
Expand All @@ -72,6 +52,5 @@ def _compute_single_system(

# Compute potential
potentials_by_pair = distances_sq.pow(-self.exponent / 2.0)
potentials = torch.matmul(potentials_by_pair, charges)

return potentials
return torch.matmul(potentials_by_pair, charges)
34 changes: 5 additions & 29 deletions src/meshlode/calculators/ewald.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import List, Optional
from typing import List, Optional, Union

import torch

# extra imports for neighbor list
from ase import Atoms
from ase.neighborlist import neighbor_list

from .calculator_base import default_exponent
from .calculator_base_periodic import CalculatorBasePeriodic


Expand Down Expand Up @@ -65,7 +64,7 @@ class EwaldPotential(CalculatorBasePeriodic):
def __init__(
self,
all_types: Optional[List[int]] = None,
exponent: Optional[torch.Tensor] = default_exponent,
exponent: float = 1.0,
sr_cutoff: Optional[torch.Tensor] = None,
atomic_smearing: Optional[float] = None,
lr_wavelength: Optional[float] = None,
Expand All @@ -88,34 +87,11 @@ def __init__(
def _compute_single_system(
self,
positions: torch.Tensor,
cell: Union[None, torch.Tensor],
charges: torch.Tensor,
cell: torch.Tensor,
neighbor_indices: Union[None, torch.Tensor],
neighbor_shifts: Union[None, torch.Tensor],
) -> torch.Tensor:
"""
Compute the "electrostatic" potential at the position of all atoms in a
structure.
:param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian
coordinates of the atoms. The implementation also works if the positions
are not contained within the unit cell.
:param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest
case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the
charge of atom i. More generally, the potential for the same atom positions
is computed for n_channels independent meshes, and one can specify the
"charge" of each atom on each of the meshes independently. For standard LODE
that treats all (atomic) types separately, one example could be: If n_atoms
= 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use
the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for
the charges. This would then separately compute the "Na" potential and "Cl"
potential. Subtracting these from each other, one could recover the more
standard electrostatic potential in which Na and Cl have charges of +1 and
-1, respectively.
:param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the
structure, where cell[i] is the i-th basis vector.
:returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential
at the position of each atom for the `n_channels` independent meshes separately.
"""
# Check that the realspace cutoff (if provided) is not too large
# This is because the current implementation is not able to return multiple
# periodic images of the same atom as a neighbor
Expand Down
27 changes: 9 additions & 18 deletions src/meshlode/calculators/mesh.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import List, Optional
from typing import List, Optional, Union

import torch

from meshlode.lib.fourier_convolution import FourierSpaceConvolution
from meshlode.lib.mesh_interpolator import MeshInterpolator

from .calculator_base import default_exponent
from .calculator_base_periodic import CalculatorBasePeriodic


Expand Down Expand Up @@ -58,7 +57,7 @@ def __init__(
interpolation_order: Optional[int] = 4,
subtract_self: Optional[bool] = False,
all_types: Optional[List[int]] = None,
exponent: Optional[torch.Tensor] = default_exponent,
exponent: float = 1.0,
):
super().__init__(all_types=all_types, exponent=exponent)

Expand All @@ -71,11 +70,12 @@ def __init__(
# If no explicit mesh_spacing is given, set it such that it can resolve
# the smeared potentials.
if mesh_spacing is None:
mesh_spacing = atomic_smearing / 2
self.mesh_spacing = atomic_smearing / 2
else:
self.mesh_spacing = mesh_spacing

# Store provided parameters
self.atomic_smearing = atomic_smearing
self.mesh_spacing = mesh_spacing
self.interpolation_order = interpolation_order
self.subtract_self = subtract_self

Expand All @@ -85,9 +85,10 @@ def __init__(
def _compute_single_system(
self,
positions: torch.Tensor,
cell: Union[None, torch.Tensor],
charges: torch.Tensor,
cell: torch.Tensor,
mesh_spacing: Optional[float] = None,
neighbor_indices: Union[None, torch.Tensor],
neighbor_shifts: Union[None, torch.Tensor],
) -> torch.Tensor:
"""
Compute the "electrostatic" potential at the position of all atoms in a
Expand Down Expand Up @@ -115,17 +116,7 @@ def _compute_single_system(
at the position of each atom for the `n_channels` independent meshes separately.
"""
# Initializations
n_atoms = len(positions)
assert positions.shape == (n_atoms, 3)
assert charges.shape[0] == n_atoms

assert positions.dtype == cell.dtype and charges.dtype == cell.dtype
assert positions.device == cell.device and charges.device == cell.device

# Define cutoff in reciprocal space
if mesh_spacing is None:
mesh_spacing = self.mesh_spacing
k_cutoff = 2 * torch.pi / mesh_spacing
k_cutoff = 2 * torch.pi / self.mesh_spacing

# Compute number of times each basis vector of the
# reciprocal space can be scaled until the cutoff
Expand Down
Loading

0 comments on commit b1556ac

Please sign in to comment.