Add UnloadedMeshExecutable to represent a MeshExecutable that is not loaded

on any physical devices for the purposes of serialization. This type is easier
to serialize because it has not yet been converted into arg-handlers.

Potential API:
```
  str, in_tree, out_tree = lowered.compile_and_serialize()
  exec = jax.experimental.load_serialized(str, in_tree, out_tree, backend)
  exec // identical to lowered.compile().
```

PiperOrigin-RevId: 486751141
This commit is contained in:
Parker Schuh 2022-11-07 13:39:59 -08:00 committed by jax authors
parent 218305964f
commit 2c1fe45997

View File

@ -3127,6 +3127,21 @@ class MeshComputation(stages.XlaLowering):
self.compile_args = compile_args
self._executable = None
def _compile_unloaded(
self,
_allow_propagation_to_outputs: bool = False,
_allow_compile_replicated: bool = True
) -> Union[UnloadedMeshExecutable, MeshExecutable]:
if self.is_trivial:
return MeshExecutable.from_trivial_jaxpr(**self.compile_args)
else:
return UnloadedMeshExecutable.from_hlo(
self._name,
self._hlo,
**self.compile_args,
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
_allow_compile_replicated=_allow_compile_replicated) # type: ignore
# -- stages.XlaLowering overrides
def hlo(self) -> xc.XlaComputation:
@ -3152,13 +3167,11 @@ class MeshComputation(stages.XlaLowering):
_allow_propagation_to_outputs : bool = False,
_allow_compile_replicated : bool = True) -> MeshExecutable:
if self._executable is None:
if self.is_trivial:
self._executable = MeshExecutable.from_trivial_jaxpr(**self.compile_args)
else:
self._executable = MeshExecutable.from_hlo(
self._name, self._hlo, **self.compile_args,
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
_allow_compile_replicated=_allow_compile_replicated) # type: ignore
executable = self._compile_unloaded(
_allow_propagation_to_outputs, _allow_compile_replicated)
if isinstance(executable, UnloadedMeshExecutable):
executable = executable.load()
self._executable = executable
return self._executable
@ -3167,10 +3180,22 @@ def _get_input_metadata(
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
) -> Tuple[Sequence[XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
Sequence[ShapedArray]]:
avals, shardings = _get_normalized_avals_and_shardings(
global_in_avals, in_shardings, in_is_global)
return shardings, _get_input_indices(avals, shardings), avals
def _get_normalized_avals_and_shardings(
global_in_avals: Sequence[ShapedArray],
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
) -> Tuple[Sequence[ShapedArray], Sequence[XLACompatibleSharding]]:
from jax._src.sharding import MeshPspecSharding
shardings, input_indices, input_avals = [], [], []
for gaval, i, is_global in safe_zip(global_in_avals, in_shardings, in_is_global):
avals = []
shardings = []
for gaval, i, is_global in safe_zip(global_in_avals, in_shardings,
in_is_global):
if is_global:
aval = gaval
sharding = i
@ -3178,9 +3203,21 @@ def _get_input_metadata(
assert isinstance(i, MeshPspecSharding)
aval = i.mesh._global_to_local(cast(ArrayMapping, _get_array_mapping(i.spec)), gaval)
sharding = MeshPspecSharding(i.mesh.local_mesh, i.spec)
avals.append(aval)
shardings.append(sharding)
return avals, shardings
def _get_input_indices(
avals: Sequence[ShapedArray], shardings: Sequence[XLACompatibleSharding]
) -> Sequence[Tuple[Optional[Index], ...]]:
input_indices = []
for aval, sharding in zip(avals, shardings):
if aval is core.abstract_token:
index = tuple((slice(None),) for _ in range(len(sharding.addressable_devices)))
index = tuple(
(slice(None),) for _ in range(len(sharding.addressable_devices)))
else:
# We special case this logic to support fully replicated values because
# the mesh is global mesh and the indices returned by `spec_to_indices` will
@ -3188,15 +3225,16 @@ def _get_input_metadata(
# indices for the local devices of the global mesh.
proto = sharding._to_xla_op_sharding(aval.ndim)
if is_op_sharding_replicated(proto):
index = tuple((slice(None),) * aval.ndim
for _ in range(len(sharding.addressable_devices))) # type: ignore
index = tuple(
(slice(None),) * aval.ndim
for _ in range(len(sharding.addressable_devices))) # type: ignore
else:
index = tuple(sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
shardings.append(sharding)
index = tuple(
sharding.addressable_devices_indices_map(
aval.shape).values()) # type: ignore
input_indices.append(index)
input_avals.append(aval)
return shardings, input_indices, input_avals
return input_indices
def _get_op_sharding_shardings_from_executable(
@ -3238,25 +3276,58 @@ def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):
[MeshPspecSharding(mesh, o) for o in out_pspec])
class MeshExecutable(stages.XlaExecutable):
__slots__ = ['xla_executable', 'unsafe_call', 'in_avals',
'_in_shardings', '_out_shardings', '_auto_spmd_lowering',
'_kept_var_idx', '_device_assignment']
@dataclasses.dataclass
class UnloadedMeshExecutable:
xla_executable: Any
device_assignment: Sequence[xc.Device]
backend: xb.XlaBackend
input_avals: Sequence[ShapedArray]
input_shardings: Sequence[XLACompatibleSharding]
output_avals: Sequence[ShapedArray]
output_shardings: Sequence[XLACompatibleSharding]
committed: bool
are_out_shardings_from_xla: Sequence[bool]
pmap_nreps: int
name: str
unordered_effects: List[core.Effect]
ordered_effects: List[core.Effect]
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
kept_var_idx: Set[int]
auto_spmd_lowering: bool
def __init__(self, xla_executable, unsafe_call, in_avals,
in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx,
device_assignment):
self.xla_executable = xla_executable
self.unsafe_call = unsafe_call
# in_avals is a list of global and local avals. Aval is global if input
# is a GDA or jax.Array else local.
self.in_avals = in_avals
self._in_shardings = in_shardings
self._out_shardings = out_shardings
self._auto_spmd_lowering = auto_spmd_lowering
self._kept_var_idx = kept_var_idx
self._device_assignment = device_assignment
def load(self) -> MeshExecutable:
input_indices = _get_input_indices(self.input_avals, self.input_shardings)
handle_args = InputsHandler(self.xla_executable.local_devices(),
self.input_shardings, input_indices,
InputsHandlerMode.pjit_or_xmap)
handle_outs = global_avals_to_results_handler(
self.output_avals, self.output_shardings, self.committed,
self.are_out_shardings_from_xla) # type: ignore # arg-type
# This path is taken for `jit(pmap)` cases. Nothing else should flow
# through this path. This is exactly same to what happens in `jit`.
if self.pmap_nreps > 1:
has_unordered_effects = bool(self.unordered_effects)
buffer_counts = dispatch.get_buffer_counts(
self.output_avals, self.ordered_effects, has_unordered_effects)
unsafe_call = partial(
dispatch._execute_replicated, self.name, self.xla_executable, None,
buffer_counts, handle_outs, has_unordered_effects, self.ordered_effects,
self.kept_var_idx, bool(self.host_callbacks),
from_lower_sharding_computation=True)
else:
unsafe_call = ExecuteReplicated( # type: ignore # assignment
self.xla_executable, self.name, self.backend, handle_args,
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
bool(self.host_callbacks), self.kept_var_idx)
return MeshExecutable(self.xla_executable, unsafe_call, self.input_avals,
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self.device_assignment)
# May return a MeshExecutable in the compile_replicated case.
@staticmethod
def from_hlo(name: str,
computation: Union[ir.Module, xc.XlaComputation],
@ -3282,7 +3353,9 @@ class MeshExecutable(stages.XlaExecutable):
backend: xb.XlaBackend,
device_assignment: Sequence[xc.Device],
committed: bool,
pmap_nreps: int = 1) -> MeshExecutable:
pmap_nreps: int = 1
) -> Union[MeshExecutable, UnloadedMeshExecutable]:
dev: np.ndarray
if auto_spmd_lowering:
assert mesh is not None and spmd_lowering
@ -3357,33 +3430,51 @@ class MeshExecutable(stages.XlaExecutable):
else:
are_out_shardings_from_xla = (False,) * len(global_out_avals)
in_shardings, input_indices, input_avals = _get_input_metadata(
global_in_avals, in_shardings, in_is_global) # type: ignore
handle_args = InputsHandler(xla_executable.local_devices(), in_shardings,
input_indices, InputsHandlerMode.pjit_or_xmap)
handle_outs = global_avals_to_results_handler(
global_out_avals, out_shardings, committed, are_out_shardings_from_xla) # type: ignore # arg-type
input_avals, input_shardings = (
_get_normalized_avals_and_shardings(global_in_avals,
in_shardings, # type: ignore # arg-type
in_is_global))
# This path is taken for `jit(pmap)` cases. Nothing else should flow
# through this path. This is exactly same to what happens in `jit`.
if pmap_nreps > 1:
has_unordered_effects = bool(unordered_effects)
buffer_counts = dispatch.get_buffer_counts(
global_out_avals, ordered_effects, has_unordered_effects)
unsafe_call = partial(
dispatch._execute_replicated, name, xla_executable, None,
buffer_counts, handle_outs, has_unordered_effects, ordered_effects,
kept_var_idx, bool(host_callbacks),
from_lower_sharding_computation=True)
else:
unsafe_call = ExecuteReplicated( # type: ignore # assignment
xla_executable, name, backend, handle_args,
handle_outs, unordered_effects, ordered_effects, keepalive,
bool(host_callbacks), kept_var_idx)
return UnloadedMeshExecutable(
xla_executable=xla_executable,
device_assignment=device_assignment,
backend=backend,
input_avals=input_avals,
input_shardings=input_shardings,
output_avals=global_out_avals,
output_shardings=out_shardings, # type: ignore # arg-type
committed=committed,
are_out_shardings_from_xla=are_out_shardings_from_xla,
pmap_nreps=pmap_nreps,
name=name,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
keepalive=keepalive,
host_callbacks=host_callbacks,
kept_var_idx=kept_var_idx,
auto_spmd_lowering=auto_spmd_lowering)
return MeshExecutable(xla_executable, unsafe_call, input_avals,
in_shardings, out_shardings, auto_spmd_lowering,
kept_var_idx, device_assignment)
class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "unsafe_call", "in_avals", "_in_shardings",
"_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
"_device_assignment"
]
def __init__(self, xla_executable, unsafe_call, in_avals, in_shardings,
out_shardings, auto_spmd_lowering, kept_var_idx,
device_assignment):
self.xla_executable = xla_executable
self.unsafe_call = unsafe_call
# in_avals is a list of global and local avals. Aval is global if input
# is a GDA or jax.Array else local.
self.in_avals = in_avals
self._in_shardings = in_shardings
self._out_shardings = out_shardings
self._auto_spmd_lowering = auto_spmd_lowering
self._kept_var_idx = kept_var_idx
self._device_assignment = device_assignment
@staticmethod
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,