diff --git a/pyscf_ipu/experimental/special.py b/pyscf_ipu/experimental/special.py index 734c252..0cd06e7 100644 --- a/pyscf_ipu/experimental/special.py +++ b/pyscf_ipu/experimental/special.py @@ -5,7 +5,7 @@ import numpy as np from jax import lax from jax.ops import segment_sum -from jax.scipy.special import betaln, gammaln +from jax.scipy.special import betaln, gammainc, gammaln from .types import FloatN, IntN from .units import LMAX @@ -77,7 +77,25 @@ def binom_lookup(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: binom = binom_lookup -def gammanu(nu: IntN, t: FloatN, num_terms: int = 128) -> FloatN: +def gammanu_gamma(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN: + """ + eq 2.11 from THO but simplified using SymPy and converted to jax + + t, u = symbols("t u", real=True, positive=True) + nu = Symbol("nu", integer=True, nonnegative=True) + + expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1))) + f = lambdify((nu, t), expr, modules="scipy") + ?f + + We evaulate this in log-space to avoid overflow/nan + """ + x = nu + 0.5 + gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x) + return jnp.where(t <= epsilon, 1 / (2 * nu + 1), jnp.exp(gn)) + + +def gammanu_series(nu: IntN, t: FloatN, num_terms: int = 128) -> FloatN: """ eq 2.11 from THO but simplified as derived in equation 19 of gammanu.ipynb """ @@ -93,6 +111,9 @@ def gammanu(nu: IntN, t: FloatN, num_terms: int = 128) -> FloatN: return jnp.exp(-t) / 2 * total +gammanu = gammanu_series + + def binom_factor(i: int, j: int, a: float, b: float, lmax: int = LMAX) -> FloatN: """ Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An