Skip to content

Commit

Permalink
styling
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jun 13, 2024
1 parent 6d3455a commit 8e325d6
Show file tree
Hide file tree
Showing 14 changed files with 353 additions and 172 deletions.
13 changes: 6 additions & 7 deletions src/meshlode/calculators/calculator_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from meshlode.lib import InversePowerLawPotential
from typing import List, Optional, Union

import torch

from meshlode.lib import InversePowerLawPotential


@torch.jit.script
def _1d_tolist(x: torch.Tensor) -> List[int]:
Expand Down Expand Up @@ -38,17 +39,17 @@ class CalculatorBase(torch.nn.Module):
def __init__(
self,
all_types: Optional[List[int]] = None,
exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64),
exponent: Optional[torch.Tensor] = torch.tensor(1.0, dtype=torch.float64),
):
super().__init__()

if all_types is None:
self.all_types = None
else:
self.all_types = _1d_tolist(torch.unique(torch.tensor(all_types)))

self.exponent = exponent
self.potential = InversePowerLawPotential(exponent = exponent)
self.potential = InversePowerLawPotential(exponent=exponent)

# This function is kept to keep this library compatible with the broader pytorch
# infrastructure, which require a "forward" function. We name this function
Expand All @@ -60,9 +61,7 @@ def forward(
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""forward just calls :py:meth:`CalculatorModule.compute`"""
return self.compute(
types=types, positions=positions, charges=charges
)
return self.compute(types=types, positions=positions, charges=charges)

def compute(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/meshlode/calculators/calculator_base_periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,4 @@ def compute(
if len(types) == 1:
return potentials[0]
else:
return potentials
return potentials
8 changes: 4 additions & 4 deletions src/meshlode/calculators/direct.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .calculator_base import CalculatorBase

import torch

from .calculator_base import CalculatorBase


class DirectPotential(CalculatorBase):
"""A specie-wise long-range potential computed using a direct summation over all
Expand Down Expand Up @@ -73,9 +73,9 @@ def _compute_single_system(
# obvious alternative of setting the same components to zero after the division
# had issues with autograd. I would appreciate any better alternatives.
distances_sq[diagonal_indices, diagonal_indices] += 1e50

# Compute potential
potentials_by_pair = distances_sq.pow(-self.exponent / 2.)
potentials_by_pair = distances_sq.pow(-self.exponent / 2.0)
potentials = torch.matmul(potentials_by_pair, charges)

return potentials
22 changes: 13 additions & 9 deletions src/meshlode/calculators/ewald.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import torch
from typing import List, Optional

from .calculator_base_periodic import CalculatorBasePeriodic
import torch

# extra imports for neighbor list
from ase import Atoms
from ase.neighborlist import neighbor_list

from .calculator_base_periodic import CalculatorBasePeriodic


class EwaldPotential(CalculatorBasePeriodic):
"""A specie-wise long-range potential computed using the Ewald sum, scaling as
O(N^2) with respect to the number of particles N used as a reference to test faster
Expand Down Expand Up @@ -62,12 +64,12 @@ class EwaldPotential(CalculatorBasePeriodic):
def __init__(
self,
all_types: Optional[List[int]] = None,
exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64),
exponent: Optional[torch.Tensor] = torch.tensor(1.0, dtype=torch.float64),
sr_cutoff: Optional[float] = None,
atomic_smearing: Optional[float] = None,
lr_wavelength: Optional[float] = None,
subtract_self: Optional[bool] = True,
subtract_interior: Optional[bool] = False
subtract_interior: Optional[bool] = False,
):
super().__init__(all_types=all_types, exponent=exponent)

Expand Down Expand Up @@ -154,8 +156,8 @@ def _compute_single_system(
sr_cutoff=sr_cutoff,
)

##return charges * torch.sum(positions, dim=1) * self.exponent + potential_sr
##return charges * torch.sum(positions, dim=1) * self.exponent + potential_sr

potential_lr = self._compute_lr(
positions=positions,
charges=charges,
Expand All @@ -164,11 +166,11 @@ def _compute_single_system(
lr_wavelength=lr_wavelength,
)

#return potential_lr
# return potential_lr

potential_ewald = potential_sr + potential_lr
return potential_ewald

def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor:
"""
For a given unit cell, compute all reciprocal space vectors that are used to
Expand Down Expand Up @@ -346,7 +348,9 @@ def _compute_sr(
# Compute energy
potential = torch.zeros_like(charges)
for i, j, shift in zip(atom_is, atom_js, shifts):
dist = torch.linalg.norm(positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell)))
dist = torch.linalg.norm(
positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell))
)

# If the contribution from all atoms within the cutoff is to be subtracted
# this short-range part will simply use -V_LR as the potential
Expand Down
4 changes: 2 additions & 2 deletions src/meshlode/calculators/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .calculator_base_periodic import CalculatorBasePeriodic


class MeshPotential(CalculatorBasePeriodic):
"""A specie-wise long-range potential, computed using the particle-mesh Ewald (PME)
method scaling as O(NlogN) with respect to the number of particles N.
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
interpolation_order: Optional[int] = 4,
subtract_self: Optional[bool] = False,
all_types: Optional[List[int]] = None,
exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64),
exponent: Optional[torch.Tensor] = torch.tensor(1.0, dtype=torch.float64),
):
super().__init__(all_types=all_types, exponent=exponent)

Expand Down Expand Up @@ -120,7 +121,6 @@ def _compute_single_system(
assert positions.dtype == cell.dtype and charges.dtype == cell.dtype
assert positions.device == cell.device and charges.device == cell.device


# Define cutoff in reciprocal space
if mesh_spacing is None:
mesh_spacing = self.mesh_spacing
Expand Down
41 changes: 27 additions & 14 deletions src/meshlode/calculators/meshewald.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import torch
from typing import List, Optional

# from .mesh import MeshPotential
from .calculator_base_periodic import CalculatorBasePeriodic
from meshlode.lib.mesh_interpolator import MeshInterpolator
import torch

# extra imports for neighbor list
from ase import Atoms
from ase.neighborlist import neighbor_list

from meshlode.lib.mesh_interpolator import MeshInterpolator

# from .mesh import MeshPotential
from .calculator_base_periodic import CalculatorBasePeriodic


class MeshEwaldPotential(CalculatorBasePeriodic):
"""A specie-wise long-range potential computed using a mesh-based Ewald method,
scaling as O(NlogN) with respect to the number of particles N used as a reference
Expand Down Expand Up @@ -46,15 +49,17 @@ class MeshEwaldPotential(CalculatorBasePeriodic):
def __init__(
self,
all_types: Optional[List[int]] = None,
exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64),
exponent: Optional[int] = 1,
sr_cutoff: Optional[float] = None,
atomic_smearing: Optional[float] = None,
mesh_spacing: Optional[float] = None,
subtract_self: Optional[bool] = True,
interpolation_order: Optional[int] = 4,
subtract_interior: Optional[bool] = False
subtract_interior: Optional[bool] = False,
):
super().__init__(all_types=all_types, exponent=exponent)
super().__init__(
all_types=all_types, exponent=torch.tensor(exponent, dtype=torch.float64)
)

# Check that all provided values are correct
if interpolation_order not in [1, 2, 3, 4, 5]:
Expand All @@ -66,7 +71,10 @@ def __init__(
self.atomic_smearing = atomic_smearing
self.mesh_spacing = mesh_spacing
self.interpolation_order = interpolation_order
self.sr_cutoff = torch.tensor(sr_cutoff)
if sr_cutoff is not None:
self.sr_cutoff = torch.tensor(sr_cutoff)
else:
self.sr_cutoff = sr_cutoff

# If interior contributions are to be subtracted, also do so for self term
if subtract_interior:
Expand Down Expand Up @@ -162,7 +170,9 @@ def _compute_single_system(
cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6
if self.sr_cutoff is not None:
if self.sr_cutoff > cutoff_max:
raise ValueError(f"sr_cutoff {self.sr_cutoff} needs to be < {cutoff_max}")
raise ValueError(
f"sr_cutoff {self.sr_cutoff} needs to be < {cutoff_max}"
)

# Set the defaut values of convergence parameters
# The total computational cost = cost of SR part + cost of LR part
Expand All @@ -184,7 +194,7 @@ def _compute_single_system(
smearing = self.atomic_smearing

if self.mesh_spacing is None:
mesh_spacing = smearing / 8.
mesh_spacing = smearing / 8.0
else:
mesh_spacing = self.mesh_spacing

Expand All @@ -203,7 +213,8 @@ def _compute_single_system(
charges=charges,
cell=cell,
smearing=smearing,
lr_wavelength=mesh_spacing)
lr_wavelength=mesh_spacing,
)

# Combine both parts to obtain the full potential
potential_ewald = potential_sr + potential_lr
Expand Down Expand Up @@ -258,11 +269,11 @@ def _compute_lr(
# Step 2.1: Generate k-vectors and evaluate kernel function
kvectors = self._generate_kvectors(ns=ns, cell=cell)
knorm_sq = torch.sum(kvectors**2, dim=3)

# Step 2.2: Evaluate kernel function (careful, tensor shapes are different from
# the pure Ewald implementation since we are no longer flattening)
G = self.potential.potential_fourier_from_k_sq(knorm_sq, smearing)
G[0,0,0] = self.potential.potential_fourier_at_zero(smearing)
G[0, 0, 0] = self.potential.potential_fourier_at_zero(smearing)

potential_mesh = rho_mesh

Expand Down Expand Up @@ -323,7 +334,9 @@ def _compute_sr(
# Compute energy
potential = torch.zeros_like(charges)
for i, j, shift in zip(atom_is, atom_js, shifts):
dist = torch.linalg.norm(positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell)))
dist = torch.linalg.norm(
positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell))
)

# If the contribution from all atoms within the cutoff is to be subtracted
# this short-range part will simply use -V_LR as the potential
Expand Down
Loading

0 comments on commit 8e325d6

Please sign in to comment.