mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #18980 from gnecula:export_api
PiperOrigin-RevId: 591172917
This commit is contained in:
commit
91faddd023
@ -9,6 +9,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
## jax 0.4.24
|
||||
|
||||
* Changes
|
||||
|
||||
* JAX lowering to StableHLO does not depend on physical devices anymore.
|
||||
If your primitive wraps custom_paritioning or JAX callbacks in the lowering
|
||||
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
|
||||
@ -17,6 +18,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
devices to create `Sharding`s during lowering.
|
||||
This is a temporary state until we can create `Sharding`s without physical
|
||||
devices.
|
||||
* Refactored the API for `jax.experimental.export`. Instead of
|
||||
`from jax.experimental.export import export` you should use now
|
||||
`from jax.experimental import export`.
|
||||
|
||||
## jaxlib 0.4.24
|
||||
|
||||
|
@ -83,7 +83,7 @@ from numpy import array, float32
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental import export
|
||||
|
||||
from jax.experimental import pjit
|
||||
|
||||
|
@ -22,7 +22,6 @@ load("@rules_python//python:defs.bzl", "py_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# Please add new users to :australis_users.
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//visibility:private"],
|
||||
@ -31,7 +30,8 @@ package(
|
||||
py_library(
|
||||
name = "export",
|
||||
srcs = [
|
||||
"export.py",
|
||||
"__init__.py",
|
||||
"_export.py",
|
||||
"serialization.py",
|
||||
"serialization_generated.py",
|
||||
"shape_poly.py",
|
||||
|
@ -12,3 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from jax.experimental.export._export import (
|
||||
minimum_supported_serialization_version,
|
||||
maximum_supported_serialization_version,
|
||||
Exported,
|
||||
export,
|
||||
call_exported, # TODO: deprecate
|
||||
call,
|
||||
DisabledSafetyCheck,
|
||||
default_lowering_platform,
|
||||
|
||||
symbolic_shape,
|
||||
args_specs,
|
||||
)
|
||||
from jax.experimental.export.serialization import (
|
||||
serialize,
|
||||
deserialize,
|
||||
)
|
||||
|
@ -24,6 +24,7 @@ import functools
|
||||
import itertools
|
||||
import re
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
@ -55,6 +56,20 @@ zip = util.safe_zip
|
||||
|
||||
DType = Any
|
||||
Shape = jax._src.core.Shape
|
||||
# The values of input and output sharding from the lowering.
|
||||
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
|
||||
|
||||
# None means unspecified sharding
|
||||
Sharding = Union[xla_client.HloSharding, None]
|
||||
|
||||
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
|
||||
# for a description of the different versions.
|
||||
minimum_supported_serialization_version = 6
|
||||
maximum_supported_serialization_version = 9
|
||||
|
||||
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
|
||||
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9
|
||||
|
||||
|
||||
class DisabledSafetyCheck:
|
||||
"""A safety check should be skipped on (de)serialization.
|
||||
@ -115,19 +130,6 @@ class DisabledSafetyCheck:
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._impl)
|
||||
|
||||
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
|
||||
# for a description of the different versions.
|
||||
minimum_supported_serialization_version = 6
|
||||
maximum_supported_serialization_version = 9
|
||||
|
||||
_VERSION_START_SUPPORT_SHAPE_ASSERTIONS = 7
|
||||
_VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS = 9
|
||||
|
||||
# The values of input and output sharding from the lowering.
|
||||
LoweringSharding = Union[sharding.XLACompatibleSharding, pxla.UnspecifiedValue]
|
||||
|
||||
# None means unspecified sharding
|
||||
Sharding = Union[xla_client.HloSharding, None]
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Exported:
|
||||
@ -1050,7 +1052,7 @@ def _export_native_vjp(primal_fun, primal: Exported) -> Exported:
|
||||
|
||||
### Calling the exported function
|
||||
|
||||
def call_exported(exported: Exported) -> Callable[..., jax.Array]:
|
||||
def call(exported: Exported) -> Callable[..., jax.Array]:
|
||||
if not isinstance(exported, Exported):
|
||||
raise ValueError(
|
||||
"The exported argument must be an export.Exported. "
|
||||
@ -1096,6 +1098,7 @@ def call_exported(exported: Exported) -> Callable[..., jax.Array]:
|
||||
return exported.out_tree.unflatten(res_flat)
|
||||
return f_imported
|
||||
|
||||
call_exported = call
|
||||
|
||||
# A JAX primitive for invoking a serialized JAX function.
|
||||
call_exported_p = core.Primitive("call_exported")
|
||||
@ -1296,3 +1299,30 @@ def wrap_with_sharding(ctx: mlir.LoweringRuleContext,
|
||||
return x
|
||||
return mlir.wrap_with_sharding_op(
|
||||
ctx, x, x_aval, x_sharding.to_proto())
|
||||
|
||||
# TODO(necula): Previously, we had `from jax.experimental.export import export`
|
||||
# Now we want to simplify the usage, and export the public APIs directly
|
||||
# from `jax.experimental.export` and now `jax.experimental.export.export`
|
||||
# refers to the `export` function. Since there may still be users of the
|
||||
# old API in other packages, we add the old public API as attributes of the
|
||||
# exported function. We will clean this up after a deprecation period.
|
||||
def wrap_with_deprecation_warning(f):
|
||||
msg = (f"You are using function `{f.__name__}` from "
|
||||
"`jax.experimental.export.export`. You should instead use it directly "
|
||||
"from `jax.experimental.export`. Instead of "
|
||||
"`from jax.experimental.export import export` you should use "
|
||||
"`from jax.experimental import export`.")
|
||||
def wrapped_f(*args, **kwargs):
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return f(*args, **kwargs)
|
||||
return wrapped_f
|
||||
|
||||
export.export = wrap_with_deprecation_warning(export)
|
||||
export.Exported = Exported
|
||||
export.call_exported = wrap_with_deprecation_warning(call_exported)
|
||||
export.DisabledSafetyCheck = DisabledSafetyCheck
|
||||
export.default_lowering_platform = wrap_with_deprecation_warning(default_lowering_platform)
|
||||
export.symbolic_shape = wrap_with_deprecation_warning(symbolic_shape)
|
||||
export.args_specs = wrap_with_deprecation_warning(args_specs)
|
||||
export.minimum_supported_serialization_version = minimum_supported_serialization_version
|
||||
export.maximum_supported_serialization_version = maximum_supported_serialization_version
|
@ -29,8 +29,10 @@ from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import tree_util
|
||||
from jax._src.lib import xla_client
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.export import serialization_generated as ser_flatbuf
|
||||
from jax.experimental.export import _export
|
||||
from jax.experimental import export
|
||||
|
||||
import numpy as np
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -353,7 +355,7 @@ def _deserialize_aval(aval: ser_flatbuf.AbstractValue) -> core.AbstractValue:
|
||||
|
||||
|
||||
def _serialize_sharding(
|
||||
builder: flatbuffers.Builder, s: export.Sharding
|
||||
builder: flatbuffers.Builder, s: _export.Sharding
|
||||
) -> int:
|
||||
proto = None
|
||||
if s is None:
|
||||
@ -370,7 +372,7 @@ def _serialize_sharding(
|
||||
return ser_flatbuf.ShardingEnd(builder)
|
||||
|
||||
|
||||
def _deserialize_sharding(s: ser_flatbuf.Sharding) -> export.Sharding:
|
||||
def _deserialize_sharding(s: ser_flatbuf.Sharding) -> _export.Sharding:
|
||||
kind = s.Kind()
|
||||
if kind == ser_flatbuf.ShardingKind.unspecified:
|
||||
return None
|
||||
|
@ -38,7 +38,8 @@ from jax import tree_util
|
||||
from jax import sharding
|
||||
from jax.experimental import maps
|
||||
from jax.experimental.export import shape_poly
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.export import _export
|
||||
from jax.experimental import export
|
||||
from jax.experimental.jax2tf import impl_no_xla
|
||||
from jax.interpreters import xla
|
||||
|
||||
@ -514,14 +515,14 @@ class NativeSerializationImpl(SerializationImpl):
|
||||
|
||||
def get_vjp_fun(self) -> tuple[Callable,
|
||||
Sequence[core.AbstractValue]]:
|
||||
return export._get_vjp_fun(self.fun_jax,
|
||||
in_tree=self.exported.in_tree,
|
||||
in_avals=self.exported.in_avals,
|
||||
in_shardings=self.exported.in_shardings,
|
||||
out_avals=self.exported.out_avals,
|
||||
out_shardings=self.exported.out_shardings,
|
||||
nr_devices=self.exported.nr_devices,
|
||||
apply_jit=True)
|
||||
return _export._get_vjp_fun(self.fun_jax,
|
||||
in_tree=self.exported.in_tree,
|
||||
in_avals=self.exported.in_avals,
|
||||
in_shardings=self.exported.in_shardings,
|
||||
out_avals=self.exported.out_avals,
|
||||
out_shardings=self.exported.out_shardings,
|
||||
nr_devices=self.exported.nr_devices,
|
||||
apply_jit=True)
|
||||
|
||||
class GraphSerializationImpl(SerializationImpl):
|
||||
def __init__(self, fun_jax, *,
|
||||
@ -586,14 +587,14 @@ class GraphSerializationImpl(SerializationImpl):
|
||||
# We reuse the code for native serialization to get the VJP functions,
|
||||
# except we use unspecified shardings, and we do not apply a jit on the
|
||||
# VJP. This matches the older behavior of jax2tf for graph serialization.
|
||||
return export._get_vjp_fun(self.fun_jax,
|
||||
in_tree=self.in_tree,
|
||||
in_avals=self.args_avals_flat,
|
||||
in_shardings=(None,) * len(self.args_avals_flat),
|
||||
out_avals=self.outs_avals,
|
||||
out_shardings=(None,) * len(self.outs_avals),
|
||||
nr_devices=1, # Does not matter for unspecified shardings
|
||||
apply_jit=False)
|
||||
return _export._get_vjp_fun(self.fun_jax,
|
||||
in_tree=self.in_tree,
|
||||
in_avals=self.args_avals_flat,
|
||||
in_shardings=(None,) * len(self.args_avals_flat),
|
||||
out_avals=self.outs_avals,
|
||||
out_shardings=(None,) * len(self.outs_avals),
|
||||
nr_devices=1, # Does not matter for unspecified shardings
|
||||
apply_jit=False)
|
||||
|
||||
|
||||
def dtype_of_val(val: TfVal) -> DType:
|
||||
|
@ -27,7 +27,8 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental import export
|
||||
from jax.experimental.export import _export
|
||||
from jax._src.internal_test_util import export_back_compat_test_util as bctu
|
||||
|
||||
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
|
||||
@ -97,7 +98,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
|
||||
def test_custom_call_coverage(self):
|
||||
"""Tests that the back compat tests cover all the targets declared stable."""
|
||||
targets_to_cover = set(export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
|
||||
targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
|
||||
# Add here all the testdatas that should cover the targets guaranteed
|
||||
# stable
|
||||
covering_testdatas = [
|
||||
|
@ -29,7 +29,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental import export
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
import numpy as np
|
||||
|
||||
|
@ -37,9 +37,8 @@ from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental import export
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
@ -31,7 +31,7 @@ from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental import export
|
||||
from jax._src import config
|
||||
from jax._src import xla_bridge
|
||||
import numpy as np
|
||||
|
@ -26,8 +26,8 @@ import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.export import serialization
|
||||
from jax.experimental import export
|
||||
from jax.experimental.export import _export
|
||||
from jax.experimental import pjit
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import Mesh
|
||||
@ -145,8 +145,8 @@ def get_exported(fun, max_vjp_orders=0,
|
||||
"""Like export.export but with serialization + deserialization."""
|
||||
def serde_exported(*fun_args, **fun_kwargs):
|
||||
exp = export.export(fun, **export_kwargs)(*fun_args, **fun_kwargs)
|
||||
serialized = serialization.serialize(exp, vjp_order=max_vjp_orders)
|
||||
return serialization.deserialize(serialized)
|
||||
serialized = export.serialize(exp, vjp_order=max_vjp_orders)
|
||||
return export.deserialize(serialized)
|
||||
return serde_exported
|
||||
|
||||
class JaxExportTest(jtu.JaxTestCase):
|
||||
@ -704,7 +704,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
x = np.array([True, False, True, False], dtype=np.bool_)
|
||||
exp = get_exported(f_jax)(jax.ShapeDtypeStruct(export.symbolic_shape("b"),
|
||||
x.dtype))
|
||||
x.dtype))
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(f_jax(x), res)
|
||||
|
||||
@ -1042,7 +1042,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
exp = get_exported(f_jax)(x)
|
||||
if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
self.assertEqual(["ForTestingOrderedEffect1()", "ForTestingOrderedEffect2()"],
|
||||
sorted(str(e) for e in exp.ordered_effects))
|
||||
self.assertEqual(["ForTestingUnorderedEffect1()"],
|
||||
@ -1072,11 +1072,11 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# Results
|
||||
r"!stablehlo.token .*jax.token = true.*"
|
||||
r"!stablehlo.token .*jax.token = true.*")
|
||||
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
|
||||
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
||||
|
||||
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
# The main function does not have tokens
|
||||
self.assertNotRegex(mlir_module_str, r"@main.*token")
|
||||
else:
|
||||
@ -1094,7 +1094,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
export.call_exported(exp)(x))
|
||||
|
||||
lowered_outer = jax.jit(f_outer).lower(x)
|
||||
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
self.assertEqual(["ForTestingOrderedEffect2()"],
|
||||
[str(e) for e in lowered_outer._lowering.compile_args["ordered_effects"]])
|
||||
else:
|
||||
@ -1104,7 +1104,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
sorted([str(e) for e in lowered_outer._lowering.compile_args["unordered_effects"]]))
|
||||
|
||||
mlir_outer_module_str = str(lowered_outer.compiler_ir())
|
||||
if exp.mlir_module_serialization_version >= export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
main_expected_re = main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
|
||||
self.assertRegex(mlir_outer_module_str, main_expected_re)
|
||||
|
||||
@ -1132,11 +1132,11 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
r"%arg3: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
|
||||
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
|
||||
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
||||
|
||||
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
# The main function does not have tokens
|
||||
self.assertNotRegex(mlir_module_str, r"@main.*token")
|
||||
else:
|
||||
@ -1179,11 +1179,11 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
r"%arg4: tensor<\?x\?xf32>.*\) -> \("
|
||||
# Results
|
||||
r"!stablehlo.token {jax.token = true}, tensor<\?x\?xf32>.*\)")
|
||||
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
wrapped_main_expected_re = wrapped_main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
|
||||
self.assertRegex(mlir_module_str, wrapped_main_expected_re)
|
||||
|
||||
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
# The main function does not have tokens
|
||||
self.assertNotRegex(mlir_module_str, r"@main.*token")
|
||||
else:
|
||||
@ -1216,7 +1216,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
f_jax = jax.jit(f_jax, donate_argnums=(0,))
|
||||
exp = export.export(f_jax)(x)
|
||||
mlir_module_str = str(exp.mlir_module())
|
||||
if exp.mlir_module_serialization_version < export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
if exp.mlir_module_serialization_version < _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
self.assertRegex(mlir_module_str, r"@main.*tf.aliasing_output = 0")
|
||||
self.assertRegex(mlir_module_str, r"@_wrapped_jax_export_main.*tf.aliasing_output = 1")
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user