Skip to content

Commit

Permalink
Linted + added a custom PyTorch exponential integral function to supp…
Browse files Browse the repository at this point in the history
…ort backpropagation
  • Loading branch information
E-Rum committed Jan 6, 2025
1 parent 1d2ed8a commit 54b6fd6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
13 changes: 5 additions & 8 deletions src/torchpme/potentials/integerspline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

import torch
from torch.special import gammaln, gammainc
from torch.special import gammainc, gammaln

from .potential import Potential
from .spline import SplinePotential
Expand All @@ -17,6 +17,7 @@ def gamma(x: torch.Tensor) -> torch.Tensor:
"""
return torch.exp(gammaln(x))


class InversePowerLawPotentialSpline(Potential):
"""
Inverse power-law potentials of the form :math:`1/r^p`.
Expand Down Expand Up @@ -102,16 +103,12 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
x = 0.5 * dist**2 / smearing**2
peff = exponent / 2
prefac = 1.0 / (2 * smearing**2) ** peff
return prefac * gammainc(peff, x) / x ** peff
return prefac * gammainc(peff, x) / x**peff

@torch.jit.export
def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
r"""
TODO: Fourier transform of the LR part potential in terms of :math:`\mathbf{k^2}`.
"""
spline = SplinePotential(
self.r_grid, self.lr_from_dist(self.r_grid)
)
r"""TODO: Fourier transform of the LR part potential in terms of :math:`\mathbf{k^2}`."""
spline = SplinePotential(self.r_grid, self.lr_from_dist(self.r_grid))
return spline.lr_from_k_sq(k_sq)

def self_contribution(self) -> torch.Tensor:
Expand Down
29 changes: 22 additions & 7 deletions src/torchpme/potentials/inversepowerlaw.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

import torch
from torch.special import gammaln, gammainc
from scipy.special import exp1
from torch.special import gammainc, gammaln

from .potential import Potential

Expand All @@ -18,8 +18,24 @@ def gamma(x: torch.Tensor) -> torch.Tensor:
return torch.exp(gammaln(x))


# Custom exponential function to have an autograd-compatible version of the exponential
class CustomE1(torch.autograd.Function):
"""The exponential integral E1(x)"""

@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return exp1(input)

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
return -grad_output * torch.exp(-input) / input


# Auxilary function for stable Fourier transform implementation
def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
"""Function to compute the regularized incomplete gamma function complement for integer exponents."""
if exponent not in [1, 2, 3, 4, 5, 6]:
raise ValueError(f"Unsupported exponent: {exponent}")

Expand All @@ -28,18 +44,19 @@ def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Te
if exponent == 2:
return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z))
if exponent == 3:
return exp1(z)
return CustomE1.apply(z)
if exponent == 4:
return 2 * (
torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z))
)
if exponent == 5:
return torch.exp(-z) - z * exp1(z)
return torch.exp(-z) - z * CustomE1.apply(z)
if exponent == 6:
return (
(2 - 4 * z) * torch.exp(-z)
+ 4 * torch.sqrt(torch.pi) * z**1.5 * torch.erfc(torch.sqrt(z))
) / 3
return None


class InversePowerLawPotential(Potential):
Expand Down Expand Up @@ -128,7 +145,7 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
x = 0.5 * dist**2 / smearing**2
peff = exponent / 2
prefac = 1.0 / (2 * smearing**2) ** peff
return prefac * gammainc(peff, x) / x ** peff
return prefac * gammainc(peff, x) / x**peff

@torch.jit.export
def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -159,9 +176,7 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
# for consistency reasons.
masked = torch.where(x == 0, 1.0, x) # avoid NaNs in backwards, see Coulomb
return torch.where(
k_sq == 0,
0.0,
prefac * gammaincc_over_powerlaw(exponent,masked)
k_sq == 0, 0.0, prefac * gammaincc_over_powerlaw(exponent, masked)
)

def self_contribution(self) -> torch.Tensor:
Expand Down

0 comments on commit 54b6fd6

Please sign in to comment.