mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Allow tokens being passed to jit
and through dispatch and being returned from the jitted function.
Fixes https://github.com/google/jax/issues/21160 PiperOrigin-RevId: 632531105
This commit is contained in:
parent
0267ed0ba9
commit
bac3a6fa8f
@ -933,13 +933,6 @@ def _array_shard_arg(x, sharding):
|
||||
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
|
||||
|
||||
|
||||
def _token_shard_arg(x, sharding):
|
||||
return _array_shard_arg(x._buf, sharding)
|
||||
|
||||
|
||||
pxla.shard_arg_handlers[core.Token] = _token_shard_arg
|
||||
|
||||
|
||||
def _array_global_result_handler(global_aval, out_sharding, committed):
|
||||
if global_aval.dtype == dtypes.float0:
|
||||
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
|
||||
@ -952,22 +945,6 @@ def _array_global_result_handler(global_aval, out_sharding, committed):
|
||||
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
|
||||
pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler
|
||||
|
||||
|
||||
def _token_global_result_handler(global_aval, out_sharding, committed):
|
||||
array_handler = _array_global_result_handler(
|
||||
core.token_shaped_array, out_sharding, committed
|
||||
)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
out_buf = array_handler(*args, **kwargs)
|
||||
return core.Token(out_buf)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler
|
||||
|
||||
|
||||
# Only used for Arrays that come out of pmap.
|
||||
def _array_local_result_handler(aval, sharding, indices):
|
||||
if aval.dtype == dtypes.float0:
|
||||
@ -980,3 +957,21 @@ def _array_local_result_handler(aval, sharding, indices):
|
||||
)
|
||||
pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler
|
||||
pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler
|
||||
|
||||
|
||||
# Token handlers
|
||||
|
||||
def _token_shard_arg(x, sharding):
|
||||
return _array_shard_arg(x._buf, sharding)
|
||||
pxla.shard_arg_handlers[core.Token] = _token_shard_arg
|
||||
|
||||
|
||||
def _token_global_result_handler(global_aval, out_sharding, committed):
|
||||
array_handler = _array_global_result_handler(
|
||||
core.token_shaped_array, out_sharding, committed)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
out_buf = array_handler(*args, **kwargs)
|
||||
return core.Token(out_buf)
|
||||
return wrapper
|
||||
pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler
|
||||
|
@ -2528,7 +2528,8 @@ def _get_out_sharding_from_orig_sharding(
|
||||
out = []
|
||||
orig_handler = _orig_out_sharding_handlers[type(orig_in_s)]
|
||||
for o, out_aval in safe_zip(out_shardings, out_avals):
|
||||
if isinstance(o, sharding_impls.GSPMDSharding):
|
||||
if (isinstance(o, sharding_impls.GSPMDSharding) and
|
||||
out_aval is not core.abstract_token):
|
||||
# Only return the same input sharding object if the OpShardings and
|
||||
# in_aval.ndim and out_aval.ndim match. This is because if OpSharding is
|
||||
# replicated then, it doesn't encode the ndim in it. The devices
|
||||
|
@ -4051,6 +4051,15 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
jax.vmap(jax.grad(model), in_axes=(None, 0))(params, x) # doesn't crash
|
||||
|
||||
def test_jit_token_input(self):
|
||||
x = jnp.arange(8)
|
||||
token = jax.lax.create_token(None)
|
||||
device = jax.devices()[0]
|
||||
x = jax.device_put(x, device=device)
|
||||
out1, out2 = jax.jit(lambda x, t: (x, t))(x, token)
|
||||
self.assertArraysEqual(out1, x)
|
||||
self.assertIsInstance(out2, core.Token)
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user