Allow tokens being passed to jit and through dispatch and being returned from the jitted function.

Fixes https://github.com/google/jax/issues/21160

PiperOrigin-RevId: 632531105
This commit is contained in:
Yash Katariya 2024-05-10 10:11:55 -07:00 committed by jax authors
parent 0267ed0ba9
commit bac3a6fa8f
3 changed files with 29 additions and 24 deletions

View File

@ -933,13 +933,6 @@ def _array_shard_arg(x, sharding):
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
def _token_shard_arg(x, sharding):
return _array_shard_arg(x._buf, sharding)
pxla.shard_arg_handlers[core.Token] = _token_shard_arg
def _array_global_result_handler(global_aval, out_sharding, committed):
if global_aval.dtype == dtypes.float0:
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
@ -952,22 +945,6 @@ def _array_global_result_handler(global_aval, out_sharding, committed):
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler
def _token_global_result_handler(global_aval, out_sharding, committed):
array_handler = _array_global_result_handler(
core.token_shaped_array, out_sharding, committed
)
def wrapper(*args, **kwargs):
out_buf = array_handler(*args, **kwargs)
return core.Token(out_buf)
return wrapper
pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler
# Only used for Arrays that come out of pmap.
def _array_local_result_handler(aval, sharding, indices):
if aval.dtype == dtypes.float0:
@ -980,3 +957,21 @@ def _array_local_result_handler(aval, sharding, indices):
)
pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler
pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler
# Token handlers
def _token_shard_arg(x, sharding):
return _array_shard_arg(x._buf, sharding)
pxla.shard_arg_handlers[core.Token] = _token_shard_arg
def _token_global_result_handler(global_aval, out_sharding, committed):
array_handler = _array_global_result_handler(
core.token_shaped_array, out_sharding, committed)
def wrapper(*args, **kwargs):
out_buf = array_handler(*args, **kwargs)
return core.Token(out_buf)
return wrapper
pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler

View File

@ -2528,7 +2528,8 @@ def _get_out_sharding_from_orig_sharding(
out = []
orig_handler = _orig_out_sharding_handlers[type(orig_in_s)]
for o, out_aval in safe_zip(out_shardings, out_avals):
if isinstance(o, sharding_impls.GSPMDSharding):
if (isinstance(o, sharding_impls.GSPMDSharding) and
out_aval is not core.abstract_token):
# Only return the same input sharding object if the OpShardings and
# in_aval.ndim and out_aval.ndim match. This is because if OpSharding is
# replicated then, it doesn't encode the ndim in it. The devices

View File

@ -4051,6 +4051,15 @@ class ArrayPjitTest(jtu.JaxTestCase):
jax.vmap(jax.grad(model), in_axes=(None, 0))(params, x) # doesn't crash
def test_jit_token_input(self):
x = jnp.arange(8)
token = jax.lax.create_token(None)
device = jax.devices()[0]
x = jax.device_put(x, device=device)
out1, out2 = jax.jit(lambda x, t: (x, t))(x, token)
self.assertArraysEqual(out1, x)
self.assertIsInstance(out2, core.Token)
class TempSharding(Sharding):