Skip to content

Commit

Permalink
simplify base class even further
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jul 4, 2024
1 parent b1556ac commit 1cd02a4
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 212 deletions.
85 changes: 0 additions & 85 deletions src/meshlode/calculators/calculator_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -190,13 +189,6 @@ def _validate_compute_parameters(
f"cell ({cell_single.device})"
)

if type(neighbor_indices_single) is not type(neighbor_indices_single):
raise ValueError(
f"Inconsistent of neighbor_indices "
f"({type(neighbor_indices_single)}) and neighbor_indices "
f"({neighbor_indices_single})"
)

if neighbor_indices_single is not None:
# TODO validate shape and dtype

Expand Down Expand Up @@ -297,83 +289,6 @@ def _compute_impl(
else:
return potentials

def compute(
self,
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor] = None,
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None,
neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute potential for all provided "systems" stacked inside list.
The computation is performed on the same ``device`` as ``systems`` is stored on.
The ``dtype`` of the output tensors will be the same as the input.
:param types: single or list of 1D tensor of integer representing the
particles identity. For atoms, this is typically their atomic numbers.
:param positions: single or 2D tensor of shape (len(types), 3) containing the
Cartesian positions of all particles in the system.
:param cell: Ignored.
:param charges: Optional single or list of 2D tensor of shape (len(types), n),
:param neighbor_indices: Optional single or list of 2D tensors of shape (2, n),
where n is the number of atoms. The 2 rows correspond to the indices of
the two atoms which are considered neighbors (e.g. within a cutoff distance)
:param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n),
where n is the number of atoms. The 3 rows correspond to the shift indices
for periodic images.
:return: List of torch Tensors containing the potentials for all frames and all
atoms. Each tensor in the list is of shape (n_atoms, n_types), where
n_types is the number of types in all systems combined. If the input was
a single system only a single torch tensor with the potentials is returned.
IMPORTANT: If multiple types are present, the different "types-channels"
are ordered according to atomic number. For example, if a structure contains
a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this
system, the feature tensor will be of shape (3, 2) = (``n_atoms``,
``n_types``), where ``features[0, 0]`` is the potential at the position of
the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms,
while ``features[0,1]`` is the potential at the position of the Oxygen atom
generated by the Oxygen atom(s).
"""
if cell is not None:
warnings.warn(
"`cell` parameter was proviced but will be ignored", stacklevel=2
)

return self._compute_impl(
types=types,
positions=positions,
cell=cell,
charges=charges,
neighbor_indices=neighbor_indices,
neighbor_shifts=neighbor_shifts,
)

# This function is kept to keep MeshLODE compatible with the broader pytorch
# infrastructure, which require a "forward" function. We name this function
# "compute" instead, for compatibility with other COSMO software.
def forward(
self,
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor] = None,
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None,
neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""forward just calls :py:meth:`CalculatorModule.compute`"""
return self.compute(
types=types,
positions=positions,
cell=cell,
charges=charges,
neighbor_indices=neighbor_indices,
neighbor_shifts=neighbor_shifts,
)

def _compute_single_system(
self,
positions: torch.Tensor,
Expand Down
88 changes: 0 additions & 88 deletions src/meshlode/calculators/calculator_base_periodic.py

This file was deleted.

59 changes: 58 additions & 1 deletion src/meshlode/calculators/direct.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import List, Optional, Union

import torch

Expand All @@ -22,6 +22,63 @@ class DirectPotential(CalculatorBase):

name = "DirectPotential"

def compute(
self,
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute potential for all provided "systems" stacked inside list.
The computation is performed on the same ``device`` as ``systems`` is stored on.
The ``dtype`` of the output tensors will be the same as the input.
:param types: single or list of 1D tensor of integer representing the
particles identity. For atoms, this is typically their atomic numbers.
:param positions: single or 2D tensor of shape (len(types), 3) containing the
Cartesian positions of all particles in the system.
:param charges: Optional single or list of 2D tensor of shape (len(types), n),
:return: List of torch Tensors containing the potentials for all frames and all
atoms. Each tensor in the list is of shape (n_atoms, n_types), where
n_types is the number of types in all systems combined. If the input was
a single system only a single torch tensor with the potentials is returned.
IMPORTANT: If multiple types are present, the different "types-channels"
are ordered according to atomic number. For example, if a structure contains
a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this
system, the feature tensor will be of shape (3, 2) = (``n_atoms``,
``n_types``), where ``features[0, 0]`` is the potential at the position of
the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms,
while ``features[0,1]`` is the potential at the position of the Oxygen atom
generated by the Oxygen atom(s).
"""

return self._compute_impl(
types=types,
positions=positions,
cell=None,
charges=charges,
neighbor_indices=None,
neighbor_shifts=None,
)

# This function is kept to keep MeshLODE compatible with the broader pytorch
# infrastructure, which require a "forward" function. We name this function
# "compute" instead, for compatibility with other COSMO software.
def forward(
self,
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""forward just calls :py:meth:`CalculatorModule.compute`"""
return self.compute(
types=types,
positions=positions,
charges=charges,
)

def _compute_single_system(
self,
positions: torch.Tensor,
Expand Down
80 changes: 78 additions & 2 deletions src/meshlode/calculators/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from ase import Atoms
from ase.neighborlist import neighbor_list

from .calculator_base_periodic import CalculatorBasePeriodic
from .calculator_base import CalculatorBase


class EwaldPotential(CalculatorBasePeriodic):
class EwaldPotential(CalculatorBase):
"""A specie-wise long-range potential computed using the Ewald sum, scaling as
O(N^2) with respect to the number of particles N used as a reference to test faster
implementations.
Expand Down Expand Up @@ -84,6 +84,82 @@ def __init__(
self.subtract_self = subtract_self
self.subtract_interior = subtract_interior

def compute(
self,
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor],
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None,
neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute potential for all provided "systems" stacked inside list.
The computation is performed on the same ``device`` as ``systems`` is stored on.
The ``dtype`` of the output tensors will be the same as the input.
:param types: single or list of 1D tensor of integer representing the
particles identity. For atoms, this is typically their atomic numbers.
:param positions: single or 2D tensor of shape (len(types), 3) containing the
Cartesian positions of all particles in the system.
:param cell: single or 2D tensor of shape (3, 3), describing the bounding
box/unit cell of the system. Each row should be one of the bounding box
vector; and columns should contain the x, y, and z components of these
vectors (i.e. the cell should be given in row-major order).
:param charges: Optional single or list of 2D tensor of shape (len(types), n),
:param neighbor_indices: Optional single or list of 2D tensors of shape (2, n),
where n is the number of atoms. The 2 rows correspond to the indices of
the two atoms which are considered neighbors (e.g. within a cutoff distance)
:param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n),
where n is the number of atoms. The 3 rows correspond to the shift indices
for periodic images.
:return: List of torch Tensors containing the potentials for all frames and all
atoms. Each tensor in the list is of shape (n_atoms, n_types), where
n_types is the number of types in all systems combined. If the input was
a single system only a single torch tensor with the potentials is returned.
IMPORTANT: If multiple types are present, the different "types-channels"
are ordered according to atomic number. For example, if a structure contains
a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this
system, the feature tensor will be of shape (3, 2) = (``n_atoms``,
``n_types``), where ``features[0, 0]`` is the potential at the position of
the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms,
while ``features[0,1]`` is the potential at the position of the Oxygen atom
generated by the Oxygen atom(s).
"""

return self._compute_impl(
types=types,
positions=positions,
cell=cell,
charges=charges,
neighbor_indices=neighbor_indices,
neighbor_shifts=neighbor_shifts,
)

# This function is kept to keep MeshLODE compatible with the broader pytorch
# infrastructure, which require a "forward" function. We name this function
# "compute" instead, for compatibility with other COSMO software.
def forward(
self,
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor],
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None,
neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""forward just calls :py:meth:`CalculatorModule.compute`"""
return self.compute(
types=types,
positions=positions,
cell=cell,
charges=charges,
neighbor_indices=neighbor_indices,
neighbor_shifts=neighbor_shifts,
)

def _compute_single_system(
self,
positions: torch.Tensor,
Expand Down
Loading

0 comments on commit 1cd02a4

Please sign in to comment.