Allow compiling and then serializing jax.stages.Lowered.

This adds experimental APIs to `serialize_executable.py`:

`compile_and_serialize(lowered)`
and
`load_compiled(serialized, in_tree, out_tree)`

for serializing and deserializing executables.

PiperOrigin-RevId: 489014705
This commit is contained in:
Parker Schuh 2022-11-16 12:53:32 -08:00 committed by jax authors
parent fb4db5b60f
commit da765a2e54
2 changed files with 142 additions and 0 deletions

View File

@ -0,0 +1,104 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pickling support for precompiled binaries."""
import pickle
import io
from typing import Optional, Union
import jax
from jax._src.lib import xla_client as xc
def compile_and_serialize(lowered: jax.stages.Lowered):
"""Compiles a lowered executable, and then serializes the resulting binary.
Because pytrees are not serializable, they are returned so that
the user can handle them properly.
"""
from jax.interpreters import pxla
if (jax.config.jax_array and
isinstance(lowered._lowering, pxla.MeshComputation) and all(
pxla._is_unspecified(o)
for o in lowered._lowering.compile_args['out_shardings'])):
kw = dict(_allow_propagation_to_outputs=True)
else:
kw = {}
unloaded_compilation = lowered._lowering._compile_unloaded(**kw)
args_info_flat, in_tree = jax.tree_util.tree_flatten(lowered.args_info)
with io.BytesIO() as file:
_JaxPjrtPickler(file).dump(
(unloaded_compilation, args_info_flat, lowered._no_kwargs))
return file.getvalue(), in_tree, lowered.out_tree
def load_compiled(serialized,
in_tree,
out_tree,
backend: Optional[Union[str, xc.Client]] = None):
"""Constructs a jax.stages.Compiled from a serialized executable."""
if backend is None or isinstance(backend, str):
backend = jax.devices(backend)[0].client
(unloaded_compilation, args_info_flat,
no_kwargs) = _JaxPjrtUnpickler(io.BytesIO(serialized), backend).load()
args_info = in_tree.unflatten(args_info_flat)
loaded_compiled_obj = unloaded_compilation.load()
return jax.stages.Compiled(
loaded_compiled_obj,
args_info,
out_tree,
no_kwargs=no_kwargs,
create_cpp_call=None)
class _JaxPjrtPickler(pickle.Pickler):
device_types = (xc.Device,)
client_types = (xc.Client,)
def persistent_id(self, obj):
if isinstance(obj, xc.LoadedExecutable):
return ('exec', obj.client.serialize_executable(obj),
obj.compile_options())
if isinstance(obj, xc._xla.Executable):
return ('exec', obj.serialize(), obj.compile_options())
if isinstance(obj, self.device_types):
return ('device', obj.id)
if isinstance(obj, self.client_types):
return ('client',)
class _JaxPjrtUnpickler(pickle.Unpickler):
def __init__(self, file, backend):
super().__init__(file)
self.backend = backend
self.devices_by_id = {d.id: d for d in backend.devices()}
def persistent_load(self, pid):
if pid[0] == 'exec':
return self.backend.deserialize_executable(pid[1], pid[2])
if pid[0] == 'device':
return self.devices_by_id[pid[1]]
if pid[0] == 'client':
return self.backend
raise pickle.UnpicklingError

View File

@ -14,6 +14,7 @@
"""Tests for GlobalDeviceArray."""
import os
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
@ -28,9 +29,12 @@ from jax._src.util import prod, safe_zip
from jax.interpreters import pxla
from jax.experimental.pjit import pjit
from jax.experimental import PartitionSpec as P
from jax.experimental.serialize_executable import (
compile_and_serialize, load_compiled)
from jax._src import sharding
from jax._src import array
from jax._src import prng
from jax._src.lib import xla_extension_version
from jax.experimental import maps
from jax.config import config
@ -899,5 +903,39 @@ class RngShardingTest(jtu.JaxTestCase):
y_ref2 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref2)
@unittest.skipIf(xla_extension_version < 106,
'Pjit pickling requires newer jaxlib.')
def test_pickle_pjit_lower(self):
example_exe = jax.jit(lambda x: x * x).lower(
jax.core.ShapedArray(
(2, 2), dtype=np.float32)).compile()._executable.xla_executable
# Skip if CompileOptions is not available. This is true on
# CPU/GPU/Cloud TPU for now.
try:
example_exe.compile_options()
except Exception as e:
if str(e) == 'UNIMPLEMENTED: CompileOptions not available.':
raise unittest.SkipTest('Serialization not supported')
raise e
def fun(x):
return x * x
with maps.Mesh(np.array(jax.devices()), ('data',)):
lowered = pjit(
fun,
in_axis_resources=P('data'),
out_axis_resources=P(None, 'data'),
).lower(jax.ShapedArray(shape=(8, 8), dtype=np.float32))
def verify_serialization(lowered):
serialized, in_tree, out_tree = compile_and_serialize(lowered)
compiled = load_compiled(serialized, in_tree, out_tree)
self.assertEqual(compiled.as_text(), lowered.compile().as_text())
verify_serialization(lowered)
verify_serialization(jax.jit(lambda x: x * x).lower(np.arange(100)))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())