[export] Refactor the imports for the public API of jax.experimental.export

Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
This commit is contained in:
George Necula 2024-01-08 05:29:11 -08:00 committed by jax authors
parent ed2a839884
commit 69788d18b6
14 changed files with 117 additions and 61 deletions

View File

@ -9,6 +9,7 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.24
* Changes
* JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_paritioning or JAX callbacks in the lowering
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
@ -27,6 +28,10 @@ Remember to align the itemized text with the first line of an item within a list
({jax-issue}`#19231`; note that this may result in user-visible behavior
changes); improved the error messages for inconclusive inequality comparisons
({jax-issue}`#19235`).
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will
continue to work for a deprecation period of 3 months.
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).

View File

@ -83,7 +83,7 @@ from numpy import array, float32
import jax
from jax import tree_util
from jax.experimental.export import export
from jax.experimental import export
from jax.experimental import pjit

View File

@ -22,7 +22,6 @@ load("@rules_python//python:defs.bzl", "py_library")
licenses(["notice"])
# Please add new users to :australis_users.
package(
default_applicable_licenses = [],
default_visibility = ["//visibility:private"],
@ -31,7 +30,8 @@ package(
py_library(
name = "export",
srcs = [
"export.py",
"__init__.py",
"_export.py",
"serialization.py",
"serialization_generated.py",
"shape_poly.py",

View File

@ -12,3 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from jax.experimental.export._export import (
minimum_supported_serialization_version,
maximum_supported_serialization_version,
Exported,
export,
call_exported, # TODO: deprecate
call,
DisabledSafetyCheck,
default_lowering_platform,
symbolic_shape,
args_specs,
)
from jax.experimental.export.serialization import (
serialize,
deserialize,
)

View File

@ -24,6 +24,7 @@ import functools
import itertools
import re
from typing import Any, Callable, Optional, TypeVar, Union
import warnings
from absl import logging
import numpy as np
@ -57,6 +58,20 @@ zip = util.safe_zip
DType = Any
Shape = jax._src.core.Shape
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
# None means unspecified sharding
Sharding = Union[xla_client.HloSharding, None]
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
# for a description of the different versions.
minimum_supported_serialization_version = 6
maximum_supported_serialization_version = 9
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9
class DisabledSafetyCheck:
"""A safety check should be skipped on (de)serialization.
@ -117,19 +132,6 @@ class DisabledSafetyCheck:
def __hash__(self) -> int:
return hash(self._impl)
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
# for a description of the different versions.
minimum_supported_serialization_version = 6
maximum_supported_serialization_version = 9
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
# None means unspecified sharding
Sharding = Union[xla_client.HloSharding, None]
@dataclasses.dataclass(frozen=True)
class Exported:
@ -1052,7 +1054,7 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
### Calling the exported function
def call_exported(exported: Exported) -> Callable[..., jax.Array]:
def call(exported: Exported) -> Callable[..., jax.Array]:
if not isinstance(exported, Exported):
raise ValueError(
"The exported argument must be an export.Exported. "
@ -1107,6 +1109,7 @@ def call_exported(exported: Exported) -> Callable[..., jax.Array]:
return exported.out_tree.unflatten(res_flat)
return f_imported
call_exported = call
# A JAX primitive for invoking a serialized JAX function.
call_exported_p = core.Primitive("call_exported")
@ -1307,3 +1310,30 @@ def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
return x
return mlir.wrap_with_sharding_op(
ctx, x, x_aval, x_sharding.to_proto())
# TODO(necula): Previously, we had `from jax.experimental.export import export`
# Now we want to simplify the usage, and export the public APIs directly
# from `jax.experimental.export` and now `jax.experimental.export.export`
# refers to the `export` function. Since there may still be users of the
# old API in other packages, we add the old public API as attributes of the
# exported function. We will clean this up after a deprecation period.
def wrap_with_deprecation_warning(f):
msg = (f"You are using function `{f.__name__}` from "
"`jax.experimental.export.export`. You should instead use it directly "
"from `jax.experimental.export`. Instead of "
"`from jax.experimental.export import export` you should use "
"`from jax.experimental import export`.")
def wrapped_f(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return f(*args, **kwargs)
return wrapped_f
export.export = wrap_with_deprecation_warning(export)
export.Exported = Exported
export.call_exported = wrap_with_deprecation_warning(call_exported)
export.DisabledSafetyCheck = DisabledSafetyCheck
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
export.symbolic_shape = wrap_with_deprecation_warning(symbolic_shape)
export.args_specs = wrap_with_deprecation_warning(args_specs)
export.minimum_supported_serialization_version = minimum_supported_serialization_version
export.maximum_supported_serialization_version = maximum_supported_serialization_version

View File

@ -29,8 +29,10 @@ from jax._src import dtypes
from jax._src import effects
from jax._src import tree_util
from jax._src.lib import xla_client
from jax.experimental.export import export
from jax.experimental.export import serialization_generated as ser_flatbuf
from jax.experimental.export import _export
from jax.experimental import export
import numpy as np
T = TypeVar("T")
@ -359,7 +361,7 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue) -> core.AbstractValue:
def _serialize_sharding(
builder: flatbuffers.Builder, s: export.Sharding
builder: flatbuffers.Builder, s: _export.Sharding
) -> int:
proto = None
if s is None:
@ -376,7 +378,7 @@ def _serialize_sharding(
return ser_flatbuf.ShardingEnd(builder)
def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding:
def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.Sharding:
kind = s.Kind()
if kind == ser_flatbuf.ShardingKind.unspecified:
return None

View File

@ -38,7 +38,8 @@ from jax import tree_util
from jax import sharding
from jax.experimental import maps
from jax.experimental.export import shape_poly
from jax.experimental.export import export
from jax.experimental.export import _export
from jax.experimental import export
from jax.experimental.jax2tf import impl_no_xla
from jax.interpreters import xla
@ -515,7 +516,7 @@ class NativeSerializationImpl(SerializationImpl):
def get_vjp_fun(self) -> tuple[Callable,
Sequence[core.AbstractValue]]:
return export._get_vjp_fun(self.fun_jax,
return _export._get_vjp_fun(self.fun_jax,
in_tree=self.exported.in_tree,
in_avals=self.exported.in_avals,
in_shardings=self.exported.in_shardings,
@ -587,7 +588,7 @@ class GraphSerializationImpl(SerializationImpl):
# We reuse the code for native serialization to get the VJP functions,
# except we use unspecified shardings, and we do not apply a jit on the
# VJP. This matches the older behavior of jax2tf for graph serialization.
return export._get_vjp_fun(self.fun_jax,
return _export._get_vjp_fun(self.fun_jax,
in_tree=self.in_tree,
in_avals=self.args_avals_flat,
in_shardings=(None,) * len(self.args_avals_flat),

View File

@ -29,7 +29,7 @@ from jax._src import test_util as jtu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax.experimental import export
from jax.experimental.jax2tf.tests import tf_test_util
import numpy as np

View File

@ -37,9 +37,8 @@ from jax._src import core
from jax._src import source_info_util
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import mlir
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax.experimental import export
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.maps import xmap
from jax.experimental.shard_map import shard_map

View File

@ -31,7 +31,7 @@ from jax._src import test_util as jtu
from jax import tree_util
from jax.experimental import jax2tf
from jax.experimental.export import export
from jax.experimental import export
from jax._src import config
from jax._src import xla_bridge
import numpy as np

View File

@ -27,7 +27,8 @@ import numpy as np
import jax
from jax import lax
from jax.experimental.export import export
from jax.experimental import export
from jax.experimental.export import _export
from jax._src.internal_test_util import export_back_compat_test_util as bctu
from jax._src.internal_test_util.export_back_compat_test_data import cpu_ducc_fft
@ -97,7 +98,7 @@ class CompatTest(bctu.CompatTestBase):
def test_custom_call_coverage(self):
"""Tests that the back compat tests cover all the targets declared stable."""
targets_to_cover = set(export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
# Add here all the testdatas that should cover the targets guaranteed
# stable
covering_testdatas = [

View File

@ -33,7 +33,7 @@ import numpy as np
import jax
from jax import lax
from jax._src import test_util as jtu
from jax.experimental.export import export
from jax.experimental import export
from jax._src.internal_test_util import test_harnesses

View File

@ -26,10 +26,10 @@ import jax
from jax import lax
from jax import numpy as jnp
from jax import tree_util
from jax.experimental.export import export
from jax.experimental.export import serialization
from jax.experimental.shard_map import shard_map
from jax.experimental import export
from jax.experimental.export import _export
from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
@ -146,8 +146,8 @@ def get_exported(fun, vjp_order=0,
"""Like export.export but with serialization + deserialization."""
def serde_exported(*fun_args, **fun_kwargs):
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
serialized = serialization.serialize(exp, vjp_order=vjp_order)
return serialization.deserialize(serialized)
serialized = export.serialize(exp, vjp_order=vjp_order)
return export.deserialize(serialized)
return serde_exported
class JaxExportTest(jtu.JaxTestCase):
@ -1139,7 +1139,7 @@ class JaxExportTest(jtu.JaxTestCase):
)
exp = get_exported(f_jax)(x)
if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in exp.ordered_effects))
self.assertEqual(["ForTestingUnorderedEffect1()"],
@ -1169,11 +1169,11 @@ class JaxExportTest(jtu.JaxTestCase):
# Results
r"!stablehlo.token .*jax.token = true.*"
r"!stablehlo.token .*jax.token = true.*")
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
# The main function does not have tokens
self.assertNotRegex(mlir_module_str, r"@main.*token")
else:
@ -1191,7 +1191,7 @@ class JaxExportTest(jtu.JaxTestCase):
export.call_exported(exp)(x))
lowered_outer = jax.jit(f_outer).lower(x)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertEqual(["ForTestingOrderedEffect2()"],
[str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]])
else:
@ -1201,7 +1201,7 @@ class JaxExportTest(jtu.JaxTestCase):
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))
mlir_outer_module_str = str(lowered_outer.compiler_ir())
if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
main_expected_re = main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_outer_module_str, main_expected_re)
@ -1229,11 +1229,11 @@ class JaxExportTest(jtu.JaxTestCase):
r"%arg3: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
# The main function does not have tokens
self.assertNotRegex(mlir_module_str, r"@main.*token")
else:
@ -1276,11 +1276,11 @@ class JaxExportTest(jtu.JaxTestCase):
r"%arg4: tensor<\?x\?xf32>.*\) -> \("
# Results
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
# The main function does not have tokens
self.assertNotRegex(mlir_module_str, r"@main.*token")
else:
@ -1313,7 +1313,7 @@ class JaxExportTest(jtu.JaxTestCase):
f_jax = jax.jit(f_jax, donate_argnums=(0,))
exp = export.export(f_jax)(x)
mlir_module_str = str(exp.mlir_module())
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 0")
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")
else:

View File

@ -33,7 +33,7 @@ import operator as op
import re
import jax
from jax.experimental.export import export
from jax.experimental import export
from jax.experimental.export import shape_poly
from jax.experimental import pjit
from jax import lax