mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[key reuse] fix signature for device_put
This commit is contained in:
parent
a4a657bc43
commit
f83175fc94
@ -292,7 +292,6 @@ key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature(Sink(0))
|
||||
key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[lax.copy_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[lax.device_put_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[lax.reshape_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature(Source(0))
|
||||
@ -561,6 +560,14 @@ def _remat_key_type_signature(eqn):
|
||||
key_reuse_signatures[remat_p] = _remat_key_type_signature
|
||||
|
||||
|
||||
@dynamic_key_reuse_signature
|
||||
def _device_put_signature(eqn):
|
||||
num_vals = len(eqn.invars)
|
||||
return KeyReuseSignature(*(Forward(i, i) for i in range(num_vals)))
|
||||
|
||||
key_reuse_signatures[lax.device_put_p] = _device_put_signature
|
||||
|
||||
|
||||
def call_impl_with_key_reuse_checks(prim: core.Primitive, raw_impl: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
if prim not in key_reuse_signatures:
|
||||
# TODO(jakevdp): should we use an unknown signature here?
|
||||
|
@ -49,7 +49,6 @@ primitives_with_static_signatures = {
|
||||
jax.lax.broadcast_in_dim_p: (lambda key: key[None], key),
|
||||
jax.lax.copy_p: (jnp.array, key),
|
||||
jax.lax.convert_element_type_p: (lambda key: jnp.array(key, dtype=key.dtype), key),
|
||||
jax.lax.device_put_p: (jax.device_put, key),
|
||||
jax.lax.reshape_p: (lambda key: key.reshape((1,)), key),
|
||||
jax.lax.squeeze_p: (jnp.squeeze, key1D),
|
||||
jax.lax.dynamic_slice_p: (partial(jax.lax.dynamic_slice, slice_sizes=(1,)), key1D, (0,)),
|
||||
@ -178,14 +177,27 @@ class KeyReuseUnitTestWithForwarding(jtu.JaxTestCase):
|
||||
def test_device_put(self):
|
||||
def f(key):
|
||||
assert_unconsumed(key)
|
||||
key2 = jax.device_put(key)
|
||||
assert_unconsumed(key)
|
||||
assert_unconsumed(key2)
|
||||
key_d = jax.device_put(key)
|
||||
assert_unconsumed(key_d)
|
||||
consume(key)
|
||||
assert_consumed(key)
|
||||
assert_consumed(key2)
|
||||
assert_consumed(key_d)
|
||||
self.check_key_reuse(f, jax.random.key(0))
|
||||
|
||||
def test_device_put_multiple(self):
|
||||
def f(key1, key2):
|
||||
assert_unconsumed(key1)
|
||||
assert_unconsumed(key2)
|
||||
key1_d, key2_d = jax.device_put((key1, key2))
|
||||
|
||||
assert_unconsumed(key1_d)
|
||||
consume(key1)
|
||||
assert_consumed(key1_d)
|
||||
|
||||
assert_unconsumed(key2_d)
|
||||
consume(key2)
|
||||
assert_consumed(key2_d)
|
||||
self.check_key_reuse(f, jax.random.key(0), jax.random.key(1))
|
||||
|
||||
def test_squeeze(self):
|
||||
def f(key):
|
||||
assert_unconsumed(key)
|
||||
|
Loading…
x
Reference in New Issue
Block a user