Skip to content

Commit

Permalink
Make all calculators device-independent (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri authored Jul 25, 2024
1 parent 64dc478 commit 32dcccb
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 30 deletions.
7 changes: 5 additions & 2 deletions src/meshlode/calculators/directpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ def _compute_single_system(
# The squared distance and the inner product between two vectors r_i and r_j are
# related by: d_ij^2 = |r_i - r_j|^2 = r_i^2 + r_j^2 - 2*r_i*r_j
num_atoms = len(positions)
diagonal_indices = torch.arange(num_atoms)
dtype = positions.dtype
device = positions.device

diagonal_indices = torch.arange(num_atoms, device=device)
gram_matrix = positions @ positions.T
squared_norms = gram_matrix[diagonal_indices, diagonal_indices].reshape(-1, 1)
ones = torch.ones((1, len(positions)), dtype=positions.dtype)
ones = torch.ones((1, len(positions)), dtype=dtype, device=device)
squared_norms_matrix = torch.matmul(squared_norms, ones)
distances_sq = squared_norms_matrix + squared_norms_matrix.T - 2 * gram_matrix

Expand Down
12 changes: 6 additions & 6 deletions src/meshlode/metatensor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap:
for i_atom in range(len(system)):
values_samples.append([i_system, i_atom])

samples_vals_tensor = torch.tensor(values_samples, device=self._device)
samples_values = torch.tensor(values_samples, device=self._device)
properties_values = torch.arange(self._n_charges_channels, device=self._device)

block = TensorBlock(
values=torch.vstack(potentials),
samples=Labels(["system", "atom"], samples_vals_tensor),
samples=Labels(["system", "atom"], samples_values),
components=[],
properties=Labels(
"charges_channel", torch.arange(self._n_charges_channels).reshape(-1, 1)
),
properties=Labels("charges_channel", properties_values.reshape(-1, 1)),
)

return TensorMap(keys=Labels("_", torch.tensor([[0]])), blocks=[block])
keys = Labels("_", torch.tensor([[0]], device=self._device))
return TensorMap(keys=keys, blocks=[block])
36 changes: 24 additions & 12 deletions tests/calculators/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from meshlode import DirectPotential, EwaldPotential, PMEPotential


AVAILABLE_DEVICES = [torch.device("cpu")] + torch.cuda.is_available() * [
torch.device("cuda")
]
MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3))
CHARGES_CSCL = torch.tensor([1.0, -1.0])


ATOMIC_SMEARING = 0.1
LR_WAVELENGTH = ATOMIC_SMEARING / 4
MESH_SPACING = ATOMIC_SMEARING / 4
Expand Down Expand Up @@ -49,8 +50,11 @@
],
)
class TestWorkflow:
def cscl_system(self, periodic):
def cscl_system(self, periodic, device=None):
"""CsCl crystal. Same as in the madelung test"""
if device is None:
device = torch.device("cpu")

positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])
charges = torch.tensor([1.0, -1.0]).reshape((-1, 1))
if periodic:
Expand All @@ -59,9 +63,15 @@ def cscl_system(self, periodic):
neighbor_indices, neighbor_shifts = neighbor_list_torch(
positions=positions, cell=cell
)
return positions, charges, cell, neighbor_indices, neighbor_shifts
return (
positions.to(device=device),
charges.to(device=device),
cell.to(device=device),
neighbor_indices.to(device=device),
neighbor_shifts.to(device=device),
)
else:
return positions, charges
return positions.to(device=device), charges.to(device=device)

def test_interpolation_order_error(self, CalculatorClass, params, periodic):
if type(CalculatorClass) in [PMEPotential]:
Expand Down Expand Up @@ -122,25 +132,27 @@ def test_dtype_device(self, CalculatorClass, params, periodic):
assert potential.dtype == dtype
assert potential.device.type == device

def check_operation(self, calculator, periodic):
def check_operation(self, calculator, periodic, device):
"""Make sure computation runs and returns a torch.Tensor."""
descriptor_compute = calculator.compute(*self.cscl_system(periodic))
descriptor_forward = calculator.forward(*self.cscl_system(periodic))
descriptor_compute = calculator.compute(*self.cscl_system(periodic, device))
descriptor_forward = calculator.forward(*self.cscl_system(periodic, device))

assert type(descriptor_compute) is torch.Tensor
assert type(descriptor_forward) is torch.Tensor
assert torch.equal(descriptor_forward, descriptor_compute)

def test_operation_as_python(self, CalculatorClass, params, periodic):
@pytest.mark.parametrize("device", AVAILABLE_DEVICES)
def test_operation_as_python(self, CalculatorClass, params, periodic, device):
"""Run `check_operation` as a normal python script"""
calculator = CalculatorClass(**params)
self.check_operation(calculator, periodic)
self.check_operation(calculator=calculator, periodic=periodic, device=device)

def test_operation_as_torch_script(self, CalculatorClass, params, periodic):
@pytest.mark.parametrize("device", AVAILABLE_DEVICES)
def test_operation_as_torch_script(self, CalculatorClass, params, periodic, device):
"""Run `check_operation` as a compiled torch script module."""
calculator = CalculatorClass(**params)
scripted = torch.jit.script(calculator)
self.check_operation(scripted, periodic)
self.check_operation(calculator=scripted, periodic=periodic, device=device)

def test_save_load(self, CalculatorClass, params, periodic):
calculator = CalculatorClass(**params)
Expand Down
27 changes: 17 additions & 10 deletions tests/metatensor/test_workflow_metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
mts_torch = pytest.importorskip("metatensor.torch")
mts_atomistic = pytest.importorskip("metatensor.torch.atomistic")


AVAILABLE_DEVICES = [torch.device("cpu")] + torch.cuda.is_available() * [
torch.device("cuda")
]
ATOMIC_SMEARING = 0.1
LR_WAVELENGTH = ATOMIC_SMEARING / 4
MESH_SPACING = ATOMIC_SMEARING / 4
Expand Down Expand Up @@ -47,9 +49,12 @@
],
)
class TestWorkflow:
def cscl_system(self):
def cscl_system(self, device=None):
"""CsCl crystal. Same as in the madelung test"""

if device is None:
device = torch.device("cpu")

system = mts_atomistic.System(
types=torch.tensor([17, 55]),
positions=torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]),
Expand All @@ -65,12 +70,12 @@ def cscl_system(self):
system.add_data(name="charges", data=data)
add_neighbor_list(system)

return system
return system.to(device=device)

def check_operation(self, calculator):
def check_operation(self, calculator, device):
"""Make sure computation runs and returns a metatensor.TensorMap."""
descriptor_compute = calculator.compute(self.cscl_system())
descriptor_forward = calculator.forward(self.cscl_system())
descriptor_compute = calculator.compute(self.cscl_system(device))
descriptor_forward = calculator.forward(self.cscl_system(device))

assert isinstance(descriptor_compute, torch.ScriptObject)
assert isinstance(descriptor_forward, torch.ScriptObject)
Expand All @@ -80,16 +85,18 @@ def check_operation(self, calculator):

assert mts_torch.equal(descriptor_forward, descriptor_compute)

def test_operation_as_python(self, CalculatorClass, params):
@pytest.mark.parametrize("device", AVAILABLE_DEVICES)
def test_operation_as_python(self, CalculatorClass, params, device):
"""Run `check_operation` as a normal python script"""
calculator = CalculatorClass(**params)
self.check_operation(calculator)
self.check_operation(calculator=calculator, device=device)

def test_operation_as_torch_script(self, CalculatorClass, params):
@pytest.mark.parametrize("device", AVAILABLE_DEVICES)
def test_operation_as_torch_script(self, CalculatorClass, params, device):
"""Run `check_operation` as a compiled torch script module."""
calculator = CalculatorClass(**params)
scripted = torch.jit.script(calculator)
self.check_operation(scripted)
self.check_operation(calculator=scripted, device=device)

def test_save_load(self, CalculatorClass, params):
calculator = CalculatorClass(**params)
Expand Down

0 comments on commit 32dcccb

Please sign in to comment.