Redefine compile_and_serialize as serialize(lowered.compile()).

This has the downside of keeping around the UnloadedMeshComputation,
but it makes the serialize() API easier to understand.

PiperOrigin-RevId: 518715469
This commit is contained in:
Parker Schuh 2023-03-22 17:22:39 -07:00 committed by jax authors
parent 5f724cf9a2
commit 484eb26d2a
3 changed files with 86 additions and 80 deletions

View File

@ -1299,9 +1299,6 @@ class PmapComputation(stages.XlaLowering):
self._hlo = hlo
self.compile_args = compile_args
def _compile_unloaded(self) -> Union[UnloadedPmapExecutable, PmapExecutable]:
return UnloadedPmapExecutable.from_hlo(self._hlo, **self.compile_args)
# -- stages.XlaLowering overrides
def hlo(self) -> xc.XlaComputation:
@ -1319,10 +1316,8 @@ class PmapComputation(stages.XlaLowering):
@profiler.annotate_function
def compile(self) -> PmapExecutable:
if self._executable is None:
executable = self._compile_unloaded()
if isinstance(executable, UnloadedPmapExecutable):
executable = executable.load()
self._executable = executable
self._executable = UnloadedPmapExecutable.from_hlo(
self._hlo, **self.compile_args)
return self._executable
@ -1471,9 +1466,9 @@ class UnloadedPmapExecutable:
ordered_effects=ordered_effects,
keepalive=keepalive,
host_callbacks=host_callbacks,
)
).load()
def load(self) -> PmapExecutable:
def build_execute_fun(self):
input_indices = []
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
assert isinstance(spec, sharding_impls.PmapSharding), spec
@ -1489,10 +1484,13 @@ class UnloadedPmapExecutable:
self.ordered_effects, self.keepalive,
bool(self.host_callbacks),
set(range(len(input_indices))))
return execute_fun
def load(self) -> PmapExecutable:
fingerprint = getattr(self.compiled, "fingerprint", None)
return PmapExecutable(self.compiled, execute_fun, fingerprint,
self.local_input_avals)
return PmapExecutable(self.compiled, self.build_execute_fun, fingerprint,
self.local_input_avals, self)
def _compile_replicated_pmap_executable_from_hlo(
@ -1507,17 +1505,27 @@ def _compile_replicated_pmap_executable_from_hlo(
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
# TODO(frostig): need `compile_replicated` to give us the XLA executable
return PmapExecutable(None, execute_fun, None, pci.avals)
return PmapExecutable(None, lambda: execute_fun, None, pci.avals, None)
class PmapExecutable(stages.XlaExecutable):
__slots__ = ["xla_executable", "unsafe_call", "fingerprint", "in_avals"]
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
"fingerprint", "in_avals", "_unloaded_executable"]
def __init__(self, xla_executable, unsafe_call, fingerprint, in_avals):
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
in_avals, unloaded_executable):
self.xla_executable = xla_executable
self.unsafe_call = unsafe_call
self._unsafe_call = None
self.build_unsafe_call = build_unsafe_call
self.fingerprint = fingerprint
self.in_avals = in_avals
self._unloaded_executable = unloaded_executable
@property
def unsafe_call(self) -> Callable[..., Any]:
if self._unsafe_call is None:
self._unsafe_call = self.build_unsafe_call()
return self._unsafe_call
# -- stages.XlaExecutable overrides
@ -1529,7 +1537,7 @@ class PmapExecutable(stages.XlaExecutable):
# TODO(frostig): do we need to check sharding and sharded avals?
arg_avals = map(xla.abstractify, args)
check_arg_avals_for_call(self.in_avals, arg_avals)
return self.unsafe_call(*args)
return self.unsafe_call(*args) # pylint: disable=not-callable
def _get_pmap_sharding(devices, specs):
@ -2804,21 +2812,6 @@ class MeshComputation(stages.XlaLowering):
self.compile_args = compile_args
self._executable = None
def _compile_unloaded(
self,
_allow_propagation_to_outputs: Optional[Sequence[bool]] = None,
_allow_compile_replicated: bool = True
) -> Union[UnloadedMeshExecutable, MeshExecutable]:
if self.is_trivial:
return MeshExecutable.from_trivial_jaxpr(**self.compile_args)
else:
return UnloadedMeshExecutable.from_hlo(
self._name,
self._hlo,
**self.compile_args,
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
_allow_compile_replicated=_allow_compile_replicated) # type: ignore
# -- stages.XlaLowering overrides
def hlo(self) -> xc.XlaComputation:
@ -2841,11 +2834,16 @@ class MeshComputation(stages.XlaLowering):
_allow_propagation_to_outputs: Optional[Sequence[bool]] = None,
_allow_compile_replicated: bool = True) -> MeshExecutable:
if self._executable is None:
executable = self._compile_unloaded(
_allow_propagation_to_outputs, _allow_compile_replicated)
if isinstance(executable, UnloadedMeshExecutable):
executable = executable.load()
self._executable = executable
if self.is_trivial:
self._executable = MeshExecutable.from_trivial_jaxpr(
**self.compile_args)
else:
self._executable = UnloadedMeshExecutable.from_hlo(
self._name,
self._hlo,
**self.compile_args,
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
_allow_compile_replicated=_allow_compile_replicated)
return self._executable
def cost_analysis(self) -> Dict[str, float]:
@ -2952,7 +2950,7 @@ class UnloadedMeshExecutable:
kept_var_idx: Set[int]
auto_spmd_lowering: bool
def load(self) -> MeshExecutable:
def build_unsafe_call(self):
input_indices = _get_input_indices(self.input_avals, self.input_shardings)
handle_args = InputsHandler(self.xla_executable.local_devices(),
self.input_shardings, input_indices)
@ -2964,11 +2962,14 @@ class UnloadedMeshExecutable:
self.xla_executable, self.name, self.backend, handle_args,
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
bool(self.host_callbacks), self.kept_var_idx)
return unsafe_call
return MeshExecutable(self.xla_executable, unsafe_call, self.input_avals,
def load(self) -> MeshExecutable:
return MeshExecutable(self.xla_executable, self.build_unsafe_call,
self.input_avals,
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self.device_assignment)
self.device_assignment, self)
# May return a MeshExecutable in the compile_replicated case.
@staticmethod
@ -2996,7 +2997,7 @@ class UnloadedMeshExecutable:
device_assignment: Sequence[xc.Device],
committed: bool,
pmap_nreps: int = 1
) -> Union[MeshExecutable, UnloadedMeshExecutable]:
) -> MeshExecutable:
dev: np.ndarray
if auto_spmd_lowering:
@ -3119,7 +3120,7 @@ class UnloadedMeshExecutable:
keepalive=keepalive,
host_callbacks=host_callbacks,
kept_var_idx=kept_var_idx,
auto_spmd_lowering=auto_spmd_lowering)
auto_spmd_lowering=auto_spmd_lowering).load()
class MeshExecutableFastpathData(NamedTuple):
@ -3134,24 +3135,35 @@ class MeshExecutableFastpathData(NamedTuple):
class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "unsafe_call", "in_avals", "_in_shardings",
"_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
"_device_assignment"
"xla_executable", "_unsafe_call",
"build_unsafe_call", "in_avals",
"_in_shardings", "_out_shardings",
"_auto_spmd_lowering", "_kept_var_idx",
"_device_assignment",
"_unloaded_executable",
]
def __init__(self, xla_executable, unsafe_call, in_avals, in_shardings,
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
out_shardings, auto_spmd_lowering, kept_var_idx,
device_assignment):
device_assignment, unloaded_executable=None):
self.xla_executable = xla_executable
self.unsafe_call = unsafe_call
self.build_unsafe_call = build_unsafe_call
# in_avals is a list of global and local avals. Aval is global if input
# is a GDA or jax.Array else local.
self.in_avals = in_avals
self._unsafe_call = None
self._in_shardings = in_shardings
self._out_shardings = out_shardings
self._auto_spmd_lowering = auto_spmd_lowering
self._kept_var_idx = kept_var_idx
self._device_assignment = device_assignment
self._unloaded_executable = unloaded_executable
@property
def unsafe_call(self) -> Callable[..., Any]:
if self._unsafe_call is None:
self._unsafe_call = self.build_unsafe_call()
return self._unsafe_call
@staticmethod
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
@ -3174,8 +3186,9 @@ class MeshExecutable(stages.XlaExecutable):
[False] * len(global_out_avals))
unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins,
handle_outs, kept_var_idx)
return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings,
out_shardings, False, kept_var_idx, device_assignment)
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, False, kept_var_idx,
device_assignment, None)
# -- stages.XlaExecutable overrides
@ -3189,7 +3202,7 @@ class MeshExecutable(stages.XlaExecutable):
check_arg_avals_for_call(ref_avals, arg_avals)
# Check the GDA sharding and the input sharding.
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings)
return self.unsafe_call(*args)
return self.unsafe_call(*args) # pylint: disable=not-callable
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
return self._in_shardings
@ -3297,9 +3310,9 @@ def _compile_replicated_mesh_executable_from_hlo(
out_avals=global_out_avals, out_shardings=out_shardings,
committed=committed, pmap_nreps=pmap_nreps)
xla_executable = None
return MeshExecutable(xla_executable, unsafe_call, global_in_avals,
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, auto_spmd_lowering,
kept_var_idx, device_assignment)
kept_var_idx, device_assignment, None)
def _compile_replicated_mesh_executable_from_trivial_jaxpr(
@ -3320,9 +3333,9 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
in_indices=input_indices, in_shardings=in_shardings,
kept_var_idx=kept_var_idx, out_handler=handle_outs,
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings,
out_shardings, False, kept_var_idx,
device_assignment)
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, False, kept_var_idx,
device_assignment, None)
@lru_cache()

View File

@ -21,46 +21,39 @@ import jax
from jax._src.lib import xla_client as xc
def compile_and_serialize(lowered: jax.stages.Lowered):
"""Compiles a lowered executable, and then serializes the resulting binary.
def serialize(compiled: jax.stages.Compiled):
"""Serializes a compiled binary.
Because pytrees are not serializable, they are returned so that
the user can handle them properly.
"""
from jax.interpreters import pxla
if isinstance(lowered._lowering, pxla.MeshComputation):
kw = dict(_allow_propagation_to_outputs=[
pxla._is_unspecified(o)
for o in lowered._lowering.compile_args['out_shardings']])
else:
kw = {}
unloaded_compilation = lowered._lowering._compile_unloaded(**kw)
args_info_flat, in_tree = jax.tree_util.tree_flatten(lowered.args_info)
unloaded_executable = getattr(compiled._executable,
'_unloaded_executable', None)
if unloaded_executable is None:
raise ValueError("Compilation does not support serialization")
args_info_flat, in_tree = jax.tree_util.tree_flatten(compiled.args_info)
with io.BytesIO() as file:
_JaxPjrtPickler(file).dump(
(unloaded_compilation, args_info_flat, lowered._no_kwargs))
return file.getvalue(), in_tree, lowered.out_tree
(unloaded_executable, args_info_flat, compiled._no_kwargs))
return file.getvalue(), in_tree, compiled.out_tree
def load_compiled(serialized,
in_tree,
out_tree,
backend: Optional[Union[str, xc.Client]] = None):
def deserialize_and_load(serialized,
in_tree,
out_tree,
backend: Optional[Union[str, xc.Client]] = None):
"""Constructs a jax.stages.Compiled from a serialized executable."""
if backend is None or isinstance(backend, str):
backend = jax.devices(backend)[0].client
(unloaded_compilation, args_info_flat,
(unloaded_executable, args_info_flat,
no_kwargs) = _JaxPjrtUnpickler(io.BytesIO(serialized), backend).load()
args_info = in_tree.unflatten(args_info_flat)
loaded_compiled_obj = unloaded_compilation.load()
loaded_compiled_obj = unloaded_executable.load()
return jax.stages.Compiled(
loaded_compiled_obj, args_info, out_tree, no_kwargs=no_kwargs)

View File

@ -32,7 +32,7 @@ from jax._src.util import safe_zip
from jax.interpreters import pxla
from jax.experimental.pjit import pjit
from jax.experimental.serialize_executable import (
compile_and_serialize, load_compiled)
serialize, deserialize_and_load)
from jax.experimental import multihost_utils
from jax.sharding import PartitionSpec as P
from jax._src import array
@ -1033,8 +1033,8 @@ class RngShardingTest(jtu.JaxTestCase):
).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32))
def verify_serialization(lowered):
serialized, in_tree, out_tree = compile_and_serialize(lowered)
compiled = load_compiled(serialized, in_tree, out_tree)
serialized, in_tree, out_tree = serialize(lowered.compile())
compiled = deserialize_and_load(serialized, in_tree, out_tree)
self.assertEqual(compiled.as_text(), lowered.compile().as_text())
verify_serialization(lowered)