diff --git a/regtest/.gitignore b/regtest/.gitignore index 8ce027ff27..d7c6e0cf2a 100644 --- a/regtest/.gitignore +++ b/regtest/.gitignore @@ -41,6 +41,7 @@ !/gridtools !/clusters !/unittest +!/metatensor # These files we just want to ignore completely tmp report.txt diff --git a/regtest/metatensor/rt-soap/.gitignore b/regtest/metatensor/rt-soap/.gitignore new file mode 100644 index 0000000000..7d07b69c6c --- /dev/null +++ b/regtest/metatensor/rt-soap/.gitignore @@ -0,0 +1,2 @@ +extensions/ +soap_cv.pt diff --git a/regtest/metatensor/rt-soap/Makefile b/regtest/metatensor/rt-soap/Makefile new file mode 100644 index 0000000000..3703b27cea --- /dev/null +++ b/regtest/metatensor/rt-soap/Makefile @@ -0,0 +1 @@ +include ../../scripts/test.make diff --git a/regtest/metatensor/rt-soap/config b/regtest/metatensor/rt-soap/config new file mode 100644 index 0000000000..307851a20e --- /dev/null +++ b/regtest/metatensor/rt-soap/config @@ -0,0 +1,5 @@ +plumed_modules=metatensor +plumed_needs=metatensor +type=driver +# this is to test a different name +arg="--plumed plumed.dat --ixyz crystal_structure.xyz --dump-forces forces --dump-forces-fmt %8.4f --debug-forces forces.num" diff --git a/regtest/metatensor/rt-soap/crystal_structure.xyz b/regtest/metatensor/rt-soap/crystal_structure.xyz new file mode 100644 index 0000000000..236e93d355 --- /dev/null +++ b/regtest/metatensor/rt-soap/crystal_structure.xyz @@ -0,0 +1,178 @@ +176 +10.8151 11.1671 12.6671 +N 1.28413 0.0356836 12.195 +N 4.15648 11.1648 5.88207 +N 6.66501 5.56845 0.433575 +N 9.55986 5.61216 6.79708 +N 3.19304 0.415601 0.79143 +N 2.20149 10.7488 7.12262 +N 8.66181 5.16664 11.8319 +N 7.55431 5.98008 5.60745 +O 0.588269 10.7008 10.069 +O 4.8823 0.501289 3.74831 +O 5.96103 6.07786 2.60388 +O 10.243 5.10818 8.93897 +O 9.14171 8.47003 10.3675 +O 7.07259 2.7051 3.99376 +O 3.70069 8.25989 2.31089 +O 1.67804 2.88647 8.67235 +O 10.6638 1.07613 7.81409 +O 5.58016 10.0738 1.4989 +O 5.26318 4.50881 4.91221 +O 0.188342 6.66055 11.1768 +O 2.59101 9.42059 0.233576 +O 2.79327 1.7319 6.52566 +O 8.0305 7.282 12.4527 +O 8.2152 3.85759 6.13971 +O 3.82264 2.485 1.51498 +O 1.64119 8.66909 7.85617 +O 9.20944 3.10331 11.2096 +O 7.03065 7.99741 4.82086 +O 10.1687 1.63869 1.57326 +O 6.06204 9.48052 7.85816 +O 4.79055 3.89767 11.091 +O 0.625828 7.29331 4.78201 +C 2.95238 1.72613 1.06281 +C 2.39424 9.44611 7.32799 +C 8.34889 3.87138 11.6213 +C 7.84063 7.30003 5.3126 +C 2.36095 10.6306 0.183825 +C 3.05248 0.576928 6.49156 +C 7.74299 6.09163 12.4796 +C 8.47531 5.07354 6.11657 +C 1.11456 1.43798 12.1146 +C 4.34028 9.73065 5.75877 +C 6.56198 4.11671 0.624847 +C 9.70287 7.04878 6.88873 +C 1.52986 2.17765 0.678822 +C 3.85612 9.03075 7.04374 +C 7.00303 3.4499 11.9504 +C 9.23522 7.75214 5.61205 +C 0.398007 10.3598 11.4259 +C 5.00925 0.855939 5.14878 +C 5.80886 6.44152 1.24675 +C 10.4388 4.72998 7.54811 +C 9.68665 10.4586 11.7387 +C 6.44378 0.752554 5.38023 +C 4.30831 6.3222 0.894521 +C 1.08084 4.85311 7.26565 +C 9.13597 9.89693 10.3582 +C 7.10879 1.28477 4.08154 +C 3.70376 6.83428 2.27981 +C 1.70103 4.35143 8.59228 +C 10.0922 10.4258 9.38103 +C 6.11343 0.779889 3.03707 +C 4.72635 6.30337 3.28925 +C 0.706842 4.86811 9.62774 +C 9.63945 0.573215 8.71565 +C 6.58754 10.6145 2.36632 +C 4.28809 4.99295 3.94499 +C 1.17869 6.1342 10.2792 +C 1.46661 3.64735 0.563839 +C 3.91337 7.45062 6.83304 +C 6.89236 1.9212 12.138 +C 9.31088 9.26262 5.79701 +C 0.49689 1.86112 1.86459 +C 4.90182 9.32548 8.20692 +C 5.9603 3.7177 10.8188 +C 10.2874 7.39666 4.47252 +C 1.05177 1.68814 3.25477 +C 4.38145 9.48367 9.61965 +C 6.45079 3.90566 9.38004 +C 9.75263 7.25408 3.08129 +C 0.991026 0.195369 3.56821 +C 4.32584 10.96 9.93137 +C 6.46726 5.35922 9.10466 +C 9.80269 5.76208 2.72201 +C 0.128048 2.46113 4.27152 +C 5.29095 8.70348 10.617 +C 5.55851 3.11593 8.44484 +C 10.6702 8.04546 2.03703 +H 4.10239 11.1459 1.09329 +H 1.30213 0.0342548 7.42189 +H 9.53281 5.60692 11.5551 +H 6.66589 5.53935 5.23088 +H 8.64718 8.09618 11.1453 +H 7.60533 3.01489 4.78942 +H 3.23645 8.63677 1.61018 +H 2.20095 2.55921 7.88601 +H 0.379193 1.73168 8.24643 +H 5.04373 9.3881 1.93904 +H 5.78935 3.80089 4.40439 +H 10.4567 7.37919 10.7535 +H 0.104051 1.73809 11.8608 +H 5.30499 9.43917 5.54202 +H 5.53299 3.87341 0.812386 +H 10.6884 7.36664 7.10621 +H 1.78896 1.81187 11.2638 +H 3.68439 9.3429 4.9022 +H 7.12753 3.75643 1.45471 +H 9.08085 7.38722 7.78966 +H 0.697568 9.29512 11.6168 +H 4.66031 1.90419 5.29401 +H 6.09731 7.44726 1.04688 +H 10.0692 3.7367 7.3867 +H 9.38022 9.89395 12.6339 +H 6.79461 1.33138 6.26903 +H 3.96278 6.9026 0.0896576 +H 1.42227 4.31828 6.41678 +H 9.44192 0.322678 11.8104 +H 6.81651 10.8324 5.48627 +H 3.96238 5.27012 0.847479 +H 1.43779 5.92571 7.14938 +H 8.11352 10.2845 10.2269 +H 8.10148 0.852127 3.86806 +H 2.70619 6.46082 2.46064 +H 2.7613 4.72198 8.75886 +H 10.3267 9.63132 8.60634 +H 5.9242 1.48199 2.2331 +H 4.8782 7.07518 4.06084 +H 0.525666 4.08661 10.4007 +H 8.76306 0.369849 8.15227 +H 7.48526 10.8057 1.82079 +H 3.3422 5.20994 4.52094 +H 2.07818 5.96684 10.8643 +H 9.41381 1.29284 9.49431 +H 6.77874 9.85185 3.13693 +H 4.03685 4.26764 3.12524 +H 1.38637 6.94525 9.50275 +H 2.1846 3.98948 12.4268 +H 3.1876 7.08306 6.13166 +H 7.60275 1.57516 0.261558 +H 8.61406 9.57521 6.56965 +H 0.494526 4.01383 0.245591 +H 4.94626 7.18663 6.57171 +H 5.86353 1.59166 12.4312 +H 10.3224 9.5727 6.09657 +H 1.77827 4.23366 1.4863 +H 3.69279 6.9806 7.78454 +H 7.17693 1.38145 11.1517 +H 9.05764 9.78475 4.85955 +H 2.05845 2.03468 3.34088 +H 3.31336 9.105 9.65719 +H 7.44442 3.52356 9.32561 +H 8.75636 7.64728 3.02387 +H 1.7089 10.8043 2.93884 +H 3.66568 0.380257 9.29099 +H 7.17053 5.95779 9.6717 +H 9.10914 5.19244 3.33095 +H 1.35776 0.0583836 4.6365 +H 4.09743 0.00380453 10.9624 +H 6.77825 5.49627 8.02518 +H 9.44701 5.64256 1.66552 +H 0.0490674 10.9316 3.47795 +H 5.36527 0.220938 9.78787 +H 5.44707 5.78211 9.23447 +H 10.7855 5.34902 2.8569 +H 0.46903 2.26044 5.28333 +H 4.9299 8.92286 11.6417 +H 5.88028 3.32144 7.4085 +H 10.3598 7.85218 1.02466 +H 0.215689 3.5216 4.08985 +H 5.25018 7.63447 10.4261 +H 5.58329 2.03707 8.55595 +H 10.6224 9.11714 2.23242 +H 9.89429 2.11629 4.1664 +H 6.30059 9.05835 10.5129 +H 4.49554 3.4428 8.48274 +H 0.872511 7.66384 2.14979 diff --git a/regtest/metatensor/rt-soap/plumed.dat b/regtest/metatensor/rt-soap/plumed.dat new file mode 100644 index 0000000000..8237e14c41 --- /dev/null +++ b/regtest/metatensor/rt-soap/plumed.dat @@ -0,0 +1,12 @@ +soap: METATENSOR ... + MODEL=soap_cv.pt + EXTENSIONS_DIRECTORY=../extensions + SPECIES=1-40 +... + + +scalar: SUM ARG=soap PERIODIC=NO +BIASVALUE ARG=scalar + + +PRINT ARG=soap FILE=soap.matx STRIDE=1 FMT=%8.4f diff --git a/regtest/metatensor/rt-soap/soap_cv.py b/regtest/metatensor/rt-soap/soap_cv.py new file mode 100644 index 0000000000..9b22113501 --- /dev/null +++ b/regtest/metatensor/rt-soap/soap_cv.py @@ -0,0 +1,105 @@ +from typing import Dict, List, Optional + +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatensor.torch.atomistic import ( + MetatensorAtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + System, +) +from rascaline.torch import SoapPowerSpectrum + + +class SOAP_CV(torch.nn.Module): + def __init__(self, species): + super().__init__() + + self.neighbor_type_pairs = Labels( + names=["neighbor_1_type", "neighbor_2_type"], + values=torch.tensor( + [[t1, t2] for t1 in species for t2 in species if t1 <= t2] + ), + ) + self.calculator = SoapPowerSpectrum( + cutoff=4.0, + max_angular=6, + max_radial=6, + radial_basis={"Gto": {}}, + cutoff_function={"ShiftedCosine": {"width": 0.5}}, + center_atom_weight=1.0, + atomic_gaussian_width=0.3, + ) + + self.pca_projection = torch.rand(2520, 3, dtype=torch.float64) + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + + if "collective_variable" not in outputs: + return {} + + output = outputs["collective_variable"] + + soap = self.calculator(systems, selected_samples=selected_atoms) + soap = soap.keys_to_samples("center_type") + soap = soap.keys_to_properties(self.neighbor_type_pairs) + + if not output.per_atom: + raise ValueError("per_atom=False is not supported") + + soap_block = soap.block() + projected = soap_block.values @ self.pca_projection + + block = TensorBlock( + values=projected, + samples=soap_block.samples, + components=[], + properties=Labels("soap_pca", torch.tensor([[0], [1], [2]])), + ) + cv = TensorMap(keys=Labels("_", torch.tensor([[0]])), blocks=[block]) + + return {"collective_variable": cv} + + +cv = SOAP_CV(species=[1, 6, 7, 8]) +cv.eval() + + +capabilites = ModelCapabilities( + outputs={ + "collective_variable": ModelOutput( + quantity="", + unit="", + per_atom=True, + explicit_gradients=["postions"], + ) + }, + interaction_range=4.0, + supported_devices=["cpu", "cuda"], + length_unit="nm", + atomic_types=[6, 1, 7, 8], + # dtype=TODO +) + +metadata = ModelMetadata( + name="Collective Variable test", + description=""" +A simple collective variable for testing purposes +""", + authors=["..."], + references={ + "implementation": ["ref to SOAP code"], + "architecture": ["ref to SOAP"], + "model": ["ref to paper"], + }, +) + + +model = MetatensorAtomisticModel(cv, metadata, capabilites) +model.export("soap_cv.pt", collect_extensions="extensions") diff --git a/src/metatensor/README.md b/src/metatensor/README.md index c21bc6185b..addae311d2 100644 --- a/src/metatensor/README.md +++ b/src/metatensor/README.md @@ -52,7 +52,7 @@ cmake --build . --target install --parallel METATENSOR_CMAKE_PREFIX=$(python -c "import metatensor; print(metatensor.utils.cmake_prefix_path)") METATENSOR_PREFIX=$(cd "$METATENSOR_CMAKE_PREFIX/../.." && pwd) -METATENSOR_TORCH_CMAKE_PREFIX=$(python -c "import torch; print(torch.utils.cmake_prefix_path)") +METATENSOR_TORCH_CMAKE_PREFIX=$(python -c "import metatensor.torch; print(metatensor.torch.utils.cmake_prefix_path)") METATENSOR_TORCH_PREFIX=$(cd "$METATENSOR_TORCH_CMAKE_PREFIX/../.." && pwd) ``` @@ -61,18 +61,24 @@ METATENSOR_TORCH_PREFIX=$(cd "$METATENSOR_TORCH_CMAKE_PREFIX/../.." && pwd) ```bash cd +# set the rpath to make sure plumed executable will be able to find the right libraries +RPATH="-Wl,-rpath,$TORCH_PREFIX/lib -Wl,-rpath,$METATENSOR_PREFIX/lib -Wl,-rpath,$METATENSOR_TORCH_PREFIX/lib" + # configure PLUMED with metatensor ./configure --enable-libtorch --enable-metatensor --enable-modules=+metatensor \ - LDFLAGS="-L$TORCH_PREFIX/lib -L$METATENSOR_PREFIX/lib -L$METATENSOR_TORCH_PREFIX/lib" \ + LDFLAGS="-L$TORCH_PREFIX/lib -L$METATENSOR_PREFIX/lib -L$METATENSOR_TORCH_PREFIX/lib $RPATH" \ CPPFLAGS="$TORCH_INCLUDES -I$METATENSOR_PREFIX/include -I$METATENSOR_TORCH_PREFIX/include" # If you are on Linux and use a pip-installed version of libtorch, or the # pre-cxx11-ABI build of libtorch, you'll need to add "-D_GLIBCXX_USE_CXX11_ABI=0" # to the compilation flags: ./configure --enable-libtorch --enable-metatensor --enable-modules=+metatensor \ - LDFLAGS="-L$TORCH_PREFIX/lib -L$METATENSOR_PREFIX/lib -L$METATENSOR_TORCH_PREFIX/lib" \ + LDFLAGS="-L$TORCH_PREFIX/lib -L$METATENSOR_PREFIX/lib -L$METATENSOR_TORCH_PREFIX/lib $RPATH" \ CPPFLAGS="$TORCH_INCLUDES -I$METATENSOR_PREFIX/include -I$METATENSOR_TORCH_PREFIX/include" \ CXXFLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" make -j && make install ``` + + + diff --git a/src/metatensor/metatensor.cpp b/src/metatensor/metatensor.cpp index e27373ee8f..f705d2394e 100644 --- a/src/metatensor/metatensor.cpp +++ b/src/metatensor/metatensor.cpp @@ -7,25 +7,57 @@ any later version. ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ -#if !defined(__PLUMED_HAS_LIBTORCH) || !defined(__PLUMED_HAS_METATENSOR) - -// give a nice error message if the user tries to enable -// metatensor without enabling the corresponding libraries -#error "can not compile the metatensor module without the corresponding libraries, either the disable metatensor module or configure with `--enable-metatensor --enable-libtorch` and make sure the libraries can be found" - -#else - #include "core/ActionAtomistic.h" #include "core/ActionWithValue.h" #include "core/ActionRegister.h" #include "core/PlumedMain.h" +#if !defined(__PLUMED_HAS_LIBTORCH) || !defined(__PLUMED_HAS_METATENSOR) + +namespace PLMD { namespace metatensor { +class MetatensorPlumedAction: public ActionAtomistic, public ActionWithValue { +public: + static void registerKeywords(Keywords& keys); + explicit MetatensorPlumedAction(const ActionOptions&) { + throw std::runtime_error( + "Can not use metatensor action without the corresponding libraries. \n" + "Make sure to configure with `--enable-metatensor --enable-libtorch` " + "and that the corresponding libraries are found" + ); + } +}; + +}} // namespace PLMD::metatensor + +#else + +#include + #include +#include "torch/csrc/autograd/autograd.h" + #include +#include + +#include "vesin.h" + + +// TEMPORARY HACK +#include +// TEMPORARY HACK + +namespace PLMD { namespace metatensor { +// We will cast Vector/Tensor to pointers to arrays and doubles, so let's make +// sure this is legal to do +static_assert(std::is_standard_layout::value); +static_assert(sizeof(PLMD::Vector) == sizeof(std::array)); +static_assert(alignof(PLMD::Vector) == alignof(std::array)); -namespace PLMD { +static_assert(std::is_standard_layout::value); +static_assert(sizeof(PLMD::Tensor) == sizeof(std::array, 3>)); +static_assert(alignof(PLMD::Tensor) == alignof(std::array, 3>)); class MetatensorPlumedAction: public ActionAtomistic, public ActionWithValue { public: @@ -37,26 +69,145 @@ class MetatensorPlumedAction: public ActionAtomistic, public ActionWithValue { unsigned getNumberOfDerivatives() override; private: + void createSystem(); - metatensor_torch::TorchTensorMap output_; -}; + // compute a neighbor list following metatensor format, using data from PLUMED + metatensor_torch::TorchTensorBlock computeNeighbors( + metatensor_torch::NeighborsListOptions request, + const std::vector& positions, + const PLMD::Tensor& cell + ); -PLUMED_REGISTER_ACTION(MetatensorPlumedAction, "METATENSOR") + torch::jit::Module model_; -void MetatensorPlumedAction::registerKeywords(Keywords& keys) { - Action::registerKeywords(keys); - ActionAtomistic::registerKeywords(keys); - ActionWithValue::registerKeywords(keys); + torch::Tensor atomic_types_; + // store the strain to be able to compute the virial with autograd + torch::Tensor strain_; + + metatensor_torch::System system_; + metatensor_torch::ModelEvaluationOptions evaluations_options_; + bool check_consistency_ = true; + metatensor_torch::TorchTensorMap output_; +}; - throw std::runtime_error("unimplemented"); -} MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options): Action(options), ActionAtomistic(options), ActionWithValue(options) { - throw std::runtime_error("unimplemented"); + std::string extensions_directory; + this->parse("EXTENSIONS_DIRECTORY", extensions_directory); + + // TEMPORARY BAD CODE, TO BE REMOVED + dlopen( + (extensions_directory + "/rascaline/lib/librascaline.dylib").c_str(), + RTLD_LOCAL | RTLD_NOW + ); + + dlopen( + (extensions_directory + "/rascaline/torch/lib/librascaline_torch.dylib").c_str(), + RTLD_LOCAL | RTLD_NOW + ); + // END OF TEMPORARY BAD CODE, TO BE REMOVED + + // load the model + std::string model_path; + this->parse("MODEL", model_path); + + try { + this->model_ = metatensor_torch::load_atomistic_model(model_path); + } catch (const std::exception& e) { + error("failed to load model at '" + model_path + "': " + e.what()); + } + + + // parse the atomic types from the input file + std::vector atomic_types; + std::vector species_to_metatensor_types; + parseVector("SPECIES_TO_METATENSOR_TYPES", species_to_metatensor_types); + bool has_custom_types = !species_to_metatensor_types.empty(); + + std::vector all_atoms; + parseAtomList("SPECIES", all_atoms); + + auto n_species = 0; + if (all_atoms.empty()) { + std::vector t; + for (int i=1;;i++) { + parseAtomList("SPECIES", i, t); + if (t.empty()) { + break; + } + + int32_t type = i; + if (has_custom_types) { + if (species_to_metatensor_types.size() < i) { + error( + "SPECIES_TO_METATENSOR_TYPES is too small, " + "it should have one entry for each species (we have at least " + + std::to_string(i) + " species and " + + std::to_string(species_to_metatensor_types.size()) + + "entries in SPECIES_TO_METATENSOR_TYPES)" + ); + } + + type = species_to_metatensor_types[i - 1]; + } + + log.printf(" Species %d includes atoms : ", i); + for(unsigned j=0; jwarning( + "SPECIES_TO_METATENSOR_TYPES contains more entries (" + + std::to_string(species_to_metatensor_types.size()) + + ") than there where species (" + std::to_string(n_species) + ")" + ); + } + + this->atomic_types_ = torch::tensor(std::move(atomic_types)); + + // Request the atoms and check we have read in everything + requestAtoms(all_atoms); + + // TODO: selected_atoms + // evaluations_options_->set_selected_atoms() + + // setup the output + // TODO: define the size/type of output a bit better + this->addValue({1, 1}); + this->setNotPeriodic(); + this->getPntrToComponent(0)->buildDataStore(); + + // create evaluation options for the model. These won't change during the + // simulation, so we initialize them once here. + evaluations_options_ = torch::make_intrusive(); + evaluations_options_->set_length_unit(getUnits().getLengthString()); + + auto output = torch::make_intrusive(); + // TODO: should this be configurable? + output->per_atom = true; + // we are using torch autograd system to compute gradients, so we don't need + // any explicit gradients. + output->explicit_gradients = {}; + evaluations_options_->outputs.insert("collective_variable", output); } unsigned MetatensorPlumedAction::getNumberOfDerivatives() { @@ -65,16 +216,336 @@ unsigned MetatensorPlumedAction::getNumberOfDerivatives() { } +void MetatensorPlumedAction::createSystem() { + const auto& cell = this->getPbc().getBox(); + + auto tensor_options = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU); + auto torch_cell = torch::zeros({3, 3}, tensor_options); + + // TODO: check if cell is stored in row or column major order + // TODO: check if cell is zero for non-periodic systems + torch_cell[0][0] = cell(0, 0); + torch_cell[0][1] = cell(0, 1); + torch_cell[0][2] = cell(0, 2); + + torch_cell[1][0] = cell(1, 0); + torch_cell[1][1] = cell(1, 1); + torch_cell[1][2] = cell(1, 2); + + torch_cell[2][0] = cell(2, 0); + torch_cell[2][1] = cell(2, 1); + torch_cell[2][2] = cell(2, 2); + + const auto& positions = this->getPositions(); + + auto torch_positions = torch::from_blob( + const_cast(positions.data()), + {static_cast(positions.size()), 3}, + tensor_options + ); + + // setup torch's automatic gradient tracking + if (!this->doNotCalculateDerivatives()) { + torch_positions.requires_grad_(true); + + this->strain_ = torch::eye(3, tensor_options.requires_grad(true)); + + // pretend to scale positions/cell by the strain so that it enters the + // computational graph. + torch_positions = torch_positions.matmul(this->strain_); + torch_positions.retain_grad(); + + torch_cell = torch_cell.matmul(this->strain_); + } + + + // TODO: move data to another dtype/device as requested by the model or user + this->system_ = torch::make_intrusive( + this->atomic_types_, + torch_positions, + torch_cell + ); + + // compute the neighbors list requested by the model, and register them with + // the system + auto nl_requests = this->model_.run_method("requested_neighbors_lists"); + for (auto request_ivalue: nl_requests.toList()) { + auto request = request_ivalue.get().toCustomClass(); + + auto neighbors = this->computeNeighbors(request, positions, cell); + metatensor_torch::register_autograd_neighbors(this->system_, neighbors, this->check_consistency_); + this->system_->add_neighbors_list(request, neighbors); + } +} + + +metatensor_torch::TorchTensorBlock MetatensorPlumedAction::computeNeighbors( + metatensor_torch::NeighborsListOptions request, + const std::vector& positions, + const PLMD::Tensor& cell +) { + auto labels_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); + auto neighbor_component = torch::make_intrusive( + "xyz", + torch::tensor({0, 1, 2}, labels_options).reshape({3, 1}) + ); + auto neighbor_properties = torch::make_intrusive( + "distance", torch::zeros({1, 1}, labels_options) + ); + + auto cutoff = request->engine_cutoff(this->getUnits().getLengthString()); + + auto periodic = ( + cell(0, 0) == 0.0 && cell(0, 1) == 0.0 && cell(0, 2) == 0.0 && + cell(1, 0) == 0.0 && cell(1, 1) == 0.0 && cell(1, 2) == 0.0 && + cell(2, 0) == 0.0 && cell(2, 2) == 0.0 && cell(2, 2) == 0.0 + ); + + // use https://github.com/Luthaf/vesin to compute the requested neighbor + // lists since we can not get these from PLUMED + VesinOptions options; + options.cutoff = cutoff; + options.full = request->full_list(); + options.return_shifts = true; + options.return_distances = false; + options.return_vectors = true; + + VesinNeighborsList* vesin_neighbor_list = new VesinNeighborsList(); + memset(vesin_neighbor_list, 0, sizeof(VesinNeighborsList)); + + const char* error_message = NULL; + int status = vesin_neighbors( + reinterpret_cast(positions.data()), + positions.size(), + reinterpret_cast(&cell(0, 0)), + periodic, + VesinCPU, + options, + vesin_neighbor_list, + &error_message + ); + + if (status != EXIT_SUCCESS) { + this->error( + "failed to compute neighbor list (cutoff=" + std::to_string(cutoff) + + "full=" + (request->full_list() ? "true" : "false") + "): " + error_message + ); + } + + // transform from vesin to metatensor format + auto n_pairs = static_cast(vesin_neighbor_list->length); + + auto pair_vectors = torch::from_blob( + vesin_neighbor_list->vectors, + {n_pairs, 3, 1}, + /*deleter*/ [=](void*) { + vesin_free(vesin_neighbor_list); + delete vesin_neighbor_list; + }, + torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU) + ); + + auto pair_samples_values = torch::zeros({n_pairs, 5}, labels_options); + for (unsigned i=0; i(vesin_neighbor_list->pairs[i][0]); + pair_samples_values[i][1] = static_cast(vesin_neighbor_list->pairs[i][1]); + pair_samples_values[i][2] = vesin_neighbor_list->shifts[i][0]; + pair_samples_values[i][3] = vesin_neighbor_list->shifts[i][1]; + pair_samples_values[i][4] = vesin_neighbor_list->shifts[i][2]; + } + + auto neighbor_samples = torch::make_intrusive( + std::vector{"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"}, + pair_samples_values + ); + + auto neighbors = torch::make_intrusive( + pair_vectors, + neighbor_samples, + std::vector{neighbor_component}, + neighbor_properties + ); + + return neighbors; +} + + void MetatensorPlumedAction::calculate() { - throw std::runtime_error("unimplemented"); + this->createSystem(); + + try { + auto ivalue_output = this->model_.forward({ + std::vector{this->system_}, + evaluations_options_, + this->check_consistency_, + }); + + auto dict_output = ivalue_output.toGenericDict(); + auto cv = dict_output.at("collective_variable"); + this->output_ = cv.toCustomClass(); + } catch (const std::exception& e) { + error("failed to evaluate the model: " + std::string(e.what())); + } + + // send the output back to plumed + plumed_massert(this->output_->keys()->count() == 1, "output should have a single block"); + auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0); + plumed_massert(block->components().empty(), "components are not yet supported in the output"); + auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64); + auto n_samples = torch_values.size(0); + auto n_properties = torch_values.size(1); + + Value* value = this->getPntrToComponent(0); + const auto& value_shape = value->getShape(); + // reshape the plumed `Value` to hold the data returned by the model + if (n_samples == 1) { + if (n_properties == 1) { + // the CV is a single scalar + if (value->getRank() != 0) { + log.printf(" output of metatensor model is a scalar\n"); + value->setShape({}); + } + + value->set(torch_values.item()); + } else { + // we have multiple CV describing a single thing (atom or full system) + if (value->getRank() != 1 || value_shape[0] != n_properties) { + log.printf(" output of metatensor model is a 1x%d vector\n", n_properties); + value->setShape({static_cast(n_properties)}); + } + + for (unsigned i=0; iset(i, torch_values[0][i].item()); + } + } + } else { + if (n_properties == 1) { + // we have a single CV describing multiple things (i.e. atoms) + if (value->getRank() != 1 || value_shape[0] != n_samples) { + log.printf(" output of metatensor model is a %dx1 vector\n", n_samples); + value->setShape({static_cast(n_samples)}); + } + + // TODO: check sample order? + for (unsigned i=0; iset(i, torch_values[i][0].item()); + } + } else { + // the CV is a matrix + if (value->getRank() != 2 || value_shape[0] != n_samples || value_shape[1] != n_properties) { + log.printf(" output of metatensor model is a %dx%d matrix\n", n_samples, n_properties); + value->setShape({ + static_cast(n_samples), + static_cast(n_properties), + }); + value->reshapeMatrixStore(n_properties); + } + + // TODO: check sample order? + for (unsigned i=0; iset(i * n_properties + j, torch_values[i][j].item()); + } + } + } + } } void MetatensorPlumedAction::apply() { - throw std::runtime_error("unimplemented"); -} + auto* value = this->getPntrToComponent(0); + if (!value->forcesWereAdded()) { + return; + } + + auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0); + auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64); + auto n_samples = torch_values.size(0); + auto n_properties = torch_values.size(1); + + auto output_grad = torch::zeros_like(torch_values); + if (n_samples == 1) { + if (n_properties == 1) { + output_grad[0][0] = value->getForce(); + } else { + for (unsigned i=0; igetForce(i); + } + } + } else { + if (n_properties == 1) { + // TODO: check sample order? + for (unsigned i=0; igetForce(i); + } + } else { + // TODO: check sample order? + for (unsigned i=0; igetForce(i * n_properties + j); + } + } + } + } + + auto input_grad = torch::autograd::grad( + {torch_values}, + {this->system_->positions(), this->strain_}, + {output_grad} + ); + plumed_assert(input_grad[0].is_cpu()); + plumed_assert(input_grad[0].is_contiguous()); + + plumed_assert(input_grad[1].is_cpu()); + plumed_assert(input_grad[1].is_contiguous()); -} // namespace PLMD + auto positions_grad = input_grad[0]; + auto strain_grad = input_grad[1]; + auto derivatives = std::vector( + positions_grad.data_ptr(), + positions_grad.data_ptr() + 3 * this->system_->size() + ); + + // add virials to the derivatives + derivatives.push_back(strain_grad[0][0].item()); + derivatives.push_back(strain_grad[0][1].item()); + derivatives.push_back(strain_grad[0][2].item()); + + derivatives.push_back(strain_grad[1][0].item()); + derivatives.push_back(strain_grad[1][1].item()); + derivatives.push_back(strain_grad[1][2].item()); + + derivatives.push_back(strain_grad[2][0].item()); + derivatives.push_back(strain_grad[2][1].item()); + derivatives.push_back(strain_grad[2][2].item()); + + + unsigned index = 0; + this->setForcesOnAtoms(derivatives, index); +} + +}} // namespace PLMD::metatensor #endif + + +namespace PLMD { namespace metatensor { + // use the same implementation for both the actual action and the dummy one + // (when libtorch and libmetatensor could not be found). + void MetatensorPlumedAction::registerKeywords(Keywords& keys) { + Action::registerKeywords(keys); + ActionAtomistic::registerKeywords(keys); + ActionWithValue::registerKeywords(keys); + + keys.add("compulsory", "MODEL", "path to the exported metatensor model"); + keys.add("optional", "EXTENSIONS_DIRECTORY", "path to the directory containing TorchScript extensions to load"); + + keys.add("numbered", "SPECIES", "the atoms in each PLUMED species"); + keys.reset_style("SPECIES", "atoms"); + + keys.add("optional", "SPECIES_TO_METATENSOR_TYPES", "mapping from PLUMED SPECIES to metatensor's atomic types"); + } + + PLUMED_REGISTER_ACTION(MetatensorPlumedAction, "METATENSOR") +}} diff --git a/src/metatensor/vesin-single-build.cpp b/src/metatensor/vesin-single-build.cpp new file mode 100644 index 0000000000..3821b30e98 --- /dev/null +++ b/src/metatensor/vesin-single-build.cpp @@ -0,0 +1,839 @@ +#include +#include +#include + +#include +#include +#include + +#ifndef VESIN_CPU_CELL_LIST_HPP +#define VESIN_CPU_CELL_LIST_HPP + +#include + +#include "vesin.h" + +#ifndef VESIN_TYPES_HPP +#define VESIN_TYPES_HPP + +#ifndef VESIN_MATH_HPP +#define VESIN_MATH_HPP + +#include +#include +#include + +namespace vesin { +struct Vector; + +Vector operator*(Vector vector, double scalar); + +struct Vector: public std::array { + double dot(Vector other) const { + return (*this)[0] * other[0] + (*this)[1] * other[1] + (*this)[2] * other[2]; + } + + double norm() const { + return std::sqrt(this->dot(*this)); + } + + Vector normalize() const { + return *this * (1.0 / this->norm()); + } + + Vector cross(Vector other) const { + return Vector{ + (*this)[1] * other[2] - (*this)[2] * other[1], + (*this)[2] * other[0] - (*this)[0] * other[2], + (*this)[0] * other[1] - (*this)[1] * other[0], + }; + } +}; + +inline Vector operator+(Vector u, Vector v) { + return Vector{ + u[0] + v[0], + u[1] + v[1], + u[2] + v[2], + }; +} + +inline Vector operator-(Vector u, Vector v) { + return Vector{ + u[0] - v[0], + u[1] - v[1], + u[2] - v[2], + }; +} + +inline Vector operator*(double scalar, Vector vector) { + return Vector{ + scalar * vector[0], + scalar * vector[1], + scalar * vector[2], + }; +} + +inline Vector operator*(Vector vector, double scalar) { + return Vector{ + scalar * vector[0], + scalar * vector[1], + scalar * vector[2], + }; +} + + +struct Matrix: public std::array, 3> { + double determinant() const { + return (*this)[0][0] * ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2]) + - (*this)[0][1] * ((*this)[1][0] * (*this)[2][2] - (*this)[1][2] * (*this)[2][0]) + + (*this)[0][2] * ((*this)[1][0] * (*this)[2][1] - (*this)[1][1] * (*this)[2][0]); + } + + Matrix inverse() const { + auto det = this->determinant(); + + if (std::abs(det) < 1e-30) { + throw std::runtime_error("this matrix is not invertible"); + } + + auto inverse = Matrix(); + inverse[0][0] = ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2]) / det; + inverse[0][1] = ((*this)[0][2] * (*this)[2][1] - (*this)[0][1] * (*this)[2][2]) / det; + inverse[0][2] = ((*this)[0][1] * (*this)[1][2] - (*this)[0][2] * (*this)[1][1]) / det; + inverse[1][0] = ((*this)[1][2] * (*this)[2][0] - (*this)[1][0] * (*this)[2][2]) / det; + inverse[1][1] = ((*this)[0][0] * (*this)[2][2] - (*this)[0][2] * (*this)[2][0]) / det; + inverse[1][2] = ((*this)[1][0] * (*this)[0][2] - (*this)[0][0] * (*this)[1][2]) / det; + inverse[2][0] = ((*this)[1][0] * (*this)[2][1] - (*this)[2][0] * (*this)[1][1]) / det; + inverse[2][1] = ((*this)[2][0] * (*this)[0][1] - (*this)[0][0] * (*this)[2][1]) / det; + inverse[2][2] = ((*this)[0][0] * (*this)[1][1] - (*this)[1][0] * (*this)[0][1]) / det; + return inverse; + } +}; + + +inline Vector operator*(Matrix matrix, Vector vector) { + return Vector{ + matrix[0][0] * vector[0] + matrix[0][1] * vector[1] + matrix[0][2] * vector[2], + matrix[1][0] * vector[0] + matrix[1][1] * vector[1] + matrix[1][2] * vector[2], + matrix[2][0] * vector[0] + matrix[2][1] * vector[1] + matrix[2][2] * vector[2], + }; +} + +inline Vector operator*(Vector vector, Matrix matrix) { + return Vector{ + vector[0] * matrix[0][0] + vector[1] * matrix[1][0] + vector[2] * matrix[2][0], + vector[0] * matrix[0][1] + vector[1] * matrix[1][1] + vector[2] * matrix[2][1], + vector[0] * matrix[0][2] + vector[1] * matrix[1][2] + vector[2] * matrix[2][2], + }; +} + +} + +#endif + +namespace vesin { + +class BoundingBox { +public: + BoundingBox(Matrix matrix, bool periodic): matrix_(std::move(matrix)), periodic_(periodic) { + if (periodic) { + this->inverse_ = matrix_.inverse(); + } else { + this->matrix_ = Matrix{{{ + {{1, 0, 0}}, + {{0, 1, 0}}, + {{0, 0, 1}} + }}}; + this->inverse_ = matrix_; + } + } + + const Matrix& matrix() const { + return this->matrix_; + } + + bool periodic() const { + return this->periodic_; + } + + /// Convert a vector from cartesian coordinates to fractional coordinates + Vector cartesian_to_fractional(Vector cartesian) const { + return cartesian * inverse_; + } + + /// Convert a vector from fractional coordinates to cartesian coordinates + Vector fractional_to_cartesian(Vector fractional) const { + return fractional * matrix_; + } + + /// Get the three distances between faces of the bounding box + Vector distances_between_faces() const { + auto a = Vector{matrix_[0]}; + auto b = Vector{matrix_[1]}; + auto c = Vector{matrix_[2]}; + + // Plans normal vectors + auto na = b.cross(c).normalize(); + auto nb = c.cross(a).normalize(); + auto nc = a.cross(b).normalize(); + + return Vector{ + std::abs(na.dot(a)), + std::abs(nb.dot(b)), + std::abs(nc.dot(c)), + }; + } + +private: + Matrix matrix_; + Matrix inverse_; + bool periodic_; +}; + + +/// A cell shift represents the displacement along cell axis between the actual +/// position of an atom and a periodic image of this atom. +/// +/// The cell shift can be used to reconstruct the vector between two points, +/// wrapped inside the unit cell. +struct CellShift: public std::array { + /// Compute the shift vector in cartesian coordinates, using the given cell + /// matrix (stored in row major order). + Vector cartesian(Matrix cell) const { + auto vector = Vector{ + static_cast((*this)[0]), + static_cast((*this)[1]), + static_cast((*this)[2]), + }; + return vector * cell; + } +}; + +inline CellShift operator+(CellShift a, CellShift b) { + return CellShift{ + a[0] + b[0], + a[1] + b[1], + a[2] + b[2], + }; +} + +inline CellShift operator-(CellShift a, CellShift b) { + return CellShift{ + a[0] - b[0], + a[1] - b[1], + a[2] - b[2], + }; +} + + + +} + +#endif + +namespace vesin { namespace cpu { + +void free_neighbors(VesinNeighborList& neighbors); + +void neighbors( + const Vector* points, + size_t n_points, + BoundingBox cell, + VesinOptions options, + VesinNeighborList& neighbors +); + + +/// The cell list is used to sort atoms inside bins/cells. +/// +/// The list of potential pairs is then constructed by looking through all +/// neighboring cells (the number of cells to search depends on the cutoff and +/// the size of the cells) for each atom to create pair candidates. +class CellList { +public: + /// Create a new `CellList` for the given bounding box and cutoff, + /// determining all required parameters. + CellList(BoundingBox box, double cutoff); + + /// Add a single point to the cell list at the given `position`. The point + /// is uniquely identified by its `index`. + void add_point(size_t index, Vector position); + + /// Iterate over all possible pairs, calling the given callback every time + template + void foreach_pair(Function callback); + +private: + /// How many cells do we need to look at when searching neighbors to include + /// all neighbors below cutoff + std::array n_search_; + + /// the cells themselves are a list of points & corresponding + /// shift to place the point inside the cell + struct Point { + size_t index; + CellShift shift; + }; + struct Cell: public std::vector {}; + + // raw data for the cells + std::vector cells_; + // shape of the cell array + std::array cells_shape_; + + BoundingBox box_; + + Cell& get_cell(std::array index); +}; + +/// Wrapper around `VesinNeighborsList` that behaves like a std::vector, +/// automatically growing memory allocations. +class GrowableNeighborsList { +public: + VesinNeighborList& neighbors; + size_t capacity; + VesinOptions options; + + size_t length() const { + return neighbors.length; + } + + void increment_length() { + neighbors.length += 1; + } + + void set_pair(size_t index, size_t first, size_t second); + void set_shift(size_t index, vesin::CellShift shift); + void set_distance(size_t index, double distance); + void set_vector(size_t index, vesin::Vector vector); + + // reset length to 0, and allocate/deallocate members of + // `neighbors` according to `options` + void reset(); + + // allocate more memory & update capacity + void grow(); +}; + +}} + +#endif + +using namespace vesin::cpu; + +void vesin::cpu::neighbors( + const Vector* points, + size_t n_points, + BoundingBox cell, + VesinOptions options, + VesinNeighborList& raw_neighbors +) { + auto cell_list = CellList(cell, options.cutoff); + + for (size_t i=0; i second) { + return; + } + + if (first == second) { + // When creating pairs between a point and one of its periodic + // images, the code generate multiple redundant pairs (e.g. with + // shifts 0 1 1 and 0 -1 -1); and we want to only keep one of + // these. + if (shift[0] + shift[1] + shift[2] < 0) { + // drop shifts on the negative half-space + return; + } + + if ((shift[0] + shift[1] + shift[2] == 0) + && (shift[2] < 0 || (shift[2] == 0 && shift[1] < 0))) { + // drop shifts in the negative half plane or the negative + // shift[1] axis. See below for a graphical representation: + // we are keeping the shifts indicated with `O` and dropping + // the ones indicated with `X` + // + // O O O │ O O O + // O O O │ O O O + // O O O │ O O O + // ─X─X─X─┼─O─O─O─ + // X X X │ X X X + // X X X │ X X X + // X X X │ X X X + return; + } + } + } + + auto vector = points[second] - points[first] + shift.cartesian(cell_matrix); + auto distance2 = vector.dot(vector); + + if (distance2 < cutoff2) { + auto index = neighbors.length(); + neighbors.set_pair(index, first, second); + + if (options.return_shifts) { + neighbors.set_shift(index, shift); + } + + if (options.return_distances) { + neighbors.set_distance(index, std::sqrt(distance2)); + } + + if (options.return_vectors) { + neighbors.set_vector(index, vector); + } + + neighbors.increment_length(); + } + }); +} + +/* ========================================================================== */ + +/// Maximal number of cells, we need to use this to prevent having too many +/// cells with a small bounding box and a large cutoff +#define MAX_NUMBER_OF_CELLS 1e5 + + +/// Function to compute both quotient and remainder of the division of a by b. +/// This function follows Python convention, making sure the remainder have the +/// same sign as `b`. +static std::tuple divmod(int32_t a, size_t b) { + assert(b < (std::numeric_limits::max())); + auto b_32 = static_cast(b); + auto quotient = a / b_32; + auto remainder = a % b_32; + if (remainder < 0) { + remainder += b; + quotient -= 1; + } + return std::make_tuple(quotient, remainder); +} + +/// Apply the `divmod` function to three components at the time +static std::tuple, std::array> +divmod(std::array a, std::array b) { + auto [qx, rx] = divmod(a[0], b[0]); + auto [qy, ry] = divmod(a[1], b[1]); + auto [qz, rz] = divmod(a[2], b[2]); + return std::make_tuple( + std::array{qx, qy, qz}, + std::array{rx, ry, rz} + ); +} + +CellList::CellList(BoundingBox box, double cutoff): + n_search_({0, 0, 0}), + cells_shape_({0, 0, 0}), + box_(std::move(box)) +{ + auto distances_between_faces = box_.distances_between_faces(); + + auto n_cells = Vector{ + std::clamp(std::trunc(distances_between_faces[0] / cutoff), 1.0, HUGE_VAL), + std::clamp(std::trunc(distances_between_faces[1] / cutoff), 1.0, HUGE_VAL), + std::clamp(std::trunc(distances_between_faces[2] / cutoff), 1.0, HUGE_VAL), + }; + + assert(std::isfinite(n_cells[0]) && std::isfinite(n_cells[1]) && std::isfinite(n_cells[2])); + + // limit memory consumption by ensuring we have less than `MAX_N_CELLS` + // cells to look though + auto n_cells_total = n_cells[0] * n_cells[1] * n_cells[2]; + if (n_cells_total > MAX_NUMBER_OF_CELLS) { + // set the total number of cells close to MAX_N_CELLS, while keeping + // roughly the ratio of cells in each direction + auto ratio_x_y = n_cells[0] / n_cells[1]; + auto ratio_y_z = n_cells[1] / n_cells[2]; + + n_cells[2] = std::trunc(std::cbrt(MAX_NUMBER_OF_CELLS / (ratio_x_y * ratio_y_z * ratio_y_z))); + n_cells[1] = std::trunc(ratio_y_z * n_cells[2]); + n_cells[0] = std::trunc(ratio_x_y * n_cells[1]); + } + + // number of cells to search in each direction to make sure all possible + // pairs below the cutoff are accounted for. + this->n_search_ = std::array{ + static_cast(std::ceil(cutoff * n_cells[0] / distances_between_faces[0])), + static_cast(std::ceil(cutoff * n_cells[1] / distances_between_faces[1])), + static_cast(std::ceil(cutoff * n_cells[2] / distances_between_faces[2])), + }; + + this->cells_shape_ = std::array{ + static_cast(n_cells[0]), + static_cast(n_cells[1]), + static_cast(n_cells[2]), + }; + + for (size_t spatial=0; spatial<3; spatial++) { + if (n_search_[spatial] < 1) { + n_search_[spatial] = 1; + } + + // don't look for neighboring cells if we have only one cell and no + // periodic boundary condition + if (n_cells[spatial] == 1 && !box.periodic()) { + n_search_[spatial] = 0; + } + } + + this->cells_.resize(cells_shape_[0] * cells_shape_[1] * cells_shape_[2]); +} + +void CellList::add_point(size_t index, Vector position) { + auto fractional = box_.cartesian_to_fractional(position); + + // find the cell in which this atom should go + auto cell_index = std::array{ + static_cast(std::floor(fractional[0] * cells_shape_[0])), + static_cast(std::floor(fractional[1] * cells_shape_[1])), + static_cast(std::floor(fractional[2] * cells_shape_[2])), + }; + + // deal with pbc by wrapping the atom inside if it was outside of the + // cell + CellShift shift; + // auto (shift, cell_index) = + if (box_.periodic()) { + auto result = divmod(cell_index, cells_shape_); + shift = CellShift{std::move(std::get<0>(result))}; + cell_index = std::move(std::get<1>(result)); + } else { + shift = CellShift({0, 0, 0}); + cell_index = std::array{ + std::clamp(cell_index[0], 0, static_cast(cells_shape_[0] - 1)), + std::clamp(cell_index[1], 0, static_cast(cells_shape_[1] - 1)), + std::clamp(cell_index[2], 0, static_cast(cells_shape_[2] - 1)), + }; + } + + this->get_cell(cell_index).emplace_back(Point{index, shift}); +} + + +template +void CellList::foreach_pair(Function callback) { + for (int32_t cell_i_x=0; cell_i_x(cells_shape_[0]); cell_i_x++) { + for (int32_t cell_i_y=0; cell_i_y(cells_shape_[1]); cell_i_y++) { + for (int32_t cell_i_z=0; cell_i_z(cells_shape_[2]); cell_i_z++) { + auto& current_cell = this->get_cell({cell_i_x, cell_i_y, cell_i_z}); + // look through each neighboring cell + for (int32_t delta_x=-n_search_[0]; delta_x<=n_search_[0]; delta_x++) { + for (int32_t delta_y=-n_search_[1]; delta_y<=n_search_[1]; delta_y++) { + for (int32_t delta_z=-n_search_[2]; delta_z<=n_search_[2]; delta_z++) { + auto cell_i = std::array{ + cell_i_x + delta_x, + cell_i_y + delta_y, + cell_i_z + delta_z, + }; + + // shift vector from one cell to the other and index of + // the neighboring cell + auto [cell_shift, neighbor_cell_i] = divmod(cell_i, cells_shape_); + + for (const auto& atom_i: current_cell) { + for (const auto& atom_j: this->get_cell(neighbor_cell_i)) { + auto shift = CellShift{cell_shift} + atom_i.shift - atom_j.shift; + auto shift_is_zero = shift[0] == 0 && shift[1] == 0 && shift[2] == 0; + + if (!box_.periodic() && !shift_is_zero) { + // do not create pairs crossing the periodic + // boundaries in a non-periodic box + continue; + } + + if (atom_i.index == atom_j.index && shift_is_zero) { + // only create pairs with the same atom twice if the + // pair spans more than one bounding box + continue; + } + + callback(atom_i.index, atom_j.index, shift); + } + } // loop over atoms in current neighbor cells + }}} + }}} // loop over neighboring cells +} + +CellList::Cell& CellList::get_cell(std::array index) { + size_t linear_index = (cells_shape_[0] * cells_shape_[1] * index[2]) + + (cells_shape_[0] * index[1]) + + index[0]; + return cells_[linear_index]; +} + +/* ========================================================================== */ + + +void GrowableNeighborsList::set_pair(size_t index, size_t first, size_t second) { + if (index >= this->capacity) { + this->grow(); + } + + this->neighbors.pairs[index][0] = first; + this->neighbors.pairs[index][1] = second; +} + +void GrowableNeighborsList::set_shift(size_t index, vesin::CellShift shift) { + if (index >= this->capacity) { + this->grow(); + } + + this->neighbors.shifts[index][0] = shift[0]; + this->neighbors.shifts[index][1] = shift[1]; + this->neighbors.shifts[index][2] = shift[2]; +} + +void GrowableNeighborsList::set_distance(size_t index, double distance) { + if (index >= this->capacity) { + this->grow(); + } + + this->neighbors.distances[index] = distance; +} + +void GrowableNeighborsList::set_vector(size_t index, vesin::Vector vector) { + if (index >= this->capacity) { + this->grow(); + } + + this->neighbors.vectors[index][0] = vector[0]; + this->neighbors.vectors[index][1] = vector[1]; + this->neighbors.vectors[index][2] = vector[2]; +} + +template +static scalar_t (*alloc(scalar_t (*ptr)[N], size_t size, size_t new_size))[N] { + ptr = reinterpret_cast(std::realloc(ptr, new_size * sizeof(scalar_t[N]))); + + if (ptr == nullptr) { + throw std::bad_alloc(); + } + + // initialize with a bit pattern that maps to NaN for double + std::memset(ptr + size, 0b11111111, (new_size - size) * sizeof(scalar_t[N])); + + return ptr; +} + +template +static scalar_t* alloc(scalar_t* ptr, size_t size, size_t new_size) { + ptr = reinterpret_cast(std::realloc(ptr, new_size * sizeof(scalar_t))); + + if (ptr == nullptr) { + throw std::bad_alloc(); + } + + // initialize with a bit pattern that maps to NaN for double + std::memset(ptr + size, 0b11111111, (new_size - size) * sizeof(scalar_t)); + + return ptr; +} + +void GrowableNeighborsList::grow() { + auto new_size = neighbors.length * 2; + if (new_size == 0) { + new_size = 1; + } + + auto new_pairs = alloc(neighbors.pairs, neighbors.length, new_size); + + int32_t (*new_shifts)[3] = nullptr; + if (options.return_shifts) { + new_shifts = alloc(neighbors.shifts, neighbors.length, new_size); + } + + double *new_distances = nullptr; + if (options.return_distances) { + new_distances = alloc(neighbors.distances, neighbors.length, new_size); + } + + double (*new_vectors)[3] = nullptr; + if (options.return_vectors) { + new_vectors = alloc(neighbors.vectors, neighbors.length, new_size); + } + + this->neighbors.pairs = new_pairs; + this->neighbors.shifts = new_shifts; + this->neighbors.distances = new_distances; + this->neighbors.vectors = new_vectors; + + this->capacity = new_size; +} + +void GrowableNeighborsList::reset() { + // set all allocated data to zero + auto size = this->neighbors.length; + std::memset(this->neighbors.pairs, 0, size * sizeof(size_t[2])); + + if (this->neighbors.shifts != nullptr) { + std::memset(this->neighbors.shifts, 0, size * sizeof(int32_t[3])); + } + + if (this->neighbors.distances != nullptr) { + std::memset(this->neighbors.distances, 0, size * sizeof(double)); + } + + if (this->neighbors.vectors != nullptr) { + std::memset(this->neighbors.vectors, 0, size * sizeof(double[3])); + } + + // reset length (but keep the capacity where it's at) + this->neighbors.length = 0; + + // allocate/deallocate pointers as required + auto shifts = this->neighbors.shifts; + if (this->options.return_shifts && shifts == nullptr) { + shifts = alloc(shifts, 0, capacity); + } else if (!this->options.return_shifts && shifts != nullptr) { + std::free(shifts); + shifts = nullptr; + } + + auto distances = this->neighbors.distances; + if (this->options.return_distances && distances == nullptr) { + distances = alloc(distances, 0, capacity); + } else if (!this->options.return_distances && distances != nullptr) { + std::free(distances); + distances = nullptr; + } + + auto vectors = this->neighbors.vectors; + if (this->options.return_vectors && vectors == nullptr) { + vectors = alloc(vectors, 0, capacity); + } else if (!this->options.return_vectors && vectors != nullptr) { + std::free(vectors); + vectors = nullptr; + } + + this->neighbors.shifts = shifts; + this->neighbors.distances = distances; + this->neighbors.vectors = vectors; +} + + +void vesin::cpu::free_neighbors(VesinNeighborList& neighbors) { + assert(neighbors.device == VesinCPU); + + std::free(neighbors.pairs); + std::free(neighbors.shifts); + std::free(neighbors.vectors); + std::free(neighbors.distances); +} +#include +#include + + + +thread_local std::string LAST_ERROR; + +extern "C" int vesin_neighbors( + const double (*points)[3], + size_t n_points, + const double box[3][3], + bool periodic, + VesinDevice device, + VesinOptions options, + VesinNeighborList* neighbors, + const char** error_message +) { + if (error_message == nullptr) { + return EXIT_FAILURE; + } + + if (points == nullptr) { + *error_message = "`points` can not be a NULL pointer"; + return EXIT_FAILURE; + } + + if (box == nullptr) { + *error_message = "`cell` can not be a NULL pointer"; + return EXIT_FAILURE; + } + + if (neighbors == nullptr) { + *error_message = "`neighbors` can not be a NULL pointer"; + return EXIT_FAILURE; + } + + if (neighbors->device != VesinUnknownDevice && neighbors->device != device) { + *error_message = "`neighbors` device and data `device` do not match, free the neighbors first"; + return EXIT_FAILURE; + } + + if (device == VesinUnknownDevice) { + *error_message = "got an unknown device to use when running simulation"; + return EXIT_FAILURE; + } + + if (neighbors->device == VesinUnknownDevice) { + // initialize the device + neighbors->device = device; + } else if (neighbors->device != device) { + *error_message = "`neighbors.device` and `device` do not match, free the neighbors first"; + return EXIT_FAILURE; + } + + try { + if (device == VesinCPU) { + auto matrix = vesin::Matrix{{{ + {{box[0][0], box[0][1], box[0][2]}}, + {{box[1][0], box[1][1], box[1][2]}}, + {{box[2][0], box[2][1], box[2][2]}}, + }}}; + + vesin::cpu::neighbors( + reinterpret_cast(points), + n_points, + vesin::BoundingBox(matrix, periodic), + options, + *neighbors + ); + } else { + throw std::runtime_error("unknown device " + std::to_string(device)); + } + } catch (const std::bad_alloc&) { + LAST_ERROR = "failed to allocate memory"; + *error_message = LAST_ERROR.c_str(); + return EXIT_FAILURE; + } catch (const std::exception& e) { + LAST_ERROR = e.what(); + *error_message = LAST_ERROR.c_str(); + return EXIT_FAILURE; + } catch (...) { + *error_message = "fatal error: unknown type thrown as exception"; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + + +extern "C" void vesin_free(VesinNeighborList* neighbors) { + if (neighbors == nullptr) { + return; + } + + if (neighbors->device == VesinUnknownDevice) { + // nothing to do + } else if (neighbors->device == VesinCPU) { + vesin::cpu::free_neighbors(*neighbors); + } + + std::memset(neighbors, 0, sizeof(VesinNeighborList)); +} diff --git a/src/metatensor/vesin.h b/src/metatensor/vesin.h new file mode 100644 index 0000000000..e1456693ac --- /dev/null +++ b/src/metatensor/vesin.h @@ -0,0 +1,134 @@ +#ifndef VESIN_H +#define VESIN_H + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Options for a neighbors list calculation +typedef struct VesinOptions { + /// Spherical cutoff, only pairs below this cutoff will be included + double cutoff; + /// Should the returned neighbors list be a full list (include both `i -> j` + /// and `j -> i` pairs) or a half list (include only `i -> j`). + bool full; + // TODO: sort option? + + /// Should the returned `VesinNeighborsList` contain `shifts`? + bool return_shifts; + /// Should the returned `VesinNeighborsList` contain `distances`? + bool return_distances; + /// Should the returned `VesinNeighborsList` contain `vector`? + bool return_vectors; +} VesinOptions; + +/// Device on which the data can be +enum VesinDevice { + /// Unknown device, used for default initialization and to indicate no + /// allocated data. + VesinUnknownDevice = 0, + /// CPU device + VesinCPU = 1, +}; + + +/// The actual neighbors list +/// +/// This is organized as a list of pairs, where each pair can contain the +/// following data: +/// +/// - indices of the points in the pair; +/// - distance between points in the pair, accounting for periodic boundary +/// conditions; +/// - vector between points in the pair, accounting for periodic boundary +/// conditions; +/// - periodic shift that created the pair. This is only relevant when using +/// periodic boundary conditions, and contains the number of bounding box we +/// need to cross to create the pair. If the positions of the points are `r_i` +/// and `r_j`, the bounding box is described by a matrix of three vectors `H`, +/// and the periodic shift is `S`, the distance vector for a given pair will +/// be given by `r_ij = r_j - r_i + S @ H`. +/// +/// Under periodic boundary conditions, two atoms can be part of multiple pairs, +/// each pair having a different periodic shift. +typedef struct VesinNeighborsList { +#ifdef __cplusplus + VesinNeighborsList(): + length(0), + device(VesinUnknownDevice), + pairs(nullptr), + shifts(nullptr), + distances(nullptr), + vectors(nullptr) + {} +#endif + + /// Number of pairs in this neighbors list + size_t length; + /// Device used for the data allocations + VesinDevice device; + /// Array of pairs (storing the indices of the first and second point in the + /// pair), containing `length` elements. + size_t (*pairs)[2]; + /// Array of box shifts, one for each `pair`. This is only set if + /// `options.return_pairs` was `true` during the calculation. + int32_t (*shifts)[3]; + /// Array of pair distance (i.e. distance between the two points), one for + /// each pair. This is only set if `options.return_distances` was `true` + /// during the calculation. + double *distances; + /// Array of pair vector (i.e. vector between the two points), one for + /// each pair. This is only set if `options.return_vector` was `true` + /// during the calculation. + double (*vectors)[3]; + + // TODO: custom memory allocators? +} VesinNeighborList; + +/// Free all allocated memory inside a `VesinNeighborsList`, according the it's +/// `device`. +void vesin_free(VesinNeighborList* neighbors); + +/// Compute a neighbors list. +/// +/// The data is returned in a `VesinNeighborsList`. For an initial call, the +/// `VesinNeighborsList` should be zero-initialized (or default-initalized in +/// C++). The `VesinNeighborsList` can be re-used across calls to this functions +/// to re-use memory allocations, and once it is no longer needed, users should +/// call `vesin_free` to release the corresponding memory. +/// +/// @param points positions of all points in the system; +/// @param n_points number of elements in the `points` array +/// @param box bounding box for the system. If the system is non-periodic, +/// this is ignored. This should contain the three vectors of the bounding +/// box, one vector per row of the matrix. +/// @param periodic is the system using periodic boundary conditions? +/// @param device device where the `points` and `box` data is allocated. +/// @param options options for the calculation +/// @param neighbors non-NULL pointer to `VesinNeighborsList` that will be used +/// to store the computed list of neighbors. +/// @param error_message Pointer to a `char*` that wil be set to the error +/// message if this function fails. This does not need to be freed when no +/// longer needed. +int vesin_neighbors( + const double (*points)[3], + size_t n_points, + const double box[3][3], + bool periodic, + VesinDevice device, + VesinOptions options, + VesinNeighborList* neighbors, + const char** error_message +); + + +#ifdef __cplusplus + +} // extern "C" + +#endif + +#endif