mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Do the sharding.addressable_devices check only once in _get_input_indices since all shardings should have the same device_assignment.
That check happens at the start of lower_sharding_computation. Also use the optimized DeviceAssignment object which has all the calculations cached if this path is hit multiple times. Also remove `device_assignment` from MeshExecutable since it is not used anywhere in that class PiperOrigin-RevId: 523182028
This commit is contained in:
parent
f25b701b26
commit
49438c78e4
@ -2495,14 +2495,22 @@ class MeshComputation(stages.XlaLowering):
|
||||
|
||||
|
||||
def _get_input_indices(
|
||||
avals: Sequence[ShapedArray], shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
avals: Sequence[ShapedArray],
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
da_object: Union[_DeviceAssignment, Sequence[xc.Device]],
|
||||
) -> Sequence[Tuple[Optional[Index], ...]]:
|
||||
|
||||
input_indices = []
|
||||
if isinstance(da_object, _DeviceAssignment):
|
||||
num_addressable_devices = len(da_object.addressable_device_assignment)
|
||||
else:
|
||||
num_addressable_devices = len(
|
||||
[d for d in da_object if d.process_index == d.client.process_index()])
|
||||
|
||||
for aval, sharding in zip(avals, shardings):
|
||||
if aval is core.abstract_token:
|
||||
index = tuple(
|
||||
(slice(None),) for _ in range(len(sharding.addressable_devices)))
|
||||
(slice(None),) for _ in range(num_addressable_devices))
|
||||
else:
|
||||
# We special case this logic to support fully replicated values because
|
||||
# the mesh is global mesh and the indices returned by `spec_to_indices` will
|
||||
@ -2511,12 +2519,10 @@ def _get_input_indices(
|
||||
proto = sharding._to_xla_op_sharding(aval.ndim)
|
||||
if op_shardings.is_op_sharding_replicated(proto):
|
||||
index = tuple(
|
||||
(slice(None),) * aval.ndim
|
||||
for _ in range(len(sharding.addressable_devices))) # type: ignore
|
||||
(slice(None),) * aval.ndim for _ in range(num_addressable_devices)) # type: ignore
|
||||
else:
|
||||
index = tuple(
|
||||
sharding.addressable_devices_indices_map(
|
||||
aval.shape).values()) # type: ignore
|
||||
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
|
||||
input_indices.append(index)
|
||||
|
||||
return input_indices
|
||||
@ -2683,7 +2689,7 @@ def _cached_compilation(computation, name, mesh, num_out_avals, spmd_lowering,
|
||||
@dataclasses.dataclass
|
||||
class UnloadedMeshExecutable:
|
||||
xla_executable: Any
|
||||
device_assignment: Sequence[xc.Device]
|
||||
device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]]
|
||||
backend: xb.XlaBackend
|
||||
input_avals: Sequence[ShapedArray]
|
||||
input_shardings: Sequence[sharding_impls.XLACompatibleSharding]
|
||||
@ -2700,7 +2706,8 @@ class UnloadedMeshExecutable:
|
||||
auto_spmd_lowering: bool
|
||||
|
||||
def build_unsafe_call(self):
|
||||
input_indices = _get_input_indices(self.input_avals, self.input_shardings)
|
||||
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
|
||||
self.device_assignment)
|
||||
handle_args = InputsHandler(self.xla_executable.local_devices(),
|
||||
self.input_shardings, input_indices)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
@ -2718,7 +2725,7 @@ class UnloadedMeshExecutable:
|
||||
self.input_avals,
|
||||
self.input_shardings, self.output_shardings,
|
||||
self.auto_spmd_lowering, self.kept_var_idx,
|
||||
self.device_assignment, self)
|
||||
self)
|
||||
|
||||
# May return a MeshExecutable in the compile_replicated case.
|
||||
@staticmethod
|
||||
@ -2756,6 +2763,7 @@ class UnloadedMeshExecutable:
|
||||
compiler_options.values()) if compiler_options is not None else None
|
||||
da = device_assignment if isinstance(
|
||||
device_assignment, _DeviceAssignment) else tuple(device_assignment)
|
||||
del device_assignment
|
||||
xla_executable, compile_options = _cached_compilation(
|
||||
computation, name, mesh, len(global_out_avals), spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering, allow_propagation_to_outputs,
|
||||
@ -2772,10 +2780,6 @@ class UnloadedMeshExecutable:
|
||||
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
|
||||
pmap_nreps)
|
||||
|
||||
del da
|
||||
device_assignment = device_assignment.device_assignment if isinstance(
|
||||
device_assignment, _DeviceAssignment) else device_assignment
|
||||
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None
|
||||
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
|
||||
@ -2790,8 +2794,10 @@ class UnloadedMeshExecutable:
|
||||
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
|
||||
and pmap_nreps == 1):
|
||||
assert mesh is None
|
||||
device_assignment = da.device_assignment if isinstance( # type: ignore
|
||||
da, _DeviceAssignment) else da
|
||||
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
||||
xla_executable, device_assignment,
|
||||
xla_executable, device_assignment, # type: ignore
|
||||
len(global_in_avals), len(global_out_avals))
|
||||
orig_out_shardings = out_shardings
|
||||
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
|
||||
@ -2813,26 +2819,15 @@ class UnloadedMeshExecutable:
|
||||
are_out_shardings_from_xla = (False,) * len(global_out_avals)
|
||||
|
||||
if pmap_nreps > 1:
|
||||
local_devices = xla_executable.local_devices()
|
||||
# Create replicated shardings for jit(pmap) path with local devices
|
||||
# because multihost jit(pmap) is not allowed.
|
||||
in_shardings = [
|
||||
sharding_impls.GSPMDSharding.get_replicated(local_devices)
|
||||
] * len(in_shardings)
|
||||
out_shardings = [
|
||||
sharding_impls.GSPMDSharding.get_replicated(local_devices)
|
||||
] * len(out_shardings)
|
||||
# jit(pmap) will generate Arrays with multi-device sharding.
|
||||
# It is unsupported for these shardings to be uncommited, so force
|
||||
# the outputs to be committed.
|
||||
committed = True
|
||||
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
|
||||
xla_executable.local_devices(), len(in_shardings), len(out_shardings))
|
||||
|
||||
out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding(
|
||||
in_shardings, out_shardings, are_out_shardings_from_xla)
|
||||
|
||||
return UnloadedMeshExecutable(
|
||||
xla_executable=xla_executable,
|
||||
device_assignment=device_assignment,
|
||||
device_assignment=da, # type: ignore
|
||||
backend=backend,
|
||||
input_avals=global_in_avals,
|
||||
input_shardings=in_shardings, # type: ignore
|
||||
@ -2861,17 +2856,14 @@ class MeshExecutableFastpathData(NamedTuple):
|
||||
|
||||
class MeshExecutable(stages.XlaExecutable):
|
||||
__slots__ = [
|
||||
"xla_executable", "_unsafe_call",
|
||||
"build_unsafe_call", "in_avals",
|
||||
"_in_shardings", "_out_shardings",
|
||||
"_auto_spmd_lowering", "_kept_var_idx",
|
||||
"_device_assignment",
|
||||
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
|
||||
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
|
||||
"_unloaded_executable",
|
||||
]
|
||||
|
||||
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
|
||||
out_shardings, auto_spmd_lowering, kept_var_idx,
|
||||
device_assignment, unloaded_executable=None):
|
||||
unloaded_executable=None):
|
||||
self.xla_executable = xla_executable
|
||||
self.build_unsafe_call = build_unsafe_call
|
||||
# in_avals is a list of global and local avals. Aval is global if input
|
||||
@ -2882,7 +2874,6 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
self._out_shardings = out_shardings
|
||||
self._auto_spmd_lowering = auto_spmd_lowering
|
||||
self._kept_var_idx = kept_var_idx
|
||||
self._device_assignment = device_assignment
|
||||
self._unloaded_executable = unloaded_executable
|
||||
|
||||
@property
|
||||
@ -2899,11 +2890,11 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
if hasattr(backend, "compile_replicated"):
|
||||
return _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
||||
jaxpr, consts, global_in_avals, global_out_avals, in_shardings,
|
||||
backend, da_object.device_assignment, committed, kept_var_idx, 1)
|
||||
backend, da_object, committed, kept_var_idx, 1)
|
||||
|
||||
out_shardings = _out_shardings_for_trivial(
|
||||
jaxpr, consts, in_shardings, da_object.device_assignment)
|
||||
indices = _get_input_indices(global_out_avals, out_shardings)
|
||||
indices = _get_input_indices(global_out_avals, out_shardings, da_object)
|
||||
local_device_assignment = da_object.addressable_device_assignment
|
||||
handle_ins = InputsHandler(local_device_assignment, out_shardings, indices)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
@ -2913,7 +2904,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
handle_outs, kept_var_idx)
|
||||
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, False, kept_var_idx,
|
||||
da_object.device_assignment, None)
|
||||
None)
|
||||
|
||||
# -- stages.XlaExecutable overrides
|
||||
|
||||
@ -2975,6 +2966,22 @@ def check_arg_avals_for_call(ref_avals, arg_avals):
|
||||
f"called with:\n {arg_aval}")
|
||||
|
||||
|
||||
def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
|
||||
# Create replicated shardings for jit(pmap) path with local devices
|
||||
# because multihost jit(pmap) is not allowed.
|
||||
in_shardings = [
|
||||
sharding_impls.GSPMDSharding.get_replicated(local_devices)
|
||||
] * num_in_shardings
|
||||
out_shardings = [
|
||||
sharding_impls.GSPMDSharding.get_replicated(local_devices)
|
||||
] * num_out_shardings
|
||||
# jit(pmap) will generate Arrays with multi-device sharding.
|
||||
# It is unsupported for these shardings to be uncommited, so force
|
||||
# the outputs to be committed.
|
||||
committed = True
|
||||
return in_shardings, out_shardings, committed, tuple(local_devices)
|
||||
|
||||
|
||||
def _out_shardings_for_trivial(
|
||||
jaxpr: core.Jaxpr, consts: Sequence[Any],
|
||||
in_shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
@ -3026,11 +3033,8 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
|
||||
in_shardings = semantics_in_shardings.shardings
|
||||
out_shardings = semantics_out_shardings.shardings
|
||||
device_assignment = da.device_assignment if isinstance(
|
||||
da, _DeviceAssignment) else da
|
||||
|
||||
input_indices = _get_input_indices(
|
||||
global_in_avals, in_shardings) # type: ignore
|
||||
input_indices = _get_input_indices(global_in_avals, in_shardings, da) # type: ignore
|
||||
if pmap_nreps > 1:
|
||||
# For a jit wrapping a pmap, replicate each input index to match the
|
||||
# devices of the replicated jit computation.
|
||||
@ -3049,30 +3053,29 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
xla_executable = None
|
||||
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, auto_spmd_lowering,
|
||||
kept_var_idx, device_assignment, None)
|
||||
kept_var_idx, None)
|
||||
|
||||
|
||||
def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
||||
jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend,
|
||||
device_assignment, committed, kept_var_idx, pmap_nreps):
|
||||
da_object, committed, kept_var_idx, pmap_nreps):
|
||||
out_shardings = _out_shardings_for_trivial(
|
||||
jaxpr, consts, in_shardings, device_assignment)
|
||||
jaxpr, consts, in_shardings, da_object.device_assignment)
|
||||
|
||||
input_indices = _get_input_indices(
|
||||
global_in_avals, in_shardings) # type: ignore
|
||||
input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
global_out_avals, out_shardings, committed,
|
||||
[False] * len(global_out_avals))
|
||||
# Use the standard out_handler.
|
||||
unsafe_call = backend.compile_replicated(
|
||||
is_trivial=True, jaxpr=jaxpr, consts=consts,
|
||||
device_assignment=device_assignment, in_avals=global_in_avals,
|
||||
device_assignment=da_object.device_assignment, in_avals=global_in_avals,
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=kept_var_idx, out_handler=handle_outs,
|
||||
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
|
||||
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, False, kept_var_idx,
|
||||
device_assignment, None)
|
||||
None)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
|
@ -3534,7 +3534,8 @@ class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
mp = NamedSharding(global_mesh, P(None))
|
||||
|
||||
out_indices = pxla._get_input_indices(in_avals, [mp, mp, mp])
|
||||
out_indices = pxla._get_input_indices(in_avals, [mp, mp, mp],
|
||||
list(global_mesh.devices.flat))
|
||||
|
||||
self.assertLen(out_indices, len(in_avals))
|
||||
self.assertTrue(all(len(out) == len(global_mesh.local_devices)
|
||||
|
Loading…
x
Reference in New Issue
Block a user