mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #26977 from jakevdp:fix-expn
PiperOrigin-RevId: 735506133
This commit is contained in:
commit
c942b0fef0
@ -2106,7 +2106,7 @@ def expi_jvp(primals, tangents):
|
||||
return expi(x), jnp.exp(x) / x * x_dot
|
||||
|
||||
|
||||
def _expn1(n: Array, x: Array) -> Array:
|
||||
def _expn1(x: Array, n: Array) -> Array:
|
||||
# exponential integral En
|
||||
_c = _lax_const
|
||||
MACHEP = jnp.finfo(x.dtype).eps
|
||||
@ -2143,7 +2143,7 @@ def _expn1(n: Array, x: Array) -> Array:
|
||||
return d["z"] ** r * psi / jnp.exp(gammaln(t)) - d["ans"]
|
||||
|
||||
|
||||
def _expn2(n: Array, x: Array) -> Array:
|
||||
def _expn2(x: Array, n: Array) -> Array:
|
||||
# x > 1.
|
||||
_c = _lax_const
|
||||
BIG = _c(x, 1.44115188075855872e17)
|
||||
@ -2194,7 +2194,7 @@ def _expn2(n: Array, x: Array) -> Array:
|
||||
return d["ans"] * jnp.exp(-x)
|
||||
|
||||
|
||||
def _expn3(n: Array, x: Array) -> Array:
|
||||
def _expn3(x: Array, n: Array) -> Array:
|
||||
# n >= 5000
|
||||
_c = _lax_const
|
||||
one = _c(x, 1.0)
|
||||
@ -2248,11 +2248,11 @@ def expn(n: ArrayLike, x: ArrayLike) -> Array:
|
||||
jnp.inf,
|
||||
one / n1, # prevent div by zero
|
||||
jnp.exp(-x) / x,
|
||||
partial(_expn3, n),
|
||||
partial(_expn2, n),
|
||||
partial(_expn1, n),
|
||||
_expn3,
|
||||
_expn2,
|
||||
_expn1,
|
||||
]
|
||||
ret = jnp.piecewise(x, conds, vals)
|
||||
ret = jnp.piecewise(x, conds, vals, n=n)
|
||||
return ret
|
||||
|
||||
|
||||
|
@ -273,6 +273,11 @@ class LaxScipySpcialFunctionsTest(jtu.JaxTestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
lsp_special.beta(x=1, y=1)
|
||||
|
||||
def testExpnTracerLeaks(self):
|
||||
# Regression test for https://github.com/jax-ml/jax/issues/26972
|
||||
with jax.checking_leaks():
|
||||
lsp_special.expi(jnp.ones(()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user