mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add UnloadedMeshExecutable to represent a MeshExecutable that is not loaded
on any physical devices for the purposes of serialization. This type is easier to serialize because it has not yet been converted into arg-handlers. Potential API: ``` str, in_tree, out_tree = lowered.compile_and_serialize() exec = jax.experimental.load_serialized(str, in_tree, out_tree, backend) exec // identical to lowered.compile(). ``` PiperOrigin-RevId: 486751141
This commit is contained in:
parent
218305964f
commit
2c1fe45997
@ -3127,6 +3127,21 @@ class MeshComputation(stages.XlaLowering):
|
||||
self.compile_args = compile_args
|
||||
self._executable = None
|
||||
|
||||
def _compile_unloaded(
|
||||
self,
|
||||
_allow_propagation_to_outputs: bool = False,
|
||||
_allow_compile_replicated: bool = True
|
||||
) -> Union[UnloadedMeshExecutable, MeshExecutable]:
|
||||
if self.is_trivial:
|
||||
return MeshExecutable.from_trivial_jaxpr(**self.compile_args)
|
||||
else:
|
||||
return UnloadedMeshExecutable.from_hlo(
|
||||
self._name,
|
||||
self._hlo,
|
||||
**self.compile_args,
|
||||
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
|
||||
_allow_compile_replicated=_allow_compile_replicated) # type: ignore
|
||||
|
||||
# -- stages.XlaLowering overrides
|
||||
|
||||
def hlo(self) -> xc.XlaComputation:
|
||||
@ -3152,13 +3167,11 @@ class MeshComputation(stages.XlaLowering):
|
||||
_allow_propagation_to_outputs : bool = False,
|
||||
_allow_compile_replicated : bool = True) -> MeshExecutable:
|
||||
if self._executable is None:
|
||||
if self.is_trivial:
|
||||
self._executable = MeshExecutable.from_trivial_jaxpr(**self.compile_args)
|
||||
else:
|
||||
self._executable = MeshExecutable.from_hlo(
|
||||
self._name, self._hlo, **self.compile_args,
|
||||
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
|
||||
_allow_compile_replicated=_allow_compile_replicated) # type: ignore
|
||||
executable = self._compile_unloaded(
|
||||
_allow_propagation_to_outputs, _allow_compile_replicated)
|
||||
if isinstance(executable, UnloadedMeshExecutable):
|
||||
executable = executable.load()
|
||||
self._executable = executable
|
||||
return self._executable
|
||||
|
||||
|
||||
@ -3167,10 +3180,22 @@ def _get_input_metadata(
|
||||
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
|
||||
) -> Tuple[Sequence[XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
|
||||
Sequence[ShapedArray]]:
|
||||
avals, shardings = _get_normalized_avals_and_shardings(
|
||||
global_in_avals, in_shardings, in_is_global)
|
||||
return shardings, _get_input_indices(avals, shardings), avals
|
||||
|
||||
|
||||
def _get_normalized_avals_and_shardings(
|
||||
global_in_avals: Sequence[ShapedArray],
|
||||
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
|
||||
) -> Tuple[Sequence[ShapedArray], Sequence[XLACompatibleSharding]]:
|
||||
from jax._src.sharding import MeshPspecSharding
|
||||
|
||||
shardings, input_indices, input_avals = [], [], []
|
||||
for gaval, i, is_global in safe_zip(global_in_avals, in_shardings, in_is_global):
|
||||
avals = []
|
||||
shardings = []
|
||||
|
||||
for gaval, i, is_global in safe_zip(global_in_avals, in_shardings,
|
||||
in_is_global):
|
||||
if is_global:
|
||||
aval = gaval
|
||||
sharding = i
|
||||
@ -3178,9 +3203,21 @@ def _get_input_metadata(
|
||||
assert isinstance(i, MeshPspecSharding)
|
||||
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
|
||||
sharding = MeshPspecSharding(i.mesh.local_mesh, i.spec)
|
||||
avals.append(aval)
|
||||
shardings.append(sharding)
|
||||
|
||||
return avals, shardings
|
||||
|
||||
|
||||
def _get_input_indices(
|
||||
avals: Sequence[ShapedArray], shardings: Sequence[XLACompatibleSharding]
|
||||
) -> Sequence[Tuple[Optional[Index], ...]]:
|
||||
|
||||
input_indices = []
|
||||
for aval, sharding in zip(avals, shardings):
|
||||
if aval is core.abstract_token:
|
||||
index = tuple((slice(None),) for _ in range(len(sharding.addressable_devices)))
|
||||
index = tuple(
|
||||
(slice(None),) for _ in range(len(sharding.addressable_devices)))
|
||||
else:
|
||||
# We special case this logic to support fully replicated values because
|
||||
# the mesh is global mesh and the indices returned by `spec_to_indices` will
|
||||
@ -3188,15 +3225,16 @@ def _get_input_metadata(
|
||||
# indices for the local devices of the global mesh.
|
||||
proto = sharding._to_xla_op_sharding(aval.ndim)
|
||||
if is_op_sharding_replicated(proto):
|
||||
index = tuple((slice(None),) * aval.ndim
|
||||
for _ in range(len(sharding.addressable_devices))) # type: ignore
|
||||
index = tuple(
|
||||
(slice(None),) * aval.ndim
|
||||
for _ in range(len(sharding.addressable_devices))) # type: ignore
|
||||
else:
|
||||
index = tuple(sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
|
||||
|
||||
shardings.append(sharding)
|
||||
index = tuple(
|
||||
sharding.addressable_devices_indices_map(
|
||||
aval.shape).values()) # type: ignore
|
||||
input_indices.append(index)
|
||||
input_avals.append(aval)
|
||||
return shardings, input_indices, input_avals
|
||||
|
||||
return input_indices
|
||||
|
||||
|
||||
def _get_op_sharding_shardings_from_executable(
|
||||
@ -3238,25 +3276,58 @@ def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):
|
||||
[MeshPspecSharding(mesh, o) for o in out_pspec])
|
||||
|
||||
|
||||
class MeshExecutable(stages.XlaExecutable):
|
||||
__slots__ = ['xla_executable', 'unsafe_call', 'in_avals',
|
||||
'_in_shardings', '_out_shardings', '_auto_spmd_lowering',
|
||||
'_kept_var_idx', '_device_assignment']
|
||||
@dataclasses.dataclass
|
||||
class UnloadedMeshExecutable:
|
||||
xla_executable: Any
|
||||
device_assignment: Sequence[xc.Device]
|
||||
backend: xb.XlaBackend
|
||||
input_avals: Sequence[ShapedArray]
|
||||
input_shardings: Sequence[XLACompatibleSharding]
|
||||
output_avals: Sequence[ShapedArray]
|
||||
output_shardings: Sequence[XLACompatibleSharding]
|
||||
committed: bool
|
||||
are_out_shardings_from_xla: Sequence[bool]
|
||||
pmap_nreps: int
|
||||
name: str
|
||||
unordered_effects: List[core.Effect]
|
||||
ordered_effects: List[core.Effect]
|
||||
keepalive: Sequence[Any]
|
||||
host_callbacks: Sequence[Any]
|
||||
kept_var_idx: Set[int]
|
||||
auto_spmd_lowering: bool
|
||||
|
||||
def __init__(self, xla_executable, unsafe_call, in_avals,
|
||||
in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx,
|
||||
device_assignment):
|
||||
self.xla_executable = xla_executable
|
||||
self.unsafe_call = unsafe_call
|
||||
# in_avals is a list of global and local avals. Aval is global if input
|
||||
# is a GDA or jax.Array else local.
|
||||
self.in_avals = in_avals
|
||||
self._in_shardings = in_shardings
|
||||
self._out_shardings = out_shardings
|
||||
self._auto_spmd_lowering = auto_spmd_lowering
|
||||
self._kept_var_idx = kept_var_idx
|
||||
self._device_assignment = device_assignment
|
||||
def load(self) -> MeshExecutable:
|
||||
input_indices = _get_input_indices(self.input_avals, self.input_shardings)
|
||||
handle_args = InputsHandler(self.xla_executable.local_devices(),
|
||||
self.input_shardings, input_indices,
|
||||
InputsHandlerMode.pjit_or_xmap)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
self.output_avals, self.output_shardings, self.committed,
|
||||
self.are_out_shardings_from_xla) # type: ignore # arg-type
|
||||
|
||||
# This path is taken for `jit(pmap)` cases. Nothing else should flow
|
||||
# through this path. This is exactly same to what happens in `jit`.
|
||||
if self.pmap_nreps > 1:
|
||||
has_unordered_effects = bool(self.unordered_effects)
|
||||
buffer_counts = dispatch.get_buffer_counts(
|
||||
self.output_avals, self.ordered_effects, has_unordered_effects)
|
||||
unsafe_call = partial(
|
||||
dispatch._execute_replicated, self.name, self.xla_executable, None,
|
||||
buffer_counts, handle_outs, has_unordered_effects, self.ordered_effects,
|
||||
self.kept_var_idx, bool(self.host_callbacks),
|
||||
from_lower_sharding_computation=True)
|
||||
else:
|
||||
unsafe_call = ExecuteReplicated( # type: ignore # assignment
|
||||
self.xla_executable, self.name, self.backend, handle_args,
|
||||
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
|
||||
bool(self.host_callbacks), self.kept_var_idx)
|
||||
|
||||
return MeshExecutable(self.xla_executable, unsafe_call, self.input_avals,
|
||||
self.input_shardings, self.output_shardings,
|
||||
self.auto_spmd_lowering, self.kept_var_idx,
|
||||
self.device_assignment)
|
||||
|
||||
# May return a MeshExecutable in the compile_replicated case.
|
||||
@staticmethod
|
||||
def from_hlo(name: str,
|
||||
computation: Union[ir.Module, xc.XlaComputation],
|
||||
@ -3282,7 +3353,9 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
backend: xb.XlaBackend,
|
||||
device_assignment: Sequence[xc.Device],
|
||||
committed: bool,
|
||||
pmap_nreps: int = 1) -> MeshExecutable:
|
||||
pmap_nreps: int = 1
|
||||
) -> Union[MeshExecutable, UnloadedMeshExecutable]:
|
||||
|
||||
dev: np.ndarray
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None and spmd_lowering
|
||||
@ -3357,33 +3430,51 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
else:
|
||||
are_out_shardings_from_xla = (False,) * len(global_out_avals)
|
||||
|
||||
in_shardings, input_indices, input_avals = _get_input_metadata(
|
||||
global_in_avals, in_shardings, in_is_global) # type: ignore
|
||||
handle_args = InputsHandler(xla_executable.local_devices(), in_shardings,
|
||||
input_indices, InputsHandlerMode.pjit_or_xmap)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
global_out_avals, out_shardings, committed, are_out_shardings_from_xla) # type: ignore # arg-type
|
||||
input_avals, input_shardings = (
|
||||
_get_normalized_avals_and_shardings(global_in_avals,
|
||||
in_shardings, # type: ignore # arg-type
|
||||
in_is_global))
|
||||
|
||||
# This path is taken for `jit(pmap)` cases. Nothing else should flow
|
||||
# through this path. This is exactly same to what happens in `jit`.
|
||||
if pmap_nreps > 1:
|
||||
has_unordered_effects = bool(unordered_effects)
|
||||
buffer_counts = dispatch.get_buffer_counts(
|
||||
global_out_avals, ordered_effects, has_unordered_effects)
|
||||
unsafe_call = partial(
|
||||
dispatch._execute_replicated, name, xla_executable, None,
|
||||
buffer_counts, handle_outs, has_unordered_effects, ordered_effects,
|
||||
kept_var_idx, bool(host_callbacks),
|
||||
from_lower_sharding_computation=True)
|
||||
else:
|
||||
unsafe_call = ExecuteReplicated( # type: ignore # assignment
|
||||
xla_executable, name, backend, handle_args,
|
||||
handle_outs, unordered_effects, ordered_effects, keepalive,
|
||||
bool(host_callbacks), kept_var_idx)
|
||||
return UnloadedMeshExecutable(
|
||||
xla_executable=xla_executable,
|
||||
device_assignment=device_assignment,
|
||||
backend=backend,
|
||||
input_avals=input_avals,
|
||||
input_shardings=input_shardings,
|
||||
output_avals=global_out_avals,
|
||||
output_shardings=out_shardings, # type: ignore # arg-type
|
||||
committed=committed,
|
||||
are_out_shardings_from_xla=are_out_shardings_from_xla,
|
||||
pmap_nreps=pmap_nreps,
|
||||
name=name,
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
keepalive=keepalive,
|
||||
host_callbacks=host_callbacks,
|
||||
kept_var_idx=kept_var_idx,
|
||||
auto_spmd_lowering=auto_spmd_lowering)
|
||||
|
||||
return MeshExecutable(xla_executable, unsafe_call, input_avals,
|
||||
in_shardings, out_shardings, auto_spmd_lowering,
|
||||
kept_var_idx, device_assignment)
|
||||
|
||||
class MeshExecutable(stages.XlaExecutable):
|
||||
__slots__ = [
|
||||
"xla_executable", "unsafe_call", "in_avals", "_in_shardings",
|
||||
"_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
|
||||
"_device_assignment"
|
||||
]
|
||||
|
||||
def __init__(self, xla_executable, unsafe_call, in_avals, in_shardings,
|
||||
out_shardings, auto_spmd_lowering, kept_var_idx,
|
||||
device_assignment):
|
||||
self.xla_executable = xla_executable
|
||||
self.unsafe_call = unsafe_call
|
||||
# in_avals is a list of global and local avals. Aval is global if input
|
||||
# is a GDA or jax.Array else local.
|
||||
self.in_avals = in_avals
|
||||
self._in_shardings = in_shardings
|
||||
self._out_shardings = out_shardings
|
||||
self._auto_spmd_lowering = auto_spmd_lowering
|
||||
self._kept_var_idx = kept_var_idx
|
||||
self._device_assignment = device_assignment
|
||||
|
||||
@staticmethod
|
||||
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
|
||||
|
Loading…
x
Reference in New Issue
Block a user