Merge pull request #18980 from gnecula:export_api

PiperOrigin-RevId: 591172917
This commit is contained in:
jax authors 2023-12-15 00:52:19 -08:00
commit 91faddd023
12 changed files with 113 additions and 58 deletions

View File

@ -9,6 +9,7 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.24
* Changes
* JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_paritioning or JAX callbacks in the lowering
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
@ -17,6 +18,9 @@ Remember to align the itemized text with the first line of an item within a list
devices to create `Sharding`s during lowering.
This is a temporary state until we can create `Sharding`s without physical
devices.
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`.
## jaxlib 0.4.24

View File

@ -83,7 +83,7 @@ from numpy import array, float32
import jax
from jax import tree_util
from jax.experimental.export import export
from jax.experimental import export
from jax.experimental import pjit

View File

@ -22,7 +22,6 @@ load("@rules_python//python:defs.bzl", "py_library")
licenses(["notice"])
# Please add new users to :australis_users.
package(
default_applicable_licenses = [],
default_visibility = ["//visibility:private"],
@ -31,7 +30,8 @@ package(
py_library(
name = "export",
srcs = [
"export.py",
"__init__.py",
"_export.py",
"serialization.py",
"serialization_generated.py",
"shape_poly.py",

View File

@ -12,3 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from jax.experimental.export._export import (
minimum_supported_serialization_version,
maximum_supported_serialization_version,
Exported,
export,
call_exported, # TODO: deprecate
call,
DisabledSafetyCheck,
default_lowering_platform,
symbolic_shape,
args_specs,
)
from jax.experimental.export.serialization import (
serialize,
deserialize,
)

View File

@ -24,6 +24,7 @@ import functools
import itertools
import re
from typing import Any, Callable, Optional, TypeVar, Union
import warnings
from absl import logging
import numpy as np
@ -55,6 +56,20 @@ zip = util.safe_zip
DType = Any
Shape = jax._src.core.Shape
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
# None means unspecified sharding
Sharding = Union[xla_client.HloSharding, None]
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
# for a description of the different versions.
minimum_supported_serialization_version = 6
maximum_supported_serialization_version = 9
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9
class DisabledSafetyCheck:
"""A safety check should be skipped on (de)serialization.
@ -115,19 +130,6 @@ class DisabledSafetyCheck:
def __hash__(self) -> int:
return hash(self._impl)
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
# for a description of the different versions.
minimum_supported_serialization_version = 6
maximum_supported_serialization_version = 9
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
# None means unspecified sharding
Sharding = Union[xla_client.HloSharding, None]
@dataclasses.dataclass(frozen=True)
class Exported:
@ -1050,7 +1052,7 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
### Calling the exported function
def call_exported(exported: Exported) -> Callable[..., jax.Array]:
def call(exported: Exported) -> Callable[..., jax.Array]:
if not isinstance(exported, Exported):
raise ValueError(
"The exported argument must be an export.Exported. "
@ -1096,6 +1098,7 @@ def call_exported(exported: Exported) -> Callable[..., jax.Array]:
return exported.out_tree.unflatten(res_flat)
return f_imported
call_exported = call
# A JAX primitive for invoking a serialized JAX function.
call_exported_p = core.Primitive("call_exported")
@ -1296,3 +1299,30 @@ def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
return x
return mlir.wrap_with_sharding_op(
ctx, x, x_aval, x_sharding.to_proto())
# TODO(necula): Previously, we had `from jax.experimental.export import export`
# Now we want to simplify the usage, and export the public APIs directly
# from `jax.experimental.export` and now `jax.experimental.export.export`
# refers to the `export` function. Since there may still be users of the
# old API in other packages, we add the old public API as attributes of the
# exported function. We will clean this up after a deprecation period.
def wrap_with_deprecation_warning(f):
msg = (f"You are using function `{f.__name__}` from "
"`jax.experimental.export.export`. You should instead use it directly "
"from `jax.experimental.export`. Instead of "
"`from jax.experimental.export import export` you should use "
"`from jax.experimental import export`.")
def wrapped_f(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return f(*args, **kwargs)
return wrapped_f
export.export = wrap_with_deprecation_warning(export)
export.Exported = Exported
export.call_exported = wrap_with_deprecation_warning(call_exported)
export.DisabledSafetyCheck = DisabledSafetyCheck
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
export.symbolic_shape = wrap_with_deprecation_warning(symbolic_shape)
export.args_specs = wrap_with_deprecation_warning(args_specs)
export.minimum_supported_serialization_version = minimum_supported_serialization_version
export.maximum_supported_serialization_version = maximum_supported_serialization_version

View File

@ -29,8 +29,10 @@ from jax._src import dtypes
from jax._src import effects
from jax._src import tree_util
from jax._src.lib import xla_client
from jax.experimental.export import export
from jax.experimental.export import serialization_generated as ser_flatbuf
from jax.experimental.export import _export
from jax.experimental import export
import numpy as np
T = TypeVar("T")
@ -353,7 +355,7 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue) -> core.AbstractValue:
def _serialize_sharding(
builder: flatbuffers.Builder, s: export.Sharding
builder: flatbuffers.Builder, s: _export.Sharding
) -> int:
proto = None
if s is None:
@ -370,7 +372,7 @@ def _serialize_sharding(
return ser_flatbuf.ShardingEnd(builder)
def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding:
def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.Sharding:
kind = s.Kind()
if kind == ser_flatbuf.ShardingKind.unspecified:
return None

View File

@ -38,7 +38,8 @@ from jax import tree_util
from jax import sharding
from jax.experimental import maps
from jax.experimental.export import shape_poly
from jax.experimental.export import export
from jax.experimental.export import _export
from jax.experimental import export
from jax.experimental.jax2tf import impl_no_xla
from jax.interpreters import xla
@ -514,14 +515,14 @@ class NativeSerializationImpl(SerializationImpl):
def get_vjp_fun(self) -> tuple[Callable,
Sequence[core.AbstractValue]]:
return export._get_vjp_fun(self.fun_jax,
in_tree=self.exported.in_tree,
in_avals=self.exported.in_avals,
in_shardings=self.exported.in_shardings,
out_avals=self.exported.out_avals,
out_shardings=self.exported.out_shardings,
nr_devices=self.exported.nr_devices,
apply_jit=True)
return _export._get_vjp_fun(self.fun_jax,
in_tree=self.exported.in_tree,
in_avals=self.exported.in_avals,
in_shardings=self.exported.in_shardings,
out_avals=self.exported.out_avals,
out_shardings=self.exported.out_shardings,
nr_devices=self.exported.nr_devices,
apply_jit=True)
class GraphSerializationImpl(SerializationImpl):
def __init__(self, fun_jax, *,
@ -586,14 +587,14 @@ class GraphSerializationImpl(SerializationImpl):
# We reuse the code for native serialization to get the VJP functions,
# except we use unspecified shardings, and we do not apply a jit on the
# VJP. This matches the older behavior of jax2tf for graph serialization.
return export._get_vjp_fun(self.fun_jax,
in_tree=self.in_tree,
in_avals=self.args_avals_flat,
in_shardings=(None,) * len(self.args_avals_flat),
out_avals=self.outs_avals,
out_shardings=(None,) * len(self.outs_avals),
nr_devices=1, # Does not matter for unspecified shardings
apply_jit=False)
return _export._get_vjp_fun(self.fun_jax,
in_tree=self.in_tree,
in_avals=self.args_avals_flat,
in_shardings=(None,) * len(self.args_avals_flat),
out_avals=self.outs_avals,
out_shardings=(None,) * len(self.outs_avals),
nr_devices=1, # Does not matter for unspecified shardings
apply_jit=False)
def dtype_of_val(val: TfVal) -> DType:

View File

@ -27,7 +27,8 @@ import numpy as np
import jax
from jax import lax
from jax.experimental.export import export
from jax.experimental import export
from jax.experimental.export import _export
from jax._src.internal_test_util import export_back_compat_test_util as bctu
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
@ -97,7 +98,7 @@ class CompatTest(bctu.CompatTestBase):
def test_custom_call_coverage(self):
"""Tests that the back compat tests cover all the targets declared stable."""
targets_to_cover = set(export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
# Add here all the testdatas that should cover the targets guaranteed
# stable
covering_testdatas = [

View File

@ -29,7 +29,7 @@ from jax._src import test_util as jtu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax.experimental import export
from jax.experimental.jax2tf.tests import tf_test_util
import numpy as np

View File

@ -37,9 +37,8 @@ from jax._src import core
from jax._src import source_info_util
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import mlir
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax.experimental import export
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.maps import xmap
from jax.experimental.shard_map import shard_map

View File

@ -31,7 +31,7 @@ from jax._src import test_util as jtu
from jax import tree_util
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax.experimental import export
from jax._src import config
from jax._src import xla_bridge
import numpy as np

View File

@ -26,8 +26,8 @@ import jax
from jax import lax
from jax import numpy as jnp
from jax import tree_util
from jax.experimental.export import export
from jax.experimental.export import serialization
from jax.experimental import export
from jax.experimental.export import _export
from jax.experimental import pjit
from jax.sharding import NamedSharding
from jax.sharding import Mesh
@ -145,8 +145,8 @@ def get_exported(fun, max_vjp_orders=0,
"""Like export.export but with serialization + deserialization."""
def serde_exported(*fun_args, **fun_kwargs):
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
serialized = serialization.serialize(exp, vjp_order=max_vjp_orders)
return serialization.deserialize(serialized)
serialized = export.serialize(exp, vjp_order=max_vjp_orders)
return export.deserialize(serialized)
return serde_exported
class JaxExportTest(jtu.JaxTestCase):
@ -704,7 +704,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.array([True, False, True, False], dtype=np.bool_)
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype))
x.dtype))
res = export.call_exported(exp)(x)
self.assertAllClose(f_jax(x), res)
@ -1042,7 +1042,7 @@ class JaxExportTest(jtu.JaxTestCase):
)
exp = get_exported(f_jax)(x)
if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in exp.ordered_effects))
self.assertEqual(["ForTestingUnorderedEffect1()"],
@ -1072,11 +1072,11 @@ class JaxExportTest(jtu.JaxTestCase):
# Results
r"!stablehlo.token .*jax.token = true.*"
r"!stablehlo.token .*jax.token = true.*")
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
# The main function does not have tokens
self.assertNotRegex(mlir_module_str, r"@main.*token")
else:
@ -1094,7 +1094,7 @@ class JaxExportTest(jtu.JaxTestCase):
export.call_exported(exp)(x))
lowered_outer = jax.jit(f_outer).lower(x)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertEqual(["ForTestingOrderedEffect2()"],
[str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]])
else:
@ -1104,7 +1104,7 @@ class JaxExportTest(jtu.JaxTestCase):
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))
mlir_outer_module_str = str(lowered_outer.compiler_ir())
if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
main_expected_re = main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_outer_module_str, main_expected_re)
@ -1132,11 +1132,11 @@ class JaxExportTest(jtu.JaxTestCase):
r"%arg3: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
# The main function does not have tokens
self.assertNotRegex(mlir_module_str, r"@main.*token")
else:
@ -1179,11 +1179,11 @@ class JaxExportTest(jtu.JaxTestCase):
r"%arg4: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
# The main function does not have tokens
self.assertNotRegex(mlir_module_str, r"@main.*token")
else:
@ -1216,7 +1216,7 @@ class JaxExportTest(jtu.JaxTestCase):
f_jax = jax.jit(f_jax, donate_argnums=(0,))
exp = export.export(f_jax)(x)
mlir_module_str = str(exp.mlir_module())
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 0")
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")
else: