Do the sharding.addressable_devices check only once in _get_input_indices since all shardings should have the same device_assignment.

That check happens at the start of lower_sharding_computation. Also use the optimized DeviceAssignment object which has all the calculations cached if this path is hit multiple times.

Also remove `device_assignment` from MeshExecutable since it is not used anywhere in that class

PiperOrigin-RevId: 523182028
This commit is contained in:
Yash Katariya 2023-04-10 12:22:45 -07:00 committed by jax authors
parent f25b701b26
commit 49438c78e4
2 changed files with 54 additions and 50 deletions

View File

@ -2495,14 +2495,22 @@ class MeshComputation(stages.XlaLowering):
def _get_input_indices(
avals: Sequence[ShapedArray], shardings: Sequence[sharding_impls.XLACompatibleSharding]
avals: Sequence[ShapedArray],
shardings: Sequence[sharding_impls.XLACompatibleSharding],
da_object: Union[_DeviceAssignment, Sequence[xc.Device]],
) -> Sequence[Tuple[Optional[Index], ...]]:
input_indices = []
if isinstance(da_object, _DeviceAssignment):
num_addressable_devices = len(da_object.addressable_device_assignment)
else:
num_addressable_devices = len(
[d for d in da_object if d.process_index == d.client.process_index()])
for aval, sharding in zip(avals, shardings):
if aval is core.abstract_token:
index = tuple(
(slice(None),) for _ in range(len(sharding.addressable_devices)))
(slice(None),) for _ in range(num_addressable_devices))
else:
# We special case this logic to support fully replicated values because
# the mesh is global mesh and the indices returned by `spec_to_indices` will
@ -2511,12 +2519,10 @@ def _get_input_indices(
proto = sharding._to_xla_op_sharding(aval.ndim)
if op_shardings.is_op_sharding_replicated(proto):
index = tuple(
(slice(None),) * aval.ndim
for _ in range(len(sharding.addressable_devices))) # type: ignore
(slice(None),) * aval.ndim for _ in range(num_addressable_devices)) # type: ignore
else:
index = tuple(
sharding.addressable_devices_indices_map(
aval.shape).values()) # type: ignore
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
input_indices.append(index)
return input_indices
@ -2683,7 +2689,7 @@ def _cached_compilation(computation, name, mesh, num_out_avals, spmd_lowering,
@dataclasses.dataclass
class UnloadedMeshExecutable:
xla_executable: Any
device_assignment: Sequence[xc.Device]
device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]]
backend: xb.XlaBackend
input_avals: Sequence[ShapedArray]
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
@ -2700,7 +2706,8 @@ class UnloadedMeshExecutable:
auto_spmd_lowering: bool
def build_unsafe_call(self):
input_indices = _get_input_indices(self.input_avals, self.input_shardings)
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
self.device_assignment)
handle_args = InputsHandler(self.xla_executable.local_devices(),
self.input_shardings, input_indices)
handle_outs = global_avals_to_results_handler(
@ -2718,7 +2725,7 @@ class UnloadedMeshExecutable:
self.input_avals,
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self.device_assignment, self)
self)
# May return a MeshExecutable in the compile_replicated case.
@staticmethod
@ -2756,6 +2763,7 @@ class UnloadedMeshExecutable:
compiler_options.values()) if compiler_options is not None else None
da = device_assignment if isinstance(
device_assignment, _DeviceAssignment) else tuple(device_assignment)
del device_assignment
xla_executable, compile_options = _cached_compilation(
computation, name, mesh, len(global_out_avals), spmd_lowering,
tuple_args, auto_spmd_lowering, allow_propagation_to_outputs,
@ -2772,10 +2780,6 @@ class UnloadedMeshExecutable:
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
pmap_nreps)
del da
device_assignment = device_assignment.device_assignment if isinstance(
device_assignment, _DeviceAssignment) else device_assignment
if auto_spmd_lowering:
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
@ -2790,8 +2794,10 @@ class UnloadedMeshExecutable:
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
and pmap_nreps == 1):
assert mesh is None
device_assignment = da.device_assignment if isinstance( # type: ignore
da, _DeviceAssignment) else da
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
xla_executable, device_assignment,
xla_executable, device_assignment, # type: ignore
len(global_in_avals), len(global_out_avals))
orig_out_shardings = out_shardings
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
@ -2813,26 +2819,15 @@ class UnloadedMeshExecutable:
are_out_shardings_from_xla = (False,) * len(global_out_avals)
if pmap_nreps > 1:
local_devices = xla_executable.local_devices()
# Create replicated shardings for jit(pmap) path with local devices
# because multihost jit(pmap) is not allowed.
in_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * len(in_shardings)
out_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * len(out_shardings)
# jit(pmap) will generate Arrays with multi-device sharding.
# It is unsupported for these shardings to be uncommited, so force
# the outputs to be committed.
committed = True
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding(
in_shardings, out_shardings, are_out_shardings_from_xla)
return UnloadedMeshExecutable(
xla_executable=xla_executable,
device_assignment=device_assignment,
device_assignment=da, # type: ignore
backend=backend,
input_avals=global_in_avals,
input_shardings=in_shardings, # type: ignore
@ -2861,17 +2856,14 @@ class MeshExecutableFastpathData(NamedTuple):
class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "_unsafe_call",
"build_unsafe_call", "in_avals",
"_in_shardings", "_out_shardings",
"_auto_spmd_lowering", "_kept_var_idx",
"_device_assignment",
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
"_unloaded_executable",
]
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
out_shardings, auto_spmd_lowering, kept_var_idx,
device_assignment, unloaded_executable=None):
unloaded_executable=None):
self.xla_executable = xla_executable
self.build_unsafe_call = build_unsafe_call
# in_avals is a list of global and local avals. Aval is global if input
@ -2882,7 +2874,6 @@ class MeshExecutable(stages.XlaExecutable):
self._out_shardings = out_shardings
self._auto_spmd_lowering = auto_spmd_lowering
self._kept_var_idx = kept_var_idx
self._device_assignment = device_assignment
self._unloaded_executable = unloaded_executable
@property
@ -2899,11 +2890,11 @@ class MeshExecutable(stages.XlaExecutable):
if hasattr(backend, "compile_replicated"):
return _compile_replicated_mesh_executable_from_trivial_jaxpr(
jaxpr, consts, global_in_avals, global_out_avals, in_shardings,
backend, da_object.device_assignment, committed, kept_var_idx, 1)
backend, da_object, committed, kept_var_idx, 1)
out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, da_object.device_assignment)
indices = _get_input_indices(global_out_avals, out_shardings)
indices = _get_input_indices(global_out_avals, out_shardings, da_object)
local_device_assignment = da_object.addressable_device_assignment
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
handle_outs = global_avals_to_results_handler(
@ -2913,7 +2904,7 @@ class MeshExecutable(stages.XlaExecutable):
handle_outs, kept_var_idx)
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, False, kept_var_idx,
da_object.device_assignment, None)
None)
# -- stages.XlaExecutable overrides
@ -2975,6 +2966,22 @@ def check_arg_avals_for_call(ref_avals, arg_avals):
f"called with:\n {arg_aval}")
def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
# Create replicated shardings for jit(pmap) path with local devices
# because multihost jit(pmap) is not allowed.
in_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * num_in_shardings
out_shardings = [
sharding_impls.GSPMDSharding.get_replicated(local_devices)
] * num_out_shardings
# jit(pmap) will generate Arrays with multi-device sharding.
# It is unsupported for these shardings to be uncommited, so force
# the outputs to be committed.
committed = True
return in_shardings, out_shardings, committed, tuple(local_devices)
def _out_shardings_for_trivial(
jaxpr: core.Jaxpr, consts: Sequence[Any],
in_shardings: Sequence[sharding_impls.XLACompatibleSharding],
@ -3026,11 +3033,8 @@ def _compile_replicated_mesh_executable_from_hlo(
in_shardings = semantics_in_shardings.shardings
out_shardings = semantics_out_shardings.shardings
device_assignment = da.device_assignment if isinstance(
da, _DeviceAssignment) else da
input_indices = _get_input_indices(
global_in_avals, in_shardings) # type: ignore
input_indices = _get_input_indices(global_in_avals, in_shardings, da) # type: ignore
if pmap_nreps > 1:
# For a jit wrapping a pmap, replicate each input index to match the
# devices of the replicated jit computation.
@ -3049,30 +3053,29 @@ def _compile_replicated_mesh_executable_from_hlo(
xla_executable = None
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, auto_spmd_lowering,
kept_var_idx, device_assignment, None)
kept_var_idx, None)
def _compile_replicated_mesh_executable_from_trivial_jaxpr(
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
device_assignment, committed, kept_var_idx, pmap_nreps):
da_object, committed, kept_var_idx, pmap_nreps):
out_shardings = _out_shardings_for_trivial(
jaxpr, consts, in_shardings, device_assignment)
jaxpr, consts, in_shardings, da_object.device_assignment)
input_indices = _get_input_indices(
global_in_avals, in_shardings) # type: ignore
input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed,
[False] * len(global_out_avals))
# Use the standard out_handler.
unsafe_call = backend.compile_replicated(
is_trivial=True, jaxpr=jaxpr, consts=consts,
device_assignment=device_assignment, in_avals=global_in_avals,
device_assignment=da_object.device_assignment, in_avals=global_in_avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=kept_var_idx, out_handler=handle_outs,
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, False, kept_var_idx,
device_assignment, None)
None)
@lru_cache()

View File

@ -3534,7 +3534,8 @@ class UtilTest(jtu.JaxTestCase):
mp = NamedSharding(global_mesh, P(None))
out_indices = pxla._get_input_indices(in_avals, [mp, mp, mp])
out_indices = pxla._get_input_indices(in_avals, [mp, mp, mp],
list(global_mesh.devices.flat))
self.assertLen(out_indices, len(in_avals))
self.assertTrue(all(len(out) == len(global_mesh.local_devices)