Skip to content

Commit

Permalink
Merge pull request #138 from lab-cosmo/splinegpu
Browse files Browse the repository at this point in the history
SplinePotential device compatibility
  • Loading branch information
E-Rum authored Jan 8, 2025
2 parents 9742132 + 24cd5de commit 04edb22
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
Fixed
#####

* Fixed consistency of ``dtype`` and ``device`` in the ``SplinePotential`` class
* Fix inconsistent ``cutoff`` in neighbor list example
* All calculators now check if the cell is zero if the potential is range-separated

Expand Down
15 changes: 12 additions & 3 deletions src/torchpme/potentials/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def __init__(
if len(y_grid) != len(r_grid):
raise ValueError("Length of radial grid and value array mismatch.")

r_grid = r_grid.to(dtype=dtype, device=device)
y_grid = y_grid.to(dtype=dtype, device=device)

if reciprocal:
if torch.min(r_grid) <= 0.0:
raise ValueError(
Expand All @@ -89,6 +92,8 @@ def __init__(
k_grid = torch.pi * 2 * torch.reciprocal(r_grid).flip(dims=[0])
else:
k_grid = r_grid.clone()
else:
k_grid = k_grid.to(dtype=dtype, device=device)

if yhat_grid is None:
# computes automatically!
Expand All @@ -98,6 +103,8 @@ def __init__(
y_grid,
compute_second_derivatives(r_grid, y_grid),
)
else:
yhat_grid = yhat_grid.to(dtype=dtype, device=device)

# the function is defined for k**2, so we define the grid accordingly
if reciprocal:
Expand All @@ -108,12 +115,14 @@ def __init__(
self._krn_spline = CubicSpline(k_grid**2, yhat_grid)

if y_at_zero is None:
self._y_at_zero = self._spline(torch.tensor([0.0]))
self._y_at_zero = self._spline(torch.zeros(1, dtype=dtype, device=device))
else:
self._y_at_zero = y_at_zero

if yhat_at_zero is None:
self._yhat_at_zero = self._krn_spline(torch.tensor([0.0]))
self._yhat_at_zero = self._krn_spline(
torch.zeros(1, dtype=dtype, device=device)
)
else:
self._yhat_at_zero = yhat_at_zero

Expand All @@ -140,7 +149,7 @@ def self_contribution(self) -> torch.Tensor:
return self._y_at_zero

def background_correction(self) -> torch.Tensor:
return torch.tensor([0.0])
return torch.zeros(1)

from_dist.__doc__ = Potential.from_dist.__doc__
lr_from_dist.__doc__ = Potential.lr_from_dist.__doc__
Expand Down
2 changes: 1 addition & 1 deletion src/torchpme/utils/splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def compute_second_derivatives(
d2y = _solve_tridiagonal(a, b, c, d)

# Converts back to the original dtype
return d2y.to(x_points.dtype)
return d2y.to(dtype=x_points.dtype, device=x_points.device)


def compute_spline_ft(
Expand Down
32 changes: 32 additions & 0 deletions tests/test_potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,35 @@ def test_combined_potential_learnable_weights():
loss.backward()
optimizer.step()
assert torch.allclose(combined.weights, weights - 0.1)


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
@pytest.mark.parametrize(
"potential_class", [CoulombPotential, InversePowerLawPotential, SplinePotential]
)
def test_potential_device_dtype(potential_class, device, dtype):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

smearing = 1.0
exponent = 1.0

if potential_class is InversePowerLawPotential:
potential = potential_class(
exponent=exponent, smearing=smearing, dtype=dtype, device=device
)
elif potential_class is SplinePotential:
x_grid = torch.linspace(0, 20, 100, device=device, dtype=dtype)
y_grid = torch.exp(-(x_grid**2) * 0.5)
potential = potential_class(
r_grid=x_grid, y_grid=y_grid, reciprocal=False, dtype=dtype, device=device
)
else:
potential = potential_class(smearing=smearing, dtype=dtype, device=device)

dists = torch.linspace(0.1, 10.0, 100, device=device, dtype=dtype)
potential_lr = potential.lr_from_dist(dists)

assert potential_lr.device.type == device
assert potential_lr.dtype == dtype

0 comments on commit 04edb22

Please sign in to comment.