Checkify: switch to initial-style.

This commit is contained in:
lenamartens 2022-12-06 16:10:27 +00:00
parent caf4f7b3f7
commit 0bce1cf129
3 changed files with 2474 additions and 747 deletions

File diff suppressed because it is too large Load Diff

1760
jax/_src/checkify.py.orig Normal file

File diff suppressed because it is too large Load Diff

View File

@ -544,7 +544,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return f(jnp.array([jnp.inf]))[0]
err, _ = g(2.)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
self.assertIn("nan generated by primitive: sin", err.get())
@jtu.skip_on_devices("tpu")
def test_custom_jvp(self):
@ -774,7 +774,6 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(f)(jnp.ones(4, jnp.int32))
self.assertSetEqual(jaxpr.effects,
{ErrorEffect(FailedCheckError, (
jax.ShapeDtypeStruct((0,), jnp.int32),
jax.ShapeDtypeStruct((4,), jnp.int32),))})
def g(x, y):
checkify.check(False, "hi: {} {}", x, y)
@ -783,7 +782,6 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
jax.make_jaxpr(g)(
jnp.ones(4, jnp.int32), jnp.ones(2, jnp.float32)).effects,
{ErrorEffect(FailedCheckError, (
jax.ShapeDtypeStruct((0,), jnp.int32),
jax.ShapeDtypeStruct((4,), jnp.int32),
jax.ShapeDtypeStruct((2,), jnp.float32)))})