mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
parent
31cb3fd36e
commit
11fdda9583
@ -27,6 +27,7 @@ from jax import lax
|
||||
|
||||
from jax.experimental import shard_map
|
||||
from jax._src import api
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
@ -933,6 +934,19 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
error_checks[pjit.pjit_p] = pjit_error_check
|
||||
|
||||
|
||||
def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
in_avals = tuple(map(get_shaped_aval, new_vals_in))
|
||||
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(
|
||||
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals)
|
||||
checked_jaxpr, () = checked_jaxpr_.jaxpr, checked_jaxpr_.consts
|
||||
err_and_out = ad_checkpoint.remat_p.bind(*new_vals_in, jaxpr=checked_jaxpr,
|
||||
**params)
|
||||
return tree_unflatten(out_tree, err_and_out)
|
||||
error_checks[ad_checkpoint.remat_p] = remat_error_check
|
||||
|
||||
|
||||
def shard_map_error_check(
|
||||
error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs
|
||||
):
|
||||
@ -950,12 +964,10 @@ def shard_map_error_check(
|
||||
raise ValueError(f'Unsupported aval type: {type(v)}')
|
||||
in_avals[i] = sharder(mesh, new_in_names[i], v)
|
||||
|
||||
if not isinstance(jaxpr, core.ClosedJaxpr):
|
||||
jaxpr = core.ClosedJaxpr(jaxpr, ())
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
# jaxpr to checked_jaxpr
|
||||
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(
|
||||
jaxpr, enabled_errors, err_tree, *in_avals
|
||||
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals
|
||||
)
|
||||
num_out_error_vals = out_tree.num_leaves - len(out_names)
|
||||
|
||||
|
@ -912,6 +912,25 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
jax.jit(checkify.checkify(f))(0) # Does not crash bc of leaked tracer.
|
||||
|
||||
@parameterized.parameters(True, False)
|
||||
def test_remat(self, jit):
|
||||
# basic test from https://github.com/jax-ml/jax/issues/23867
|
||||
def fn(x: jax.Array):
|
||||
checkify.check(jnp.all(x > 0), "x must be positive")
|
||||
return x + 1
|
||||
|
||||
fn = jax.remat(fn)
|
||||
if jit:
|
||||
fn = jax.jit(fn)
|
||||
fn = checkify.checkify(fn)
|
||||
err, y = fn(jnp.array([1, 2, 3]))
|
||||
self.assertIsNone(err.get())
|
||||
self.assertAllClose(y, jnp.array([2, 3, 4]), check_dtypes=False)
|
||||
|
||||
err, _ = fn(jnp.array([0, 2, 3]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "x must be positive")
|
||||
|
||||
|
||||
@jtu.with_config(jax_check_tracer_leaks=True)
|
||||
class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user