diff --git a/docs/src/references/lib/index.rst b/docs/src/references/lib/index.rst index 9a69879a..fb2dc6c3 100644 --- a/docs/src/references/lib/index.rst +++ b/docs/src/references/lib/index.rst @@ -9,4 +9,3 @@ are used for the meshLODE calculators. fourier_convolution mesh_interpolator - system diff --git a/docs/src/references/lib/system.rst b/docs/src/references/lib/system.rst deleted file mode 100644 index 00d52b2c..00000000 --- a/docs/src/references/lib/system.rst +++ /dev/null @@ -1,6 +0,0 @@ -System -====== - -.. autoclass:: meshlode.System - :members: - :undoc-members: diff --git a/examples/library-tutorial.py b/examples/library-tutorial.py index 5cf93e0e..014d247c 100644 --- a/examples/library-tutorial.py +++ b/examples/library-tutorial.py @@ -1,7 +1,6 @@ """ Basic Tutorial for Library functions ==================================== - This examples provides an illustration of the functioning of the underlaying library functions of ``meshlode`` and the construction LODE descriptors (`Grisafi 2019 `__, `Grisafi 2021 @@ -18,6 +17,7 @@ import matplotlib.pyplot as plt import numpy as np import torch +from metatensor.torch.atomistic import System import meshlode @@ -48,22 +48,19 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): # %% # Builds the structure # -------------------- -# # Builds a CsCl structure by replicating the primitive cell using ase and convert it to -# a :py:class:`list` of :py:class:`meshlode.System`. We add a bit of noise to make -# it less boring! +# a :py:class:`list` of :py:class:`metatensor.torch.atomistic.System`. We add a bit of +# noise to make it less boring! # positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) * 4 -atomic_types = torch.tensor([55, 17]) # Cs and Cl +types = torch.tensor([55, 17]) # Cs and Cl cell = torch.eye(3) * 4 -ase_frame = ase.Atoms(positions=positions, cell=cell, numbers=atomic_types).repeat( - [2, 2, 2] -) +ase_frame = ase.Atoms(positions=positions, cell=cell, numbers=types).repeat([2, 2, 2]) ase_frame.positions[:] += np.random.normal(size=ase_frame.positions.shape) * 0.1 charges = torch.tensor([1.0, -1.0] * 8) -frame = meshlode.System( - species=torch.tensor(ase_frame.numbers), +system = System( + types=torch.tensor(ase_frame.numbers), positions=torch.tensor(np.array(ase_frame.positions)), cell=torch.tensor(ase_frame.cell), ) @@ -85,7 +82,6 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): # %% # MeshInterpolator # ---------------- -# # ``MeshInterpolator`` serves as a utility class to compute a mesh # representation of points, and/or to project a function defined on the # mesh on a set of points. Computing the mesh representation is a two-step @@ -95,15 +91,15 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): # interpol = meshlode.lib.mesh_interpolator.MeshInterpolator( - frame.cell, torch.tensor([16, 16, 16]), interpolation_order=3 + system.cell, torch.tensor([16, 16, 16]), interpolation_order=3 ) -interpol.compute_interpolation_weights(frame.positions) +interpol.compute_interpolation_weights(system.positions) # %% # We use two sets of weights: ones (giving the atom density irrespective -# of the species) and charges (giving a smooth representation of the point +# of the types) and charges (giving a smooth representation of the point # charges). # @@ -126,7 +122,6 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): # %% # Fourier filter # -------------- -# # This module computes a Fourier-domain filter, that can be used e.g. to # smear the density and/or compute a 1/r^p potential field. This can also # be easily extended to compute an arbitrary filter @@ -137,7 +132,7 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): # %% # plain atomic_smearing rho_mesh = fsc.compute( - mesh_values=mesh, cell=frame.cell, potential_exponent=0, atomic_smearing=1 + mesh_values=mesh, cell=system.cell, potential_exponent=0, atomic_smearing=1 ) sliceplot(rho_mesh[0, :, :, :5]) @@ -145,7 +140,7 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): # %% # coulomb-like potential, no atomic_smearing coulomb_mesh = fsc.compute( - mesh_values=mesh, cell=frame.cell, potential_exponent=1, atomic_smearing=0 + mesh_values=mesh, cell=system.cell, potential_exponent=1, atomic_smearing=0 ) sliceplot(coulomb_mesh[1, :, :, :5], cmap="seismic") @@ -154,7 +149,6 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): # %% # Back-interpolation (on the same points) # --------------------------------------- -# # The same ``MeshInterpolator`` object can be used to compute a field on # the same points used initially to generate the atom density # @@ -167,7 +161,6 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): # %% # Back-interpolation (on different points) # ---------------------------------------- -# # In order to compute the field on a different set of points, it is # sufficient to build another ``MeshInterpolator`` object and to compute # it with the desired field. One can also use a different @@ -175,13 +168,13 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): # interpol_slice = meshlode.lib.mesh_interpolator.MeshInterpolator( - frame.cell, torch.tensor([16, 16, 16]), interpolation_order=4 + system.cell, torch.tensor([16, 16, 16]), interpolation_order=4 ) # Compute a denser grid on a 2D slice n_points = 50 -x = torch.linspace(0, frame.cell[0, 0], n_points + 1)[:n_points] -y = torch.linspace(0, frame.cell[1, 1], n_points + 1)[:n_points] +x = torch.linspace(0, system.cell[0, 0], n_points + 1)[:n_points] +y = torch.linspace(0, system.cell[1, 1], n_points + 1)[:n_points] xx, yy = torch.meshgrid(x, y, indexing="ij") # Flatten xx and yy, and concatenate with a zero column for the z-coordinate diff --git a/examples/madelung.py b/examples/madelung.py index 7193746c..fbfc6fa4 100644 --- a/examples/madelung.py +++ b/examples/madelung.py @@ -1,7 +1,6 @@ """ Compute Madelung Constants ========================== - In this tutorial we show how to calculate the Madelung constants and total electrostatic energy of atomic structures using the :py:class:`meshlode.MeshPotential` and :py:class:`meshlode.metatensor.MeshPotential` calculator. @@ -11,24 +10,24 @@ import math import torch +from metatensor.torch.atomistic import System import meshlode # %% # Define simple example structure having the CsCl structure and compute the reference -# values. MeshPotential by default outputs the species sorted according to the atomic +# values. MeshPotential by default outputs the types sorted according to the atomic # number. Thus, we input the compound "CsCl" and "ClCs" since Cl and Cs have atomic # numbers 17 and 55, respectively. -atomic_types = torch.tensor([17, 55]) # Cl and Cs +types = torch.tensor([17, 55]) # Cl and Cs +positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) charges = torch.tensor([-1.0, 1.0]) cell = torch.eye(3) -positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) -frame = meshlode.System(species=atomic_types, positions=positions, cell=torch.eye(3)) # %% # Define the expected values of the energy -n_atoms = len(positions) +n_atoms = len(types) madelung = 2 * 1.7626 / math.sqrt(3) energies_ref = -madelung * torch.ones((n_atoms, 1)) @@ -43,7 +42,6 @@ # %% # Computation using ``meshlode`` # ------------------------------ -# # Compute features using MP = meshlode.MeshPotential( @@ -52,7 +50,7 @@ interpolation_order=interpolation_order, subtract_self=True, ) -potentials_torch = MP.compute(frame) +potentials_torch = MP.compute(types=types, positions=positions, cell=cell) # %% # The "potentials" that have been computed so far are not the actual electrostatic @@ -66,7 +64,7 @@ for idx_n in range(n_atoms): # The coulomb potential between atoms i and j is charge_i * charge_j / d_ij # The features are simply computing a pure 1/r potential with no prefactors. - # Thus, to compute the energy between atoms of species i and j, we need to + # Thus, to compute the energy between atoms of types i and j, we need to # multiply by the charges of i and j. print(charges[idx_c] * charges[idx_n], potentials_torch[idx_n, idx_c]) atomic_energies_torch[idx_c] += ( @@ -88,8 +86,11 @@ # %% # Computation using ``meshlode.metatensor`` # ----------------------------------------- -# -# We now compute the same constants using the metatensor based calculator +# We now compute the same constants using the metatensor based calculator. To achieve +# this we first store our system parameters like the ``types``, ``positions`` and the +# ``cell`` defined above into a :py:class:`metatensor.torch.atomistic.System` class. + +system = System(types=types, positions=positions, cell=cell) MP = meshlode.metatensor.MeshPotential( atomic_smearing=atomic_smearing, @@ -97,7 +98,7 @@ interpolation_order=interpolation_order, subtract_self=True, ) -potential_metatensor = MP.compute(frame) +potential_metatensor = MP.compute(system) # %% @@ -105,16 +106,16 @@ # of the "potentials" weighted by the charges of the atoms. atomic_energies_metatensor = torch.zeros((n_atoms, 1)) -for idx_c, c in enumerate(atomic_types): - for idx_n, n in enumerate(atomic_types): - # Take the coefficients with the correct center atom and neighbor atom species +for idx_c, c in enumerate(types): + for idx_n, n in enumerate(types): + # Take the coefficients with the correct center atom and neighbor atom types block = potential_metatensor.block( {"center_type": int(c), "neighbor_type": int(n)} ) # The coulomb potential between atoms i and j is charge_i * charge_j / d_ij # The features are simply computing a pure 1/r potential with no prefactors. - # Thus, to compute the energy between atoms of species i and j, we need to + # Thus, to compute the energy between atoms of types i and j, we need to # multiply by the charges of i and j. print(c, n, charges[idx_c] * charges[idx_n], block.values[0, 0]) atomic_energies_metatensor[idx_c] += ( diff --git a/pyproject.toml b/pyproject.toml index 2b3cb229..cf3c1daa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ keywords = [ "Atomistic Simulations", ] dependencies = [ - "torch >= 1.11", + "torch >=1.11", ] dynamic = ["version"] @@ -44,7 +44,7 @@ examples = [ "matplotlib", ] metatensor = [ - "metatensor[torch]", + "metatensor-torch >=0.3", ] [project.urls] diff --git a/src/meshlode/__init__.py b/src/meshlode/__init__.py index 527135a4..ee18ff6d 100644 --- a/src/meshlode/__init__.py +++ b/src/meshlode/__init__.py @@ -1,5 +1,4 @@ from .calculators.meshpotential import MeshPotential -from .lib.system import System try: from . import metatensor # noqa @@ -7,5 +6,5 @@ pass -__all__ = ["MeshPotential", "System"] +__all__ = ["MeshPotential"] __version__ = "0.0.0-dev" diff --git a/src/meshlode/calculators/meshpotential.py b/src/meshlode/calculators/meshpotential.py index 823e0ccd..778d7db6 100644 --- a/src/meshlode/calculators/meshpotential.py +++ b/src/meshlode/calculators/meshpotential.py @@ -4,7 +4,6 @@ from meshlode.lib.fourier_convolution import FourierSpaceConvolution from meshlode.lib.mesh_interpolator import MeshInterpolator -from meshlode.lib.system import System def _1d_tolist(x: torch.Tensor) -> List[int]: @@ -47,14 +46,14 @@ class MeshPotential(torch.nn.Module): Define simple example structure having the CsCl (Cesium Chloride) structure - >>> atomic_types = torch.tensor([55, 17]) # Cs and Cl + >>> types = torch.tensor([55, 17]) # Cs and Cl >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) >>> cell = torch.eye(3) Compute features >>> MP = MeshPotential(atomic_smearing=0.2, mesh_spacing=0.1, interpolation_order=4) - >>> MP.compute(atomic_types=atomic_types, positions=positions, cell=cell) + >>> MP.compute(types=types, positions=positions, cell=cell) tensor([[-0.5467, 1.3755], [ 1.3755, -0.5467]]) """ @@ -91,9 +90,7 @@ def __init__( if all_types is None: self.all_types = None else: - self.all_types = _1d_tolist( - torch.unique(torch.tensor(all_types)) - ) + self.all_types = _1d_tolist(torch.unique(torch.tensor(all_types))) # Initilize auxiliary objects self.fourier_space_convolution = FourierSpaceConvolution() @@ -116,60 +113,103 @@ def compute( positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute the potential at the position of each atom for all provided systems. + """Compute potential for all provided "systems" stacked inside list. - :param types: TODO - :param positions: TODO - :param cell: TODO + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param cell: single or 2D tensor of shape (3, 3), describing the bounding + box/unit cell of the system. Each row should be one of the bounding box + vector; and columns should contain the x, y, and z components of these + vectors (i.e. the cell should be given in row-major order). :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms,n_species), where - n_species is the number of species in all systems combined. If the input was + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was a single system only a single torch tensor with the potentials is returned. - IMPORTANT: If multiple species are present, the different "species-channels" + IMPORTANT: If multiple types are present, the different "types-channels" are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of species O, H, H, then for this + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_species``), where ``features[0, 0]`` is the potential at the position of + ``n_types``), where ``features[0, 0]`` is the potential at the position of the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, while ``features[0,1]`` is the potential at the position of the Oxygen atom generated by the Oxygen atom(s). """ - # Make sure that the compute function also works if only a single frame is - # provided as input (for convenience of users testing out the code) + # make sure compute function works if only a single tensor are provided as input if not isinstance(types, list): types = [types] - if not isinstance(species, list): + if not isinstance(positions, list): positions = [positions] - if not isinstance(species, list): + if not isinstance(cell, list): cell = [cell] - # TODO check that all inputs are on the same device and that positions and cell - # have the same dtype! + for types_single, positions_single, cell_single in zip(types, positions, cell): + if len(types_single.shape) != 1: + raise ValueError( + "each `types` must be a 1 dimensional tensor, got at least " + f"one tensor with {len(types_single.shape)} dimensions" + ) + + if positions_single.shape != (len(types_single), 3): + raise ValueError( + "each `positions` must be a (n_types x 3) tensor, got at least " + f"one tensor with shape {list(positions_single.shape)}" + ) + + if cell_single.shape != (3, 3): + raise ValueError( + "each `cell` must be a (3 x 3) tensor, got at least " + f"one tensor with shape {list(cell_single.shape)}" + ) + + if cell_single.dtype != positions_single.dtype: + raise ValueError( + "`cell` must be have the same dtype as `positions`, got " + f"{cell_single.dtype} and {positions_single.dtype}" + ) + + if ( + positions_single.device != types_single.device + or cell_single.device != types_single.device + ): + raise ValueError( + "`types`, `positions`, and `cell` must be on the same device, got " + f"{types_single.device}, {positions_single.device} and " + f"{cell_single.device}." + ) + + # We don't require and test that all dtypes and devices are consistent if a list + # of inputs. Each "frame" is processed independently. requested_types = self._get_requested_types(types) - n_types = len(atomic_types) + n_types = len(requested_types) potentials = [] for types_single, positions_single, cell_single in zip(types, positions, cell): # One-hot encoding of charge information - charges = torch.zeros((len(types_single), n_types), dtype=positions_single.dtype) + charges = torch.zeros( + (len(types_single), n_types), dtype=positions_single.dtype + ) for i_type, atomic_type in enumerate(requested_types): charges[types_single == atomic_type, i_type] = 1.0 # Compute the potentials potentials.append( - self._compute_single_frame(positions=positions_single, charges=charges, cell=cell_single) + self._compute_single_system( + positions=positions_single, charges=charges, cell=cell_single + ) ) - if len(species) == 1: + if len(types) == 1: return potentials[0] else: return potentials def _get_requested_types(self, types: List[torch.Tensor]) -> List[int]: - """Extract a list of all present types from the list of types.""" + """Extract a list of all unique and present types from the list of types.""" all_types = torch.hstack(types) types_requested = _1d_tolist(torch.unique(all_types)) @@ -183,7 +223,7 @@ def _get_requested_types(self, types: List[torch.Tensor]) -> List[int]: else: return types_requested - def _compute_single_frame( + def _compute_single_system( self, positions: torch.Tensor, charges: torch.Tensor, @@ -201,8 +241,8 @@ def _compute_single_frame( charge of atom i. More generally, the potential for the same atom positions is computed for n_channels independent meshes, and one can specify the "charge" of each atom on each of the meshes independently. For standard LODE - that treats all atomic species separately, one example could be: If n_atoms - = 4 and the species are [Na, Cl, Cl, Na], one could set n_channels=2 and use + that treats all (atomic) types separately, one example could be: If n_atoms + = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for the charges. This would then separately compute the "Na" potential and "Cl" potential. Subtracting these from each other, one could recover the more diff --git a/src/meshlode/lib/__init__.py b/src/meshlode/lib/__init__.py index 136dcddd..e7c78e89 100644 --- a/src/meshlode/lib/__init__.py +++ b/src/meshlode/lib/__init__.py @@ -1,5 +1,4 @@ -from .system import System from .fourier_convolution import FourierSpaceConvolution from .mesh_interpolator import MeshInterpolator -__all__ = ["FourierSpaceConvolution", "MeshInterpolator", "System"] +__all__ = ["FourierSpaceConvolution", "MeshInterpolator"] diff --git a/src/meshlode/lib/mesh_interpolator.py b/src/meshlode/lib/mesh_interpolator.py index 35a6e4e7..9ae31540 100644 --- a/src/meshlode/lib/mesh_interpolator.py +++ b/src/meshlode/lib/mesh_interpolator.py @@ -181,7 +181,7 @@ def points_to_mesh(self, particle_weights: torch.Tensor) -> torch.Tensor: ``particle_weights[i,a]`` is the weight (charge) that point (atom) i has to generate the "a-th" potential. In practice, this can be used to compute e.g. the Na and Cl contributions to the potential separately by using a one-hot - encoding of the species. + encoding of the types. :return: torch.tensor of shape ``(n_channels, n_mesh, n_mesh, n_mesh)`` Discrete density diff --git a/src/meshlode/metatensor/meshpotential.py b/src/meshlode/metatensor/meshpotential.py index 5092f857..7fc7ccc8 100644 --- a/src/meshlode/metatensor/meshpotential.py +++ b/src/meshlode/metatensor/meshpotential.py @@ -16,29 +16,35 @@ from .. import calculators +# We are breaking the Liskov substitution principle here by changing the signature of +# "compute" compated to the supertype of "MeshPotential". +# mypy: disable-error-code="override" + + class MeshPotential(calculators.MeshPotential): - """A species wise long range potential. + """An (atomic) type wise long range potential. Refer to :class:`meshlode.MeshPotential` for full documentation. Example ------- >>> import torch - >>> from meshlode.lib import System + >>> from metatensor.torch.atomistic import System >>> from meshlode.metatensor import MeshPotential Define simple example structure having the CsCl (Cesium Chloride) structure + >>> types = torch.tensor([55, 17]) # Cs and Cl >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) - >>> atomic_types = torch.tensor([55, 17]) # Cs and Cl - >>> frame = System(species=atomic_types, positions=positions, cell=torch.eye(3)) + >>> cell = torch.eye(3) + >>> system = System(types=types, positions=positions, cell=cell) Compute features >>> MP = MeshPotential(atomic_smearing=0.2, mesh_spacing=0.1, interpolation_order=4) - >>> features = MP.compute(frame) + >>> features = MP.compute(system) - All species combinations + All (atomic) type combinations >>> features.keys Labels( @@ -67,16 +73,15 @@ def compute( self, systems: Union[List[System], System], ) -> TensorMap: - """Compute the potential at the position of each atom for all Systems provided - in "frames". + """Compute potential for all provided ``systems``. - :param systems: single System or list of Systems on which to run the - calculation. If any of the systems' ``positions`` or ``cell`` has - ``requires_grad`` set to :py:obj:`True`, then the corresponding gradients - are computed and registered as a custom node in the computational graph, to - allow backward propagation of the gradients later. + All ``systems`` must have the same ``dtype`` and the same ``device``. - :return: TensorMap containing the potential of all atoms. The keys of the + :param systems: single System or list of + :py:class:`metatensor.torch.atomisic.System` on which to run the + calculation. + + :return: TensorMap containing the potential of all types. The keys of the tensormap are "center_type" and "neighbor_type". """ # Make sure that the compute function also works if only a single frame is @@ -84,54 +89,69 @@ def compute( if not isinstance(systems, list): systems = [systems] - atomic_types = self._get_atomic_types(systems) - n_species = len(atomic_types) - - # Initialize dictionary for sparse storage of the features Generate a dictionary - # to map atomic species to array indices In general, the species are sorted - # according to atomic number and assigned the array indices 0, 1, 2,... - # Example: for H2O: `H` is mapped to `0` and `O` is mapped to `1`. - n_species_sq = n_species * n_species - feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_species_sq)} + if len(systems) > 1: + for system in systems[1:]: + if system.dtype != systems[0].dtype: + raise ValueError( + "`dtype` of all systems must be the same, got " + f"{system.dtype} and {systems[0].dtype}`" + ) + + if system.device != systems[0].device: + raise ValueError( + "`device of all systems must be the same, got " + f"{system.device} and {systems[0].device}`" + ) + + requested_types = self._get_requested_types( + [system.types for system in systems] + ) + n_types = len(requested_types) + + # Initialize dictionary for sparse storage of the features to map atomic types + # to array indices. In general, the types are sorted according to their + # (integer) type and assigned the array indices 0, 1, 2,... Example: for H2O: + # `H` is mapped to `0` and `O` is mapped to `1`. + n_types_sq = n_types * n_types + feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_types_sq)} for system in systems: # One-hot encoding of charge information - n_atoms = len(system) - species = system.species - charges = torch.zeros((n_atoms, n_species), dtype=torch.float) - for i_specie, atomic_type in enumerate(atomic_types): - charges[species == atomic_type, i_specie] = 1.0 + types = system.types + charges = torch.zeros((len(system), n_types), dtype=system.positions.dtype) + for i_specie, atomic_type in enumerate(requested_types): + charges[types == atomic_type, i_specie] = 1.0 # Compute the potentials - potential = self._compute_single_frame( + potential = self._compute_single_system( system.positions, charges, system.cell ) # Reorder data into Metatensor format - for spec_center, at_num_center in enumerate(atomic_types): - for spec_neighbor in range(len(atomic_types)): - a_pair = spec_center * n_species + spec_neighbor + for spec_center, at_num_center in enumerate(requested_types): + for spec_neighbor in range(len(requested_types)): + a_pair = spec_center * n_types + spec_neighbor feat_dic[a_pair] += [ - potential[species == at_num_center, spec_neighbor] + potential[types == at_num_center, spec_neighbor] ] # Assemble all computed potential values into TensorBlocks for each combination # of center_type and neighbor_type blocks: List[TensorBlock] = [] for keys, values in feat_dic.items(): - spec_center = atomic_types[keys // n_species] + spec_center = requested_types[keys // n_types] # Generate the Labels objects for the samples and properties of the # TensorBlock. values_samples: List[List[int]] = [] for i_frame, system in enumerate(systems): for i_atom in range(len(system)): - if system.species[i_atom] == spec_center: + if system.types[i_atom] == spec_center: values_samples.append([i_frame, i_atom]) samples_vals_tensor = torch.tensor(values_samples, dtype=torch.int32) - # If no atoms are found that match the species pair `samples_vals_tensor` + # If no atoms are found that match the types pair `samples_vals_tensor` # will be empty. We have to reshape the empty tensor to be a valid input for # `Labels`. if len(samples_vals_tensor) == 0: @@ -151,8 +171,8 @@ def compute( # Generate TensorMap from TensorBlocks by defining suitable keys key_values: List[torch.Tensor] = [] - for spec_center in atomic_types: - for spec_neighbor in atomic_types: + for spec_center in requested_types: + for spec_neighbor in requested_types: key_values.append(torch.tensor([spec_center, spec_neighbor])) key_values = torch.vstack(key_values) labels_keys = Labels(["center_type", "neighbor_type"], key_values) diff --git a/tests/calculators/test_meshpotential.py b/tests/calculators/test_meshpotential.py index ca4daffc..c04c398e 100644 --- a/tests/calculators/test_meshpotential.py +++ b/tests/calculators/test_meshpotential.py @@ -7,20 +7,20 @@ import torch from torch.testing import assert_close -from meshlode import MeshPotential, System +from meshlode import MeshPotential MADELUNG_CSCL = torch.tensor(2 * 1.7626 / math.sqrt(3)) CHARGES_CSCL = torch.tensor([1.0, -1.0]) -def cscl_system() -> System: +def cscl_system(): """CsCl crystal. Same as in the madelung test""" - return System( - species=torch.tensor([55, 17]), - positions=torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]), - cell=torch.eye(3), - ) + types = torch.tensor([55, 17]) + positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) + cell = torch.eye(3) + + return types, positions, cell # Initialize the calculators. For now, only the MeshPotential is implemented. @@ -46,21 +46,21 @@ def test_interpolation_order_error(): def test_all_types(): descriptor = MeshPotential(atomic_smearing=0.1, all_types=[8, 55, 17]) - values = descriptor.compute(cscl_system()) + values = descriptor.compute(*cscl_system()) assert values.shape == (2, 3) assert torch.equal(values[:, 0], torch.zeros(2)) def test_all_types_error(): descriptor = MeshPotential(atomic_smearing=0.1, all_types=[17]) - with pytest.raises(ValueError, match="Global list of atomic numbers"): - descriptor.compute(cscl_system()) + with pytest.raises(ValueError, match="Global list of types"): + descriptor.compute(*cscl_system()) # Make sure that the calculators are computing the features without raising errors, # and returns the correct output format (TensorMap) def check_operation(calculator): - descriptor = calculator.compute(cscl_system()) + descriptor = calculator.compute(*cscl_system()) assert type(descriptor) is torch.Tensor @@ -76,7 +76,7 @@ def test_operation_as_torch_script(): def test_single_frame(): - values = descriptor().compute(cscl_system()) + values = descriptor().compute(*cscl_system()) print(values) assert_close( MADELUNG_CSCL, @@ -87,7 +87,10 @@ def test_single_frame(): def test_multi_frame(): - l_values = descriptor().compute([cscl_system(), cscl_system()]) + types, positions, cell = cscl_system() + l_values = descriptor().compute( + types=[types, types], positions=[positions, positions], cell=[cell, cell] + ) for values in l_values: assert_close( MADELUNG_CSCL, @@ -95,3 +98,58 @@ def test_multi_frame(): atol=1e4, rtol=1e-5, ) + + +def test_types_error(): + types = torch.tensor([[1, 2], [3, 4]]) # This is a 2D tensor, should be 1D + positions = torch.zeros((2, 3)) + cell = torch.eye(3) + + match = ( + "each `types` must be a 1 dimensional tensor, got at least one tensor with " + "2 dimensions" + ) + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) + + +def test_positions_error(): + types = torch.tensor([1, 2]) + positions = torch.zeros( + (1, 3) + ) # This should have the same first dimension as types + cell = torch.eye(3) + + match = ( + "each `positions` must be a \\(n_types x 3\\) tensor, got at least " + "one tensor with shape \\[1, 3\\]" + ) + + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) + + +def test_cell_error(): + types = torch.tensor([1, 2, 3]) + positions = torch.zeros((3, 3)) + cell = torch.eye(2) # This is a 2x2 tensor, should be 3x3 + + match = ( + "each `cell` must be a \\(3 x 3\\) tensor, got at least one tensor " + "with shape \\[2, 2\\]" + ) + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) + + +def test_positions_cell_dtype_error(): + types = torch.tensor([1, 2, 3]) + positions = torch.zeros((3, 3), dtype=torch.float32) + cell = torch.eye(3, dtype=torch.float64) + + match = ( + "`cell` must be have the same dtype as `positions`, got torch.float64 " + "and torch.float32" + ) + with pytest.raises(ValueError, match=match): + descriptor().compute(types=types, positions=positions, cell=cell) diff --git a/tests/metatensor/test_madelung.py b/tests/metatensor/test_madelung.py index 70b34827..8bb7c630 100644 --- a/tests/metatensor/test_madelung.py +++ b/tests/metatensor/test_madelung.py @@ -4,10 +4,9 @@ import pytest import torch +from metatensor.torch.atomistic import System from torch.testing import assert_close -from meshlode import System - meshlode_metatensor = pytest.importorskip("meshlode.metatensor") @@ -44,7 +43,7 @@ def crystal_dictionary(self): # closest Na-Cl pair is exactly 1. The cubic unit cell # in these units would have a length of 2. d["NaCl"]["symbols"] = ["Na", "Cl"] - d["NaCl"]["atomic_types"] = torch.tensor([11, 17]) + d["NaCl"]["types"] = torch.tensor([11, 17]) d["NaCl"]["charges"] = torch.tensor([[1.0, -1]]).T d["NaCl"]["positions"] = torch.tensor([[0, 0, 0], [1.0, 0, 0]]) d["NaCl"]["cell"] = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]]) @@ -56,7 +55,7 @@ def crystal_dictionary(self): # The closest Cs-Cl distance is sqrt(3)/2. We thus divide # the Madelung constant by this value to match the reference. d["CsCl"]["symbols"] = ["Cs", "Cl"] - d["CsCl"]["atomic_types"] = torch.tensor([55, 17]) + d["CsCl"]["types"] = torch.tensor([55, 17]) d["CsCl"]["charges"] = torch.tensor([[1.0, -1]]).T d["CsCl"]["positions"] = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) d["CsCl"]["cell"] = torch.eye(3) @@ -70,7 +69,7 @@ def crystal_dictionary(self): # If, on the other han_pylode_without_centerd, we set the lattice constant of # the cubic cell equal to 1, the Zn-S distance is sqrt(3)/4. d["ZnS"]["symbols"] = ["S", "Zn"] - d["ZnS"]["atomic_types"] = torch.tensor([16, 30]) + d["ZnS"]["types"] = torch.tensor([16, 30]) d["ZnS"]["charges"] = torch.tensor([[1.0, -1]]).T d["ZnS"]["positions"] = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) d["ZnS"]["cell"] = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]]) @@ -80,7 +79,7 @@ def crystal_dictionary(self): u = torch.tensor([3 / 8]) c = torch.sqrt(1 / u) d["ZnSO4"]["symbols"] = ["S", "Zn", "S", "Zn"] - d["ZnSO4"]["atomic_types"] = torch.tensor([16, 30, 16, 30]) + d["ZnSO4"]["types"] = torch.tensor([16, 30, 16, 30]) d["ZnSO4"]["charges"] = torch.tensor([[1.0, -1, 1, -1]]).T d["ZnSO4"]["positions"] = torch.tensor( [ @@ -125,7 +124,9 @@ def test_madelung_low_order( MP = meshlode_metatensor.MeshPotential( smearing_eff, mesh_spacing, interpolation_order, subtract_self=True ) - potentials_mesh = MP._compute_single_frame(cell, positions, charges) + potentials_mesh = MP._compute_single_system( + positions=positions, charges=charges, cell=cell + ) energies = potentials_mesh * charges energies_target = -torch.ones_like(energies) * madelung assert_close(energies, energies_target, rtol=1e-4, atol=1e-6) @@ -159,7 +160,9 @@ def test_madelung_high_order( MP = meshlode_metatensor.MeshPotential( smearing_eff, mesh_spacing, interpolation_order, subtract_self=True ) - potentials_mesh = MP._compute_single_frame(cell, positions, charges) + potentials_mesh = MP._compute_single_system( + positions=positions, charges=charges, cell=cell + ) energies = potentials_mesh * charges energies_target = -torch.ones_like(energies) * madelung assert_close(energies, energies_target, rtol=1e-2, atol=1e-3) @@ -183,25 +186,25 @@ def test_madelung_low_order_metatensor( dic = crystal_dictionary[crystal_name] positions = dic["positions"] * scaling_factor cell = dic["cell"] * scaling_factor - atomic_types = dic["atomic_types"] + types = dic["types"] charges = dic["charges"] madelung = dic["madelung"] / scaling_factor mesh_spacing = atomic_smearing / 2 * scaling_factor smearing_eff = atomic_smearing * scaling_factor n_atoms = len(positions) - frame = System(species=atomic_types, positions=positions, cell=cell) + system = System(types=types, positions=positions, cell=cell) MP = meshlode_metatensor.MeshPotential( atomic_smearing=smearing_eff, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, subtract_self=True, ) - potentials_mesh = MP.compute(frame) + potentials_mesh = MP.compute(system) # Compute the actual potential from the features energies = torch.zeros((n_atoms, 1)) - for idx_c, c in enumerate(atomic_types): - for idx_n, n in enumerate(atomic_types): + for idx_c, c in enumerate(types): + for idx_n, n in enumerate(types): block = potentials_mesh.block( {"center_type": int(c), "neighbor_type": int(n)} ) diff --git a/tests/metatensor/test_metatensor_meshpotential.py b/tests/metatensor/test_metatensor_meshpotential.py index f888374b..75dcf415 100644 --- a/tests/metatensor/test_metatensor_meshpotential.py +++ b/tests/metatensor/test_metatensor_meshpotential.py @@ -2,21 +2,22 @@ import pytest import torch +from metatensor.torch.atomistic import System from packaging import version -from meshlode import System - metatensor_torch = pytest.importorskip("metatensor.torch") meshlode_metatensor = pytest.importorskip("meshlode.metatensor") # Define toy system consisting of a single structure for testing -def toy_system_single_frame() -> System: +def toy_system_single_frame(dtype=torch.float32) -> System: return System( - species=torch.tensor([1, 1, 8, 8]), - positions=torch.tensor([[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]]), - cell=torch.tensor([[10.0, 0, 0], [0, 10, 0], [0, 0, 10]]), + types=torch.tensor([1, 1, 8, 8]), + positions=torch.tensor( + [[0.0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3]], dtype=dtype + ), + cell=torch.tensor([[10.0, 0, 0], [0, 10, 0], [0, 0, 10]], dtype=dtype), ) @@ -46,6 +47,17 @@ def test_all_types(): ) +def test_wrong_dtype_between_systems(): + match = "`dtype` of all systems must be the same, got 7 and 6" + with pytest.raises(ValueError, match=match): + descriptor().compute( + [ + toy_system_single_frame(dtype=torch.float32), + toy_system_single_frame(dtype=torch.float64), + ] + ) + + # Make sure that the calculators are computing the features without raising errors, # and returns the correct output format (TensorMap) def check_operation(calculator): @@ -66,28 +78,28 @@ def test_operation_as_torch_script(): check_operation(scripted) -# Define a more complex toy system consisting of multiple frames, mixing three species. +# Define a more complex toy system consisting of multiple frames, mixing three types. def toy_system_2() -> List[System]: # First few frames containing Nitrogen L = 2.0 frames = [] frames.append( System( - species=torch.tensor([7]), + types=torch.tensor([7]), positions=torch.zeros((1, 3)), cell=L * 2 * torch.eye(3), ) ) frames.append( System( - species=torch.tensor([7, 7]), + types=torch.tensor([7, 7]), positions=torch.zeros((2, 3)), cell=L * 2 * torch.eye(3), ) ) frames.append( System( - species=torch.tensor([7, 7, 7]), + types=torch.tensor([7, 7, 7]), positions=torch.zeros((3, 3)), cell=L * 2 * torch.eye(3), ) @@ -96,9 +108,7 @@ def toy_system_2() -> List[System]: # One more frame containing Na and Cl positions = torch.tensor([[0, 0, 0], [1.0, 0, 0]]) cell = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]]) - frames.append( - System(species=torch.tensor([11, 17]), positions=positions, cell=cell) - ) + frames.append(System(types=torch.tensor([11, 17]), positions=positions, cell=cell)) return frames