forked from plumed/plumed2
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of the metatensor action
This seems to be mostly working, we now need to test it and quick the tires!
- Loading branch information
Showing
11 changed files
with
1,780 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,7 @@ | |
!/gridtools | ||
!/clusters | ||
!/unittest | ||
!/metatensor | ||
# These files we just want to ignore completely | ||
tmp | ||
report.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
extensions/ | ||
soap_cv.pt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include ../../scripts/test.make |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
plumed_modules=metatensor | ||
plumed_needs=metatensor | ||
type=driver | ||
# this is to test a different name | ||
arg="--plumed plumed.dat --ixyz crystal_structure.xyz --dump-forces forces --dump-forces-fmt %8.4f --debug-forces forces.num" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
176 | ||
10.8151 11.1671 12.6671 | ||
N 1.28413 0.0356836 12.195 | ||
N 4.15648 11.1648 5.88207 | ||
N 6.66501 5.56845 0.433575 | ||
N 9.55986 5.61216 6.79708 | ||
N 3.19304 0.415601 0.79143 | ||
N 2.20149 10.7488 7.12262 | ||
N 8.66181 5.16664 11.8319 | ||
N 7.55431 5.98008 5.60745 | ||
O 0.588269 10.7008 10.069 | ||
O 4.8823 0.501289 3.74831 | ||
O 5.96103 6.07786 2.60388 | ||
O 10.243 5.10818 8.93897 | ||
O 9.14171 8.47003 10.3675 | ||
O 7.07259 2.7051 3.99376 | ||
O 3.70069 8.25989 2.31089 | ||
O 1.67804 2.88647 8.67235 | ||
O 10.6638 1.07613 7.81409 | ||
O 5.58016 10.0738 1.4989 | ||
O 5.26318 4.50881 4.91221 | ||
O 0.188342 6.66055 11.1768 | ||
O 2.59101 9.42059 0.233576 | ||
O 2.79327 1.7319 6.52566 | ||
O 8.0305 7.282 12.4527 | ||
O 8.2152 3.85759 6.13971 | ||
O 3.82264 2.485 1.51498 | ||
O 1.64119 8.66909 7.85617 | ||
O 9.20944 3.10331 11.2096 | ||
O 7.03065 7.99741 4.82086 | ||
O 10.1687 1.63869 1.57326 | ||
O 6.06204 9.48052 7.85816 | ||
O 4.79055 3.89767 11.091 | ||
O 0.625828 7.29331 4.78201 | ||
C 2.95238 1.72613 1.06281 | ||
C 2.39424 9.44611 7.32799 | ||
C 8.34889 3.87138 11.6213 | ||
C 7.84063 7.30003 5.3126 | ||
C 2.36095 10.6306 0.183825 | ||
C 3.05248 0.576928 6.49156 | ||
C 7.74299 6.09163 12.4796 | ||
C 8.47531 5.07354 6.11657 | ||
C 1.11456 1.43798 12.1146 | ||
C 4.34028 9.73065 5.75877 | ||
C 6.56198 4.11671 0.624847 | ||
C 9.70287 7.04878 6.88873 | ||
C 1.52986 2.17765 0.678822 | ||
C 3.85612 9.03075 7.04374 | ||
C 7.00303 3.4499 11.9504 | ||
C 9.23522 7.75214 5.61205 | ||
C 0.398007 10.3598 11.4259 | ||
C 5.00925 0.855939 5.14878 | ||
C 5.80886 6.44152 1.24675 | ||
C 10.4388 4.72998 7.54811 | ||
C 9.68665 10.4586 11.7387 | ||
C 6.44378 0.752554 5.38023 | ||
C 4.30831 6.3222 0.894521 | ||
C 1.08084 4.85311 7.26565 | ||
C 9.13597 9.89693 10.3582 | ||
C 7.10879 1.28477 4.08154 | ||
C 3.70376 6.83428 2.27981 | ||
C 1.70103 4.35143 8.59228 | ||
C 10.0922 10.4258 9.38103 | ||
C 6.11343 0.779889 3.03707 | ||
C 4.72635 6.30337 3.28925 | ||
C 0.706842 4.86811 9.62774 | ||
C 9.63945 0.573215 8.71565 | ||
C 6.58754 10.6145 2.36632 | ||
C 4.28809 4.99295 3.94499 | ||
C 1.17869 6.1342 10.2792 | ||
C 1.46661 3.64735 0.563839 | ||
C 3.91337 7.45062 6.83304 | ||
C 6.89236 1.9212 12.138 | ||
C 9.31088 9.26262 5.79701 | ||
C 0.49689 1.86112 1.86459 | ||
C 4.90182 9.32548 8.20692 | ||
C 5.9603 3.7177 10.8188 | ||
C 10.2874 7.39666 4.47252 | ||
C 1.05177 1.68814 3.25477 | ||
C 4.38145 9.48367 9.61965 | ||
C 6.45079 3.90566 9.38004 | ||
C 9.75263 7.25408 3.08129 | ||
C 0.991026 0.195369 3.56821 | ||
C 4.32584 10.96 9.93137 | ||
C 6.46726 5.35922 9.10466 | ||
C 9.80269 5.76208 2.72201 | ||
C 0.128048 2.46113 4.27152 | ||
C 5.29095 8.70348 10.617 | ||
C 5.55851 3.11593 8.44484 | ||
C 10.6702 8.04546 2.03703 | ||
H 4.10239 11.1459 1.09329 | ||
H 1.30213 0.0342548 7.42189 | ||
H 9.53281 5.60692 11.5551 | ||
H 6.66589 5.53935 5.23088 | ||
H 8.64718 8.09618 11.1453 | ||
H 7.60533 3.01489 4.78942 | ||
H 3.23645 8.63677 1.61018 | ||
H 2.20095 2.55921 7.88601 | ||
H 0.379193 1.73168 8.24643 | ||
H 5.04373 9.3881 1.93904 | ||
H 5.78935 3.80089 4.40439 | ||
H 10.4567 7.37919 10.7535 | ||
H 0.104051 1.73809 11.8608 | ||
H 5.30499 9.43917 5.54202 | ||
H 5.53299 3.87341 0.812386 | ||
H 10.6884 7.36664 7.10621 | ||
H 1.78896 1.81187 11.2638 | ||
H 3.68439 9.3429 4.9022 | ||
H 7.12753 3.75643 1.45471 | ||
H 9.08085 7.38722 7.78966 | ||
H 0.697568 9.29512 11.6168 | ||
H 4.66031 1.90419 5.29401 | ||
H 6.09731 7.44726 1.04688 | ||
H 10.0692 3.7367 7.3867 | ||
H 9.38022 9.89395 12.6339 | ||
H 6.79461 1.33138 6.26903 | ||
H 3.96278 6.9026 0.0896576 | ||
H 1.42227 4.31828 6.41678 | ||
H 9.44192 0.322678 11.8104 | ||
H 6.81651 10.8324 5.48627 | ||
H 3.96238 5.27012 0.847479 | ||
H 1.43779 5.92571 7.14938 | ||
H 8.11352 10.2845 10.2269 | ||
H 8.10148 0.852127 3.86806 | ||
H 2.70619 6.46082 2.46064 | ||
H 2.7613 4.72198 8.75886 | ||
H 10.3267 9.63132 8.60634 | ||
H 5.9242 1.48199 2.2331 | ||
H 4.8782 7.07518 4.06084 | ||
H 0.525666 4.08661 10.4007 | ||
H 8.76306 0.369849 8.15227 | ||
H 7.48526 10.8057 1.82079 | ||
H 3.3422 5.20994 4.52094 | ||
H 2.07818 5.96684 10.8643 | ||
H 9.41381 1.29284 9.49431 | ||
H 6.77874 9.85185 3.13693 | ||
H 4.03685 4.26764 3.12524 | ||
H 1.38637 6.94525 9.50275 | ||
H 2.1846 3.98948 12.4268 | ||
H 3.1876 7.08306 6.13166 | ||
H 7.60275 1.57516 0.261558 | ||
H 8.61406 9.57521 6.56965 | ||
H 0.494526 4.01383 0.245591 | ||
H 4.94626 7.18663 6.57171 | ||
H 5.86353 1.59166 12.4312 | ||
H 10.3224 9.5727 6.09657 | ||
H 1.77827 4.23366 1.4863 | ||
H 3.69279 6.9806 7.78454 | ||
H 7.17693 1.38145 11.1517 | ||
H 9.05764 9.78475 4.85955 | ||
H 2.05845 2.03468 3.34088 | ||
H 3.31336 9.105 9.65719 | ||
H 7.44442 3.52356 9.32561 | ||
H 8.75636 7.64728 3.02387 | ||
H 1.7089 10.8043 2.93884 | ||
H 3.66568 0.380257 9.29099 | ||
H 7.17053 5.95779 9.6717 | ||
H 9.10914 5.19244 3.33095 | ||
H 1.35776 0.0583836 4.6365 | ||
H 4.09743 0.00380453 10.9624 | ||
H 6.77825 5.49627 8.02518 | ||
H 9.44701 5.64256 1.66552 | ||
H 0.0490674 10.9316 3.47795 | ||
H 5.36527 0.220938 9.78787 | ||
H 5.44707 5.78211 9.23447 | ||
H 10.7855 5.34902 2.8569 | ||
H 0.46903 2.26044 5.28333 | ||
H 4.9299 8.92286 11.6417 | ||
H 5.88028 3.32144 7.4085 | ||
H 10.3598 7.85218 1.02466 | ||
H 0.215689 3.5216 4.08985 | ||
H 5.25018 7.63447 10.4261 | ||
H 5.58329 2.03707 8.55595 | ||
H 10.6224 9.11714 2.23242 | ||
H 9.89429 2.11629 4.1664 | ||
H 6.30059 9.05835 10.5129 | ||
H 4.49554 3.4428 8.48274 | ||
H 0.872511 7.66384 2.14979 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
soap: METATENSOR ... | ||
MODEL=soap_cv.pt | ||
EXTENSIONS_DIRECTORY=../extensions | ||
SPECIES=1-40 | ||
... | ||
|
||
|
||
scalar: SUM ARG=soap PERIODIC=NO | ||
BIASVALUE ARG=scalar | ||
|
||
|
||
PRINT ARG=soap FILE=soap.matx STRIDE=1 FMT=%8.4f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from typing import Dict, List, Optional | ||
|
||
import torch | ||
from metatensor.torch import Labels, TensorBlock, TensorMap | ||
from metatensor.torch.atomistic import ( | ||
MetatensorAtomisticModel, | ||
ModelCapabilities, | ||
ModelMetadata, | ||
ModelOutput, | ||
System, | ||
) | ||
from rascaline.torch import SoapPowerSpectrum | ||
|
||
|
||
class SOAP_CV(torch.nn.Module): | ||
def __init__(self, species): | ||
super().__init__() | ||
|
||
self.neighbor_type_pairs = Labels( | ||
names=["neighbor_1_type", "neighbor_2_type"], | ||
values=torch.tensor( | ||
[[t1, t2] for t1 in species for t2 in species if t1 <= t2] | ||
), | ||
) | ||
self.calculator = SoapPowerSpectrum( | ||
cutoff=4.0, | ||
max_angular=6, | ||
max_radial=6, | ||
radial_basis={"Gto": {}}, | ||
cutoff_function={"ShiftedCosine": {"width": 0.5}}, | ||
center_atom_weight=1.0, | ||
atomic_gaussian_width=0.3, | ||
) | ||
|
||
self.pca_projection = torch.rand(2520, 3, dtype=torch.float64) | ||
|
||
def forward( | ||
self, | ||
systems: List[System], | ||
outputs: Dict[str, ModelOutput], | ||
selected_atoms: Optional[Labels], | ||
) -> Dict[str, TensorMap]: | ||
|
||
if "collective_variable" not in outputs: | ||
return {} | ||
|
||
output = outputs["collective_variable"] | ||
|
||
soap = self.calculator(systems, selected_samples=selected_atoms) | ||
soap = soap.keys_to_samples("center_type") | ||
soap = soap.keys_to_properties(self.neighbor_type_pairs) | ||
|
||
if not output.per_atom: | ||
raise ValueError("per_atom=False is not supported") | ||
|
||
soap_block = soap.block() | ||
projected = soap_block.values @ self.pca_projection | ||
|
||
block = TensorBlock( | ||
values=projected, | ||
samples=soap_block.samples, | ||
components=[], | ||
properties=Labels("soap_pca", torch.tensor([[0], [1], [2]])), | ||
) | ||
cv = TensorMap(keys=Labels("_", torch.tensor([[0]])), blocks=[block]) | ||
|
||
return {"collective_variable": cv} | ||
|
||
|
||
cv = SOAP_CV(species=[1, 6, 7, 8]) | ||
cv.eval() | ||
|
||
|
||
capabilites = ModelCapabilities( | ||
outputs={ | ||
"collective_variable": ModelOutput( | ||
quantity="", | ||
unit="", | ||
per_atom=True, | ||
explicit_gradients=["postions"], | ||
) | ||
}, | ||
interaction_range=4.0, | ||
supported_devices=["cpu", "cuda"], | ||
length_unit="nm", | ||
atomic_types=[6, 1, 7, 8], | ||
# dtype=TODO | ||
) | ||
|
||
metadata = ModelMetadata( | ||
name="Collective Variable test", | ||
description=""" | ||
A simple collective variable for testing purposes | ||
""", | ||
authors=["..."], | ||
references={ | ||
"implementation": ["ref to SOAP code"], | ||
"architecture": ["ref to SOAP"], | ||
"model": ["ref to paper"], | ||
}, | ||
) | ||
|
||
|
||
model = MetatensorAtomisticModel(cv, metadata, capabilites) | ||
model.export("soap_cv.pt", collect_extensions="extensions") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.