[key reuse]: handle remat

This commit is contained in:
Jake VanderPlas 2024-02-05 16:02:12 -08:00
parent 9e94e6ef71
commit 7360edd404
3 changed files with 61 additions and 0 deletions

View File

@ -28,6 +28,7 @@ from jax._src import pjit
from jax._src import prng
from jax._src import random
from jax._src import util
from jax._src.ad_checkpoint import remat_p
from jax._src.debugging import debug_callback_p
from jax._src.interpreters import partial_eval as pe
@ -322,3 +323,22 @@ def _while_key_type_signature(eqn, args_consumed):
return body_signature
key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature
def _remat_key_type_signature(eqn, args_consumed):
# The assumption here is that the non-differentiated pass contains all relevant
# key usage, and the differentiated pass
# 1) will only consume keys that are already consumed in the non-differentiated pass
# 2) will never create keys
# Therefore, the differentiated pass is a no-op.
if eqn.params['differentiated']:
return KeyReuseSignatureWithForwards([], [])
jaxpr = eqn.params['jaxpr']
forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars)
if var in eqn.invars[:i]}
sig = get_jaxpr_type_signature(jaxpr)
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
# Double consumption detected: re-trace with context for better errors.
get_jaxpr_type_signature(jaxpr, args_consumed, forwarded_inputs)
return sig
key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature

View File

@ -28,6 +28,7 @@ from jax._src import pjit
from jax._src import prng
from jax._src import random
from jax._src import util
from jax._src.ad_checkpoint import remat_p
from jax._src.debugging import debug_callback_p
from jax._src.interpreters import partial_eval as pe
@ -293,3 +294,21 @@ def _while_key_type_signature(eqn, args_consumed):
return body_signature
key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature
def _remat_key_type_signature(eqn, args_consumed):
# The assumption here is that the non-differentiated pass contains all relevant
# key usage, and the differentiated pass
# 1) will only consume keys that are already consumed in the non-differentiated pass
# 2) will never create keys
# Therefore, the differentiated pass is a no-op.
if eqn.params['differentiated']:
return KeyReuseSignature([], [])
jaxpr = eqn.params['jaxpr']
sig = get_jaxpr_type_signature(jaxpr)
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
# Double consumption detected: re-trace with context for better errors.
get_jaxpr_type_signature(jaxpr, args_consumed)
return sig
key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature

View File

@ -802,6 +802,28 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
self.check_key_reuse(f, 0)
@jax.numpy_dtype_promotion('standard')
def test_remat(self):
@jax.checkpoint
def f_bad(x, key):
return x * jax.random.bits(key) + jax.random.bits(key)
@jax.checkpoint
def f_good(x, key):
return x * jax.random.bits(key)
x = jnp.float32(1.0)
key = jax.random.key(0)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
self.check_key_reuse(f_bad, x, key)
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
self.check_key_reuse(jax.grad(f_bad), x, key)
self.check_key_reuse(f_good, x, key)
self.check_key_reuse(jax.grad(f_good), x, key)
class KeyReuseIntegrationTestSimple(KeyReuseIntegrationTest):
use_forwarding = False