diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 8d47a748d..3fb0366a7 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2495,14 +2495,22 @@ class MeshComputation(stages.XlaLowering): def _get_input_indices( - avals: Sequence[ShapedArray], shardings: Sequence[sharding_impls.XLACompatibleSharding] + avals: Sequence[ShapedArray], + shardings: Sequence[sharding_impls.XLACompatibleSharding], + da_object: Union[_DeviceAssignment, Sequence[xc.Device]], ) -> Sequence[Tuple[Optional[Index], ...]]: input_indices = [] + if isinstance(da_object, _DeviceAssignment): + num_addressable_devices = len(da_object.addressable_device_assignment) + else: + num_addressable_devices = len( + [d for d in da_object if d.process_index == d.client.process_index()]) + for aval, sharding in zip(avals, shardings): if aval is core.abstract_token: index = tuple( - (slice(None),) for _ in range(len(sharding.addressable_devices))) + (slice(None),) for _ in range(num_addressable_devices)) else: # We special case this logic to support fully replicated values because # the mesh is global mesh and the indices returned by `spec_to_indices` will @@ -2511,12 +2519,10 @@ def _get_input_indices( proto = sharding._to_xla_op_sharding(aval.ndim) if op_shardings.is_op_sharding_replicated(proto): index = tuple( - (slice(None),) * aval.ndim - for _ in range(len(sharding.addressable_devices))) # type: ignore + (slice(None),) * aval.ndim for _ in range(num_addressable_devices)) # type: ignore else: index = tuple( - sharding.addressable_devices_indices_map( - aval.shape).values()) # type: ignore + sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore input_indices.append(index) return input_indices @@ -2683,7 +2689,7 @@ def _cached_compilation(computation, name, mesh, num_out_avals, spmd_lowering, @dataclasses.dataclass class UnloadedMeshExecutable: xla_executable: Any - device_assignment: Sequence[xc.Device] + device_assignment: Union[_DeviceAssignment, Sequence[xc.Device]] backend: xb.XlaBackend input_avals: Sequence[ShapedArray] input_shardings: Sequence[sharding_impls.XLACompatibleSharding] @@ -2700,7 +2706,8 @@ class UnloadedMeshExecutable: auto_spmd_lowering: bool def build_unsafe_call(self): - input_indices = _get_input_indices(self.input_avals, self.input_shardings) + input_indices = _get_input_indices(self.input_avals, self.input_shardings, + self.device_assignment) handle_args = InputsHandler(self.xla_executable.local_devices(), self.input_shardings, input_indices) handle_outs = global_avals_to_results_handler( @@ -2718,7 +2725,7 @@ class UnloadedMeshExecutable: self.input_avals, self.input_shardings, self.output_shardings, self.auto_spmd_lowering, self.kept_var_idx, - self.device_assignment, self) + self) # May return a MeshExecutable in the compile_replicated case. @staticmethod @@ -2756,6 +2763,7 @@ class UnloadedMeshExecutable: compiler_options.values()) if compiler_options is not None else None da = device_assignment if isinstance( device_assignment, _DeviceAssignment) else tuple(device_assignment) + del device_assignment xla_executable, compile_options = _cached_compilation( computation, name, mesh, len(global_out_avals), spmd_lowering, tuple_args, auto_spmd_lowering, allow_propagation_to_outputs, @@ -2772,10 +2780,6 @@ class UnloadedMeshExecutable: tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed, pmap_nreps) - del da - device_assignment = device_assignment.device_assignment if isinstance( - device_assignment, _DeviceAssignment) else device_assignment - if auto_spmd_lowering: assert mesh is not None in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( @@ -2790,8 +2794,10 @@ class UnloadedMeshExecutable: elif (out_shardings and any(is_unspecified(o) for o in out_shardings) and pmap_nreps == 1): assert mesh is None + device_assignment = da.device_assignment if isinstance( # type: ignore + da, _DeviceAssignment) else da _, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore - xla_executable, device_assignment, + xla_executable, device_assignment, # type: ignore len(global_in_avals), len(global_out_avals)) orig_out_shardings = out_shardings out_shardings, are_out_shardings_from_xla = [], [] # type: ignore @@ -2813,26 +2819,15 @@ class UnloadedMeshExecutable: are_out_shardings_from_xla = (False,) * len(global_out_avals) if pmap_nreps > 1: - local_devices = xla_executable.local_devices() - # Create replicated shardings for jit(pmap) path with local devices - # because multihost jit(pmap) is not allowed. - in_shardings = [ - sharding_impls.GSPMDSharding.get_replicated(local_devices) - ] * len(in_shardings) - out_shardings = [ - sharding_impls.GSPMDSharding.get_replicated(local_devices) - ] * len(out_shardings) - # jit(pmap) will generate Arrays with multi-device sharding. - # It is unsupported for these shardings to be uncommited, so force - # the outputs to be committed. - committed = True + in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap( + xla_executable.local_devices(), len(in_shardings), len(out_shardings)) out_shardings, are_out_shardings_from_xla = maybe_get_orig_out_sharding( in_shardings, out_shardings, are_out_shardings_from_xla) return UnloadedMeshExecutable( xla_executable=xla_executable, - device_assignment=device_assignment, + device_assignment=da, # type: ignore backend=backend, input_avals=global_in_avals, input_shardings=in_shardings, # type: ignore @@ -2861,17 +2856,14 @@ class MeshExecutableFastpathData(NamedTuple): class MeshExecutable(stages.XlaExecutable): __slots__ = [ - "xla_executable", "_unsafe_call", - "build_unsafe_call", "in_avals", - "_in_shardings", "_out_shardings", - "_auto_spmd_lowering", "_kept_var_idx", - "_device_assignment", + "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", + "_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx", "_unloaded_executable", ] def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, - device_assignment, unloaded_executable=None): + unloaded_executable=None): self.xla_executable = xla_executable self.build_unsafe_call = build_unsafe_call # in_avals is a list of global and local avals. Aval is global if input @@ -2882,7 +2874,6 @@ class MeshExecutable(stages.XlaExecutable): self._out_shardings = out_shardings self._auto_spmd_lowering = auto_spmd_lowering self._kept_var_idx = kept_var_idx - self._device_assignment = device_assignment self._unloaded_executable = unloaded_executable @property @@ -2899,11 +2890,11 @@ class MeshExecutable(stages.XlaExecutable): if hasattr(backend, "compile_replicated"): return _compile_replicated_mesh_executable_from_trivial_jaxpr( jaxpr, consts, global_in_avals, global_out_avals, in_shardings, - backend, da_object.device_assignment, committed, kept_var_idx, 1) + backend, da_object, committed, kept_var_idx, 1) out_shardings = _out_shardings_for_trivial( jaxpr, consts, in_shardings, da_object.device_assignment) - indices = _get_input_indices(global_out_avals, out_shardings) + indices = _get_input_indices(global_out_avals, out_shardings, da_object) local_device_assignment = da_object.addressable_device_assignment handle_ins = InputsHandler(local_device_assignment, out_shardings, indices) handle_outs = global_avals_to_results_handler( @@ -2913,7 +2904,7 @@ class MeshExecutable(stages.XlaExecutable): handle_outs, kept_var_idx) return MeshExecutable(None, lambda: unsafe_call, global_in_avals, in_shardings, out_shardings, False, kept_var_idx, - da_object.device_assignment, None) + None) # -- stages.XlaExecutable overrides @@ -2975,6 +2966,22 @@ def check_arg_avals_for_call(ref_avals, arg_avals): f"called with:\n {arg_aval}") +def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): + # Create replicated shardings for jit(pmap) path with local devices + # because multihost jit(pmap) is not allowed. + in_shardings = [ + sharding_impls.GSPMDSharding.get_replicated(local_devices) + ] * num_in_shardings + out_shardings = [ + sharding_impls.GSPMDSharding.get_replicated(local_devices) + ] * num_out_shardings + # jit(pmap) will generate Arrays with multi-device sharding. + # It is unsupported for these shardings to be uncommited, so force + # the outputs to be committed. + committed = True + return in_shardings, out_shardings, committed, tuple(local_devices) + + def _out_shardings_for_trivial( jaxpr: core.Jaxpr, consts: Sequence[Any], in_shardings: Sequence[sharding_impls.XLACompatibleSharding], @@ -3026,11 +3033,8 @@ def _compile_replicated_mesh_executable_from_hlo( in_shardings = semantics_in_shardings.shardings out_shardings = semantics_out_shardings.shardings - device_assignment = da.device_assignment if isinstance( - da, _DeviceAssignment) else da - input_indices = _get_input_indices( - global_in_avals, in_shardings) # type: ignore + input_indices = _get_input_indices(global_in_avals, in_shardings, da) # type: ignore if pmap_nreps > 1: # For a jit wrapping a pmap, replicate each input index to match the # devices of the replicated jit computation. @@ -3049,30 +3053,29 @@ def _compile_replicated_mesh_executable_from_hlo( xla_executable = None return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals, in_shardings, out_shardings, auto_spmd_lowering, - kept_var_idx, device_assignment, None) + kept_var_idx, None) def _compile_replicated_mesh_executable_from_trivial_jaxpr( jaxpr, consts, global_in_avals, global_out_avals, in_shardings, backend, - device_assignment, committed, kept_var_idx, pmap_nreps): + da_object, committed, kept_var_idx, pmap_nreps): out_shardings = _out_shardings_for_trivial( - jaxpr, consts, in_shardings, device_assignment) + jaxpr, consts, in_shardings, da_object.device_assignment) - input_indices = _get_input_indices( - global_in_avals, in_shardings) # type: ignore + input_indices = _get_input_indices(global_in_avals, in_shardings, da_object) # type: ignore handle_outs = global_avals_to_results_handler( global_out_avals, out_shardings, committed, [False] * len(global_out_avals)) # Use the standard out_handler. unsafe_call = backend.compile_replicated( is_trivial=True, jaxpr=jaxpr, consts=consts, - device_assignment=device_assignment, in_avals=global_in_avals, + device_assignment=da_object.device_assignment, in_avals=global_in_avals, in_indices=input_indices, in_shardings=in_shardings, kept_var_idx=kept_var_idx, out_handler=handle_outs, out_shardings=out_shardings, pmap_nreps=pmap_nreps) return MeshExecutable(None, lambda: unsafe_call, global_in_avals, in_shardings, out_shardings, False, kept_var_idx, - device_assignment, None) + None) @lru_cache() diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 724c8ba94..29d015d63 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3534,7 +3534,8 @@ class UtilTest(jtu.JaxTestCase): mp = NamedSharding(global_mesh, P(None)) - out_indices = pxla._get_input_indices(in_avals, [mp, mp, mp]) + out_indices = pxla._get_input_indices(in_avals, [mp, mp, mp], + list(global_mesh.devices.flat)) self.assertLen(out_indices, len(in_avals)) self.assertTrue(all(len(out) == len(global_mesh.local_devices)