mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Call shard_arg fallback in pjit's cpp fast path instead of dropping out completely.
PiperOrigin-RevId: 592344105
This commit is contained in:
parent
67d5c3bdea
commit
9b6bf2cab0
@ -2745,6 +2745,8 @@ class MeshExecutableFastpathData(NamedTuple):
|
||||
out_avals: Sequence[ShapedArray]
|
||||
out_committed: Sequence[bool]
|
||||
kept_var_bitvec: Iterable[bool]
|
||||
arg_handler_devices: Sequence[xc.Device]
|
||||
arg_handler_indices: Sequence[tuple[Index | None, ...]]
|
||||
|
||||
|
||||
def reflatten_outputs_for_dispatch(out_tree, out_flat):
|
||||
@ -2845,13 +2847,19 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
for i in range(len(args_flat))]
|
||||
fastpath_data = MeshExecutableFastpathData(
|
||||
self.xla_executable, out_tree_dispatch, self._in_shardings,
|
||||
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
|
||||
self._out_shardings, out_avals, out_committed, kept_var_bitvec,
|
||||
self.unsafe_call.in_handler.local_devices,
|
||||
self.unsafe_call.in_handler.input_indices)
|
||||
else:
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data
|
||||
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.dispatch_registry)
|
||||
if xla_extension_version >= 226:
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.dispatch_registry, shard_arg)
|
||||
else:
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], # type: ignore
|
||||
tree_util.dispatch_registry)
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
|
@ -56,6 +56,7 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, XLACompatibleSharding, GSPMDSharding,
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
|
||||
@ -214,7 +215,9 @@ def _get_fastpath_data(executable, out_tree, args_flat, out_flat):
|
||||
for i in range(len(args_flat))]
|
||||
fastpath_data = pxla.MeshExecutableFastpathData(
|
||||
executable.xla_executable, out_tree, executable._in_shardings,
|
||||
executable._out_shardings, out_avals, out_committed, kept_var_bitvec)
|
||||
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
|
||||
executable.unsafe_call.in_handler.local_devices,
|
||||
executable.unsafe_call.in_handler.input_indices)
|
||||
else:
|
||||
fastpath_data = None
|
||||
return fastpath_data
|
||||
@ -260,11 +263,18 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)
|
||||
return outs, fastpath_data
|
||||
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun, cache_miss, static_argnums, static_argnames,
|
||||
donate_argnums, tree_util.dispatch_registry,
|
||||
_get_cpp_global_cache(pjit_has_explicit_sharding))
|
||||
if xla_extension_version >= 226:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun, cache_miss, static_argnums, static_argnames,
|
||||
donate_argnums, tree_util.dispatch_registry,
|
||||
pxla.shard_arg, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
||||
else:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun, cache_miss, static_argnums, static_argnames,
|
||||
donate_argnums, tree_util.dispatch_registry,
|
||||
_get_cpp_global_cache(pjit_has_explicit_sharding))
|
||||
|
||||
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
||||
cpp_pjitted_f._fun = fun
|
||||
@ -1242,9 +1252,14 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
|
||||
has_explicit_sharding = _pjit_explicit_sharding(
|
||||
in_shardings, out_shardings, None, None)
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.dispatch_registry,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
if xla_extension_version >= 226:
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.dispatch_registry, pxla.shard_arg,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
else:
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
|
||||
tree_util.dispatch_registry,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
|
||||
pjit_p.def_impl(_pjit_call_impl)
|
||||
|
||||
|
@ -4206,6 +4206,17 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
r"sharding.*the computation was compiled with"):
|
||||
g(x, y2)
|
||||
|
||||
def test_dce_no_array(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x')))
|
||||
|
||||
@jax.jit
|
||||
def f(a, b, c):
|
||||
return a, c
|
||||
|
||||
f(arr, 2., 3.)
|
||||
f(arr, 2., 3.) # doesn't crash
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class UtilTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user