mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Checkify: switch to initial-style.
This commit is contained in:
parent
caf4f7b3f7
commit
0bce1cf129
1457
jax/_src/checkify.py
1457
jax/_src/checkify.py
File diff suppressed because it is too large
Load Diff
1760
jax/_src/checkify.py.orig
Normal file
1760
jax/_src/checkify.py.orig
Normal file
File diff suppressed because it is too large
Load Diff
@ -544,7 +544,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return f(jnp.array([jnp.inf]))[0]
|
||||
err, _ = g(2.)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
||||
self.assertIn("nan generated by primitive: sin", err.get())
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_custom_jvp(self):
|
||||
@ -774,7 +774,6 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.ones(4, jnp.int32))
|
||||
self.assertSetEqual(jaxpr.effects,
|
||||
{ErrorEffect(FailedCheckError, (
|
||||
jax.ShapeDtypeStruct((0,), jnp.int32),
|
||||
jax.ShapeDtypeStruct((4,), jnp.int32),))})
|
||||
def g(x, y):
|
||||
checkify.check(False, "hi: {} {}", x, y)
|
||||
@ -783,7 +782,6 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
jax.make_jaxpr(g)(
|
||||
jnp.ones(4, jnp.int32), jnp.ones(2, jnp.float32)).effects,
|
||||
{ErrorEffect(FailedCheckError, (
|
||||
jax.ShapeDtypeStruct((0,), jnp.int32),
|
||||
jax.ShapeDtypeStruct((4,), jnp.int32),
|
||||
jax.ShapeDtypeStruct((2,), jnp.float32)))})
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user