diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 02075f154..1c8a37c58 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1299,9 +1299,6 @@ class PmapComputation(stages.XlaLowering): self._hlo = hlo self.compile_args = compile_args - def _compile_unloaded(self) -> Union[UnloadedPmapExecutable, PmapExecutable]: - return UnloadedPmapExecutable.from_hlo(self._hlo, **self.compile_args) - # -- stages.XlaLowering overrides def hlo(self) -> xc.XlaComputation: @@ -1319,10 +1316,8 @@ class PmapComputation(stages.XlaLowering): @profiler.annotate_function def compile(self) -> PmapExecutable: if self._executable is None: - executable = self._compile_unloaded() - if isinstance(executable, UnloadedPmapExecutable): - executable = executable.load() - self._executable = executable + self._executable = UnloadedPmapExecutable.from_hlo( + self._hlo, **self.compile_args) return self._executable @@ -1471,9 +1466,9 @@ class UnloadedPmapExecutable: ordered_effects=ordered_effects, keepalive=keepalive, host_callbacks=host_callbacks, - ) + ).load() - def load(self) -> PmapExecutable: + def build_execute_fun(self): input_indices = [] for aval, spec in safe_zip(self.local_input_avals, self.input_shardings): assert isinstance(spec, sharding_impls.PmapSharding), spec @@ -1489,10 +1484,13 @@ class UnloadedPmapExecutable: self.ordered_effects, self.keepalive, bool(self.host_callbacks), set(range(len(input_indices)))) + return execute_fun + + def load(self) -> PmapExecutable: fingerprint = getattr(self.compiled, "fingerprint", None) - return PmapExecutable(self.compiled, execute_fun, fingerprint, - self.local_input_avals) + return PmapExecutable(self.compiled, self.build_execute_fun, fingerprint, + self.local_input_avals, self) def _compile_replicated_pmap_executable_from_hlo( @@ -1507,17 +1505,27 @@ def _compile_replicated_pmap_executable_from_hlo( in_indices=input_indices, in_shardings=in_shardings, kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs) # TODO(frostig): need `compile_replicated` to give us the XLA executable - return PmapExecutable(None, execute_fun, None, pci.avals) + return PmapExecutable(None, lambda: execute_fun, None, pci.avals, None) class PmapExecutable(stages.XlaExecutable): - __slots__ = ["xla_executable", "unsafe_call", "fingerprint", "in_avals"] + __slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call", + "fingerprint", "in_avals", "_unloaded_executable"] - def __init__(self, xla_executable, unsafe_call, fingerprint, in_avals): + def __init__(self, xla_executable, build_unsafe_call, fingerprint, + in_avals, unloaded_executable): self.xla_executable = xla_executable - self.unsafe_call = unsafe_call + self._unsafe_call = None + self.build_unsafe_call = build_unsafe_call self.fingerprint = fingerprint self.in_avals = in_avals + self._unloaded_executable = unloaded_executable + + @property + def unsafe_call(self) -> Callable[..., Any]: + if self._unsafe_call is None: + self._unsafe_call = self.build_unsafe_call() + return self._unsafe_call # -- stages.XlaExecutable overrides @@ -1529,7 +1537,7 @@ class PmapExecutable(stages.XlaExecutable): # TODO(frostig): do we need to check sharding and sharded avals? arg_avals = map(xla.abstractify, args) check_arg_avals_for_call(self.in_avals, arg_avals) - return self.unsafe_call(*args) + return self.unsafe_call(*args) # pylint: disable=not-callable def _get_pmap_sharding(devices, specs): @@ -2804,21 +2812,6 @@ class MeshComputation(stages.XlaLowering): self.compile_args = compile_args self._executable = None - def _compile_unloaded( - self, - _allow_propagation_to_outputs: Optional[Sequence[bool]] = None, - _allow_compile_replicated: bool = True - ) -> Union[UnloadedMeshExecutable, MeshExecutable]: - if self.is_trivial: - return MeshExecutable.from_trivial_jaxpr(**self.compile_args) - else: - return UnloadedMeshExecutable.from_hlo( - self._name, - self._hlo, - **self.compile_args, - _allow_propagation_to_outputs=_allow_propagation_to_outputs, - _allow_compile_replicated=_allow_compile_replicated) # type: ignore - # -- stages.XlaLowering overrides def hlo(self) -> xc.XlaComputation: @@ -2841,11 +2834,16 @@ class MeshComputation(stages.XlaLowering): _allow_propagation_to_outputs: Optional[Sequence[bool]] = None, _allow_compile_replicated: bool = True) -> MeshExecutable: if self._executable is None: - executable = self._compile_unloaded( - _allow_propagation_to_outputs, _allow_compile_replicated) - if isinstance(executable, UnloadedMeshExecutable): - executable = executable.load() - self._executable = executable + if self.is_trivial: + self._executable = MeshExecutable.from_trivial_jaxpr( + **self.compile_args) + else: + self._executable = UnloadedMeshExecutable.from_hlo( + self._name, + self._hlo, + **self.compile_args, + _allow_propagation_to_outputs=_allow_propagation_to_outputs, + _allow_compile_replicated=_allow_compile_replicated) return self._executable def cost_analysis(self) -> Dict[str, float]: @@ -2952,7 +2950,7 @@ class UnloadedMeshExecutable: kept_var_idx: Set[int] auto_spmd_lowering: bool - def load(self) -> MeshExecutable: + def build_unsafe_call(self): input_indices = _get_input_indices(self.input_avals, self.input_shardings) handle_args = InputsHandler(self.xla_executable.local_devices(), self.input_shardings, input_indices) @@ -2964,11 +2962,14 @@ class UnloadedMeshExecutable: self.xla_executable, self.name, self.backend, handle_args, handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive, bool(self.host_callbacks), self.kept_var_idx) + return unsafe_call - return MeshExecutable(self.xla_executable, unsafe_call, self.input_avals, + def load(self) -> MeshExecutable: + return MeshExecutable(self.xla_executable, self.build_unsafe_call, + self.input_avals, self.input_shardings, self.output_shardings, self.auto_spmd_lowering, self.kept_var_idx, - self.device_assignment) + self.device_assignment, self) # May return a MeshExecutable in the compile_replicated case. @staticmethod @@ -2996,7 +2997,7 @@ class UnloadedMeshExecutable: device_assignment: Sequence[xc.Device], committed: bool, pmap_nreps: int = 1 - ) -> Union[MeshExecutable, UnloadedMeshExecutable]: + ) -> MeshExecutable: dev: np.ndarray if auto_spmd_lowering: @@ -3119,7 +3120,7 @@ class UnloadedMeshExecutable: keepalive=keepalive, host_callbacks=host_callbacks, kept_var_idx=kept_var_idx, - auto_spmd_lowering=auto_spmd_lowering) + auto_spmd_lowering=auto_spmd_lowering).load() class MeshExecutableFastpathData(NamedTuple): @@ -3134,24 +3135,35 @@ class MeshExecutableFastpathData(NamedTuple): class MeshExecutable(stages.XlaExecutable): __slots__ = [ - "xla_executable", "unsafe_call", "in_avals", "_in_shardings", - "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx", - "_device_assignment" + "xla_executable", "_unsafe_call", + "build_unsafe_call", "in_avals", + "_in_shardings", "_out_shardings", + "_auto_spmd_lowering", "_kept_var_idx", + "_device_assignment", + "_unloaded_executable", ] - def __init__(self, xla_executable, unsafe_call, in_avals, in_shardings, + def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, - device_assignment): + device_assignment, unloaded_executable=None): self.xla_executable = xla_executable - self.unsafe_call = unsafe_call + self.build_unsafe_call = build_unsafe_call # in_avals is a list of global and local avals. Aval is global if input # is a GDA or jax.Array else local. self.in_avals = in_avals + self._unsafe_call = None self._in_shardings = in_shardings self._out_shardings = out_shardings self._auto_spmd_lowering = auto_spmd_lowering self._kept_var_idx = kept_var_idx self._device_assignment = device_assignment + self._unloaded_executable = unloaded_executable + + @property + def unsafe_call(self) -> Callable[..., Any]: + if self._unsafe_call is None: + self._unsafe_call = self.build_unsafe_call() + return self._unsafe_call @staticmethod def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals, @@ -3174,8 +3186,9 @@ class MeshExecutable(stages.XlaExecutable): [False] * len(global_out_avals)) unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins, handle_outs, kept_var_idx) - return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings, - out_shardings, False, kept_var_idx, device_assignment) + return MeshExecutable(None, lambda: unsafe_call, global_in_avals, + in_shardings, out_shardings, False, kept_var_idx, + device_assignment, None) # -- stages.XlaExecutable overrides @@ -3189,7 +3202,7 @@ class MeshExecutable(stages.XlaExecutable): check_arg_avals_for_call(ref_avals, arg_avals) # Check the GDA sharding and the input sharding. check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings) - return self.unsafe_call(*args) + return self.unsafe_call(*args) # pylint: disable=not-callable def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]: return self._in_shardings @@ -3297,9 +3310,9 @@ def _compile_replicated_mesh_executable_from_hlo( out_avals=global_out_avals, out_shardings=out_shardings, committed=committed, pmap_nreps=pmap_nreps) xla_executable = None - return MeshExecutable(xla_executable, unsafe_call, global_in_avals, + return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals, in_shardings, out_shardings, auto_spmd_lowering, - kept_var_idx, device_assignment) + kept_var_idx, device_assignment, None) def _compile_replicated_mesh_executable_from_trivial_jaxpr( @@ -3320,9 +3333,9 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr( in_indices=input_indices, in_shardings=in_shardings, kept_var_idx=kept_var_idx, out_handler=handle_outs, out_shardings=out_shardings, pmap_nreps=pmap_nreps) - return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings, - out_shardings, False, kept_var_idx, - device_assignment) + return MeshExecutable(None, lambda: unsafe_call, global_in_avals, + in_shardings, out_shardings, False, kept_var_idx, + device_assignment, None) @lru_cache() diff --git a/jax/experimental/serialize_executable.py b/jax/experimental/serialize_executable.py index d7b3a6fba..75b63a09d 100644 --- a/jax/experimental/serialize_executable.py +++ b/jax/experimental/serialize_executable.py @@ -21,46 +21,39 @@ import jax from jax._src.lib import xla_client as xc -def compile_and_serialize(lowered: jax.stages.Lowered): - """Compiles a lowered executable, and then serializes the resulting binary. +def serialize(compiled: jax.stages.Compiled): + """Serializes a compiled binary. Because pytrees are not serializable, they are returned so that the user can handle them properly. """ - - from jax.interpreters import pxla - - if isinstance(lowered._lowering, pxla.MeshComputation): - kw = dict(_allow_propagation_to_outputs=[ - pxla._is_unspecified(o) - for o in lowered._lowering.compile_args['out_shardings']]) - else: - kw = {} - - unloaded_compilation = lowered._lowering._compile_unloaded(**kw) - args_info_flat, in_tree = jax.tree_util.tree_flatten(lowered.args_info) + unloaded_executable = getattr(compiled._executable, + '_unloaded_executable', None) + if unloaded_executable is None: + raise ValueError("Compilation does not support serialization") + args_info_flat, in_tree = jax.tree_util.tree_flatten(compiled.args_info) with io.BytesIO() as file: _JaxPjrtPickler(file).dump( - (unloaded_compilation, args_info_flat, lowered._no_kwargs)) - return file.getvalue(), in_tree, lowered.out_tree + (unloaded_executable, args_info_flat, compiled._no_kwargs)) + return file.getvalue(), in_tree, compiled.out_tree -def load_compiled(serialized, - in_tree, - out_tree, - backend: Optional[Union[str, xc.Client]] = None): +def deserialize_and_load(serialized, + in_tree, + out_tree, + backend: Optional[Union[str, xc.Client]] = None): """Constructs a jax.stages.Compiled from a serialized executable.""" if backend is None or isinstance(backend, str): backend = jax.devices(backend)[0].client - (unloaded_compilation, args_info_flat, + (unloaded_executable, args_info_flat, no_kwargs) = _JaxPjrtUnpickler(io.BytesIO(serialized), backend).load() args_info = in_tree.unflatten(args_info_flat) - loaded_compiled_obj = unloaded_compilation.load() + loaded_compiled_obj = unloaded_executable.load() return jax.stages.Compiled( loaded_compiled_obj, args_info, out_tree, no_kwargs=no_kwargs) diff --git a/tests/array_test.py b/tests/array_test.py index a04e2dae5..632e15f84 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -32,7 +32,7 @@ from jax._src.util import safe_zip from jax.interpreters import pxla from jax.experimental.pjit import pjit from jax.experimental.serialize_executable import ( - compile_and_serialize, load_compiled) + serialize, deserialize_and_load) from jax.experimental import multihost_utils from jax.sharding import PartitionSpec as P from jax._src import array @@ -1033,8 +1033,8 @@ class RngShardingTest(jtu.JaxTestCase): ).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32)) def verify_serialization(lowered): - serialized, in_tree, out_tree = compile_and_serialize(lowered) - compiled = load_compiled(serialized, in_tree, out_tree) + serialized, in_tree, out_tree = serialize(lowered.compile()) + compiled = deserialize_and_load(serialized, in_tree, out_tree) self.assertEqual(compiled.as_text(), lowered.compile().as_text()) verify_serialization(lowered)