diff --git a/docs/src/references/changelog.rst b/docs/src/references/changelog.rst index b943cab8..d5b7960b 100644 --- a/docs/src/references/changelog.rst +++ b/docs/src/references/changelog.rst @@ -27,6 +27,7 @@ changelog `_ format. This project follows Fixed ##### +* Fixed consistency of ``dtype`` and ``device`` in the ``SplinePotential`` class * Fix inconsistent ``cutoff`` in neighbor list example * All calculators now check if the cell is zero if the potential is range-separated diff --git a/src/torchpme/potentials/spline.py b/src/torchpme/potentials/spline.py index 4230ddbe..a89120f4 100644 --- a/src/torchpme/potentials/spline.py +++ b/src/torchpme/potentials/spline.py @@ -74,6 +74,9 @@ def __init__( if len(y_grid) != len(r_grid): raise ValueError("Length of radial grid and value array mismatch.") + r_grid = r_grid.to(dtype=dtype, device=device) + y_grid = y_grid.to(dtype=dtype, device=device) + if reciprocal: if torch.min(r_grid) <= 0.0: raise ValueError( @@ -89,6 +92,8 @@ def __init__( k_grid = torch.pi * 2 * torch.reciprocal(r_grid).flip(dims=[0]) else: k_grid = r_grid.clone() + else: + k_grid = k_grid.to(dtype=dtype, device=device) if yhat_grid is None: # computes automatically! @@ -98,6 +103,8 @@ def __init__( y_grid, compute_second_derivatives(r_grid, y_grid), ) + else: + yhat_grid = yhat_grid.to(dtype=dtype, device=device) # the function is defined for k**2, so we define the grid accordingly if reciprocal: @@ -108,12 +115,14 @@ def __init__( self._krn_spline = CubicSpline(k_grid**2, yhat_grid) if y_at_zero is None: - self._y_at_zero = self._spline(torch.tensor([0.0])) + self._y_at_zero = self._spline(torch.zeros(1, dtype=dtype, device=device)) else: self._y_at_zero = y_at_zero if yhat_at_zero is None: - self._yhat_at_zero = self._krn_spline(torch.tensor([0.0])) + self._yhat_at_zero = self._krn_spline( + torch.zeros(1, dtype=dtype, device=device) + ) else: self._yhat_at_zero = yhat_at_zero @@ -140,7 +149,7 @@ def self_contribution(self) -> torch.Tensor: return self._y_at_zero def background_correction(self) -> torch.Tensor: - return torch.tensor([0.0]) + return torch.zeros(1) from_dist.__doc__ = Potential.from_dist.__doc__ lr_from_dist.__doc__ = Potential.lr_from_dist.__doc__ diff --git a/src/torchpme/utils/splines.py b/src/torchpme/utils/splines.py index 2d66c5f6..036ded7c 100644 --- a/src/torchpme/utils/splines.py +++ b/src/torchpme/utils/splines.py @@ -198,7 +198,7 @@ def compute_second_derivatives( d2y = _solve_tridiagonal(a, b, c, d) # Converts back to the original dtype - return d2y.to(x_points.dtype) + return d2y.to(dtype=x_points.dtype, device=x_points.device) def compute_spline_ft( diff --git a/tests/test_potentials.py b/tests/test_potentials.py index a20b71b7..87d590bd 100644 --- a/tests/test_potentials.py +++ b/tests/test_potentials.py @@ -573,3 +573,35 @@ def test_combined_potential_learnable_weights(): loss.backward() optimizer.step() assert torch.allclose(combined.weights, weights - 0.1) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +@pytest.mark.parametrize( + "potential_class", [CoulombPotential, InversePowerLawPotential, SplinePotential] +) +def test_potential_device_dtype(potential_class, device, dtype): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + smearing = 1.0 + exponent = 1.0 + + if potential_class is InversePowerLawPotential: + potential = potential_class( + exponent=exponent, smearing=smearing, dtype=dtype, device=device + ) + elif potential_class is SplinePotential: + x_grid = torch.linspace(0, 20, 100, device=device, dtype=dtype) + y_grid = torch.exp(-(x_grid**2) * 0.5) + potential = potential_class( + r_grid=x_grid, y_grid=y_grid, reciprocal=False, dtype=dtype, device=device + ) + else: + potential = potential_class(smearing=smearing, dtype=dtype, device=device) + + dists = torch.linspace(0.1, 10.0, 100, device=device, dtype=dtype) + potential_lr = potential.lr_from_dist(dists) + + assert potential_lr.device.type == device + assert potential_lr.dtype == dtype