Prevent tracer leaks in scipy.special.expn

This commit is contained in:
Jake VanderPlas 2025-03-06 14:38:11 -08:00
parent 4b49c03523
commit b441b2b7a5
2 changed files with 12 additions and 7 deletions

View File

@ -2106,7 +2106,7 @@ def expi_jvp(primals, tangents):
return expi(x), jnp.exp(x) / x * x_dot
def _expn1(n: Array, x: Array) -> Array:
def _expn1(x: Array, n: Array) -> Array:
# exponential integral En
_c = _lax_const
MACHEP = jnp.finfo(x.dtype).eps
@ -2143,7 +2143,7 @@ def _expn1(n: Array, x: Array) -> Array:
return d["z"] ** r * psi / jnp.exp(gammaln(t)) - d["ans"]
def _expn2(n: Array, x: Array) -> Array:
def _expn2(x: Array, n: Array) -> Array:
# x > 1.
_c = _lax_const
BIG = _c(x, 1.44115188075855872e17)
@ -2194,7 +2194,7 @@ def _expn2(n: Array, x: Array) -> Array:
return d["ans"] * jnp.exp(-x)
def _expn3(n: Array, x: Array) -> Array:
def _expn3(x: Array, n: Array) -> Array:
# n >= 5000
_c = _lax_const
one = _c(x, 1.0)
@ -2248,11 +2248,11 @@ def expn(n: ArrayLike, x: ArrayLike) -> Array:
jnp.inf,
one / n1, # prevent div by zero
jnp.exp(-x) / x,
partial(_expn3, n),
partial(_expn2, n),
partial(_expn1, n),
_expn3,
_expn2,
_expn1,
]
ret = jnp.piecewise(x, conds, vals)
ret = jnp.piecewise(x, conds, vals, n=n)
return ret

View File

@ -273,6 +273,11 @@ class LaxScipySpcialFunctionsTest(jtu.JaxTestCase):
with self.assertRaises(TypeError):
lsp_special.beta(x=1, y=1)
def testExpnTracerLeaks(self):
# Regression test for https://github.com/jax-ml/jax/issues/26972
with jax.checking_leaks():
lsp_special.expi(jnp.ones(()))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())