Skip to content

Commit

Permalink
un-parameterize tests
Browse files Browse the repository at this point in the history
  • Loading branch information
awf committed Oct 13, 2023
1 parent 777e096 commit 8cd225c
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 135 deletions.
119 changes: 62 additions & 57 deletions notebooks/binom_factor_table.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,16 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand Down Expand Up @@ -276,26 +283,38 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0, 2, 1) (0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(2*b, a, b, domain='ZZ')\n",
"(0, 2, 5) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n",
"(0, 6, 6) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n",
"(1, 5, 5) (0, 5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a + 5*b, a, b, domain='ZZ')\n",
"(2, 1, 2) (0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(2*a + b, a, b, domain='ZZ')\n",
"(2, 3, 5) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n",
"(2, 5, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**2*b**5, a, b, domain='ZZ')\n",
"(2, 7, 7) (0, 0, 21, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**2 + 14*a*b + 21*b**2, a, b, domain='ZZ')\n",
"(4, 0, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(4*a**3, a, b, domain='ZZ')\n",
"(4, 2, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**4*b**2, a, b, domain='ZZ')\n",
"(0, 3, 2) (0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(3*b, a, b, domain='ZZ')\n",
"(0, 4, 1) (0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(4*b**3, a, b, domain='ZZ')\n",
"(0, 6, 4) (0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(15*b**2, a, b, domain='ZZ')\n",
"(1, 1, 7) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n",
"(1, 2, 1) (0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(2*a*b + b**2, a, b, domain='ZZ')\n",
"(1, 3, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a*b**3, a, b, domain='ZZ')\n",
"(1, 3, 5) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n",
"(1, 4, 1) (0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(4*a*b**3 + b**4, a, b, domain='ZZ')\n",
"(1, 6, 7) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n",
"(2, 0, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n",
"(2, 2, 4) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n",
"(2, 2, 7) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n",
"(2, 3, 2) (0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(3*a**2*b + 6*a*b**2 + b**3, a, b, domain='ZZ')\n",
"(2, 4, 2) (0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(6*a**2*b**2 + 8*a*b**3 + b**4, a, b, domain='ZZ')\n",
"(3, 0, 6) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n",
"(3, 1, 4) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n",
"(3, 2, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(2*a**3*b + 3*a**2*b**2, a, b, domain='ZZ')\n",
"(3, 5, 3) (0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(10*a**3*b**2 + 30*a**2*b**3 + 15*a*b**4 + b**5, a, b, domain='ZZ')\n",
"(3, 7, 3) (0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 63, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(35*a**3*b**4 + 63*a**2*b**5 + 21*a*b**6 + b**7, a, b, domain='ZZ')\n",
"(4, 0, 4) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n",
"(4, 2, 2) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**4 + 8*a**3*b + 6*a**2*b**2, a, b, domain='ZZ')\n",
"(5, 0, 4) (0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(5*a, a, b, domain='ZZ')\n",
"(5, 5, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**5*b**5, a, b, domain='ZZ')\n",
"(6, 5, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 100, 0, 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 60, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(5*a**6*b + 60*a**5*b**2 + 150*a**4*b**3 + 100*a**3*b**4 + 15*a**2*b**5, a, b, domain='ZZ')\n",
"(4, 4, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**4*b**4, a, b, domain='ZZ')\n",
"(5, 2, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**5*b**2, a, b, domain='ZZ')\n",
"(5, 7, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(7*a**5*b**6 + 5*a**4*b**7, a, b, domain='ZZ')\n",
"(6, 0, 7) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n",
"(6, 6, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(6*a**6*b**5 + 6*a**5*b**6, a, b, domain='ZZ')\n",
"(6, 6, 7) (0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 90, 0, 0, 0, 0, 0, 0, 300, 0, 0, 0, 0, 0, 0, 300, 0, 0, 0, 0, 0, 0, 90, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(6*a**5 + 90*a**4*b + 300*a**3*b**2 + 300*a**2*b**3 + 90*a*b**4 + 6*b**5, a, b, domain='ZZ')\n",
"(6, 7, 7) (0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 126, 0, 0, 0, 0, 0, 0, 525, 0, 0, 0, 0, 0, 0, 700, 0, 0, 0, 0, 0, 0, 315, 0, 0, 0, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**6 + 42*a**5*b + 315*a**4*b**2 + 700*a**3*b**3 + 525*a**2*b**4 + 126*a*b**5 + 7*b**6, a, b, domain='ZZ')\n",
"(7, 1, 2) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(7*a**6 + 21*a**5*b, a, b, domain='ZZ')\n",
"(7, 2, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0, 70, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(21*a**5 + 70*a**4*b + 35*a**3*b**2, a, b, domain='ZZ')\n",
"(7, 2, 7) (0, 0, 1, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(21*a**2 + 14*a*b + b**2, a, b, domain='ZZ')\n",
"(7, 6, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0, 210, 0, 0, 0, 0, 0, 0, 315, 0, 0, 0, 0, 0, 0, 140, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0) Poly(15*a**7*b**2 + 140*a**6*b**3 + 315*a**5*b**4 + 210*a**4*b**5 + 35*a**3*b**6, a, b, domain='ZZ')\n",
"(7, 6, 6) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 126, 0, 0, 0, 0, 0, 0, 525, 0, 0, 0, 0, 0, 0, 700, 0, 0, 0, 0, 0, 0, 315, 0, 0, 0, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0) Poly(a**7 + 42*a**6*b + 315*a**5*b**2 + 700*a**4*b**3 + 525*a**3*b**4 + 126*a**2*b**5 + 7*a*b**6, a, b, domain='ZZ')\n",
"(7, 7, 3) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0, 147, 0, 0, 0, 0, 0, 0, 147, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0) Poly(35*a**7*b**4 + 147*a**6*b**5 + 147*a**5*b**6 + 35*a**4*b**7, a, b, domain='ZZ')\n"
"(7, 3, 7) (0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 63, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(35*a**3 + 63*a**2*b + 21*a*b**2 + b**3, a, b, domain='ZZ')\n",
"(7, 4, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0) Poly(4*a**7*b**3 + 7*a**6*b**4, a, b, domain='ZZ')\n"
]
}
],
Expand Down Expand Up @@ -490,49 +509,35 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(0, 0, 0) 1.0 [1.]\n",
"(0, 0, 5) 0 [0.]\n",
"(0, 0, 6) 0 [0.]\n",
"(0, 1, 2) 0 [0.]\n",
"(1, 1, 3) 0 [0.]\n",
"(0, 2, 2) 1.0 [1.]\n",
"(0, 5, 0) 51.53632558509851 [51.536327]\n",
"(1, 1, 6) 0 [0.]\n",
"(1, 3, 5) 0 [0.]\n",
"(1, 4, 3) 38.72000133514405 [38.72]\n",
"(1, 4, 5) 1.0 [1.]\n",
"(1, 6, 7) 1.0 [1.]\n",
"(2, 1, 4) 0 [0.]\n",
"(2, 2, 0) 5.856400369262701 [5.8564005]\n",
"(2, 3, 3) 30.2500011253357 [30.25]\n",
"(2, 3, 6) 0 [0.]\n",
"(2, 5, 2) 438.05876595012853 [438.05878]\n",
"(3, 2, 5) 1.0 [1.]\n",
"(2, 1, 7) 0 [0.]\n",
"(2, 2, 5) 0 [0.]\n",
"(2, 6, 7) 15.40000033378601 [15.400001]\n",
"(2, 7, 4) 3336.977076303898 [3336.977]\n",
"(3, 1, 4) 1.0 [1.]\n",
"(3, 1, 6) 0 [0.]\n",
"(3, 1, 7) 0 [0.]\n",
"(3, 3, 0) 14.172489843082529 [14.17249]\n",
"(3, 4, 4) 171.69900965380683 [171.699]\n",
"(3, 4, 5) 61.71000228881837 [61.710003]\n",
"(3, 4, 6) 12.100000262260437 [12.1]\n",
"(3, 5, 0) 68.59485381402618 [68.59486]\n",
"(3, 5, 1) 342.9742594246647 [342.97427]\n",
"(3, 6, 3) 2692.773055105913 [2692.773]\n",
"(3, 7, 2) 5144.614001991794 [5144.6143]\n",
"(4, 0, 6) 0 [0.]\n",
"(4, 0, 7) 0 [0.]\n",
"(3, 5, 3) 889.0016110117186 [889.0016]\n",
"(4, 2, 2) 60.02810437345515 [60.028107]\n",
"(4, 2, 6) 1.0 [1.]\n",
"(4, 3, 0) 15.58973916528927 [15.589739]\n",
"(4, 5, 4) 1613.731182697727 [1613.7311]\n",
"(4, 7, 6) 7959.14124416379 [7959.1416]\n",
"(5, 0, 4) 5.5000001192092896 [5.5]\n",
"(5, 0, 6) 0 [0.]\n",
"(5, 3, 2) 258.64793837960883 [258.64795]\n",
"(5, 5, 0) 82.99977671291498 [82.99978]\n",
"(5, 6, 0) 182.5995127261507 [182.59952]\n",
"(6, 1, 1) 23.030295995009112 [23.030296]\n",
"(6, 3, 5) 531.4683398457538 [531.4683]\n",
"(6, 6, 3) 13581.78134955436 [13581.781]\n",
"(6, 6, 4) 20192.610065654717 [20192.611]\n",
"(6, 6, 6) 15924.563805685058 [15924.564]\n",
"(7, 0, 5) 25.410000801086426 [25.41]\n",
"(7, 2, 4) 484.76355986921754 [484.76355]\n",
"(7, 3, 4) 1475.7104961144375 [1475.7103]\n"
"(4, 5, 5) 997.0521781336806 [997.0522]\n",
"(4, 6, 3) 4926.357549691019 [4926.358]\n",
"(5, 1, 6) 1.0 [1.]\n",
"(5, 2, 4) 93.17000511407859 [93.170006]\n",
"(5, 4, 0) 37.72717041542877 [37.727173]\n",
"(5, 6, 2) 4338.624597774309 [4338.625]\n",
"(6, 3, 6) 252.89001389455817 [252.89001]\n",
"(6, 4, 5) 1903.623008021071 [1903.6229]\n",
"(6, 7, 0) 441.8908399527361 [441.89084]\n",
"(7, 0, 2) 33.8207136652209 [33.820713]\n",
"(7, 2, 7) 64.13000242233278 [64.130005]\n",
"(7, 5, 4) 12625.740273048325 [12625.74]\n",
"(7, 6, 2) 8353.927707221395 [8353.928]\n"
]
}
],
Expand Down Expand Up @@ -573,7 +578,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
"version": "3.8.17"
},
"orig_nbformat": 4
},
Expand Down
30 changes: 10 additions & 20 deletions pyscf_ipu/experimental/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .basis import Basis
from .orbital import batch_orbitals
from .primitive import Primitive, product
from .special import binom, binom_factor_default, factorial, factorial2, gammanu
from .special import binom, binom_factor, factorial, factorial2, gammanu
from .types import Float3, FloatN, FloatNx3, FloatNxN
from .units import LMAX

Expand Down Expand Up @@ -112,9 +112,7 @@ def build_gindex():
return i, r, u


def _nuclear_primitives(
a: Primitive, b: Primitive, c: Float3, binom_factor=binom_factor_default
):
def _nuclear_primitives(a: Primitive, b: Primitive, c: Float3):
p = product(a, b)
pa = p.center - a.center
pb = p.center - b.center
Expand Down Expand Up @@ -157,7 +155,7 @@ def g_term(l1, l2, pa, pb, cp):

overlap_primitives = jit(_overlap_primitives)
kinetic_primitives = jit(_kinetic_primitives)
nuclear_primitives = jit(_nuclear_primitives, static_argnames="binom_factor")
nuclear_primitives = jit(_nuclear_primitives)

vmap_overlap_primitives = jit(vmap(_overlap_primitives))
vmap_kinetic_primitives = jit(vmap(_kinetic_primitives))
Expand Down Expand Up @@ -185,13 +183,7 @@ def build_cindex():
return i1, i2, r1, r2, u


def _eri_primitives(
a: Primitive,
b: Primitive,
c: Primitive,
d: Primitive,
binom_factor=binom_factor_default,
) -> float:
def _eri_primitives(a: Primitive, b: Primitive, c: Primitive, d: Primitive) -> float:
p = product(a, b)
q = product(c, d)
pa = p.center - a.center
Expand Down Expand Up @@ -252,10 +244,8 @@ def c_term(la, lb, lc, ld, pa, pb, qc, qd, qp):
)


eri_primitives = jit(_eri_primitives, static_argnames="binom_factor")
vmap_eri_primitives = jit(
vmap(_eri_primitives, in_axes=(0, 0, 0, 0, None)), static_argnames="binom_factor"
)
eri_primitives = jit(_eri_primitives)
vmap_eri_primitives = jit(vmap(_eri_primitives))


def gen_ijkl(n: int):
Expand All @@ -270,7 +260,7 @@ def gen_ijkl(n: int):
yield idx, jdx, kdx, ldx


def eri_basis_sparse(b: Basis, binom_factor=binom_factor_default):
def eri_basis_sparse(b: Basis):
indices = []
batch = []
offset = np.cumsum([o.num_primitives for o in b.orbitals])
Expand All @@ -288,12 +278,12 @@ def eri_basis_sparse(b: Basis, binom_factor=binom_factor_default):
pijkl = [
tree_map(lambda x: jnp.take(x, idx, axis=0), primitives) for idx in indices
]
eris = cijkl * vmap_eri_primitives(*pijkl, binom_factor)
eris = cijkl * vmap_eri_primitives(*pijkl)
return segment_sum(eris, batch, num_segments=count + 1)


def eri_basis(b: Basis, binom_factor=binom_factor_default):
unique_eris = eri_basis_sparse(b, binom_factor)
def eri_basis(b: Basis):
unique_eris = eri_basis_sparse(b)
ii, jj, kk, ll = jnp.array(list(gen_ijkl(b.num_orbitals)), dtype=jnp.int32).T

# Apply 8x permutation symmetry to build dense ERI from sparse ERI.
Expand Down
Loading

0 comments on commit 8cd225c

Please sign in to comment.