Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manual charges #14

Merged
merged 11 commits into from
Apr 11, 2024
84 changes: 69 additions & 15 deletions src/meshlode/calculators/meshpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,19 @@ def forward(
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor],
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""forward just calls :py:meth:`CalculatorModule.compute`"""
return self.compute(types=types, positions=positions, cell=cell)
return self.compute(
types=types, positions=positions, cell=cell, charges=charges
)

def compute(
self,
types: Union[List[torch.Tensor], torch.Tensor],
positions: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor],
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute potential for all provided "systems" stacked inside list.

Expand All @@ -131,6 +135,7 @@ def compute(
box/unit cell of the system. Each row should be one of the bounding box
vector; and columns should contain the x, y, and z components of these
vectors (i.e. the cell should be given in row-major order).
:param charges: Optional single or list of 2D tensor of shape (len(types), n),

:return: List of torch Tensors containing the potentials for all frames and all
atoms. Each tensor in the list is of shape (n_atoms, n_types), where
Expand All @@ -154,6 +159,7 @@ def compute(
if not isinstance(cell, list):
cell = [cell]

# Check that all inputs are consistent
for types_single, positions_single, cell_single in zip(types, positions, cell):
if len(types_single.shape) != 1:
raise ValueError(
Expand Down Expand Up @@ -189,27 +195,61 @@ def compute(
f"{cell_single.device}."
)

# We don't require and test that all dtypes and devices are consistent if a list
# of inputs. Each "frame" is processed independently.

requested_types = self._get_requested_types(types)
n_types = len(requested_types)

potentials = []
for types_single, positions_single, cell_single in zip(types, positions, cell):
# One-hot encoding of charge information
charges = torch.zeros(
(len(types_single), n_types),
dtype=positions_single.dtype,
device=positions_single.device,
)
for i_type, atomic_type in enumerate(requested_types):
charges[types_single == atomic_type, i_type] = 1.0
# If charges are not provided, we assume that all types are treated separately
PicoCentauri marked this conversation as resolved.
Show resolved Hide resolved
if charges is None:
charges = []
for types_single, positions_single in zip(types, positions):
# One-hot encoding of charge information
charges_single = self._one_hot_charges(
types_single,
requested_types,
n_types,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just len(requested_types) as you write in L199 I think you don't need this as a parameter.

positions_single.dtype,
positions_single.device,
)
charges.append(charges_single)

# If charges are provided, we need to make sure that they are consistent with
# the provided types

else:
if not isinstance(charges, list):
charges = [charges]
if len(charges) != len(types):
raise ValueError(
"The number of `types` and `charges` tensors must be the same, "
f"got {len(types)} and {len(charges)}."
)
for charges_single, types_single in zip(charges, types):
if charges_single.shape[0] != len(types_single):
raise ValueError(
"The first dimension of `charges` must be the same as the "
f"length of `types`, got {charges_single.shape[0]} and "
f"{len(types_single)}."
)
if charges[0].dtype != positions[0].dtype:
raise ValueError(
"`charges` must be have the same dtype as `positions`, got "
f"{charges[0].dtype} and {positions[0].dtype}."
)
if charges[0].device != positions[0].device:
raise ValueError(
"`charges` must be on the same device as `positions`, got "
f"{charges[0].device} and {positions[0].device}."
)
# We don't require and test that all dtypes and devices are consistent if a list
# of inputs. Each "frame" is processed independently.
potentials = []
for positions_single, cell_single, charges_single in zip(
positions, cell, charges
):
# Compute the potentials
potentials.append(
self._compute_single_system(
positions=positions_single, charges=charges, cell=cell_single
positions=positions_single, charges=charges_single, cell=cell_single
)
)

Expand Down Expand Up @@ -312,3 +352,17 @@ def _compute_single_system(
interpolated_potential -= charges * self_contrib

return interpolated_potential

def _one_hot_charges(
self,
types: torch.Tensor,
requested_types: List[int],
n_types: int,
dtype: torch.dtype,
device: torch.device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for convenience and consistence with torch, I would make them optional.

Suggested change
dtype: torch.dtype,
device: torch.device,
dtype: Optional[torch.dtype]=None,
device: Optional[torch.device]=None,

) -> torch.Tensor:
one_hot_charges = torch.zeros((len(types), n_types), dtype=dtype, device=device)
for i_type, atomic_type in enumerate(requested_types):
one_hot_charges[types == atomic_type, i_type] = 1.0

return one_hot_charges
90 changes: 67 additions & 23 deletions src/meshlode/metatensor/meshpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,22 @@ def compute(
) -> TensorMap:
"""Compute potential for all provided ``systems``.

All ``systems`` must have the same ``dtype`` and the same ``device``.
All ``systems`` must have the same ``dtype`` and the same ``device``. If each
system contains a custom data field `charges` the potential will be calculated
for each "charges-channel". The number of `charges-channels` must be same in all
``systems``. If no "explicit" charges are set the potential will be calculated
for each "types-channels".

Refer to :meth:`meshlode.MeshPotential.compute()` for additional details on how
"charges-channel" and "types-channels" are computed.

:param systems: single System or list of
:py:class:`metatensor.torch.atomisic.System` on which to run the
calculation.
calculations.

:return: TensorMap containing the potential of all types. The keys of the
tensormap are "center_type" and "neighbor_type".
TensorMap are "center_type" and "neighbor_type" if no charges are asociated
with
"""
# Make sure that the compute function also works if only a single frame is
# provided as input (for convenience of users testing out the code)
Expand Down Expand Up @@ -111,38 +119,71 @@ def compute(
)
n_types = len(requested_types)

# Initialize dictionary for sparse storage of the features to map atomic types
# to array indices. In general, the types are sorted according to their
# (integer) type and assigned the array indices 0, 1, 2,... Example: for H2O:
# `H` is mapped to `0` and `O` is mapped to `1`.
n_types_sq = n_types * n_types
feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_types_sq)}
has_charges = torch.tensor(["charges" in s.known_data() for s in systems])
all_charges = torch.all(has_charges)
any_charges = torch.any(has_charges)

if any_charges and not all_charges:
raise ValueError("`systems` do not consistently contain `charges` data")
if all_charges:
use_explicit_charges = True
n_charges_channels = systems[0].get_data("charges").values.shape[1]
spec_channels = list(range(n_charges_channels))
key_names = ["center_type", "charges_channel"]

for i_system, system in enumerate(systems):
n_channels = system.get_data("charges").values.shape[1]
if n_channels != n_charges_channels:
raise ValueError(
f"number of charges-channels in system index {i_system} "
f"({n_channels}) is inconsistent with first system "
f"({n_charges_channels})"
)
else:
# Use one hot encoded type channel per species for charges channel
use_explicit_charges = False
n_charges_channels = n_types
spec_channels = requested_types
key_names = ["center_type", "neighbor_type"]

# Initialize dictionary for TensorBlock storage.
#
# If `use_explicit_charges=False`, the blocks are sorted according to the
# (integer) center_type and neighbor_type. Blocks are assigned the array indices
# 0, 1, 2,... Example: for H2O: `H` is mapped to `0` and `O` is mapped to `1`.
#
# For `use_explicit_charges=True` the blocks are stored according to the
# center_type and charge_channel
n_blocks = n_types * n_charges_channels
feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_blocks)}

for system in systems:
# One-hot encoding of charge information
types = system.types
charges = torch.zeros((len(system), n_types), dtype=dtype, device=device)
for i_specie, atomic_type in enumerate(requested_types):
charges[types == atomic_type, i_specie] = 1.0
if use_explicit_charges:
charges = system.get_data("charges").values
else:
# One-hot encoding of charge information
charges = self._one_hot_charges(
system.types, requested_types, n_types, dtype, device
)

# Compute the potentials
potential = self._compute_single_system(
system.positions, charges, system.cell
)

# Reorder data into Metatensor format
# Reorder data into metatensor format
for spec_center, at_num_center in enumerate(requested_types):
for spec_neighbor in range(len(requested_types)):
a_pair = spec_center * n_types + spec_neighbor
for spec_channel in range(len(spec_channels)):
a_pair = spec_center * n_charges_channels + spec_channel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh yes good catch!

feat_dic[a_pair] += [
potential[types == at_num_center, spec_neighbor]
potential[system.types == at_num_center, spec_channel]
]

# Assemble all computed potential values into TensorBlocks for each combination
# of center_type and neighbor_type
# of center_type and neighbor_type/charge_channel
blocks: List[TensorBlock] = []
for keys, values in feat_dic.items():
spec_center = requested_types[keys // n_types]
spec_center = requested_types[keys // n_charges_channels]

# Generate the Labels objects for the samples and properties of the
# TensorBlock.
Expand Down Expand Up @@ -176,14 +217,17 @@ def compute(

blocks.append(block)

assert len(blocks) == n_blocks

# Generate TensorMap from TensorBlocks by defining suitable keys
key_values: List[torch.Tensor] = []
for spec_center in requested_types:
for spec_neighbor in requested_types:
for spec_channel in spec_channels:
key_values.append(
torch.tensor([spec_center, spec_neighbor], device=device)
torch.tensor([spec_center, spec_channel], device=device)
)
key_values = torch.vstack(key_values)
labels_keys = Labels(["center_type", "neighbor_type"], key_values)

labels_keys = Labels(key_names, key_values)

return TensorMap(keys=labels_keys, blocks=blocks)
79 changes: 78 additions & 1 deletion tests/calculators/test_meshpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def cscl_system():
return types, positions, cell


def cscl_system_with_charges():
"""CsCl crystal with charges."""
charges = torch.tensor([[0.0, 1.0], [1.0, 0]])
return cscl_system() + (charges,)


# Initialize the calculators. For now, only the MeshPotential is implemented.
def descriptor() -> MeshPotential:
atomic_smearing = 0.1
Expand Down Expand Up @@ -86,7 +92,17 @@ def test_operation_as_torch_script():

def test_single_frame():
values = descriptor().compute(*cscl_system())
print(values)
assert_close(
MADELUNG_CSCL,
CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1],
atol=1e4,
rtol=1e-5,
)


# Test with explicit charges
def test_single_frame_with_charges():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment here that this test uses an explicit one hot encoding for the charges.

values = descriptor().compute(*cscl_system_with_charges())
assert_close(
MADELUNG_CSCL,
CHARGES_CSCL[0] * values[0, 0] + CHARGES_CSCL[1] * values[0, 1],
Expand Down Expand Up @@ -138,6 +154,36 @@ def test_positions_error():
descriptor().compute(types=types, positions=positions, cell=cell)


def test_charges_error_dimension_mismatch():
types = torch.tensor([1, 2])
positions = torch.zeros((2, 3))
cell = torch.eye(3)
charges = torch.zeros((1, 2)) # This should have the same first dimension as types

match = (
"The first dimension of `charges` must be the same as the length "
"of `types`, got 1 and 2."
)

with pytest.raises(ValueError, match=match):
descriptor().compute(
types=types, positions=positions, cell=cell, charges=charges
)


def test_charges_error_length_mismatch():
types = [torch.tensor([1, 2]), torch.tensor([1, 2, 3])]
positions = [torch.zeros((2, 3)), torch.zeros((3, 3))]
cell = torch.eye(3)
charges = [torch.zeros(2, 1)] # This should have the same length as types
match = "The number of `types` and `charges` tensors must be the same, got 2 and 1."

with pytest.raises(ValueError, match=match):
descriptor().compute(
types=types, positions=positions, cell=cell, charges=charges
)


def test_cell_error():
types = torch.tensor([1, 2, 3])
positions = torch.zeros((3, 3))
Expand Down Expand Up @@ -212,6 +258,37 @@ def test_inconsistent_device():
MP.compute(types=types, positions=positions, cell=cell)


def test_inconsistent_device_charges():
"""Test if the cell and positions have inconsistent device and error is raised."""
types = torch.tensor([1], device="cpu")
positions = torch.tensor([[0.0, 0.0, 0.0]], device="cpu")
cell = torch.eye(3, device="cpu")
charges = torch.tensor([0.0], device="meta") # different device

MP = MeshPotential(atomic_smearing=0.2)

match = "`charges` must be on the same device as `positions`, got meta and cpu."
with pytest.raises(ValueError, match=match):
MP.compute(types=types, positions=positions, cell=cell, charges=charges)


def test_inconsistent_dtype_charges():
"""Test if the cell and positions have inconsistent dtype and error is raised."""
types = torch.tensor([1], dtype=torch.float32)
positions = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32)
cell = torch.eye(3, dtype=torch.float32)
charges = torch.tensor([0.0], dtype=torch.float64) # Different dtype

MP = MeshPotential(atomic_smearing=0.2)

match = (
"`charges` must be have the same dtype as `positions`, got torch.float64 and "
"torch.float32"
)
with pytest.raises(ValueError, match=match):
MP.compute(types=types, positions=positions, cell=cell, charges=charges)


def test_1d_tolist():
in_list = [1, 2, 7, 3, 4, 42]
in_tensor = torch.tensor(in_list)
Expand Down
Loading