Skip to content

Commit

Permalink
moved index computation closer to matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
mihaipgc committed Oct 10, 2023
1 parent 5ffa784 commit 0f017bf
Showing 1 changed file with 95 additions and 78 deletions.
173 changes: 95 additions & 78 deletions pyscf_ipu/nanoDFT/sparse_symmetric_intor_ERI.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated
from functools import partial
from icecream import ic
from tqdm import tqdm
jax.config.update('jax_platform_name', "cpu")
#jax.config.update('jax_enable_x64', True)
HYB_B3LYP = 0.2
Expand Down Expand Up @@ -222,14 +223,18 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
# Step 1. Compute indices where ERI is non-zero due to geometry (pre-screening).
# Below computes: np.max([ERI[a,b,a,b] for a,b in zip(tril_idx[0], tril_idx[1])])
ERI_s8 = mol.intor("int2e_sph", aosym="s8") # TODO: make custom libcint code compute this.
lst = [ ]
lst_abab = np.zeros((N*(N+1)//2), dtype=np.float32)
lst_ab = np.zeros((N*(N+1)//2, 2), dtype=np.int32)
tril_idx = np.tril_indices(N)
for a, b in zip(tril_idx[0], tril_idx[1]):
for c, (a, b) in tqdm(enumerate(zip(tril_idx[0], tril_idx[1]))):
index_ab_s8 = a*(a+1)//2 + b
index_s8 = index_ab_s8*(index_ab_s8+3)//2
abab = np.abs(ERI_s8[index_s8])
lst.append((abab, a,b))
considered_indices = set([(a,b) for abab, a, b in lst if abab*np.max(lst) >= tolerance**2])
# lst.append((abab, a,b))
lst_abab[c] = abab
lst_ab[c, :] = (a, b)
abab_max = np.max(lst_abab)
considered_indices = set([(a,b) for abab, (a, b) in tqdm(zip(lst_abab, lst_ab)) if abab*abab_max >= tolerance**2])
print('n_bas', n_bas)
print('ao_loc', ao_loc)

Expand All @@ -243,7 +248,7 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
nonzero_seed = sym_pattern
num_calls = 0

for i in range(n_bas): # consider all shells << all ijkl
for i in tqdm(range(n_bas)): # consider all shells << all ijkl
for j in range(i+1):
for k in range(i, n_bas):
for l in range(k+1):
Expand Down Expand Up @@ -427,16 +432,19 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):

BLOCK_ERI_SIZE = np.sum(np.array([eri.shape[0] for eri in all_eris]))

comp_distinct_idx_list = [None]*BLOCK_ERI_SIZE
comp_do_list = [None]*BLOCK_ERI_SIZE
comp_list_index = 0

print('[a.shape for a in all_eris]', [a.shape for a in all_eris])
print('[a.shape for a in all_indices]', [a.shape for a in all_indices])


# go from our memory layout to mol.intor("int2e_sph", "s8")
# comp_distinct_idx_list = [None]*BLOCK_ERI_SIZE

temp = 0
for zip_counter, (eri, idx) in enumerate(zip(all_eris, all_indices)):

# comp_list_index = 0

# go from our memory layout to mol.intor("int2e_sph", "s8")
# for zip_counter, (eri, idx) in enumerate(zip(all_eris, all_indices)):
comp_distinct_idx_list = []
print(eri.shape)
for ind in range(eri.shape[0]):
i, j, k, l = [idx[ind, z] for z in range(4)]
Expand All @@ -447,91 +455,100 @@ def compute_diff_jk(dm, mol, nprog, nbatch, tolerance, backend):
_j0:(_j0+_dj),
_k0:(_k0+_dk),
_l0:(_l0+_dl)].transpose(4, 3, 2, 1, 0).astype(np.int16)

comp_distinct_idx_list.append(block_idx.reshape(-1, 4))
# comp_list_index += 1

comp_distinct_idx = np.concatenate(comp_distinct_idx_list)

def ijkl_in_bounds(i, j, k, l):
return i>=j and k>=l and (i*(i+1)//2+j)>=(k*(k+1)//2+l)
ijkl_arr = np.sum([np.prod(np.array(a).shape) for a in input_ijkl])
print('input_ijkl.nbytes/1e6', ijkl_arr*2/1e6)
print('comp_distinct_idx.nbytes/1e6', comp_distinct_idx.astype(np.int16).nbytes/1e6)

remainder = comp_distinct_idx.shape[0] % (nprog*nbatch)

block_do = np.zeros((_dl*_dk*_dj*_di))
for ci, ijkl in enumerate(block_idx.reshape(-1, 4)):
block_do[ci] = ijkl_in_bounds(ijkl[2], ijkl[3], ijkl[0], ijkl[1]) and ~(nonzero_seed[ijkl[0]] ^ nonzero_seed[ijkl[1]]) ^ (nonzero_seed[ijkl[2]] ^ nonzero_seed[ijkl[3]])

comp_distinct_idx_list[comp_list_index] = block_idx.reshape(-1, 4)
comp_do_list[comp_list_index] = block_do
comp_list_index += 1
# unused for nipu==batches==1
if remainder != 0:
print('padding', remainder, nprog*nbatch-remainder, comp_distinct_idx.shape)
comp_distinct_idx = np.pad(comp_distinct_idx, ((0, nprog*nbatch-remainder), (0, 0)))
eri = jnp.pad(eri.reshape(-1), ((0, nprog*nbatch-remainder)))




comp_distinct_idx = np.concatenate(comp_distinct_idx_list)
comp_do = np.concatenate(comp_do_list)

remainder = comp_distinct_idx.shape[0] % (nprog*nbatch)
# output of mol.intor("int2e_ssph", aosym="s8")
comp_distinct_ERI = eri.reshape(nprog, nbatch, -1) #jnp.concatenate([eri.reshape(-1) for eri in all_eris]).reshape(nprog, nbatch, -1)
comp_distinct_idx = comp_distinct_idx.reshape(nprog, nbatch, -1, 4)

if remainder != 0:
print('padding', remainder, nprog*nbatch-remainder, comp_distinct_idx.shape)
comp_distinct_idx = np.pad(comp_distinct_idx, ((0, nprog*nbatch-remainder), (0, 0)))
comp_do = np.pad(comp_do, ((0, nprog*nbatch-remainder)))
all_eris.append(jnp.zeros((nprog*nbatch-remainder), dtype=jnp.float32))

# output of mol.intor("int2e_ssph", aosym="s8")
comp_distinct_ERI = jnp.concatenate([eri.reshape(-1) for eri in all_eris]).reshape(nprog, nbatch, -1)
comp_distinct_idx = comp_distinct_idx.reshape(nprog, nbatch, -1, 4)
comp_do = comp_do.reshape(nprog, nbatch, -1)
# comp_distinct_ERI *= comp_do

print('comp_distinct_ERI.shape', comp_distinct_ERI.shape)
print('comp_distinct_idx.shape', comp_distinct_idx.shape)

# compute repetitions caused by 8x symmetry when computing from the distinct_ERI form and scale accordingly
drep = num_repetitions_fast_4d(comp_distinct_idx[:, :, :, 0], comp_distinct_idx[:, :, :, 1], comp_distinct_idx[:, :, :, 2], comp_distinct_idx[:, :, :, 3], xnp=np, dtype=np.uint32)
comp_distinct_ERI = comp_distinct_ERI / drep

# int16 storage supported but not slicing; use conversion trick to enable slicing
comp_distinct_idx = jax.lax.bitcast_convert_type(comp_distinct_idx, jnp.float16)
# reduce this from |eri_floats| to num_calls*4 ~ perhaps 10x smaller
print('comp_distinct_ERI.shape', comp_distinct_ERI.shape)
print('comp_distinct_idx.shape', comp_distinct_idx.shape)

#diff_JK = jax.pmap(sparse_symmetric_einsum, in_axes=(0,0,None,None), static_broadcasted_argnums=(3,), backend=backend, axis_name="p")(comp_distinct_ERI, comp_distinct_idx, dm, backend)
#diff_JK = sparse_symmetric_einsum(comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend)
nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend
dm = dm.reshape(-1)
diff_JK = jnp.zeros(dm.shape)
N = int(np.sqrt(dm.shape[0]))

def iteration(symmetry, vals):
diff_JK = vals
is_K_matrix = (symmetry >= 8)

# compute repetitions caused by 8x symmetry when computing from the distinct_ERI form and scale accordingly
drep = num_repetitions_fast_4d(comp_distinct_idx[:, :, :, 0], comp_distinct_idx[:, :, :, 1], comp_distinct_idx[:, :, :, 2], comp_distinct_idx[:, :, :, 3], xnp=np, dtype=np.uint32)
comp_distinct_ERI = comp_distinct_ERI / drep


def sequentialized_iter(i, vals):
# Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16)
# Trade-off: Using one function leads to smaller always-live memory.
# int16 storage supported but not slicing; use conversion trick to enable slicing
# comp_distinct_idx = jax.lax.bitcast_convert_type(comp_distinct_idx, jnp.float16)
# reduce this from |eri_floats| to num_calls*4 ~ perhaps 10x smaller

#diff_JK = jax.pmap(sparse_symmetric_einsum, in_axes=(0,0,None,None), static_broadcasted_argnums=(3,), backend=backend, axis_name="p")(comp_distinct_ERI, comp_distinct_idx, dm, backend)
#diff_JK = sparse_symmetric_einsum(comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend)
nonzero_distinct_ERI, nonzero_indices, dm, backend = comp_distinct_ERI[0], comp_distinct_idx[0], dm, backend
dm = dm.reshape(-1)
diff_JK = jnp.zeros(dm.shape)
N = int(np.sqrt(dm.shape[0]))

def iteration(i, vals):
diff_JK = vals

indices = nonzero_indices[i]

indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32)
# indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32)
indices = indices.astype(jnp.int32)
eris = nonzero_distinct_ERI[i]
print(indices.shape)

if backend == "cpu": dm_indices = cpu_ijkl(indices, symmetry+is_K_matrix*8, indices_func)
else: dm_indices = ipu_ijkl(indices, symmetry+is_K_matrix*8, N)
dm_values = jnp.take(dm, dm_indices, axis=0)

print('nonzero_distinct_ERI.shape', nonzero_distinct_ERI.shape)
print('dm_values.shape', dm_values.shape)
print('eris.shape', eris.shape)
dm_values = dm_values.at[:].mul( eris ) # this is prod, but re-use variable for inplace update.

if backend == "cpu": ss_indices = cpu_ijkl(indices, symmetry+8+is_K_matrix*8, indices_func)
else: ss_indices = ipu_ijkl(indices, symmetry+8+is_K_matrix*8, N)
diff_JK = diff_JK + jax.ops.segment_sum(dm_values, ss_indices, N**2) * (-HYB_B3LYP/2)**is_K_matrix


def sequentialized_iter(symmetry, vals):
# Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16)
# Trade-off: Using one function leads to smaller always-live memory.
is_K_matrix = (symmetry >= 8)

diff_JK = vals



if backend == "cpu": dm_indices = cpu_ijkl(indices, symmetry+is_K_matrix*8, indices_func)
else: dm_indices = ipu_ijkl(indices, symmetry+is_K_matrix*8, N)
dm_values = jnp.take(dm, dm_indices, axis=0)

print('nonzero_distinct_ERI.shape', nonzero_distinct_ERI.shape)
print('dm_values.shape', dm_values.shape)
print('eris.shape', eris.shape)
dm_values = dm_values.at[:].mul( eris ) # this is prod, but re-use variable for inplace update.

if backend == "cpu": ss_indices = cpu_ijkl(indices, symmetry+8+is_K_matrix*8, indices_func)
else: ss_indices = ipu_ijkl(indices, symmetry+8+is_K_matrix*8, N)
diff_JK = diff_JK + jax.ops.segment_sum(dm_values, ss_indices, N**2) * (-HYB_B3LYP/2)**is_K_matrix

return diff_JK


# diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK)
diff_JK = jax.lax.fori_loop(0, 16, sequentialized_iter, diff_JK)
# diff_JK = sequentialized_iter(0, diff_JK)
return diff_JK

batches = nonzero_indices.shape[0] # before pmap, tensor had shape (nipus, batches, -1) so [0]=batches after pmap
diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK)
return diff_JK
for bi in range(batches):
diff_JK = iteration(bi, diff_JK)

return jax.lax.fori_loop(0, 16, iteration, diff_JK)
temp += diff_JK


# temp /= (zip_counter + 1)

return temp

if __name__ == "__main__":
import time
Expand Down

0 comments on commit 0f017bf

Please sign in to comment.