mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Redefine compile_and_serialize
as serialize(lowered.compile())
.
This has the downside of keeping around the UnloadedMeshComputation, but it makes the serialize() API easier to understand. PiperOrigin-RevId: 518715469
This commit is contained in:
parent
5f724cf9a2
commit
484eb26d2a
@ -1299,9 +1299,6 @@ class PmapComputation(stages.XlaLowering):
|
||||
self._hlo = hlo
|
||||
self.compile_args = compile_args
|
||||
|
||||
def _compile_unloaded(self) -> Union[UnloadedPmapExecutable, PmapExecutable]:
|
||||
return UnloadedPmapExecutable.from_hlo(self._hlo, **self.compile_args)
|
||||
|
||||
# -- stages.XlaLowering overrides
|
||||
|
||||
def hlo(self) -> xc.XlaComputation:
|
||||
@ -1319,10 +1316,8 @@ class PmapComputation(stages.XlaLowering):
|
||||
@profiler.annotate_function
|
||||
def compile(self) -> PmapExecutable:
|
||||
if self._executable is None:
|
||||
executable = self._compile_unloaded()
|
||||
if isinstance(executable, UnloadedPmapExecutable):
|
||||
executable = executable.load()
|
||||
self._executable = executable
|
||||
self._executable = UnloadedPmapExecutable.from_hlo(
|
||||
self._hlo, **self.compile_args)
|
||||
return self._executable
|
||||
|
||||
|
||||
@ -1471,9 +1466,9 @@ class UnloadedPmapExecutable:
|
||||
ordered_effects=ordered_effects,
|
||||
keepalive=keepalive,
|
||||
host_callbacks=host_callbacks,
|
||||
)
|
||||
).load()
|
||||
|
||||
def load(self) -> PmapExecutable:
|
||||
def build_execute_fun(self):
|
||||
input_indices = []
|
||||
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
|
||||
assert isinstance(spec, sharding_impls.PmapSharding), spec
|
||||
@ -1489,10 +1484,13 @@ class UnloadedPmapExecutable:
|
||||
self.ordered_effects, self.keepalive,
|
||||
bool(self.host_callbacks),
|
||||
set(range(len(input_indices))))
|
||||
return execute_fun
|
||||
|
||||
def load(self) -> PmapExecutable:
|
||||
fingerprint = getattr(self.compiled, "fingerprint", None)
|
||||
|
||||
return PmapExecutable(self.compiled, execute_fun, fingerprint,
|
||||
self.local_input_avals)
|
||||
return PmapExecutable(self.compiled, self.build_execute_fun, fingerprint,
|
||||
self.local_input_avals, self)
|
||||
|
||||
|
||||
def _compile_replicated_pmap_executable_from_hlo(
|
||||
@ -1507,17 +1505,27 @@ def _compile_replicated_pmap_executable_from_hlo(
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=set(range(len(pci.avals))), out_handler=handle_outs)
|
||||
# TODO(frostig): need `compile_replicated` to give us the XLA executable
|
||||
return PmapExecutable(None, execute_fun, None, pci.avals)
|
||||
return PmapExecutable(None, lambda: execute_fun, None, pci.avals, None)
|
||||
|
||||
|
||||
class PmapExecutable(stages.XlaExecutable):
|
||||
__slots__ = ["xla_executable", "unsafe_call", "fingerprint", "in_avals"]
|
||||
__slots__ = ["xla_executable", "_unsafe_call", "build_unsafe_call",
|
||||
"fingerprint", "in_avals", "_unloaded_executable"]
|
||||
|
||||
def __init__(self, xla_executable, unsafe_call, fingerprint, in_avals):
|
||||
def __init__(self, xla_executable, build_unsafe_call, fingerprint,
|
||||
in_avals, unloaded_executable):
|
||||
self.xla_executable = xla_executable
|
||||
self.unsafe_call = unsafe_call
|
||||
self._unsafe_call = None
|
||||
self.build_unsafe_call = build_unsafe_call
|
||||
self.fingerprint = fingerprint
|
||||
self.in_avals = in_avals
|
||||
self._unloaded_executable = unloaded_executable
|
||||
|
||||
@property
|
||||
def unsafe_call(self) -> Callable[..., Any]:
|
||||
if self._unsafe_call is None:
|
||||
self._unsafe_call = self.build_unsafe_call()
|
||||
return self._unsafe_call
|
||||
|
||||
# -- stages.XlaExecutable overrides
|
||||
|
||||
@ -1529,7 +1537,7 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
# TODO(frostig): do we need to check sharding and sharded avals?
|
||||
arg_avals = map(xla.abstractify, args)
|
||||
check_arg_avals_for_call(self.in_avals, arg_avals)
|
||||
return self.unsafe_call(*args)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
|
||||
def _get_pmap_sharding(devices, specs):
|
||||
@ -2804,21 +2812,6 @@ class MeshComputation(stages.XlaLowering):
|
||||
self.compile_args = compile_args
|
||||
self._executable = None
|
||||
|
||||
def _compile_unloaded(
|
||||
self,
|
||||
_allow_propagation_to_outputs: Optional[Sequence[bool]] = None,
|
||||
_allow_compile_replicated: bool = True
|
||||
) -> Union[UnloadedMeshExecutable, MeshExecutable]:
|
||||
if self.is_trivial:
|
||||
return MeshExecutable.from_trivial_jaxpr(**self.compile_args)
|
||||
else:
|
||||
return UnloadedMeshExecutable.from_hlo(
|
||||
self._name,
|
||||
self._hlo,
|
||||
**self.compile_args,
|
||||
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
|
||||
_allow_compile_replicated=_allow_compile_replicated) # type: ignore
|
||||
|
||||
# -- stages.XlaLowering overrides
|
||||
|
||||
def hlo(self) -> xc.XlaComputation:
|
||||
@ -2841,11 +2834,16 @@ class MeshComputation(stages.XlaLowering):
|
||||
_allow_propagation_to_outputs: Optional[Sequence[bool]] = None,
|
||||
_allow_compile_replicated: bool = True) -> MeshExecutable:
|
||||
if self._executable is None:
|
||||
executable = self._compile_unloaded(
|
||||
_allow_propagation_to_outputs, _allow_compile_replicated)
|
||||
if isinstance(executable, UnloadedMeshExecutable):
|
||||
executable = executable.load()
|
||||
self._executable = executable
|
||||
if self.is_trivial:
|
||||
self._executable = MeshExecutable.from_trivial_jaxpr(
|
||||
**self.compile_args)
|
||||
else:
|
||||
self._executable = UnloadedMeshExecutable.from_hlo(
|
||||
self._name,
|
||||
self._hlo,
|
||||
**self.compile_args,
|
||||
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
|
||||
_allow_compile_replicated=_allow_compile_replicated)
|
||||
return self._executable
|
||||
|
||||
def cost_analysis(self) -> Dict[str, float]:
|
||||
@ -2952,7 +2950,7 @@ class UnloadedMeshExecutable:
|
||||
kept_var_idx: Set[int]
|
||||
auto_spmd_lowering: bool
|
||||
|
||||
def load(self) -> MeshExecutable:
|
||||
def build_unsafe_call(self):
|
||||
input_indices = _get_input_indices(self.input_avals, self.input_shardings)
|
||||
handle_args = InputsHandler(self.xla_executable.local_devices(),
|
||||
self.input_shardings, input_indices)
|
||||
@ -2964,11 +2962,14 @@ class UnloadedMeshExecutable:
|
||||
self.xla_executable, self.name, self.backend, handle_args,
|
||||
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
|
||||
bool(self.host_callbacks), self.kept_var_idx)
|
||||
return unsafe_call
|
||||
|
||||
return MeshExecutable(self.xla_executable, unsafe_call, self.input_avals,
|
||||
def load(self) -> MeshExecutable:
|
||||
return MeshExecutable(self.xla_executable, self.build_unsafe_call,
|
||||
self.input_avals,
|
||||
self.input_shardings, self.output_shardings,
|
||||
self.auto_spmd_lowering, self.kept_var_idx,
|
||||
self.device_assignment)
|
||||
self.device_assignment, self)
|
||||
|
||||
# May return a MeshExecutable in the compile_replicated case.
|
||||
@staticmethod
|
||||
@ -2996,7 +2997,7 @@ class UnloadedMeshExecutable:
|
||||
device_assignment: Sequence[xc.Device],
|
||||
committed: bool,
|
||||
pmap_nreps: int = 1
|
||||
) -> Union[MeshExecutable, UnloadedMeshExecutable]:
|
||||
) -> MeshExecutable:
|
||||
|
||||
dev: np.ndarray
|
||||
if auto_spmd_lowering:
|
||||
@ -3119,7 +3120,7 @@ class UnloadedMeshExecutable:
|
||||
keepalive=keepalive,
|
||||
host_callbacks=host_callbacks,
|
||||
kept_var_idx=kept_var_idx,
|
||||
auto_spmd_lowering=auto_spmd_lowering)
|
||||
auto_spmd_lowering=auto_spmd_lowering).load()
|
||||
|
||||
|
||||
class MeshExecutableFastpathData(NamedTuple):
|
||||
@ -3134,24 +3135,35 @@ class MeshExecutableFastpathData(NamedTuple):
|
||||
|
||||
class MeshExecutable(stages.XlaExecutable):
|
||||
__slots__ = [
|
||||
"xla_executable", "unsafe_call", "in_avals", "_in_shardings",
|
||||
"_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
|
||||
"_device_assignment"
|
||||
"xla_executable", "_unsafe_call",
|
||||
"build_unsafe_call", "in_avals",
|
||||
"_in_shardings", "_out_shardings",
|
||||
"_auto_spmd_lowering", "_kept_var_idx",
|
||||
"_device_assignment",
|
||||
"_unloaded_executable",
|
||||
]
|
||||
|
||||
def __init__(self, xla_executable, unsafe_call, in_avals, in_shardings,
|
||||
def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
|
||||
out_shardings, auto_spmd_lowering, kept_var_idx,
|
||||
device_assignment):
|
||||
device_assignment, unloaded_executable=None):
|
||||
self.xla_executable = xla_executable
|
||||
self.unsafe_call = unsafe_call
|
||||
self.build_unsafe_call = build_unsafe_call
|
||||
# in_avals is a list of global and local avals. Aval is global if input
|
||||
# is a GDA or jax.Array else local.
|
||||
self.in_avals = in_avals
|
||||
self._unsafe_call = None
|
||||
self._in_shardings = in_shardings
|
||||
self._out_shardings = out_shardings
|
||||
self._auto_spmd_lowering = auto_spmd_lowering
|
||||
self._kept_var_idx = kept_var_idx
|
||||
self._device_assignment = device_assignment
|
||||
self._unloaded_executable = unloaded_executable
|
||||
|
||||
@property
|
||||
def unsafe_call(self) -> Callable[..., Any]:
|
||||
if self._unsafe_call is None:
|
||||
self._unsafe_call = self.build_unsafe_call()
|
||||
return self._unsafe_call
|
||||
|
||||
@staticmethod
|
||||
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
|
||||
@ -3174,8 +3186,9 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
[False] * len(global_out_avals))
|
||||
unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins,
|
||||
handle_outs, kept_var_idx)
|
||||
return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings,
|
||||
out_shardings, False, kept_var_idx, device_assignment)
|
||||
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, False, kept_var_idx,
|
||||
device_assignment, None)
|
||||
|
||||
# -- stages.XlaExecutable overrides
|
||||
|
||||
@ -3189,7 +3202,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
check_arg_avals_for_call(ref_avals, arg_avals)
|
||||
# Check the GDA sharding and the input sharding.
|
||||
check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings)
|
||||
return self.unsafe_call(*args)
|
||||
return self.unsafe_call(*args) # pylint: disable=not-callable
|
||||
|
||||
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
|
||||
return self._in_shardings
|
||||
@ -3297,9 +3310,9 @@ def _compile_replicated_mesh_executable_from_hlo(
|
||||
out_avals=global_out_avals, out_shardings=out_shardings,
|
||||
committed=committed, pmap_nreps=pmap_nreps)
|
||||
xla_executable = None
|
||||
return MeshExecutable(xla_executable, unsafe_call, global_in_avals,
|
||||
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, auto_spmd_lowering,
|
||||
kept_var_idx, device_assignment)
|
||||
kept_var_idx, device_assignment, None)
|
||||
|
||||
|
||||
def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
||||
@ -3320,9 +3333,9 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
||||
in_indices=input_indices, in_shardings=in_shardings,
|
||||
kept_var_idx=kept_var_idx, out_handler=handle_outs,
|
||||
out_shardings=out_shardings, pmap_nreps=pmap_nreps)
|
||||
return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings,
|
||||
out_shardings, False, kept_var_idx,
|
||||
device_assignment)
|
||||
return MeshExecutable(None, lambda: unsafe_call, global_in_avals,
|
||||
in_shardings, out_shardings, False, kept_var_idx,
|
||||
device_assignment, None)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
|
@ -21,46 +21,39 @@ import jax
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
def compile_and_serialize(lowered: jax.stages.Lowered):
|
||||
"""Compiles a lowered executable, and then serializes the resulting binary.
|
||||
def serialize(compiled: jax.stages.Compiled):
|
||||
"""Serializes a compiled binary.
|
||||
|
||||
Because pytrees are not serializable, they are returned so that
|
||||
the user can handle them properly.
|
||||
"""
|
||||
|
||||
from jax.interpreters import pxla
|
||||
|
||||
if isinstance(lowered._lowering, pxla.MeshComputation):
|
||||
kw = dict(_allow_propagation_to_outputs=[
|
||||
pxla._is_unspecified(o)
|
||||
for o in lowered._lowering.compile_args['out_shardings']])
|
||||
else:
|
||||
kw = {}
|
||||
|
||||
unloaded_compilation = lowered._lowering._compile_unloaded(**kw)
|
||||
args_info_flat, in_tree = jax.tree_util.tree_flatten(lowered.args_info)
|
||||
unloaded_executable = getattr(compiled._executable,
|
||||
'_unloaded_executable', None)
|
||||
if unloaded_executable is None:
|
||||
raise ValueError("Compilation does not support serialization")
|
||||
args_info_flat, in_tree = jax.tree_util.tree_flatten(compiled.args_info)
|
||||
|
||||
with io.BytesIO() as file:
|
||||
_JaxPjrtPickler(file).dump(
|
||||
(unloaded_compilation, args_info_flat, lowered._no_kwargs))
|
||||
return file.getvalue(), in_tree, lowered.out_tree
|
||||
(unloaded_executable, args_info_flat, compiled._no_kwargs))
|
||||
return file.getvalue(), in_tree, compiled.out_tree
|
||||
|
||||
|
||||
def load_compiled(serialized,
|
||||
in_tree,
|
||||
out_tree,
|
||||
backend: Optional[Union[str, xc.Client]] = None):
|
||||
def deserialize_and_load(serialized,
|
||||
in_tree,
|
||||
out_tree,
|
||||
backend: Optional[Union[str, xc.Client]] = None):
|
||||
"""Constructs a jax.stages.Compiled from a serialized executable."""
|
||||
|
||||
if backend is None or isinstance(backend, str):
|
||||
backend = jax.devices(backend)[0].client
|
||||
|
||||
(unloaded_compilation, args_info_flat,
|
||||
(unloaded_executable, args_info_flat,
|
||||
no_kwargs) = _JaxPjrtUnpickler(io.BytesIO(serialized), backend).load()
|
||||
|
||||
args_info = in_tree.unflatten(args_info_flat)
|
||||
|
||||
loaded_compiled_obj = unloaded_compilation.load()
|
||||
loaded_compiled_obj = unloaded_executable.load()
|
||||
|
||||
return jax.stages.Compiled(
|
||||
loaded_compiled_obj, args_info, out_tree, no_kwargs=no_kwargs)
|
||||
|
@ -32,7 +32,7 @@ from jax._src.util import safe_zip
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental.serialize_executable import (
|
||||
compile_and_serialize, load_compiled)
|
||||
serialize, deserialize_and_load)
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import array
|
||||
@ -1033,8 +1033,8 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
).lower(core.ShapedArray(shape=(8, 8), dtype=np.float32))
|
||||
|
||||
def verify_serialization(lowered):
|
||||
serialized, in_tree, out_tree = compile_and_serialize(lowered)
|
||||
compiled = load_compiled(serialized, in_tree, out_tree)
|
||||
serialized, in_tree, out_tree = serialize(lowered.compile())
|
||||
compiled = deserialize_and_load(serialized, in_tree, out_tree)
|
||||
self.assertEqual(compiled.as_text(), lowered.compile().as_text())
|
||||
|
||||
verify_serialization(lowered)
|
||||
|
Loading…
x
Reference in New Issue
Block a user