Skip to content

Commit

Permalink
Add parameter for interpolation_order
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Nov 23, 2023
1 parent 0062383 commit 4c5e9b1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
4 changes: 4 additions & 0 deletions src/meshlode/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class MeshPotential(torch.nn.Module):
atomic density.
:param mesh_spacing: Value that determines the umber of Fourier-space grid points
that will be used along each axis.
:param interpolation_order: Interpolation order for mapping onto the grid.
``4`` equals cubic interpolation.
Example
-------
Expand All @@ -38,12 +40,14 @@ def __init__(
self,
atomic_gaussian_width: float,
mesh_spacing: float = 0.2,
interpolation_order: float = 4,
):
super().__init__()

self.parameters = {
"atomic_gaussian_width": atomic_gaussian_width,
"mesh_spacing": mesh_spacing,
"interpolation_order": interpolation_order,
}

def compute(
Expand Down
27 changes: 13 additions & 14 deletions src/meshlode/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,26 @@


class System:
"""A single system for which we want to run a calculation."""
"""A single system for which we want to run a calculation.
:param species: species of the atoms/particles in this system. This should
be a 1D array of integer containing different values for different
system. The species will typically match the atomic element, but does
not have to.
:param positions: positions of the atoms/particles in this system. This
should be a ``len(species) x 3`` 2D array containing the positions of
each atom.
:param cell: 3x3 cell matrix for periodic boundary conditions, where each
row is one of the cell vector. Use a matrix filled with ``0`` for
non-periodic systems.
"""

def __init__(
self,
species: torch.Tensor,
positions: torch.Tensor,
cell: torch.Tensor,
):
"""
:param species: species of the atoms/particles in this system. This should
be a 1D array of integer containing different values for different
system. The species will typically match the atomic element, but does
not have to.
:param positions: positions of the atoms/particles in this system. This
should be a ``len(species) x 3`` 2D array containing the positions of
each atom.
:param cell: 3x3 cell matrix for periodic boundary conditions, where each
row is one of the cell vector. Use a matrix filled with ``0`` for
non-periodic systems.
"""

self._species = species
self._positions = positions
self._cell = cell
Expand Down
6 changes: 3 additions & 3 deletions tests/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def system():
)


def spherical_expansion():
def descriptor():
return calculators.MeshPotential(
atomic_gaussian_width=1,
)
Expand All @@ -31,9 +31,9 @@ def check_operation(calculator):


def test_operation_as_python():
check_operation(spherical_expansion())
check_operation(descriptor())


def test_operation_as_torch_script():
scripted = torch.jit.script(spherical_expansion())
scripted = torch.jit.script(descriptor())
check_operation(scripted)

0 comments on commit 4c5e9b1

Please sign in to comment.