diff --git a/src/meshlode/calculators/calculator_base.py b/src/meshlode/calculators/calculator_base.py index 37537a0c..2216c7e0 100644 --- a/src/meshlode/calculators/calculator_base.py +++ b/src/meshlode/calculators/calculator_base.py @@ -1,4 +1,3 @@ -import warnings from typing import List, Optional, Tuple, Union import torch @@ -190,13 +189,6 @@ def _validate_compute_parameters( f"cell ({cell_single.device})" ) - if type(neighbor_indices_single) is not type(neighbor_indices_single): - raise ValueError( - f"Inconsistent of neighbor_indices " - f"({type(neighbor_indices_single)}) and neighbor_indices " - f"({neighbor_indices_single})" - ) - if neighbor_indices_single is not None: # TODO validate shape and dtype @@ -297,83 +289,6 @@ def _compute_impl( else: return potentials - def compute( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor] = None, - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute potential for all provided "systems" stacked inside list. - - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. - - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. - :param cell: Ignored. - :param charges: Optional single or list of 2D tensor of shape (len(types), n), - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). - """ - if cell is not None: - warnings.warn( - "`cell` parameter was proviced but will be ignored", stacklevel=2 - ) - - return self._compute_impl( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) - - # This function is kept to keep MeshLODE compatible with the broader pytorch - # infrastructure, which require a "forward" function. We name this function - # "compute" instead, for compatibility with other COSMO software. - def forward( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor] = None, - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) - def _compute_single_system( self, positions: torch.Tensor, diff --git a/src/meshlode/calculators/calculator_base_periodic.py b/src/meshlode/calculators/calculator_base_periodic.py deleted file mode 100644 index 0d64a4f4..00000000 --- a/src/meshlode/calculators/calculator_base_periodic.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import List, Optional, Union - -import torch - -from .calculator_base import CalculatorBase - - -class CalculatorBasePeriodic(CalculatorBase): - """ - Base calculator for periodic implementations - """ - - name = "CalculatorBasePeriodic" - - def forward( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor] = None, - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """forward just calls :py:meth:`CalculatorModule.compute`""" - return self.compute( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) - - def compute( - self, - types: Union[List[torch.Tensor], torch.Tensor], - positions: Union[List[torch.Tensor], torch.Tensor], - cell: Union[List[torch.Tensor], torch.Tensor] = None, - charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, - neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """Compute potential for all provided "systems" stacked inside list. - - The computation is performed on the same ``device`` as ``systems`` is stored on. - The ``dtype`` of the output tensors will be the same as the input. - - :param types: single or list of 1D tensor of integer representing the - particles identity. For atoms, this is typically their atomic numbers. - :param positions: single or 2D tensor of shape (len(types), 3) containing the - Cartesian positions of all particles in the system. - :param cell: single or 2D tensor of shape (3, 3), describing the bounding - box/unit cell of the system. Each row should be one of the bounding box - vector; and columns should contain the x, y, and z components of these - vectors (i.e. the cell should be given in row-major order). - :param charges: Optional single or list of 2D tensor of shape (len(types), n), - :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), - where n is the number of atoms. The 2 rows correspond to the indices of - the two atoms which are considered neighbors (e.g. within a cutoff distance) - :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), - where n is the number of atoms. The 3 rows correspond to the shift indices - for periodic images. - - :return: List of torch Tensors containing the potentials for all frames and all - atoms. Each tensor in the list is of shape (n_atoms, n_types), where - n_types is the number of types in all systems combined. If the input was - a single system only a single torch tensor with the potentials is returned. - - IMPORTANT: If multiple types are present, the different "types-channels" - are ordered according to atomic number. For example, if a structure contains - a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this - system, the feature tensor will be of shape (3, 2) = (``n_atoms``, - ``n_types``), where ``features[0, 0]`` is the potential at the position of - the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, - while ``features[0,1]`` is the potential at the position of the Oxygen atom - generated by the Oxygen atom(s). - """ - if cell is None: - raise ValueError("cell must be provided") - - return self._compute_impl( - types=types, - positions=positions, - cell=cell, - charges=charges, - neighbor_indices=neighbor_indices, - neighbor_shifts=neighbor_shifts, - ) diff --git a/src/meshlode/calculators/direct.py b/src/meshlode/calculators/direct.py index 8d0d4c3a..f53d5978 100644 --- a/src/meshlode/calculators/direct.py +++ b/src/meshlode/calculators/direct.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Optional, Union import torch @@ -22,6 +22,63 @@ class DirectPotential(CalculatorBase): name = "DirectPotential" + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + + return self._compute_impl( + types=types, + positions=positions, + cell=None, + charges=charges, + neighbor_indices=None, + neighbor_shifts=None, + ) + + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute( + types=types, + positions=positions, + charges=charges, + ) + def _compute_single_system( self, positions: torch.Tensor, diff --git a/src/meshlode/calculators/ewald.py b/src/meshlode/calculators/ewald.py index 6503d2da..74105405 100644 --- a/src/meshlode/calculators/ewald.py +++ b/src/meshlode/calculators/ewald.py @@ -6,10 +6,10 @@ from ase import Atoms from ase.neighborlist import neighbor_list -from .calculator_base_periodic import CalculatorBasePeriodic +from .calculator_base import CalculatorBase -class EwaldPotential(CalculatorBasePeriodic): +class EwaldPotential(CalculatorBase): """A specie-wise long-range potential computed using the Ewald sum, scaling as O(N^2) with respect to the number of particles N used as a reference to test faster implementations. @@ -84,6 +84,82 @@ def __init__( self.subtract_self = subtract_self self.subtract_interior = subtract_interior + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param cell: single or 2D tensor of shape (3, 3), describing the bounding + box/unit cell of the system. Each row should be one of the bounding box + vector; and columns should contain the x, y, and z components of these + vectors (i.e. the cell should be given in row-major order). + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + def _compute_single_system( self, positions: torch.Tensor, diff --git a/src/meshlode/calculators/mesh.py b/src/meshlode/calculators/mesh.py index f799ae41..c7da4811 100644 --- a/src/meshlode/calculators/mesh.py +++ b/src/meshlode/calculators/mesh.py @@ -2,13 +2,12 @@ import torch -from meshlode.lib.fourier_convolution import FourierSpaceConvolution -from meshlode.lib.mesh_interpolator import MeshInterpolator +from ..lib.fourier_convolution import FourierSpaceConvolution +from ..lib.mesh_interpolator import MeshInterpolator +from .calculator_base import CalculatorBase -from .calculator_base_periodic import CalculatorBasePeriodic - -class MeshPotential(CalculatorBasePeriodic): +class MeshPotential(CalculatorBase): """A specie-wise long-range potential, computed using the particle-mesh Ewald (PME) method scaling as O(NlogN) with respect to the number of particles N. @@ -82,6 +81,82 @@ def __init__( # Initilize auxiliary objects self.fourier_space_convolution = FourierSpaceConvolution() + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param cell: single or 2D tensor of shape (3, 3), describing the bounding + box/unit cell of the system. Each row should be one of the bounding box + vector; and columns should contain the x, y, and z components of these + vectors (i.e. the cell should be given in row-major order). + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + def _compute_single_system( self, positions: torch.Tensor, @@ -90,31 +165,6 @@ def _compute_single_system( neighbor_indices: Union[None, torch.Tensor], neighbor_shifts: Union[None, torch.Tensor], ) -> torch.Tensor: - """ - Compute the "electrostatic" potential at the position of all atoms in a - structure. - - :param positions: torch.tensor of shape (n_atoms, 3). Contains the Cartesian - coordinates of the atoms. The implementation also works if the positions - are not contained within the unit cell. - :param charges: torch.tensor of shape `(n_atoms, n_channels)`. In the simplest - case, this would be a tensor of shape (n_atoms, 1) where charges[i,0] is the - charge of atom i. More generally, the potential for the same atom positions - is computed for n_channels independent meshes, and one can specify the - "charge" of each atom on each of the meshes independently. For standard LODE - that treats all (atomic) types separately, one example could be: If n_atoms - = 4 and the types are [Na, Cl, Cl, Na], one could set n_channels=2 and use - the one-hot encoding charges = torch.tensor([[1,0],[0,1],[0,1],[1,0]]) for - the charges. This would then separately compute the "Na" potential and "Cl" - potential. Subtracting these from each other, one could recover the more - standard electrostatic potential in which Na and Cl have charges of +1 and - -1, respectively. - :param cell: torch.tensor of shape `(3, 3)`. Describes the unit cell of the - structure, where cell[i] is the i-th basis vector. - - :returns: torch.tensor of shape `(n_atoms, n_channels)` containing the potential - at the position of each atom for the `n_channels` independent meshes separately. - """ # Initializations k_cutoff = 2 * torch.pi / self.mesh_spacing diff --git a/src/meshlode/calculators/meshewald.py b/src/meshlode/calculators/meshewald.py index 46563ed6..74479f49 100644 --- a/src/meshlode/calculators/meshewald.py +++ b/src/meshlode/calculators/meshewald.py @@ -6,13 +6,11 @@ from ase import Atoms from ase.neighborlist import neighbor_list -from meshlode.lib.mesh_interpolator import MeshInterpolator +from ..lib.mesh_interpolator import MeshInterpolator +from .calculator_base import CalculatorBase -# from .mesh import MeshPotential -from .calculator_base_periodic import CalculatorBasePeriodic - -class MeshEwaldPotential(CalculatorBasePeriodic): +class MeshEwaldPotential(CalculatorBase): """A specie-wise long-range potential computed using a mesh-based Ewald method, scaling as O(NlogN) with respect to the number of particles N used as a reference to test faster implementations. @@ -77,6 +75,82 @@ def __init__( self.subtract_self = subtract_self self.subtract_interior = subtract_interior + def compute( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Compute potential for all provided "systems" stacked inside list. + + The computation is performed on the same ``device`` as ``systems`` is stored on. + The ``dtype`` of the output tensors will be the same as the input. + + :param types: single or list of 1D tensor of integer representing the + particles identity. For atoms, this is typically their atomic numbers. + :param positions: single or 2D tensor of shape (len(types), 3) containing the + Cartesian positions of all particles in the system. + :param cell: single or 2D tensor of shape (3, 3), describing the bounding + box/unit cell of the system. Each row should be one of the bounding box + vector; and columns should contain the x, y, and z components of these + vectors (i.e. the cell should be given in row-major order). + :param charges: Optional single or list of 2D tensor of shape (len(types), n), + :param neighbor_indices: Optional single or list of 2D tensors of shape (2, n), + where n is the number of atoms. The 2 rows correspond to the indices of + the two atoms which are considered neighbors (e.g. within a cutoff distance) + :param neighbor_shifts: Optional single or list of 2D tensors of shape (3, n), + where n is the number of atoms. The 3 rows correspond to the shift indices + for periodic images. + + :return: List of torch Tensors containing the potentials for all frames and all + atoms. Each tensor in the list is of shape (n_atoms, n_types), where + n_types is the number of types in all systems combined. If the input was + a single system only a single torch tensor with the potentials is returned. + + IMPORTANT: If multiple types are present, the different "types-channels" + are ordered according to atomic number. For example, if a structure contains + a water molecule with atoms 0, 1, 2 being of types O, H, H, then for this + system, the feature tensor will be of shape (3, 2) = (``n_atoms``, + ``n_types``), where ``features[0, 0]`` is the potential at the position of + the Oxygen atom (atom 0, first index) generated by the HYDROGEN(!) atoms, + while ``features[0,1]`` is the potential at the position of the Oxygen atom + generated by the Oxygen atom(s). + """ + + return self._compute_impl( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + + # This function is kept to keep MeshLODE compatible with the broader pytorch + # infrastructure, which require a "forward" function. We name this function + # "compute" instead, for compatibility with other COSMO software. + def forward( + self, + types: Union[List[torch.Tensor], torch.Tensor], + positions: Union[List[torch.Tensor], torch.Tensor], + cell: Union[List[torch.Tensor], torch.Tensor], + charges: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + neighbor_indices: Union[List[torch.Tensor], torch.Tensor] = None, + neighbor_shifts: Union[List[torch.Tensor], torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """forward just calls :py:meth:`CalculatorModule.compute`""" + return self.compute( + types=types, + positions=positions, + cell=cell, + charges=charges, + neighbor_indices=neighbor_indices, + neighbor_shifts=neighbor_shifts, + ) + def _generate_kvectors(self, ns: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: """ For a given unit cell, compute all reciprocal space vectors that are used to diff --git a/tests/calculators/test_workflow_direct.py b/tests/calculators/test_workflow_direct.py index 9139bad5..5ea9d15f 100644 --- a/tests/calculators/test_workflow_direct.py +++ b/tests/calculators/test_workflow_direct.py @@ -27,7 +27,7 @@ def cscl_system(): def cscl_system_with_charges(): """CsCl crystal with (cell) and charges.""" charges = torch.tensor([[0.0, 1.0], [1.0, 0]]) - return cscl_system() + (None, charges,) + return cscl_system() + (charges,) # Initialize the calculators. For now, only the DirectPotential is implemented.