Skip to content

Commit

Permalink
Added dtype parameterization for tests and made cosmetic changes
Browse files Browse the repository at this point in the history
  • Loading branch information
E-Rum committed Jan 8, 2025
1 parent 7755255 commit 62598e3
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/src/references/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
Fixed
#####

* Fixed consistency of dtype and device in the SplinePotential class
* All calculators now check if the cell is zero if the potential is range-separated


Expand Down
8 changes: 3 additions & 5 deletions src/torchpme/potentials/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,13 @@ def __init__(
self._krn_spline = CubicSpline(k_grid**2, yhat_grid)

if y_at_zero is None:
self._y_at_zero = self._spline(
torch.tensor([0.0], dtype=dtype, device=device)
)
self._y_at_zero = self._spline(torch.zeros(1, dtype=dtype, device=device))
else:
self._y_at_zero = y_at_zero

if yhat_at_zero is None:
self._yhat_at_zero = self._krn_spline(
torch.tensor([0.0], dtype=dtype, device=device)
torch.zeros(1, dtype=dtype, device=device)
)
else:
self._yhat_at_zero = yhat_at_zero
Expand Down Expand Up @@ -151,7 +149,7 @@ def self_contribution(self) -> torch.Tensor:
return self._y_at_zero

def background_correction(self) -> torch.Tensor:
return torch.tensor([0.0])
return torch.zeros(1)

from_dist.__doc__ = Potential.from_dist.__doc__
lr_from_dist.__doc__ = Potential.lr_from_dist.__doc__
Expand Down
7 changes: 4 additions & 3 deletions tests/test_potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,22 +576,22 @@ def test_combined_potential_learnable_weights():


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
@pytest.mark.parametrize(
"potential_class", [CoulombPotential, InversePowerLawPotential, SplinePotential]
)
def test_potential_device(potential_class, device):
def test_potential_device_dtype(potential_class, device, dtype):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

smearing = 1.0
exponent = 1.0
dtype = torch.float64

if potential_class is InversePowerLawPotential:
potential = potential_class(
exponent=exponent, smearing=smearing, dtype=dtype, device=device
)
elif potential_class == SplinePotential:
elif potential_class is SplinePotential:
x_grid = torch.linspace(0, 20, 100, device=device, dtype=dtype)
y_grid = torch.exp(-(x_grid**2) * 0.5)
potential = potential_class(
Expand All @@ -604,3 +604,4 @@ def test_potential_device(potential_class, device):
potential_lr = potential.lr_from_dist(dists)

assert potential_lr.device.type == device
assert potential_lr.dtype == dtype

0 comments on commit 62598e3

Please sign in to comment.