diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 403ee11c5..6fe95dbbe 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1004,19 +1004,6 @@ class UnloadedPmapExecutable: shards.out_sharded_avals, pci.out_axes)] out_shardings = _get_pmap_sharding(local_device_assignment, out_specs) - if hasattr(pci.backend, "compile_replicated"): - input_indices = [ - sharding_specs.spec_to_indices(aval.shape, spec) - if spec is not None else None - for aval, spec in safe_zip(pci.avals, input_sharding_specs) - ] - handle_outs = local_avals_to_results_handler(local_unmapped_avals, - out_shardings) - return _compile_replicated_pmap_executable_from_hlo( - hlo, pci, input_indices, in_shardings, handle_outs, - compile_options, host_callbacks, bool(unordered_effects), - ordered_effects, jaxpr_debug_info) - with dispatch.log_elapsed_time( "Finished XLA compilation of {fun_name} in {elapsed_time} sec", fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT): @@ -1038,23 +1025,6 @@ class UnloadedPmapExecutable: jaxpr_debug_info=jaxpr_debug_info).load() -def _compile_replicated_pmap_executable_from_hlo( - hlo: ir.Module, pci, input_indices, in_shardings, handle_outs, - compile_options, host_callbacks, has_unordered_effects, ordered_effects, - jaxpr_debug_info): - # Use the standard out_handler. - execute_fun = pci.backend.compile_replicated( - is_trivial=False, name=pci.name, computation=hlo, - compile_options=compile_options, host_callbacks=host_callbacks, - has_unordered_effects=has_unordered_effects, - ordered_effects=ordered_effects, in_avals=pci.avals, - in_indices=input_indices, in_shardings=in_shardings, - kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs) - # TODO(frostig): need `compile_replicated` to give us the XLA executable - return PmapExecutable(None, lambda: execute_fun, None, pci.avals, - jaxpr_debug_info, None) - - class PmapExecutable(stages.XlaExecutable): __slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call", "fingerprint", "in_avals", "_jaxpr_debug_info", @@ -2109,7 +2079,7 @@ def lower_sharding_computation( any(not is_unspecified(js) for js, _ in jaxpr_sharding) or any(not is_unspecified(o) for o in out_shardings)) - if xla_extension_version < 241 or hasattr(backend, "compile_replicated"): + if xla_extension_version < 241: gs = GSPMDSharding.get_replicated(device_assignment) in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) @@ -2720,15 +2690,12 @@ def _cached_compilation(computation, name, mesh, spmd_lowering, opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs) opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs) - if hasattr(backend, "compile_replicated"): - return None, compile_options - with dispatch.log_elapsed_time( "Finished XLA compilation of {fun_name} in {elapsed_time} sec", fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT): xla_executable = compiler.compile_or_get_cached( backend, computation, dev, compile_options, host_callbacks) - return xla_executable, compile_options + return xla_executable def _maybe_get_and_check_in_shardings( @@ -2858,7 +2825,6 @@ class UnloadedMeshExecutable: self.in_layouts, self.out_layouts, self.all_args_info, self) - # May return a MeshExecutable in the compile_replicated case. @staticmethod def from_hlo(name: str, hlo: ir.Module, @@ -2911,24 +2877,12 @@ class UnloadedMeshExecutable: mesh = i.mesh # type: ignore break - xla_executable, compile_options = _cached_compilation( + xla_executable = _cached_compilation( hlo, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, allow_prop_to_inputs, allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps, compiler_options_keys, compiler_options_values) - if hasattr(backend, "compile_replicated"): - semantics_in_shardings = SemanticallyEqualShardings( - in_shardings, global_in_avals) # type: ignore - semantics_out_shardings = SemanticallyEqualShardings( - out_shardings, global_out_avals) # type: ignore - return _compile_replicated_mesh_executable_from_hlo( - hlo, name, tuple(global_in_avals), tuple(global_out_avals), - semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering, - compile_options, tuple(host_callbacks), bool(unordered_effects), - tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed, - pmap_nreps) - if auto_spmd_lowering: assert mesh is not None in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable( @@ -3177,35 +3131,6 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): return in_shardings, out_shardings, committed, tuple(local_devices) -@weakref_lru_cache -def _compile_replicated_mesh_executable_from_hlo( - computation, name, global_in_avals, global_out_avals, semantics_in_shardings, - semantics_out_shardings, auto_spmd_lowering, compile_options, - host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx, - backend, da, committed, pmap_nreps): - assert not auto_spmd_lowering - in_shardings = semantics_in_shardings.shardings - out_shardings = semantics_out_shardings.shardings - - kept_var_idx = set(kept_var_idx) - # Will compute out_handler with executable information. - unsafe_call = backend.compile_replicated( - is_trivial=False, name=name, computation=computation, - compile_options=compile_options, host_callbacks=host_callbacks, - has_unordered_effects=has_unordered_effects, - device_assignment=da, ordered_effects=ordered_effects, - in_avals=global_in_avals, - in_shardings=in_shardings, kept_var_idx=kept_var_idx, - out_avals=global_out_avals, out_shardings=out_shardings, - committed=committed, pmap_nreps=pmap_nreps) - xla_executable = None - return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals, - global_out_avals, in_shardings, out_shardings, - auto_spmd_lowering, kept_var_idx, - (None,) * len(global_in_avals), - (None,) * len(global_out_avals)) - - @lru_cache def create_mesh_pspec_sharding( mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None, @@ -3358,12 +3283,3 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk): def maybe_extend_axis_env(*args, **kwargs): with core.extend_axis_env(*args, **kwargs): yield - - -def device_put(x, devices: Sequence[xc.ArrayImpl], - replicate: bool=False) -> list[xc.ArrayImpl]: - """Call device_put on a sequence of devices and return a flat sequence of buffers.""" - if replicate: - return [jax.device_put(x, device) for device in devices] - else: - return [jax.device_put(val, device) for val, device in safe_zip(x, devices)]