From 4c5e9b1fcc7e593fc9684f9d817f8357d649df57 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Thu, 23 Nov 2023 09:10:46 +0100 Subject: [PATCH] Add parameter for `interpolation_order` --- src/meshlode/calculators.py | 4 ++++ src/meshlode/system.py | 27 +++++++++++++-------------- tests/calculators.py | 6 +++--- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/meshlode/calculators.py b/src/meshlode/calculators.py index e49d8f8a..cf706d49 100644 --- a/src/meshlode/calculators.py +++ b/src/meshlode/calculators.py @@ -24,6 +24,8 @@ class MeshPotential(torch.nn.Module): atomic density. :param mesh_spacing: Value that determines the umber of Fourier-space grid points that will be used along each axis. + :param interpolation_order: Interpolation order for mapping onto the grid. + ``4`` equals cubic interpolation. Example ------- @@ -38,12 +40,14 @@ def __init__( self, atomic_gaussian_width: float, mesh_spacing: float = 0.2, + interpolation_order: float = 4, ): super().__init__() self.parameters = { "atomic_gaussian_width": atomic_gaussian_width, "mesh_spacing": mesh_spacing, + "interpolation_order": interpolation_order, } def compute( diff --git a/src/meshlode/system.py b/src/meshlode/system.py index 18695627..c5173904 100644 --- a/src/meshlode/system.py +++ b/src/meshlode/system.py @@ -2,7 +2,19 @@ class System: - """A single system for which we want to run a calculation.""" + """A single system for which we want to run a calculation. + + :param species: species of the atoms/particles in this system. This should + be a 1D array of integer containing different values for different + system. The species will typically match the atomic element, but does + not have to. + :param positions: positions of the atoms/particles in this system. This + should be a ``len(species) x 3`` 2D array containing the positions of + each atom. + :param cell: 3x3 cell matrix for periodic boundary conditions, where each + row is one of the cell vector. Use a matrix filled with ``0`` for + non-periodic systems. + """ def __init__( self, @@ -10,19 +22,6 @@ def __init__( positions: torch.Tensor, cell: torch.Tensor, ): - """ - :param species: species of the atoms/particles in this system. This should - be a 1D array of integer containing different values for different - system. The species will typically match the atomic element, but does - not have to. - :param positions: positions of the atoms/particles in this system. This - should be a ``len(species) x 3`` 2D array containing the positions of - each atom. - :param cell: 3x3 cell matrix for periodic boundary conditions, where each - row is one of the cell vector. Use a matrix filled with ``0`` for - non-periodic systems. - """ - self._species = species self._positions = positions self._cell = cell diff --git a/tests/calculators.py b/tests/calculators.py index 1f0ed7af..92787250 100644 --- a/tests/calculators.py +++ b/tests/calculators.py @@ -13,7 +13,7 @@ def system(): ) -def spherical_expansion(): +def descriptor(): return calculators.MeshPotential( atomic_gaussian_width=1, ) @@ -31,9 +31,9 @@ def check_operation(calculator): def test_operation_as_python(): - check_operation(spherical_expansion()) + check_operation(descriptor()) def test_operation_as_torch_script(): - scripted = torch.jit.script(spherical_expansion()) + scripted = torch.jit.script(descriptor()) check_operation(scripted)