diff --git a/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py b/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py index 6462fde..9873270 100644 --- a/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py +++ b/pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py @@ -163,7 +163,7 @@ def get_shapes(input_ijkl, bas): return len, nf -def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend): +def compute_diff_jk(dm, mol, nbatch, tolerance, backend): dm = dm.reshape(-1) diff_JK = jnp.zeros(dm.shape) N = int(np.sqrt(dm.shape[0])) @@ -392,46 +392,16 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend): num_shells, shell_size = eri.shape # save original tensor shape - def compute_full_shell_idx(idx): - comp_distinct_idx_list = [] - for ind in range(eri.shape[0]): - i, j, k, l = [idx[ind, z] for z in range(4)] - _di, _dj, _dk, _dl = ao_loc[i+1] - ao_loc[i], ao_loc[j+1] - ao_loc[j], ao_loc[k+1] - ao_loc[k], ao_loc[l+1] - ao_loc[l] - _i0, _j0, _k0, _l0 = ao_loc[i], ao_loc[j], ao_loc[k], ao_loc[l] - block_idx = np.mgrid[ - _i0:(_i0+_di), - _j0:(_j0+_dj), - _k0:(_k0+_dk), - _l0:(_l0+_dl)].transpose(4, 3, 2, 1, 0) #.astype(np.int16) - - comp_distinct_idx_list.append(block_idx.reshape(-1, 4)) - comp_distinct_idx = np.concatenate(comp_distinct_idx_list) - return comp_distinct_idx - - comp_distinct_idx = compute_full_shell_idx(idx) - - ijkl_arr = np.sum([np.prod(np.array(a).shape) for a in input_ijkl]) - print('input_ijkl.nbytes/1e6', ijkl_arr*2/1e6) - print('comp_distinct_idx.nbytes/1e6', comp_distinct_idx.astype(np.int16).nbytes/1e6) - - remainder = (eri.shape[0]) % (nprog*nbatch) + remainder = (eri.shape[0]) % (nbatch) - # unused for nipu==batches==1 + # pad tensors; unused for nipu==batches==1 if remainder != 0: - print('padding', remainder, nprog*nbatch-remainder, comp_distinct_idx.shape) - comp_distinct_idx = np.pad(comp_distinct_idx.reshape(-1, shell_size, 4), ((0, (nprog*nbatch-remainder)), (0, 0), (0, 0))).reshape(-1, 4) - eri = jnp.pad(eri, ((0, nprog*nbatch-remainder), (0, 0))) - idx = jnp.pad(idx, ((0, nprog*nbatch-remainder), (0, 0))) + eri = jnp.pad(eri, ((0, nbatch-remainder), (0, 0))) + idx = jnp.pad(idx, ((0, nbatch-remainder), (0, 0))) - comp_distinct_ERI = eri.reshape(nprog, nbatch, -1) - comp_distinct_idx = comp_distinct_idx.reshape(nprog, nbatch, -1, 4) - idx = idx.reshape(nprog, nbatch, -1, 4) - + nonzero_distinct_ERI = eri.reshape(nbatch, -1) + nonzero_indices = idx.reshape(nbatch, -1, 4) - # nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend - nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], idx[0], dm, backend - - dm = dm.reshape(-1) diff_JK = jnp.zeros(dm.shape) N = int(np.sqrt(dm.shape[0])) @@ -439,44 +409,28 @@ def compute_full_shell_idx(idx): def foreach_batch(i, vals): diff_JK, nonzero_indices, ao_loc = vals - eris = nonzero_distinct_ERI[i].reshape(-1) - - if False: - - indices = nonzero_indices[i] - # # indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32) - indices = indices.astype(jnp.int32) - - print('eris.shape', eris.shape) - print('indices.shape', indices.shape) - - else: - # Compute offsets and sizes - idx = nonzero_indices[i] - _i, _j, _k, _l = [idx[:, z] for z in range(4)] - _di, _dj, _dk, _dl = [(ao_loc[z+1] - ao_loc[z]).reshape(-1, 1) for z in [_i, _j, _k, _l]] - _i0, _j0, _k0, _l0 = [ao_loc[z].reshape(-1, 1) for z in [_i, _j, _k, _l]] - - def gen_shell_idx(idx_sh): - idx_sh = idx_sh.reshape(-1, shell_size) - # Compute the indices - ind_i = (idx_sh ) % _di + _i0 - ind_j = (idx_sh // (_di) ) % _dj + _j0 - ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0 - ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0 - print('>>', ind_i.shape) - # Update the array with the computed indices - return jnp.stack([ind_i.reshape(-1), ind_j.reshape(-1), ind_k.reshape(-1), ind_l.reshape(-1)], axis=1) - - indices = gen_shell_idx(jnp.arange((eris.shape[0]))) # <<<<<<<<<<<<<<<<<<<<<<<<< - - print('eris.shape', eris.shape) - print('indices.shape', indices.shape) + # Compute offsets and sizes + batch_idx = nonzero_indices[i] + _i, _j, _k, _l = [batch_idx[:, z] for z in range(4)] + _di, _dj, _dk, _dl = [(ao_loc[z+1] - ao_loc[z]).reshape(-1, 1) for z in [_i, _j, _k, _l]] + _i0, _j0, _k0, _l0 = [ao_loc[z].reshape(-1, 1) for z in [_i, _j, _k, _l]] + + def gen_shell_idx(idx_sh): + # Compute the indices + ind_i = (idx_sh ) % _di + _i0 + ind_j = (idx_sh // (_di) ) % _dj + _j0 + ind_k = (idx_sh // (_di*_dj) ) % _dk + _k0 + ind_l = (idx_sh // (_di*_dj*_dk)) % _dl + _l0 + + # Update the array with the computed indices + return jnp.stack([ind_i.reshape(-1), ind_j.reshape(-1), ind_k.reshape(-1), ind_l.reshape(-1)], axis=1) + + eris = nonzero_distinct_ERI[i].reshape(-1) + indices = gen_shell_idx(jnp.arange((eris.shape[0])).reshape(-1, shell_size)) # compute repetitions caused by 8x symmetry when computing from the distinct_ERI form and scale accordingly drep = num_repetitions_fast_4d(indices[:, 0], indices[:, 1], indices[:, 2], indices[:, 3], xnp=jnp, dtype=jnp.uint32) eris = eris / drep - def foreach_symmetry(sym, vals): # Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16) @@ -570,7 +524,7 @@ def foreach_symmetry(sym, vals): # ------------------------------------ # - diff_JK = jax.jit(compute_diff_jk, backend=backend, static_argnames=['mol', 'nprog', 'nbatch', 'tolerance', 'backend'])(dm, mol, args.nipu, args.batches, args.itol, args.backend) + diff_JK = jax.jit(compute_diff_jk, backend=backend, static_argnames=['mol', 'nbatch', 'tolerance', 'backend'])(dm, mol, args.batches, args.itol, args.backend) # ------------------------------------ #