mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove deprecated code from JAX lowering and compilation
PiperOrigin-RevId: 622530123
This commit is contained in:
parent
3b5980fd73
commit
c6804f92d0
@ -1004,19 +1004,6 @@ class UnloadedPmapExecutable:
|
||||
shards.out_sharded_avals, pci.out_axes)]
|
||||
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
|
||||
|
||||
if hasattr(pci.backend, "compile_replicated"):
|
||||
input_indices = [
|
||||
sharding_specs.spec_to_indices(aval.shape, spec)
|
||||
if spec is not None else None
|
||||
for aval, spec in safe_zip(pci.avals, input_sharding_specs)
|
||||
]
|
||||
handle_outs = local_avals_to_results_handler(local_unmapped_avals,
|
||||
out_shardings)
|
||||
return _compile_replicated_pmap_executable_from_hlo(
|
||||
hlo, pci, input_indices, in_shardings, handle_outs,
|
||||
compile_options, host_callbacks, bool(unordered_effects),
|
||||
ordered_effects, jaxpr_debug_info)
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
||||
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
@ -1038,23 +1025,6 @@ class UnloadedPmapExecutable:
|
||||
jaxpr_debug_info=jaxpr_debug_info).load()
|
||||
|
||||
|
||||
def _compile_replicated_pmap_executable_from_hlo(
|
||||
hlo: ir.Module, pci, input_indices, in_shardings, handle_outs,
|
||||
compile_options, host_callbacks, has_unordered_effects, ordered_effects,
|
||||
jaxpr_debug_info):
|
||||
# Use the standard out_handler.
|
||||
execute_fun = pci.backend.compile_replicated(
|
||||
is_trivial=False, name=pci.name, computation=hlo,
|
||||
compile_options=compile_options, host_callbacks=host_callbacks,
|
||||
has_unordered_effects=has_unordered_effects,
|
||||
ordered_effects=ordered_effects, in_avals=pci.avals,
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
|
||||
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
||||
return PmapExecutable(None, lambda: execute_fun, None, pci.avals,
|
||||
jaxpr_debug_info, None)
|
||||
|
||||
|
||||
class PmapExecutable(stages.XlaExecutable):
|
||||
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
|
||||
"fingerprint", "in_avals", "_jaxpr_debug_info",
|
||||
@ -2109,7 +2079,7 @@ def lower_sharding_computation(
|
||||
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
|
||||
any(not is_unspecified(o) for o in out_shardings))
|
||||
|
||||
if xla_extension_version < 241 or hasattr(backend, "compile_replicated"):
|
||||
if xla_extension_version < 241:
|
||||
gs = GSPMDSharding.get_replicated(device_assignment)
|
||||
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
|
||||
|
||||
@ -2720,15 +2690,12 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
|
||||
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
|
||||
|
||||
if hasattr(backend, "compile_replicated"):
|
||||
return None, compile_options
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
||||
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
xla_executable = compiler.compile_or_get_cached(
|
||||
backend, computation, dev, compile_options, host_callbacks)
|
||||
return xla_executable, compile_options
|
||||
return xla_executable
|
||||
|
||||
|
||||
def _maybe_get_and_check_in_shardings(
|
||||
@ -2858,7 +2825,6 @@ class UnloadedMeshExecutable:
|
||||
self.in_layouts, self.out_layouts,
|
||||
self.all_args_info, self)
|
||||
|
||||
# May return a MeshExecutable in the compile_replicated case.
|
||||
@staticmethod
|
||||
def from_hlo(name: str,
|
||||
hlo: ir.Module,
|
||||
@ -2911,24 +2877,12 @@ class UnloadedMeshExecutable:
|
||||
mesh = i.mesh # type: ignore
|
||||
break
|
||||
|
||||
xla_executable, compile_options = _cached_compilation(
|
||||
xla_executable = _cached_compilation(
|
||||
hlo, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
||||
allow_prop_to_outputs, tuple(host_callbacks), backend, da, pmap_nreps,
|
||||
compiler_options_keys, compiler_options_values)
|
||||
|
||||
if hasattr(backend, "compile_replicated"):
|
||||
semantics_in_shardings = SemanticallyEqualShardings(
|
||||
in_shardings, global_in_avals) # type: ignore
|
||||
semantics_out_shardings = SemanticallyEqualShardings(
|
||||
out_shardings, global_out_avals) # type: ignore
|
||||
return _compile_replicated_mesh_executable_from_hlo(
|
||||
hlo, name, tuple(global_in_avals), tuple(global_out_avals),
|
||||
semantics_in_shardings, semantics_out_shardings, auto_spmd_lowering,
|
||||
compile_options, tuple(host_callbacks), bool(unordered_effects),
|
||||
tuple(ordered_effects), tuple(kept_var_idx), backend, da, committed,
|
||||
pmap_nreps)
|
||||
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None
|
||||
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
|
||||
@ -3177,35 +3131,6 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
|
||||
return in_shardings, out_shardings, committed, tuple(local_devices)
|
||||
|
||||
|
||||
@weakref_lru_cache
|
||||
def _compile_replicated_mesh_executable_from_hlo(
|
||||
computation, name, global_in_avals, global_out_avals, semantics_in_shardings,
|
||||
semantics_out_shardings, auto_spmd_lowering, compile_options,
|
||||
host_callbacks, has_unordered_effects, ordered_effects, kept_var_idx,
|
||||
backend, da, committed, pmap_nreps):
|
||||
assert not auto_spmd_lowering
|
||||
in_shardings = semantics_in_shardings.shardings
|
||||
out_shardings = semantics_out_shardings.shardings
|
||||
|
||||
kept_var_idx = set(kept_var_idx)
|
||||
# Will compute out_handler with executable information.
|
||||
unsafe_call = backend.compile_replicated(
|
||||
is_trivial=False, name=name, computation=computation,
|
||||
compile_options=compile_options, host_callbacks=host_callbacks,
|
||||
has_unordered_effects=has_unordered_effects,
|
||||
device_assignment=da, ordered_effects=ordered_effects,
|
||||
in_avals=global_in_avals,
|
||||
in_shardings=in_shardings, kept_var_idx=kept_var_idx,
|
||||
out_avals=global_out_avals, out_shardings=out_shardings,
|
||||
committed=committed, pmap_nreps=pmap_nreps)
|
||||
xla_executable = None
|
||||
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
|
||||
global_out_avals, in_shardings, out_shardings,
|
||||
auto_spmd_lowering, kept_var_idx,
|
||||
(None,) * len(global_in_avals),
|
||||
(None,) * len(global_out_avals))
|
||||
|
||||
|
||||
@lru_cache
|
||||
def create_mesh_pspec_sharding(
|
||||
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
|
||||
@ -3358,12 +3283,3 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
|
||||
def maybe_extend_axis_env(*args, **kwargs):
|
||||
with core.extend_axis_env(*args, **kwargs):
|
||||
yield
|
||||
|
||||
|
||||
def device_put(x, devices: Sequence[xc.ArrayImpl],
|
||||
replicate: bool=False) -> list[xc.ArrayImpl]:
|
||||
"""Call device_put on a sequence of devices and return a flat sequence of buffers."""
|
||||
if replicate:
|
||||
return [jax.device_put(x, device) for device in devices]
|
||||
else:
|
||||
return [jax.device_put(val, device) for val, device in safe_zip(x, devices)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user