mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Allow compiling and then serializing jax.stages.Lowered.
This adds experimental APIs to `serialize_executable.py`: `compile_and_serialize(lowered)` and `load_compiled(serialized, in_tree, out_tree)` for serializing and deserializing executables. PiperOrigin-RevId: 489014705
This commit is contained in:
parent
fb4db5b60f
commit
da765a2e54
104
jax/experimental/serialize_executable.py
Normal file
104
jax/experimental/serialize_executable.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pickling support for precompiled binaries."""
|
||||
|
||||
import pickle
|
||||
import io
|
||||
from typing import Optional, Union
|
||||
|
||||
import jax
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
def compile_and_serialize(lowered: jax.stages.Lowered):
|
||||
"""Compiles a lowered executable, and then serializes the resulting binary.
|
||||
|
||||
Because pytrees are not serializable, they are returned so that
|
||||
the user can handle them properly.
|
||||
"""
|
||||
|
||||
from jax.interpreters import pxla
|
||||
|
||||
if (jax.config.jax_array and
|
||||
isinstance(lowered._lowering, pxla.MeshComputation) and all(
|
||||
pxla._is_unspecified(o)
|
||||
for o in lowered._lowering.compile_args['out_shardings'])):
|
||||
kw = dict(_allow_propagation_to_outputs=True)
|
||||
else:
|
||||
kw = {}
|
||||
|
||||
unloaded_compilation = lowered._lowering._compile_unloaded(**kw)
|
||||
args_info_flat, in_tree = jax.tree_util.tree_flatten(lowered.args_info)
|
||||
|
||||
with io.BytesIO() as file:
|
||||
_JaxPjrtPickler(file).dump(
|
||||
(unloaded_compilation, args_info_flat, lowered._no_kwargs))
|
||||
return file.getvalue(), in_tree, lowered.out_tree
|
||||
|
||||
|
||||
def load_compiled(serialized,
|
||||
in_tree,
|
||||
out_tree,
|
||||
backend: Optional[Union[str, xc.Client]] = None):
|
||||
"""Constructs a jax.stages.Compiled from a serialized executable."""
|
||||
|
||||
if backend is None or isinstance(backend, str):
|
||||
backend = jax.devices(backend)[0].client
|
||||
|
||||
(unloaded_compilation, args_info_flat,
|
||||
no_kwargs) = _JaxPjrtUnpickler(io.BytesIO(serialized), backend).load()
|
||||
|
||||
args_info = in_tree.unflatten(args_info_flat)
|
||||
|
||||
loaded_compiled_obj = unloaded_compilation.load()
|
||||
|
||||
return jax.stages.Compiled(
|
||||
loaded_compiled_obj,
|
||||
args_info,
|
||||
out_tree,
|
||||
no_kwargs=no_kwargs,
|
||||
create_cpp_call=None)
|
||||
|
||||
|
||||
class _JaxPjrtPickler(pickle.Pickler):
|
||||
device_types = (xc.Device,)
|
||||
client_types = (xc.Client,)
|
||||
|
||||
def persistent_id(self, obj):
|
||||
if isinstance(obj, xc.LoadedExecutable):
|
||||
return ('exec', obj.client.serialize_executable(obj),
|
||||
obj.compile_options())
|
||||
if isinstance(obj, xc._xla.Executable):
|
||||
return ('exec', obj.serialize(), obj.compile_options())
|
||||
if isinstance(obj, self.device_types):
|
||||
return ('device', obj.id)
|
||||
if isinstance(obj, self.client_types):
|
||||
return ('client',)
|
||||
|
||||
|
||||
class _JaxPjrtUnpickler(pickle.Unpickler):
|
||||
|
||||
def __init__(self, file, backend):
|
||||
super().__init__(file)
|
||||
self.backend = backend
|
||||
self.devices_by_id = {d.id: d for d in backend.devices()}
|
||||
|
||||
def persistent_load(self, pid):
|
||||
if pid[0] == 'exec':
|
||||
return self.backend.deserialize_executable(pid[1], pid[2])
|
||||
if pid[0] == 'device':
|
||||
return self.devices_by_id[pid[1]]
|
||||
if pid[0] == 'client':
|
||||
return self.backend
|
||||
raise pickle.UnpicklingError
|
@ -14,6 +14,7 @@
|
||||
"""Tests for GlobalDeviceArray."""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
@ -28,9 +29,12 @@ from jax._src.util import prod, safe_zip
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.serialize_executable import (
|
||||
compile_and_serialize, load_compiled)
|
||||
from jax._src import sharding
|
||||
from jax._src import array
|
||||
from jax._src import prng
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.experimental import maps
|
||||
|
||||
from jax.config import config
|
||||
@ -899,5 +903,39 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
y_ref2 = f(jax.device_put(x, jax.devices()[0]))
|
||||
self.assertArraysEqual(y, y_ref2)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 106,
|
||||
'Pjit pickling requires newer jaxlib.')
|
||||
def test_pickle_pjit_lower(self):
|
||||
example_exe = jax.jit(lambda x: x * x).lower(
|
||||
jax.core.ShapedArray(
|
||||
(2, 2), dtype=np.float32)).compile()._executable.xla_executable
|
||||
|
||||
# Skip if CompileOptions is not available. This is true on
|
||||
# CPU/GPU/Cloud TPU for now.
|
||||
try:
|
||||
example_exe.compile_options()
|
||||
except Exception as e:
|
||||
if str(e) == 'UNIMPLEMENTED: CompileOptions not available.':
|
||||
raise unittest.SkipTest('Serialization not supported')
|
||||
raise e
|
||||
|
||||
def fun(x):
|
||||
return x * x
|
||||
|
||||
with maps.Mesh(np.array(jax.devices()), ('data',)):
|
||||
lowered = pjit(
|
||||
fun,
|
||||
in_axis_resources=P('data'),
|
||||
out_axis_resources=P(None, 'data'),
|
||||
).lower(jax.ShapedArray(shape=(8, 8), dtype=np.float32))
|
||||
|
||||
def verify_serialization(lowered):
|
||||
serialized, in_tree, out_tree = compile_and_serialize(lowered)
|
||||
compiled = load_compiled(serialized, in_tree, out_tree)
|
||||
self.assertEqual(compiled.as_text(), lowered.compile().as_text())
|
||||
|
||||
verify_serialization(lowered)
|
||||
verify_serialization(jax.jit(lambda x: x * x).lower(np.arange(100)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user