Remove deprecated code from JAX lowering and compilation

PiperOrigin-RevId: 622530123
This commit is contained in:
Yash Katariya 2024-04-06 19:42:39 -07:00 committed by jax authors
parent 3b5980fd73
commit c6804f92d0

View File

@ -1004,19 +1004,6 @@ class UnloadedPmapExecutable:
shards.out_sharded_avals, pci.out_axes)]
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
if hasattr(pci.backend, "compile_replicated"):
input_indices = [
sharding_specs.spec_to_indices(aval.shape, spec)
if spec is not None else None
for aval, spec in safe_zip(pci.avals, input_sharding_specs)
]
handle_outs = local_avals_to_results_handler(local_unmapped_avals,
out_shardings)
return _compile_replicated_pmap_executable_from_hlo(
hlo, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, bool(unordered_effects),
ordered_effects, jaxpr_debug_info)
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
@ -1038,23 +1025,6 @@ class UnloadedPmapExecutable:
jaxpr_debug_info=jaxpr_debug_info).load()
def _compile_replicated_pmap_executable_from_hlo(
hlo: ir.Module, pci, input_indices, in_shardings, handle_outs,
compile_options, host_callbacks, has_unordered_effects, ordered_effects,
jaxpr_debug_info):
# Use the standard out_handler.
execute_fun = pci.backend.compile_replicated(
is_trivial=False, name=pci.name, computation=hlo,
compile_options=compile_options, host_callbacks=host_callbacks,
has_unordered_effects=has_unordered_effects,
ordered_effects=ordered_effects, in_avals=pci.avals,
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, lambda: execute_fun, None, pci.avals,
jaxpr_debug_info, None)
class PmapExecutable(stages.XlaExecutable):
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
"fingerprint", "in_avals", "_jaxpr_debug_info",
@ -2109,7 +2079,7 @@ def lower_sharding_computation(
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(o) for o in out_shardings))
if xla_extension_version < 241 or hasattr(backend, "compile_replicated"):
if xla_extension_version < 241:
gs = GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
@ -2720,15 +2690,12 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
if hasattr(backend, "compile_replicated"):
return None, compile_options
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
xla_executable = compiler.compile_or_get_cached(
backend, computation, dev, compile_options, host_callbacks)
return xla_executable, compile_options
return xla_executable
def _maybe_get_and_check_in_shardings(
@ -2858,7 +2825,6 @@ class UnloadedMeshExecutable:
self.in_layouts, self.out_layouts,
self.all_args_info, self)
# May return a MeshExecutable in the compile_replicated case.
@staticmethod
def from_hlo(name: str,
hlo: ir.Module,
@ -2911,24 +2877,12 @@ class UnloadedMeshExecutable:
mesh = i.mesh # type: ignore
break
xla_executable, compile_options = _cached_compilation(
xla_executable = _cached_compilation(
hlo, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
compiler_options_keys, compiler_options_values)
if hasattr(backend, "compile_replicated"):
semantics_in_shardings = SemanticallyEqualShardings(
in_shardings, global_in_avals) # type: ignore
semantics_out_shardings = SemanticallyEqualShardings(
out_shardings, global_out_avals) # type: ignore
return _compile_replicated_mesh_executable_from_hlo(
hlo, name, tuple(global_in_avals), tuple(global_out_avals),
semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering,
compile_options, tuple(host_callbacks), bool(unordered_effects),
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
pmap_nreps)
if auto_spmd_lowering:
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
@ -3177,35 +3131,6 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
return in_shardings, out_shardings, committed, tuple(local_devices)
@weakref_lru_cache
def _compile_replicated_mesh_executable_from_hlo(
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
semantics_out_shardings, auto_spmd_lowering, compile_options,
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
backend, da, committed, pmap_nreps):
assert not auto_spmd_lowering
in_shardings = semantics_in_shardings.shardings
out_shardings = semantics_out_shardings.shardings
kept_var_idx = set(kept_var_idx)
# Will compute out_handler with executable information.
unsafe_call = backend.compile_replicated(
is_trivial=False, name=name, computation=computation,
compile_options=compile_options, host_callbacks=host_callbacks,
has_unordered_effects=has_unordered_effects,
device_assignment=da, ordered_effects=ordered_effects,
in_avals=global_in_avals,
in_shardings=in_shardings, kept_var_idx=kept_var_idx,
out_avals=global_out_avals, out_shardings=out_shardings,
committed=committed, pmap_nreps=pmap_nreps)
xla_executable = None
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
global_out_avals, in_shardings, out_shardings,
auto_spmd_lowering, kept_var_idx,
(None,) * len(global_in_avals),
(None,) * len(global_out_avals))
@lru_cache
def create_mesh_pspec_sharding(
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
@ -3358,12 +3283,3 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
def maybe_extend_axis_env(*args, **kwargs):
with core.extend_axis_env(*args, **kwargs):
yield
def device_put(x, devices: Sequence[xc.ArrayImpl],
replicate: bool=False) -> list[xc.ArrayImpl]:
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
if replicate:
return [jax.device_put(x, device) for device in devices]
else:
return [jax.device_put(val, device) for val, device in safe_zip(x, devices)]