diff --git a/CHANGELOG.md b/CHANGELOG.md index d00703139..1923394a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.24 * Changes + * JAX lowering to StableHLO does not depend on physical devices anymore. If your primitive wraps custom_paritioning or JAX callbacks in the lowering rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your @@ -27,6 +28,10 @@ Remember to align the itemized text with the first line of an item within a list ({jax-issue}`#19231`; note that this may result in user-visible behavior changes); improved the error messages for inconclusive inequality comparisons ({jax-issue}`#19235`). + * Refactored the API for `jax.experimental.export`. Instead of + `from jax.experimental.export import export` you should use now + `from jax.experimental import export`. The old way of importing will + continue to work for a deprecation period of 3 months. * Deprecations & Removals * A number of previously deprecated functions have been removed, following a standard 3+ month deprecation cycle (see {ref}`api-compatibility`). diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index fbef4b05c..2c8fbf713 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -83,7 +83,7 @@ from numpy import array, float32 import jax from jax import tree_util -from jax.experimental.export import export +from jax.experimental import export from jax.experimental import pjit diff --git a/jax/experimental/export/BUILD b/jax/experimental/export/BUILD index 1eabba6b5..a0189039a 100644 --- a/jax/experimental/export/BUILD +++ b/jax/experimental/export/BUILD @@ -22,7 +22,6 @@ load("@rules_python//python:defs.bzl", "py_library") licenses(["notice"]) -# Please add new users to :australis_users. package( default_applicable_licenses = [], default_visibility = ["//visibility:private"], @@ -31,7 +30,8 @@ package( py_library( name = "export", srcs = [ - "export.py", + "__init__.py", + "_export.py", "serialization.py", "serialization_generated.py", "shape_poly.py", diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py index e9a5c3c22..e00e80d54 100644 --- a/jax/experimental/export/__init__.py +++ b/jax/experimental/export/__init__.py @@ -12,3 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + +from jax.experimental.export._export import ( + minimum_supported_serialization_version, + maximum_supported_serialization_version, + Exported, + export, + call_exported, # TODO: deprecate + call, + DisabledSafetyCheck, + default_lowering_platform, + + symbolic_shape, + args_specs, +) +from jax.experimental.export.serialization import ( + serialize, + deserialize, +) diff --git a/jax/experimental/export/export.py b/jax/experimental/export/_export.py similarity index 97% rename from jax/experimental/export/export.py rename to jax/experimental/export/_export.py index 03fffc9ee..12994e227 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/_export.py @@ -24,6 +24,7 @@ import functools import itertools import re from typing import Any, Callable, Optional, TypeVar, Union +import warnings from absl import logging import numpy as np @@ -57,6 +58,20 @@ zip = util.safe_zip DType = Any Shape = jax._src.core.Shape +# The values of input and output sharding from the lowering. +LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] + +# None means unspecified sharding +Sharding = Union[xla_client.HloSharding, None] + +# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions +# for a description of the different versions. +minimum_supported_serialization_version = 6 +maximum_supported_serialization_version = 9 + +_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7 +_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9 + class DisabledSafetyCheck: """A safety check should be skipped on (de)serialization. @@ -117,19 +132,6 @@ class DisabledSafetyCheck: def __hash__(self) -> int: return hash(self._impl) -# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions -# for a description of the different versions. -minimum_supported_serialization_version = 6 -maximum_supported_serialization_version = 9 - -_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7 -_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9 - -# The values of input and output sharding from the lowering. -LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] - -# None means unspecified sharding -Sharding = Union[xla_client.HloSharding, None] @dataclasses.dataclass(frozen=True) class Exported: @@ -1052,7 +1054,7 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported: ### Calling the exported function -def call_exported(exported: Exported) -> Callable[..., jax.Array]: +def call(exported: Exported) -> Callable[..., jax.Array]: if not isinstance(exported, Exported): raise ValueError( "The exported argument must be an export.Exported. " @@ -1107,6 +1109,7 @@ def call_exported(exported: Exported) -> Callable[..., jax.Array]: return exported.out_tree.unflatten(res_flat) return f_imported +call_exported = call # A JAX primitive for invoking a serialized JAX function. call_exported_p = core.Primitive("call_exported") @@ -1307,3 +1310,30 @@ def wrap_with_sharding(ctx: mlir.LoweringRuleContext, return x return mlir.wrap_with_sharding_op( ctx, x, x_aval, x_sharding.to_proto()) + +# TODO(necula): Previously, we had `from jax.experimental.export import export` +# Now we want to simplify the usage, and export the public APIs directly +# from `jax.experimental.export` and now `jax.experimental.export.export` +# refers to the `export` function. Since there may still be users of the +# old API in other packages, we add the old public API as attributes of the +# exported function. We will clean this up after a deprecation period. +def wrap_with_deprecation_warning(f): + msg = (f"You are using function `{f.__name__}` from " + "`jax.experimental.export.export`. You should instead use it directly " + "from `jax.experimental.export`. Instead of " + "`from jax.experimental.export import export` you should use " + "`from jax.experimental import export`.") + def wrapped_f(*args, **kwargs): + warnings.warn(msg, DeprecationWarning) + return f(*args, **kwargs) + return wrapped_f + +export.export = wrap_with_deprecation_warning(export) +export.Exported = Exported +export.call_exported = wrap_with_deprecation_warning(call_exported) +export.DisabledSafetyCheck = DisabledSafetyCheck +export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform) +export.symbolic_shape = wrap_with_deprecation_warning(symbolic_shape) +export.args_specs = wrap_with_deprecation_warning(args_specs) +export.minimum_supported_serialization_version = minimum_supported_serialization_version +export.maximum_supported_serialization_version = maximum_supported_serialization_version diff --git a/jax/experimental/export/serialization.py b/jax/experimental/export/serialization.py index b4b0e215f..bb1893d86 100644 --- a/jax/experimental/export/serialization.py +++ b/jax/experimental/export/serialization.py @@ -29,8 +29,10 @@ from jax._src import dtypes from jax._src import effects from jax._src import tree_util from jax._src.lib import xla_client -from jax.experimental.export import export from jax.experimental.export import serialization_generated as ser_flatbuf +from jax.experimental.export import _export +from jax.experimental import export + import numpy as np T = TypeVar("T") @@ -359,7 +361,7 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue) -> core.AbstractValue: def _serialize_sharding( - builder: flatbuffers.Builder, s: export.Sharding + builder: flatbuffers.Builder, s: _export.Sharding ) -> int: proto = None if s is None: @@ -376,7 +378,7 @@ def _serialize_sharding( return ser_flatbuf.ShardingEnd(builder) -def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding: +def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.Sharding: kind = s.Kind() if kind == ser_flatbuf.ShardingKind.unspecified: return None diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 41985df63..653b46591 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -38,7 +38,8 @@ from jax import tree_util from jax import sharding from jax.experimental import maps from jax.experimental.export import shape_poly -from jax.experimental.export import export +from jax.experimental.export import _export +from jax.experimental import export from jax.experimental.jax2tf import impl_no_xla from jax.interpreters import xla @@ -515,14 +516,14 @@ class NativeSerializationImpl(SerializationImpl): def get_vjp_fun(self) -> tuple[Callable, Sequence[core.AbstractValue]]: - return export._get_vjp_fun(self.fun_jax, - in_tree=self.exported.in_tree, - in_avals=self.exported.in_avals, - in_shardings=self.exported.in_shardings, - out_avals=self.exported.out_avals, - out_shardings=self.exported.out_shardings, - nr_devices=self.exported.nr_devices, - apply_jit=True) + return _export._get_vjp_fun(self.fun_jax, + in_tree=self.exported.in_tree, + in_avals=self.exported.in_avals, + in_shardings=self.exported.in_shardings, + out_avals=self.exported.out_avals, + out_shardings=self.exported.out_shardings, + nr_devices=self.exported.nr_devices, + apply_jit=True) class GraphSerializationImpl(SerializationImpl): def __init__(self, fun_jax, *, @@ -587,14 +588,14 @@ class GraphSerializationImpl(SerializationImpl): # We reuse the code for native serialization to get the VJP functions, # except we use unspecified shardings, and we do not apply a jit on the # VJP. This matches the older behavior of jax2tf for graph serialization. - return export._get_vjp_fun(self.fun_jax, - in_tree=self.in_tree, - in_avals=self.args_avals_flat, - in_shardings=(None,) * len(self.args_avals_flat), - out_avals=self.outs_avals, - out_shardings=(None,) * len(self.outs_avals), - nr_devices=1, # Does not matter for unspecified shardings - apply_jit=False) + return _export._get_vjp_fun(self.fun_jax, + in_tree=self.in_tree, + in_avals=self.args_avals_flat, + in_shardings=(None,) * len(self.args_avals_flat), + out_avals=self.outs_avals, + out_shardings=(None,) * len(self.outs_avals), + nr_devices=1, # Does not matter for unspecified shardings + apply_jit=False) def dtype_of_val(val: TfVal) -> DType: diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index b729c1193..2a79c62e1 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -29,7 +29,7 @@ from jax._src import test_util as jtu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax.experimental import jax2tf -from jax.experimental.export import export +from jax.experimental import export from jax.experimental.jax2tf.tests import tf_test_util import numpy as np diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index def5af73b..b28bdb01f 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -37,9 +37,8 @@ from jax._src import core from jax._src import source_info_util from jax._src import test_util as jtu from jax._src import xla_bridge as xb -from jax._src.interpreters import mlir from jax.experimental import jax2tf -from jax.experimental.export import export +from jax.experimental import export from jax.experimental.jax2tf.tests import tf_test_util from jax.experimental.maps import xmap from jax.experimental.shard_map import shard_map diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 9bfcd37fe..b1558dd9a 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -31,7 +31,7 @@ from jax._src import test_util as jtu from jax import tree_util from jax.experimental import jax2tf -from jax.experimental.export import export +from jax.experimental import export from jax._src import config from jax._src import xla_bridge import numpy as np diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 3be63dbd6..f0f53c382 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -27,7 +27,8 @@ import numpy as np import jax from jax import lax -from jax.experimental.export import export +from jax.experimental import export +from jax.experimental.export import _export from jax._src.internal_test_util import export_back_compat_test_util as bctu from jax._src.internal_test_util.export_back_compat_test_data import cpu_ducc_fft @@ -97,7 +98,7 @@ class CompatTest(bctu.CompatTestBase): def test_custom_call_coverage(self): """Tests that the back compat tests cover all the targets declared stable.""" - targets_to_cover = set(export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) + targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) # Add here all the testdatas that should cover the targets guaranteed # stable covering_testdatas = [ diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 406c09227..fbc1a4b54 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -33,7 +33,7 @@ import numpy as np import jax from jax import lax from jax._src import test_util as jtu -from jax.experimental.export import export +from jax.experimental import export from jax._src.internal_test_util import test_harnesses diff --git a/tests/export_test.py b/tests/export_test.py index 0ff30804c..63e99e889 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -26,10 +26,10 @@ import jax from jax import lax from jax import numpy as jnp from jax import tree_util -from jax.experimental.export import export -from jax.experimental.export import serialization -from jax.experimental.shard_map import shard_map +from jax.experimental import export +from jax.experimental.export import _export from jax.experimental import pjit +from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P @@ -146,8 +146,8 @@ def get_exported(fun, vjp_order=0, """Like export.export but with serialization + deserialization.""" def serde_exported(*fun_args, **fun_kwargs): exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs) - serialized = serialization.serialize(exp, vjp_order=vjp_order) - return serialization.deserialize(serialized) + serialized = export.serialize(exp, vjp_order=vjp_order) + return export.deserialize(serialized) return serde_exported class JaxExportTest(jtu.JaxTestCase): @@ -744,7 +744,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.array([True, False, True, False], dtype=np.bool_) exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), - x.dtype)) + x.dtype)) res = export.call_exported(exp)(x) self.assertAllClose(f_jax(x), res) @@ -1139,7 +1139,7 @@ class JaxExportTest(jtu.JaxTestCase): ) exp = get_exported(f_jax)(x) - if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], sorted(str(e) for e in exp.ordered_effects)) self.assertEqual(["ForTestingUnorderedEffect1()"], @@ -1169,11 +1169,11 @@ class JaxExportTest(jtu.JaxTestCase): # Results r"!stablehlo.token .*jax.token = true.*" r"!stablehlo.token .*jax.token = true.*") - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") self.assertRegex(mlir_module_str, wrapped_main_expected_re) - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: # The main function does not have tokens self.assertNotRegex(mlir_module_str, r"@main.*token") else: @@ -1191,7 +1191,7 @@ class JaxExportTest(jtu.JaxTestCase): export.call_exported(exp)(x)) lowered_outer = jax.jit(f_outer).lower(x) - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: self.assertEqual(["ForTestingOrderedEffect2()"], [str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]]) else: @@ -1201,7 +1201,7 @@ class JaxExportTest(jtu.JaxTestCase): sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]])) mlir_outer_module_str = str(lowered_outer.compiler_ir()) - if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: main_expected_re = main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") self.assertRegex(mlir_outer_module_str, main_expected_re) @@ -1229,11 +1229,11 @@ class JaxExportTest(jtu.JaxTestCase): r"%arg3: tensor<\?x\?xf32>.*\) -> \(" # Results r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") self.assertRegex(mlir_module_str, wrapped_main_expected_re) - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: # The main function does not have tokens self.assertNotRegex(mlir_module_str, r"@main.*token") else: @@ -1276,11 +1276,11 @@ class JaxExportTest(jtu.JaxTestCase): r"%arg4: tensor<\?x\?xf32>.*\) -> \(" # Results r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") self.assertRegex(mlir_module_str, wrapped_main_expected_re) - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: # The main function does not have tokens self.assertNotRegex(mlir_module_str, r"@main.*token") else: @@ -1313,7 +1313,7 @@ class JaxExportTest(jtu.JaxTestCase): f_jax = jax.jit(f_jax, donate_argnums=(0,)) exp = export.export(f_jax)(x) mlir_module_str = str(exp.mlir_module()) - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 0") self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1") else: diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index ee4fb1621..85e76c768 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -33,7 +33,7 @@ import operator as op import re import jax -from jax.experimental.export import export +from jax.experimental import export from jax.experimental.export import shape_poly from jax.experimental import pjit from jax import lax