From 69788d18b6da4358cc288b10468f575120bb4825 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 8 Jan 2024 05:29:11 -0800 Subject: [PATCH] [export] Refactor the imports for the public API of jax.experimental.export Previously we used `from jax.experimental.export import export` and `export.export(fun)`. Now we want to add the public API directly to `jax.experimental.export`, for the following desired usage: ``` from jax.experimental import export exp: export.Exported = export.export(fun) ser: bytearray = export.serialize(exp) exp1: export.Exported = export.deserialized(ser) export.call(exp1) ``` This change requires changing the type of `jax.experimental.export.export` from a module to a function. This confuses pytype for the targets with strict type checking, which is why I attempt to make this change atomically throughout the internal code base. In order to support backwards compatibility with OSS packages, this change also includes explicit JAX version checks in several OSS packages, and also adds to the `export` function the attributes that the old export module had. PiperOrigin-RevId: 596563481 --- CHANGELOG.md | 5 ++ .../export_back_compat_test_util.py | 2 +- jax/experimental/export/BUILD | 4 +- jax/experimental/export/__init__.py | 18 ++++++ .../export/{export.py => _export.py} | 58 ++++++++++++++----- jax/experimental/export/serialization.py | 8 ++- jax/experimental/jax2tf/jax2tf.py | 35 +++++------ jax/experimental/jax2tf/tests/call_tf_test.py | 2 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 3 +- jax/experimental/jax2tf/tests/tf_test_util.py | 2 +- tests/export_back_compat_test.py | 5 +- tests/export_harnesses_multi_platform_test.py | 2 +- tests/export_test.py | 32 +++++----- tests/shape_poly_test.py | 2 +- 14 files changed, 117 insertions(+), 61 deletions(-) rename jax/experimental/export/{export.py => _export.py} (97%) diff --git a/CHANGELOG.md b/CHANGELOG.md index d00703139..1923394a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.24 * Changes + * JAX lowering to StableHLO does not depend on physical devices anymore. If your primitive wraps custom_paritioning or JAX callbacks in the lowering rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your @@ -27,6 +28,10 @@ Remember to align the itemized text with the first line of an item within a list ({jax-issue}`#19231`; note that this may result in user-visible behavior changes); improved the error messages for inconclusive inequality comparisons ({jax-issue}`#19235`). + * Refactored the API for `jax.experimental.export`. Instead of + `from jax.experimental.export import export` you should use now + `from jax.experimental import export`. The old way of importing will + continue to work for a deprecation period of 3 months. * Deprecations & Removals * A number of previously deprecated functions have been removed, following a standard 3+ month deprecation cycle (see {ref}`api-compatibility`). diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index fbef4b05c..2c8fbf713 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -83,7 +83,7 @@ from numpy import array, float32 import jax from jax import tree_util -from jax.experimental.export import export +from jax.experimental import export from jax.experimental import pjit diff --git a/jax/experimental/export/BUILD b/jax/experimental/export/BUILD index 1eabba6b5..a0189039a 100644 --- a/jax/experimental/export/BUILD +++ b/jax/experimental/export/BUILD @@ -22,7 +22,6 @@ load("@rules_python//python:defs.bzl", "py_library") licenses(["notice"]) -# Please add new users to :australis_users. package( default_applicable_licenses = [], default_visibility = ["//visibility:private"], @@ -31,7 +30,8 @@ package( py_library( name = "export", srcs = [ - "export.py", + "__init__.py", + "_export.py", "serialization.py", "serialization_generated.py", "shape_poly.py", diff --git a/jax/experimental/export/__init__.py b/jax/experimental/export/__init__.py index e9a5c3c22..e00e80d54 100644 --- a/jax/experimental/export/__init__.py +++ b/jax/experimental/export/__init__.py @@ -12,3 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + +from jax.experimental.export._export import ( + minimum_supported_serialization_version, + maximum_supported_serialization_version, + Exported, + export, + call_exported, # TODO: deprecate + call, + DisabledSafetyCheck, + default_lowering_platform, + + symbolic_shape, + args_specs, +) +from jax.experimental.export.serialization import ( + serialize, + deserialize, +) diff --git a/jax/experimental/export/export.py b/jax/experimental/export/_export.py similarity index 97% rename from jax/experimental/export/export.py rename to jax/experimental/export/_export.py index 03fffc9ee..12994e227 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/_export.py @@ -24,6 +24,7 @@ import functools import itertools import re from typing import Any, Callable, Optional, TypeVar, Union +import warnings from absl import logging import numpy as np @@ -57,6 +58,20 @@ zip = util.safe_zip DType = Any Shape = jax._src.core.Shape +# The values of input and output sharding from the lowering. +LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] + +# None means unspecified sharding +Sharding = Union[xla_client.HloSharding, None] + +# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions +# for a description of the different versions. +minimum_supported_serialization_version = 6 +maximum_supported_serialization_version = 9 + +_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7 +_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9 + class DisabledSafetyCheck: """A safety check should be skipped on (de)serialization. @@ -117,19 +132,6 @@ class DisabledSafetyCheck: def __hash__(self) -> int: return hash(self._impl) -# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions -# for a description of the different versions. -minimum_supported_serialization_version = 6 -maximum_supported_serialization_version = 9 - -_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7 -_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9 - -# The values of input and output sharding from the lowering. -LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue] - -# None means unspecified sharding -Sharding = Union[xla_client.HloSharding, None] @dataclasses.dataclass(frozen=True) class Exported: @@ -1052,7 +1054,7 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported: ### Calling the exported function -def call_exported(exported: Exported) -> Callable[..., jax.Array]: +def call(exported: Exported) -> Callable[..., jax.Array]: if not isinstance(exported, Exported): raise ValueError( "The exported argument must be an export.Exported. " @@ -1107,6 +1109,7 @@ def call_exported(exported: Exported) -> Callable[..., jax.Array]: return exported.out_tree.unflatten(res_flat) return f_imported +call_exported = call # A JAX primitive for invoking a serialized JAX function. call_exported_p = core.Primitive("call_exported") @@ -1307,3 +1310,30 @@ def wrap_with_sharding(ctx: mlir.LoweringRuleContext, return x return mlir.wrap_with_sharding_op( ctx, x, x_aval, x_sharding.to_proto()) + +# TODO(necula): Previously, we had `from jax.experimental.export import export` +# Now we want to simplify the usage, and export the public APIs directly +# from `jax.experimental.export` and now `jax.experimental.export.export` +# refers to the `export` function. Since there may still be users of the +# old API in other packages, we add the old public API as attributes of the +# exported function. We will clean this up after a deprecation period. +def wrap_with_deprecation_warning(f): + msg = (f"You are using function `{f.__name__}` from " + "`jax.experimental.export.export`. You should instead use it directly " + "from `jax.experimental.export`. Instead of " + "`from jax.experimental.export import export` you should use " + "`from jax.experimental import export`.") + def wrapped_f(*args, **kwargs): + warnings.warn(msg, DeprecationWarning) + return f(*args, **kwargs) + return wrapped_f + +export.export = wrap_with_deprecation_warning(export) +export.Exported = Exported +export.call_exported = wrap_with_deprecation_warning(call_exported) +export.DisabledSafetyCheck = DisabledSafetyCheck +export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform) +export.symbolic_shape = wrap_with_deprecation_warning(symbolic_shape) +export.args_specs = wrap_with_deprecation_warning(args_specs) +export.minimum_supported_serialization_version = minimum_supported_serialization_version +export.maximum_supported_serialization_version = maximum_supported_serialization_version diff --git a/jax/experimental/export/serialization.py b/jax/experimental/export/serialization.py index b4b0e215f..bb1893d86 100644 --- a/jax/experimental/export/serialization.py +++ b/jax/experimental/export/serialization.py @@ -29,8 +29,10 @@ from jax._src import dtypes from jax._src import effects from jax._src import tree_util from jax._src.lib import xla_client -from jax.experimental.export import export from jax.experimental.export import serialization_generated as ser_flatbuf +from jax.experimental.export import _export +from jax.experimental import export + import numpy as np T = TypeVar("T") @@ -359,7 +361,7 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue) -> core.AbstractValue: def _serialize_sharding( - builder: flatbuffers.Builder, s: export.Sharding + builder: flatbuffers.Builder, s: _export.Sharding ) -> int: proto = None if s is None: @@ -376,7 +378,7 @@ def _serialize_sharding( return ser_flatbuf.ShardingEnd(builder) -def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding: +def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.Sharding: kind = s.Kind() if kind == ser_flatbuf.ShardingKind.unspecified: return None diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 41985df63..653b46591 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -38,7 +38,8 @@ from jax import tree_util from jax import sharding from jax.experimental import maps from jax.experimental.export import shape_poly -from jax.experimental.export import export +from jax.experimental.export import _export +from jax.experimental import export from jax.experimental.jax2tf import impl_no_xla from jax.interpreters import xla @@ -515,14 +516,14 @@ class NativeSerializationImpl(SerializationImpl): def get_vjp_fun(self) -> tuple[Callable, Sequence[core.AbstractValue]]: - return export._get_vjp_fun(self.fun_jax, - in_tree=self.exported.in_tree, - in_avals=self.exported.in_avals, - in_shardings=self.exported.in_shardings, - out_avals=self.exported.out_avals, - out_shardings=self.exported.out_shardings, - nr_devices=self.exported.nr_devices, - apply_jit=True) + return _export._get_vjp_fun(self.fun_jax, + in_tree=self.exported.in_tree, + in_avals=self.exported.in_avals, + in_shardings=self.exported.in_shardings, + out_avals=self.exported.out_avals, + out_shardings=self.exported.out_shardings, + nr_devices=self.exported.nr_devices, + apply_jit=True) class GraphSerializationImpl(SerializationImpl): def __init__(self, fun_jax, *, @@ -587,14 +588,14 @@ class GraphSerializationImpl(SerializationImpl): # We reuse the code for native serialization to get the VJP functions, # except we use unspecified shardings, and we do not apply a jit on the # VJP. This matches the older behavior of jax2tf for graph serialization. - return export._get_vjp_fun(self.fun_jax, - in_tree=self.in_tree, - in_avals=self.args_avals_flat, - in_shardings=(None,) * len(self.args_avals_flat), - out_avals=self.outs_avals, - out_shardings=(None,) * len(self.outs_avals), - nr_devices=1, # Does not matter for unspecified shardings - apply_jit=False) + return _export._get_vjp_fun(self.fun_jax, + in_tree=self.in_tree, + in_avals=self.args_avals_flat, + in_shardings=(None,) * len(self.args_avals_flat), + out_avals=self.outs_avals, + out_shardings=(None,) * len(self.outs_avals), + nr_devices=1, # Does not matter for unspecified shardings + apply_jit=False) def dtype_of_val(val: TfVal) -> DType: diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index b729c1193..2a79c62e1 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -29,7 +29,7 @@ from jax._src import test_util as jtu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax.experimental import jax2tf -from jax.experimental.export import export +from jax.experimental import export from jax.experimental.jax2tf.tests import tf_test_util import numpy as np diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index def5af73b..b28bdb01f 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -37,9 +37,8 @@ from jax._src import core from jax._src import source_info_util from jax._src import test_util as jtu from jax._src import xla_bridge as xb -from jax._src.interpreters import mlir from jax.experimental import jax2tf -from jax.experimental.export import export +from jax.experimental import export from jax.experimental.jax2tf.tests import tf_test_util from jax.experimental.maps import xmap from jax.experimental.shard_map import shard_map diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 9bfcd37fe..b1558dd9a 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -31,7 +31,7 @@ from jax._src import test_util as jtu from jax import tree_util from jax.experimental import jax2tf -from jax.experimental.export import export +from jax.experimental import export from jax._src import config from jax._src import xla_bridge import numpy as np diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 3be63dbd6..f0f53c382 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -27,7 +27,8 @@ import numpy as np import jax from jax import lax -from jax.experimental.export import export +from jax.experimental import export +from jax.experimental.export import _export from jax._src.internal_test_util import export_back_compat_test_util as bctu from jax._src.internal_test_util.export_back_compat_test_data import cpu_ducc_fft @@ -97,7 +98,7 @@ class CompatTest(bctu.CompatTestBase): def test_custom_call_coverage(self): """Tests that the back compat tests cover all the targets declared stable.""" - targets_to_cover = set(export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) + targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) # Add here all the testdatas that should cover the targets guaranteed # stable covering_testdatas = [ diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 406c09227..fbc1a4b54 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -33,7 +33,7 @@ import numpy as np import jax from jax import lax from jax._src import test_util as jtu -from jax.experimental.export import export +from jax.experimental import export from jax._src.internal_test_util import test_harnesses diff --git a/tests/export_test.py b/tests/export_test.py index 0ff30804c..63e99e889 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -26,10 +26,10 @@ import jax from jax import lax from jax import numpy as jnp from jax import tree_util -from jax.experimental.export import export -from jax.experimental.export import serialization -from jax.experimental.shard_map import shard_map +from jax.experimental import export +from jax.experimental.export import _export from jax.experimental import pjit +from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding from jax.sharding import Mesh from jax.sharding import PartitionSpec as P @@ -146,8 +146,8 @@ def get_exported(fun, vjp_order=0, """Like export.export but with serialization + deserialization.""" def serde_exported(*fun_args, **fun_kwargs): exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs) - serialized = serialization.serialize(exp, vjp_order=vjp_order) - return serialization.deserialize(serialized) + serialized = export.serialize(exp, vjp_order=vjp_order) + return export.deserialize(serialized) return serde_exported class JaxExportTest(jtu.JaxTestCase): @@ -744,7 +744,7 @@ class JaxExportTest(jtu.JaxTestCase): x = np.array([True, False, True, False], dtype=np.bool_) exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"), - x.dtype)) + x.dtype)) res = export.call_exported(exp)(x) self.assertAllClose(f_jax(x), res) @@ -1139,7 +1139,7 @@ class JaxExportTest(jtu.JaxTestCase): ) exp = get_exported(f_jax)(x) - if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"], sorted(str(e) for e in exp.ordered_effects)) self.assertEqual(["ForTestingUnorderedEffect1()"], @@ -1169,11 +1169,11 @@ class JaxExportTest(jtu.JaxTestCase): # Results r"!stablehlo.token .*jax.token = true.*" r"!stablehlo.token .*jax.token = true.*") - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") self.assertRegex(mlir_module_str, wrapped_main_expected_re) - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: # The main function does not have tokens self.assertNotRegex(mlir_module_str, r"@main.*token") else: @@ -1191,7 +1191,7 @@ class JaxExportTest(jtu.JaxTestCase): export.call_exported(exp)(x)) lowered_outer = jax.jit(f_outer).lower(x) - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: self.assertEqual(["ForTestingOrderedEffect2()"], [str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]]) else: @@ -1201,7 +1201,7 @@ class JaxExportTest(jtu.JaxTestCase): sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]])) mlir_outer_module_str = str(lowered_outer.compiler_ir()) - if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: main_expected_re = main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") self.assertRegex(mlir_outer_module_str, main_expected_re) @@ -1229,11 +1229,11 @@ class JaxExportTest(jtu.JaxTestCase): r"%arg3: tensor<\?x\?xf32>.*\) -> \(" # Results r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") self.assertRegex(mlir_module_str, wrapped_main_expected_re) - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: # The main function does not have tokens self.assertNotRegex(mlir_module_str, r"@main.*token") else: @@ -1276,11 +1276,11 @@ class JaxExportTest(jtu.JaxTestCase): r"%arg4: tensor<\?x\?xf32>.*\) -> \(" # Results r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)") - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>") self.assertRegex(mlir_module_str, wrapped_main_expected_re) - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: # The main function does not have tokens self.assertNotRegex(mlir_module_str, r"@main.*token") else: @@ -1313,7 +1313,7 @@ class JaxExportTest(jtu.JaxTestCase): f_jax = jax.jit(f_jax, donate_argnums=(0,)) exp = export.export(f_jax)(x) mlir_module_str = str(exp.mlir_module()) - if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: + if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS: self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 0") self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1") else: diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index ee4fb1621..85e76c768 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -33,7 +33,7 @@ import operator as op import re import jax -from jax.experimental.export import export +from jax.experimental import export from jax.experimental.export import shape_poly from jax.experimental import pjit from jax import lax