From 071e17768595386da2a3b3562ff98c64cf7606f2 Mon Sep 17 00:00:00 2001 From: Adam Krzywaniak Date: Sat, 20 Jan 2024 17:28:33 +0000 Subject: [PATCH] code for paper plots --- pyscf_ipu/direct/another_plot.py | 79 ++ pyscf_ipu/direct/inference_heatmap_plot.py | 303 ++++++++ .../direct/inference_heatmap_plot_small_bs.py | 311 ++++++++ pyscf_ipu/direct/plot_heatmap_for_paper.py | 44 ++ pyscf_ipu/direct/train.py | 734 +++++++----------- pyscf_ipu/direct/transformer.py | 72 +- 6 files changed, 1035 insertions(+), 508 deletions(-) create mode 100644 pyscf_ipu/direct/another_plot.py create mode 100644 pyscf_ipu/direct/inference_heatmap_plot.py create mode 100644 pyscf_ipu/direct/inference_heatmap_plot_small_bs.py create mode 100644 pyscf_ipu/direct/plot_heatmap_for_paper.py diff --git a/pyscf_ipu/direct/another_plot.py b/pyscf_ipu/direct/another_plot.py new file mode 100644 index 0000000..0b73ffe --- /dev/null +++ b/pyscf_ipu/direct/another_plot.py @@ -0,0 +1,79 @@ +import pickle +import numpy as np +import matplotlib.pyplot as plt + + +ml_file = "heatmap_data_009.pkl" +pyscf_file = "heatmap_pyscf_009.pkl" +# Load data from the pickle file +with open(ml_file, 'rb') as file: + data_list = pickle.load(file) + +with open(pyscf_file, 'rb') as file: + pyscf_list = pickle.load(file) + +# Extract phi, psi, and values from the loaded data +phi_values, psi_values, heatmap_val = zip(*data_list) + +# Extract phi, psi, and values from the loaded data +phi_values_p, psi_values_p, heatmap_pyscf = zip(*pyscf_list) + +matrix_size = int(len(data_list) ** 0.5) + +heatmap_val = np.array(heatmap_val).reshape(matrix_size, matrix_size) +heatmap_pyscf = np.array(heatmap_pyscf).reshape(matrix_size, matrix_size) + +# valid_E = NN(molecule) \approx E +# state.pyscf_E = DFT(molecule) = E +# state.valid_l = | NN(molecule) - DFT(molecule) | +# +heatmap_pyscf = -heatmap_pyscf + +phi_coordinates, psi_coordinates = np.meshgrid(np.linspace(min(phi_values), max(phi_values), matrix_size), + np.linspace(min(psi_values), max(psi_values), matrix_size)) + +fig, ax = plt.subplots(2,3, figsize=(10, 8)) +# im = ax[0,0].imshow( heatmap_val ) +im = ax[0,0].imshow(heatmap_val, cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) + +# ax[0,0].set_xlim(phi_values) +# ax[0,0].set_ylim(psi_values) +im2 = ax[0,1].imshow( heatmap_pyscf, cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +diff = ax[0,2].imshow( np.abs(heatmap_val - heatmap_pyscf), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) + +log = ax[1,0].imshow( np.log(np.abs(heatmap_val )), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +log2 = ax[1,1].imshow( np.log(np.abs(heatmap_pyscf )), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +difflog = ax[1,2].imshow( np.log(np.abs((heatmap_val - heatmap_pyscf))), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) + +for i in range(3): + for j in range(2): + ax[j, i].set_xticks(np.arange(phi_values[0], phi_values[-1], 45)) + ax[j, i].set_yticks(np.arange(psi_values[0], psi_values[-1], 45)) + # ax[j, i].set_xlim([phi_values[0], phi_values[-1]]) + # ax[j, i].set_ylim([psi_values[0], psi_values[-1]]) + ax[j, i].set_xlabel("phi [deg]") + ax[j, i].set_ylabel("psi [deg]") + +# orient = 'vertical' +orient = 'horizontal' +cbar = fig.colorbar(im, ax=ax[0, 0], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(im2, ax=ax[0, 1], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(diff, ax=ax[0, 2], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(log, ax=ax[1, 0], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(log2, ax=ax[1, 1], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(difflog, ax=ax[1, 2], orientation=orient, fraction=0.05, pad=0.28) + +# for a in ax.reshape(-1): a.axis("off") +ax[0,0].set_title("NN Energy") +ax[0,1].set_title("PySCF Energy") +ax[0,2].set_title("|NN-PySCF| Energy") + +ax[1,0].set_title("NN log(|Energy|)") +ax[1,1].set_title("PySCF log(|Energy|)") +ax[1,2].set_title("|NN-PySCF| log(|Energy|)") +# ax[0,0].set_ylabel("Energy") # may fail with axis("off") +# ax[1,0].set_ylabel("log(|Energy|)") # may fail with axis("off") +plt.tight_layout() + +# Save the plot to a PNG file +plt.savefig("poc.png") \ No newline at end of file diff --git a/pyscf_ipu/direct/inference_heatmap_plot.py b/pyscf_ipu/direct/inference_heatmap_plot.py new file mode 100644 index 0000000..cd86d5b --- /dev/null +++ b/pyscf_ipu/direct/inference_heatmap_plot.py @@ -0,0 +1,303 @@ +import pickle +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np + +HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = 27.2114079527, 1e-20, 0.2 + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('-basis', type=str, default="sto3g") +parser.add_argument('-level', type=int, default=0) + +# GD options +parser.add_argument('-backend', type=str, default="cpu") +parser.add_argument('-lr', type=float, default=2.5e-4) +parser.add_argument('-steps', type=int, default=100000) +parser.add_argument('-bs', type=int, default=8) +parser.add_argument('-val_bs', type=int, default=8) +parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. + +# energy computation speedups +parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) +parser.add_argument('-xc_f32', action="store_true") +parser.add_argument('-eri_f32', action="store_true") +parser.add_argument('-eri_bs', type=int, default=8) + +parser.add_argument('-normal', action="store_true") +parser.add_argument('-wandb', action="store_true") +parser.add_argument('-prof', action="store_true") +parser.add_argument('-visualize', action="store_true") +parser.add_argument('-skip', action="store_true", help="skip pyscf test case") + +# dataset +parser.add_argument('-qm9', action="store_true") +parser.add_argument('-benzene', action="store_true") +parser.add_argument('-hydrogens', action="store_true") +parser.add_argument('-water', action="store_true") +parser.add_argument('-waters', action="store_true") +parser.add_argument('-alanine', action="store_true") +parser.add_argument('-states', type=int, default=1) +parser.add_argument('-workers', type=int, default=5) +parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. + # do noise schedule, start small slowly increase +parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") +parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") +parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") + +# models +parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") +parser.add_argument('-tiny', action="store_true") +parser.add_argument('-small', action="store_true") +parser.add_argument('-base', action="store_true") +parser.add_argument('-medium', action="store_true") +parser.add_argument('-large', action="store_true") +parser.add_argument('-xlarge', action="store_true") + +parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model +parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model + +# inference heatmap plot args +parser.add_argument("-heatmap_step", type=int, default=10) +parser.add_argument("-plot_range", type=int, default=360) +opts = parser.parse_args() + +assert opts.val_bs * opts.heatmap_step == opts.plot_range, "[Temporary dependency] Try adjusting VAL_BS and HEATMAP_STEP so that their product is equal to PLOT_RANGE (by default 360)" + +if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True + +if opts.alanine: + mol_str = [[ # 22 atoms (12 hydrogens) => 10 heavy atoms (i.e. larger than QM9). + ["H", ( 2.000 , 1.000, -0.000)], + ["C", ( 2.000 , 2.090, 0.000)], + ["H", ( 1.486 , 2.454, 0.890)], + ["H", ( 1.486 , 2.454, -0.890)], + ["C", ( 3.427 , 2.641, -0.000)], + ["O", ( 4.391 , 1.877, -0.000)], + ["N", ( 3.555 , 3.970, -0.000)], + ["H", ( 2.733 , 4.556, -0.000)], + ["C", ( 4.853 , 4.614, -0.000)], # carbon alpha + ["H", ( 5.408 , 4.316, 0.890)], # hydrogne attached to carbon alpha + ["C", ( 5.661 , 4.221, -1.232)], # carbon beta + ["H", ( 5.123 , 4.521, -2.131)], # hydrogens attached to carbon beta + ["H", ( 6.630 , 4.719, -1.206)], # hydrogens attached to carbon beta + ["H", ( 5.809 , 3.141, -1.241)], # hydrogens attached to carbon beta + ["C", ( 4.713 , 6.129, 0.000)], + ["O", ( 3.601 , 6.653, 0.000)], + ["N", ( 5.846 , 6.835, 0.000)], + ["H", ( 6.737 , 6.359, -0.000)], + ["C", ( 5.846 , 8.284, 0.000)], + ["H", ( 4.819 , 8.648, 0.000)], + ["H", ( 6.360 , 8.648, 0.890)], + ["H", ( 6.360 , 8.648, -0.890)], + ]] + +B, BxNxN, BxNxK = None, None, None +cfg = None +from train import dm_energy + +from transformer import transformer_init +from train import nao +# global cfg +'''Model ViT model embedding #heads #layers #params training throughput +dimension resolution (im/sec) +DeiT-Ti N/A 192 3 12 5M 224 2536 +DeiT-S N/A 384 6 12 22M 224 940 +DeiT-B ViT-B 768 12 12 86M 224 292 +Parameters Layers dmodel +117M 12 768 +345M 24 1024 +762M 36 1280 +1542M 48 1600 +''' +if opts.tiny: # 5M + d_model= 192 + n_heads = 6 + n_layers = 12 +if opts.small: + d_model= 384 + n_heads = 6 + n_layers = 12 +if opts.base: + d_model= 768 + n_heads = 12 + n_layers = 12 +if opts.medium: + d_model= 1024 + n_heads = 16 + n_layers = 24 +if opts.large: + d_model= 1280 + n_heads = 16 + n_layers = 36 +if opts.xlarge: + d_model= 1600 + n_heads = 25 + n_layers = 48 + +if opts.nn: + rnd_key = jax.random.PRNGKey(42) + n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ + nao("O", opts.basis) + nao("F", opts.basis) + \ + nao("H", opts.basis) + rnd_key, cfg, params, total_params = transformer_init( + rnd_key, + n_vocab, + d_model =d_model, + n_layers=n_layers, + n_heads =n_heads, + d_ff =d_model*4, + ) + +# vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) +valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) + +from train import batched_state +from torch.utils.data import DataLoader, Dataset +class OnTheFlyQM9(Dataset): + # prepares dft tensors with pyscf "on the fly". + # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). + # problem: second epoch always gives segfault. + # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples + def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False, init_phi_psi = None): + # only take molecules with use {CNOFH}, nao=nao and spin=0. + import pandas as pd + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] + # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. + # this shuffling is important, because it makes the last 10 samples iid (used for validation) + #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? + + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] + #print(df["pyscf"].) # todo: print smile strings + + self.num_epochs = num_epochs + self.opts = opts + self.validation = not train + self.extrapolate = extrapolate + self.init_phi_psi = init_phi_psi + + # self.benzene = [ + # ["C", ( 0.0000, 0.0000, 0.0000)], + # ["C", ( 1.4000, 0.0000, 0.0000)], + # ["C", ( 2.1000, 1.2124, 0.0000)], + # ["C", ( 1.4000, 2.4249, 0.0000)], + # ["C", ( 0.0000, 2.4249, 0.0000)], + # ["C", (-0.7000, 1.2124, 0.0000)], + # ["H", (-0.5500, -0.9526, 0.0000)], + # ["H", (-0.5500, 3.3775, 0.0000)], + # ["H", ( 1.9500, -0.9526, 0.0000)], + # ["H", (-1.8000, 1.2124, 0.0000)], + # ["H", ( 3.2000, 1.2124, 0.0000)], + # ["H", ( 1.9500, 3.3775, 0.0000)] + # ] + # self.waters = [ + # ["O", (-1.464, 0.099, 0.300)], + # ["H", (-1.956, 0.624, -0.340)], + # ["H", (-1.797, -0.799, 0.206)], + # ["O", ( 1.369, 0.146, -0.395)], + # ["H", ( 1.894, 0.486, 0.335)], + # ["H", ( 0.451, 0.165, -0.083)] + # ] + + # if opts.benzene: self.mol_strs = [self.benzene] + # if opts.waters: self.mol_strs = [self.waters] + if opts.alanine: self.mol_strs = mol_str + + if train: self.bs = opts.bs + else: self.bs = opts.val_bs + + def __len__(self): + return len(self.mol_strs)*self.num_epochs + + def __getitem__(self, idx): + return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ + wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx, init_phi_psi = self.init_phi_psi, inference=True, inference_psi_step=opts.heatmap_step) + + +print("loading checkpoint") +weights = pickle.load(open("%s_model.pickle"%opts.resume, "rb")) +print("done loading. ") + +# print("loading adam state") +# adam_state = pickle.load(open("%s_adam_state.pickle"%opts.resume, "rb")) +# print("done") + +# weights, adam_state = jax.device_put(weights), jax.device_put(adam_state) +weights = jax.device_put(weights) + +from train import HashableNamespace + +# make `opts` hashable so that JAX will not complain about the static parameter that is passed as arg +opts = HashableNamespace(opts) + +data = [] +pyscf = [] +# data.append((1,1,344)) +# data.append((2,4,323)) +# data.append((3,3,334)) +# data.append((4,2,331)) + +for phi in range(0, opts.plot_range, opts.heatmap_step): + for psi in range(0, opts.plot_range, opts.val_bs * opts.heatmap_step): + val_qm9 = OnTheFlyQM9(opts, train=False, init_phi_psi=(phi, psi)) + val_state = jax.device_put(val_qm9[0]) + # print("\n^^^^^^^^^^^\nJUST VAL QM9 [0]:", val_qm9[0]) + # print("WHOLE VAL QM9:", val_qm9) + print("VAL_QM9[0].pyscf_E:", val_qm9[0].pyscf_E) + _, (valid_vals, _, vdensity_matrix, vW) = valf(weights, val_state, opts.normal, opts.nn, cfg, opts) + + valid_l = np.abs(valid_vals*HARTREE_TO_EV-val_state.pyscf_E) + valid_E = np.abs(valid_vals*HARTREE_TO_EV) + + print("valid_l: ", valid_l, "\nvalid_E: ", valid_E, "\nphi ", phi, " psi ", psi) + + for i in range(0, opts.val_bs): + data.append((phi, psi + i * opts.heatmap_step, valid_E[i])) + pyscf.append((phi, psi + i * opts.heatmap_step, val_state.pyscf_E[i].item())) + + # data.append((phi, psi, valid_E[0])) + +#data = np.log(np.abs(data)) +import matplotlib.pyplot as plt +from scipy.interpolate import griddata +# Extract phi, psi, and values from the data +phi_values, psi_values, heatmap_values = zip(*data) + +# Define a grid +phi_grid, psi_grid = np.meshgrid(np.linspace(min(phi_values), max(phi_values), 100), + np.linspace(min(psi_values), max(psi_values), 100)) +# Interpolate values on the grid +heatmap_interpolated = griddata((phi_values, psi_values), heatmap_values, (phi_grid, psi_grid), method='cubic', fill_value=0) + + +# Create a filled contour plot +plt.contourf(psi_grid, phi_grid, heatmap_interpolated, cmap='viridis', levels=100) +plt.colorbar(label='Intensity') + +# Set axis labels and title +plt.xlabel('Psi Angle') +plt.ylabel('Phi Angle') +plt.title('2D Heatmap with Interpolation') + +# Save the plot to a PNG file +plt.savefig('heatmap_plot.png') + +# Show the plot +plt.show() + +import pickle + +print("DATA ML", data) +print("DATA PYSCF", pyscf) +# Save data to a pickle file +with open('heatmap_data.pkl', 'wb') as file: + pickle.dump(data, file) + + +# Save pyscf to a pickle file +with open('heatmap_pyscf.pkl', 'wb') as file: + pickle.dump(pyscf, file) \ No newline at end of file diff --git a/pyscf_ipu/direct/inference_heatmap_plot_small_bs.py b/pyscf_ipu/direct/inference_heatmap_plot_small_bs.py new file mode 100644 index 0000000..6bcbfa9 --- /dev/null +++ b/pyscf_ipu/direct/inference_heatmap_plot_small_bs.py @@ -0,0 +1,311 @@ +import pickle +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np + +HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = 27.2114079527, 1e-20, 0.2 + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('-basis', type=str, default="sto3g") +parser.add_argument('-level', type=int, default=0) + +# GD options +parser.add_argument('-backend', type=str, default="cpu") +parser.add_argument('-lr', type=float, default=2.5e-4) +parser.add_argument('-steps', type=int, default=100000) +parser.add_argument('-bs', type=int, default=8) +parser.add_argument('-val_bs', type=int, default=8) +parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. + +# energy computation speedups +parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) +parser.add_argument('-xc_f32', action="store_true") +parser.add_argument('-eri_f32', action="store_true") +parser.add_argument('-eri_bs', type=int, default=8) + +parser.add_argument('-normal', action="store_true") +parser.add_argument('-wandb', action="store_true") +parser.add_argument('-prof', action="store_true") +parser.add_argument('-visualize', action="store_true") +parser.add_argument('-skip', action="store_true", help="skip pyscf test case") + +# dataset +parser.add_argument('-qm9', action="store_true") +parser.add_argument('-benzene', action="store_true") +parser.add_argument('-hydrogens', action="store_true") +parser.add_argument('-water', action="store_true") +parser.add_argument('-waters', action="store_true") +parser.add_argument('-alanine', action="store_true") +parser.add_argument('-states', type=int, default=1) +parser.add_argument('-workers', type=int, default=5) +parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. + # do noise schedule, start small slowly increase +parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") +parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") +parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") + +# models +parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") +parser.add_argument('-tiny', action="store_true") +parser.add_argument('-small', action="store_true") +parser.add_argument('-base', action="store_true") +parser.add_argument('-medium', action="store_true") +parser.add_argument('-large', action="store_true") +parser.add_argument('-xlarge', action="store_true") + +parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model +parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model + +# inference heatmap plot args +parser.add_argument("-heatmap_step", type=int, default=10) +parser.add_argument("-plot_range", type=int, default=360) +opts = parser.parse_args() + +# assert opts.val_bs * opts.heatmap_step == opts.plot_range, "[Temporary dependency] Try adjusting VAL_BS and HEATMAP_STEP so that their product is equal to PLOT_RANGE (by default 360)" +assert (opts.plot_range % (opts.val_bs * opts.heatmap_step)) == 0, "batch * step will not fit within the range with integer number of subranges" +if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True + +if opts.alanine: + mol_str = [[ # 22 atoms (12 hydrogens) => 10 heavy atoms (i.e. larger than QM9). + ["H", ( 2.000 , 1.000, -0.000)], + ["C", ( 2.000 , 2.090, 0.000)], + ["H", ( 1.486 , 2.454, 0.890)], + ["H", ( 1.486 , 2.454, -0.890)], + ["C", ( 3.427 , 2.641, -0.000)], + ["O", ( 4.391 , 1.877, -0.000)], + ["N", ( 3.555 , 3.970, -0.000)], + ["H", ( 2.733 , 4.556, -0.000)], + ["C", ( 4.853 , 4.614, -0.000)], # carbon alpha + ["H", ( 5.408 , 4.316, 0.890)], # hydrogne attached to carbon alpha + ["C", ( 5.661 , 4.221, -1.232)], # carbon beta + ["H", ( 5.123 , 4.521, -2.131)], # hydrogens attached to carbon beta + ["H", ( 6.630 , 4.719, -1.206)], # hydrogens attached to carbon beta + ["H", ( 5.809 , 3.141, -1.241)], # hydrogens attached to carbon beta + ["C", ( 4.713 , 6.129, 0.000)], + ["O", ( 3.601 , 6.653, 0.000)], + ["N", ( 5.846 , 6.835, 0.000)], + ["H", ( 6.737 , 6.359, -0.000)], + ["C", ( 5.846 , 8.284, 0.000)], + ["H", ( 4.819 , 8.648, 0.000)], + ["H", ( 6.360 , 8.648, 0.890)], + ["H", ( 6.360 , 8.648, -0.890)], + ]] + +B, BxNxN, BxNxK = None, None, None +cfg = None +from train import dm_energy + +from transformer import transformer_init +from train import nao +# global cfg +'''Model ViT model embedding #heads #layers #params training throughput +dimension resolution (im/sec) +DeiT-Ti N/A 192 3 12 5M 224 2536 +DeiT-S N/A 384 6 12 22M 224 940 +DeiT-B ViT-B 768 12 12 86M 224 292 +Parameters Layers dmodel +117M 12 768 +345M 24 1024 +762M 36 1280 +1542M 48 1600 +''' +if opts.tiny: # 5M + d_model= 192 + n_heads = 6 + n_layers = 12 +if opts.small: + d_model= 384 + n_heads = 6 + n_layers = 12 +if opts.base: + d_model= 768 + n_heads = 12 + n_layers = 12 +if opts.medium: + d_model= 1024 + n_heads = 16 + n_layers = 24 +if opts.large: + d_model= 1280 + n_heads = 16 + n_layers = 36 +if opts.xlarge: + d_model= 1600 + n_heads = 25 + n_layers = 48 + +if opts.nn: + rnd_key = jax.random.PRNGKey(42) + n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ + nao("O", opts.basis) + nao("F", opts.basis) + \ + nao("H", opts.basis) + rnd_key, cfg, params, total_params = transformer_init( + rnd_key, + n_vocab, + d_model =d_model, + n_layers=n_layers, + n_heads =n_heads, + d_ff =d_model*4, + ) + +# vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) +valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) + +from train import batched_state +from torch.utils.data import DataLoader, Dataset +class OnTheFlyQM9(Dataset): + # prepares dft tensors with pyscf "on the fly". + # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). + # problem: second epoch always gives segfault. + # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples + def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False, init_phi_psi = None): + # only take molecules with use {CNOFH}, nao=nao and spin=0. + import pandas as pd + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] + # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. + # this shuffling is important, because it makes the last 10 samples iid (used for validation) + #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? + + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] + #print(df["pyscf"].) # todo: print smile strings + + self.num_epochs = num_epochs + self.opts = opts + self.validation = not train + self.extrapolate = extrapolate + self.init_phi_psi = init_phi_psi + + # self.benzene = [ + # ["C", ( 0.0000, 0.0000, 0.0000)], + # ["C", ( 1.4000, 0.0000, 0.0000)], + # ["C", ( 2.1000, 1.2124, 0.0000)], + # ["C", ( 1.4000, 2.4249, 0.0000)], + # ["C", ( 0.0000, 2.4249, 0.0000)], + # ["C", (-0.7000, 1.2124, 0.0000)], + # ["H", (-0.5500, -0.9526, 0.0000)], + # ["H", (-0.5500, 3.3775, 0.0000)], + # ["H", ( 1.9500, -0.9526, 0.0000)], + # ["H", (-1.8000, 1.2124, 0.0000)], + # ["H", ( 3.2000, 1.2124, 0.0000)], + # ["H", ( 1.9500, 3.3775, 0.0000)] + # ] + # self.waters = [ + # ["O", (-1.464, 0.099, 0.300)], + # ["H", (-1.956, 0.624, -0.340)], + # ["H", (-1.797, -0.799, 0.206)], + # ["O", ( 1.369, 0.146, -0.395)], + # ["H", ( 1.894, 0.486, 0.335)], + # ["H", ( 0.451, 0.165, -0.083)] + # ] + + # if opts.benzene: self.mol_strs = [self.benzene] + # if opts.waters: self.mol_strs = [self.waters] + if opts.alanine: self.mol_strs = mol_str + + if train: self.bs = opts.bs + else: self.bs = opts.val_bs + + def __len__(self): + return len(self.mol_strs)*self.num_epochs + + def __getitem__(self, idx): + return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ + wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx, init_phi_psi = self.init_phi_psi, inference=True, inference_psi_step=opts.heatmap_step) + + +print("loading checkpoint") +weights = pickle.load(open("%s_model.pickle"%opts.resume, "rb")) +print("done loading. ") + +# print("loading adam state") +# adam_state = pickle.load(open("%s_adam_state.pickle"%opts.resume, "rb")) +# print("done") + +# weights, adam_state = jax.device_put(weights), jax.device_put(adam_state) +weights = jax.device_put(weights) + +from train import HashableNamespace + +# make `opts` hashable so that JAX will not complain about the static parameter that is passed as arg +opts = HashableNamespace(opts) + +data = [] +pyscf = [] +# data.append((1,1,344)) +# data.append((2,4,323)) +# data.append((3,3,334)) +# data.append((4,2,331)) + +valid_E = None +val_state = None +for phi in range(0, opts.plot_range, opts.heatmap_step): + # psi_start = 0 + # psi_end = psi_start + opts.val_bs * opts.heatmap_step + # while psi_end <= opts.plot_range: + # for psi in range(psi_start, psi_end, opts.heatmap_step): + for psi in range(0, opts.plot_range, opts.val_bs * opts.heatmap_step): + # print(psi, psi_start, psi_end, "<<<<<<<<<<<<<<<<<<") + val_qm9 = OnTheFlyQM9(opts, train=False, init_phi_psi=(phi, psi)) + val_state = jax.device_put(val_qm9[0]) + # print("\n^^^^^^^^^^^\nJUST VAL QM9 [0]:", val_qm9[0]) + # print("WHOLE VAL QM9:", val_qm9) + print("VAL_QM9[0].pyscf_E:", val_qm9[0].pyscf_E) + _, (valid_vals, _, vdensity_matrix, vW) = valf(weights, val_state, opts.normal, opts.nn, cfg, opts) + + valid_l = np.abs(valid_vals*HARTREE_TO_EV-val_state.pyscf_E) + valid_E = np.abs(valid_vals*HARTREE_TO_EV) + + print("valid_l: ", valid_l, "\nvalid_E: ", valid_E, "\nphi ", phi, " psi ", psi) + + for i in range(0, opts.val_bs): + data.append((phi, psi + i * opts.heatmap_step, valid_E[i])) + pyscf.append((phi, psi + i * opts.heatmap_step, val_state.pyscf_E[i].item())) + # psi_start = 0 + psi_end + # psi_end += opts.val_bs * opts.heatmap_step + # data.append((phi, psi, valid_E[0])) + +#data = np.log(np.abs(data)) +import matplotlib.pyplot as plt +from scipy.interpolate import griddata +# Extract phi, psi, and values from the data +phi_values, psi_values, heatmap_values = zip(*data) + +# Define a grid +phi_grid, psi_grid = np.meshgrid(np.linspace(min(phi_values), max(phi_values), 100), + np.linspace(min(psi_values), max(psi_values), 100)) +# Interpolate values on the grid +heatmap_interpolated = griddata((phi_values, psi_values), heatmap_values, (phi_grid, psi_grid), method='cubic', fill_value=0) + + +# Create a filled contour plot +plt.contourf(psi_grid, phi_grid, heatmap_interpolated, cmap='viridis', levels=100) +plt.colorbar(label='Intensity') + +# Set axis labels and title +plt.xlabel('Psi Angle') +plt.ylabel('Phi Angle') +plt.title('2D Heatmap with Interpolation') + +# Save the plot to a PNG file +plt.savefig('heatmap_plot.png') + +# Show the plot +plt.show() + +import pickle + +print("DATA ML", data) +print("DATA PYSCF", pyscf) +# Save data to a pickle file +with open('heatmap_data_bs2.pkl', 'wb') as file: + pickle.dump(data, file) + + +# Save pyscf to a pickle file +with open('heatmap_pyscf_bs2.pkl', 'wb') as file: + pickle.dump(pyscf, file) \ No newline at end of file diff --git a/pyscf_ipu/direct/plot_heatmap_for_paper.py b/pyscf_ipu/direct/plot_heatmap_for_paper.py new file mode 100644 index 0000000..58fab93 --- /dev/null +++ b/pyscf_ipu/direct/plot_heatmap_for_paper.py @@ -0,0 +1,44 @@ +import pickle +import numpy as np +import matplotlib.pyplot as plt +from scipy.interpolate import griddata + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('-data_file', type=str) +parser.add_argument('-output_name', type=str, default="default_output.png") +parser.add_argument('-log', type=bool, default=False) +opts = parser.parse_args() + +# Load data from the pickle file +with open(opts.data_file, 'rb') as file: + data_list = pickle.load(file) + + +# Extract phi, psi, and values from the loaded data +phi_values, psi_values, heatmap_values = zip(*data_list) + +if opts.log: + heatmap_values = np.log(np.abs(heatmap_values - np.mean(heatmap_values))) + +print(heatmap_values) +# Create a meshgrid of phi and psi coordinates +phi_coordinates, psi_coordinates = np.meshgrid(np.linspace(min(phi_values), max(phi_values), 100), + np.linspace(min(psi_values), max(psi_values), 100)) + +# Interpolate values on the grid +heatmap_interpolated = griddata((phi_values, psi_values), heatmap_values, (phi_coordinates, psi_coordinates), method='cubic', fill_value=0) + +# Display the 2D matrix as an image +plt.imshow(heatmap_interpolated, cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +plt.colorbar(label='Intensity') # Add colorbar with label + +# Set axis labels and title +plt.xlabel('Psi Angle') +plt.ylabel('Phi Angle') +plt.title('2D Heatmap from Pickle File') + +# Save the plot to a PNG file +plt.savefig(opts.output_name) + +# Show the plot diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index 18d1c5b..fee2dbc 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -1,9 +1,8 @@ import os -os.environ['OMP_NUM_THREADS'] = '16' +os.environ['OMP_NUM_THREADS'] = '8' import jax jax.config.update('jax_enable_x64', True) import jax.numpy as jnp -import scipy import numpy as np import pyscf import optax @@ -16,10 +15,7 @@ import math from functools import partial import pickle -import random -random.seed(42) -MD17_WATER, MD17_ETHANOL, MD17_ALDEHYDE, MD17_URACIL = 1, 2, 3, 4 cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 def T(x): return jnp.transpose(x, (0,2,1)) @@ -27,24 +23,20 @@ def T(x): return jnp.transpose(x, (0,2,1)) B, BxNxN, BxNxK = None, None, None # Only need to recompute: L_inv, grid_AO, grid_weights, H_core, ERI and E_nuc. -def dm_energy(W: BxNxK, state, normal, nn, cfg=None, opts=None): +def dm_energy(W: BxNxK, state, normal, nn, cfg, opts):#): if nn: - W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0, 0), out_axes=(0))(cfg, \ - W, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32), state.L_inv.astype(jnp.float32)) - #W, state.ao_types, state.pos.astype(jnp.float64), state.H_core.astype(jnp.float64), state.L_inv.astype(jnp.float64)) + W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0))(cfg, \ + W, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) W = W.astype(jnp.float64) # we can interpret state.H_core + W as hamiltonian, and predict hlgap from these! - H = state.H_core + W - L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ H @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO + L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ (state.H_core + W) @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO density_matrix: BxNxN = 2 * (L_inv_Q*state.mask) @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO E_xc: B = exchange_correlation(density_matrix, state, normal, opts.xc_f32) # O(B*gsize*N^2) FLOP O(gsize*N^2) FLIO - diff_JK: BxNxN = JK(density_matrix, state, normal, opts.foriloop, opts.eri_f32, opts.bs) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO + diff_JK: BxNxN = JK(density_matrix, state, normal, opts.foriloop, opts.eri_f32) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO energies: B = E_xc + state.E_nuc + jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(W.shape[0], -1), axis=-1) energy: float = jnp.sum(energies) - return energy, (energies, E_xc, density_matrix, W, H) - - + return energy, (energies, E_xc, density_matrix, W) def sparse_mult(values, dm, state, gsize): in_ = dm.take(state.cols, axis=0) @@ -62,7 +54,6 @@ def exchange_correlation(density_matrix: BxNxN, state, normal, xc_f32): if False: main: BxGsizexN = state.main_grid_AO @ density_matrix # (1, gsize, N) @ (B, N, N) = O(B gsize N^2) FLOPs and O(gsize*N + N^2 +B * gsize * N) FLIOs correction: BxGsizexN = jax.vmap(sparse_mult, in_axes=(0,0,None, None))(state.sparse_diffs_grid_AO, density_matrix, state, gsize) - # todo: remove state.grid_AO w/ sparsity tricks => reduce memory 10x. rho_a = jnp.einsum("bpij,bqij->bpi", state.grid_AO, main.reshape(B,1,gsize,N)) rho_b = jnp.einsum("bpij,bqij->bpi", state.grid_AO, correction.reshape(B,1,gsize,N)) rho = rho_a - rho_b @@ -76,7 +67,7 @@ def exchange_correlation(density_matrix: BxNxN, state, normal, xc_f32): E_xc = jnp.sum(rho[:, 0] * state.grid_weights * E_xc, axis=-1).reshape(B) return E_xc -def JK(density_matrix, state, normal, jax_foriloop, eri_f32, bs): +def JK(density_matrix, state, normal, jax_foriloop, eri_f32): if normal: J = jnp.einsum('bijkl,bji->bkl', state.ERI, density_matrix) K = jnp.einsum('bijkl,bjk->bil', state.ERI, density_matrix) @@ -86,43 +77,32 @@ def JK(density_matrix, state, normal, jax_foriloop, eri_f32, bs): if eri_f32: density_matrix = density_matrix.astype(jnp.float32) - if bs == 1: - diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( - state.nonzero_distinct_ERI[0], - state.nonzero_indices[0], - density_matrix, - jax_foriloop - ) - - - else: - '''diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( - state.nonzero_distinct_ERI[0], - state.nonzero_indices[0], - density_matrix, - jax_foriloop - ) - diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0, None))( - state.diffs_ERI, - state.indxs, - density_matrix, - jax_foriloop - )''' - - diff_JK: BxNxN = jax.vmap(sparse_einsum, in_axes=(None, None, 0, None))( - state.nonzero_distinct_ERI[0], - state.precomputed_nonzero_indices, - density_matrix, - jax_foriloop - ) - if bs > 1: - correction = jax.vmap(sparse_einsum, in_axes=(0, None, 0, None))( - state.diffs_ERI, - state.precomputed_indxs, - density_matrix, - jax_foriloop - ) - diff_JK: BxNxN = diff_JK - correction + '''diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.nonzero_indices[0], + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.indxs, + density_matrix, + jax_foriloop + )''' + + diff_JK: BxNxN = jax.vmap(sparse_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.precomputed_nonzero_indices, + density_matrix, + jax_foriloop + ) + correction = jax.vmap(sparse_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.precomputed_indxs, + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - correction return diff_JK.astype(jnp.float64) @@ -141,10 +121,10 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_nonzero_distinct_ERI=200000, pad_sparse_diff_grid=200000, mol_idx=42, + init_phi_psi=None, + inference=False, + inference_psi_step=5, # degrees ): - start_time = time.time() - do_print = opts.do_print - if do_print: print("\t[%.4fs] start of 'batched_state'. "%(time.time()-start_time)) # pad molecule if using nn. if not opts.nn: pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ @@ -160,19 +140,18 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_nonzero_distinct_ERI = 20000 pad_sparse_diff_grid = 20000 - if opts.qm9 or opts.qh9: + if opts.qm9: pad_electrons=60 - padding_estimate = [48330, 163222, 17034, 159361, 139505] - if opts.nperturb == 2 or opts.nperturb == 1: padding_estimate = [int(a*2.1) for a in padding_estimate] - else: padding_estimate = [int(a*1.1) for a in padding_estimate] - + '''pad_diff_ERIs=120000 + pad_distinct_ERIs=400000 + pad_grid_AO=50000 + pad_nonzero_distinct_ERI=400000 + pad_sparse_diff_grid=400000''' + #padding_estimate = [37426, 149710, 17010, 140122, 138369] + padding_estimate = [48330, 163222, 17034, 159361, 139505] + padding_estimate = [int(a*1.1) for a in padding_estimate] pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate - if opts.basis == "def2-svp": - # disable padding temporarily - max_pad_electrons, max_pad_diff_ERIs, max_pad_distinct_ERIs, max_pad_grid_AO, max_pad_nonzero_distinct_ERI, max_pad_sparse_diff_grid = \ - -1, -1, -1, -1, -1, -1 - if opts.alanine: # todo: (adam) the ERI padding may change when rotating molecule more! pad_electrons = 70 @@ -191,39 +170,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, if opts.waters: pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = [a//3 for a in [ pad_diff_ERIs , pad_distinct_ERIs , pad_grid_AO , pad_nonzero_distinct_ERI , pad_sparse_diff_grid ]] - if opts.md17 > 0: - if opts.md17 == MD17_WATER: - if opts.level == 1: padding_estimate = [ 3361, 5024, 10172 , 5024 , 155958] - if opts.level == 3: padding_estimate = [ 3361, 5024, 34310 , 5024 , 494370] - if opts.bs == 2 and opts.wiggle_var == 0: padding_estimate = [ 1, 5024, 10172 , 5024 , 1] - padding_estimate = [int(a*1.5) for a in padding_estimate] - pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate - - elif opts.md17 == MD17_ETHANOL: - pad_electrons = 72 - #padding_estimate = [ 99042, 197660, 34310*5 , 197308 , 494370*5] - #padding_estimate = [113522, 224275, 30348, 197308, 609163] - if opts.level == 1: padding_estimate = [ 1, 415186, 30348, 415145, 1] - padding_estimate = [int(a*1.1) for a in padding_estimate] - pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate - - elif opts.md17 == MD17_ALDEHYDE: - pad_electrons = 90 - #padding_estimate = [130646, 224386 , 30348, 223233, 626063] - #padding_estimate = [ 34074, 204235 , 30348, 203632, 285934] - if opts.level == 1: padding_estimate = [1, 939479, 35704, 939479, 1] - padding_estimate = [int(a*1.1) for a in padding_estimate] - pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate - - elif opts.md17 == MD17_URACIL: - #raise Exception("not ready yet") - pad_electrons = 132 - padding_estimate = [1, 3769764 , 51184, 3769764, 1] - padding_estimate = [int(a*1.05) for a in padding_estimate] - pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate - - - mol = build_mol(mol_str, opts.basis, unit="bohr") + mol = build_mol(mol_str, opts.basis) pad_electrons = min(pad_electrons, mol.nao_nr()) # Set seed to ensure different rotation; initially all workers did same rotation! @@ -233,53 +180,28 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, water1_xyz = np.array([mol_str[i][1] for i in range(0,3)]) water2_xyz = np.array([mol_str[i][1] for i in range(3,6)]) - - if opts.md17 > 0: - natm = len(mol_str) - atom_num = random.sample(range(natm), 1)[0] - atoms = np.array([mol_str[i][1] for i in range(0,natm)]) - - if opts.qm9 or opts.qh9: - # pick random atom to perturb (of the first 9 heavy ones) - atom_num1, atom_num2, atom_num3 = random.sample(range(9), 3) - atoms = np.array([mol_str[i][1] for i in range(0,10)]) - atom_type = [mol_str[i][0] for i in range(0,10)] + if opts.qm9: + atoms = np.array([mol_str[i][1] for i in range(0,3)]) + # pick random atom to permute (of the first 9 heavy ones) + atom_num = int(np.random.uniform(0, 8)) if opts.alanine: # train on [-180, 180], validate [-180, 180] extrapolate [-360, 360]\[180, -180] # todo: draw picture (in training loop) - if extrapolate: + if extrapolate and not inference: phi, psi = [float(a) for a in np.random.uniform(180, 360, 2)] - else: + elif inference: + phi, psi = init_phi_psi + else: phi, psi = [float(a) for a in np.random.uniform(0, 180, 2)] angles = [] - # Combinatorics of atom substitutions (n_electrons has to be even). - # P := single atom modification, n_electrons remain even (C-> O, O->C, N->F, F->N) - # B1, B2 := single atom modification, n_electrons becomes odd (C->{N,F}, O->{N,F}, N->{O,C}, F->{O,C}) - I = {"C":"C", "O":"O", "F":"F", "N":"N"} - P = {"C": "O", "O":"C", "F":"N", "N":"F"} - B1 = {"C":"N", "O":"N", "F":"O", "N":"O"} - B2 = {"C":"F", "O":"F", "F":"C", "N":"C"} - - if opts.nperturb == 3: - allowed_pertubations=[ - (I,I,I), (I,I,P), (I,P,I), (I,P,P), (P,I,I), (P,I,P), (P,P,I), (P,P,P), - (I,B1,B2), (B1,I,B2), (B1,B2,I), (I,B2,B1), (B2,I,B1), (B2,B1,I), (P,B1,B2), (B1,P,B2), (B1,B2,P), (P,B2,B1), (B2,P,B1), (B2,B1,P), - (I,B1,B1), (B1,I,B1), (B1,B1,I), (I,B1,B1), (B1,I,B1), (B1,B1,I), (P,B1,B1), (B1,P,B1), (B1,B1,P), (P,B1,B1), (B1,P,B1), (B1,B1,P), - (I,B2,B2), (B2,I,B2), (B2,B2,I), (I,B2,B2), (B2,I,B2), (B2,B2,I), (P,B2,B2), (B2,P,B2), (B2,B2,P), (P,B2,B2), (B2,P,B2), (B2,B2,P), - ] - if opts.nperturb == 2: - allowed_pertubations=[ (I, I, I), (I, I, P), (I, P, I), (I, P, P), (I, B1, B1), (I, B1, B2), (I, B2, B1), (I, B2, B2), ] - if opts.nperturb == 1: - allowed_pertubations=[ (I, I, I), (I, I, P), ]*10 - - if opts.nperturb: random.shuffle(allowed_pertubations) - states = [] for iteration in range(bs): - if do_print: print("\t[%.4fs] initializing state %i. "%(time.time()-start_time, iteration)) + import copy + new_str = copy.deepcopy(mol_str) + if opts.alanine: from rdkit import Chem from rdkit.Chem import AllChem @@ -296,86 +218,94 @@ def get_atom_positions(mol): conf = mol.GetConformer() return np.concatenate([xyz(conf.GetAtomPosition(i)) for i in range(mol.GetNumAtoms())], axis=0) - str = [mol_str[j][0] for j in range(len(mol_str))] - pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) + str = [new_str[j][0] for j in range(len(new_str))] + pos = np.concatenate([np.array(new_str[j][1]).reshape(1, 3) for j in range(len(new_str))]) # todo: save=wandb.log({"pair": angle1, angle2, NN_energy ) (rotation, NN_energy) for train/val molecule (for val also save PySCF energy) # only saving angles (angle not paired up with energy) AllChem.SetDihedralDeg(molecule.GetConformer(), *phi_atoms, phi) - angle = psi + float(np.random.uniform(0, opts.rotate_deg, 1)) # perhaps add 45 and mod 360? + angle = psi + float(np.random.uniform(0, opts.rotate_deg, 1)) # perhaps add 45 and mod 360? # todo: check math whether val/extra/train have uniform distribution on their respective domains. - if extrapolate: # make sure angle is in [] + if extrapolate and not inference: # make sure angle is in [] angle = angle % 180 + 180 # angle should be in [180, 360] - else: + elif inference: + angle = psi + iteration * inference_psi_step # overwrite the angle when in inference mode with fixed, not randomized, step + angle = angle % 360 # angle should be [0, 360] for heatmap + else: # validation angle = angle % 180 # angle should be [0, 180] AllChem.SetDihedralDeg(molecule.GetConformer(), *psi_atoms, angle ) pos = get_atom_positions(molecule) angles.append((phi, angle)) - for j in range(len(mol_str)): mol_str[j][1] = tuple(pos[j]) + for j in range(len(new_str)): new_str[j][1] = tuple(pos[j]) '''if iteration == 0 and opts.wandb: from plot import create_rdkit_mol import wandb wandb.log({"mol_valid=%s"%validation: create_rdkit_mol(str, pos) })''' - if opts.md17 > 0: - if iteration > 0: - mol_str[atom_num][1] = tuple(atoms[atom_num] + np.random.normal(0,opts.wiggle_var, (3))) - if opts.waters: # todo: rotate both water molecules and draw x=phi, y=psi. rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] center = water2_xyz.mean(axis=0) water_xyz = np.dot(water2_xyz - center, rotation_matrix) + center - mol_str[3][1] = tuple(water_xyz[0]) - mol_str[4][1] = tuple(water_xyz[1]) - mol_str[5][1] = tuple(water_xyz[2]) + new_str[3][1] = tuple(water_xyz[0]) + new_str[4][1] = tuple(water_xyz[1]) + new_str[5][1] = tuple(water_xyz[2]) '''if opts.wandb and iteration == 0: from plot import create_rdkit_mol import wandb - str = [mol_str[j][0] for j in range(len(mol_str))] - pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) + str = [new_str[j][0] for j in range(len(new_str))] + pos = np.concatenate([np.array(new_str[j][1]).reshape(1, 3) for j in range(len(new_str))]) wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' - elif opts.qm9 or opts.qh9: - if iteration == 0 and (validation or extrapolate): pass - else: - s = mol_str[atom_num1][0]+ mol_str[atom_num2][0]+ mol_str[atom_num3][0] - - # for small stuff, wiggle a single atom a bit. - if opts.nperturb <= 1: mol_str[atom_num3][1] = tuple(atoms[atom_num3] + np.random.normal(0, opts.wiggle_var, (3))) + elif opts.qm9: + # todo: find dihedral to rotate over similar to alanine dipeptide. + # broken; rotate first three atoms around their center of mass + # this breaks molecule; should use dihedral angle as done with the dipeptide. + #rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] + #center = atoms.mean(axis=0) + #rotated_atoms = np.dot(atoms - center, rotation_matrix) + center - # assuming all atoms are non-hydrogen; if wrong don't do any pertubations. - if opts.nperturb > 0 and mol_str[atom_num1][0] != "H" and mol_str[atom_num2][0] != "H" and mol_str[atom_num3][0] != "H": - A,B,C = allowed_pertubations[iteration] - mol_str[atom_num1][0] = A[atom_type[atom_num1]] - mol_str[atom_num2][0] = B[atom_type[atom_num2]] - mol_str[atom_num3][0] = C[atom_type[atom_num3]] + # for extrapolation, do even more. - s += "->" + mol_str[atom_num1][0]+ mol_str[atom_num2][0]+ mol_str[atom_num3][0] - if do_print: print(iteration, s) + if iteration == 0 and (validation or extrapolate): + pass + else: + #new_str[0][1] = tuple(atoms[0] + np.random.normal(0, opts.wiggle_var, (3))) + new_str[atom_num][1] = tuple(atoms[atom_num] + np.random.normal(0, opts.wiggle_var, (3))) + #new_str[1][1] = tuple(atoms[1] + np.random.normal(0, opts.wiggle_var, (3))) + #new_str[2][1] = tuple(atoms[2] + np.random.normal(0, opts.wiggle_var, (3))) - # -nperturb 2 is a lot slower; changing self.prune=False roughly doubles AO matrices. + '''if opts.wandb and iteration == 0: + from plot import create_rdkit_mol + import wandb + str = [new_str[j][0] for j in range(len(new_str))] + pos = np.concatenate([np.array(new_str[j][1]).reshape(1, 3) for j in range(len(new_str))]) + wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' if iteration == 0: - state = init_dft(mol_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) + state = init_dft(new_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) c, w = state.grid_coords, state.grid_weights elif iteration <= 1 or not opts.prof: # when profiling create fake molecule to skip waiting - state = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and iteration < 3, state=state, pad_electrons=pad_electrons) + state = init_dft(new_str, opts, c, w, do_pyscf=do_pyscf and iteration < 80, state=state, pad_electrons=pad_electrons) states.append(state) - + # If we add energy here we get plot basically! + # todo: save and store in training loop, then we can match with energy + # can't get to work in wandb, but can just use download api and the plot. + '''if opts.alanine and opts.wandb: + for phi, psi in angles: + if not validation: + wandb.log({"phi_train": phi , "psi_train": psi}) + else: + wandb.log({"phi_valid": phi, "psi_valid": psi})''' state = cats(states) N = state.N[0] - if do_print: print("\t[%.4fs] concatenated states. "%(time.time()-start_time)) - - - #_nonzero = None # Compute ERI sparsity. nonzero = [] @@ -385,23 +315,16 @@ def get_atom_positions(mol): e[indxs] = 0 nonzero.append(np.nonzero(e)[0]) - #_nonzero = indxs if _nonzero is None else np.logical_or(_nonzero, indxs) - - if do_print: print("\t[%.4fs] got sparsity. "%(time.time()-start_time)) - # Merge nonzero indices and prepare (ij, kl). # rep is the number of repetitions we include in the sparse representation. + #nonzero_indices = np.union1d(nonzero[0], nonzero[1]) union = nonzero[0] - for i in range(1, len(nonzero)): # this takes 12s/it for def2-svp. + for i in range(1, len(nonzero)): union = np.union1d(union, nonzero[i]) nonzero_indices = union - if do_print: print("\t[%.4fs] got union of sparsity. "%(time.time()-start_time)) - - from sparse_symmetric_ERI import get_i_j, num_repetitions_fast ij, kl = get_i_j(nonzero_indices) rep = num_repetitions_fast(ij, kl) - if do_print: print("\t[%.4fs] got (ij) and reps. "%(time.time()-start_time)) batches = opts.eri_bs es = [] @@ -415,10 +338,8 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = np.concatenate([np.expand_dims(a, axis=0) for a in es]) - if do_print: print("\t[%.4fs] padded ERI and nonzero_indices. . "%(time.time()-start_time)) i, j = get_i_j(ij.reshape(-1)) k, l = get_i_j(kl.reshape(-1)) - if do_print: print("\t[%.4fs] got ijkl. "%(time.time()-start_time)) if remainder != 0: i = np.pad(i, ((0,batches-remainder))) @@ -427,15 +348,12 @@ def get_atom_positions(mol): l = np.pad(l, ((0,batches-remainder))) nonzero_indices = np.vstack([i,j,k,l]).T.reshape(batches, -1, 4).astype(np.int32) # todo: use int16 or int32 here? state.nonzero_indices = nonzero_indices - if do_print: print("\t[%.4fs] padded and vstacked ijkl. "%(time.time()-start_time)) # batching (w/ same sparsity pattern across batch) allows precomputing all {ss,dm}_indices instead of computing in sparse_sym_eri every iteration. # function below does this. # todo: consider removing, didn't get expecting 3x (only 5%; not sure if additional memory/complication justifies). from sparse_symmetric_ERI import precompute_indices - - if opts.normal: diff_state = None else: main_grid_AO = state.grid_AO[:1] @@ -444,41 +362,39 @@ def get_atom_positions(mol): sparse_diffs_grid_AO = diffs_grid_AO[:, 0, rows,cols] # use the same sparsity pattern across a batch. - if opts.bs > 1: - diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI - diff_indxs = state.nonzero_indices.reshape(1, batches, -1, 4) - nzr = np.abs(diff_ERIs[1]).reshape(batches, -1) > 1e-10 + diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI + diff_indxs = state.nonzero_indices.reshape(1, batches, -1, 4) + nzr = np.abs(diff_ERIs[1]).reshape(batches, -1) > 1e-10 - diff_ERIs = diff_ERIs[:, nzr].reshape(bs, -1) - diff_indxs = diff_indxs[:, nzr].reshape(-1, 4) + diff_ERIs = diff_ERIs[:, nzr].reshape(bs, -1) + diff_indxs = diff_indxs[:, nzr].reshape(-1, 4) - remainder = np.sum(nzr) % batches - if remainder != 0: - diff_ERIs = np.pad(diff_ERIs, ((0,0),(0,batches-remainder))) - diff_indxs = np.pad(diff_indxs, ((0,batches-remainder),(0,0))) + remainder = np.sum(nzr) % batches + if remainder != 0: + diff_ERIs = np.pad(diff_ERIs, ((0,0),(0,batches-remainder))) + diff_indxs = np.pad(diff_indxs, ((0,batches-remainder),(0,0))) - diff_ERIs = diff_ERIs.reshape(bs, batches, -1) - diff_indxs = diff_indxs.reshape(batches, -1, 4) + diff_ERIs = diff_ERIs.reshape(bs, batches, -1) + diff_indxs = diff_indxs.reshape(batches, -1, 4) - if opts.bs > 1: precomputed_indxs = precompute_indices(diff_indxs, N).astype(np.int16) + precomputed_indxs = precompute_indices(diff_indxs, N).astype(np.int16) - if pad_diff_ERIs == -1: - state.indxs=diff_indxs - state.diffs_ERI=diff_ERIs - assert False, "deal with precomputed_indxs; only added in else branch below" - else: - max_pad_diff_ERIs = diff_ERIs.shape[2] - if do_print: print("\t[%.4fs] max_pad_diff_ERIs=%i"%(time.time()-start_time, max_pad_diff_ERIs)) - # pad ERIs with 0 and indices with -1 so they point to 0. - assert diff_indxs.shape[1] == diff_ERIs.shape[2] - pad = pad_diff_ERIs - diff_indxs.shape[1] - assert pad > 0, (pad_diff_ERIs, diff_indxs.shape[1]) - state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) - state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros - #print(diff_indxs.shape, precomputed_indxs.shape) - if opts.bs > 1: state.precomputed_indxs = np.pad(precomputed_indxs, ((0,0), (0,0),(0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) - - #if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) + if pad_diff_ERIs == -1: + state.indxs=diff_indxs + state.diffs_ERI=diff_ERIs + assert False, "deal with precomputed_indxs; only added in else branch below" + else: + max_pad_diff_ERIs = diff_ERIs.shape[2] + # pad ERIs with 0 and indices with -1 so they point to 0. + assert diff_indxs.shape[1] == diff_ERIs.shape[2] + pad = pad_diff_ERIs - diff_indxs.shape[1] + assert pad > 0, (pad_diff_ERIs, diff_indxs.shape[1]) + state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) + state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros + #print(diff_indxs.shape, precomputed_indxs.shape) + state.precomputed_indxs = np.pad(precomputed_indxs, ((0,0), (0,0),(0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) + + #if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) state.rows=rows state.cols=cols @@ -490,7 +406,6 @@ def get_atom_positions(mol): if pad_sparse_diff_grid != -1: max_pad_sparse_diff_grid = state.rows.shape[0] - if do_print: print("\t[%.4fs] max_pad_sparse_diff_grid=%i"%(time.time()-start_time, max_pad_sparse_diff_grid)) assert state.sparse_diffs_grid_AO.shape[1] == state.rows.shape[0] assert state.sparse_diffs_grid_AO.shape[1] == state.cols.shape[0] pad = pad_sparse_diff_grid - state.rows.shape[0] @@ -509,7 +424,6 @@ def get_atom_positions(mol): # todo: looks like we're padding, then looking for zeros, then padding; this can be simplified. if pad_distinct_ERIs != -1: max_pad_distinct_ERIs = state.nonzero_distinct_ERI.shape[2] - if do_print: print("\t[%.4fs] max_pad_distinct_ERIs=%i"%(time.time()-start_time, max_pad_diff_ERIs)) assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] pad = pad_distinct_ERIs - state.nonzero_distinct_ERI.shape[2] assert pad > 0, (pad_distinct_ERIs, state.nonzero_distinct_ERI.shape[2]) @@ -520,7 +434,6 @@ def get_atom_positions(mol): if pad_grid_AO != -1: max_pad_grid_AO = state.grid_AO.shape[2] - if do_print: print("\t[%.4fs] max_pad_grid_AO=%i"%(time.time()-start_time, max_pad_grid_AO)) prev_size = state.grid_AO.shape[2] assert state.grid_AO.shape[2] == state.grid_weights.shape[1] @@ -555,12 +468,11 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = state.nonzero_distinct_ERI.reshape(1, batches, -1) state.nonzero_indices = state.nonzero_indices.reshape(1, batches, -1, 4) - if opts.bs > 1: precomputed_nonzero_indices = precompute_indices(state.nonzero_indices[0], N).astype(np.int16) + precomputed_nonzero_indices = precompute_indices(state.nonzero_indices[0], N).astype(np.int16) #print(state.nonzero_indices.shape, precomputed_nonzero_indices.shape) if pad_nonzero_distinct_ERI != -1: max_pad_nonzero_distinct_ERI = state.nonzero_distinct_ERI.shape[2] - if do_print: print("\t[%.4fs] max_pad_nonzero_distinct_ERI=%i"%(time.time()-start_time, max_pad_nonzero_distinct_ERI)) assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] pad = pad_nonzero_distinct_ERI - state.nonzero_distinct_ERI.shape[2] @@ -568,7 +480,7 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, ((0,0),(0,0),(0,pad))) state.nonzero_indices = np.pad(state.nonzero_indices, ((0,0),(0,0),(0,pad), (0,0)), 'constant', constant_values=(-1)) - if opts.bs > 1: state.precomputed_nonzero_indices = np.pad(precomputed_nonzero_indices, ((0,0), (0,0), (0,0), (0, pad),(0,0)), 'constant', constant_values=(-1)) + state.precomputed_nonzero_indices = np.pad(precomputed_nonzero_indices, ((0,0), (0,0), (0,0), (0, pad),(0,0)), 'constant', constant_values=(-1)) #print(state.precomputed_nonzero_indices.shape, state.nonzero_indices.shape) #if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) @@ -590,7 +502,6 @@ def get_atom_positions(mol): def nanoDFT(mol_str, opts): - start_time = time.time() print() # Initialize validation set. # This consists of DFT tensors initialized with PySCF/CPU. @@ -602,14 +513,10 @@ def nanoDFT(mol_str, opts): run = wandb.init(project='ndft_alanine') elif opts.qm9: run = wandb.init(project='ndft_qm9') - elif opts.md17 > 0: - run = wandb.init(project='md17') else: run = wandb.init(project='ndft') opts.name = run.name - wandb.log(vars(opts)) - else: opts.name = "%i"%time.time() @@ -646,17 +553,13 @@ def nanoDFT(mol_str, opts): d_model= 1024 n_heads = 16 n_layers = 24 - if opts.large: # this is 600M; - d_model= 1280 # 80*16 + if opts.large: + d_model= 1280 n_heads = 16 n_layers = 36 - if opts.largep: # interpolated between large and largep. - d_model= 91*16 # halway from 80 to 100 - n_heads = 16*1 - n_layers = 43 - if opts.xlarge: # this is 1.3B; decrease parameter count 30%. - d_model= 1600 # 100*16 - n_heads = 25 + if opts.xlarge: + d_model= 1600 + n_heads = 25 n_layers = 48 if opts.nn: @@ -668,7 +571,6 @@ def nanoDFT(mol_str, opts): n_heads =n_heads, d_ff =d_model*4, ) - print("[%.4fs] initialized transformer. "%(time.time()-start_time) ) params = params.to_float32() if opts.resume: @@ -679,35 +581,42 @@ def nanoDFT(mol_str, opts): if opts.nn: #https://arxiv.org/pdf/1706.03762.pdf see 5.3 optimizer - def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.min_lr, warmup_iters=opts.warmup_iters, lr_decay_iters=opts.lr_decay): + + # try to mimic karpathy as closely as possible ;) + # https://github.com/karpathy/nanoGPT/blob/master/train.py + # still differs on + # [ ] weight initialization + + def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2000, lr_decay_iters=600000): # 600k/30 = 20k; so hit mi + #return learning_rate * it / warmup_iters # to allow jax jit? + # allow jax jit + '''if it < warmup_iters: return learning_rate * it / warmup_iters # linearly increase until hit warmup iters. + if it > lr_decay_iters: return min_lr # after decay (600k iterations) go to 10x lower + + # in between, decay learning rate using this function; this is from 2k steps to 600k steps + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) + return min_lr + coeff * (learning_rate - min_lr)''' + #if it < warmup_iters: return learning_rate * it / warmup_iters cond1 = (it < warmup_iters) * learning_rate * it / warmup_iters cond2 = (it > lr_decay_iters) * min_lr + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) coeff = 0.5 * (1.0 + jnp.cos(jnp.pi * decay_ratio)) cond3 = (it >= warmup_iters) * (it <= lr_decay_iters) * (min_lr + coeff * (learning_rate - min_lr)) - if not opts.resume: return cond1 + cond2 + cond3 - else: return learning_rate + return cond1 + cond2 + cond3 adam = optax.chain( optax.clip_by_global_norm(1), - #optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), - optax.scale_by_adam(b1=0.99, b2=0.999, eps=1e-12), - #optax.scale_by_factored_rms(), # use this for larger model (more memory efficient) - optax.add_decayed_weights(0.1), + optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), + optax.add_decayed_weights(0.1),#, configure_decay_mask(params)), optax.scale_by_schedule(custom_schedule), optax.scale(-1), - #optax.ema(opts.ema) if opts.ema != 0 else None ) + w = params - df = None - if opts.qh9: - df = pd.read_pickle("qh9/qh9stable_processed_shuffled.pickle") - df = df[df["N_sto3g"]==55] - print(df.shape) - elif opts.qm9: - df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules - if nao != -1: df = df[df["nao"]==nao] from torch.utils.data import DataLoader, Dataset class OnTheFlyQM9(Dataset): @@ -715,20 +624,22 @@ class OnTheFlyQM9(Dataset): # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). # problem: second epoch always gives segfault. # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples - def __init__(self, opts, df=None, nao=294, train=True, num_epochs=10**9, extrapolate=False): + def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False): + # only take molecules with use {CNOFH}, nao=nao and spin=0. + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. # this shuffling is important, because it makes the last 10 samples iid (used for validation) #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? - if opts.qh9 or opts.qm9: - if train: self.mol_strs = df["pyscf"].values[:-10] - else: self.mol_strs = df["pyscf"].values[-10:] + + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] #print(df["pyscf"].) # todo: print smile strings self.num_epochs = num_epochs self.opts = opts self.validation = not train self.extrapolate = extrapolate - self.do_pyscf = self.validation or self.extrapolate self.benzene = [ ["C", ( 0.0000, 0.0000, 0.0000)], @@ -754,64 +665,34 @@ def __init__(self, opts, df=None, nao=294, train=True, num_epochs=10**9, extrapo ] if opts.benzene: self.mol_strs = [self.benzene] - - if opts.md17 > 0: - mol = {MD17_WATER: "water", MD17_ALDEHYDE: "malondialdehyde", MD17_ETHANOL: "ethanol", MD17_URACIL: "uracil"}[opts.md17] - mode = {True: "train", False: "val"}[train] - filename = "md17/%s_%s.pickle"%(mode, mol) - df = pd.read_pickle(filename) - - self.mol_strs = df["pyscf"].values.tolist() - N = int(np.sqrt(df["H"].values.tolist()[0].reshape(-1).size)) - self.H = [a.reshape(N, N) for a in df["H"].values.tolist()] - self.E = df["E"].values.tolist() - self.mol_strs = [eval(a) for a in self.mol_strs] - else: - self.H = [0 for _ in self.mol_strs] - self.E = [0 for _ in self.mol_strs] - - + if opts.waters: self.mol_strs = [self.waters] if opts.alanine: self.mol_strs = mol_str if train: self.bs = opts.bs else: self.bs = opts.val_bs - - def __len__(self): return len(self.mol_strs)*self.num_epochs + def __len__(self): + return len(self.mol_strs)*self.num_epochs def __getitem__(self, idx): return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ - wiggle_num=0, do_pyscf=self.do_pyscf, validation=False, \ - extrapolate=self.extrapolate, mol_idx=idx), self.H[idx%len(self.mol_strs)], self.E[idx%len(self.mol_strs)] + wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx) - print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) - val_qm9 = OnTheFlyQM9(opts, train=False, df=df) - print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) - ext_qm9 = OnTheFlyQM9(opts, extrapolate=True, df=df) - print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) + val_qm9 = OnTheFlyQM9(opts, train=False) + ext_qm9 = OnTheFlyQM9(opts, extrapolate=True) + # parallel dataloader bug; precompute here is not slow but causes dataloader later to die. + # run once to quickly precompute. if opts.precompute: val_state = val_qm9[0] ext_state = ext_qm9[0] exit() - qm9 = OnTheFlyQM9(opts, train=True, df=df) - print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) + qm9 = OnTheFlyQM9(opts, train=True) if opts.workers != 0: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, prefetch_factor=2, collate_fn=lambda x: x[0]) else: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, collate_fn=lambda x: x[0]) pbar = tqdm(train_dataloader) - print("[%.4fs] initialized dataloaders. "%(time.time()-start_time) ) - - if opts.test_dataloader: - - t0 = time.time() - for iteration, (state, H, E) in enumerate(pbar): - if iteration == 0: summary(state) - print(time.time()-t0) - t0 = time.time() - print(state.pad_sizes.reshape(1, -1)) - - exit() else: @@ -829,7 +710,6 @@ def __next__(self): return self.item vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) adam_state = adam.init(w) - print("[%.4fs] jitted vandg and valf."%(time.time()-start_time) ) if opts.resume: print("loading adam state") @@ -837,13 +717,11 @@ def __next__(self): return self.item print("done") w, adam_state = jax.device_put(w), jax.device_put(adam_state) - print("[%.4fs] jax.device_put(w,adam_state)."%(time.time()-start_time) ) @partial(jax.jit, backend=opts.backend) def update(w, adam_state, accumulated_grad): - if opts.grad_acc: accumulated_grad = jax.tree_map(lambda x: x / (opts.bs * opts.mol_repeats), accumulated_grad) - else: accumulated_grad = jax.tree_map(lambda x: x / opts.bs, accumulated_grad) + accumulated_grad = jax.tree_map(lambda x: x / opts.bs, accumulated_grad) updates, adam_state = adam.update(accumulated_grad, adam_state, w) w = optax.apply_updates(w, updates) return w, adam_state @@ -854,14 +732,10 @@ def update(w, adam_state, accumulated_grad): min_val, min_dm, mins, valid_str, step, val_state, ext_state = 0, 0, np.ones(opts.bs)*1e6, "", 0, None, None t0, load_time, train_time, val_time, plot_time = time.time(), 0, 0, 0, 0 - accumulated_grad = None paddings = [] states = [] - - - print("[%.4fs] first iteration."%(time.time()-start_time) ) - for iteration, (state, H, E) in enumerate(pbar): + for iteration, state in enumerate(pbar): if iteration == 0: summary(state) state = jax.device_put(state) @@ -877,68 +751,55 @@ def update(w, adam_state, accumulated_grad): states.append(state) if len(states) > opts.mol_repeats: states.pop(0) - if opts.shuffle: random.shuffle(states) # load_time, t0 = time.time()-t0, time.time() - + if opts.checkpoint != -1 and iteration % opts.checkpoint == 0: # and iteration > 0: + t0 = time.time() + try: + name = opts.name.replace("-", "_") + path_model = "checkpoints/%s_%i_model.pickle"%(name, iteration) + path_adam = "checkpoints/%s_%i_adam_state.pickle"%(name, iteration) + print("trying to checkpoint to %s and %s"%(path_model, path_adam)) + pickle.dump(jax.device_get(w), open(path_model, "wb")) + pickle.dump(jax.device_get(adam_state), open(path_adam, "wb")) + print("done!") + print("\t-resume \"%s\""%(path_model.replace("_model.pickle", ""))) + except: + print("fail!") + pass + print("tried saving model took %fs"%(time.time()-t0)) + save_time, t0 = time.time()-t0, time.time() + + + if len(states) < 50: print(len(states)) + + for j, state in enumerate(states): + print(". ", end="", flush=True) + if j == 0: _t0 =time.time() + (val, (vals, E_xc, density_matrix, _W)), grad = vandg(w, state, opts.normal, opts.nn, cfg, opts) + print(",", end="", flush=True) + if j == 0: time_step1 = time.time()-_t0 - if len(states) < 50: print(len(states), opts.name) - - reps = 1 - if opts.md17 == 4: reps = 5 - if opts.md17 == 3: reps = 2 - if opts.md17 == 2: reps = 2 - - for _ in range(reps): - for j, state in enumerate(states): - print(". ", end="", flush=True) - if j == 0: _t0 =time.time() - (val, (vals, E_xc, density_matrix, _W, _H)), grad = vandg(w, state, opts.normal, opts.nn, cfg, opts) - print(",", end="", flush=True) - if j == 0: time_step1 = time.time()-_t0 - - if opts.grad_acc == 0 or len(states) < opts.mol_repeats: - w, adam_state = update(w, adam_state, grad) - else: - accumulated_grad = grad if accumulated_grad is None else jax.tree_map(lambda x, y: x + y, accumulated_grad, grad) - - if (j+1) % opts.grad_acc == 0 and j > 0: # we assume opts.grad_acc divides opts.mol_repeats; prev was basically grad_acc=0 or grad_acc=mol_repeats, can now do hybrid. - w, adam_state = update(w, adam_state, grad) - accumulated_grad = None - print("#", end="", flush=True) - - - if opts.checkpoint != -1 and adam_state[1].count % opts.checkpoint == 0 and adam_state[1].count > 0: - t0 = time.time() - try: - name = opts.name.replace("-", "_") - path_model = "checkpoints/%s_%i_model.pickle"%(name, iteration) - path_adam = "checkpoints/%s_%i_adam_state.pickle"%(name, iteration) - print("trying to checkpoint to %s and %s"%(path_model, path_adam)) - pickle.dump(jax.device_get(w), open(path_model, "wb")) - pickle.dump(jax.device_get(adam_state), open(path_adam, "wb")) - print("done!") - print("\t-resume \"%s\""%(path_model.replace("_model.pickle", ""))) - except: - print("fail!") - pass - print("tried saving model took %fs"%(time.time()-t0)) - save_time, t0 = time.time()-t0, time.time() + # todo: have hyper parameter that accumulates gradient or takes step? + w, adam_state = update(w, adam_state, grad) # todo: rename global_batch_size = len(states)*opts.bs if opts.wandb: dct["global_batch_size"] = global_batch_size train_time, t0 = time.time()-t0, time.time() + + # plot grad norm + #if iteration % 10 == 0: + # for k,v in accumulated_grad.items(): dct[k + "_norm"] = np.linalg.norm(v .reshape(-1) ) update_time, t0 = time.time()-t0, time.time() if not opts.nn: str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" str += "pyscf=%.7f us=%.7f"%(state.pyscf_E[0]/HARTREE_TO_EV, vals[0]) else: - #print(vals[0], E) - pbar.set_description("train=%.4f"%(vals[0]*HARTREE_TO_EV) + "[eV] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, update_time, val_time, plot_time)) + pbar.set_description("train=".join(["%.2f"%i for i in vals[:1]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, update_time, val_time, plot_time)) if opts.wandb: dct["time_load"] = load_time @@ -946,66 +807,45 @@ def update(w, adam_state, accumulated_grad): dct["time_train"] = train_time dct["time_val"] = val_time plot_iteration = iteration % 10 == 0 - - dct["train_E"] = np.abs(E*HARTREE_TO_EV) - dct["train_E_pred"] = np.abs(vals[0]*HARTREE_TO_EV) + for i in range(0, 2): + if not opts.nn: + dct['train_l%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) + dct['train_pyscf%i'%i ] = np.abs(state.pyscf_E[i]) + dct['train_E%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV) + if plot_iteration: + dct['img/dm%i'%i] = wandb.Image(np.expand_dims(density_matrix[i], axis=-1)) + dct['img/W%i'%i] = wandb.Image(np.expand_dims(_W[i], axis=-1)) step = adam_state[1].count plot_time, t0 = time.time()-t0, time.time() - if opts.nn and (iteration < 250 or iteration % 10 == 0): - val_idx = 1 - if val_state is None: val_state, val_H, val_E = jax.device_put(val_qm9[val_idx]) # todo: cat 8 of these. - _, (valid_vals, _, vdensity_matrix, vW, _val_H) = valf(w, val_state, opts.normal, opts.nn, cfg, opts) - - if opts.md17 > 0: - def get_H_from_dm(dm): - import pyscf - from pyscf import gto, dft - m = pyscf.gto.Mole(atom=val_qm9.mol_strs[val_idx], basis="def2-svp", unit="bohr") - m.build() - mf = dft.RKS(m) - mf.xc = 'B3LYP5' - mf.verbose = 0 - mf.diis_space = 8 - mf.conv_tol = 1e-13 - mf.grad_tol = 3.16e-5 - mf.grids.level = 3 - #mf.kernel() - h_core = mf.get_hcore() - S = mf.get_ovlp() - vxc = mf.get_veff(m, dm) - H = h_core + vxc - S = mf.get_ovlp() - return H, S - - matrix = np.array(vdensity_matrix[0]) - N = int(np.sqrt(matrix.size)) - _val_H, S = get_H_from_dm(matrix.reshape(N, N)) - - # compare eigenvalues - pred_vals = scipy.linalg.eigh(_val_H, S)[0] - label_vals = scipy.linalg.eigh(val_H, S)[0] - MAE_vals = np.mean(np.abs(pred_vals - label_vals)) - dct["val_eps"] = MAE_vals - + - lr = custom_schedule(step) - valid_str = "lr=%.3e"%lr + "val=%.4f [eV] "%(valid_vals[0]*HARTREE_TO_EV-val_E*HARTREE_TO_EV) + "mae_H=%.4f "%( - np.mean(np.abs(val_H/np.abs(val_H) - _val_H/np.abs(_val_H))) - ) - if opts.md17> 0:valid_str+= " eps=%.4f"%(MAE_vals) - valid_str += "val'=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + # TODO: Plot molecules and val/ext angles. + if opts.nn and (iteration < 250 or iteration % 10 == 0): - dct['val_E'] = np.abs(valid_vals[0]*HARTREE_TO_EV-val_E*HARTREE_TO_EV ) - dct['val_H_MAE'] = np.mean(np.abs(val_H - _val_H)) # perhaps sign doesn't matter? + if val_state is None: val_state = jax.device_put(val_qm9[0]) + _, (valid_vals, _, vdensity_matrix, vW) = valf(w, val_state, opts.normal, opts.nn, cfg, opts) + if ext_state is None: ext_state = jax.device_put(ext_qm9[0]) + _, (ext_vals, _, edensity_matrix, eW) = valf(w, ext_state, opts.normal, opts.nn, cfg, opts) + lr = custom_schedule(step) + valid_str = "lr=%.3e"%lr + "val=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + valid_str += "ext=" + "".join(["%.4f "%(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" if opts.wandb: - for i in range(0, 3): + for i in range(0, opts.val_bs): dct['valid_l%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) dct['valid_E%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV) dct['valid_pyscf%i'%i ] = np.abs(val_state.pyscf_E[i]) + dct['img/val_dm%i'%i] = wandb.Image(np.expand_dims(vdensity_matrix[i], axis=-1)) + dct['img/val_W%i'%i] = wandb.Image(np.expand_dims(vW[i], axis=-1)) + + dct['ext_l%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) + dct['ext_E%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV) + dct['ext_pyscf%i'%i ] = np.abs(ext_state.pyscf_E[i]) + dct['img/ext_dm%i'%i] = wandb.Image(np.expand_dims(edensity_matrix[i], axis=-1)) + dct['img/ext_W%i'%i] = wandb.Image(np.expand_dims(eW[i], axis=-1)) dct["scheduled_lr"] = lr @@ -1154,6 +994,7 @@ def get_partition( coords_all = [] weights_all = [] + # [ ] consider another grid? for ia in range(mol.natm): coords, vol = atom_grids_tab[mol.atom_symbol(ia)] coords = coords + atom_coords[ia] # [ngrid, 3] @@ -1178,9 +1019,7 @@ def build(self, atom_coords, state=None) : mol = self.mol atom_grids_tab = self.gen_atomic_grids( - mol, self.atom_grid, self.radi_method, self.level, - self.prune, - #False, # WARNING: disabling self.prune; this makes sizes of C,N,O,F all a bit larger, but the same ; allow atom substitution + mol, self.atom_grid, self.radi_method, self.level, self.prune ) coords, weights = get_partition( @@ -1210,10 +1049,9 @@ def grids_from_pyscf_mol( def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=True, state=None, pad_electrons=-1): - do_print = False #t0 = time.time() - mol = build_mol(mol_str, opts.basis, unit="bohr") - if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts, unit="bohr") + mol = build_mol(mol_str, opts.basis) + if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) else: pyscf_E, pyscf_hlgap, pyscf_forces = np.zeros(1), np.zeros(1), np.zeros(1) N = mol.nao_nr() # N=66 for C6H6 (number of atomic **and** molecular orbitals) @@ -1221,8 +1059,6 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T E_nuc = mol.energy_nuc() # float = 202.4065 [Hartree] for C6H6. TODO(): Port to jax. from pyscf import dft - if do_print: print("grid", end="", flush=True) - #grids = pyscf.dft.gen_grid.Grids(mol) grids = DifferentiableGrids(mol) grids.level = opts.level @@ -1234,8 +1070,6 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1' grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6. - if do_print: print("int1e", end="", flush=True) - # TODO(): Add integral math formulas for kinetic/nuclear/O/ERI. kinetic = mol.intor_symmetric('int1e_kin') # (N,N) nuclear = mol.intor_symmetric('int1e_nuc') # (N,N) @@ -1260,14 +1094,10 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T eri_threshold = 0 batches = 1 nipu = 1 - - # todo: rewrite int2e_sph to only recompute changing atomic orbitals (will be N times faster). - if do_print: print("int2e",end ="", flush=True) nonzero_distinct_ERI = mol.intor("int2e_sph", aosym="s8") #ERI = [nonzero_distinct_ERI, nonzero_indices] #ERI = ERI ERI = np.zeros(1) - if do_print: print(nonzero_distinct_ERI.shape, nonzero_distinct_ERI.nbytes/10**9) #ERI = mol.intor("int2e_sph") def e(x): return np.expand_dims(x, axis=0) @@ -1474,19 +1304,12 @@ def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|> def pyscf_reference(mol_str, opts): from pyscf import __config__ __config__.dft_rks_RKS_grids_level = opts.level - mol = build_mol(mol_str, opts.basis, unit="bohr") + mol = build_mol(mol_str, opts.basis) mol.max_cycle = 50 mf = pyscf.scf.RKS(mol) - #mf.max_cycle = 50 - #mf.xc = "b3lyp5" - #mf.diis_space = 8 - mf.xc = 'B3LYP5' - mf.verbose = 0 # put this to 4 to check i set parameters correctly! + mf.max_cycle = 50 + mf.xc = "b3lyp5" mf.diis_space = 8 - # options from qh9 - mf.conv_tol=1e-13 - mf.grad_tol=3.16e-5 - mf.grids.level = 3 pyscf_energies = [] pyscf_hlgaps = [] lumo = mol.nelectron//2 @@ -1525,26 +1348,35 @@ def print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, cosine_similarity = dot_products / (norm_X * norm_Y) print("Force cosine similarity:",cosine_similarity) -def build_mol(mol_str, basis_name, unit="bohr"): +def build_mol(mol_str, basis_name): mol = pyscf.gto.mole.Mole() - mol.build(atom=mol_str, unit=unit, basis=basis_name, spin=0, verbose=0) + mol.build(atom=mol_str, unit="Angstrom", basis=basis_name, spin=0, verbose=0) return mol -def reference(mol_str, opts, unit="bohr"): +def reference(mol_str, opts): import pickle import hashlib if opts.skip: return np.zeros(1), np.zeros(1), np.zeros(1) - filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level) + unit).encode('utf-8')).hexdigest() + filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level)).encode('utf-8')).hexdigest() print(filename) if not os.path.exists(filename): pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) with open(filename, "wb") as file: - pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces, unit], file) + pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces], file) else: - pyscf_E, pyscf_hlgap, pyscf_forces, unit = pickle.load(open(filename, "rb")) + pyscf_E, pyscf_hlgap, pyscf_forces = pickle.load(open(filename, "rb")) return pyscf_E, pyscf_hlgap, pyscf_forces +class HashableNamespace: + def __init__(self, namespace): + self.__dict__.update(namespace.__dict__) + + def __hash__(self): + # Convert the relevant attributes to a tuple for hashing + return hash(tuple(sorted(self.__dict__.items()))) + + if __name__ == "__main__": import os import argparse @@ -1556,18 +1388,11 @@ def reference(mol_str, opts, unit="bohr"): # GD options parser.add_argument('-backend', type=str, default="cpu") - parser.add_argument('-lr', type=float, default=5e-4) - parser.add_argument('-min_lr', type=float, default=1e-7) - parser.add_argument('-warmup_iters', type=float, default=1000) - parser.add_argument('-lr_decay', type=float, default=200000) - parser.add_argument('-ema', type=float, default=0.0) - + parser.add_argument('-lr', type=float, default=2.5e-4) parser.add_argument('-steps', type=int, default=100000) parser.add_argument('-bs', type=int, default=8) - parser.add_argument('-val_bs', type=int, default=3) + parser.add_argument('-val_bs', type=int, default=8) parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. - parser.add_argument('-grad_acc', type=int, default=0) # integer, deciding how many steps to accumulate. - parser.add_argument('-shuffle', action="store_true") # whether to to shuffle the window of states each step. # energy computation speedups parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) @@ -1582,16 +1407,12 @@ def reference(mol_str, opts, unit="bohr"): parser.add_argument('-skip', action="store_true", help="skip pyscf test case") # dataset - parser.add_argument('-nperturb', type=int, default=0, help="How many atoms to perturb (supports 1,2,3)") parser.add_argument('-qm9', action="store_true") - parser.add_argument('-md17', type=int, default=-1) - parser.add_argument('-qh9', action="store_true") parser.add_argument('-benzene', action="store_true") parser.add_argument('-hydrogens', action="store_true") parser.add_argument('-water', action="store_true") parser.add_argument('-waters', action="store_true") parser.add_argument('-alanine', action="store_true") - parser.add_argument('-do_print', action="store_true") # useful for debugging. parser.add_argument('-states', type=int, default=1) parser.add_argument('-workers', type=int, default=5) parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. @@ -1599,9 +1420,6 @@ def reference(mol_str, opts, unit="bohr"): parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") - parser.add_argument('-test_dataloader', action="store_true", help="no training, just test/loop through dataloader. ") - - # models parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") @@ -1611,19 +1429,12 @@ def reference(mol_str, opts, unit="bohr"): parser.add_argument('-medium', action="store_true") parser.add_argument('-large', action="store_true") parser.add_argument('-xlarge', action="store_true") - parser.add_argument('-largep', action="store_true") # large "plus" parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model opts = parser.parse_args() - if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True - assert opts.grad_acc == 0 or opts.mol_repeats % opts.grad_acc == 0, "mol_repeats needs to be a multiple of grad_acc (gradient accumulation)." - - class HashableNamespace: - def __init__(self, namespace): self.__dict__.update(namespace.__dict__) - def __hash__(self): return hash(tuple(sorted(self.__dict__.items()))) - opts = HashableNamespace(opts) + if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True args_dict = vars(opts) print(args_dict) @@ -1633,9 +1444,6 @@ def __hash__(self): return hash(tuple(sorted(self.__dict__.items()))) df = df[df["spin"] == 0] # only consider spin=0 mol_strs = df["pyscf"].values - if opts.qh9: - mol_strs = [] - # benzene if opts.benzene: mol_strs = [[ @@ -1658,7 +1466,7 @@ def __hash__(self): return hash(tuple(sorted(self.__dict__.items()))) ["H", ( 0.0000, 0.0000, 0.0000)], ["H", ( 1.4000, 0.0000, 0.0000)], ]] - if opts.md17 > 0 : + if opts.water: mol_strs = [[ ["O", ( 0.0000, 0.0000, 0.0000)], ["H", ( 0.0000, 1.4000, 0.0000)], @@ -1699,10 +1507,12 @@ def __hash__(self): return hash(tuple(sorted(self.__dict__.items()))) ["H", ( 6.360 , 8.648, -0.890)], ]] + # make opts hashable so that JAX will not complain about the static parameter that is passed as arg + opts = HashableNamespace(opts) nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_strs, opts) exit() pyscf_E, pyscf_hlgap, pyscf_forces = reference(mol_str, opts) nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy, np.array(dm), np.array(H)) - print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) \ No newline at end of file + print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py index 45e0e05..9cc2b9d 100644 --- a/pyscf_ipu/direct/transformer.py +++ b/pyscf_ipu/direct/transformer.py @@ -58,7 +58,7 @@ def transformer_init( total_params += np.prod(params.embeddings.shape) print("%26s %26s %26s"%("params.embeddings",params.embeddings.shape, np.prod(params.embeddings.shape))) - rng, params.project_positions, shape = linear_init_uniform(rng, 123, d_model) + rng, params.project_positions, shape = linear_init_uniform(rng, 12, d_model) total_params += np.prod(shape) print("%26s %26s %26s"%("params.project_positions",shape, np.prod(shape))) @@ -95,32 +95,23 @@ def transformer_init( @partial(jax.jit, static_argnums=0) -def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp.ndarray, L_inv): +def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp.ndarray): """ cfg: Config, from transformer_init, holds hyperparameters params: Current transformer parameters, initialized in init x: 1D array of L integers, representing the input sequence output: L x n_vocab logits """ + L, = x.shape # x is just 1D. Vmap/pmap will handle batching + embeddings = cfg.lambda_e * params.embeddings[x, :] # L x Dm - L, Dm = embeddings.shape - - # Roughly get f( {R@ri+t}_i ) = f( {r_i}_i ) - position = position - jnp.mean(position, axis=0).reshape(1, 3) # makes jnp.mean(position, axis=0) = [0,0,0] - cov = jnp.cov(position.T) - eigvects = jnp.linalg.eigh(cov)[1] - position = position @ eigvects # makes jnp.cov(positions.T)=jnp.eye(3) - - # Mix of sin/cos and 3d point cloud transformers. - #position = jnp.concatenate([position, jnp.cos(position), jnp.sin(position), jnp.tanh(position)], axis=1) #(N,3) -> (N,12) - position = jnp.concatenate([position] + \ - [jnp.cos(position*f/20*2*np.pi) for f in range(20)] + \ - [jnp.sin(position*f/20*2*np.pi) for f in range(20)], - axis=1) #(N,3) -> (N,3+60+60) = (N, 123) + + all_pairs = jnp.linalg.norm(position.reshape(1, -1, 3) - position.reshape(-1, 1, 3), axis=-1) + + # inspired by 3d point cloud transformers; + # nspired by andrew: use trigonometric functions as feature transformations + position = jnp.concatenate([position, jnp.cos(position), jnp.sin(position), jnp.tanh(position)], axis=1) #(N,3) -> (N,12) positions = linear(params.project_positions, position) # L x Dm - del position - all_pairs = jnp.linalg.norm(positions.reshape(1, -1, Dm) - positions.reshape(-1, 1, Dm), axis=-1) - all_pairs = all_pairs / jnp.max(all_pairs) # Add (learned) positional encodings x = embeddings + positions # L x Dm @@ -137,13 +128,12 @@ def block(x, layer_num, layer): q = jnp.transpose(q.reshape(L, nheads, Dm//nheads), (1, 0, 2)) k = jnp.transpose(k.reshape(L, nheads, Dm//nheads), (1, 0, 2)) v = jnp.transpose(v.reshape(L, nheads, Dm//nheads), (1, 0, 2)) - score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm//nheads) + score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm) - if True: # todo: why does this improve loss from ~1000 to ~300 first step (qm9). - score += H_core - #score += all_pairs # => NaNs for some reason - #score += L_inv - score += L_inv @ H_core @ L_inv.T + # do like graphformer and append position here? + #if layer_num < 6: # doesn't look like it helps + # score += H_core + # score += all_pairs attn = jax.nn.softmax(score , axis=1) x = x + (attn @ v).reshape(L, Dm) @@ -159,26 +149,16 @@ def block(x, layer_num, layer): # Residual connection x = x + t2 - return x + return x, score # Apply the transformer layers # todo: cut jit time by making this jax.lax.foriloop - for layer_num, layer in enumerate(params.layers[:-1]): - x = jax.checkpoint(block)(x, layer_num, layer) - - layer = params.layers[-1] - # Prediction is last attention (without nhead = 1), and q=k so score is symmetric! - nheads = 1 - t1 = vmap(standardize)(x) # L x Dm - t1 = elementwise_linear(layer.norm_self_attn, t1) # L x Dm - qkv = linear(layer.kqv, t1) - q,k,v = jnp.split(qkv, 3, axis=1) - q = jnp.transpose(q.reshape(L, nheads, Dm//nheads), (1, 0, 2)) - k = q - #v = jnp.transpose(v.reshape(L, nheads, Dm//nheads), (1, 0, 2)) - score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm*nheads) # symmetric: initial loss goes from 1200 to 980 (qm9). + for layer_num, layer in enumerate(params.layers): + x, score = jax.checkpoint(block)(x, layer_num, layer) - M = score[0] + # todo: if this isn't symmetric eigh gives imaginary eigenvalues? (bad) + M = score[0] # take first attention head + #M = (M + M.T)/2 # make symmetric! return M import types @@ -294,7 +274,7 @@ def convert_to_float32(x): parser.add_argument('-large', action="store_true") parser.add_argument('-xlarge', action="store_true") opts = parser.parse_args() - + # initialize model # transformer tiny 5M d_model= 192 @@ -331,10 +311,10 @@ def convert_to_float32(x): extrapolate=False, mol_idx=0) summary(state) - output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0, 0), out_axes=(0)), + output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0)), static_argnums=(0,), backend="cpu")(cfg, \ - params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32), state.L_inv.astype(jnp.float32)) + params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) print(np.sum(output)) # 162.58726108305348 @@ -348,10 +328,10 @@ def convert_to_float32(x): new_params = pickle.load(open("checkpoints/example.pickle", "rb")) # check that output remains the same - new_output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0, 0), out_axes=(0)), + new_output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0)), static_argnums=(0,), backend="cpu")(cfg, \ - new_params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32), state.L_inv.astype(jnp.float32)) + new_params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) assert np.allclose(output, new_output) print("TEST CASE PASSED!")