diff --git a/src/meshlode/calculators.py b/src/meshlode/calculators.py index 182a851b..6792e38d 100644 --- a/src/meshlode/calculators.py +++ b/src/meshlode/calculators.py @@ -9,7 +9,7 @@ Our calculator API follows the `rascaline `_ API and coding guidelines to promote usability and interoperability with existing workflows. """ -from typing import Dict, List, Union +from typing import Dict, List, Union, Optional import torch from metatensor.torch import Labels, TensorBlock, TensorMap @@ -30,10 +30,11 @@ def _my_1d_tolist(x: torch.Tensor): class MeshPotential(torch.nn.Module): """A species wise long range potential. - :param atomic_gaussian_width: Width of the atom-centered gaussian used to create the + :param atomic_smearing: Width of the atom-centered gaussian used to create the atomic density. :param mesh_spacing: Value that determines the umber of Fourier-space grid points - that will be used along each axis. + that will be used along each axis. If set to ``None``, it will automatically + be set to half of ``atomic_smearing`` :param interpolation_order: Interpolation order for mapping onto the grid, where an interpolation order of p corresponds to interpolation by a polynomial of degree ``p-1`` (e.g. ``p=4`` for cubic interpolation). @@ -47,7 +48,7 @@ class MeshPotential(torch.nn.Module): >>> import torch >>> from meshlode import MeshPotential, System - Define simple example structure having the CsCl structure + Define simple example structure having the CsCl (Cesium Chloride) structure >>> positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) >>> atomic_numbers = torch.tensor([55, 17]) # Cs and Cl @@ -55,9 +56,7 @@ class MeshPotential(torch.nn.Module): Compute features - >>> MP = MeshPotential( - ... atomic_gaussian_width=0.2, mesh_spacing=0.1, interpolation_order=4 - ... ) + >>> MP = MeshPotential(atomic_smearing=0.2, mesh_spacing=0.1, interpolation_order=4) >>> features = MP.compute(frame) All species combinations @@ -70,10 +69,10 @@ class MeshPotential(torch.nn.Module): 55 17 55 55 ) - >>> block_ClCl = features.block({"species_center": 17, "species_neighbor": 17}) The Cl-potential at the position of the Cl atom + >>> block_ClCl = features.block({"species_center": 17, "species_neighbor": 17}) >>> block_ClCl.values tensor([[1.3755]]) @@ -83,14 +82,26 @@ class MeshPotential(torch.nn.Module): def __init__( self, - atomic_gaussian_width: float, - mesh_spacing: float = 0.2, - interpolation_order: float = 4, - subtract_self: bool = False, + atomic_smearing: float, + mesh_spacing: Optional[float] = None, + interpolation_order: Optional[float] = 4, + subtract_self: Optional[bool] = False, ): super().__init__() - self.atomic_gaussian_width = atomic_gaussian_width + # Check that all provided values are correct + if interpolation_order not in [1, 2, 3, 4, 5]: + raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") + if atomic_smearing <= 0: + raise ValueError(f"`atomic_smearing` {atomic_smearing} has to be positive") + + # If no explicit mesh_spacing is given, set it such that it can resolve + # the smeared potentials. + if mesh_spacing is None: + mesh_spacing = atomic_smearing / 2 + + # Store provided parameters + self.atomic_smearing = atomic_smearing self.mesh_spacing = mesh_spacing self.interpolation_order = interpolation_order self.subtract_self = subtract_self @@ -105,7 +116,6 @@ def forward( """forward just calls :py:meth:`CalculatorModule.compute`""" res = self.compute(frames=systems) return res - # return 0. def compute( self, @@ -211,12 +221,12 @@ def _compute_single_frame( Compute the "electrostatic" potential at the position of all atoms in a structure. - :param cell: torch.tensor of shape (3,3). Describes the unit cell of the + :param cell: torch.tensor of shape `(3,3)`. Describes the unit cell of the structure, where cell[i] is the i-th basis vector. :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian coordinates of the atoms. The implementation also works if the positions are not contained within the unit cell. - :param charges: torch.tensor of shape (n_atoms, n_channels). In the simplest + :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the charge of atom i. More generally, the potential for the same atom positions is computed for n_channels independent meshes, and one can specify the @@ -229,11 +239,10 @@ def _compute_single_frame( standard electrostatic potential in which Na and Cl have charges of +1 and -1, respectively. - :returns: torch.tensor of shape (n_atoms, n_channels) containing the potential - at the position of each atom for the n_channels independent meshes separately. + :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential + at the position of each atom for the `n_channels` independent meshes separately. """ - smearing = self.atomic_gaussian_width - mesh_resolution = self.mesh_spacing + smearing = self.atomic_smearing interpolation_order = self.interpolation_order # Initializations @@ -241,11 +250,8 @@ def _compute_single_frame( assert positions.shape == (n_atoms, 3) assert charges.shape[0] == n_atoms - # Define k-vectors - if mesh_resolution is None: - k_cutoff = 2 * torch.pi / (smearing / 2) - else: - k_cutoff = 2 * torch.pi / mesh_resolution + # Define cutoff in reciprocal space + k_cutoff = 2 * torch.pi / self.mesh_spacing # Compute number of times each basis vector of the # reciprocal space can be scaled until the cutoff diff --git a/src/meshlode/fourier_convolution.py b/src/meshlode/fourier_convolution.py index 14fde2a0..92bcbba6 100644 --- a/src/meshlode/fourier_convolution.py +++ b/src/meshlode/fourier_convolution.py @@ -3,6 +3,7 @@ =================== """ import torch +from typing import Optional class FourierSpaceConvolution: @@ -10,11 +11,13 @@ class FourierSpaceConvolution: Class for handling all the steps necessary to compute the convolution f*G between two functions f and G, where the values of f are provided on a discrete mesh. - :param cell: torch.tensor of shape (3,3) Tensor specifying the real space unit + :param cell: torch.tensor of shape ``(3,3)`` Tensor specifying the real space unit cell of a structure, where cell[i] is the i-th basis vector """ def __init__(self, cell: torch.Tensor): + if cell.shape != (3, 3): + raise ValueError(f"cell of shape {cell.shape} should be of shape (3,3)") self.cell: torch.Tensor = cell def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor: @@ -22,20 +25,20 @@ def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor: For a given unit cell, compute all reciprocal space vectors that are used to perform sums in the Fourier transformed space. - :param cell: torch.tensor of shape (3,3) - Tensor specifying the real space unit cell of a structure, where cell[i] is - the i-th basis vector - :param ns: torch.tensor of shape (3,) - ns = [nx, ny, nz] contains the number of mesh points in the x-, y- and + :param ns: torch.tensor of shape ``(3,)`` + ``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and z-direction, respectively. For faster performance during the Fast Fourier Transform (FFT) it is recommended to use values of nx, ny and nz that are powers of 2. - :return: torch.tensor of shape [N_k,3] Contains all reciprocal space vectors - that will be used during Ewald summation (or related approaches). The number - N_k of such vectors is given by N_k = nx * ny * nz. k_vectors[i] contains - the i-th vector, where the order has no special significance. + :return: torch.tensor of shape ``(N,3)`` Contains all reciprocal space vectors + that will be used during Ewald summation (or related approaches). + ``k_vectors[i]`` contains the i-th vector, where the order has no special + significance. """ + if ns.shape != (3,): + raise ValueError(f"ns of shape {ns.shape} should be of shape (3,)") + # Define basis vectors of the reciprocal cell reciprocal_cell = 2 * torch.pi * self.cell.inverse().T bx = reciprocal_cell[0] @@ -55,17 +58,18 @@ def generate_kvectors(self, ns: torch.Tensor) -> torch.Tensor: return k_vectors def kernel_func( - self, ksq: torch.Tensor, potential_exponent: int = 1, smearing: float = 0.2 + self, ksq: torch.Tensor, potential_exponent: int = 1, + smearing: float = 0.2 ) -> torch.Tensor: """ Fourier transform of the Coulomb potential or more general effective 1/r**p potentials with additional smearing to remove the singularity at the origin. - :param ksq: torch.tensor of shape (N_k,) Squared norm of the k-vectors + :param ksq: torch.tensor of shape ``(N,)`` Squared norm of the k-vectors :param potential_exponent: Exponent of the effective 1/r**p decay :param smearing: Broadening of the 1/r**p decay close to the origin - :return: torch.tensor of shape (N_k,) with the values of the kernel function + :return: torch.tensor of shape ``(N,)`` with the values of the kernel function G(k) evaluated at the provided (squared norms of the) k-vectors """ if potential_exponent == 1: @@ -76,7 +80,7 @@ def kernel_func( raise ValueError("Only potential exponents 0 and 1 are supported") def value_at_origin( - self, potential_exponent: int = 1, smearing: float = 0.2 + self, potential_exponent: int = 1, smearing: Optional[float] = 0.2 ) -> float: """ Since the kernel function in reciprocal space typically has a (removable) @@ -87,7 +91,7 @@ def value_at_origin( :return: float of G(k=0), the value of the kernel function at the origin. """ - if potential_exponent in [1, 2, 3]: + if potential_exponent == 1: return 0.0 elif potential_exponent == 0: return 1.0 @@ -104,7 +108,7 @@ def compute( Compute the "electrostatic potential" from the density defined on a discrete mesh. - :param mesh_values: torch.tensor of shape (n_channels, nx, ny, nz) + :param mesh_values: torch.tensor of shape ``(n_channels, nx, ny, nz)`` The values of the density defined on a mesh. :param potential_exponent: int The exponent in the 1/r**p decay of the effective potential, where p=1 @@ -112,10 +116,13 @@ def compute( :param smearing: float Width of the Gaussian smearing (for the Coulomb potential). - :returns: torch.tensor of shape (n_channels, nx, ny, nz) + :returns: torch.tensor of shape ``(n_channels, nx, ny, nz)`` The potential evaluated on the same mesh points as the provided density. """ + if mesh_values.dim() != 4: + raise ValueError("`mesh_values`` needs to be a 4 dimensional tensor") + # Get shape information from mesh n_channels, nx, ny, nz = mesh_values.shape ns = torch.tensor([nx, ny, nz]) diff --git a/src/meshlode/mesh_interpolator.py b/src/meshlode/mesh_interpolator.py index e9daeed7..88796183 100644 --- a/src/meshlode/mesh_interpolator.py +++ b/src/meshlode/mesh_interpolator.py @@ -21,9 +21,9 @@ class MeshInterpolator: of calculations is identical, this is performed in a separate function called "compute_interpolation_weights". - :param cell: torch.tensor of shape (3,3) - cell[i] is the i-th basis vector of the unit cell - :param ns_mesh: list of tuple of size 3 + :param cell: torch.tensor of shape ``(3,3)``, where ``cell[i]`` is the i-th basis + vector of the unit cell + :param ns_mesh: toch.tensor of shape ``(3,)`` Number of mesh points to use along each of the three axes :param interpolation_order: int The degree of the polynomials used for interpolation. A higher order leads @@ -34,6 +34,14 @@ class MeshInterpolator: def __init__( self, cell: torch.Tensor, ns_mesh: torch.Tensor, interpolation_order: int ): + # Check that the provided parameters match the specifications + if cell.shape != (3, 3): + raise ValueError(f"cell of shape {cell.shape} should be of shape (3,3)") + if ns_mesh.shape != (3,): + raise ValueError(f"shape {ns_mesh.shape} of `ns_mesh` has to be (3,)") + if interpolation_order not in [1, 2, 3, 4, 5]: + raise ValueError("Only `interpolation_order` from 1 to 5 are allowed") + self.cell = cell self.ns_mesh = ns_mesh self.interpolation_order = interpolation_order @@ -48,7 +56,7 @@ def __init__( self.y_indices: torch.Tensor = torch.tensor(0) self.z_indices: torch.Tensor = torch.tensor(0) - def compute_1d_weights(self, x: torch.Tensor) -> torch.Tensor: + def _compute_1d_weights(self, x: torch.Tensor) -> torch.Tensor: """ Generate the smooth interpolation weights used to smear the particles onto a mesh. @@ -57,10 +65,10 @@ def compute_1d_weights(self, x: torch.Tensor) -> torch.Tensor: J. Chem. Phys. 109, 7678–7693 (1998) https://doi.org/10.1063/1.477414 - :param x: torch.tensor of shape (n,) + :param x: torch.tensor of shape ``(n,)`` Set of relative positions in the interval [-1/2, 1/2]. - :return: torch.tensor of shape (interpolation_order, n) + :return: torch.tensor of shape ``(interpolation_order, n)`` Interpolation weights """ # Compute weights based on the given order @@ -108,12 +116,16 @@ def compute_interpolation_weights(self, positions: torch.Tensor): """ Compute the interpolation weights of each atom for a given cell (specified during initialization of this class). The weights are not returned, but are used - when calling the forward (points_to_mesh) and backward (mesh_to_points) + when calling the forward (``points_to_mesh``) and backward (``mesh_to_points``) interpolation functions. - :param positions: torch.tensor of shape (N,3) + :param positions: torch.tensor of shape ``(N,3)`` Absolute positions of atoms in Cartesian coordinates """ + n_positions = len(positions) + if positions.shape != (n_positions, 3): + raise ValueError(f"shape {positions.shape} of `positions` has to be (N,3)") + # Compute positions relative to the mesh basis vectors positions_rel = torch.linalg.solve(self.cell.T, positions.T).T positions_rel *= self.ns_mesh @@ -127,7 +139,7 @@ def compute_interpolation_weights(self, positions: torch.Tensor): offsets = positions_rel - positions_rel_idx # Compute weights based on distances and interpolation order - self.interpolation_weights = self.compute_1d_weights(offsets) + self.interpolation_weights = self._compute_1d_weights(offsets) # Calculate indices of mesh points on which # the particle weights are interpolated @@ -169,15 +181,18 @@ def points_to_mesh(self, particle_weights: torch.Tensor) -> torch.Tensor: "compute_interpolation_weights" has been called before to compute all the necessary weights and indices. - :param particle_weights: torch.tensor of shape (n_atoms, n_channels) - particle_weights[i,a] is the "weight" or "charge" that atom i has to + :param particle_weights: torch.tensor of shape ``(n_points, n_channels)`` + ``particle_weights[i,a]`` is the weight (charge) that point (atom) i has to generate the "a-th" potential. In practice, this can be used to compute e.g. the Na and Cl contributions to the potential separately by using a one-hot encoding of the species. - :return: torch.tensor of shape (n_channels, n_mesh, n_mesh, n_mesh) + :return: torch.tensor of shape ``(n_channels, n_mesh, n_mesh, n_mesh)`` Discrete density """ + if particle_weights.dim() != 2: + raise ValueError("`particle_weights` needs to be a tensor of dimension 2") + # Update mesh values by combining particle weights and interpolation weights n_channels = particle_weights.shape[1] nx = int(self.ns_mesh[0]) @@ -203,18 +218,18 @@ def mesh_to_points(self, mesh_vals: torch.Tensor) -> torch.Tensor: Take a function defined on a mesh and interpolate its values on arbitrary positions. - :param mesh_vals: torch.tensor of shape (n_channels, nx, ny, nz) + :param mesh_vals: torch.tensor of shape ``(n_channels, nx, ny, nz)`` The tensor contains the values of a function evaluated on a three-dimensional mesh. (nx, ny, nz) are the number of points along each of the three directions, while n_channels provides the number of such functions that are treated simulateously for the present system. - :param positions: torch.tensor of shape (n_points,3) - Absolute positions of particles in Cartesian coordinates, onto whose - locations we wish to interpolate the mesh values. - :return: interpolated_values: torch.tensor of shape (n_points, n_channels) + :return: interpolated_values: torch.tensor of shape ``(n_points, n_channels)`` Values of the interpolated function. """ + if mesh_vals.dim() != 4: + raise ValueError("`mesh_vals` need to be a tensor of dimension 4") + interpolated_values = ( ( mesh_vals[:, self.x_indices, self.y_indices, self.z_indices] diff --git a/tests/test_calculators.py b/tests/test_calculators.py index 870040e7..3697d782 100644 --- a/tests/test_calculators.py +++ b/tests/test_calculators.py @@ -21,7 +21,7 @@ def toy_system_single_frame() -> System: # Initialize the calculators. For now, only the MeshPotential is implemented. def descriptor() -> MeshPotential: return MeshPotential( - atomic_gaussian_width=1.0, + atomic_smearing=1.0, ) @@ -87,11 +87,11 @@ class TestMultiFrameToySystem: # extreme values. tensormaps_list = [] frames = toy_system_2() - for atomic_gaussian_width in [0.01, 0.3, 3.7]: + for atomic_smearing in [0.01, 0.3, 3.7]: for mesh_spacing in [15.3, 0.19]: for interpolation_order in [1, 2, 3, 4, 5]: MP = MeshPotential( - atomic_gaussian_width=atomic_gaussian_width, + atomic_smearing=atomic_smearing, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, subtract_self=False, diff --git a/tests/test_madelung.py b/tests/test_madelung.py index 95b12a52..a588c36b 100644 --- a/tests/test_madelung.py +++ b/tests/test_madelung.py @@ -188,7 +188,7 @@ def test_madelung_low_order_metatensor( n_atoms = len(positions) frame = System(species=atomic_numbers, positions=positions, cell=cell) MP = MeshPotential( - atomic_gaussian_width=smearing_eff, + atomic_smearing=smearing_eff, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, subtract_self=True,