diff --git a/jax/experimental/key_reuse/_forwarding.py b/jax/experimental/key_reuse/_forwarding.py index 3e46c8b12..ebe3737cc 100644 --- a/jax/experimental/key_reuse/_forwarding.py +++ b/jax/experimental/key_reuse/_forwarding.py @@ -28,6 +28,7 @@ from jax._src import pjit from jax._src import prng from jax._src import random from jax._src import util +from jax._src.ad_checkpoint import remat_p from jax._src.debugging import debug_callback_p from jax._src.interpreters import partial_eval as pe @@ -322,3 +323,22 @@ def _while_key_type_signature(eqn, args_consumed): return body_signature key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature + +def _remat_key_type_signature(eqn, args_consumed): + # The assumption here is that the non-differentiated pass contains all relevant + # key usage, and the differentiated pass + # 1) will only consume keys that are already consumed in the non-differentiated pass + # 2) will never create keys + # Therefore, the differentiated pass is a no-op. + if eqn.params['differentiated']: + return KeyReuseSignatureWithForwards([], []) + jaxpr = eqn.params['jaxpr'] + forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars) + if var in eqn.invars[:i]} + sig = get_jaxpr_type_signature(jaxpr) + if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks): + # Double consumption detected: re-trace with context for better errors. + get_jaxpr_type_signature(jaxpr, args_consumed, forwarded_inputs) + return sig + +key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature diff --git a/jax/experimental/key_reuse/_simple.py b/jax/experimental/key_reuse/_simple.py index 0d379587c..06f8e15b6 100644 --- a/jax/experimental/key_reuse/_simple.py +++ b/jax/experimental/key_reuse/_simple.py @@ -28,6 +28,7 @@ from jax._src import pjit from jax._src import prng from jax._src import random from jax._src import util +from jax._src.ad_checkpoint import remat_p from jax._src.debugging import debug_callback_p from jax._src.interpreters import partial_eval as pe @@ -293,3 +294,21 @@ def _while_key_type_signature(eqn, args_consumed): return body_signature key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature + + +def _remat_key_type_signature(eqn, args_consumed): + # The assumption here is that the non-differentiated pass contains all relevant + # key usage, and the differentiated pass + # 1) will only consume keys that are already consumed in the non-differentiated pass + # 2) will never create keys + # Therefore, the differentiated pass is a no-op. + if eqn.params['differentiated']: + return KeyReuseSignature([], []) + jaxpr = eqn.params['jaxpr'] + sig = get_jaxpr_type_signature(jaxpr) + if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks): + # Double consumption detected: re-trace with context for better errors. + get_jaxpr_type_signature(jaxpr, args_consumed) + return sig + +key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 0cda84d99..cae82aa34 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -802,6 +802,28 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase): self.check_key_reuse(f, 0) + @jax.numpy_dtype_promotion('standard') + def test_remat(self): + @jax.checkpoint + def f_bad(x, key): + return x * jax.random.bits(key) + jax.random.bits(key) + + @jax.checkpoint + def f_good(x, key): + return x * jax.random.bits(key) + + x = jnp.float32(1.0) + key = jax.random.key(0) + + with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): + self.check_key_reuse(f_bad, x, key) + + with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): + self.check_key_reuse(jax.grad(f_bad), x, key) + + self.check_key_reuse(f_good, x, key) + self.check_key_reuse(jax.grad(f_good), x, key) + class KeyReuseIntegrationTestSimple(KeyReuseIntegrationTest): use_forwarding = False