Skip to content

Commit

Permalink
FIXED TESTS AND LINTERS AGAIN!!!!
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Nov 14, 2023
1 parent 4a4ed59 commit 25586a0
Show file tree
Hide file tree
Showing 8 changed files with 429 additions and 341 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ jobs:
python-version: "3.8"
- os: macos-11
python-version: "3.11"
- os: windows-2019
python-version: "3.8"
- os: windows-2019
python-version: "3.11"
#- os: windows-2019
# python-version: "3.8"
#- os: windows-2019
# python-version: "3.11"

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ keywords = [
"Atomistic Simulations",
]
dependencies = [
"scipy",
"torch >= 1.11",
"metatensor[torch]",
"sphericart[torch]",
]
dynamic = ["version"]

Expand Down
85 changes: 48 additions & 37 deletions src/meshlode/fourier.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import torch
import math
import math
from time import time

from typing import Optional
from metatensor.torch import TensorBlock
from .system import System
import torch

from .mesh import Mesh

from time import time

# TODO we don't really need to re-compute the Fourier mesh at each call. one could separate the construction of the grid and the update of the values
# TODO we don't really need to re-compute the Fourier mesh at each call. one could
# separate the construction of the grid and the update of the values
class FourierFilter(torch.nn.Module):
def __init__(self, kspace_filter="coulomb", kzero_value=None):
"""
The `kspace_filter` argument defines a R->R function that is applied to the squared norm of the k vectors
The `kspace_filter` argument defines a R->R function that is applied to the
squared norm of the k vectors
"""

super(FourierFilter, self).__init__()
Expand All @@ -22,28 +21,38 @@ def __init__(self, kspace_filter="coulomb", kzero_value=None):
self.kspace_filter = torch.reciprocal
self.kzero_value = 0.0
else:
self.kspace_filter = kspace_filter
self.kspace_filter = kspace_filter

self.timings=dict(n_eval=0, r2k=0, k2r=0, filter=0,
filter_grid=0, filter_calc=0, filter_prod=0)
self.timings = dict(
n_eval=0,
r2k=0,
k2r=0,
filter=0,
filter_grid=0,
filter_calc=0,
filter_prod=0,
)
pass

def compute_r2k(self, mesh: Mesh) -> Mesh:

k_size = math.pi*2/mesh.spacing
k_mesh = Mesh(torch.eye(3)*k_size, n_channels=mesh.n_channels,
mesh_resolution=k_size/mesh.n_mesh,
mesh_style="rfft",
dtype=torch.complex64)

k_mesh.values[:] = torch.fft.rfftn(mesh.values, norm="ortho", dim=(1,2,3))

k_size = math.pi * 2 / mesh.spacing
k_mesh = Mesh(
torch.eye(3) * k_size,
n_channels=mesh.n_channels,
mesh_resolution=k_size / mesh.n_mesh,
mesh_style="rfft",
dtype=torch.complex64,
)

k_mesh.values[:] = torch.fft.rfftn(mesh.values, norm="ortho", dim=(1, 2, 3))

return k_mesh

def apply_filter(self, k_mesh: Mesh) -> Mesh:
self.timings["filter_grid"] -= time()
kxs, kys, kzs = torch.meshgrid(k_mesh.grid_x, k_mesh.grid_y, k_mesh.grid_z,
indexing="ij")

def apply_filter(self, k_mesh: Mesh) -> Mesh:
self.timings["filter_grid"] -= time()
kxs, kys, kzs = torch.meshgrid(
k_mesh.grid_x, k_mesh.grid_y, k_mesh.grid_z, indexing="ij"
)
self.timings["filter_grid"] += time()

self.timings["filter_calc"] -= time()
Expand All @@ -54,34 +63,36 @@ def apply_filter(self, k_mesh: Mesh) -> Mesh:
self.timings["filter_prod"] -= time()
k_mesh.values *= k_filter
if self.kzero_value is not None:
k_mesh.values[:,0,0,0] = self.kzero_value
k_mesh.values[:, 0, 0, 0] = self.kzero_value
self.timings["filter_prod"] += time()

pass

def compute_k2r(self, k_mesh: Mesh) -> Mesh:
box_size = math.pi * 2 / k_mesh.spacing
mesh = Mesh(
torch.eye(3) * box_size,
k_mesh.n_channels,
mesh_resolution=box_size / k_mesh.n_mesh,
dtype=torch.float64,
)

mesh.values[:] = torch.fft.irfftn(k_mesh.values, norm="ortho", dim=(1, 2, 3))

box_size = math.pi*2/k_mesh.spacing
mesh = Mesh(torch.eye(3)*box_size, k_mesh.n_channels, mesh_resolution=box_size/k_mesh.n_mesh, dtype=torch.float64)

mesh.values[:] = torch.fft.irfftn(k_mesh.values, norm="ortho", dim=(1,2,3))

return mesh

def forward(self, mesh:Mesh) -> Mesh:

self.timings["n_eval"]+=1
def forward(self, mesh: Mesh) -> Mesh:
self.timings["n_eval"] += 1
self.timings["r2k"] -= time()
k_mesh = self.compute_r2k(mesh)
self.timings["r2k"] += time()

self.timings["filter"] -= time()
self.apply_filter(k_mesh)
self.timings["filter"] += time()

self.timings["k2r"] -= time()
rval=self.compute_k2r(k_mesh)
rval = self.compute_k2r(k_mesh)
self.timings["k2r"] += time()

return rval

Loading

0 comments on commit 25586a0

Please sign in to comment.