From 458bee1ccf306cd110e1c97b832336d735085508 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Sun, 14 Jan 2024 15:48:46 +0000 Subject: [PATCH] . --- pyscf_ipu/direct/train.py | 566 +++++++++++++++++++++++--------------- 1 file changed, 349 insertions(+), 217 deletions(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index c94dc83..18d1c5b 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -1,8 +1,9 @@ import os -os.environ['OMP_NUM_THREADS'] = '8' +os.environ['OMP_NUM_THREADS'] = '16' import jax jax.config.update('jax_enable_x64', True) import jax.numpy as jnp +import scipy import numpy as np import pyscf import optax @@ -18,6 +19,7 @@ import random random.seed(42) +MD17_WATER, MD17_ETHANOL, MD17_ALDEHYDE, MD17_URACIL = 1, 2, 3, 4 cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 def T(x): return jnp.transpose(x, (0,2,1)) @@ -25,20 +27,24 @@ def T(x): return jnp.transpose(x, (0,2,1)) B, BxNxN, BxNxK = None, None, None # Only need to recompute: L_inv, grid_AO, grid_weights, H_core, ERI and E_nuc. -def dm_energy(W: BxNxK, state, normal, nn): +def dm_energy(W: BxNxK, state, normal, nn, cfg=None, opts=None): if nn: W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0, 0), out_axes=(0))(cfg, \ W, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32), state.L_inv.astype(jnp.float32)) + #W, state.ao_types, state.pos.astype(jnp.float64), state.H_core.astype(jnp.float64), state.L_inv.astype(jnp.float64)) W = W.astype(jnp.float64) # we can interpret state.H_core + W as hamiltonian, and predict hlgap from these! - L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ (state.H_core + W) @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO + H = state.H_core + W + L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ H @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO density_matrix: BxNxN = 2 * (L_inv_Q*state.mask) @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO E_xc: B = exchange_correlation(density_matrix, state, normal, opts.xc_f32) # O(B*gsize*N^2) FLOP O(gsize*N^2) FLIO - diff_JK: BxNxN = JK(density_matrix, state, normal, opts.foriloop, opts.eri_f32) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO + diff_JK: BxNxN = JK(density_matrix, state, normal, opts.foriloop, opts.eri_f32, opts.bs) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO energies: B = E_xc + state.E_nuc + jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(W.shape[0], -1), axis=-1) energy: float = jnp.sum(energies) - return energy, (energies, E_xc, density_matrix, W) + return energy, (energies, E_xc, density_matrix, W, H) + + def sparse_mult(values, dm, state, gsize): in_ = dm.take(state.cols, axis=0) @@ -70,7 +76,7 @@ def exchange_correlation(density_matrix: BxNxN, state, normal, xc_f32): E_xc = jnp.sum(rho[:, 0] * state.grid_weights * E_xc, axis=-1).reshape(B) return E_xc -def JK(density_matrix, state, normal, jax_foriloop, eri_f32): +def JK(density_matrix, state, normal, jax_foriloop, eri_f32, bs): if normal: J = jnp.einsum('bijkl,bji->bkl', state.ERI, density_matrix) K = jnp.einsum('bijkl,bjk->bil', state.ERI, density_matrix) @@ -80,32 +86,43 @@ def JK(density_matrix, state, normal, jax_foriloop, eri_f32): if eri_f32: density_matrix = density_matrix.astype(jnp.float32) - '''diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( - state.nonzero_distinct_ERI[0], - state.nonzero_indices[0], - density_matrix, - jax_foriloop - ) - diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0, None))( - state.diffs_ERI, - state.indxs, - density_matrix, - jax_foriloop - )''' - - diff_JK: BxNxN = jax.vmap(sparse_einsum, in_axes=(None, None, 0, None))( - state.nonzero_distinct_ERI[0], - state.precomputed_nonzero_indices, - density_matrix, - jax_foriloop - ) - correction = jax.vmap(sparse_einsum, in_axes=(0, None, 0, None))( - state.diffs_ERI, - state.precomputed_indxs, - density_matrix, - jax_foriloop - ) - diff_JK: BxNxN = diff_JK - correction + if bs == 1: + diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.nonzero_indices[0], + density_matrix, + jax_foriloop + ) + + + else: + '''diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.nonzero_indices[0], + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.indxs, + density_matrix, + jax_foriloop + )''' + + diff_JK: BxNxN = jax.vmap(sparse_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.precomputed_nonzero_indices, + density_matrix, + jax_foriloop + ) + if bs > 1: + correction = jax.vmap(sparse_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.precomputed_indxs, + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - correction return diff_JK.astype(jnp.float64) @@ -125,8 +142,9 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_sparse_diff_grid=200000, mol_idx=42, ): - - do_print = False + start_time = time.time() + do_print = opts.do_print + if do_print: print("\t[%.4fs] start of 'batched_state'. "%(time.time()-start_time)) # pad molecule if using nn. if not opts.nn: pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ @@ -142,24 +160,12 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_nonzero_distinct_ERI = 20000 pad_sparse_diff_grid = 20000 - if opts.qm9: + if opts.qm9 or opts.qh9: pad_electrons=60 - '''pad_diff_ERIs=120000 - pad_distinct_ERIs=400000 - pad_grid_AO=50000 - pad_nonzero_distinct_ERI=400000 - pad_sparse_diff_grid=400000''' - #padding_estimate = [37426, 149710, 17010, 140122, 138369] - - # idea: - # train w/o atom pertubation until convergence, then 1, then 2, then 3. - padding_estimate = [48330, 163222, 17034, 159361, 139505] - if opts.nperturb == 2 or opts.nperturb == 1: padding_estimate = [int(a*2.1) for a in padding_estimate] else: padding_estimate = [int(a*1.1) for a in padding_estimate] - pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate if opts.basis == "def2-svp": @@ -185,7 +191,39 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, if opts.waters: pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = [a//3 for a in [ pad_diff_ERIs , pad_distinct_ERIs , pad_grid_AO , pad_nonzero_distinct_ERI , pad_sparse_diff_grid ]] - mol = build_mol(mol_str, opts.basis) + if opts.md17 > 0: + if opts.md17 == MD17_WATER: + if opts.level == 1: padding_estimate = [ 3361, 5024, 10172 , 5024 , 155958] + if opts.level == 3: padding_estimate = [ 3361, 5024, 34310 , 5024 , 494370] + if opts.bs == 2 and opts.wiggle_var == 0: padding_estimate = [ 1, 5024, 10172 , 5024 , 1] + padding_estimate = [int(a*1.5) for a in padding_estimate] + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + + elif opts.md17 == MD17_ETHANOL: + pad_electrons = 72 + #padding_estimate = [ 99042, 197660, 34310*5 , 197308 , 494370*5] + #padding_estimate = [113522, 224275, 30348, 197308, 609163] + if opts.level == 1: padding_estimate = [ 1, 415186, 30348, 415145, 1] + padding_estimate = [int(a*1.1) for a in padding_estimate] + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + + elif opts.md17 == MD17_ALDEHYDE: + pad_electrons = 90 + #padding_estimate = [130646, 224386 , 30348, 223233, 626063] + #padding_estimate = [ 34074, 204235 , 30348, 203632, 285934] + if opts.level == 1: padding_estimate = [1, 939479, 35704, 939479, 1] + padding_estimate = [int(a*1.1) for a in padding_estimate] + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + + elif opts.md17 == MD17_URACIL: + #raise Exception("not ready yet") + pad_electrons = 132 + padding_estimate = [1, 3769764 , 51184, 3769764, 1] + padding_estimate = [int(a*1.05) for a in padding_estimate] + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + + + mol = build_mol(mol_str, opts.basis, unit="bohr") pad_electrons = min(pad_electrons, mol.nao_nr()) # Set seed to ensure different rotation; initially all workers did same rotation! @@ -195,7 +233,13 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, water1_xyz = np.array([mol_str[i][1] for i in range(0,3)]) water2_xyz = np.array([mol_str[i][1] for i in range(3,6)]) - if opts.qm9: + + if opts.md17 > 0: + natm = len(mol_str) + atom_num = random.sample(range(natm), 1)[0] + atoms = np.array([mol_str[i][1] for i in range(0,natm)]) + + if opts.qm9 or opts.qh9: # pick random atom to perturb (of the first 9 heavy ones) atom_num1, atom_num2, atom_num3 = random.sample(range(9), 3) atoms = np.array([mol_str[i][1] for i in range(0,10)]) @@ -235,7 +279,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, states = [] for iteration in range(bs): - + if do_print: print("\t[%.4fs] initializing state %i. "%(time.time()-start_time, iteration)) if opts.alanine: from rdkit import Chem from rdkit.Chem import AllChem @@ -277,6 +321,10 @@ def get_atom_positions(mol): import wandb wandb.log({"mol_valid=%s"%validation: create_rdkit_mol(str, pos) })''' + if opts.md17 > 0: + if iteration > 0: + mol_str[atom_num][1] = tuple(atoms[atom_num] + np.random.normal(0,opts.wiggle_var, (3))) + if opts.waters: # todo: rotate both water molecules and draw x=phi, y=psi. rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] center = water2_xyz.mean(axis=0) @@ -293,7 +341,7 @@ def get_atom_positions(mol): pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' - elif opts.qm9: + elif opts.qm9 or opts.qh9: if iteration == 0 and (validation or extrapolate): pass else: s = mol_str[atom_num1][0]+ mol_str[atom_num2][0]+ mol_str[atom_num3][0] @@ -318,16 +366,14 @@ def get_atom_positions(mol): c, w = state.grid_coords, state.grid_weights elif iteration <= 1 or not opts.prof: # when profiling create fake molecule to skip waiting state = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and iteration < 3, state=state, pad_electrons=pad_electrons) - states.append(state) - if do_print: print("cat states") state = cats(states) N = state.N[0] + if do_print: print("\t[%.4fs] concatenated states. "%(time.time()-start_time)) - if do_print: print("get sparsity") #_nonzero = None @@ -341,24 +387,24 @@ def get_atom_positions(mol): #_nonzero = indxs if _nonzero is None else np.logical_or(_nonzero, indxs) + if do_print: print("\t[%.4fs] got sparsity. "%(time.time()-start_time)) # Merge nonzero indices and prepare (ij, kl). # rep is the number of repetitions we include in the sparse representation. - if do_print: print("union") # bottleneck for def2-svp. todo: fix above logical_or trick to do faster. union = nonzero[0] for i in range(1, len(nonzero)): # this takes 12s/it for def2-svp. union = np.union1d(union, nonzero[i]) nonzero_indices = union - #exit() + if do_print: print("\t[%.4fs] got union of sparsity. "%(time.time()-start_time)) from sparse_symmetric_ERI import get_i_j, num_repetitions_fast ij, kl = get_i_j(nonzero_indices) rep = num_repetitions_fast(ij, kl) + if do_print: print("\t[%.4fs] got (ij) and reps. "%(time.time()-start_time)) batches = opts.eri_bs es = [] - if do_print: print("pad") for e,i in zip(state.nonzero_distinct_ERI, state.nonzero_indices): nonzero_distinct_ERI = e[nonzero_indices] / rep remainder = nonzero_indices.shape[0] % (batches) @@ -369,9 +415,10 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = np.concatenate([np.expand_dims(a, axis=0) for a in es]) - if do_print: print("pad ijkl indices") + if do_print: print("\t[%.4fs] padded ERI and nonzero_indices. . "%(time.time()-start_time)) i, j = get_i_j(ij.reshape(-1)) k, l = get_i_j(kl.reshape(-1)) + if do_print: print("\t[%.4fs] got ijkl. "%(time.time()-start_time)) if remainder != 0: i = np.pad(i, ((0,batches-remainder))) @@ -380,13 +427,15 @@ def get_atom_positions(mol): l = np.pad(l, ((0,batches-remainder))) nonzero_indices = np.vstack([i,j,k,l]).T.reshape(batches, -1, 4).astype(np.int32) # todo: use int16 or int32 here? state.nonzero_indices = nonzero_indices + if do_print: print("\t[%.4fs] padded and vstacked ijkl. "%(time.time()-start_time)) # batching (w/ same sparsity pattern across batch) allows precomputing all {ss,dm}_indices instead of computing in sparse_sym_eri every iteration. # function below does this. # todo: consider removing, didn't get expecting 3x (only 5%; not sure if additional memory/complication justifies). - if do_print: print("precompute indices. ") from sparse_symmetric_ERI import precompute_indices + + if opts.normal: diff_state = None else: main_grid_AO = state.grid_AO[:1] @@ -395,39 +444,41 @@ def get_atom_positions(mol): sparse_diffs_grid_AO = diffs_grid_AO[:, 0, rows,cols] # use the same sparsity pattern across a batch. - diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI - diff_indxs = state.nonzero_indices.reshape(1, batches, -1, 4) - nzr = np.abs(diff_ERIs[1]).reshape(batches, -1) > 1e-10 + if opts.bs > 1: + diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI + diff_indxs = state.nonzero_indices.reshape(1, batches, -1, 4) + nzr = np.abs(diff_ERIs[1]).reshape(batches, -1) > 1e-10 - diff_ERIs = diff_ERIs[:, nzr].reshape(bs, -1) - diff_indxs = diff_indxs[:, nzr].reshape(-1, 4) + diff_ERIs = diff_ERIs[:, nzr].reshape(bs, -1) + diff_indxs = diff_indxs[:, nzr].reshape(-1, 4) - remainder = np.sum(nzr) % batches - if remainder != 0: - diff_ERIs = np.pad(diff_ERIs, ((0,0),(0,batches-remainder))) - diff_indxs = np.pad(diff_indxs, ((0,batches-remainder),(0,0))) + remainder = np.sum(nzr) % batches + if remainder != 0: + diff_ERIs = np.pad(diff_ERIs, ((0,0),(0,batches-remainder))) + diff_indxs = np.pad(diff_indxs, ((0,batches-remainder),(0,0))) - diff_ERIs = diff_ERIs.reshape(bs, batches, -1) - diff_indxs = diff_indxs.reshape(batches, -1, 4) + diff_ERIs = diff_ERIs.reshape(bs, batches, -1) + diff_indxs = diff_indxs.reshape(batches, -1, 4) - precomputed_indxs = precompute_indices(diff_indxs, N).astype(np.int16) + if opts.bs > 1: precomputed_indxs = precompute_indices(diff_indxs, N).astype(np.int16) - if pad_diff_ERIs == -1: - state.indxs=diff_indxs - state.diffs_ERI=diff_ERIs - assert False, "deal with precomputed_indxs; only added in else branch below" - else: - max_pad_diff_ERIs = diff_ERIs.shape[2] - # pad ERIs with 0 and indices with -1 so they point to 0. - assert diff_indxs.shape[1] == diff_ERIs.shape[2] - pad = pad_diff_ERIs - diff_indxs.shape[1] - assert pad > 0, (pad_diff_ERIs, diff_indxs.shape[1]) - state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) - state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros - #print(diff_indxs.shape, precomputed_indxs.shape) - state.precomputed_indxs = np.pad(precomputed_indxs, ((0,0), (0,0),(0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) - - #if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) + if pad_diff_ERIs == -1: + state.indxs=diff_indxs + state.diffs_ERI=diff_ERIs + assert False, "deal with precomputed_indxs; only added in else branch below" + else: + max_pad_diff_ERIs = diff_ERIs.shape[2] + if do_print: print("\t[%.4fs] max_pad_diff_ERIs=%i"%(time.time()-start_time, max_pad_diff_ERIs)) + # pad ERIs with 0 and indices with -1 so they point to 0. + assert diff_indxs.shape[1] == diff_ERIs.shape[2] + pad = pad_diff_ERIs - diff_indxs.shape[1] + assert pad > 0, (pad_diff_ERIs, diff_indxs.shape[1]) + state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) + state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros + #print(diff_indxs.shape, precomputed_indxs.shape) + if opts.bs > 1: state.precomputed_indxs = np.pad(precomputed_indxs, ((0,0), (0,0),(0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) + + #if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) state.rows=rows state.cols=cols @@ -439,6 +490,7 @@ def get_atom_positions(mol): if pad_sparse_diff_grid != -1: max_pad_sparse_diff_grid = state.rows.shape[0] + if do_print: print("\t[%.4fs] max_pad_sparse_diff_grid=%i"%(time.time()-start_time, max_pad_sparse_diff_grid)) assert state.sparse_diffs_grid_AO.shape[1] == state.rows.shape[0] assert state.sparse_diffs_grid_AO.shape[1] == state.cols.shape[0] pad = pad_sparse_diff_grid - state.rows.shape[0] @@ -457,6 +509,7 @@ def get_atom_positions(mol): # todo: looks like we're padding, then looking for zeros, then padding; this can be simplified. if pad_distinct_ERIs != -1: max_pad_distinct_ERIs = state.nonzero_distinct_ERI.shape[2] + if do_print: print("\t[%.4fs] max_pad_distinct_ERIs=%i"%(time.time()-start_time, max_pad_diff_ERIs)) assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] pad = pad_distinct_ERIs - state.nonzero_distinct_ERI.shape[2] assert pad > 0, (pad_distinct_ERIs, state.nonzero_distinct_ERI.shape[2]) @@ -467,6 +520,7 @@ def get_atom_positions(mol): if pad_grid_AO != -1: max_pad_grid_AO = state.grid_AO.shape[2] + if do_print: print("\t[%.4fs] max_pad_grid_AO=%i"%(time.time()-start_time, max_pad_grid_AO)) prev_size = state.grid_AO.shape[2] assert state.grid_AO.shape[2] == state.grid_weights.shape[1] @@ -501,11 +555,12 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = state.nonzero_distinct_ERI.reshape(1, batches, -1) state.nonzero_indices = state.nonzero_indices.reshape(1, batches, -1, 4) - precomputed_nonzero_indices = precompute_indices(state.nonzero_indices[0], N).astype(np.int16) + if opts.bs > 1: precomputed_nonzero_indices = precompute_indices(state.nonzero_indices[0], N).astype(np.int16) #print(state.nonzero_indices.shape, precomputed_nonzero_indices.shape) if pad_nonzero_distinct_ERI != -1: max_pad_nonzero_distinct_ERI = state.nonzero_distinct_ERI.shape[2] + if do_print: print("\t[%.4fs] max_pad_nonzero_distinct_ERI=%i"%(time.time()-start_time, max_pad_nonzero_distinct_ERI)) assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] pad = pad_nonzero_distinct_ERI - state.nonzero_distinct_ERI.shape[2] @@ -513,7 +568,7 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, ((0,0),(0,0),(0,pad))) state.nonzero_indices = np.pad(state.nonzero_indices, ((0,0),(0,0),(0,pad), (0,0)), 'constant', constant_values=(-1)) - state.precomputed_nonzero_indices = np.pad(precomputed_nonzero_indices, ((0,0), (0,0), (0,0), (0, pad),(0,0)), 'constant', constant_values=(-1)) + if opts.bs > 1: state.precomputed_nonzero_indices = np.pad(precomputed_nonzero_indices, ((0,0), (0,0), (0,0), (0, pad),(0,0)), 'constant', constant_values=(-1)) #print(state.precomputed_nonzero_indices.shape, state.nonzero_indices.shape) #if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) @@ -535,6 +590,7 @@ def get_atom_positions(mol): def nanoDFT(mol_str, opts): + start_time = time.time() print() # Initialize validation set. # This consists of DFT tensors initialized with PySCF/CPU. @@ -546,6 +602,8 @@ def nanoDFT(mol_str, opts): run = wandb.init(project='ndft_alanine') elif opts.qm9: run = wandb.init(project='ndft_qm9') + elif opts.md17 > 0: + run = wandb.init(project='md17') else: run = wandb.init(project='ndft') opts.name = run.name @@ -588,13 +646,17 @@ def nanoDFT(mol_str, opts): d_model= 1024 n_heads = 16 n_layers = 24 - if opts.large: - d_model= 1280 + if opts.large: # this is 600M; + d_model= 1280 # 80*16 n_heads = 16 n_layers = 36 - if opts.xlarge: - d_model= 1600 - n_heads = 25 + if opts.largep: # interpolated between large and largep. + d_model= 91*16 # halway from 80 to 100 + n_heads = 16*1 + n_layers = 43 + if opts.xlarge: # this is 1.3B; decrease parameter count 30%. + d_model= 1600 # 100*16 + n_heads = 25 n_layers = 48 if opts.nn: @@ -606,6 +668,7 @@ def nanoDFT(mol_str, opts): n_heads =n_heads, d_ff =d_model*4, ) + print("[%.4fs] initialized transformer. "%(time.time()-start_time) ) params = params.to_float32() if opts.resume: @@ -616,30 +679,35 @@ def nanoDFT(mol_str, opts): if opts.nn: #https://arxiv.org/pdf/1706.03762.pdf see 5.3 optimizer - - # try to mimic karpathy as closely as possible ;) - # https://github.com/karpathy/nanoGPT/blob/master/train.py - # still differs on - # [ ] weight initialization - - def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2000, lr_decay_iters=600000): # 600k/30 = 20k; so hit mi + def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.min_lr, warmup_iters=opts.warmup_iters, lr_decay_iters=opts.lr_decay): cond1 = (it < warmup_iters) * learning_rate * it / warmup_iters cond2 = (it > lr_decay_iters) * min_lr decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) coeff = 0.5 * (1.0 + jnp.cos(jnp.pi * decay_ratio)) cond3 = (it >= warmup_iters) * (it <= lr_decay_iters) * (min_lr + coeff * (learning_rate - min_lr)) - return cond1 + cond2 + cond3 + if not opts.resume: return cond1 + cond2 + cond3 + else: return learning_rate adam = optax.chain( optax.clip_by_global_norm(1), - optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), + #optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), + optax.scale_by_adam(b1=0.99, b2=0.999, eps=1e-12), + #optax.scale_by_factored_rms(), # use this for larger model (more memory efficient) optax.add_decayed_weights(0.1), optax.scale_by_schedule(custom_schedule), optax.scale(-1), + #optax.ema(opts.ema) if opts.ema != 0 else None ) - w = params + df = None + if opts.qh9: + df = pd.read_pickle("qh9/qh9stable_processed_shuffled.pickle") + df = df[df["N_sto3g"]==55] + print(df.shape) + elif opts.qm9: + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] from torch.utils.data import DataLoader, Dataset class OnTheFlyQM9(Dataset): @@ -647,31 +715,20 @@ class OnTheFlyQM9(Dataset): # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). # problem: second epoch always gives segfault. # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples - def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False): - - if opts.qh9: - df = pd.read_pickle("qh9/qh9stable_processed_shuffled.pickle") - if nao != -1: df = df[df["N_sto3g"]==55] - print(df.shape) - else: - # only take molecules with use {CNOFH}, nao=nao and spin=0. - #qm9 - df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules - if nao != -1: df = df[df["nao"]==nao] - - + def __init__(self, opts, df=None, nao=294, train=True, num_epochs=10**9, extrapolate=False): # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. # this shuffling is important, because it makes the last 10 samples iid (used for validation) #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? - - if train: self.mol_strs = df["pyscf"].values[:-10] - else: self.mol_strs = df["pyscf"].values[-10:] + if opts.qh9 or opts.qm9: + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] #print(df["pyscf"].) # todo: print smile strings self.num_epochs = num_epochs self.opts = opts self.validation = not train self.extrapolate = extrapolate + self.do_pyscf = self.validation or self.extrapolate self.benzene = [ ["C", ( 0.0000, 0.0000, 0.0000)], @@ -697,39 +754,58 @@ def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=Fals ] if opts.benzene: self.mol_strs = [self.benzene] - if opts.waters: self.mol_strs = [self.waters] + + if opts.md17 > 0: + mol = {MD17_WATER: "water", MD17_ALDEHYDE: "malondialdehyde", MD17_ETHANOL: "ethanol", MD17_URACIL: "uracil"}[opts.md17] + mode = {True: "train", False: "val"}[train] + filename = "md17/%s_%s.pickle"%(mode, mol) + df = pd.read_pickle(filename) + + self.mol_strs = df["pyscf"].values.tolist() + N = int(np.sqrt(df["H"].values.tolist()[0].reshape(-1).size)) + self.H = [a.reshape(N, N) for a in df["H"].values.tolist()] + self.E = df["E"].values.tolist() + self.mol_strs = [eval(a) for a in self.mol_strs] + else: + self.H = [0 for _ in self.mol_strs] + self.E = [0 for _ in self.mol_strs] + + if opts.alanine: self.mol_strs = mol_str if train: self.bs = opts.bs else: self.bs = opts.val_bs - def __len__(self): - return len(self.mol_strs)*self.num_epochs + + def __len__(self): return len(self.mol_strs)*self.num_epochs def __getitem__(self, idx): return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ - wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ - extrapolate=self.extrapolate, mol_idx=idx) + wiggle_num=0, do_pyscf=self.do_pyscf, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx), self.H[idx%len(self.mol_strs)], self.E[idx%len(self.mol_strs)] - val_qm9 = OnTheFlyQM9(opts, train=False) - ext_qm9 = OnTheFlyQM9(opts, extrapolate=True) + print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) + val_qm9 = OnTheFlyQM9(opts, train=False, df=df) + print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) + ext_qm9 = OnTheFlyQM9(opts, extrapolate=True, df=df) + print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) - # parallel dataloader bug; precompute here is not slow but causes dataloader later to die. - # run once to quickly precompute. if opts.precompute: val_state = val_qm9[0] ext_state = ext_qm9[0] exit() - qm9 = OnTheFlyQM9(opts, train=True) + qm9 = OnTheFlyQM9(opts, train=True, df=df) + print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) if opts.workers != 0: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, prefetch_factor=2, collate_fn=lambda x: x[0]) else: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, collate_fn=lambda x: x[0]) pbar = tqdm(train_dataloader) + print("[%.4fs] initialized dataloaders. "%(time.time()-start_time) ) if opts.test_dataloader: t0 = time.time() - for iteration, state in enumerate(pbar): + for iteration, (state, H, E) in enumerate(pbar): if iteration == 0: summary(state) print(time.time()-t0) t0 = time.time() @@ -750,9 +826,10 @@ def __next__(self): return self.item adam = optax.adabelief(opts.lr) summary(states[0]) - vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) - valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn')) + vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) + valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) adam_state = adam.init(w) + print("[%.4fs] jitted vandg and valf."%(time.time()-start_time) ) if opts.resume: print("loading adam state") @@ -760,6 +837,7 @@ def __next__(self): return self.item print("done") w, adam_state = jax.device_put(w), jax.device_put(adam_state) + print("[%.4fs] jax.device_put(w,adam_state)."%(time.time()-start_time) ) @partial(jax.jit, backend=opts.backend) @@ -781,7 +859,9 @@ def update(w, adam_state, accumulated_grad): paddings = [] states = [] - for iteration, state in enumerate(pbar): + + print("[%.4fs] first iteration."%(time.time()-start_time) ) + for iteration, (state, H, E) in enumerate(pbar): if iteration == 0: summary(state) state = jax.device_put(state) @@ -800,59 +880,65 @@ def update(w, adam_state, accumulated_grad): if opts.shuffle: random.shuffle(states) # load_time, t0 = time.time()-t0, time.time() - if opts.checkpoint != -1 and iteration % opts.checkpoint == 0: # and iteration > 0: - t0 = time.time() - try: - name = opts.name.replace("-", "_") - path_model = "checkpoints/%s_%i_model.pickle"%(name, iteration) - path_adam = "checkpoints/%s_%i_adam_state.pickle"%(name, iteration) - print("trying to checkpoint to %s and %s"%(path_model, path_adam)) - pickle.dump(jax.device_get(w), open(path_model, "wb")) - pickle.dump(jax.device_get(adam_state), open(path_adam, "wb")) - print("done!") - print("\t-resume \"%s\""%(path_model.replace("_model.pickle", ""))) - except: - print("fail!") - pass - print("tried saving model took %fs"%(time.time()-t0)) - save_time, t0 = time.time()-t0, time.time() - - - if len(states) < 50: print(len(states)) - for j, state in enumerate(states): - print(". ", end="", flush=True) - if j == 0: _t0 =time.time() - (val, (vals, E_xc, density_matrix, _W)), grad = vandg(w, state, opts.normal, opts.nn) - print(",", end="", flush=True) - if j == 0: time_step1 = time.time()-_t0 - - if opts.grad_acc == 0 or len(states) < opts.mol_repeats: - w, adam_state = update(w, adam_state, grad) - else: - accumulated_grad = grad if accumulated_grad is None else jax.tree_map(lambda x, y: x + y, accumulated_grad, grad) - if j % opts.grad_acc == 0 and j > 0: # we assume opts.grad_acc divides opts.mol_repeats; prev was basically grad_acc=0 or grad_acc=mol_repeats, can now do hybrid. - w, adam_state = update(w, adam_state, grad) - accumulated_grad = None - print("#", end="", flush=True) + if len(states) < 50: print(len(states), opts.name) + + reps = 1 + if opts.md17 == 4: reps = 5 + if opts.md17 == 3: reps = 2 + if opts.md17 == 2: reps = 2 + + for _ in range(reps): + for j, state in enumerate(states): + print(". ", end="", flush=True) + if j == 0: _t0 =time.time() + (val, (vals, E_xc, density_matrix, _W, _H)), grad = vandg(w, state, opts.normal, opts.nn, cfg, opts) + print(",", end="", flush=True) + if j == 0: time_step1 = time.time()-_t0 + + if opts.grad_acc == 0 or len(states) < opts.mol_repeats: + w, adam_state = update(w, adam_state, grad) + else: + accumulated_grad = grad if accumulated_grad is None else jax.tree_map(lambda x, y: x + y, accumulated_grad, grad) + + if (j+1) % opts.grad_acc == 0 and j > 0: # we assume opts.grad_acc divides opts.mol_repeats; prev was basically grad_acc=0 or grad_acc=mol_repeats, can now do hybrid. + w, adam_state = update(w, adam_state, grad) + accumulated_grad = None + print("#", end="", flush=True) + + + if opts.checkpoint != -1 and adam_state[1].count % opts.checkpoint == 0 and adam_state[1].count > 0: + t0 = time.time() + try: + name = opts.name.replace("-", "_") + path_model = "checkpoints/%s_%i_model.pickle"%(name, iteration) + path_adam = "checkpoints/%s_%i_adam_state.pickle"%(name, iteration) + print("trying to checkpoint to %s and %s"%(path_model, path_adam)) + pickle.dump(jax.device_get(w), open(path_model, "wb")) + pickle.dump(jax.device_get(adam_state), open(path_adam, "wb")) + print("done!") + print("\t-resume \"%s\""%(path_model.replace("_model.pickle", ""))) + except: + print("fail!") + pass + print("tried saving model took %fs"%(time.time()-t0)) + save_time, t0 = time.time()-t0, time.time() + # todo: rename global_batch_size = len(states)*opts.bs if opts.wandb: dct["global_batch_size"] = global_batch_size train_time, t0 = time.time()-t0, time.time() - - # plot grad norm - #if iteration % 10 == 0: - # for k,v in accumulated_grad.items(): dct[k + "_norm"] = np.linalg.norm(v .reshape(-1) ) update_time, t0 = time.time()-t0, time.time() if not opts.nn: str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" str += "pyscf=%.7f us=%.7f"%(state.pyscf_E[0]/HARTREE_TO_EV, vals[0]) else: - pbar.set_description("train=".join(["%.2f"%i for i in vals[:1]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, update_time, val_time, plot_time)) + #print(vals[0], E) + pbar.set_description("train=%.4f"%(vals[0]*HARTREE_TO_EV) + "[eV] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, update_time, val_time, plot_time)) if opts.wandb: dct["time_load"] = load_time @@ -860,44 +946,66 @@ def update(w, adam_state, accumulated_grad): dct["time_train"] = train_time dct["time_val"] = val_time plot_iteration = iteration % 10 == 0 - for i in range(0, 2): - if not opts.nn: - dct['train_l%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) - dct['train_pyscf%i'%i ] = np.abs(state.pyscf_E[i]) - dct['train_E%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV) - if plot_iteration: - dct['img/dm%i'%i] = wandb.Image(np.expand_dims(density_matrix[i], axis=-1)) - dct['img/W%i'%i] = wandb.Image(np.expand_dims(_W[i], axis=-1)) + + dct["train_E"] = np.abs(E*HARTREE_TO_EV) + dct["train_E_pred"] = np.abs(vals[0]*HARTREE_TO_EV) step = adam_state[1].count plot_time, t0 = time.time()-t0, time.time() - - - # TODO: Plot molecules and val/ext angles. if opts.nn and (iteration < 250 or iteration % 10 == 0): - if val_state is None: val_state = jax.device_put(val_qm9[0]) - _, (valid_vals, _, vdensity_matrix, vW) = valf(w, val_state, opts.normal, opts.nn) - if ext_state is None: ext_state = jax.device_put(ext_qm9[0]) - _, (ext_vals, _, edensity_matrix, eW) = valf(w, ext_state, opts.normal, opts.nn) + val_idx = 1 + if val_state is None: val_state, val_H, val_E = jax.device_put(val_qm9[val_idx]) # todo: cat 8 of these. + _, (valid_vals, _, vdensity_matrix, vW, _val_H) = valf(w, val_state, opts.normal, opts.nn, cfg, opts) + + if opts.md17 > 0: + def get_H_from_dm(dm): + import pyscf + from pyscf import gto, dft + m = pyscf.gto.Mole(atom=val_qm9.mol_strs[val_idx], basis="def2-svp", unit="bohr") + m.build() + mf = dft.RKS(m) + mf.xc = 'B3LYP5' + mf.verbose = 0 + mf.diis_space = 8 + mf.conv_tol = 1e-13 + mf.grad_tol = 3.16e-5 + mf.grids.level = 3 + #mf.kernel() + h_core = mf.get_hcore() + S = mf.get_ovlp() + vxc = mf.get_veff(m, dm) + H = h_core + vxc + S = mf.get_ovlp() + return H, S + + matrix = np.array(vdensity_matrix[0]) + N = int(np.sqrt(matrix.size)) + _val_H, S = get_H_from_dm(matrix.reshape(N, N)) + + # compare eigenvalues + pred_vals = scipy.linalg.eigh(_val_H, S)[0] + label_vals = scipy.linalg.eigh(val_H, S)[0] + MAE_vals = np.mean(np.abs(pred_vals - label_vals)) + dct["val_eps"] = MAE_vals + lr = custom_schedule(step) - valid_str = "lr=%.3e"%lr + "val=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" - valid_str += "ext=" + "".join(["%.4f "%(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + valid_str = "lr=%.3e"%lr + "val=%.4f [eV] "%(valid_vals[0]*HARTREE_TO_EV-val_E*HARTREE_TO_EV) + "mae_H=%.4f "%( + np.mean(np.abs(val_H/np.abs(val_H) - _val_H/np.abs(_val_H))) + ) + if opts.md17> 0:valid_str+= " eps=%.4f"%(MAE_vals) + valid_str += "val'=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + + dct['val_E'] = np.abs(valid_vals[0]*HARTREE_TO_EV-val_E*HARTREE_TO_EV ) + dct['val_H_MAE'] = np.mean(np.abs(val_H - _val_H)) # perhaps sign doesn't matter? + if opts.wandb: - for i in range(0, opts.val_bs): + for i in range(0, 3): dct['valid_l%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) dct['valid_E%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV) dct['valid_pyscf%i'%i ] = np.abs(val_state.pyscf_E[i]) - dct['img/val_dm%i'%i] = wandb.Image(np.expand_dims(vdensity_matrix[i], axis=-1)) - dct['img/val_W%i'%i] = wandb.Image(np.expand_dims(vW[i], axis=-1)) - - dct['ext_l%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) - dct['ext_E%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV) - dct['ext_pyscf%i'%i ] = np.abs(ext_state.pyscf_E[i]) - dct['img/ext_dm%i'%i] = wandb.Image(np.expand_dims(edensity_matrix[i], axis=-1)) - dct['img/ext_W%i'%i] = wandb.Image(np.expand_dims(eW[i], axis=-1)) dct["scheduled_lr"] = lr @@ -1104,8 +1212,8 @@ def grids_from_pyscf_mol( def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=True, state=None, pad_electrons=-1): do_print = False #t0 = time.time() - mol = build_mol(mol_str, opts.basis) - if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) + mol = build_mol(mol_str, opts.basis, unit="bohr") + if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts, unit="bohr") else: pyscf_E, pyscf_hlgap, pyscf_forces = np.zeros(1), np.zeros(1), np.zeros(1) N = mol.nao_nr() # N=66 for C6H6 (number of atomic **and** molecular orbitals) @@ -1366,12 +1474,19 @@ def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|> def pyscf_reference(mol_str, opts): from pyscf import __config__ __config__.dft_rks_RKS_grids_level = opts.level - mol = build_mol(mol_str, opts.basis) + mol = build_mol(mol_str, opts.basis, unit="bohr") mol.max_cycle = 50 mf = pyscf.scf.RKS(mol) - mf.max_cycle = 50 - mf.xc = "b3lyp5" + #mf.max_cycle = 50 + #mf.xc = "b3lyp5" + #mf.diis_space = 8 + mf.xc = 'B3LYP5' + mf.verbose = 0 # put this to 4 to check i set parameters correctly! mf.diis_space = 8 + # options from qh9 + mf.conv_tol=1e-13 + mf.grad_tol=3.16e-5 + mf.grids.level = 3 pyscf_energies = [] pyscf_hlgaps = [] lumo = mol.nelectron//2 @@ -1410,23 +1525,23 @@ def print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, cosine_similarity = dot_products / (norm_X * norm_Y) print("Force cosine similarity:",cosine_similarity) -def build_mol(mol_str, basis_name): +def build_mol(mol_str, basis_name, unit="bohr"): mol = pyscf.gto.mole.Mole() - mol.build(atom=mol_str, unit="Angstrom", basis=basis_name, spin=0, verbose=0) + mol.build(atom=mol_str, unit=unit, basis=basis_name, spin=0, verbose=0) return mol -def reference(mol_str, opts): +def reference(mol_str, opts, unit="bohr"): import pickle import hashlib if opts.skip: return np.zeros(1), np.zeros(1), np.zeros(1) - filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level)).encode('utf-8')).hexdigest() + filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level) + unit).encode('utf-8')).hexdigest() print(filename) if not os.path.exists(filename): pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) with open(filename, "wb") as file: - pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces], file) + pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces, unit], file) else: - pyscf_E, pyscf_hlgap, pyscf_forces = pickle.load(open(filename, "rb")) + pyscf_E, pyscf_hlgap, pyscf_forces, unit = pickle.load(open(filename, "rb")) return pyscf_E, pyscf_hlgap, pyscf_forces @@ -1441,10 +1556,15 @@ def reference(mol_str, opts): # GD options parser.add_argument('-backend', type=str, default="cpu") - parser.add_argument('-lr', type=float, default=2.5e-4) + parser.add_argument('-lr', type=float, default=5e-4) + parser.add_argument('-min_lr', type=float, default=1e-7) + parser.add_argument('-warmup_iters', type=float, default=1000) + parser.add_argument('-lr_decay', type=float, default=200000) + parser.add_argument('-ema', type=float, default=0.0) + parser.add_argument('-steps', type=int, default=100000) parser.add_argument('-bs', type=int, default=8) - parser.add_argument('-val_bs', type=int, default=8) + parser.add_argument('-val_bs', type=int, default=3) parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. parser.add_argument('-grad_acc', type=int, default=0) # integer, deciding how many steps to accumulate. parser.add_argument('-shuffle', action="store_true") # whether to to shuffle the window of states each step. @@ -1462,14 +1582,16 @@ def reference(mol_str, opts): parser.add_argument('-skip', action="store_true", help="skip pyscf test case") # dataset - parser.add_argument('-nperturb', type=int, default=1, help="how many atoms to perturb (supports 1,2,3)") + parser.add_argument('-nperturb', type=int, default=0, help="How many atoms to perturb (supports 1,2,3)") parser.add_argument('-qm9', action="store_true") + parser.add_argument('-md17', type=int, default=-1) parser.add_argument('-qh9', action="store_true") parser.add_argument('-benzene', action="store_true") parser.add_argument('-hydrogens', action="store_true") parser.add_argument('-water', action="store_true") parser.add_argument('-waters', action="store_true") parser.add_argument('-alanine', action="store_true") + parser.add_argument('-do_print', action="store_true") # useful for debugging. parser.add_argument('-states', type=int, default=1) parser.add_argument('-workers', type=int, default=5) parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. @@ -1479,6 +1601,8 @@ def reference(mol_str, opts): parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") parser.add_argument('-test_dataloader', action="store_true", help="no training, just test/loop through dataloader. ") + + # models parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") parser.add_argument('-tiny', action="store_true") @@ -1487,23 +1611,31 @@ def reference(mol_str, opts): parser.add_argument('-medium', action="store_true") parser.add_argument('-large', action="store_true") parser.add_argument('-xlarge', action="store_true") + parser.add_argument('-largep', action="store_true") # large "plus" parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model opts = parser.parse_args() if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True - assert opts.mol_repeats % opts.grad_acc == 0, "mol_repeats needs to be a multiple of grad_acc (gradient accumulation)." + assert opts.grad_acc == 0 or opts.mol_repeats % opts.grad_acc == 0, "mol_repeats needs to be a multiple of grad_acc (gradient accumulation)." + + class HashableNamespace: + def __init__(self, namespace): self.__dict__.update(namespace.__dict__) + def __hash__(self): return hash(tuple(sorted(self.__dict__.items()))) + opts = HashableNamespace(opts) args_dict = vars(opts) print(args_dict) - if opts.qm9: df = pd.read_pickle("alchemy/atom_9.pickle") df = df[df["spin"] == 0] # only consider spin=0 mol_strs = df["pyscf"].values + if opts.qh9: + mol_strs = [] + # benzene if opts.benzene: mol_strs = [[ @@ -1526,7 +1658,7 @@ def reference(mol_str, opts): ["H", ( 0.0000, 0.0000, 0.0000)], ["H", ( 1.4000, 0.0000, 0.0000)], ]] - if opts.water: + if opts.md17 > 0 : mol_strs = [[ ["O", ( 0.0000, 0.0000, 0.0000)], ["H", ( 0.0000, 1.4000, 0.0000)], @@ -1573,4 +1705,4 @@ def reference(mol_str, opts): exit() pyscf_E, pyscf_hlgap, pyscf_forces = reference(mol_str, opts) nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy, np.array(dm), np.array(H)) - print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) + print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) \ No newline at end of file