mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[key reuse]: handle remat
This commit is contained in:
parent
9e94e6ef71
commit
7360edd404
@ -28,6 +28,7 @@ from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
from jax._src import util
|
||||
from jax._src.ad_checkpoint import remat_p
|
||||
from jax._src.debugging import debug_callback_p
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
@ -322,3 +323,22 @@ def _while_key_type_signature(eqn, args_consumed):
|
||||
return body_signature
|
||||
|
||||
key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature
|
||||
|
||||
def _remat_key_type_signature(eqn, args_consumed):
|
||||
# The assumption here is that the non-differentiated pass contains all relevant
|
||||
# key usage, and the differentiated pass
|
||||
# 1) will only consume keys that are already consumed in the non-differentiated pass
|
||||
# 2) will never create keys
|
||||
# Therefore, the differentiated pass is a no-op.
|
||||
if eqn.params['differentiated']:
|
||||
return KeyReuseSignatureWithForwards([], [])
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
forwarded_inputs = {i: eqn.invars.index(var) for i, var in enumerate(eqn.invars)
|
||||
if var in eqn.invars[:i]}
|
||||
sig = get_jaxpr_type_signature(jaxpr)
|
||||
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
|
||||
# Double consumption detected: re-trace with context for better errors.
|
||||
get_jaxpr_type_signature(jaxpr, args_consumed, forwarded_inputs)
|
||||
return sig
|
||||
|
||||
key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature
|
||||
|
@ -28,6 +28,7 @@ from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
from jax._src import util
|
||||
from jax._src.ad_checkpoint import remat_p
|
||||
from jax._src.debugging import debug_callback_p
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
@ -293,3 +294,21 @@ def _while_key_type_signature(eqn, args_consumed):
|
||||
return body_signature
|
||||
|
||||
key_reuse_signatures_dynamic[jax.lax.while_p] = _while_key_type_signature
|
||||
|
||||
|
||||
def _remat_key_type_signature(eqn, args_consumed):
|
||||
# The assumption here is that the non-differentiated pass contains all relevant
|
||||
# key usage, and the differentiated pass
|
||||
# 1) will only consume keys that are already consumed in the non-differentiated pass
|
||||
# 2) will never create keys
|
||||
# Therefore, the differentiated pass is a no-op.
|
||||
if eqn.params['differentiated']:
|
||||
return KeyReuseSignature([], [])
|
||||
jaxpr = eqn.params['jaxpr']
|
||||
sig = get_jaxpr_type_signature(jaxpr)
|
||||
if args_consumed and any(np.any(args_consumed[s.idx] & s.mask) for s in sig.sinks):
|
||||
# Double consumption detected: re-trace with context for better errors.
|
||||
get_jaxpr_type_signature(jaxpr, args_consumed)
|
||||
return sig
|
||||
|
||||
key_reuse_signatures_dynamic[remat_p] = _remat_key_type_signature
|
||||
|
@ -802,6 +802,28 @@ class KeyReuseIntegrationTest(jtu.JaxTestCase):
|
||||
|
||||
self.check_key_reuse(f, 0)
|
||||
|
||||
@jax.numpy_dtype_promotion('standard')
|
||||
def test_remat(self):
|
||||
@jax.checkpoint
|
||||
def f_bad(x, key):
|
||||
return x * jax.random.bits(key) + jax.random.bits(key)
|
||||
|
||||
@jax.checkpoint
|
||||
def f_good(x, key):
|
||||
return x * jax.random.bits(key)
|
||||
|
||||
x = jnp.float32(1.0)
|
||||
key = jax.random.key(0)
|
||||
|
||||
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
|
||||
self.check_key_reuse(f_bad, x, key)
|
||||
|
||||
with self.assertRaisesRegex(KeyReuseError, self.random_bits_error):
|
||||
self.check_key_reuse(jax.grad(f_bad), x, key)
|
||||
|
||||
self.check_key_reuse(f_good, x, key)
|
||||
self.check_key_reuse(jax.grad(f_good), x, key)
|
||||
|
||||
|
||||
class KeyReuseIntegrationTestSimple(KeyReuseIntegrationTest):
|
||||
use_forwarding = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user