Skip to content

Commit

Permalink
Add option to provide external neighborlist
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Kazuki Huguenin-Dumittan committed Jun 20, 2024
1 parent 87f77dd commit 6f1c866
Show file tree
Hide file tree
Showing 6 changed files with 325 additions and 43 deletions.
199 changes: 199 additions & 0 deletions examples/neighborlist_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import meshlode\n",
"import torch\n",
"import numpy as np\n",
"import math\n",
"from metatensor.torch.atomistic import System\n",
"\n",
"from ase import Atoms\n",
"from ase.neighborlist import neighbor_list"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Define simple example structure having the CsCl structure and compute the reference\n",
"# values. MeshPotential by default outputs the types sorted according to the atomic\n",
"# number. Thus, we input the compound \"CsCl\" and \"ClCs\" since Cl and Cs have atomic\n",
"# numbers 17 and 55, respectively.\n",
"types = torch.tensor([17, 55]) # Cl and Cs\n",
"positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])\n",
"charges = torch.tensor([-1.0, 1.0])\n",
"cell = torch.eye(3)\n",
"\n",
"# %%\n",
"# Define the expected values of the energy\n",
"n_atoms = len(types)\n",
"madelung = 2 * 1.7626 / math.sqrt(3)\n",
"energies_ref = -madelung * torch.ones((n_atoms, 1))\n",
"\n",
"# %%\n",
"# We first define general parameters for our calculation MeshLODE\n",
"\n",
"atomic_smearing = 0.1\n",
"cell = torch.eye(3)\n",
"mesh_spacing = atomic_smearing / 4\n",
"interpolation_order = 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generate neighbor list using ASE"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"sr_cutoff = np.sqrt(3) * 0.8\n",
"struc = Atoms(positions=positions, cell=cell, pbc=True)\n",
"atom_is, atom_js, neighbor_shifts = neighbor_list(\"ijS\", struc, sr_cutoff, self_interaction=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convert neighbor list from ASE to desired format (torch tensor of dtype int)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"atom_is = atom_is.reshape((-1,1))\n",
"atom_js = atom_js.reshape((-1,1))\n",
"neighbor_indices = torch.tensor(np.hstack([atom_is, atom_js]))\n",
"neighbor_shifts = torch.tensor(neighbor_shifts)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/khugueni/code/MeshLODE/src/meshlode/calculators/meshewald.py:336: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" positions[j] - positions[i] + torch.tensor(shift @ cell)\n"
]
}
],
"source": [
"system = System(types=types, positions=positions, cell=cell)\n",
"\n",
"MP = meshlode.metatensor.MeshEwaldPotential(\n",
" atomic_smearing=atomic_smearing,\n",
" mesh_spacing=mesh_spacing,\n",
" interpolation_order=interpolation_order,\n",
" subtract_self=True,\n",
" sr_cutoff=sr_cutoff,\n",
")\n",
"potential_metatensor = MP.compute(system, neighbor_indices=neighbor_indices, neighbor_shifts=neighbor_shifts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convert to Madelung constant and check that the value is correct"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(17) tensor(17) tensor(1.) tensor(-2.7745)\n",
"tensor(17) tensor(55) tensor(-1.) tensor(-0.7391)\n",
"tensor(55) tensor(17) tensor(-1.) tensor(-0.7391)\n",
"tensor(55) tensor(55) tensor(1.) tensor(-2.7745)\n",
"Using the metatensor version\n",
"Computed energies on each atom = [[-2.035360813140869], [-2.035360813140869]]\n",
"Reference Madelung constant = 2.035\n",
"Total energy = -4.071\n"
]
}
],
"source": [
"atomic_energies_metatensor = torch.zeros((n_atoms, 1))\n",
"for idx_c, c in enumerate(types):\n",
" for idx_n, n in enumerate(types):\n",
" # Take the coefficients with the correct center atom and neighbor atom types\n",
" block = potential_metatensor.block(\n",
" {\"center_type\": int(c), \"neighbor_type\": int(n)}\n",
" )\n",
"\n",
" # The coulomb potential between atoms i and j is charge_i * charge_j / d_ij\n",
" # The features are simply computing a pure 1/r potential with no prefactors.\n",
" # Thus, to compute the energy between atoms of types i and j, we need to\n",
" # multiply by the charges of i and j.\n",
" print(c, n, charges[idx_c] * charges[idx_n], block.values[0, 0])\n",
" atomic_energies_metatensor[idx_c] += (\n",
" charges[idx_c] * charges[idx_n] * block.values[0, 0]\n",
" )\n",
"\n",
"# %%\n",
"# The total energy is just the sum of all atomic energies\n",
"total_energy_metatensor = torch.sum(atomic_energies_metatensor)\n",
"\n",
"# %%\n",
"# Compare against reference Madelung constant and reference energy:\n",
"print(\"Using the metatensor version\")\n",
"print(f\"Computed energies on each atom = {atomic_energies_metatensor.tolist()}\")\n",
"print(f\"Reference Madelung constant = {madelung:.3f}\")\n",
"print(f\"Total energy = {total_energy_metatensor:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
48 changes: 39 additions & 9 deletions src/meshlode/calculators/calculator_base_periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def compute(
positions: Union[List[torch.Tensor], torch.Tensor],
cell: Union[List[torch.Tensor], torch.Tensor],
charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None,
neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Compute potential for all provided "systems" stacked inside list.
Expand All @@ -51,6 +53,12 @@ def compute(
vector; and columns should contain the x, y, and z components of these
vectors (i.e. the cell should be given in row-major order).
:param charges: Optional single or list of 2D tensor of shape (len(types), n),
:param neighbor_indices: Optional single or list of 2D tensors of shape (2, n),
where n is the number of atoms. The 2 rows correspond to the indices of
the two atoms which are considered neighbors (e.g. within a cutoff distance)
:param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n),
where n is the number of atoms. The 3 rows correspond to the shift indices
for periodic images.
:return: List of torch Tensors containing the potentials for all frames and all
atoms. Each tensor in the list is of shape (n_atoms, n_types), where
Expand Down Expand Up @@ -155,17 +163,39 @@ def compute(
# We don't require and test that all dtypes and devices are consistent if a list
# of inputs. Each "frame" is processed independently.
potentials = []
for positions_single, cell_single, charges_single in zip(
positions, cell, charges
):
# Compute the potentials
potentials.append(
self._compute_single_system(
positions=positions_single, charges=charges_single, cell=cell_single

if neighbor_indices is None:
for positions_single, cell_single, charges_single in zip(
positions, cell, charges
):
# Compute the potentials
potentials.append(
self._compute_single_system(
positions=positions_single,
charges=charges_single,
cell=cell_single,
)
)
else:
for (
positions_single,
cell_single,
charges_single,
neighbor_indices_single,
neighbor_shifts_single,
) in zip(positions, cell, charges, neighbor_indices, neighbor_shifts):
# Compute the potentials
potentials.append(
self._compute_single_system(
positions=positions_single,
charges=charges_single,
cell=cell_single,
neighbor_indices=neighbor_indices_single,
neighbor_shifts=neighbor_shifts_single,
)
)
)

if len(types) == 1:
return potentials[0]
else:
return potentials
return potentials
Loading

0 comments on commit 6f1c866

Please sign in to comment.