Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manual charges #14

Merged
merged 11 commits into from
Apr 11, 2024
69 changes: 54 additions & 15 deletions src/meshlode/calculators/meshpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,19 @@ def forward(
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor],
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""forward just calls :py:meth:`CalculatorModule.compute`"""
return self.compute(types=types, positions=positions, cell=cell)
return self.compute(
types=types, positions=positions, cell=cell, charges=charges
)

def compute(
self,
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor],
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute potential for all provided "systems" stacked inside list.

Expand All @@ -131,6 +135,7 @@ def compute(
box/unit cell of the system. Each row should be one of the bounding box
vector; and columns should contain the x, y, and z components of these
vectors (i.e. the cell should be given in row-major order).
:param charges: Optional single or list of 2D tensor of shape (len(types), n),

:return: List of torch Tensors containing the potentials for all frames and all
atoms. Each tensor in the list is of shape (n_atoms, n_types), where
Expand All @@ -154,6 +159,7 @@ def compute(
if not isinstance(cell, list):
cell = [cell]

# Check that all inputs are consistent
for types_single, positions_single, cell_single in zip(types, positions, cell):
if len(types_single.shape) != 1:
raise ValueError(
Expand Down Expand Up @@ -189,27 +195,60 @@ def compute(
f"{cell_single.device}."
)

# We don't require and test that all dtypes and devices are consistent if a list
# of inputs. Each "frame" is processed independently.

requested_types = self._get_requested_types(types)
n_types = len(requested_types)

potentials = []
for types_single, positions_single, cell_single in zip(types, positions, cell):
# One-hot encoding of charge information
charges = torch.zeros(
(len(types_single), n_types),
dtype=positions_single.dtype,
device=positions_single.device,
)
for i_type, atomic_type in enumerate(requested_types):
charges[types_single == atomic_type, i_type] = 1.0
# If charges are not provided, we assume that all types are treated separately
PicoCentauri marked this conversation as resolved.
Show resolved Hide resolved
if charges is None:
charges = []
for types_single, positions_single in zip(types, positions):
charges_single = torch.zeros(
(len(types_single), n_types),
dtype=positions_single.dtype,
device=positions_single.device,
)
for i_type, atomic_type in enumerate(requested_types):
charges_single[types_single == atomic_type, i_type] = 1.0
E-Rum marked this conversation as resolved.
Show resolved Hide resolved
charges.append(charges_single)

# If charges are provided, we need to make sure that they are consistent with
# the provided types

else:
if not isinstance(charges, list):
charges = [charges]
if len(charges) != len(types):
raise ValueError(
"The number of `types` and `charges` tensors must be the same, "
f"got {len(types)} and {len(charges)}."
)
for charges_single, types_single in zip(charges, types):
if charges_single.shape[0] != len(types_single):
raise ValueError(
"The first dimension of `charges` must be the same as the "
f"length of `types`, got {charges_single.shape[0]} and "
f"{len(types_single)}."
)
if charges[0].dtype != positions[0].dtype:
raise ValueError(
"`charges` must be have the same dtype as `positions`, got "
f"{charges[0].dtype} and {positions[0].dtype}"
)
if charges[0].device != positions[0].device:
raise ValueError(
"`charges` must be on the same device as `positions`, got "
f"{charges[0].device} and {positions[0].device}"
)
E-Rum marked this conversation as resolved.
Show resolved Hide resolved
# We don't require and test that all dtypes and devices are consistent if a list
# of inputs. Each "frame" is processed independently.
potentials = []
for types_single, positions_single, cell_single, charges_single in zip(
types, positions, cell, charges
):
# Compute the potentials
potentials.append(
self._compute_single_system(
positions=positions_single, charges=charges, cell=cell_single
positions=positions_single, charges=charges_single, cell=cell_single
)
)

Expand Down
22 changes: 22 additions & 0 deletions tests/calculators/test_meshpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ def cscl_system():
return types, positions, cell


def cscl_system_with_charges():
"""CsCl crystal. Same as in the madelung test"""
types = torch.tensor([55, 17])
positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])
cell = torch.eye(3)

return types, positions, cell, torch.tensor([[0.0, 1.0], [1.0, 0]])
E-Rum marked this conversation as resolved.
Show resolved Hide resolved


# Initialize the calculators. For now, only the MeshPotential is implemented.
def descriptor() -> MeshPotential:
atomic_smearing = 0.1
Expand Down Expand Up @@ -95,6 +104,19 @@ def test_single_frame():
)


def test_single_frame_with_charges():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment here that this test uses an explicit one hot encoding for the charges.

values = descriptor().compute(*cscl_system_with_charges())
print(values)
print(MADELUNG_CSCL),
print(CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1])
E-Rum marked this conversation as resolved.
Show resolved Hide resolved
assert_close(
MADELUNG_CSCL,
CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1],
atol=1e4,
rtol=1e-5,
)


def test_multi_frame():
types, positions, cell = cscl_system()
l_values = descriptor().compute(
Expand Down
Loading