diff --git a/notebooks/binom_factor_table.ipynb b/notebooks/binom_factor_table.ipynb index 0f4287d..0271742 100644 --- a/notebooks/binom_factor_table.ipynb +++ b/notebooks/binom_factor_table.ipynb @@ -61,9 +61,16 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -276,26 +283,38 @@ "name": "stdout", "output_type": "stream", "text": [ - "(0, 2, 1) (0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(2*b, a, b, domain='ZZ')\n", - "(0, 2, 5) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n", - "(0, 6, 6) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n", - "(1, 5, 5) (0, 5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a + 5*b, a, b, domain='ZZ')\n", - "(2, 1, 2) (0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(2*a + b, a, b, domain='ZZ')\n", - "(2, 3, 5) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n", - "(2, 5, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**2*b**5, a, b, domain='ZZ')\n", - "(2, 7, 7) (0, 0, 21, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**2 + 14*a*b + 21*b**2, a, b, domain='ZZ')\n", - "(4, 0, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(4*a**3, a, b, domain='ZZ')\n", - "(4, 2, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**4*b**2, a, b, domain='ZZ')\n", + "(0, 3, 2) (0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(3*b, a, b, domain='ZZ')\n", + "(0, 4, 1) (0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(4*b**3, a, b, domain='ZZ')\n", + "(0, 6, 4) (0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(15*b**2, a, b, domain='ZZ')\n", + "(1, 1, 7) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n", + "(1, 2, 1) (0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(2*a*b + b**2, a, b, domain='ZZ')\n", + "(1, 3, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a*b**3, a, b, domain='ZZ')\n", + "(1, 3, 5) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n", + "(1, 4, 1) (0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(4*a*b**3 + b**4, a, b, domain='ZZ')\n", + "(1, 6, 7) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n", + "(2, 0, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n", + "(2, 2, 4) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n", + "(2, 2, 7) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n", + "(2, 3, 2) (0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(3*a**2*b + 6*a*b**2 + b**3, a, b, domain='ZZ')\n", + "(2, 4, 2) (0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(6*a**2*b**2 + 8*a*b**3 + b**4, a, b, domain='ZZ')\n", + "(3, 0, 6) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n", + "(3, 1, 4) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n", + "(3, 2, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(2*a**3*b + 3*a**2*b**2, a, b, domain='ZZ')\n", + "(3, 5, 3) (0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 30, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(10*a**3*b**2 + 30*a**2*b**3 + 15*a*b**4 + b**5, a, b, domain='ZZ')\n", + "(3, 7, 3) (0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 63, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(35*a**3*b**4 + 63*a**2*b**5 + 21*a*b**6 + b**7, a, b, domain='ZZ')\n", + "(4, 0, 4) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n", "(4, 2, 2) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**4 + 8*a**3*b + 6*a**2*b**2, a, b, domain='ZZ')\n", - "(5, 0, 4) (0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(5*a, a, b, domain='ZZ')\n", - "(5, 5, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**5*b**5, a, b, domain='ZZ')\n", - "(6, 5, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 100, 0, 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 60, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(5*a**6*b + 60*a**5*b**2 + 150*a**4*b**3 + 100*a**3*b**4 + 15*a**2*b**5, a, b, domain='ZZ')\n", + "(4, 4, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**4*b**4, a, b, domain='ZZ')\n", + "(5, 2, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**5*b**2, a, b, domain='ZZ')\n", + "(5, 7, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(7*a**5*b**6 + 5*a**4*b**7, a, b, domain='ZZ')\n", + "(6, 0, 7) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n", + "(6, 6, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(6*a**6*b**5 + 6*a**5*b**6, a, b, domain='ZZ')\n", "(6, 6, 7) (0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 90, 0, 0, 0, 0, 0, 0, 300, 0, 0, 0, 0, 0, 0, 300, 0, 0, 0, 0, 0, 0, 90, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(6*a**5 + 90*a**4*b + 300*a**3*b**2 + 300*a**2*b**3 + 90*a*b**4 + 6*b**5, a, b, domain='ZZ')\n", + "(6, 7, 7) (0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 126, 0, 0, 0, 0, 0, 0, 525, 0, 0, 0, 0, 0, 0, 700, 0, 0, 0, 0, 0, 0, 315, 0, 0, 0, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**6 + 42*a**5*b + 315*a**4*b**2 + 700*a**3*b**3 + 525*a**2*b**4 + 126*a*b**5 + 7*b**6, a, b, domain='ZZ')\n", + "(7, 1, 2) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(7*a**6 + 21*a**5*b, a, b, domain='ZZ')\n", "(7, 2, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0, 70, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(21*a**5 + 70*a**4*b + 35*a**3*b**2, a, b, domain='ZZ')\n", - "(7, 2, 7) (0, 0, 1, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(21*a**2 + 14*a*b + b**2, a, b, domain='ZZ')\n", - "(7, 6, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0, 210, 0, 0, 0, 0, 0, 0, 315, 0, 0, 0, 0, 0, 0, 140, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0) Poly(15*a**7*b**2 + 140*a**6*b**3 + 315*a**5*b**4 + 210*a**4*b**5 + 35*a**3*b**6, a, b, domain='ZZ')\n", - "(7, 6, 6) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 126, 0, 0, 0, 0, 0, 0, 525, 0, 0, 0, 0, 0, 0, 700, 0, 0, 0, 0, 0, 0, 315, 0, 0, 0, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0) Poly(a**7 + 42*a**6*b + 315*a**5*b**2 + 700*a**4*b**3 + 525*a**3*b**4 + 126*a**2*b**5 + 7*a*b**6, a, b, domain='ZZ')\n", - "(7, 7, 3) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0, 147, 0, 0, 0, 0, 0, 0, 147, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0) Poly(35*a**7*b**4 + 147*a**6*b**5 + 147*a**5*b**6 + 35*a**4*b**7, a, b, domain='ZZ')\n" + "(7, 3, 7) (0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 63, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(35*a**3 + 63*a**2*b + 21*a*b**2 + b**3, a, b, domain='ZZ')\n", + "(7, 4, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0) Poly(4*a**7*b**3 + 7*a**6*b**4, a, b, domain='ZZ')\n" ] } ], @@ -490,49 +509,35 @@ "name": "stdout", "output_type": "stream", "text": [ - "(0, 0, 0) 1.0 [1.]\n", - "(0, 0, 5) 0 [0.]\n", - "(0, 0, 6) 0 [0.]\n", - "(0, 1, 2) 0 [0.]\n", - "(1, 1, 3) 0 [0.]\n", + "(0, 2, 2) 1.0 [1.]\n", + "(0, 5, 0) 51.53632558509851 [51.536327]\n", "(1, 1, 6) 0 [0.]\n", - "(1, 3, 5) 0 [0.]\n", - "(1, 4, 3) 38.72000133514405 [38.72]\n", - "(1, 4, 5) 1.0 [1.]\n", - "(1, 6, 7) 1.0 [1.]\n", - "(2, 1, 4) 0 [0.]\n", - "(2, 2, 0) 5.856400369262701 [5.8564005]\n", - "(2, 3, 3) 30.2500011253357 [30.25]\n", - "(2, 3, 6) 0 [0.]\n", - "(2, 5, 2) 438.05876595012853 [438.05878]\n", - "(3, 2, 5) 1.0 [1.]\n", + "(2, 1, 7) 0 [0.]\n", + "(2, 2, 5) 0 [0.]\n", + "(2, 6, 7) 15.40000033378601 [15.400001]\n", + "(2, 7, 4) 3336.977076303898 [3336.977]\n", + "(3, 1, 4) 1.0 [1.]\n", + "(3, 1, 6) 0 [0.]\n", + "(3, 1, 7) 0 [0.]\n", "(3, 3, 0) 14.172489843082529 [14.17249]\n", - "(3, 4, 4) 171.69900965380683 [171.699]\n", - "(3, 4, 5) 61.71000228881837 [61.710003]\n", "(3, 4, 6) 12.100000262260437 [12.1]\n", + "(3, 5, 0) 68.59485381402618 [68.59486]\n", "(3, 5, 1) 342.9742594246647 [342.97427]\n", - "(3, 6, 3) 2692.773055105913 [2692.773]\n", - "(3, 7, 2) 5144.614001991794 [5144.6143]\n", - "(4, 0, 6) 0 [0.]\n", - "(4, 0, 7) 0 [0.]\n", + "(3, 5, 3) 889.0016110117186 [889.0016]\n", "(4, 2, 2) 60.02810437345515 [60.028107]\n", - "(4, 2, 6) 1.0 [1.]\n", - "(4, 3, 0) 15.58973916528927 [15.589739]\n", - "(4, 5, 4) 1613.731182697727 [1613.7311]\n", - "(4, 7, 6) 7959.14124416379 [7959.1416]\n", - "(5, 0, 4) 5.5000001192092896 [5.5]\n", - "(5, 0, 6) 0 [0.]\n", - "(5, 3, 2) 258.64793837960883 [258.64795]\n", - "(5, 5, 0) 82.99977671291498 [82.99978]\n", - "(5, 6, 0) 182.5995127261507 [182.59952]\n", - "(6, 1, 1) 23.030295995009112 [23.030296]\n", - "(6, 3, 5) 531.4683398457538 [531.4683]\n", - "(6, 6, 3) 13581.78134955436 [13581.781]\n", - "(6, 6, 4) 20192.610065654717 [20192.611]\n", - "(6, 6, 6) 15924.563805685058 [15924.564]\n", - "(7, 0, 5) 25.410000801086426 [25.41]\n", - "(7, 2, 4) 484.76355986921754 [484.76355]\n", - "(7, 3, 4) 1475.7104961144375 [1475.7103]\n" + "(4, 5, 5) 997.0521781336806 [997.0522]\n", + "(4, 6, 3) 4926.357549691019 [4926.358]\n", + "(5, 1, 6) 1.0 [1.]\n", + "(5, 2, 4) 93.17000511407859 [93.170006]\n", + "(5, 4, 0) 37.72717041542877 [37.727173]\n", + "(5, 6, 2) 4338.624597774309 [4338.625]\n", + "(6, 3, 6) 252.89001389455817 [252.89001]\n", + "(6, 4, 5) 1903.623008021071 [1903.6229]\n", + "(6, 7, 0) 441.8908399527361 [441.89084]\n", + "(7, 0, 2) 33.8207136652209 [33.820713]\n", + "(7, 2, 7) 64.13000242233278 [64.130005]\n", + "(7, 5, 4) 12625.740273048325 [12625.74]\n", + "(7, 6, 2) 8353.927707221395 [8353.928]\n" ] } ], @@ -573,7 +578,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.18" + "version": "3.8.17" }, "orig_nbformat": 4 }, diff --git a/pyscf_ipu/experimental/integrals.py b/pyscf_ipu/experimental/integrals.py index 0036f93..3ec90ea 100644 --- a/pyscf_ipu/experimental/integrals.py +++ b/pyscf_ipu/experimental/integrals.py @@ -12,7 +12,7 @@ from .basis import Basis from .orbital import batch_orbitals from .primitive import Primitive, product -from .special import binom, binom_factor_default, factorial, factorial2, gammanu +from .special import binom, binom_factor, factorial, factorial2, gammanu from .types import Float3, FloatN, FloatNx3, FloatNxN from .units import LMAX @@ -112,9 +112,7 @@ def build_gindex(): return i, r, u -def _nuclear_primitives( - a: Primitive, b: Primitive, c: Float3, binom_factor=binom_factor_default -): +def _nuclear_primitives(a: Primitive, b: Primitive, c: Float3): p = product(a, b) pa = p.center - a.center pb = p.center - b.center @@ -157,7 +155,7 @@ def g_term(l1, l2, pa, pb, cp): overlap_primitives = jit(_overlap_primitives) kinetic_primitives = jit(_kinetic_primitives) -nuclear_primitives = jit(_nuclear_primitives, static_argnames="binom_factor") +nuclear_primitives = jit(_nuclear_primitives) vmap_overlap_primitives = jit(vmap(_overlap_primitives)) vmap_kinetic_primitives = jit(vmap(_kinetic_primitives)) @@ -185,13 +183,7 @@ def build_cindex(): return i1, i2, r1, r2, u -def _eri_primitives( - a: Primitive, - b: Primitive, - c: Primitive, - d: Primitive, - binom_factor=binom_factor_default, -) -> float: +def _eri_primitives(a: Primitive, b: Primitive, c: Primitive, d: Primitive) -> float: p = product(a, b) q = product(c, d) pa = p.center - a.center @@ -252,10 +244,8 @@ def c_term(la, lb, lc, ld, pa, pb, qc, qd, qp): ) -eri_primitives = jit(_eri_primitives, static_argnames="binom_factor") -vmap_eri_primitives = jit( - vmap(_eri_primitives, in_axes=(0, 0, 0, 0, None)), static_argnames="binom_factor" -) +eri_primitives = jit(_eri_primitives) +vmap_eri_primitives = jit(vmap(_eri_primitives)) def gen_ijkl(n: int): @@ -270,7 +260,7 @@ def gen_ijkl(n: int): yield idx, jdx, kdx, ldx -def eri_basis_sparse(b: Basis, binom_factor=binom_factor_default): +def eri_basis_sparse(b: Basis): indices = [] batch = [] offset = np.cumsum([o.num_primitives for o in b.orbitals]) @@ -288,12 +278,12 @@ def eri_basis_sparse(b: Basis, binom_factor=binom_factor_default): pijkl = [ tree_map(lambda x: jnp.take(x, idx, axis=0), primitives) for idx in indices ] - eris = cijkl * vmap_eri_primitives(*pijkl, binom_factor) + eris = cijkl * vmap_eri_primitives(*pijkl) return segment_sum(eris, batch, num_segments=count + 1) -def eri_basis(b: Basis, binom_factor=binom_factor_default): - unique_eris = eri_basis_sparse(b, binom_factor) +def eri_basis(b: Basis): + unique_eris = eri_basis_sparse(b) ii, jj, kk, ll = jnp.array(list(gen_ijkl(b.num_orbitals)), dtype=jnp.int32).T # Apply 8x permutation symmetry to build dense ERI from sparse ERI. diff --git a/pyscf_ipu/experimental/special.py b/pyscf_ipu/experimental/special.py index a94e2bd..625e500 100644 --- a/pyscf_ipu/experimental/special.py +++ b/pyscf_ipu/experimental/special.py @@ -140,7 +140,7 @@ def binom_factor_direct(i: int, j: int, a: float, b: float, s: int): def binom_factor_segment_sum( i: int, j: int, a: float, b: float, lmax: int = LMAX ) -> FloatN: - # Vectorized version of above + # Vectorized version of above, producing all values s in range(LMAX) s, t = jnp.tril_indices(lmax + 1) out = binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t) mask = ((s - i) <= t) & (t <= j) @@ -148,23 +148,20 @@ def binom_factor_segment_sum( return segment_sum(out, s, num_segments=lmax + 1) -def binom_factor__via_segment_sum( - i: int, j: int, a: float, b: float, s: int, lmax=LMAX -): +def binom_factor_via_segment_sum(i: int, j: int, a: float, b: float, s: int, lmax=LMAX): return jnp.take(binom_factor_segment_sum(i, j, a, b, lmax), s) binom_factor_table_W = jnp.array(binom_factor_table.build_binom_factor_table()) -def binom_factor__via_lookup( +def binom_factor_via_lookup( i: int, j: int, a: float, b: float, s: int, lmax=None ) -> FloatN: # Lookup-table version of above -- see binom_factor_table.ipynb for the derivation # lmax is ignored, but used to allow easy swapping with above implementation - monomials = jnp.array(binom_factor_table.get_monomials(a, b)) - coeffs = binom_factor_table_W[i, j, s] - return coeffs @ monomials + monomials = binom_factor_table.get_monomials(a, b) + return jnp.dot(binom_factor_table_W[i, j, s, :], monomials) -binom_factor_default = binom_factor__via_segment_sum +binom_factor = binom_factor_via_segment_sum diff --git a/test/test_integrals.py b/test/test_integrals.py index ed54859..34d5c06 100644 --- a/test/test_integrals.py +++ b/test/test_integrals.py @@ -4,7 +4,6 @@ import pytest from numpy.testing import assert_allclose -import pyscf_ipu.experimental as pyscf_experimental from pyscf_ipu.experimental.basis import basisset from pyscf_ipu.experimental.integrals import ( eri_basis, @@ -79,49 +78,13 @@ def test_water_kinetic(basis_name): assert_allclose(actual, expect, atol=1e-4) -def check_recompile(recompile, function): - # Force recompile - if recompile == "recompile": - # TBH, this is a bit of a red herring - it will force recompilation, - # but the whole switch is only really useful if the False case - # runs after the true case in the same process - # i.e. timing from - # pytest -k test_nuclear[lookup-recompile] --durations=5 - # will be the same as - # pytest -k test_nuclear[lookup-cached] --durations=5 - # While - # pytest -k test_nuclear[lookup- --durations=5 - # will show both times, and cached will be lower - function._clear_cache() - - -@pytest.mark.parametrize("recompile", ["recompile", "cached"]) -@pytest.mark.parametrize("binom_factor_str", ["segment_sum", "lookup"]) -def test_nuclear(binom_factor_str, recompile): +@pytest.mark.parametrize("recompile", ["first", "cached"]) +def test_nuclear(recompile): # PyQuante test case for nuclear attraction integral p = Primitive() c = jnp.zeros(3) - # Choose the implementation of binom_factor - if binom_factor_str == "segment_sum": - binom_factor = pyscf_experimental.special.binom_factor__via_segment_sum - elif binom_factor_str == "lookup": - binom_factor = pyscf_experimental.special.binom_factor__via_lookup - else: - assert False - - check_recompile(recompile, nuclear_primitives) - assert_allclose(nuclear_primitives(p, p, c, binom_factor), -1.595769, atol=1e-5) - - # if recompile == 'recompile': - # from jaxutils.jaxpr_to_expr import show_jaxpr - # show_jaxpr( - # nuclear_primitives, - # (p, p, c, binom_factor), - # file=f"tmp/nuclear_primitives_jaxpr__binom_factor__via_{binom_factor_str}.py", - # optimize=False, - # static_argnums=3, - # ) + assert_allclose(nuclear_primitives(p, p, c), -1.595769, atol=1e-5) # Reproduce the nuclear attraction matrix for H2 using STO-3G basis set # See equation 3.231 and 3.232 of Szabo and Ostlund @@ -180,24 +143,18 @@ def is_mem_limited(): return total_mem_gib < 10 -@pytest.mark.parametrize("recompile", ["recompile", "cached"]) -@pytest.mark.parametrize("binom_factor_str", ["segment_sum", "lookup"]) @pytest.mark.parametrize("sparsity", ["sparse", "dense"]) @pytest.mark.skipif(is_mem_limited(), reason="Not enough host memory!") -def test_water_eri(recompile, binom_factor_str, sparsity): +def test_water_eri(sparsity, xpass): sparse = sparsity == "sparse" - check_recompile(recompile, eri_primitives) - binom_factor = eval( - "pyscf_experimental.special.binom_factor__via_" + binom_factor_str - ) basis_name = "sto-3g" h2o = molecule("water") basis = basisset(h2o, basis_name) if sparse: - actual = eri_basis_sparse(basis, binom_factor) + actual = eri_basis_sparse(basis) else: - actual = eri_basis(basis, binom_factor) + actual = eri_basis(basis) aosym = "s8" if sparse else "s1" expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym) assert_allclose(actual, expect, atol=1e-4) diff --git a/test/test_special.py b/test/test_special.py index 5e6545c..0227e82 100644 --- a/test/test_special.py +++ b/test/test_special.py @@ -1,10 +1,15 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import jax import jax.numpy as jnp +import numpy as np import pytest from numpy.testing import assert_allclose from pyscf_ipu.experimental.special import ( binom_beta, + binom_factor_direct, + binom_factor_via_lookup, + binom_factor_via_segment_sum, binom_fori, binom_lookup, factorial2_fori, @@ -49,3 +54,23 @@ def test_binom(binom_func): assert_allclose(binom_func(one, one), one) assert_allclose(binom_func(zero, -one), zero) assert_allclose(binom_func(zero, zero), one) + + +@pytest.mark.parametrize( + "binom_func", + [binom_factor_direct, binom_factor_via_lookup, binom_factor_via_segment_sum], +) +def test_binom_factor(binom_func): + if binom_func == binom_factor_direct: + n = 10 + else: + binom_func = jax.jit(binom_func) + n = 100000 + va = np.random.rand(n) + vb = np.random.rand(n) + for i, j, s in zip( + jnp.array([0, 1, 2, 3]), jnp.array([1, 2, 3, 1]), jnp.array([1, 2, 3, 4]) + ): + for a, b in zip(va, vb): + val = binom_func(i, j, a, b, s) + val.block_until_ready()