From 9b6bf2cab014a06fad8920959759d2b36c55fb3b Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 19 Dec 2023 14:25:25 -0800 Subject: [PATCH] Call shard_arg fallback in pjit's cpp fast path instead of dropping out completely. PiperOrigin-RevId: 592344105 --- jax/_src/interpreters/pxla.py | 14 +++++++++++--- jax/_src/pjit.py | 33 ++++++++++++++++++++++++--------- tests/pjit_test.py | 11 +++++++++++ 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index aada85d1e..a63c2268c 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2745,6 +2745,8 @@ class MeshExecutableFastpathData(NamedTuple): out_avals: Sequence[ShapedArray] out_committed: Sequence[bool] kept_var_bitvec: Iterable[bool] + arg_handler_devices: Sequence[xc.Device] + arg_handler_indices: Sequence[tuple[Index | None, ...]] def reflatten_outputs_for_dispatch(out_tree, out_flat): @@ -2845,13 +2847,19 @@ class MeshExecutable(stages.XlaExecutable): for i in range(len(args_flat))] fastpath_data = MeshExecutableFastpathData( self.xla_executable, out_tree_dispatch, self._in_shardings, - self._out_shardings, out_avals, out_committed, kept_var_bitvec) + self._out_shardings, out_avals, out_committed, kept_var_bitvec, + self.unsafe_call.in_handler.local_devices, + self.unsafe_call.in_handler.input_indices) else: fastpath_data = None return outs, fastpath_data - return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry) + if xla_extension_version >= 226: + return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], + tree_util.dispatch_registry, shard_arg) + else: + return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], # type: ignore + tree_util.dispatch_registry) def check_arg_avals_for_call(ref_avals, arg_avals, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 21831f90c..deadc0403 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -56,6 +56,7 @@ from jax._src.interpreters import pxla from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.sharding_impls import ( NamedSharding, XLACompatibleSharding, GSPMDSharding, XLADeviceAssignment, SingleDeviceSharding, PmapSharding, @@ -214,7 +215,9 @@ def _get_fastpath_data(executable, out_tree, args_flat, out_flat): for i in range(len(args_flat))] fastpath_data = pxla.MeshExecutableFastpathData( executable.xla_executable, out_tree, executable._in_shardings, - executable._out_shardings, out_avals, out_committed, kept_var_bitvec) + executable._out_shardings, out_avals, out_committed, kept_var_bitvec, + executable.unsafe_call.in_handler.local_devices, + executable.unsafe_call.in_handler.input_indices) else: fastpath_data = None return fastpath_data @@ -260,11 +263,18 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames, fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat) return outs, fastpath_data - cpp_pjit_f = xc._xla.pjit( - getattr(fun, "__name__", ""), - fun, cache_miss, static_argnums, static_argnames, - donate_argnums, tree_util.dispatch_registry, - _get_cpp_global_cache(pjit_has_explicit_sharding)) + if xla_extension_version >= 226: + cpp_pjit_f = xc._xla.pjit( # type: ignore + getattr(fun, "__name__", ""), + fun, cache_miss, static_argnums, static_argnames, + donate_argnums, tree_util.dispatch_registry, + pxla.shard_arg, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore + else: + cpp_pjit_f = xc._xla.pjit( # type: ignore + getattr(fun, "__name__", ""), + fun, cache_miss, static_argnums, static_argnames, + donate_argnums, tree_util.dispatch_registry, + _get_cpp_global_cache(pjit_has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun @@ -1242,9 +1252,14 @@ def _pjit_call_impl(*args, jaxpr, donated_argnums = [i for i, d in enumerate(donated_invars) if d] has_explicit_sharding = _pjit_explicit_sharding( in_shardings, out_shardings, None, None) - return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, - _get_cpp_global_cache(has_explicit_sharding))(*args) + if xla_extension_version >= 226: + return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, + tree_util.dispatch_registry, pxla.shard_arg, + _get_cpp_global_cache(has_explicit_sharding))(*args) + else: + return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore + tree_util.dispatch_registry, + _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 73642ee1d..b37f14f3a 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4206,6 +4206,17 @@ class PJitErrorTest(jtu.JaxTestCase): r"sharding.*the computation was compiled with"): g(x, y2) + def test_dce_no_array(self): + mesh = jtu.create_global_mesh((2,), ('x',)) + arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x'))) + + @jax.jit + def f(a, b, c): + return a, c + + f(arr, 2., 3.) + f(arr, 2., 3.) # doesn't crash + @jtu.pytest_mark_if_available('multiaccelerator') class UtilTest(jtu.JaxTestCase):