Call shard_arg fallback in pjit's cpp fast path instead of dropping out completely.

PiperOrigin-RevId: 592344105
This commit is contained in:
Yash Katariya 2023-12-19 14:25:25 -08:00 committed by jax authors
parent 67d5c3bdea
commit 9b6bf2cab0
3 changed files with 46 additions and 12 deletions

View File

@ -2745,6 +2745,8 @@ class MeshExecutableFastpathData(NamedTuple):
out_avals: Sequence[ShapedArray]
out_committed: Sequence[bool]
kept_var_bitvec: Iterable[bool]
arg_handler_devices: Sequence[xc.Device]
arg_handler_indices: Sequence[tuple[Index | None, ...]]
def reflatten_outputs_for_dispatch(out_tree, out_flat):
@ -2845,13 +2847,19 @@ class MeshExecutable(stages.XlaExecutable):
for i in range(len(args_flat))]
fastpath_data = MeshExecutableFastpathData(
self.xla_executable, out_tree_dispatch, self._in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
self._out_shardings, out_avals, out_committed, kept_var_bitvec,
self.unsafe_call.in_handler.local_devices,
self.unsafe_call.in_handler.input_indices)
else:
fastpath_data = None
return outs, fastpath_data
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry)
if xla_extension_version >= 226:
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, shard_arg)
else:
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], # type: ignore
tree_util.dispatch_registry)
def check_arg_avals_for_call(ref_avals, arg_avals,

View File

@ -56,6 +56,7 @@ from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.sharding_impls import (
NamedSharding, XLACompatibleSharding, GSPMDSharding,
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
@ -214,7 +215,9 @@ def _get_fastpath_data(executable, out_tree, args_flat, out_flat):
for i in range(len(args_flat))]
fastpath_data = pxla.MeshExecutableFastpathData(
executable.xla_executable, out_tree, executable._in_shardings,
executable._out_shardings, out_avals, out_committed, kept_var_bitvec)
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
executable.unsafe_call.in_handler.local_devices,
executable.unsafe_call.in_handler.input_indices)
else:
fastpath_data = None
return fastpath_data
@ -260,11 +263,18 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)
return outs, fastpath_data
cpp_pjit_f = xc._xla.pjit(
getattr(fun, "__name__", "<unnamed function>"),
fun, cache_miss, static_argnums, static_argnames,
donate_argnums, tree_util.dispatch_registry,
_get_cpp_global_cache(pjit_has_explicit_sharding))
if xla_extension_version >= 226:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"),
fun, cache_miss, static_argnums, static_argnames,
donate_argnums, tree_util.dispatch_registry,
pxla.shard_arg, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
else:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"),
fun, cache_miss, static_argnums, static_argnames,
donate_argnums, tree_util.dispatch_registry,
_get_cpp_global_cache(pjit_has_explicit_sharding))
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
@ -1242,9 +1252,14 @@ def _pjit_call_impl(*args, jaxpr,
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, None, None)
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry,
_get_cpp_global_cache(has_explicit_sharding))(*args)
if xla_extension_version >= 226:
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry, pxla.shard_arg,
_get_cpp_global_cache(has_explicit_sharding))(*args)
else:
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
tree_util.dispatch_registry,
_get_cpp_global_cache(has_explicit_sharding))(*args)
pjit_p.def_impl(_pjit_call_impl)

View File

@ -4206,6 +4206,17 @@ class PJitErrorTest(jtu.JaxTestCase):
r"sharding.*the computation was compiled with"):
g(x, y2)
def test_dce_no_array(self):
mesh = jtu.create_global_mesh((2,), ('x',))
arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x')))
@jax.jit
def f(a, b, c):
return a, c
f(arr, 2., 3.)
f(arr, 2., 3.) # doesn't crash
@jtu.pytest_mark_if_available('multiaccelerator')
class UtilTest(jtu.JaxTestCase):