Skip to content

Commit

Permalink
Implementation of the metatensor action
Browse files Browse the repository at this point in the history
This seems to be mostly working, we now need to test it and
quick the tires!
  • Loading branch information
Luthaf committed Apr 8, 2024
1 parent 622e424 commit ccc3768
Show file tree
Hide file tree
Showing 11 changed files with 1,780 additions and 26 deletions.
1 change: 1 addition & 0 deletions regtest/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
!/gridtools
!/clusters
!/unittest
!/metatensor
# These files we just want to ignore completely
tmp
report.txt
2 changes: 2 additions & 0 deletions regtest/metatensor/rt-soap/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
extensions/
soap_cv.pt
1 change: 1 addition & 0 deletions regtest/metatensor/rt-soap/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include ../../scripts/test.make
5 changes: 5 additions & 0 deletions regtest/metatensor/rt-soap/config
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
plumed_modules=metatensor
plumed_needs=metatensor
type=driver
# this is to test a different name
arg="--plumed plumed.dat --ixyz crystal_structure.xyz --dump-forces forces --dump-forces-fmt %8.4f --debug-forces forces.num"
178 changes: 178 additions & 0 deletions regtest/metatensor/rt-soap/crystal_structure.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
176
10.8151 11.1671 12.6671
N 1.28413 0.0356836 12.195
N 4.15648 11.1648 5.88207
N 6.66501 5.56845 0.433575
N 9.55986 5.61216 6.79708
N 3.19304 0.415601 0.79143
N 2.20149 10.7488 7.12262
N 8.66181 5.16664 11.8319
N 7.55431 5.98008 5.60745
O 0.588269 10.7008 10.069
O 4.8823 0.501289 3.74831
O 5.96103 6.07786 2.60388
O 10.243 5.10818 8.93897
O 9.14171 8.47003 10.3675
O 7.07259 2.7051 3.99376
O 3.70069 8.25989 2.31089
O 1.67804 2.88647 8.67235
O 10.6638 1.07613 7.81409
O 5.58016 10.0738 1.4989
O 5.26318 4.50881 4.91221
O 0.188342 6.66055 11.1768
O 2.59101 9.42059 0.233576
O 2.79327 1.7319 6.52566
O 8.0305 7.282 12.4527
O 8.2152 3.85759 6.13971
O 3.82264 2.485 1.51498
O 1.64119 8.66909 7.85617
O 9.20944 3.10331 11.2096
O 7.03065 7.99741 4.82086
O 10.1687 1.63869 1.57326
O 6.06204 9.48052 7.85816
O 4.79055 3.89767 11.091
O 0.625828 7.29331 4.78201
C 2.95238 1.72613 1.06281
C 2.39424 9.44611 7.32799
C 8.34889 3.87138 11.6213
C 7.84063 7.30003 5.3126
C 2.36095 10.6306 0.183825
C 3.05248 0.576928 6.49156
C 7.74299 6.09163 12.4796
C 8.47531 5.07354 6.11657
C 1.11456 1.43798 12.1146
C 4.34028 9.73065 5.75877
C 6.56198 4.11671 0.624847
C 9.70287 7.04878 6.88873
C 1.52986 2.17765 0.678822
C 3.85612 9.03075 7.04374
C 7.00303 3.4499 11.9504
C 9.23522 7.75214 5.61205
C 0.398007 10.3598 11.4259
C 5.00925 0.855939 5.14878
C 5.80886 6.44152 1.24675
C 10.4388 4.72998 7.54811
C 9.68665 10.4586 11.7387
C 6.44378 0.752554 5.38023
C 4.30831 6.3222 0.894521
C 1.08084 4.85311 7.26565
C 9.13597 9.89693 10.3582
C 7.10879 1.28477 4.08154
C 3.70376 6.83428 2.27981
C 1.70103 4.35143 8.59228
C 10.0922 10.4258 9.38103
C 6.11343 0.779889 3.03707
C 4.72635 6.30337 3.28925
C 0.706842 4.86811 9.62774
C 9.63945 0.573215 8.71565
C 6.58754 10.6145 2.36632
C 4.28809 4.99295 3.94499
C 1.17869 6.1342 10.2792
C 1.46661 3.64735 0.563839
C 3.91337 7.45062 6.83304
C 6.89236 1.9212 12.138
C 9.31088 9.26262 5.79701
C 0.49689 1.86112 1.86459
C 4.90182 9.32548 8.20692
C 5.9603 3.7177 10.8188
C 10.2874 7.39666 4.47252
C 1.05177 1.68814 3.25477
C 4.38145 9.48367 9.61965
C 6.45079 3.90566 9.38004
C 9.75263 7.25408 3.08129
C 0.991026 0.195369 3.56821
C 4.32584 10.96 9.93137
C 6.46726 5.35922 9.10466
C 9.80269 5.76208 2.72201
C 0.128048 2.46113 4.27152
C 5.29095 8.70348 10.617
C 5.55851 3.11593 8.44484
C 10.6702 8.04546 2.03703
H 4.10239 11.1459 1.09329
H 1.30213 0.0342548 7.42189
H 9.53281 5.60692 11.5551
H 6.66589 5.53935 5.23088
H 8.64718 8.09618 11.1453
H 7.60533 3.01489 4.78942
H 3.23645 8.63677 1.61018
H 2.20095 2.55921 7.88601
H 0.379193 1.73168 8.24643
H 5.04373 9.3881 1.93904
H 5.78935 3.80089 4.40439
H 10.4567 7.37919 10.7535
H 0.104051 1.73809 11.8608
H 5.30499 9.43917 5.54202
H 5.53299 3.87341 0.812386
H 10.6884 7.36664 7.10621
H 1.78896 1.81187 11.2638
H 3.68439 9.3429 4.9022
H 7.12753 3.75643 1.45471
H 9.08085 7.38722 7.78966
H 0.697568 9.29512 11.6168
H 4.66031 1.90419 5.29401
H 6.09731 7.44726 1.04688
H 10.0692 3.7367 7.3867
H 9.38022 9.89395 12.6339
H 6.79461 1.33138 6.26903
H 3.96278 6.9026 0.0896576
H 1.42227 4.31828 6.41678
H 9.44192 0.322678 11.8104
H 6.81651 10.8324 5.48627
H 3.96238 5.27012 0.847479
H 1.43779 5.92571 7.14938
H 8.11352 10.2845 10.2269
H 8.10148 0.852127 3.86806
H 2.70619 6.46082 2.46064
H 2.7613 4.72198 8.75886
H 10.3267 9.63132 8.60634
H 5.9242 1.48199 2.2331
H 4.8782 7.07518 4.06084
H 0.525666 4.08661 10.4007
H 8.76306 0.369849 8.15227
H 7.48526 10.8057 1.82079
H 3.3422 5.20994 4.52094
H 2.07818 5.96684 10.8643
H 9.41381 1.29284 9.49431
H 6.77874 9.85185 3.13693
H 4.03685 4.26764 3.12524
H 1.38637 6.94525 9.50275
H 2.1846 3.98948 12.4268
H 3.1876 7.08306 6.13166
H 7.60275 1.57516 0.261558
H 8.61406 9.57521 6.56965
H 0.494526 4.01383 0.245591
H 4.94626 7.18663 6.57171
H 5.86353 1.59166 12.4312
H 10.3224 9.5727 6.09657
H 1.77827 4.23366 1.4863
H 3.69279 6.9806 7.78454
H 7.17693 1.38145 11.1517
H 9.05764 9.78475 4.85955
H 2.05845 2.03468 3.34088
H 3.31336 9.105 9.65719
H 7.44442 3.52356 9.32561
H 8.75636 7.64728 3.02387
H 1.7089 10.8043 2.93884
H 3.66568 0.380257 9.29099
H 7.17053 5.95779 9.6717
H 9.10914 5.19244 3.33095
H 1.35776 0.0583836 4.6365
H 4.09743 0.00380453 10.9624
H 6.77825 5.49627 8.02518
H 9.44701 5.64256 1.66552
H 0.0490674 10.9316 3.47795
H 5.36527 0.220938 9.78787
H 5.44707 5.78211 9.23447
H 10.7855 5.34902 2.8569
H 0.46903 2.26044 5.28333
H 4.9299 8.92286 11.6417
H 5.88028 3.32144 7.4085
H 10.3598 7.85218 1.02466
H 0.215689 3.5216 4.08985
H 5.25018 7.63447 10.4261
H 5.58329 2.03707 8.55595
H 10.6224 9.11714 2.23242
H 9.89429 2.11629 4.1664
H 6.30059 9.05835 10.5129
H 4.49554 3.4428 8.48274
H 0.872511 7.66384 2.14979
12 changes: 12 additions & 0 deletions regtest/metatensor/rt-soap/plumed.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
soap: METATENSOR ...
MODEL=soap_cv.pt
EXTENSIONS_DIRECTORY=../extensions
SPECIES=1-40
...


scalar: SUM ARG=soap PERIODIC=NO
BIASVALUE ARG=scalar


PRINT ARG=soap FILE=soap.matx STRIDE=1 FMT=%8.4f
105 changes: 105 additions & 0 deletions regtest/metatensor/rt-soap/soap_cv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Dict, List, Optional

import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import (
MetatensorAtomisticModel,
ModelCapabilities,
ModelMetadata,
ModelOutput,
System,
)
from rascaline.torch import SoapPowerSpectrum


class SOAP_CV(torch.nn.Module):
def __init__(self, species):
super().__init__()

self.neighbor_type_pairs = Labels(
names=["neighbor_1_type", "neighbor_2_type"],
values=torch.tensor(
[[t1, t2] for t1 in species for t2 in species if t1 <= t2]
),
)
self.calculator = SoapPowerSpectrum(
cutoff=4.0,
max_angular=6,
max_radial=6,
radial_basis={"Gto": {}},
cutoff_function={"ShiftedCosine": {"width": 0.5}},
center_atom_weight=1.0,
atomic_gaussian_width=0.3,
)

self.pca_projection = torch.rand(2520, 3, dtype=torch.float64)

def forward(
self,
systems: List[System],
outputs: Dict[str, ModelOutput],
selected_atoms: Optional[Labels],
) -> Dict[str, TensorMap]:

if "collective_variable" not in outputs:
return {}

output = outputs["collective_variable"]

soap = self.calculator(systems, selected_samples=selected_atoms)
soap = soap.keys_to_samples("center_type")
soap = soap.keys_to_properties(self.neighbor_type_pairs)

if not output.per_atom:
raise ValueError("per_atom=False is not supported")

soap_block = soap.block()
projected = soap_block.values @ self.pca_projection

block = TensorBlock(
values=projected,
samples=soap_block.samples,
components=[],
properties=Labels("soap_pca", torch.tensor([[0], [1], [2]])),
)
cv = TensorMap(keys=Labels("_", torch.tensor([[0]])), blocks=[block])

return {"collective_variable": cv}


cv = SOAP_CV(species=[1, 6, 7, 8])
cv.eval()


capabilites = ModelCapabilities(
outputs={
"collective_variable": ModelOutput(
quantity="",
unit="",
per_atom=True,
explicit_gradients=["postions"],
)
},
interaction_range=4.0,
supported_devices=["cpu", "cuda"],
length_unit="nm",
atomic_types=[6, 1, 7, 8],
# dtype=TODO
)

metadata = ModelMetadata(
name="Collective Variable test",
description="""
A simple collective variable for testing purposes
""",
authors=["..."],
references={
"implementation": ["ref to SOAP code"],
"architecture": ["ref to SOAP"],
"model": ["ref to paper"],
},
)


model = MetatensorAtomisticModel(cv, metadata, capabilites)
model.export("soap_cv.pt", collect_extensions="extensions")
12 changes: 9 additions & 3 deletions src/metatensor/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ cmake --build . --target install --parallel
METATENSOR_CMAKE_PREFIX=$(python -c "import metatensor; print(metatensor.utils.cmake_prefix_path)")
METATENSOR_PREFIX=$(cd "$METATENSOR_CMAKE_PREFIX/../.." && pwd)

METATENSOR_TORCH_CMAKE_PREFIX=$(python -c "import torch; print(torch.utils.cmake_prefix_path)")
METATENSOR_TORCH_CMAKE_PREFIX=$(python -c "import metatensor.torch; print(metatensor.torch.utils.cmake_prefix_path)")
METATENSOR_TORCH_PREFIX=$(cd "$METATENSOR_TORCH_CMAKE_PREFIX/../.." && pwd)
```

Expand All @@ -61,18 +61,24 @@ METATENSOR_TORCH_PREFIX=$(cd "$METATENSOR_TORCH_CMAKE_PREFIX/../.." && pwd)
```bash
cd <PLUMED/DIR>

# set the rpath to make sure plumed executable will be able to find the right libraries
RPATH="-Wl,-rpath,$TORCH_PREFIX/lib -Wl,-rpath,$METATENSOR_PREFIX/lib -Wl,-rpath,$METATENSOR_TORCH_PREFIX/lib"

# configure PLUMED with metatensor
./configure --enable-libtorch --enable-metatensor --enable-modules=+metatensor \
LDFLAGS="-L$TORCH_PREFIX/lib -L$METATENSOR_PREFIX/lib -L$METATENSOR_TORCH_PREFIX/lib" \
LDFLAGS="-L$TORCH_PREFIX/lib -L$METATENSOR_PREFIX/lib -L$METATENSOR_TORCH_PREFIX/lib $RPATH" \
CPPFLAGS="$TORCH_INCLUDES -I$METATENSOR_PREFIX/include -I$METATENSOR_TORCH_PREFIX/include"

# If you are on Linux and use a pip-installed version of libtorch, or the
# pre-cxx11-ABI build of libtorch, you'll need to add "-D_GLIBCXX_USE_CXX11_ABI=0"
# to the compilation flags:
./configure --enable-libtorch --enable-metatensor --enable-modules=+metatensor \
LDFLAGS="-L$TORCH_PREFIX/lib -L$METATENSOR_PREFIX/lib -L$METATENSOR_TORCH_PREFIX/lib" \
LDFLAGS="-L$TORCH_PREFIX/lib -L$METATENSOR_PREFIX/lib -L$METATENSOR_TORCH_PREFIX/lib $RPATH" \
CPPFLAGS="$TORCH_INCLUDES -I$METATENSOR_PREFIX/include -I$METATENSOR_TORCH_PREFIX/include" \
CXXFLAGS="-D_GLIBCXX_USE_CXX11_ABI=0"

make -j && make install
```


<!-- TODO: explain vesin update process -->
Loading

0 comments on commit ccc3768

Please sign in to comment.