Skip to content

Commit

Permalink
lint and test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ceriottm committed Dec 3, 2023
1 parent a5ae480 commit 276686a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 61 deletions.
138 changes: 78 additions & 60 deletions examples/lode-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,173 +14,191 @@
"""


# %%
# %%
# Loads libraries and defines utility functions
# ---------------------------------------------
#
#

import numpy as np
import matplotlib.pyplot as plt
import torch

torch.set_default_dtype(torch.float64)

import chemiscope, ase
import meshlode as ml


# plot a 3D mesh with a stack of 2D plots
def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None):
mesh = mesh.detach().numpy()
if vmin is None:
vmin = mesh.min()
if vmax is None:
vmax = mesh.max()
fig, ax = plt.subplots(1,mesh.shape[-1],figsize=(sz,sz/mesh.shape[-1]), sharey=True, constrained_layout=True)
fig, ax = plt.subplots(
1,
mesh.shape[-1],
figsize=(sz, sz / mesh.shape[-1]),
sharey=True,
constrained_layout=True,
)
for i in range(mesh.shape[-1]):
ax[i].matshow(mesh[:,:,i], vmin=vmin, vmax=vmax, cmap=cmap)
ax[i].matshow(mesh[:, :, i], vmin=vmin, vmax=vmax, cmap=cmap)
ax[i].set_xticklabels([])
ax[i].set_yticklabels([])


# %%
# %%
# Builds the structure
# --------------------
#
#
# Nothing special to see here. Builds a CsCl structure by replicating the
# primitive cell. Add a bit of noise to make it less boring!
#
#

positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])*4
positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) * 4
atomic_numbers = torch.tensor([55, 17]) # Cs and Cl
cell = torch.eye(3)*4
ase_frame = ase.Atoms(positions=positions, cell=cell, numbers=atomic_numbers).repeat([2,2,2])
ase_frame.positions[:] += np.random.normal(size=ase_frame.positions.shape)*0.1
charges = torch.tensor([1.0, -1.0]*8)
frame = ml.System(species = torch.tensor(ase_frame.numbers),
positions = torch.tensor(np.array(ase_frame.positions)),
cell = torch.tensor(ase_frame.cell))

cs = chemiscope.show(frames=[ase_frame],
mode="structure", settings={"structure":[{"unitCell":True, "axes":"xyz"}]})
cell = torch.eye(3) * 4
ase_frame = ase.Atoms(positions=positions, cell=cell, numbers=atomic_numbers).repeat(
[2, 2, 2]
)
ase_frame.positions[:] += np.random.normal(size=ase_frame.positions.shape) * 0.1
charges = torch.tensor([1.0, -1.0] * 8)
frame = ml.System(
species=torch.tensor(ase_frame.numbers),
positions=torch.tensor(np.array(ase_frame.positions)),
cell=torch.tensor(ase_frame.cell),
)

cs = chemiscope.show(
frames=[ase_frame],
mode="structure",
settings={"structure": [{"unitCell": True, "axes": "xyz"}]},
)

if chemiscope.jupyter._is_running_in_notebook():
from IPython.display import display

display(cs)
else:
cs.save("cscl.json.gz")


# %%
# %%
# MeshInterpolator
# ----------------
#
#
# ``MeshInterpolator`` serves as a utility class to compute a mesh
# representation of points, and/or to project a function defined on the
# mesh on a set of points. Computing the mesh representation is a two-step
# procedure. First, the weights associated with the interpolation of the
# point positions are evaluated, then they are combined with one or more
# list of atom weights to yield the mesh values.
#
#

interpol = ml.mesh_interpolator.MeshInterpolator(frame.cell,
torch.tensor([16,16,16]),
interpolation_order=3)
interpol = ml.mesh_interpolator.MeshInterpolator(
frame.cell, torch.tensor([16, 16, 16]), interpolation_order=3
)

interpol.compute_interpolation_weights(frame.positions)


# %%
# %%
# We use two sets of weights: ones (giving the atom density irrespective
# of the species) and charges (giving a smooth representation of the point
# charges).
#
#

atom_weights = torch.ones((len(charges),2))
atom_weights[:,1] = charges
atom_weights = torch.ones((len(charges), 2))
atom_weights[:, 1] = charges
mesh = interpol.points_to_mesh(atom_weights)

# there are two densities
mesh.shape

# %%
# the first corresponds to plain density
sliceplot(mesh[0,:,:,:5])
sliceplot(mesh[0, :, :, :5])

# %%
# the second to the charge-weighted one
sliceplot(mesh[1,:,:,:5], cmap="seismic", vmax=1, vmin=-1)
sliceplot(mesh[1, :, :, :5], cmap="seismic", vmax=1, vmin=-1)


# %%
# %%
# Fourier filter
# --------------
#
#
# This module computes a Fourier-domain filter, that can be used e.g. to
# smear the density and/or compute a 1/r^p potential field. This can also
# be easily extended to compute an arbitrary filter
#
#

fsc = ml.fourier_convolution.FourierSpaceConvolution(frame.cell)

# %%
# plain smearing
rho_mesh = fsc.compute(mesh,
potential_exponent=0,
smearing=1)
# plain smearing
rho_mesh = fsc.compute(mesh, potential_exponent=0, smearing=1)

sliceplot(rho_mesh[0,:,:,:5])
sliceplot(rho_mesh[0, :, :, :5])

# %%
# coulomb-like potential, no smearing
coulomb_mesh = fsc.compute(mesh,
potential_exponent=1,
smearing=0)
coulomb_mesh = fsc.compute(mesh, potential_exponent=1, smearing=0)

sliceplot(coulomb_mesh[1,:,:,:5], cmap="seismic")
sliceplot(coulomb_mesh[1, :, :, :5], cmap="seismic")


# %%
# %%
# Back-interpolation (on the same points)
# ---------------------------------------
#
#
# The same ``MeshInterpolator`` object can be used to compute a field on
# the same points used initially to generate the atom density
#
#

potentials = interpol.mesh_to_points(coulomb_mesh)

potentials


# %%
# %%
# Back-interpolation (on different points)
# ----------------------------------------
#
#
# In order to compute the field on a different set of points, it is
# sufficient to build another ``MeshInterpolator`` object and to compute
# it with the desired field. One can also use a different
# ``interpolation_order``, if wanted.
#
#

interpol_slice = ml.mesh_interpolator.MeshInterpolator(frame.cell,
torch.tensor([16,16,16]),
interpolation_order=4)
interpol_slice = ml.mesh_interpolator.MeshInterpolator(
frame.cell, torch.tensor([16, 16, 16]), interpolation_order=4
)

# Compute a denser grid on a 2D slice
n_points=50
x = torch.linspace(0, frame.cell[0,0], n_points+1)[:n_points]
y = torch.linspace(0, frame.cell[1,1], n_points+1)[:n_points]
xx, yy = torch.meshgrid(x, y, indexing='ij')
n_points = 50
x = torch.linspace(0, frame.cell[0, 0], n_points + 1)[:n_points]
y = torch.linspace(0, frame.cell[1, 1], n_points + 1)[:n_points]
xx, yy = torch.meshgrid(x, y, indexing="ij")

# Flatten xx and yy, and concatenate with a zero column for the z-coordinate
slice_points = torch.cat((xx.reshape(-1, 1), yy.reshape(-1, 1),
0.5*torch.ones(n_points**2, 1)), dim=1)
slice_points = torch.cat(
(xx.reshape(-1, 1), yy.reshape(-1, 1), 0.5 * torch.ones(n_points**2, 1)), dim=1
)


# %%
interpol_slice.compute_interpolation_weights(slice_points)

coulomb_slice = interpol_slice.mesh_to_points(coulomb_mesh)

plt.contourf(xx, yy,
coulomb_slice[:,1].reshape(n_points, n_points).T,
cmap="seismic", vmin=-1, vmax=1);
plt.contourf(
xx,
yy,
coulomb_slice[:, 1].reshape(n_points, n_points).T,
cmap="seismic",
vmin=-1,
vmax=1,
)
4 changes: 3 additions & 1 deletion tests/test_fourier_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def test_convolution_for_delta(self, cell, mesh_vals):
n_channels, nx, ny, nz = mesh_vals.shape
n_fft = nx * ny * nz
FSC = FourierSpaceConvolution(cell)
mesh_vals_new = FSC.compute(mesh_vals, potential_exponent=0) * volume / n_fft
mesh_vals_new = (
FSC.compute(mesh_vals, potential_exponent=0, smearing=0.0) * volume / n_fft
)

assert_close(mesh_vals, mesh_vals_new, rtol=1e-4, atol=1e-6)

0 comments on commit 276686a

Please sign in to comment.