From 6f1c8667b5fd5b518fe7ce79f6651751cdf7abab Mon Sep 17 00:00:00 2001 From: Kevin Kazuki Huguenin-Dumittan Date: Thu, 20 Jun 2024 16:59:56 +0200 Subject: [PATCH] Add option to provide external neighborlist --- examples/neighborlist_example.ipynb | 199 ++++++++++++++++++ .../calculators/calculator_base_periodic.py | 48 ++++- src/meshlode/calculators/meshewald.py | 63 ++++-- src/meshlode/metatensor/__init__.py | 1 - src/meshlode/metatensor/meshewald.py | 45 +++- tests/calculators/test_workflow_meshewald.py | 12 +- 6 files changed, 325 insertions(+), 43 deletions(-) create mode 100644 examples/neighborlist_example.ipynb diff --git a/examples/neighborlist_example.ipynb b/examples/neighborlist_example.ipynb new file mode 100644 index 00000000..f5df529e --- /dev/null +++ b/examples/neighborlist_example.ipynb @@ -0,0 +1,199 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import meshlode\n", + "import torch\n", + "import numpy as np\n", + "import math\n", + "from metatensor.torch.atomistic import System\n", + "\n", + "from ase import Atoms\n", + "from ase.neighborlist import neighbor_list" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Define simple example structure having the CsCl structure and compute the reference\n", + "# values. MeshPotential by default outputs the types sorted according to the atomic\n", + "# number. Thus, we input the compound \"CsCl\" and \"ClCs\" since Cl and Cs have atomic\n", + "# numbers 17 and 55, respectively.\n", + "types = torch.tensor([17, 55]) # Cl and Cs\n", + "positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])\n", + "charges = torch.tensor([-1.0, 1.0])\n", + "cell = torch.eye(3)\n", + "\n", + "# %%\n", + "# Define the expected values of the energy\n", + "n_atoms = len(types)\n", + "madelung = 2 * 1.7626 / math.sqrt(3)\n", + "energies_ref = -madelung * torch.ones((n_atoms, 1))\n", + "\n", + "# %%\n", + "# We first define general parameters for our calculation MeshLODE\n", + "\n", + "atomic_smearing = 0.1\n", + "cell = torch.eye(3)\n", + "mesh_spacing = atomic_smearing / 4\n", + "interpolation_order = 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate neighbor list using ASE" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "sr_cutoff = np.sqrt(3) * 0.8\n", + "struc = Atoms(positions=positions, cell=cell, pbc=True)\n", + "atom_is, atom_js, neighbor_shifts = neighbor_list(\"ijS\", struc, sr_cutoff, self_interaction=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Convert neighbor list from ASE to desired format (torch tensor of dtype int)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "atom_is = atom_is.reshape((-1,1))\n", + "atom_js = atom_js.reshape((-1,1))\n", + "neighbor_indices = torch.tensor(np.hstack([atom_is, atom_js]))\n", + "neighbor_shifts = torch.tensor(neighbor_shifts)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/khugueni/code/MeshLODE/src/meshlode/calculators/meshewald.py:336: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " positions[j] - positions[i] + torch.tensor(shift @ cell)\n" + ] + } + ], + "source": [ + "system = System(types=types, positions=positions, cell=cell)\n", + "\n", + "MP = meshlode.metatensor.MeshEwaldPotential(\n", + " atomic_smearing=atomic_smearing,\n", + " mesh_spacing=mesh_spacing,\n", + " interpolation_order=interpolation_order,\n", + " subtract_self=True,\n", + " sr_cutoff=sr_cutoff,\n", + ")\n", + "potential_metatensor = MP.compute(system, neighbor_indices=neighbor_indices, neighbor_shifts=neighbor_shifts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Convert to Madelung constant and check that the value is correct" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(17) tensor(17) tensor(1.) tensor(-2.7745)\n", + "tensor(17) tensor(55) tensor(-1.) tensor(-0.7391)\n", + "tensor(55) tensor(17) tensor(-1.) tensor(-0.7391)\n", + "tensor(55) tensor(55) tensor(1.) tensor(-2.7745)\n", + "Using the metatensor version\n", + "Computed energies on each atom = [[-2.035360813140869], [-2.035360813140869]]\n", + "Reference Madelung constant = 2.035\n", + "Total energy = -4.071\n" + ] + } + ], + "source": [ + "atomic_energies_metatensor = torch.zeros((n_atoms, 1))\n", + "for idx_c, c in enumerate(types):\n", + " for idx_n, n in enumerate(types):\n", + " # Take the coefficients with the correct center atom and neighbor atom types\n", + " block = potential_metatensor.block(\n", + " {\"center_type\": int(c), \"neighbor_type\": int(n)}\n", + " )\n", + "\n", + " # The coulomb potential between atoms i and j is charge_i * charge_j / d_ij\n", + " # The features are simply computing a pure 1/r potential with no prefactors.\n", + " # Thus, to compute the energy between atoms of types i and j, we need to\n", + " # multiply by the charges of i and j.\n", + " print(c, n, charges[idx_c] * charges[idx_n], block.values[0, 0])\n", + " atomic_energies_metatensor[idx_c] += (\n", + " charges[idx_c] * charges[idx_n] * block.values[0, 0]\n", + " )\n", + "\n", + "# %%\n", + "# The total energy is just the sum of all atomic energies\n", + "total_energy_metatensor = torch.sum(atomic_energies_metatensor)\n", + "\n", + "# %%\n", + "# Compare against reference Madelung constant and reference energy:\n", + "print(\"Using the metatensor version\")\n", + "print(f\"Computed energies on each atom = {atomic_energies_metatensor.tolist()}\")\n", + "print(f\"Reference Madelung constant = {madelung:.3f}\")\n", + "print(f\"Total energy = {total_energy_metatensor:.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/meshlode/calculators/calculator_base_periodic.py b/src/meshlode/calculators/calculator_base_periodic.py index 883f6f71..9ed7d10c 100644 --- a/src/meshlode/calculators/calculator_base_periodic.py +++ b/src/meshlode/calculators/calculator_base_periodic.py @@ -36,6 +36,8 @@ def compute( positions: Union[List[torch.Tensor], torch.Tensor], cell: Union[List[torch.Tensor], torch.Tensor], charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute potential for all provided "systems" stacked inside list. @@ -51,6 +53,12 @@ def compute( vector; and columns should contain the x, y, and z components of these vectors (i.e. the cell should be given in row-major order). :param charges: Optional single or list of 2D tensor of shape (len(types), n), + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. :return: List of torch Tensors containing the potentials for all frames and all atoms. Each tensor in the list is of shape (n_atoms, n_types), where @@ -155,17 +163,39 @@ def compute( # We don't require and test that all dtypes and devices are consistent if a list # of inputs. Each "frame" is processed independently. potentials = [] - for positions_single, cell_single, charges_single in zip( - positions, cell, charges - ): - # Compute the potentials - potentials.append( - self._compute_single_system( - positions=positions_single, charges=charges_single, cell=cell_single + + if neighbor_indices is None: + for positions_single, cell_single, charges_single in zip( + positions, cell, charges + ): + # Compute the potentials + potentials.append( + self._compute_single_system( + positions=positions_single, + charges=charges_single, + cell=cell_single, + ) + ) + else: + for ( + positions_single, + cell_single, + charges_single, + neighbor_indices_single, + neighbor_shifts_single, + ) in zip(positions, cell, charges, neighbor_indices, neighbor_shifts): + # Compute the potentials + potentials.append( + self._compute_single_system( + positions=positions_single, + charges=charges_single, + cell=cell_single, + neighbor_indices=neighbor_indices_single, + neighbor_shifts=neighbor_shifts_single, + ) ) - ) if len(types) == 1: return potentials[0] else: - return potentials \ No newline at end of file + return potentials diff --git a/src/meshlode/calculators/meshewald.py b/src/meshlode/calculators/meshewald.py index 4a129e13..da19320f 100644 --- a/src/meshlode/calculators/meshewald.py +++ b/src/meshlode/calculators/meshewald.py @@ -1,14 +1,19 @@ -import torch from typing import List, Optional -# from .mesh import MeshPotential -from .calculator_base_periodic import CalculatorBasePeriodic -from meshlode.lib.mesh_interpolator import MeshInterpolator +import torch # extra imports for neighbor list from ase import Atoms from ase.neighborlist import neighbor_list +from meshlode.lib.mesh_interpolator import MeshInterpolator + +from .calculator_base import default_exponent + +# from .mesh import MeshPotential +from .calculator_base_periodic import CalculatorBasePeriodic + + class MeshEwaldPotential(CalculatorBasePeriodic): """A specie-wise long-range potential computed using a mesh-based Ewald method, scaling as O(NlogN) with respect to the number of particles N used as a reference @@ -46,13 +51,13 @@ class MeshEwaldPotential(CalculatorBasePeriodic): def __init__( self, all_types: Optional[List[int]] = None, - exponent: Optional[torch.Tensor] = torch.tensor(1., dtype=torch.float64), - sr_cutoff: Optional[float] = None, + exponent: Optional[torch.Tensor] = default_exponent, + sr_cutoff: Optional[torch.Tensor] = None, atomic_smearing: Optional[float] = None, mesh_spacing: Optional[float] = None, subtract_self: Optional[bool] = True, interpolation_order: Optional[int] = 4, - subtract_interior: Optional[bool] = False + subtract_interior: Optional[bool] = False, ): super().__init__(all_types=all_types, exponent=exponent) @@ -129,6 +134,8 @@ def _compute_single_system( positions: torch.Tensor, charges: torch.Tensor, cell: torch.Tensor, + neighbor_indices: Optional[torch.Tensor] = None, + neighbor_shifts: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute the "electrostatic" potential at the position of all atoms in a @@ -162,7 +169,7 @@ def _compute_single_system( cutoff_max = torch.min(cell_dimensions) / 2 - 1e-6 if self.sr_cutoff is not None: if self.sr_cutoff > torch.min(cell_dimensions) / 2: - raise ValueError(f"sr_cutoff {sr_cutoff} needs to be > {cutoff_max}") + raise ValueError(f"sr_cutoff {self.sr_cutoff} has to be > {cutoff_max}") # Set the defaut values of convergence parameters # The total computational cost = cost of SR part + cost of LR part @@ -184,7 +191,7 @@ def _compute_single_system( smearing = self.atomic_smearing if self.mesh_spacing is None: - mesh_spacing = smearing / 8. + mesh_spacing = smearing / 8.0 else: mesh_spacing = self.mesh_spacing @@ -195,6 +202,8 @@ def _compute_single_system( cell=cell, smearing=smearing, sr_cutoff=sr_cutoff, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts ) # Compute long-range (LR) part using a Fourier / reciprocal space sum @@ -203,7 +212,8 @@ def _compute_single_system( charges=charges, cell=cell, smearing=smearing, - lr_wavelength=mesh_spacing) + lr_wavelength=mesh_spacing, + ) # Combine both parts to obtain the full potential potential_ewald = potential_sr + potential_lr @@ -233,16 +243,15 @@ def _compute_lr( structure, where cell[i] is the i-th basis vector. :param smearing: torch.Tensor smearing paramter determining the splitting between the SR and LR parts. - :param lr_wavelength: Spatial resolution used for the long-range (reciprocal space) - part of the Ewald sum. More conretely, all Fourier space vectors with a - wavelength >= this value will be kept. + :param lr_wavelength: Spatial resolution used for the long-range (reciprocal + space) part of the Ewald sum. More conretely, all Fourier space vectors with + a wavelength >= this value will be kept. :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential at the position of each atom for the `n_channels` independent meshes separately. """ # Step 0 (Preparation): Compute number of times each basis vector of the # reciprocal space can be scaled until the cutoff is reached - mesh_spacing = lr_wavelength k_cutoff = 2 * torch.pi / lr_wavelength basis_norms = torch.linalg.norm(cell, dim=1) ns_approx = k_cutoff * basis_norms / 2 / torch.pi @@ -258,11 +267,11 @@ def _compute_lr( # Step 2.1: Generate k-vectors and evaluate kernel function kvectors = self._generate_kvectors(ns=ns, cell=cell) knorm_sq = torch.sum(kvectors**2, dim=3) - + # Step 2.2: Evaluate kernel function (careful, tensor shapes are different from # the pure Ewald implementation since we are no longer flattening) G = self.potential.potential_fourier_from_k_sq(knorm_sq, smearing) - G[0,0,0] = self.potential.potential_fourier_at_zero(smearing) + G[0, 0, 0] = self.potential.potential_fourier_at_zero(smearing) potential_mesh = rho_mesh @@ -293,6 +302,8 @@ def _compute_sr( cell: torch.Tensor, smearing: torch.Tensor, sr_cutoff: torch.Tensor, + neighbor_indices: Optional[torch.Tensor] = None, + neighbor_shifts: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute the short-range part of the Ewald sum in realspace @@ -314,16 +325,24 @@ def _compute_sr( :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential at the position of each atom for the `n_channels` independent meshes separately. """ - # Get list of neighbors - struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) - atom_is, atom_js, shifts = neighbor_list( - "ijS", struc, sr_cutoff.item(), self_interaction=False - ) + if neighbor_indices is None: + # Get list of neighbors + struc = Atoms(positions=positions.detach().numpy(), cell=cell, pbc=True) + atom_is, atom_js, shifts = neighbor_list( + "ijS", struc, sr_cutoff.item(), self_interaction=False + ) + else: + atom_is = neighbor_indices[:,0] + atom_js = neighbor_indices[:,1] + shifts = neighbor_shifts.T + # Compute energy potential = torch.zeros_like(charges) for i, j, shift in zip(atom_is, atom_js, shifts): - dist = torch.linalg.norm(positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell))) + dist = torch.linalg.norm( + positions[j] - positions[i] + torch.tensor(shift.dot(struc.cell)) + ) # If the contribution from all atoms within the cutoff is to be subtracted # this short-range part will simply use -V_LR as the potential diff --git a/src/meshlode/metatensor/__init__.py b/src/meshlode/metatensor/__init__.py index ea6acb91..8afbae3a 100644 --- a/src/meshlode/metatensor/__init__.py +++ b/src/meshlode/metatensor/__init__.py @@ -1,5 +1,4 @@ from .meshpotential import MeshPotential -from .ewaldpotential import EwaldPotential from .meshewald import MeshEwaldPotential __all__ = ["MeshPotential", "EwaldPotential", "MeshEwaldPotential"] diff --git a/src/meshlode/metatensor/meshewald.py b/src/meshlode/metatensor/meshewald.py index 16213d28..9351aee2 100644 --- a/src/meshlode/metatensor/meshewald.py +++ b/src/meshlode/metatensor/meshewald.py @@ -30,13 +30,21 @@ class MeshEwaldPotential(calculators.MeshEwaldPotential): def forward( self, systems: Union[List[System], System], + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> TensorMap: """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute(systems=systems) + return self.compute( + systems=systems, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) def compute( self, systems: Union[List[System], System], + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, ) -> TensorMap: """Compute potential for all provided ``systems``. @@ -61,6 +69,20 @@ def compute( # provided as input (for convenience of users testing out the code) if not isinstance(systems, list): systems = [systems] + if (neighbor_indices is not None) and not isinstance(neighbor_indices, list): + neighbor_indices = [neighbor_indices] + if (neighbor_shifts is not None) and not isinstance(neighbor_shifts, list): + neighbor_shifts = [neighbor_shifts] + + # Check that the lengths of the provided lists agree + if (neighbor_indices is not None) and len(neighbor_indices) != len(systems): + raise ValueError( + f"Numbers of systems (= {len(systems)}) needs to match number of neighbor lists (= {len(neighbor_indices)})" + ) + if (neighbor_shifts is not None) and len(neighbor_shifts) != len(systems): + raise ValueError( + f"Numbers of systems (= {len(systems)}) needs to match number of neighbor shifts (= {len(neighbor_shifts)})" + ) if len(systems) > 1: for system in systems[1:]: @@ -122,7 +144,7 @@ def compute( n_blocks = n_types * n_charges_channels feat_dic: Dict[int, List[torch.Tensor]] = {a: [] for a in range(n_blocks)} - for system in systems: + for i, system in enumerate(systems): if use_explicit_charges: charges = system.get_data("charges").values else: @@ -131,10 +153,21 @@ def compute( system.types, requested_types, dtype, device ) - # Compute the potentials - potential = self._compute_single_system( - system.positions, charges, system.cell - ) + if neighbor_indices is None or neighbor_shifts is None: + # Compute the potentials + potential = self._compute_single_system( + positions=system.positions, + charges=charges, + cell=system.cell, + ) + else: + potential = self._compute_single_system( + positions=system.positions, + charges=charges, + cell=system.cell, + neighbor_indices=neighbor_indices[i], + neighbor_shifts=neighbor_shifts[i], + ) # Reorder data into metatensor format for spec_center, at_num_center in enumerate(requested_types): diff --git a/tests/calculators/test_workflow_meshewald.py b/tests/calculators/test_workflow_meshewald.py index a5d73ef2..9fc5a644 100644 --- a/tests/calculators/test_workflow_meshewald.py +++ b/tests/calculators/test_workflow_meshewald.py @@ -7,7 +7,7 @@ import torch from torch.testing import assert_close -from meshlode import MeshPotential, MeshEwaldPotential +from meshlode import MeshEwaldPotential, MeshPotential from meshlode.calculators.calculator_base import _1d_tolist, _is_subset @@ -85,7 +85,8 @@ def check_operation(calculator): def test_operation_as_python(): check_operation(descriptor()) -""" + +""" # Similar to the above, but also testing that the code can be compiled as a torch script # Disabled for now since (1) the ASE neighbor list and (2) the use of the potential # class are clashing with the torch script capabilities. @@ -94,6 +95,7 @@ def test_operation_as_torch_script(): check_operation(scripted) """ + def test_single_frame(): values = descriptor().compute(*cscl_system()) assert_close( @@ -245,6 +247,7 @@ def test_inconsistent_dtype(): with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell) + def test_inconsistent_device(): """Test if the cell and positions have inconsistent device and error is raised.""" types = torch.tensor([1], device="cpu") @@ -253,9 +256,8 @@ def test_inconsistent_device(): MP = MeshPotential(atomic_smearing=0.2) - match = ( - '`types`, `positions`, and `cell` must be on the same device, got cpu, cpu and meta.' - ) + match = "`types`, `positions`, and `cell` must be on the same device, got cpu, cpu " + match += "and meta." with pytest.raises(ValueError, match=match): MP.compute(types=types, positions=positions, cell=cell)