mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Checkify: fix nan check when primitive has multiple results.
PiperOrigin-RevId: 488653856
This commit is contained in:
parent
108bc83520
commit
3116ed52a9
@ -568,10 +568,10 @@ def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
|
||||
def isnan(x):
|
||||
if isinstance(x, prng.PRNGKeyArray):
|
||||
return False
|
||||
return jnp.isnan(x)
|
||||
return jnp.any(jnp.isnan(x))
|
||||
|
||||
any_nans = (jnp.any(isnan(x) for x in out)
|
||||
if prim.multiple_results else jnp.any(isnan(out)))
|
||||
any_nans = (jnp.any(jnp.array([isnan(x) for x in out]))
|
||||
if prim.multiple_results else isnan(out))
|
||||
msg = f'nan generated by primitive {prim.name} at {summary()}'
|
||||
return out, assert_func(error, any_nans, msg, None)
|
||||
|
||||
|
@ -973,6 +973,17 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "hi!")
|
||||
|
||||
def test_psum_nan_check(self):
|
||||
@partial(jax.vmap, axis_name="i")
|
||||
def f(x, y):
|
||||
return lax.psum((x, y), axis_name="i")
|
||||
|
||||
cf = checkify.checkify(f, errors=checkify.nan_checks)
|
||||
err, _ = cf(jnp.array([-jnp.inf, 0, jnp.inf]), jnp.ones((3, 2)))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive psum")
|
||||
|
||||
|
||||
class LowerableChecksTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
Loading…
x
Reference in New Issue
Block a user