From bac3a6fa8f099ec66412d6103834aa6ad931a7d1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 10 May 2024 10:11:55 -0700 Subject: [PATCH] Allow tokens being passed to `jit` and through dispatch and being returned from the jitted function. Fixes https://github.com/google/jax/issues/21160 PiperOrigin-RevId: 632531105 --- jax/_src/array.py | 41 +++++++++++++++-------------------- jax/_src/interpreters/pxla.py | 3 ++- tests/pjit_test.py | 9 ++++++++ 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 1afdcd2c5..555b2f7ac 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -933,13 +933,6 @@ def _array_shard_arg(x, sharding): pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg -def _token_shard_arg(x, sharding): - return _array_shard_arg(x._buf, sharding) - - -pxla.shard_arg_handlers[core.Token] = _token_shard_arg - - def _array_global_result_handler(global_aval, out_sharding, committed): if global_aval.dtype == dtypes.float0: return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore @@ -952,22 +945,6 @@ def _array_global_result_handler(global_aval, out_sharding, committed): pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler - -def _token_global_result_handler(global_aval, out_sharding, committed): - array_handler = _array_global_result_handler( - core.token_shaped_array, out_sharding, committed - ) - - def wrapper(*args, **kwargs): - out_buf = array_handler(*args, **kwargs) - return core.Token(out_buf) - - return wrapper - - -pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler - - # Only used for Arrays that come out of pmap. def _array_local_result_handler(aval, sharding, indices): if aval.dtype == dtypes.float0: @@ -980,3 +957,21 @@ def _array_local_result_handler(aval, sharding, indices): ) pxla.local_result_handlers[core.ShapedArray] = _array_local_result_handler pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler + + +# Token handlers + +def _token_shard_arg(x, sharding): + return _array_shard_arg(x._buf, sharding) +pxla.shard_arg_handlers[core.Token] = _token_shard_arg + + +def _token_global_result_handler(global_aval, out_sharding, committed): + array_handler = _array_global_result_handler( + core.token_shaped_array, out_sharding, committed) + + def wrapper(*args, **kwargs): + out_buf = array_handler(*args, **kwargs) + return core.Token(out_buf) + return wrapper +pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 16d5da7dc..edb08a078 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2528,7 +2528,8 @@ def _get_out_sharding_from_orig_sharding( out = [] orig_handler = _orig_out_sharding_handlers[type(orig_in_s)] for o, out_aval in safe_zip(out_shardings, out_avals): - if isinstance(o, sharding_impls.GSPMDSharding): + if (isinstance(o, sharding_impls.GSPMDSharding) and + out_aval is not core.abstract_token): # Only return the same input sharding object if the OpShardings and # in_aval.ndim and out_aval.ndim match. This is because if OpSharding is # replicated then, it doesn't encode the ndim in it. The devices diff --git a/tests/pjit_test.py b/tests/pjit_test.py index fb47eb288..dac578b35 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4051,6 +4051,15 @@ class ArrayPjitTest(jtu.JaxTestCase): jax.vmap(jax.grad(model), in_axes=(None, 0))(params, x) # doesn't crash + def test_jit_token_input(self): + x = jnp.arange(8) + token = jax.lax.create_token(None) + device = jax.devices()[0] + x = jax.device_put(x, device=device) + out1, out2 = jax.jit(lambda x, t: (x, t))(x, token) + self.assertArraysEqual(out1, x) + self.assertIsInstance(out2, core.Token) + class TempSharding(Sharding):