Skip to content

Commit

Permalink
fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jul 10, 2024
1 parent 22fd41f commit 9be73a6
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 107 deletions.
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ You can install *MeshLode* using pip with
You can then ``import meshlode`` and use it in your projects!

We also provide bindings to `metatensor <https://docs.metatensor.org/latest/>`_ which can
optionally be installed together and used as ``meshlode.metatensor`` via
We also provide bindings to `metatensor <https://docs.metatensor.org/latest/>`_ which
can optionally be installed together and used as ``meshlode.metatensor`` via

.. code-block:: bash
Expand Down
43 changes: 6 additions & 37 deletions src/meshlode/calculators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,16 @@
from meshlode.lib import InversePowerLawPotential


class CalculatorBase(torch.nn.Module):
"""Base class providing general funtionality."""
class _ShortRange:
"""Base class providing general funtionality for short range interactions."""

def __init__(
self,
exponent: float,
):
def __init__(self, exponent: float, subtract_interior: bool):
# Attach the function handling all computations related to the
# power-law potential for later convenience
self.exponent = exponent
self.subtract_interior = subtract_interior
self.potential = InversePowerLawPotential(exponent=exponent)

super().__init__()

def _compute_sr(
self,
positions: torch.Tensor,
Expand All @@ -31,32 +27,6 @@ def _compute_sr(
neighbor_indices: Optional[torch.Tensor] = None,
neighbor_shifts: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute the short-range part of the Ewald sum in realspace
:param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian
coordinates of the atoms. The implementation also works if the positions
are not contained within the unit cell.
:param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest
case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the
charge of atom i. More generally, the potential for the same atom positions
is computed for n_channels independent meshes, and one can specify the
"charge" of each atom on each of the meshes independently.
:param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the
structure, where cell[i] is the i-th basis vector.
:param smearing: torch.Tensor smearing paramter determining the splitting
between the SR and LR parts.
:param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum.
:param neighbor_indices: Optional single or list of 2D tensors of shape (2, n),
where n is the number of atoms. The 2 rows correspond to the indices of
the two atoms which are considered neighbors (e.g. within a cutoff distance)
:param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n),
where n is the number of atoms. The 3 rows correspond to the shift indices
for periodic images.
:returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential
at the position of each atom for the `n_channels` independent meshes separately.
"""
if neighbor_indices is None or neighbor_shifts is None:
# Get list of neighbors
struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True)
Expand Down Expand Up @@ -94,7 +64,7 @@ def _compute_sr(
return potential


class CalculatorBaseTorch(CalculatorBase):
class CalculatorBaseTorch(torch.nn.Module):
"""
Base calculator for the torch interface to MeshLODE.
Expand All @@ -103,9 +73,8 @@ class CalculatorBaseTorch(CalculatorBase):

def __init__(
self,
exponent: float,
):
super().__init__(exponent=exponent)
super().__init__()

def _validate_compute_parameters(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/meshlode/calculators/directpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class DirectPotential(CalculatorBaseTorch, _DirectPotentialImpl):

def __init__(self, exponent: float = 1.0):
_DirectPotentialImpl.__init__(self, exponent=exponent)
CalculatorBaseTorch.__init__(self, exponent=exponent)
CalculatorBaseTorch.__init__(self)

def compute(
self,
Expand Down
23 changes: 13 additions & 10 deletions src/meshlode/calculators/ewaldpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,35 @@
import torch

from ..lib import generate_kvectors_squeezed
from .base import CalculatorBaseTorch
from .base import CalculatorBaseTorch, _ShortRange


class _EwaldPotentialImpl:
class _EwaldPotentialImpl(_ShortRange):
def __init__(
self,
exponent: float,
sr_cutoff: Union[None, torch.Tensor],
atomic_smearing: Union[None, float],
lr_wavelength: Union[None, float],
subtract_self: Union[None, bool],
subtract_interior: Union[None, bool],
subtract_self: bool,
subtract_interior: bool,
):
if exponent < 0.0 or exponent > 3.0:
raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p < 3")
if atomic_smearing is not None and atomic_smearing <= 0:
raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive")

_ShortRange.__init__(
self, exponent=exponent, subtract_interior=subtract_interior
)
self.atomic_smearing = atomic_smearing
self.sr_cutoff = sr_cutoff
self.lr_wavelength = lr_wavelength

# If interior contributions are to be subtracted, also do so for self term
if subtract_interior:
if self.subtract_interior:
subtract_self = True
self.subtract_self = subtract_self
self.subtract_interior = subtract_interior

def _compute_single_system(
self,
Expand Down Expand Up @@ -154,7 +156,8 @@ def _compute_lr(
# TODO: modify to expression for general p
if subtract_self:
self_contrib = (
torch.sqrt(torch.tensor(2.0 / torch.pi, device=self._device)) / smearing
torch.sqrt(torch.tensor(2.0 / torch.pi, device=positions.device))
/ smearing
)
energy -= charges * self_contrib

Expand Down Expand Up @@ -218,8 +221,8 @@ def __init__(
sr_cutoff: Optional[torch.Tensor] = None,
atomic_smearing: Optional[float] = None,
lr_wavelength: Optional[float] = None,
subtract_self: Optional[bool] = True,
subtract_interior: Optional[bool] = False,
subtract_self: bool = True,
subtract_interior: bool = False,
):
_EwaldPotentialImpl.__init__(
self,
Expand All @@ -230,7 +233,7 @@ def __init__(
subtract_self=subtract_self,
subtract_interior=subtract_interior,
)
CalculatorBaseTorch.__init__(self, exponent=exponent)
CalculatorBaseTorch.__init__(self)

def compute(
self,
Expand Down
24 changes: 13 additions & 11 deletions src/meshlode/calculators/pmepotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@

from ..lib import generate_kvectors_for_mesh
from ..lib.mesh_interpolator import MeshInterpolator
from .base import CalculatorBaseTorch
from .base import CalculatorBaseTorch, _ShortRange


class _PMEPotentialImpl:
class _PMEPotentialImpl(_ShortRange):
def __init__(
self,
exponent: float,
sr_cutoff: Union[None, torch.Tensor],
atomic_smearing: Union[None, float],
mesh_spacing: Union[None, float],
interpolation_order: Union[None, int],
subtract_self: Union[None, bool],
subtract_interior: Union[None, bool],
interpolation_order: int,
subtract_self: bool,
subtract_interior: bool,
):
# Check that all provided values are correct
if exponent < 0.0 or exponent > 3.0:
Expand All @@ -26,16 +26,18 @@ def __init__(
if atomic_smearing is not None and atomic_smearing <= 0:
raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive")

_ShortRange.__init__(
self, exponent=exponent, subtract_interior=subtract_interior
)
self.atomic_smearing = atomic_smearing
self.mesh_spacing = mesh_spacing
self.interpolation_order = interpolation_order
self.sr_cutoff = sr_cutoff

# If interior contributions are to be subtracted, also do so for self term
if subtract_interior:
if self.subtract_interior:
subtract_self = True
self.subtract_self = subtract_self
self.subtract_interior = subtract_interior

self.atomic_smearing = atomic_smearing
self.mesh_spacing = mesh_spacing
Expand Down Expand Up @@ -225,9 +227,9 @@ def __init__(
sr_cutoff: Optional[torch.Tensor] = None,
atomic_smearing: Optional[float] = None,
mesh_spacing: Optional[float] = None,
interpolation_order: Optional[int] = 3,
subtract_self: Optional[bool] = True,
subtract_interior: Optional[bool] = False,
interpolation_order: int = 3,
subtract_self: bool = True,
subtract_interior: bool = False,
):
_PMEPotentialImpl.__init__(
self,
Expand All @@ -239,7 +241,7 @@ def __init__(
subtract_self=subtract_self,
subtract_interior=subtract_interior,
)
CalculatorBaseTorch.__init__(self, exponent=exponent)
CalculatorBaseTorch.__init__(self)

def compute(
self,
Expand Down
8 changes: 3 additions & 5 deletions src/meshlode/metatensor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
"Try installing it with:\npip install metatensor[torch]"
)

from ..calculators.base import CalculatorBase


class CalculatorBaseMetatensor(CalculatorBase):
def __init__(self, exponent: float):
super().__init__(exponent)
class CalculatorBaseMetatensor(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, systems: Union[List[System], System]) -> TensorMap:
"""Forward just calls :py:meth:`compute`."""
Expand Down
2 changes: 1 addition & 1 deletion src/meshlode/metatensor/directpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ class DirectPotential(CalculatorBaseMetatensor, _DirectPotentialImpl):

def __init__(self, exponent: float = 1.0):
_DirectPotentialImpl.__init__(self, exponent=exponent)
CalculatorBaseMetatensor.__init__(self, exponent=exponent)
CalculatorBaseMetatensor.__init__(self)
6 changes: 3 additions & 3 deletions src/meshlode/metatensor/ewaldpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(
sr_cutoff: Optional[torch.Tensor] = None,
atomic_smearing: Optional[float] = None,
lr_wavelength: Optional[float] = None,
subtract_self: Optional[bool] = True,
subtract_interior: Optional[bool] = False,
subtract_self: bool = True,
subtract_interior: bool = False,
):
_EwaldPotentialImpl.__init__(
self,
Expand All @@ -72,4 +72,4 @@ def __init__(
subtract_self=subtract_self,
subtract_interior=subtract_interior,
)
CalculatorBaseMetatensor.__init__(self, exponent=exponent)
CalculatorBaseMetatensor.__init__(self)
8 changes: 4 additions & 4 deletions src/meshlode/metatensor/pmepotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def __init__(
sr_cutoff: Optional[torch.Tensor] = None,
atomic_smearing: Optional[float] = None,
mesh_spacing: Optional[float] = None,
interpolation_order: Optional[int] = 3,
subtract_self: Optional[bool] = True,
subtract_interior: Optional[bool] = False,
interpolation_order: int = 3,
subtract_self: bool = True,
subtract_interior: bool = False,
):
_PMEPotentialImpl.__init__(
self,
Expand All @@ -74,4 +74,4 @@ def __init__(
subtract_self=subtract_self,
subtract_interior=subtract_interior,
)
CalculatorBaseMetatensor.__init__(self, exponent=exponent)
CalculatorBaseMetatensor.__init__(self)
Loading

0 comments on commit 9be73a6

Please sign in to comment.