add checkify rule for remat

fixes #23867
This commit is contained in:
Matthew Johnson 2024-10-01 01:56:39 +00:00
parent 31cb3fd36e
commit 11fdda9583
2 changed files with 34 additions and 3 deletions

View File

@ -27,6 +27,7 @@ from jax import lax
from jax.experimental import shard_map
from jax._src import api
from jax._src import ad_checkpoint
from jax._src import linear_util as lu
from jax._src import config
from jax._src import core
@ -933,6 +934,19 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
error_checks[pjit.pjit_p] = pjit_error_check
def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
err_vals, err_tree = jtu.tree_flatten(error)
new_vals_in = [*err_vals, *vals_in]
in_avals = tuple(map(get_shaped_aval, new_vals_in))
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals)
checked_jaxpr, () = checked_jaxpr_.jaxpr, checked_jaxpr_.consts
err_and_out = ad_checkpoint.remat_p.bind(*new_vals_in, jaxpr=checked_jaxpr,
**params)
return tree_unflatten(out_tree, err_and_out)
error_checks[ad_checkpoint.remat_p] = remat_error_check
def shard_map_error_check(
error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs
):
@ -950,12 +964,10 @@ def shard_map_error_check(
raise ValueError(f'Unsupported aval type: {type(v)}')
in_avals[i] = sharder(mesh, new_in_names[i], v)
if not isinstance(jaxpr, core.ClosedJaxpr):
jaxpr = core.ClosedJaxpr(jaxpr, ())
with core.extend_axis_env_nd(mesh.shape.items()):
# jaxpr to checked_jaxpr
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
jaxpr, enabled_errors, err_tree, *in_avals
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals
)
num_out_error_vals = out_tree.num_leaves - len(out_names)

View File

@ -912,6 +912,25 @@ class CheckifyTransformTests(jtu.JaxTestCase):
jax.jit(checkify.checkify(f))(0) # Does not crash bc of leaked tracer.
@parameterized.parameters(True, False)
def test_remat(self, jit):
# basic test from https://github.com/jax-ml/jax/issues/23867
def fn(x: jax.Array):
checkify.check(jnp.all(x > 0), "x must be positive")
return x + 1
fn = jax.remat(fn)
if jit:
fn = jax.jit(fn)
fn = checkify.checkify(fn)
err, y = fn(jnp.array([1, 2, 3]))
self.assertIsNone(err.get())
self.assertAllClose(y, jnp.array([2, 3, 4]), check_dtypes=False)
err, _ = fn(jnp.array([0, 2, 3]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "x must be positive")
@jtu.with_config(jax_check_tracer_leaks=True)
class AssertPrimitiveTests(jtu.JaxTestCase):