From a46b2de0a019e7baee9f62dd180ffc82d9c30f22 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Thu, 11 Jul 2024 11:59:59 +0200 Subject: [PATCH] Remove internal neighbor calculations --- docs/requirements.txt | 1 - examples/neighborlist_example.py | 1 - pyproject.toml | 3 +- src/meshlode/calculators/base.py | 43 +++----- src/meshlode/calculators/directpotential.py | 6 +- src/meshlode/calculators/ewaldpotential.py | 73 +++++-------- src/meshlode/calculators/pmepotential.py | 108 ++++++++++--------- src/meshlode/metatensor/base.py | 29 ++--- src/meshlode/metatensor/ewaldpotential.py | 47 +------- src/meshlode/metatensor/pmepotential.py | 66 ++++++++++-- tests/calculators/test_values_periodic.py | 78 +++++++++----- tests/calculators/test_workflow.py | 31 +++--- tests/calculators/utils.py | 27 +++++ tests/metatensor/test_base_metatensor.py | 28 +++++ tests/metatensor/test_workflow_metatensor.py | 2 + tests/metatensor/utils_metatensor.py | 53 +++++++++ tox.ini | 5 +- 17 files changed, 367 insertions(+), 234 deletions(-) create mode 100644 tests/calculators/utils.py create mode 100644 tests/metatensor/utils_metatensor.py diff --git a/docs/requirements.txt b/docs/requirements.txt index 778cdd77..44ae3f86 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,4 +3,3 @@ sphinx > 7.0 sphinx-gallery sphinx-toggleprompt tomli -chemiscope diff --git a/examples/neighborlist_example.py b/examples/neighborlist_example.py index 5ca0647a..4b13fd0f 100644 --- a/examples/neighborlist_example.py +++ b/examples/neighborlist_example.py @@ -103,7 +103,6 @@ mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, subtract_self=True, - sr_cutoff=sr_cutoff, ) potential = pme.compute(system) diff --git a/pyproject.toml b/pyproject.toml index 3f73d1b9..6909bb6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,12 +35,13 @@ keywords = [ ] dependencies = [ "torch >=1.11", - "ase >= 3.23.0", ] dynamic = ["version"] [project.optional-dependencies] examples = [ + "ase >= 3.23.0", + "chemiscope", "matplotlib", ] metatensor = [ diff --git a/src/meshlode/calculators/base.py b/src/meshlode/calculators/base.py index 0dac2f78..dab8a3a6 100644 --- a/src/meshlode/calculators/base.py +++ b/src/meshlode/calculators/base.py @@ -1,8 +1,6 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import torch -from ase import Atoms -from ase.neighborlist import neighbor_list from meshlode.lib import InversePowerLawPotential @@ -23,23 +21,12 @@ def _compute_sr( charges: torch.Tensor, cell: torch.Tensor, smearing: float, - sr_cutoff: torch.Tensor, - neighbor_indices: Optional[torch.Tensor] = None, - neighbor_shifts: Optional[torch.Tensor] = None, + neighbor_indices: torch.Tensor, + neighbor_shifts: torch.Tensor, ) -> torch.Tensor: - if neighbor_indices is None or neighbor_shifts is None: - # Get list of neighbors - struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) - atom_is, atom_js, neighbor_shifts = neighbor_list( - "ijS", struc, sr_cutoff.item(), self_interaction=False - ) - atom_is = torch.tensor(atom_is) - atom_js = torch.tensor(atom_js) - shifts = torch.tensor(neighbor_shifts, dtype=cell.dtype) # N x 3 - else: - atom_is = neighbor_indices[0] - atom_js = neighbor_indices[1] - shifts = neighbor_shifts.type(cell.dtype).T + atom_is = neighbor_indices[0] + atom_js = neighbor_indices[1] + shifts = neighbor_shifts.type(cell.dtype) # Compute energy potential = torch.zeros_like(charges) @@ -65,15 +52,9 @@ def _compute_sr( class CalculatorBaseTorch(torch.nn.Module): - """ - Base calculator for the torch interface to MeshLODE. - - :param exponent: the exponent :math:`p` in :math:`1/r^p` potentials - """ + """Base calculator for the torch interface to MeshLODE.""" - def __init__( - self, - ): + def __init__(self): super().__init__() def _validate_compute_parameters( @@ -270,7 +251,7 @@ def _validate_compute_parameters( f"device {neighbor_shifts_single.device}" ) - return positions, cell, charges, neighbor_indices, neighbor_shifts + return positions, charges, cell, neighbor_indices, neighbor_shifts def _compute_impl( self, @@ -286,8 +267,8 @@ def _compute_impl( # more general case ( positions, - cell, charges, + cell, neighbor_indices, neighbor_shifts, ) = self._validate_compute_parameters( @@ -302,11 +283,11 @@ def _compute_impl( potentials = [] for ( positions_single, - cell_single, charges_single, + cell_single, neighbor_indices_single, neighbor_shifts_single, - ) in zip(positions, cell, charges, neighbor_indices, neighbor_shifts): + ) in zip(positions, charges, cell, neighbor_indices, neighbor_shifts): # Compute the potentials potentials.append( self._compute_single_system( diff --git a/src/meshlode/calculators/directpotential.py b/src/meshlode/calculators/directpotential.py index 802f795d..1502d828 100644 --- a/src/meshlode/calculators/directpotential.py +++ b/src/meshlode/calculators/directpotential.py @@ -12,10 +12,10 @@ def __init__(self, exponent): def _compute_single_system( self, positions: torch.Tensor, - cell: Union[None, torch.Tensor], + cell: None, charges: torch.Tensor, - neighbor_indices: Union[None, torch.Tensor], - neighbor_shifts: Union[None, torch.Tensor], + neighbor_indices: None, + neighbor_shifts: None, ) -> torch.Tensor: # Compute matrix containing the squared distances from the Gram matrix # The squared distance and the inner product between two vectors r_i and r_j are diff --git a/src/meshlode/calculators/ewaldpotential.py b/src/meshlode/calculators/ewaldpotential.py index ef6e6640..1aad2d8e 100644 --- a/src/meshlode/calculators/ewaldpotential.py +++ b/src/meshlode/calculators/ewaldpotential.py @@ -10,7 +10,6 @@ class _EwaldPotentialImpl(_ShortRange): def __init__( self, exponent: float, - sr_cutoff: Union[None, torch.Tensor], atomic_smearing: Union[None, float], lr_wavelength: Union[None, float], subtract_self: bool, @@ -25,7 +24,6 @@ def __init__( self, exponent=exponent, subtract_interior=subtract_interior ) self.atomic_smearing = atomic_smearing - self.sr_cutoff = sr_cutoff self.lr_wavelength = lr_wavelength # If interior contributions are to be subtracted, also do so for self term @@ -38,26 +36,22 @@ def _compute_single_system( positions: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, - neighbor_indices: Optional[torch.Tensor] = None, - neighbor_shifts: Optional[torch.Tensor] = None, + neighbor_indices: torch.Tensor, + neighbor_shifts: torch.Tensor, ) -> torch.Tensor: # Set the defaut values of convergence parameters # The total computational cost = cost of SR part + cost of LR part # Bigger smearing increases the cost of the SR part while decreasing the cost # of the LR part. Since the latter usually is more expensive, we maximize the # value of the smearing by default to minimize the cost of the LR part. - # The two auxilary parameters (sr_cutoff, lr_wavelength) then control the + # The auxilary parameter lr_wavelength then control the # convergence of the SR and LR sums, respectively. The default values are # chosen to reach a convergence on the order of 1e-4 to 1e-5 for the test # structures. - if self.sr_cutoff is None: - cell_dimensions = torch.linalg.norm(cell, dim=1) - sr_cutoff = torch.min(cell_dimensions) / 2 - 1e-6 - else: - sr_cutoff = self.sr_cutoff - if self.atomic_smearing is None: - smearing = sr_cutoff / 5.0 + cell_dimensions = torch.linalg.norm(cell, dim=1) + max_cutoff = torch.min(cell_dimensions) / 2 - 1e-6 + smearing = max_cutoff / 5.0 else: smearing = self.atomic_smearing @@ -72,7 +66,6 @@ def _compute_single_system( charges=charges, cell=cell, smearing=smearing, - sr_cutoff=sr_cutoff, neighbor_indices=neighbor_indices, neighbor_shifts=neighbor_shifts, ) @@ -170,15 +163,21 @@ class EwaldPotential(CalculatorBaseTorch, _EwaldPotentialImpl): Scaling as :math:`\mathcal{O}(N^2)` with respect to the number of particles :math:`N`. + For computing a **neighborlist** a reasonable ``cutoff`` is half the length of + the shortest cell vector, which can be for example computed according as + + .. code-block:: python + + cell_dimensions = torch.linalg.norm(cell, dim=1) + cutoff = torch.min(cell_dimensions) / 2 - 1e-6 + + This ensures a accuracy of the short range part of ``1e-5``. + :param exponent: the exponent :math:`p` in :math:`1/r^p` potentials - :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If - not set to a global value, it will be set to be half of the shortest lattice - vector defining the cell (separately for each structure). :param atomic_smearing: Width of the atom-centered Gaussian used to split the Coulomb potential into the short- and long-range parts. If not set to a global - value, it will be set to 1/5 times the sr_cutoff value (separately for each - structure) to ensure convergence of the short-range part to a relative precision - of 1e-5. + value, it will be set to 1/5 times of half the larges box vector (separately for + each structure). :param lr_wavelength: Spatial resolution used for the long-range (reciprocal space) part of the Ewald sum. More conretely, all Fourier space vectors with a wavelength >= this value will be kept. If not set to a global value, it will be @@ -192,33 +191,12 @@ class EwaldPotential(CalculatorBaseTorch, _EwaldPotentialImpl): Note that if set to true, the self contribution (see previous) is also subtracted by default. - Example - ------- - We calculate the Madelung constant of a CsCl (Cesium-Chloride) crystal. The - reference value is :math:`2 \cdot 1.7626 / \sqrt{3} \approx 2.0354`. - - >>> import torch - - Define crystal structure - - >>> positions = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]) - >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) - >>> cell = torch.eye(3) - - Compute the potential - - >>> ewald = EwaldPotential() - >>> ewald.compute(positions=positions, charges=charges, cell=cell) - tensor([[-2.0354], - [ 2.0354]]) - - Which is the same as the reference value given above. + For an **example** on the usage refer to :py:class:`PMEPotential`. """ def __init__( self, exponent: float = 1.0, - sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, lr_wavelength: Optional[float] = None, subtract_self: bool = True, @@ -227,7 +205,6 @@ def __init__( _EwaldPotentialImpl.__init__( self, exponent=exponent, - sr_cutoff=sr_cutoff, atomic_smearing=atomic_smearing, lr_wavelength=lr_wavelength, subtract_self=subtract_self, @@ -240,14 +217,22 @@ def compute( positions: Union[List[torch.Tensor], torch.Tensor], charges: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor], + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor], ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute potential for all provided "systems" stacked inside list. The computation is performed on the same ``device`` as ``dtype`` is the input is stored on. The ``dtype`` of the output tensors will be the same as the input. + For computing a **neighborlist** a reasonable ``cutoff`` is half the length of + the shortest cell vector, which can be for example computed according as + + .. code-block:: python + + cell_dimensions = torch.linalg.norm(cell, dim=1) + cutoff = torch.min(cell_dimensions) / 2 - 1e-6 + :param positions: Single or 2D tensor of shape (``len(charges), 3``) containing the Cartesian positions of all point charges in the system. :param charges: Single 2D tensor or list of 2D tensor of shape (``n_channels, diff --git a/src/meshlode/calculators/pmepotential.py b/src/meshlode/calculators/pmepotential.py index 4de06b47..6bcb53aa 100644 --- a/src/meshlode/calculators/pmepotential.py +++ b/src/meshlode/calculators/pmepotential.py @@ -11,7 +11,6 @@ class _PMEPotentialImpl(_ShortRange): def __init__( self, exponent: float, - sr_cutoff: Union[None, torch.Tensor], atomic_smearing: Union[None, float], mesh_spacing: Union[None, float], interpolation_order: int, @@ -32,50 +31,31 @@ def __init__( self.atomic_smearing = atomic_smearing self.mesh_spacing = mesh_spacing self.interpolation_order = interpolation_order - self.sr_cutoff = sr_cutoff # If interior contributions are to be subtracted, also do so for self term if self.subtract_interior: subtract_self = True self.subtract_self = subtract_self - self.atomic_smearing = atomic_smearing - self.mesh_spacing = mesh_spacing - self.interpolation_order = interpolation_order - self.sr_cutoff = sr_cutoff - - # If interior contributions are to be subtracted, also do so for self term - if subtract_interior: - subtract_self = True - self.subtract_self = subtract_self - self.subtract_interior = subtract_interior - def _compute_single_system( self, positions: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, - neighbor_indices: Optional[torch.Tensor] = None, - neighbor_shifts: Optional[torch.Tensor] = None, + neighbor_indices: torch.Tensor, + neighbor_shifts: torch.Tensor, ) -> torch.Tensor: - # Set the defaut values of convergence parameters - # The total computational cost = cost of SR part + cost of LR part - # Bigger smearing increases the cost of the SR part while decreasing the cost - # of the LR part. Since the latter usually is more expensive, we maximize the - # value of the smearing by default to minimize the cost of the LR part. - # The two auxilary parameters (sr_cutoff, lr_wavelength) then control the - # convergence of the SR and LR sums, respectively. The default values are - # chosen to reach a convergence on the order of 1e-4 to 1e-5 for the test - # structures. - if self.sr_cutoff is None: - cell_dimensions = torch.linalg.norm(cell, dim=1) - cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6 - sr_cutoff = cutoff_max - else: - sr_cutoff = self.sr_cutoff - + # Set the defaut values of convergence parameters The total computational cost = + # cost of SR part + cost of LR part Bigger smearing increases the cost of the SR + # part while decreasing the cost of the LR part. Since the latter usually is + # more expensive, we maximize the value of the smearing by default to minimize + # the cost of the LR part. The auxilary parameter mesh_spacing then controls the + # convergence of the SR and LR sums, respectively. The default values are chosen + # to reach a convergence on the order of 1e-4 to 1e-5 for the test structures. if self.atomic_smearing is None: - smearing = sr_cutoff / 5.0 + cell_dimensions = torch.linalg.norm(cell, dim=1) + max_cutoff = torch.min(cell_dimensions) / 2 - 1e-6 + smearing = max_cutoff / 5.0 else: smearing = self.atomic_smearing @@ -90,7 +70,6 @@ def _compute_single_system( charges=charges, cell=cell, smearing=smearing, - sr_cutoff=sr_cutoff, neighbor_indices=neighbor_indices, neighbor_shifts=neighbor_shifts, ) @@ -170,20 +149,21 @@ class PMEPotential(CalculatorBaseTorch, _PMEPotentialImpl): Scaling as :math:`\mathcal{O}(NlogN)` with respect to the number of particles :math:`N` used as a reference to test faster implementations. - :param all_types: Optional global list of all atomic types that should be considered - for the computation. This option might be useful when running the calculation on - subset of a whole dataset and it required to keep the shape of the output - consistent. If this is not set the possible atomic types will be determined when - calling the :meth:`compute()`. + For computing a **neighborlist** a reasonable ``cutoff`` is half the length of + the shortest cell vector, which can be for example computed according as + + .. code-block:: python + + cell_dimensions = torch.linalg.norm(cell, dim=1) + cutoff = torch.min(cell_dimensions) / 2 - 1e-6 + + This ensures a accuracy of the short range part of ``1e-5``. + :param exponent: the exponent :math:`p` in :math:`1/r^p` potentials - :param sr_cutoff: Cutoff radius used for the short-range part of the Ewald sum. If - not set to a global value, it will be set to be half of the shortest lattice - vector defining the cell (separately for each structure). :param atomic_smearing: Width of the atom-centered Gaussian used to split the Coulomb potential into the short- and long-range parts. If not set to a global - value, it will be set to 1/5 times the sr_cutoff value (separately for each - structure) to ensure convergence of the short-range part to a relative precision - of 1e-5. + value, it will be set to 1/5 times of half the larges box vector (separately for + each structure). :param mesh_spacing: Value that determines the umber of Fourier-space grid points that will be used along each axis. If set to None, it will automatically be set to half of ``atomic_smearing``. @@ -204,6 +184,7 @@ class PMEPotential(CalculatorBaseTorch, _PMEPotentialImpl): reference value is :math:`2 \cdot 1.7626 / \sqrt{3} \approx 2.0354`. >>> import torch + >>> from vesin import NeighborList Define crystal structure @@ -211,10 +192,39 @@ class PMEPotential(CalculatorBaseTorch, _PMEPotentialImpl): >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) >>> cell = torch.eye(3) - Compute the potential + Compute the neighbor indices (``"i"``, ``"j"``) and the neighbor shifts ("``S``") + using the ``vesin`` package. Refer to the `documentation + `_ for details on the API. Similarly you can also use + ``ase``'s :py:func:`neighbor_list `. + + >>> cell_dimensions = torch.linalg.norm(cell, dim=1) + >>> cutoff = torch.min(cell_dimensions) / 2 - 1e-6 + >>> nl = NeighborList(cutoff=cutoff, full_list=True) + >>> i, j, S = nl.compute( + ... points=positions, box=cell, periodic=True, quantities="ijS" + ... ) + + The ``vesin`` calculator returned the indices and the neighbor shifts. We know stack + the together and convert them into the suitable types + + >>> i = torch.from_numpy(i.astype(int)) + >>> j = torch.from_numpy(j.astype(int)) + >>> neighbor_indices = torch.vstack([i, j]) + >>> neighbor_shifts = torch.from_numpy(S.astype(int)) + + If you inspect the neighborlist you will notice that they are empty for the given + system, which means the the whole potential will be calculated using the long range + part of the potential. Finally, we initlize the potential class and ``compute`` the + potential for the crystal >>> pme = PMEPotential() - >>> pme.compute(positions=positions, charges=charges, cell=cell) + >>> pme.compute( + ... positions=positions, + ... charges=charges, + ... cell=cell, + ... neighbor_indices=neighbor_indices, + ... neighbor_shifts=neighbor_shifts, + ... ) tensor([[-2.0384], [ 2.0384]]) @@ -224,7 +234,6 @@ class PMEPotential(CalculatorBaseTorch, _PMEPotentialImpl): def __init__( self, exponent: float = 1.0, - sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, mesh_spacing: Optional[float] = None, interpolation_order: int = 3, @@ -234,7 +243,6 @@ def __init__( _PMEPotentialImpl.__init__( self, exponent=exponent, - sr_cutoff=sr_cutoff, atomic_smearing=atomic_smearing, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, @@ -248,8 +256,8 @@ def compute( positions: Union[List[torch.Tensor], torch.Tensor], charges: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor], + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor], ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute potential for all provided "systems" stacked inside list. diff --git a/src/meshlode/metatensor/base.py b/src/meshlode/metatensor/base.py index 0bd3280d..524eca95 100644 --- a/src/meshlode/metatensor/base.py +++ b/src/meshlode/metatensor/base.py @@ -92,19 +92,22 @@ def compute(self, systems: Union[List[System], System]) -> TensorMap: charges = system.get_data("charges").values # try to extract neighbor list from system object - neighbor_indices = None - neighbor_shifts = None - for neighbor_list_options in system.known_neighbor_lists(): - if ( - hasattr(self, "sr_cutoff") - and neighbor_list_options.cutoff == self.sr_cutoff - ): - neighbor_list = system.get_neighbor_list(neighbor_list_options) - - neighbor_indices = neighbor_list.samples.values[:, :2].T - neighbor_shifts = neighbor_list.samples.values[:, 2:].T - - break + all_neighbor_list_options = system.known_neighbor_lists() + if all_neighbor_list_options: + if len(system.known_neighbor_lists()) > 1: + warnings.warn( + "Multiple neighbor lists found " + f"({len(system.known_neighbor_lists())}). Using the first one.", + stacklevel=2, + ) + + neighbor_list = system.get_neighbor_list(all_neighbor_list_options[0]) + + neighbor_indices = neighbor_list.samples.values[:, :2].T + neighbor_shifts = neighbor_list.samples.values[:, 2:] + else: + neighbor_indices = None + neighbor_shifts = None potentials.append( self._compute_single_system( diff --git a/src/meshlode/metatensor/ewaldpotential.py b/src/meshlode/metatensor/ewaldpotential.py index f16efcee..1e9ec82a 100644 --- a/src/meshlode/metatensor/ewaldpotential.py +++ b/src/meshlode/metatensor/ewaldpotential.py @@ -1,7 +1,5 @@ from typing import Optional -import torch - from ..calculators.ewaldpotential import _EwaldPotentialImpl from .base import CalculatorBaseMetatensor @@ -11,53 +9,13 @@ class EwaldPotential(CalculatorBaseMetatensor, _EwaldPotentialImpl): Refer to :class:`meshlode.EwaldPotential` for parameter documentation. - Example - ------- - We calculate the Madelung constant of a CsCl (Cesium-Chloride) crystal. The - reference value is :math:`2 \cdot 1.7626 / \sqrt{3} \approx 2.0354`. - - >>> import torch - >>> from metatensor.torch import Labels, TensorBlock - >>> from metatensor.torch.atomistic import System - - Define simple example structure - - >>> system = System( - ... types=torch.tensor([55, 17]), - ... positions=torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]), - ... cell=torch.eye(3), - ... ) - - Next we attach the charges to our ``system`` - - >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) - >>> data = TensorBlock( - ... values=charges, - ... samples=Labels.range("atom", charges.shape[0]), - ... components=[], - ... properties=Labels.range("charge", charges.shape[1]), - ... ) - >>> system.add_data(name="charges", data=data) - - and compute the potenial - - >>> ewald = EwaldPotential() - >>> potential = ewald.compute(system) - - The results are stored inside the ``values`` property inside the first - :py:class:`TensorBlock ` of the ``potential``. - - >>> potential[0].values - tensor([[-2.0354], - [ 2.0354]]) - - Which is the same as the reference value given above. + For an **example** on the usage refer to :py:class:`metatensor.PMEPotential + `. """ def __init__( self, exponent: float = 1.0, - sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, lr_wavelength: Optional[float] = None, subtract_self: bool = True, @@ -66,7 +24,6 @@ def __init__( _EwaldPotentialImpl.__init__( self, exponent=exponent, - sr_cutoff=sr_cutoff, atomic_smearing=atomic_smearing, lr_wavelength=lr_wavelength, subtract_self=subtract_self, diff --git a/src/meshlode/metatensor/pmepotential.py b/src/meshlode/metatensor/pmepotential.py index af4da3ad..2ee10dfa 100644 --- a/src/meshlode/metatensor/pmepotential.py +++ b/src/meshlode/metatensor/pmepotential.py @@ -1,7 +1,5 @@ from typing import Optional -import torch - from ..calculators.pmepotential import _PMEPotentialImpl from .base import CalculatorBaseMetatensor @@ -18,7 +16,8 @@ class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): >>> import torch >>> from metatensor.torch import Labels, TensorBlock - >>> from metatensor.torch.atomistic import System + >>> from metatensor.torch.atomistic import System, NeighborListOptions + >>> from vesin import NeighborList Define simple example structure @@ -28,7 +27,7 @@ class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): ... cell=torch.eye(3), ... ) - Next we attach the charges to our ``system`` + Next, we attach the charges to our ``system`` >>> charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) >>> data = TensorBlock( @@ -39,7 +38,62 @@ class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): ... ) >>> system.add_data(name="charges", data=data) - and compute the potenial + Compute the neighbor indices (``"i"``, ``"j"``) and the neighbor shifts ("``S``") + using the ``vesin`` package. Refer to the `documentation + `_ for details on the API. Similarly you can also use + ``ase``'s :py:func:`neighbor_list `. + + >>> cell_dimensions = torch.linalg.norm(system.cell, dim=1) + >>> cutoff = torch.min(cell_dimensions) / 2 - 1e-6 + >>> nl = NeighborList(cutoff=cutoff, full_list=True) + >>> i, j, S, D = nl.compute( + ... points=system.positions, box=system.cell, periodic=True, quantities="ijSD" + ... ) + + The ``vesin`` calculator returned the indices and the neighbor shifts. We know stack + the together and convert them into the suitable types + + >>> i = torch.from_numpy(i.astype(int)) + >>> j = torch.from_numpy(j.astype(int)) + >>> neighbor_indices = torch.vstack([i, j]) + >>> neighbor_shifts = torch.from_numpy(S.astype(int)) + + If you inspect the neighborlist you will notice that they are empty for the given + system, which means the the whole potential will be calculated using the long range + part of the potential. + + We now attach the neighbor list to the above defined ``system`` object. For this we + first create the ``samples`` metatadata for the :py:class:`TensorBlock + ` which will hold the neighbor list. + + >>> sample_values = torch.hstack([neighbor_indices.T, neighbor_shifts]) + >>> samples = Labels( + ... names=[ + ... "first_atom", + ... "second_atom", + ... "cell_shift_a", + ... "cell_shift_b", + ... "cell_shift_c", + ... ], + ... values=sample_values, + ... ) + + And wrap everything together and add it to our ``system``. + + >>> values = torch.from_numpy(D).reshape(-1, 3, 1) + >>> values = values.type(system.positions.dtype) + >>> neighbors = TensorBlock( + ... values=values, + ... samples=samples, + ... components=[Labels.range("xyz", 3)], + ... properties=Labels.range("distance", 1), + ... ) + >>> nl_options = NeighborListOptions(cutoff=cutoff, full_list=True) + >>> system.add_neighbor_list(options=nl_options, neighbors=neighbors) + + + Finally, we initlize the potential class and ``compute`` the + potential for the crystal >>> pme = PMEPotential() >>> potential = pme.compute(system) @@ -57,7 +111,6 @@ class PMEPotential(CalculatorBaseMetatensor, _PMEPotentialImpl): def __init__( self, exponent: float = 1.0, - sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, mesh_spacing: Optional[float] = None, interpolation_order: int = 3, @@ -67,7 +120,6 @@ def __init__( _PMEPotentialImpl.__init__( self, exponent=exponent, - sr_cutoff=sr_cutoff, atomic_smearing=atomic_smearing, mesh_spacing=mesh_spacing, interpolation_order=interpolation_order, diff --git a/tests/calculators/test_values_periodic.py b/tests/calculators/test_values_periodic.py index 7f9245d2..41b00467 100644 --- a/tests/calculators/test_values_periodic.py +++ b/tests/calculators/test_values_periodic.py @@ -7,6 +7,7 @@ # Imports for random structure from ase.io import read +from utils import neighbor_list_torch from meshlode import EwaldPotential, PMEPotential @@ -57,7 +58,6 @@ def define_crystal(crystal_name="CsCl"): # - 1 atom pair in the unit cell # - Cation-Anion ratio of 1:1 if crystal_name == "CsCl": - types = torch.tensor([17, 55]) # Cl and Cs positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype) charges = torch.tensor([-1.0, 1.0], dtype=dtype) cell = torch.eye(3, dtype=dtype) @@ -69,7 +69,6 @@ def define_crystal(crystal_name="CsCl"): # - 1 atom pair in the unit cell # - Cation-Anion ratio of 1:1 elif crystal_name == "NaCl_primitive": - types = torch.tensor([11, 17]) positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=dtype) charges = torch.tensor([1.0, -1.0], dtype=dtype) cell = torch.tensor([[0, 1.0, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) # fcc @@ -81,7 +80,6 @@ def define_crystal(crystal_name="CsCl"): # - 4 atom pairs in the unit cell # - Cation-Anion ratio of 1:1 elif crystal_name == "NaCl_cubic": - types = torch.tensor([11, 17, 17, 17, 11, 11, 11, 17]) positions = torch.tensor( [ [0.0, 0, 0], @@ -107,7 +105,6 @@ def define_crystal(crystal_name="CsCl"): # Remarks: we use a primitive unit cell which makes the lattice parameter of the # cubic cell equal to 2. elif crystal_name == "zincblende": - types = torch.tensor([16, 30]) positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]], dtype=dtype) charges = torch.tensor([1.0, -1], dtype=dtype) cell = torch.tensor([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) @@ -121,7 +118,6 @@ def define_crystal(crystal_name="CsCl"): elif crystal_name == "wurtzite": u = 3 / 8 c = np.sqrt(1 / u) - types = torch.tensor([16, 30, 16, 30]) positions = torch.tensor( [ [0.5, 0.5 / np.sqrt(3), 0.0], @@ -146,7 +142,6 @@ def define_crystal(crystal_name="CsCl"): elif crystal_name == "fluorite": a = 5.463 a = 1.0 - types = torch.tensor([9, 9, 20]) positions = a * torch.tensor( [[1 / 4, 1 / 4, 1 / 4], [3 / 4, 3 / 4, 3 / 4], [0, 0, 0]], dtype=dtype ) @@ -161,7 +156,6 @@ def define_crystal(crystal_name="CsCl"): # - Cation-Anion ratio of 2:1 elif crystal_name == "cu2o": a = 1.0 - types = torch.tensor([8, 8, 29, 29, 29, 29]) positions = a * torch.tensor( [ [0, 0, 0], @@ -188,7 +182,6 @@ def define_crystal(crystal_name="CsCl"): # Wigner crystal energies are taken from "Zero-Point Energy of an Electron Lattice" # by Rosemary A., Coldwell‐Horsfall and Alexei A. Maradudin (1960), eq. (A21). elif crystal_name == "wigner_sc": - types = torch.tensor([1]) positions = torch.tensor([[0, 0, 0]], dtype=dtype) charges = torch.tensor([1.0], dtype=dtype) cell = torch.tensor([[1.0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype) @@ -204,7 +197,6 @@ def define_crystal(crystal_name="CsCl"): # See description of "wigner_sc" for a general explanation on Wigner crystals. # Used to test the code for cases in which the unit cell has a nonzero net charge. elif crystal_name == "wigner_bcc": - types = torch.tensor([1]) positions = torch.tensor([[0, 0, 0]], dtype=dtype) charges = torch.tensor([1.0], dtype=dtype) cell = torch.tensor( @@ -222,7 +214,6 @@ def define_crystal(crystal_name="CsCl"): # Same as above, but now using a cubic unit cell rather than the primitive bcc cell elif crystal_name == "wigner_bcc_cubiccell": - types = torch.tensor([1, 1]) positions = torch.tensor([[0, 0, 0], [1 / 2, 1 / 2, 1 / 2]], dtype=dtype) charges = torch.tensor([1.0, 1.0], dtype=dtype) cell = torch.tensor([[1.0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype) @@ -240,7 +231,6 @@ def define_crystal(crystal_name="CsCl"): # See description of "wigner_sc" for a general explanation on Wigner crystals. # Used to test the code for cases in which the unit cell has a nonzero net charge. elif crystal_name == "wigner_fcc": - types = torch.tensor([1]) positions = torch.tensor([[0, 0, 0]], dtype=dtype) charges = torch.tensor([1.0], dtype=dtype) cell = torch.tensor([[1, 0, 1], [0, 1, 1], [1, 1, 0]], dtype=dtype) / 2 @@ -256,7 +246,6 @@ def define_crystal(crystal_name="CsCl"): # Same as above, but now using a cubic unit cell rather than the primitive fcc cell elif crystal_name == "wigner_fcc_cubiccell": - types = torch.tensor([1, 1, 1, 1]) positions = 0.5 * torch.tensor( [[0.0, 0, 0], [1, 0, 1], [1, 1, 0], [0, 1, 1]], dtype=dtype ) @@ -277,7 +266,8 @@ def define_crystal(crystal_name="CsCl"): madelung_ref = torch.tensor(madelung_ref, dtype=dtype) charges = charges.reshape((-1, 1)) - return types, positions, charges, cell, madelung_ref, num_formula_units + + return positions, charges, cell, madelung_ref, num_formula_units scaling_factors = torch.tensor([1 / 2.0353610, 1.0, 3.4951291], dtype=torch.float64) @@ -299,7 +289,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): to triclinic, as well as cation-anion ratios of 1:1, 1:2 and 2:1. """ # Get input parameters and adjust to account for scaling - types, pos, charges, cell, madelung_ref, num_units = define_crystal(crystal_name) + pos, charges, cell, madelung_ref, num_units = define_crystal(crystal_name) pos *= scaling_factor cell *= scaling_factor madelung_ref /= scaling_factor @@ -308,15 +298,28 @@ def test_madelung(crystal_name, scaling_factor, calc_name): # Define calculator and tolerances if calc_name == "ewald": sr_cutoff = scaling_factor * torch.tensor(1.0, dtype=dtype) - calc = EwaldPotential(sr_cutoff=sr_cutoff) + atomic_smearing = sr_cutoff / 5.0 + calc = EwaldPotential(atomic_smearing=atomic_smearing) rtol = 4e-6 elif calc_name == "pme": sr_cutoff = scaling_factor * torch.tensor(2.0, dtype=dtype) - calc = PMEPotential(sr_cutoff=sr_cutoff) + atomic_smearing = sr_cutoff / 5.0 + calc = PMEPotential(atomic_smearing=atomic_smearing) rtol = 9e-4 + # Compute neighbor list + neighbor_indices, neighbor_shifts = neighbor_list_torch( + positions=pos, cell=cell, cutoff=sr_cutoff.item() + ) + # Compute potential and compare against target value using default hypers - potentials = calc.compute(positions=pos, charges=charges, cell=cell) + potentials = calc.compute( + positions=pos, + charges=charges, + cell=cell, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) energies = potentials * charges madelung = -torch.sum(energies) / 2 / num_units @@ -340,7 +343,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): @pytest.mark.parametrize("scaling_factor", scaling_factors) def test_wigner(crystal_name, scaling_factor): """ - Check that the energy of a Wigner solid obtained from the Ewald sum calculator + Check that the energy of a Wigner solid obtained from the Ewald sum calculator matches the reference values. In this test, the Wigner solids are defined by placing arranging positively charged point particles on a bcc lattice, leading to a net charge of the unit cell if we @@ -351,11 +354,16 @@ def test_wigner(crystal_name, scaling_factor): to numerically slower convergence of the relevant sums. """ # Get parameters defining atomic positions, cell and charges - types, positions, charges, cell, madelung_ref, num = define_crystal(crystal_name) + positions, charges, cell, madelung_ref, _ = define_crystal(crystal_name) positions *= scaling_factor cell *= scaling_factor madelung_ref /= scaling_factor + # Compute neighbor list + neighbor_indices, neighbor_shifts = neighbor_list_torch( + positions=positions, cell=cell + ) + # Due to the slow convergence, we do not use the default values of the smearing, # but provide a range instead. The first value of 0.1 corresponds to what would be # chosen by default for the "wigner_sc" or "wigner_bcc_cubiccell" structure. @@ -373,7 +381,13 @@ def test_wigner(crystal_name, scaling_factor): # Compute potential and compare against reference calc = EwaldPotential(atomic_smearing=smeareff) - potentials = calc.compute(positions=positions, charges=charges, cell=cell) + potentials = calc.compute( + positions=positions, + charges=charges, + cell=cell, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) energies = potentials * charges energies_ref = -torch.ones_like(energies) * madelung_ref torch.testing.assert_close(energies, energies_ref, atol=0.0, rtol=rtol) @@ -415,22 +429,36 @@ def test_random_structure(sr_cutoff, frame_index, scaling_factor, ortho, calc_na # Convert into input format suitable for MeshLODE positions = scaling_factor * (torch.tensor(frame.positions, dtype=dtype) @ ortho) - positions.requires_grad = True cell = scaling_factor * torch.tensor(np.array(frame.cell), dtype=dtype) @ ortho charges = torch.tensor([1, 1, 1, 1, -1, -1, -1, -1], dtype=dtype).reshape((-1, 1)) + sr_cutoff = scaling_factor * sr_cutoff + atomic_smearing = sr_cutoff / 5.0 + + # Compute neighbor list + neighbor_indices, neighbor_shifts = neighbor_list_torch( + positions=positions, cell=cell, cutoff=sr_cutoff.item() + ) + + # Enable backward for positions + positions.requires_grad = True # Compute potential using MeshLODE and compare against reference values - sr_cutoff = scaling_factor * torch.tensor(sr_cutoff, dtype=dtype) if calc_name == "ewald": - calc = EwaldPotential(sr_cutoff=sr_cutoff) + calc = EwaldPotential(atomic_smearing=atomic_smearing) rtol_e = 2e-5 rtol_f = 3.6e-3 elif calc_name == "pme": - calc = PMEPotential(sr_cutoff=sr_cutoff) + calc = PMEPotential(atomic_smearing=atomic_smearing) rtol_e = 4.5e-3 # 1.5e-3 rtol_f = 2.5e-3 # 6e-3 - potentials = calc.compute(positions=positions, charges=charges, cell=cell) + potentials = calc.compute( + positions=positions, + charges=charges, + cell=cell, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) # Compute energy, taking into account the double counting of each pair energy = torch.sum(potentials * charges) / 2 diff --git a/tests/calculators/test_workflow.py b/tests/calculators/test_workflow.py index 26985820..55a7c4e9 100644 --- a/tests/calculators/test_workflow.py +++ b/tests/calculators/test_workflow.py @@ -6,6 +6,7 @@ import pytest import torch from torch.testing import assert_close +from utils import neighbor_list_torch from meshlode import DirectPotential, EwaldPotential, PMEPotential @@ -53,7 +54,11 @@ def cscl_system(self, periodic): charges = torch.tensor([1.0, -1.0]).reshape((-1, 1)) if periodic: cell = torch.eye(3) - return positions, charges, cell + + neighbor_indices, neighbor_shifts = neighbor_list_torch( + positions=positions, cell=cell + ) + return positions, charges, cell, neighbor_indices, neighbor_shifts else: return positions, charges @@ -86,11 +91,15 @@ def test_interpolation_order_error(self, CalculatorClass, params, periodic): def test_multi_frame(self, CalculatorClass, periodic, params): calculator = self.calculator(CalculatorClass, periodic, params) if periodic: - positions, charges, cell = self.cscl_system(periodic) + positions, charges, cell, neighbor_indices, neighbor_shifts = ( + self.cscl_system(periodic) + ) l_values = calculator.compute( positions=[positions, positions], cell=[cell, cell], charges=[charges, charges], + neighbor_indices=[neighbor_indices, neighbor_indices], + neighbor_shifts=[neighbor_shifts, neighbor_shifts], ) else: positions, charges = self.cscl_system(periodic) @@ -117,8 +126,14 @@ def test_dtype_device(self, CalculatorClass, periodic, params): calculator = self.calculator(CalculatorClass, periodic, params) if periodic: cell = torch.eye(3, dtype=dtype, device=device) + neighbor_indices = torch.tensor([0, 0]).reshape(-1, 1) + neighbor_shifts = torch.tensor([0, 0, 0]).reshape(1, -1) potential = calculator.compute( - positions=positions, charges=charges, cell=cell + positions=positions, + charges=charges, + cell=cell, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, ) else: potential = calculator.compute(positions=positions, charges=charges) @@ -130,15 +145,7 @@ def test_dtype_device(self, CalculatorClass, periodic, params): # and returns the correct output format (torch.Tensor) def check_operation(self, CalculatorClass, periodic, params): calculator = self.calculator(CalculatorClass, periodic, params) - - if periodic: - positions, charges, cell = self.cscl_system(periodic) - descriptor = calculator.compute( - positions=positions, charges=charges, cell=cell - ) - else: - positions, charges = self.cscl_system(periodic) - descriptor = calculator.compute(positions=positions, charges=charges) + descriptor = calculator.compute(*self.cscl_system(periodic)) assert type(descriptor) is torch.Tensor diff --git a/tests/calculators/utils.py b/tests/calculators/utils.py new file mode 100644 index 00000000..7a62a0bd --- /dev/null +++ b/tests/calculators/utils.py @@ -0,0 +1,27 @@ +"""Test utilities wrap common functions in the tests""" + +from typing import Optional, Tuple + +import torch +from vesin import NeighborList + + +def neighbor_list_torch( + positions: torch.tensor, cell: torch.tensor, cutoff: Optional[float] = None +) -> Tuple[torch.tensor, torch.tensor]: + + if cutoff is None: + cell_dimensions = torch.linalg.norm(cell, dim=1) + cutoff_torch = torch.min(cell_dimensions) / 2 - 1e-6 + cutoff = cutoff_torch.item() + + nl = NeighborList(cutoff=cutoff, full_list=True) + i, j, S = nl.compute(points=positions, box=cell, periodic=True, quantities="ijS") + + i = torch.from_numpy(i.astype(int)) + j = torch.from_numpy(j.astype(int)) + + neighbor_indices = torch.vstack([i, j]) + neighbor_shifts = torch.from_numpy(S.astype(int)) + + return neighbor_indices, neighbor_shifts diff --git a/tests/metatensor/test_base_metatensor.py b/tests/metatensor/test_base_metatensor.py index cca6dd43..43667a0b 100644 --- a/tests/metatensor/test_base_metatensor.py +++ b/tests/metatensor/test_base_metatensor.py @@ -1,6 +1,7 @@ import pytest import torch from packaging import version +from utils_metatensor import add_neighbor_list import meshlode @@ -195,3 +196,30 @@ def test_different_number_charge_channles(): ) with pytest.raises(ValueError, match=match): calculator.compute([system1, system2]) + + +def test_multiple_neighborlist_warning(): + system = mts_atomistic.System( + types=torch.tensor([1, 1]), + positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]]), + cell=torch.eye(3), + ) + + charges = torch.tensor([1.0, -1.0]).reshape(-1, 1) + data = mts_torch.TensorBlock( + values=charges, + samples=mts_torch.Labels.range("atom", charges.shape[0]), + components=[], + properties=mts_torch.Labels.range("charge", charges.shape[1]), + ) + + system.add_data(name="charges", data=data) + + add_neighbor_list(system, cutoff=1.0) + add_neighbor_list(system, cutoff=2.0) + + calculator = CalculatorTest() + + match = r"Multiple neighbor lists found \(2\). Using the first one." + with pytest.warns(UserWarning, match=match): + calculator.compute(system) diff --git a/tests/metatensor/test_workflow_metatensor.py b/tests/metatensor/test_workflow_metatensor.py index f5a28680..0c634823 100644 --- a/tests/metatensor/test_workflow_metatensor.py +++ b/tests/metatensor/test_workflow_metatensor.py @@ -5,6 +5,7 @@ import pytest import torch from packaging import version +from utils_metatensor import add_neighbor_list import meshlode @@ -60,6 +61,7 @@ def cscl_system(self): properties=mts_torch.Labels("charge", torch.tensor([[0]])), ) system.add_data(name="charges", data=data) + add_neighbor_list(system) return system diff --git a/tests/metatensor/utils_metatensor.py b/tests/metatensor/utils_metatensor.py new file mode 100644 index 00000000..f0547477 --- /dev/null +++ b/tests/metatensor/utils_metatensor.py @@ -0,0 +1,53 @@ +"""Test utilities wrap common functions in the metatensor tests""" + +from typing import Optional + +import pytest +import torch +from vesin import NeighborList + + +mts_torch = pytest.importorskip("metatensor.torch") +mts_atomistic = pytest.importorskip("metatensor.torch.atomistic") + + +def add_neighbor_list(system, cutoff: Optional[float] = None) -> None: + if cutoff is None: + cell_dimensions = torch.linalg.norm(system.cell, dim=1) + cutoff_torch = torch.min(cell_dimensions) / 2 - 1e-6 + cutoff = cutoff_torch.item() + + nl = NeighborList(cutoff=cutoff, full_list=True) + i, j, S, D = nl.compute( + points=system.positions, box=system.cell, periodic=True, quantities="ijSD" + ) + + i = torch.from_numpy(i.astype(int)) + j = torch.from_numpy(j.astype(int)) + + neighbor_indices = torch.vstack([i, j]) + neighbor_shifts = torch.from_numpy(S.astype(int)) + + sample_values = torch.hstack([neighbor_indices.T, neighbor_shifts]) + samples = mts_torch.Labels( + names=[ + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + values=sample_values, + ) + + values = torch.from_numpy(D).reshape(-1, 3, 1) + values = values.type(system.positions.dtype) + neighbors = mts_torch.TensorBlock( + values=values, + samples=samples, + components=[mts_torch.Labels.range("xyz", 3)], + properties=mts_torch.Labels.range("distance", 1), + ) + + nl_options = mts_atomistic.NeighborListOptions(cutoff=cutoff, full_list=True) + system.add_neighbor_list(options=nl_options, neighbors=neighbors) diff --git a/tox.ini b/tox.ini index 9f533953..4b79f346 100644 --- a/tox.ini +++ b/tox.ini @@ -44,10 +44,12 @@ commands = description = Run ALL test suite with pytest and {basepython}. usedevelop = true deps = + ase coverage[toml] metatensor-operations pytest pytest-cov + vesin extras = metatensor commands = @@ -55,7 +57,7 @@ commands = pytest {[testenv]test_options} {posargs} # Run documentation tests - pytest {[testenv]warning_options} --doctest-modules --pyargs meshlode {posargs} + pytest {[testenv]warning_options} --doctest-modules --pyargs meshlode [testenv:tests-min] description = Run the minimal core tests with pytest and {basepython}. @@ -64,6 +66,7 @@ deps = coverage[toml] pytest pytest-cov + vesin commands = # Run unit tests