[export] Refactor the imports for the public API of jax.experimental.export

Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
This commit is contained in:
George Necula 2024-01-08 05:29:11 -08:00 committed by jax authors
parent ed2a839884
commit 69788d18b6
14 changed files with 117 additions and 61 deletions

View File

@ -9,6 +9,7 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.24 ## jax 0.4.24
* Changes * Changes
* JAX lowering to StableHLO does not depend on physical devices anymore. * JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_paritioning or JAX callbacks in the lowering If your primitive wraps custom_paritioning or JAX callbacks in the lowering
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
@ -27,6 +28,10 @@ Remember to align the itemized text with the first line of an item within a list
({jax-issue}`#19231`; note that this may result in user-visible behavior ({jax-issue}`#19231`; note that this may result in user-visible behavior
changes); improved the error messages for inconclusive inequality comparisons changes); improved the error messages for inconclusive inequality comparisons
({jax-issue}`#19235`). ({jax-issue}`#19235`).
* Refactored the API for `jax.experimental.export`. Instead of
`from jax.experimental.export import export` you should use now
`from jax.experimental import export`. The old way of importing will
continue to work for a deprecation period of 3 months.
* Deprecations & Removals * Deprecations & Removals
* A number of previously deprecated functions have been removed, following a * A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`). standard 3+ month deprecation cycle (see {ref}`api-compatibility`).

View File

@ -83,7 +83,7 @@ from numpy import array, float32
import jax import jax
from jax import tree_util from jax import tree_util
from jax.experimental.export import export from jax.experimental import export
from jax.experimental import pjit from jax.experimental import pjit

View File

@ -22,7 +22,6 @@ load("@rules_python//python:defs.bzl", "py_library")
licenses(["notice"]) licenses(["notice"])
# Please add new users to :australis_users.
package( package(
default_applicable_licenses = [], default_applicable_licenses = [],
default_visibility = ["//visibility:private"], default_visibility = ["//visibility:private"],
@ -31,7 +30,8 @@ package(
py_library( py_library(
name = "export", name = "export",
srcs = [ srcs = [
"export.py", "__init__.py",
"_export.py",
"serialization.py", "serialization.py",
"serialization_generated.py", "serialization_generated.py",
"shape_poly.py", "shape_poly.py",

View File

@ -12,3 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
from jax.experimental.export._export import (
minimum_supported_serialization_version,
maximum_supported_serialization_version,
Exported,
export,
call_exported, # TODO: deprecate
call,
DisabledSafetyCheck,
default_lowering_platform,
symbolic_shape,
args_specs,
)
from jax.experimental.export.serialization import (
serialize,
deserialize,
)

View File

@ -24,6 +24,7 @@ import functools
import itertools import itertools
import re import re
from typing import Any, Callable, Optional, TypeVar, Union from typing import Any, Callable, Optional, TypeVar, Union
import warnings
from absl import logging from absl import logging
import numpy as np import numpy as np
@ -57,6 +58,20 @@ zip = util.safe_zip
DType = Any DType = Any
Shape = jax._src.core.Shape Shape = jax._src.core.Shape
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
# None means unspecified sharding
Sharding = Union[xla_client.HloSharding, None]
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
# for a description of the different versions.
minimum_supported_serialization_version = 6
maximum_supported_serialization_version = 9
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9
class DisabledSafetyCheck: class DisabledSafetyCheck:
"""A safety check should be skipped on (de)serialization. """A safety check should be skipped on (de)serialization.
@ -117,19 +132,6 @@ class DisabledSafetyCheck:
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(self._impl) return hash(self._impl)
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
# for a description of the different versions.
minimum_supported_serialization_version = 6
maximum_supported_serialization_version = 9
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9
# The values of input and output sharding from the lowering.
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
# None means unspecified sharding
Sharding = Union[xla_client.HloSharding, None]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Exported: class Exported:
@ -1052,7 +1054,7 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
### Calling the exported function ### Calling the exported function
def call_exported(exported: Exported) -> Callable[..., jax.Array]: def call(exported: Exported) -> Callable[..., jax.Array]:
if not isinstance(exported, Exported): if not isinstance(exported, Exported):
raise ValueError( raise ValueError(
"The exported argument must be an export.Exported. " "The exported argument must be an export.Exported. "
@ -1107,6 +1109,7 @@ def call_exported(exported: Exported) -> Callable[..., jax.Array]:
return exported.out_tree.unflatten(res_flat) return exported.out_tree.unflatten(res_flat)
return f_imported return f_imported
call_exported = call
# A JAX primitive for invoking a serialized JAX function. # A JAX primitive for invoking a serialized JAX function.
call_exported_p = core.Primitive("call_exported") call_exported_p = core.Primitive("call_exported")
@ -1307,3 +1310,30 @@ def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
return x return x
return mlir.wrap_with_sharding_op( return mlir.wrap_with_sharding_op(
ctx, x, x_aval, x_sharding.to_proto()) ctx, x, x_aval, x_sharding.to_proto())
# TODO(necula): Previously, we had `from jax.experimental.export import export`
# Now we want to simplify the usage, and export the public APIs directly
# from `jax.experimental.export` and now `jax.experimental.export.export`
# refers to the `export` function. Since there may still be users of the
# old API in other packages, we add the old public API as attributes of the
# exported function. We will clean this up after a deprecation period.
def wrap_with_deprecation_warning(f):
msg = (f"You are using function `{f.__name__}` from "
"`jax.experimental.export.export`. You should instead use it directly "
"from `jax.experimental.export`. Instead of "
"`from jax.experimental.export import export` you should use "
"`from jax.experimental import export`.")
def wrapped_f(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return f(*args, **kwargs)
return wrapped_f
export.export = wrap_with_deprecation_warning(export)
export.Exported = Exported
export.call_exported = wrap_with_deprecation_warning(call_exported)
export.DisabledSafetyCheck = DisabledSafetyCheck
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
export.symbolic_shape = wrap_with_deprecation_warning(symbolic_shape)
export.args_specs = wrap_with_deprecation_warning(args_specs)
export.minimum_supported_serialization_version = minimum_supported_serialization_version
export.maximum_supported_serialization_version = maximum_supported_serialization_version

View File

@ -29,8 +29,10 @@ from jax._src import dtypes
from jax._src import effects from jax._src import effects
from jax._src import tree_util from jax._src import tree_util
from jax._src.lib import xla_client from jax._src.lib import xla_client
from jax.experimental.export import export
from jax.experimental.export import serialization_generated as ser_flatbuf from jax.experimental.export import serialization_generated as ser_flatbuf
from jax.experimental.export import _export
from jax.experimental import export
import numpy as np import numpy as np
T = TypeVar("T") T = TypeVar("T")
@ -359,7 +361,7 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue) -> core.AbstractValue:
def _serialize_sharding( def _serialize_sharding(
builder: flatbuffers.Builder, s: export.Sharding builder: flatbuffers.Builder, s: _export.Sharding
) -> int: ) -> int:
proto = None proto = None
if s is None: if s is None:
@ -376,7 +378,7 @@ def _serialize_sharding(
return ser_flatbuf.ShardingEnd(builder) return ser_flatbuf.ShardingEnd(builder)
def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding: def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.Sharding:
kind = s.Kind() kind = s.Kind()
if kind == ser_flatbuf.ShardingKind.unspecified: if kind == ser_flatbuf.ShardingKind.unspecified:
return None return None

View File

@ -38,7 +38,8 @@ from jax import tree_util
from jax import sharding from jax import sharding
from jax.experimental import maps from jax.experimental import maps
from jax.experimental.export import shape_poly from jax.experimental.export import shape_poly
from jax.experimental.export import export from jax.experimental.export import _export
from jax.experimental import export
from jax.experimental.jax2tf import impl_no_xla from jax.experimental.jax2tf import impl_no_xla
from jax.interpreters import xla from jax.interpreters import xla
@ -515,14 +516,14 @@ class NativeSerializationImpl(SerializationImpl):
def get_vjp_fun(self) -> tuple[Callable, def get_vjp_fun(self) -> tuple[Callable,
Sequence[core.AbstractValue]]: Sequence[core.AbstractValue]]:
return export._get_vjp_fun(self.fun_jax, return _export._get_vjp_fun(self.fun_jax,
in_tree=self.exported.in_tree, in_tree=self.exported.in_tree,
in_avals=self.exported.in_avals, in_avals=self.exported.in_avals,
in_shardings=self.exported.in_shardings, in_shardings=self.exported.in_shardings,
out_avals=self.exported.out_avals, out_avals=self.exported.out_avals,
out_shardings=self.exported.out_shardings, out_shardings=self.exported.out_shardings,
nr_devices=self.exported.nr_devices, nr_devices=self.exported.nr_devices,
apply_jit=True) apply_jit=True)
class GraphSerializationImpl(SerializationImpl): class GraphSerializationImpl(SerializationImpl):
def __init__(self, fun_jax, *, def __init__(self, fun_jax, *,
@ -587,14 +588,14 @@ class GraphSerializationImpl(SerializationImpl):
# We reuse the code for native serialization to get the VJP functions, # We reuse the code for native serialization to get the VJP functions,
# except we use unspecified shardings, and we do not apply a jit on the # except we use unspecified shardings, and we do not apply a jit on the
# VJP. This matches the older behavior of jax2tf for graph serialization. # VJP. This matches the older behavior of jax2tf for graph serialization.
return export._get_vjp_fun(self.fun_jax, return _export._get_vjp_fun(self.fun_jax,
in_tree=self.in_tree, in_tree=self.in_tree,
in_avals=self.args_avals_flat, in_avals=self.args_avals_flat,
in_shardings=(None,) * len(self.args_avals_flat), in_shardings=(None,) * len(self.args_avals_flat),
out_avals=self.outs_avals, out_avals=self.outs_avals,
out_shardings=(None,) * len(self.outs_avals), out_shardings=(None,) * len(self.outs_avals),
nr_devices=1, # Does not matter for unspecified shardings nr_devices=1, # Does not matter for unspecified shardings
apply_jit=False) apply_jit=False)
def dtype_of_val(val: TfVal) -> DType: def dtype_of_val(val: TfVal) -> DType:

View File

@ -29,7 +29,7 @@ from jax._src import test_util as jtu
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import hlo
from jax.experimental import jax2tf from jax.experimental import jax2tf
from jax.experimental.export import export from jax.experimental import export
from jax.experimental.jax2tf.tests import tf_test_util from jax.experimental.jax2tf.tests import tf_test_util
import numpy as np import numpy as np

View File

@ -37,9 +37,8 @@ from jax._src import core
from jax._src import source_info_util from jax._src import source_info_util
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src import xla_bridge as xb from jax._src import xla_bridge as xb
from jax._src.interpreters import mlir
from jax.experimental import jax2tf from jax.experimental import jax2tf
from jax.experimental.export import export from jax.experimental import export
from jax.experimental.jax2tf.tests import tf_test_util from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.maps import xmap from jax.experimental.maps import xmap
from jax.experimental.shard_map import shard_map from jax.experimental.shard_map import shard_map

View File

@ -31,7 +31,7 @@ from jax._src import test_util as jtu
from jax import tree_util from jax import tree_util
from jax.experimental import jax2tf from jax.experimental import jax2tf
from jax.experimental.export import export from jax.experimental import export
from jax._src import config from jax._src import config
from jax._src import xla_bridge from jax._src import xla_bridge
import numpy as np import numpy as np

View File

@ -27,7 +27,8 @@ import numpy as np
import jax import jax
from jax import lax from jax import lax
from jax.experimental.export import export from jax.experimental import export
from jax.experimental.export import _export
from jax._src.internal_test_util import export_back_compat_test_util as bctu from jax._src.internal_test_util import export_back_compat_test_util as bctu
from jax._src.internal_test_util.export_back_compat_test_data import cpu_ducc_fft from jax._src.internal_test_util.export_back_compat_test_data import cpu_ducc_fft
@ -97,7 +98,7 @@ class CompatTest(bctu.CompatTestBase):
def test_custom_call_coverage(self): def test_custom_call_coverage(self):
"""Tests that the back compat tests cover all the targets declared stable.""" """Tests that the back compat tests cover all the targets declared stable."""
targets_to_cover = set(export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
# Add here all the testdatas that should cover the targets guaranteed # Add here all the testdatas that should cover the targets guaranteed
# stable # stable
covering_testdatas = [ covering_testdatas = [

View File

@ -33,7 +33,7 @@ import numpy as np
import jax import jax
from jax import lax from jax import lax
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax.experimental.export import export from jax.experimental import export
from jax._src.internal_test_util import test_harnesses from jax._src.internal_test_util import test_harnesses

View File

@ -26,10 +26,10 @@ import jax
from jax import lax from jax import lax
from jax import numpy as jnp from jax import numpy as jnp
from jax import tree_util from jax import tree_util
from jax.experimental.export import export from jax.experimental import export
from jax.experimental.export import serialization from jax.experimental.export import _export
from jax.experimental.shard_map import shard_map
from jax.experimental import pjit from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
from jax.sharding import NamedSharding from jax.sharding import NamedSharding
from jax.sharding import Mesh from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
@ -146,8 +146,8 @@ def get_exported(fun, vjp_order=0,
"""Like export.export but with serialization + deserialization.""" """Like export.export but with serialization + deserialization."""
def serde_exported(*fun_args, **fun_kwargs): def serde_exported(*fun_args, **fun_kwargs):
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs) exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
serialized = serialization.serialize(exp, vjp_order=vjp_order) serialized = export.serialize(exp, vjp_order=vjp_order)
return serialization.deserialize(serialized) return export.deserialize(serialized)
return serde_exported return serde_exported
class JaxExportTest(jtu.JaxTestCase): class JaxExportTest(jtu.JaxTestCase):
@ -744,7 +744,7 @@ class JaxExportTest(jtu.JaxTestCase):
x = np.array([True, False, True, False], dtype=np.bool_) x = np.array([True, False, True, False], dtype=np.bool_)
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
x.dtype)) x.dtype))
res = export.call_exported(exp)(x) res = export.call_exported(exp)(x)
self.assertAllClose(f_jax(x), res) self.assertAllClose(f_jax(x), res)
@ -1139,7 +1139,7 @@ class JaxExportTest(jtu.JaxTestCase):
) )
exp = get_exported(f_jax)(x) exp = get_exported(f_jax)(x)
if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
sorted(str(e) for e in exp.ordered_effects)) sorted(str(e) for e in exp.ordered_effects))
self.assertEqual(["ForTestingUnorderedEffect1()"], self.assertEqual(["ForTestingUnorderedEffect1()"],
@ -1169,11 +1169,11 @@ class JaxExportTest(jtu.JaxTestCase):
# Results # Results
r"!stablehlo.token .*jax.token = true.*" r"!stablehlo.token .*jax.token = true.*"
r"!stablehlo.token .*jax.token = true.*") r"!stablehlo.token .*jax.token = true.*")
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re) self.assertRegex(mlir_module_str, wrapped_main_expected_re)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
# The main function does not have tokens # The main function does not have tokens
self.assertNotRegex(mlir_module_str, r"@main.*token") self.assertNotRegex(mlir_module_str, r"@main.*token")
else: else:
@ -1191,7 +1191,7 @@ class JaxExportTest(jtu.JaxTestCase):
export.call_exported(exp)(x)) export.call_exported(exp)(x))
lowered_outer = jax.jit(f_outer).lower(x) lowered_outer = jax.jit(f_outer).lower(x)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertEqual(["ForTestingOrderedEffect2()"], self.assertEqual(["ForTestingOrderedEffect2()"],
[str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]]) [str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]])
else: else:
@ -1201,7 +1201,7 @@ class JaxExportTest(jtu.JaxTestCase):
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]])) sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))
mlir_outer_module_str = str(lowered_outer.compiler_ir()) mlir_outer_module_str = str(lowered_outer.compiler_ir())
if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
main_expected_re = main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") main_expected_re = main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_outer_module_str, main_expected_re) self.assertRegex(mlir_outer_module_str, main_expected_re)
@ -1229,11 +1229,11 @@ class JaxExportTest(jtu.JaxTestCase):
r"%arg3: tensor<\?x\?xf32>.*\) -> \(" r"%arg3: tensor<\?x\?xf32>.*\) -> \("
# Results # Results
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re) self.assertRegex(mlir_module_str, wrapped_main_expected_re)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
# The main function does not have tokens # The main function does not have tokens
self.assertNotRegex(mlir_module_str, r"@main.*token") self.assertNotRegex(mlir_module_str, r"@main.*token")
else: else:
@ -1276,11 +1276,11 @@ class JaxExportTest(jtu.JaxTestCase):
r"%arg4: tensor<\?x\?xf32>.*\) -> \(" r"%arg4: tensor<\?x\?xf32>.*\) -> \("
# Results # Results
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
self.assertRegex(mlir_module_str, wrapped_main_expected_re) self.assertRegex(mlir_module_str, wrapped_main_expected_re)
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
# The main function does not have tokens # The main function does not have tokens
self.assertNotRegex(mlir_module_str, r"@main.*token") self.assertNotRegex(mlir_module_str, r"@main.*token")
else: else:
@ -1313,7 +1313,7 @@ class JaxExportTest(jtu.JaxTestCase):
f_jax = jax.jit(f_jax, donate_argnums=(0,)) f_jax = jax.jit(f_jax, donate_argnums=(0,))
exp = export.export(f_jax)(x) exp = export.export(f_jax)(x)
mlir_module_str = str(exp.mlir_module()) mlir_module_str = str(exp.mlir_module())
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 0") self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 0")
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1") self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")
else: else:

View File

@ -33,7 +33,7 @@ import operator as op
import re import re
import jax import jax
from jax.experimental.export import export from jax.experimental import export
from jax.experimental.export import shape_poly from jax.experimental.export import shape_poly
from jax.experimental import pjit from jax.experimental import pjit
from jax import lax from jax import lax