diff --git a/examples/lode-demo.py b/examples/lode-demo.py index c4c24477..d510a6c6 100644 --- a/examples/lode-demo.py +++ b/examples/lode-demo.py @@ -14,19 +14,21 @@ """ -# %% +# %% # Loads libraries and defines utility functions # --------------------------------------------- -# +# import numpy as np import matplotlib.pyplot as plt import torch + torch.set_default_dtype(torch.float64) import chemiscope, ase import meshlode as ml + # plot a 3D mesh with a stack of 2D plots def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): mesh = mesh.detach().numpy() @@ -34,68 +36,82 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): vmin = mesh.min() if vmax is None: vmax = mesh.max() - fig, ax = plt.subplots(1,mesh.shape[-1],figsize=(sz,sz/mesh.shape[-1]), sharey=True, constrained_layout=True) + fig, ax = plt.subplots( + 1, + mesh.shape[-1], + figsize=(sz, sz / mesh.shape[-1]), + sharey=True, + constrained_layout=True, + ) for i in range(mesh.shape[-1]): - ax[i].matshow(mesh[:,:,i], vmin=vmin, vmax=vmax, cmap=cmap) + ax[i].matshow(mesh[:, :, i], vmin=vmin, vmax=vmax, cmap=cmap) ax[i].set_xticklabels([]) ax[i].set_yticklabels([]) -# %% +# %% # Builds the structure # -------------------- -# +# # Nothing special to see here. Builds a CsCl structure by replicating the # primitive cell. Add a bit of noise to make it less boring! -# +# -positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]])*4 +positions = torch.tensor([[0, 0, 0], [0.5, 0.5, 0.5]]) * 4 atomic_numbers = torch.tensor([55, 17]) # Cs and Cl -cell = torch.eye(3)*4 -ase_frame = ase.Atoms(positions=positions, cell=cell, numbers=atomic_numbers).repeat([2,2,2]) -ase_frame.positions[:] += np.random.normal(size=ase_frame.positions.shape)*0.1 -charges = torch.tensor([1.0, -1.0]*8) -frame = ml.System(species = torch.tensor(ase_frame.numbers), - positions = torch.tensor(np.array(ase_frame.positions)), - cell = torch.tensor(ase_frame.cell)) - -cs = chemiscope.show(frames=[ase_frame], - mode="structure", settings={"structure":[{"unitCell":True, "axes":"xyz"}]}) +cell = torch.eye(3) * 4 +ase_frame = ase.Atoms(positions=positions, cell=cell, numbers=atomic_numbers).repeat( + [2, 2, 2] +) +ase_frame.positions[:] += np.random.normal(size=ase_frame.positions.shape) * 0.1 +charges = torch.tensor([1.0, -1.0] * 8) +frame = ml.System( + species=torch.tensor(ase_frame.numbers), + positions=torch.tensor(np.array(ase_frame.positions)), + cell=torch.tensor(ase_frame.cell), +) + +cs = chemiscope.show( + frames=[ase_frame], + mode="structure", + settings={"structure": [{"unitCell": True, "axes": "xyz"}]}, +) if chemiscope.jupyter._is_running_in_notebook(): from IPython.display import display + display(cs) else: cs.save("cscl.json.gz") -# %% +# %% # MeshInterpolator # ---------------- -# +# # ``MeshInterpolator`` serves as a utility class to compute a mesh # representation of points, and/or to project a function defined on the # mesh on a set of points. Computing the mesh representation is a two-step # procedure. First, the weights associated with the interpolation of the # point positions are evaluated, then they are combined with one or more # list of atom weights to yield the mesh values. -# +# -interpol = ml.mesh_interpolator.MeshInterpolator(frame.cell, - torch.tensor([16,16,16]), - interpolation_order=3) +interpol = ml.mesh_interpolator.MeshInterpolator( + frame.cell, torch.tensor([16, 16, 16]), interpolation_order=3 +) interpol.compute_interpolation_weights(frame.positions) -# %% +# %% # We use two sets of weights: ones (giving the atom density irrespective # of the species) and charges (giving a smooth representation of the point # charges). -# +# -atom_weights = torch.ones((len(charges),2)) -atom_weights[:,1] = charges +atom_weights = torch.ones((len(charges), 2)) +atom_weights[:, 1] = charges mesh = interpol.points_to_mesh(atom_weights) # there are two densities @@ -103,77 +119,74 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): # %% # the first corresponds to plain density -sliceplot(mesh[0,:,:,:5]) +sliceplot(mesh[0, :, :, :5]) # %% # the second to the charge-weighted one -sliceplot(mesh[1,:,:,:5], cmap="seismic", vmax=1, vmin=-1) +sliceplot(mesh[1, :, :, :5], cmap="seismic", vmax=1, vmin=-1) -# %% +# %% # Fourier filter # -------------- -# +# # This module computes a Fourier-domain filter, that can be used e.g. to # smear the density and/or compute a 1/r^p potential field. This can also # be easily extended to compute an arbitrary filter -# +# fsc = ml.fourier_convolution.FourierSpaceConvolution(frame.cell) # %% -# plain smearing -rho_mesh = fsc.compute(mesh, - potential_exponent=0, - smearing=1) +# plain smearing +rho_mesh = fsc.compute(mesh, potential_exponent=0, smearing=1) -sliceplot(rho_mesh[0,:,:,:5]) +sliceplot(rho_mesh[0, :, :, :5]) # %% # coulomb-like potential, no smearing -coulomb_mesh = fsc.compute(mesh, - potential_exponent=1, - smearing=0) +coulomb_mesh = fsc.compute(mesh, potential_exponent=1, smearing=0) -sliceplot(coulomb_mesh[1,:,:,:5], cmap="seismic") +sliceplot(coulomb_mesh[1, :, :, :5], cmap="seismic") -# %% +# %% # Back-interpolation (on the same points) # --------------------------------------- -# +# # The same ``MeshInterpolator`` object can be used to compute a field on # the same points used initially to generate the atom density -# +# potentials = interpol.mesh_to_points(coulomb_mesh) potentials -# %% +# %% # Back-interpolation (on different points) # ---------------------------------------- -# +# # In order to compute the field on a different set of points, it is # sufficient to build another ``MeshInterpolator`` object and to compute # it with the desired field. One can also use a different # ``interpolation_order``, if wanted. -# +# -interpol_slice = ml.mesh_interpolator.MeshInterpolator(frame.cell, - torch.tensor([16,16,16]), - interpolation_order=4) +interpol_slice = ml.mesh_interpolator.MeshInterpolator( + frame.cell, torch.tensor([16, 16, 16]), interpolation_order=4 +) # Compute a denser grid on a 2D slice -n_points=50 -x = torch.linspace(0, frame.cell[0,0], n_points+1)[:n_points] -y = torch.linspace(0, frame.cell[1,1], n_points+1)[:n_points] -xx, yy = torch.meshgrid(x, y, indexing='ij') +n_points = 50 +x = torch.linspace(0, frame.cell[0, 0], n_points + 1)[:n_points] +y = torch.linspace(0, frame.cell[1, 1], n_points + 1)[:n_points] +xx, yy = torch.meshgrid(x, y, indexing="ij") # Flatten xx and yy, and concatenate with a zero column for the z-coordinate -slice_points = torch.cat((xx.reshape(-1, 1), yy.reshape(-1, 1), - 0.5*torch.ones(n_points**2, 1)), dim=1) +slice_points = torch.cat( + (xx.reshape(-1, 1), yy.reshape(-1, 1), 0.5 * torch.ones(n_points**2, 1)), dim=1 +) # %% @@ -181,6 +194,11 @@ def sliceplot(mesh, sz=12, cmap="viridis", vmin=None, vmax=None): coulomb_slice = interpol_slice.mesh_to_points(coulomb_mesh) -plt.contourf(xx, yy, - coulomb_slice[:,1].reshape(n_points, n_points).T, - cmap="seismic", vmin=-1, vmax=1); +plt.contourf( + xx, + yy, + coulomb_slice[:, 1].reshape(n_points, n_points).T, + cmap="seismic", + vmin=-1, + vmax=1, +) diff --git a/tests/test_fourier_convolution.py b/tests/test_fourier_convolution.py index 04c5a8c5..f3fc3596 100644 --- a/tests/test_fourier_convolution.py +++ b/tests/test_fourier_convolution.py @@ -94,6 +94,8 @@ def test_convolution_for_delta(self, cell, mesh_vals): n_channels, nx, ny, nz = mesh_vals.shape n_fft = nx * ny * nz FSC = FourierSpaceConvolution(cell) - mesh_vals_new = FSC.compute(mesh_vals, potential_exponent=0) * volume / n_fft + mesh_vals_new = ( + FSC.compute(mesh_vals, potential_exponent=0, smearing=0.0) * volume / n_fft + ) assert_close(mesh_vals, mesh_vals_new, rtol=1e-4, atol=1e-6)