[export] Move jax_export and shape_poly out of jax2tf.

Those modules have been developed initially for jax2tf
but they do not depend on TF anymore. They are used for JAX
native serialization. We move them under
jax.experimental.export (also renaming jax_export.py to export.py) so that we can use them without depending on TF.

We are leaving behind stub modules jax2tf.jax_export and jax2tf.shape_poly that just redirect some of the public APIs. To be cleaned later.

PiperOrigin-RevId: 562988740
This commit is contained in:
George Necula 2023-09-05 22:15:22 -07:00 committed by jax authors
parent b2e5a1cf6a
commit 660a015652
16 changed files with 2843 additions and 2735 deletions

View File

@ -0,0 +1,44 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# JAX-export provides APIs for exporting StableHLO for serialization purposes.
load(
"//jaxlib:jax.bzl",
"py_deps",
)
load("@rules_python//python:defs.bzl", "py_library")
licenses(["notice"])
# Please add new users to :australis_users.
package(
default_applicable_licenses = [],
default_visibility = ["//visibility:private"],
)
py_library(
name = "export",
srcs = [
"export.py",
"shape_poly.py",
],
srcs_version = "PY3",
# TODO: b/255503696: enable pytype
tags = ["pytype_unchecked_annotations"],
visibility = ["//visibility:public"],
deps = [
"//jax",
] + py_deps("numpy"),
)

View File

@ -0,0 +1,14 @@
# Copyright 2023 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -35,34 +35,21 @@ py_library(
deps = [":jax2tf_internal"],
)
py_library(
name = "jax_export",
srcs = [
"jax_export.py",
"shape_poly.py",
],
srcs_version = "PY3",
# TODO: b/255503696: enable pytype
tags = ["pytype_unchecked_annotations"],
visibility = ["//visibility:public"],
deps = [
"//jax",
] + py_deps("numpy"),
)
py_library(
name = "jax2tf_internal",
srcs = [
"call_tf.py",
"impl_no_xla.py",
"jax2tf.py",
"jax_export.py", # TODO(necula): remove stub
"shape_poly.py", # TODO(necula): remove stub
],
srcs_version = "PY3",
# TODO: b/255503696: enable pytype
tags = ["pytype_unchecked_annotations"],
visibility = jax_visibility("jax2tf_internal"),
deps = [
":jax_export",
"//jax",
"//jax/experimental/export",
] + py_deps("numpy") + py_deps("tensorflow_core") + jax2tf_deps,
)

View File

@ -21,3 +21,7 @@ from jax.experimental.jax2tf.jax2tf import (
PolyShape as PolyShape
)
from jax.experimental.jax2tf.call_tf import call_tf as call_tf
# TODO(necula): remove stub. Needed by SAX
from jax.experimental.jax2tf import jax_export
# Needed by maths.qec.
from jax.experimental.jax2tf import shape_poly

View File

@ -36,9 +36,9 @@ from jax import numpy as jnp
from jax import tree_util
from jax import sharding
from jax.experimental import maps
from jax.experimental.jax2tf import shape_poly
from jax.experimental.export import shape_poly
from jax.experimental.export import export
from jax.experimental.jax2tf import impl_no_xla
from jax.experimental.jax2tf import jax_export
from jax.interpreters import xla
from jax._src import ad_checkpoint
@ -86,7 +86,7 @@ NameStack = source_info_util.NameStack
PolyShape = shape_poly.PolyShape
DType = Any
DisabledSafetyCheck = jax_export.DisabledSafetyCheck
DisabledSafetyCheck = export.DisabledSafetyCheck
# A temporary internal flag, to enable the wrapping of jax.jit functions
# with tf.function(jit_compile=True). See #7389. This change has triggered a
@ -370,14 +370,14 @@ def convert(fun_jax: Callable,
_, a_jax_dtype = _tfval_to_tensor_jax_dtype(a)
return tf_arg_shape, a_jax_dtype
args_specs = jax_export.poly_specs(args_tf,
polymorphic_shapes=polymorphic_shapes,
get_shape_and_dtype=shape_and_dtype_tf)
args_specs = export.poly_specs(args_tf,
polymorphic_shapes=polymorphic_shapes,
get_shape_and_dtype=shape_and_dtype_tf)
# The polymorphic_shapes argument refers to positional arguments only.
# We assume None for the kwargs.
kwargs_specs = jax_export.poly_specs(kwargs_tf,
polymorphic_shapes=None,
get_shape_and_dtype=shape_and_dtype_tf)
kwargs_specs = export.poly_specs(kwargs_tf,
polymorphic_shapes=None,
get_shape_and_dtype=shape_and_dtype_tf)
combined_args_tf = (args_tf, kwargs_tf)
args_flat_tf: Sequence[TfVal]
args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf)
@ -503,7 +503,7 @@ class NativeSerializationImpl(SerializationImpl):
_thread_local_state.call_tf_concrete_function_list = _prev_func_list
self._restore_context = _restore_context
self.exported = jax_export.export(
self.exported = export.export(
self.fun_jax,
lowering_platform=self.lowering_platform,
disabled_checks=self.native_serialization_disabled_checks
@ -669,7 +669,7 @@ def eval_polymorphic_shape(fun_jax: Callable,
(c, a)
"""
def do_eval_polymorphic_shape(*args_specs) -> Any:
args_poly_specs = jax_export.poly_specs(
args_poly_specs = export.poly_specs(
args_specs, polymorphic_shapes=polymorphic_shapes)
res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs)
# TODO(necula): For now we export the polymorphic shapes using `str`.
@ -803,7 +803,7 @@ def _interpret_fun_jax(
def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
exported: jax_export.Exported,
exported: export.Exported,
) -> Sequence[TfVal]:
"""Runs the `exported` as an XlaCallModule TF op.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -28,7 +28,7 @@ import numpy as np
import jax
from jax import config
from jax import lax
from jax.experimental.jax2tf import jax_export
from jax.experimental.export import export
from jax.experimental.jax2tf.tests import back_compat_test_util as bctu
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
@ -96,7 +96,7 @@ class CompatTest(bctu.CompatTestBase):
def test_custom_call_coverage(self):
"""Tests that the back compat tests cover all the targets declared stable."""
targets_to_cover = set(jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
targets_to_cover = set(export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
# Add here all the testdatas that should cover the targets guaranteed
# stable
covering_testdatas = [

View File

@ -21,7 +21,7 @@ The tests in this file refer to the test data in ./back_compat_testdata.
There is one test for each version of a custom call target, e.g.,
`test_ducc_fft` tests the FFT custom calls on CPU.
Only custom call targets tested here should be listed in
jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE. All other custom
export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE. All other custom
call targets will result in an error when encountered during serialization.
Once we stop using a custom call target in JAX, you can remove it from the
@ -78,7 +78,7 @@ from numpy import array, float32
import jax
from jax import tree_util
from jax.experimental.jax2tf import jax_export
from jax.experimental.export import export
from jax.experimental import pjit
@ -281,12 +281,12 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
a string (for debugging), and (c) the module serialization version.
"""
# Use the native exporter, to make sure we get the proper serialization.
args_specs = jax_export.poly_specs(data.inputs, polymorphic_shapes)
exported = jax_export.export(
args_specs = export.poly_specs(data.inputs, polymorphic_shapes)
exported = export.export(
jax.jit(func),
lowering_platform=self.default_jax_backend(),
disabled_checks=tuple(
jax_export.DisabledSafetyCheck.custom_call(target)
export.DisabledSafetyCheck.custom_call(target)
for target in allow_unstable_custom_call_targets)
)(*args_specs)
@ -297,13 +297,13 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
def run_serialized(self, data: CompatTestData,
polymorphic_shapes: Optional[Sequence[str]] = None):
args_specs = jax_export.poly_specs(data.inputs, polymorphic_shapes)
args_specs = export.poly_specs(data.inputs, polymorphic_shapes)
def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray:
return core.ShapedArray(a.shape, a.dtype)
in_avals_tree = tree_util.tree_map(ndarray_to_aval, args_specs)
# TODO: we ought to ensure that out_avals are polymorphic if need be. We
# could either save the in/out_avals (but we need to first implement that
# support in jax_export), or we can just re-use them from the current
# support in export), or we can just re-use them from the current
# exported.
out_avals_tree = tree_util.tree_map(ndarray_to_aval, data.expected_outputs)
# in_tree must be for (args, kwargs)
@ -312,7 +312,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
def _get_vjp(_):
assert False # We do not have and do not need VJP
exported = jax_export.Exported(
exported = export.Exported(
fun_name="run_serialized",
in_tree=in_tree,
in_avals=tuple(in_avals),
@ -331,4 +331,4 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
_get_vjp=_get_vjp)
# We use pjit in case there are shardings in the exported module.
return pjit.pjit(jax_export.call_exported(exported))(*data.inputs)
return pjit.pjit(export.call_exported(exported))(*data.inputs)

View File

@ -39,7 +39,7 @@ from jax._src.lib.mlir.dialects import hlo
import jax._src.xla_bridge
from jax import config
from jax.experimental import jax2tf
from jax.experimental.jax2tf import jax_export
from jax.experimental.export import export
from jax.experimental.jax2tf.tests import tf_test_util
from jax.experimental.maps import xmap
from jax.experimental.shard_map import shard_map
@ -1522,7 +1522,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
stack.enter_context(mesh)
# Run the JAX native version, to check it works, and to fill caches.
_ = func_to_convert(*args)
exported = jax_export.export(
exported = export.export(
func_to_convert,
lowering_platform='tpu'
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))

View File

@ -31,8 +31,8 @@ import re
import jax
from jax.experimental import jax2tf
from jax.experimental.jax2tf import shape_poly
from jax.experimental.jax2tf import jax_export
from jax.experimental.export import export
from jax.experimental.export import shape_poly
from jax.experimental import pjit
from jax import lax
import jax.numpy as jnp
@ -72,7 +72,6 @@ expect_error_associative_scan = (
"associative scan over axis of non-constant size"))
class DimExprTest(tf_test_util.JaxToTfTestCase):
def sampled_assert_equal(self,
@ -585,7 +584,7 @@ class PolyHarness(Harness):
len(polymorphic_shapes), len(args),
f"polymorphic_shapes {polymorphic_shapes} of length "
f"{len(polymorphic_shapes)} must match number of arguments {len(args)}")
args_specs = jax_export.poly_specs(args, polymorphic_shapes)
args_specs = export.poly_specs(args, polymorphic_shapes)
input_signature = [
tf.TensorSpec(
[d if isinstance(d, int) else None for d in a.shape],

View File

@ -31,7 +31,7 @@ from jax import tree_util
from jax import config
from jax.experimental import jax2tf
from jax.experimental.jax2tf import jax_export
from jax.experimental.export import export
from jax._src import xla_bridge
import numpy as np
import tensorflow as tf # type: ignore[import]
@ -158,7 +158,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
jax_legacy_prng_key="allow")
class JaxToTfTestCase(jtu.JaxTestCase):
# We want most tests to use the maximum available version, from the locally
# installed tfxla module and jax_export.
# installed tfxla module and export.
use_max_serialization_version = True
def setUp(self):
@ -181,16 +181,16 @@ class JaxToTfTestCase(jtu.JaxTestCase):
self.addCleanup(functools.partial(config.update,
"jax_serialization_version", version))
if self.use_max_serialization_version:
# Use the largest supported by both jax_export and tfxla.call_module
version = min(jax_export.maximum_supported_serialization_version,
# Use the largest supported by both export and tfxla.call_module
version = min(export.maximum_supported_serialization_version,
tfxla.call_module_maximum_supported_version())
self.assertGreaterEqual(version,
jax_export.minimum_supported_serialization_version)
export.minimum_supported_serialization_version)
config.update("jax_serialization_version", version)
logging.info(
"Using JAX serialization version %s (jax_export.max_version %s, tf.XlaCallModule max version %s)",
"Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)",
version,
jax_export.maximum_supported_serialization_version,
export.maximum_supported_serialization_version,
tfxla.call_module_maximum_supported_version())
with contextlib.ExitStack() as stack:

View File

@ -1173,6 +1173,18 @@ py_test(
],
)
jax_test(
name = "export_test",
srcs = ["export_test.py"],
enable_configs = [
"tpu_df_2x2",
],
tags = [],
deps = [
"//jax/experimental/export",
],
)
exports_files(
[
"api_test.py",

View File

@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import math
import functools
import logging
import math
import re
from typing import Optional
import unittest
@ -24,7 +24,7 @@ import jax
from jax import numpy as jnp
from jax import tree_util
from jax.config import config
from jax.experimental.jax2tf import jax_export
from jax.experimental.export import export
from jax._src import core
from jax._src import test_util as jtu
@ -56,14 +56,14 @@ class JaxExportTest(jtu.JaxTestCase):
super().setUp()
# Run tests with the maximum supported version by default
self.override_serialization_version(
jax_export.maximum_supported_serialization_version)
export.maximum_supported_serialization_version)
def test_basic_export_only(self):
def my_fun(x):
return jnp.sin(x)
exp = jax_export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
exp = export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
self.assertEqual("my_fun", exp.fun_name)
self.assertEqual(jax_export.default_lowering_platform(), exp.lowering_platform)
self.assertEqual(export.default_lowering_platform(), exp.lowering_platform)
self.assertEqual(tree_util.tree_flatten(((1,), {}))[1], exp.in_tree)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
@ -74,7 +74,7 @@ class JaxExportTest(jtu.JaxTestCase):
def f(a_b_pair, *, a, b):
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
exp = jax_export.export(f, lowering_platform="cpu")((a, b), a=a, b=b)
exp = export.export(f, lowering_platform="cpu")((a, b), a=a, b=b)
a_aval = core.ShapedArray(a.shape, a.dtype)
b_aval = core.ShapedArray(b.shape, b.dtype)
self.assertEqual(exp.lowering_platform, "cpu")
@ -90,9 +90,9 @@ class JaxExportTest(jtu.JaxTestCase):
def f(a, b): # a: f32[2w,h] b: f32[w,h]
return jnp.concatenate([a, b], axis=0)
exp = jax_export.export(f)(
jax_export.poly_spec(a.shape, a.dtype, "(2*w, h)"),
jax_export.poly_spec(a.shape, a.dtype, "(w, h)"))
exp = export.export(f)(
export.poly_spec(a.shape, a.dtype, "(2*w, h)"),
export.poly_spec(a.shape, a.dtype, "(w, h)"))
self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape))
self.assertEqual("(w, h)", str(exp.in_avals[1].shape))
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
@ -102,25 +102,25 @@ class JaxExportTest(jtu.JaxTestCase):
def f(a0, a1, *, ak):
return jnp.concatenate([a0, a1, ak], axis=0)
a_poly_spec = jax_export.poly_spec(a.shape, a.dtype, "(w, h)")
exp = jax_export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec)
a_poly_spec = export.poly_spec(a.shape, a.dtype, "(w, h)")
exp = export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec)
self.assertEqual("(w, h)", str(exp.in_avals[0].shape))
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
def test_basic(self):
f = jnp.sin
x = np.arange(4, dtype=np.float32)
exp_f = jax_export.export(f)(x)
exp_f = export.export(f)(x)
f1 = jax_export.call_exported(exp_f)
f1 = export.call_exported(exp_f)
self.assertAllClose(f(x), f1(x))
def test_call_exported_lambda(self):
# When we export a lambda, the exported.fun_name is not a valid MLIR function name
f = lambda x: jnp.sin(x)
x = np.arange(4, dtype=np.float32)
exp_f = jax_export.export(f)(x)
f1 = jax_export.call_exported(exp_f)
exp_f = export.export(f)(x)
f1 = export.call_exported(exp_f)
self.assertAllClose(f(x), f1(x))
def test_call_twice_exported(self):
@ -129,8 +129,8 @@ class JaxExportTest(jtu.JaxTestCase):
@jax.jit
def f1(x):
exp_f = jax_export.export(f)(x)
return jax_export.call_exported(exp_f)(x) + jax_export.call_exported(exp_f)(x)
exp_f = export.export(f)(x)
return export.call_exported(exp_f)(x) + export.call_exported(exp_f)(x)
self.assertAllClose(2. * f(x), f1(x))
@ -138,9 +138,9 @@ class JaxExportTest(jtu.JaxTestCase):
f = lambda x, y: jnp.sin(x)
x = np.arange(4, dtype=np.float32)
y = np.arange(6, dtype=np.float32)
exp_f = jax_export.export(f)(x, y)
exp_f = export.export(f)(x, y)
f1 = jax_export.call_exported(exp_f)
f1 = export.call_exported(exp_f)
self.assertAllClose(f(x, y), f1(x, y))
def test_pytree(self):
@ -149,8 +149,8 @@ class JaxExportTest(jtu.JaxTestCase):
def f(a_b_pair, a, b):
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
exp_f = jax_export.export(f)((a, b), a=a, b=b)
f1 = jax_export.call_exported(exp_f)
exp_f = export.export(f)((a, b), a=a, b=b)
f1 = export.call_exported(exp_f)
self.assertAllClose(f((a, b), a=a, b=b),
f1((a, b), a=a, b=b))
@ -158,34 +158,34 @@ class JaxExportTest(jtu.JaxTestCase):
def f(a_b_pair, *, c):
return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c
a = b = c = np.arange(4, dtype=np.float32)
exp_f = jax_export.export(f)((a, b), c=c)
exp_f = export.export(f)((a, b), c=c)
with self.assertRaisesRegex(
ValueError,
"The invocation args and kwargs must have the same pytree structure"):
jax_export.call_exported(exp_f)(a, b, c=(a, b))
export.call_exported(exp_f)(a, b, c=(a, b))
def test_error_wrong_avals(self):
def f(a, *, b): # a: f32[4] and b: f32[4]
return jnp.sin(a) + jnp.cos(b)
f32_4 = np.arange(4, dtype=np.float32)
exp_f = jax_export.export(f)(f32_4, b=f32_4)
exp_f = export.export(f)(f32_4, b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Shape mismatch for args\[0\].shape\[0\]"):
jax_export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Shape mismatch for kwargs\['b'\].shape\[0\]"):
jax_export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
with self.assertRaisesRegex(ValueError,
r"Rank mismatch for args\[0\]"):
jax_export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4)
export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4)
with self.assertRaisesRegex(ValueError,
r"Dtype mismatch for args\[0\]"):
jax_export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4)
export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4)
@jtu.parameterized_filterable(
testcase_name=lambda kw: kw["platform"],
@ -194,19 +194,19 @@ class JaxExportTest(jtu.JaxTestCase):
def test_error_wrong_platform(self, platform):
a = np.arange(4, dtype=np.float32)
exp_f = jax_export.export(jnp.sin, lowering_platform=platform)(a)
exp_f = export.export(jnp.sin, lowering_platform=platform)(a)
if xb.canonicalize_platform(jtu.device_under_test()) == platform:
raise unittest.SkipTest("Uninteresting scenario")
with self.assertRaisesRegex(
ValueError, "The exported function .* was lowered for platform"):
jax_export.call_exported(exp_f)(a)
export.call_exported(exp_f)(a)
# Now try with the platform check disabled
exp_f_no_platform_check = jax_export.export(
exp_f_no_platform_check = export.export(
jnp.sin, lowering_platform=platform,
disabled_checks=[jax_export.DisabledSafetyCheck.platform()])(a)
res = jax_export.call_exported(exp_f_no_platform_check)(a)
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
res = export.call_exported(exp_f_no_platform_check)(a)
self.assertAllClose(res, jnp.sin(a))
@jtu.parameterized_filterable(
@ -230,23 +230,23 @@ class JaxExportTest(jtu.JaxTestCase):
a = np.arange(3, dtype=np.float32)
with self.assertRaisesRegex(ValueError,
"Cannot serialize code with custom calls whose targets .*"):
jax_export.export(
export.export(
lambda a: a + test_primitive.bind(a)
)(a)
# Now try again with the safety check disabled
exp = jax_export.export(
exp = export.export(
lambda a: a + test_primitive.bind(a),
disabled_checks=[jax_export.DisabledSafetyCheck.custom_call("disallowed_call_target")]
disabled_checks=[export.DisabledSafetyCheck.custom_call("disallowed_call_target")]
)(a)
self.assertIn("disallowed_call_target", exp.mlir_module())
def test_grad(self):
f = lambda x: jnp.sum(jnp.sin(x))
x = np.arange(4, dtype=np.float32)
exp_f = jax_export.export(f)(x)
exp_f = export.export(f)(x)
f1 = jax_export.call_exported(exp_f)
f1 = export.call_exported(exp_f)
self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x))
def test_pytree_vjp(self):
@ -256,14 +256,14 @@ class JaxExportTest(jtu.JaxTestCase):
a = np.arange(4, dtype=np.float32)
b = np.arange(6, dtype=np.float32)
exp_f = jax_export.export(f)((a, b), a=a, b=b)
exp_f = export.export(f)((a, b), a=a, b=b)
out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent
def f1_jax(a, b): # For VJP, make a function without kwargs
res = f((a, b), a=a, b=b)
return res
def f1_exp(a, b): # For VJP, make a function without kwargs
res = jax_export.call_exported(exp_f)((a, b), a=a, b=b)
res = export.call_exported(exp_f)((a, b), a=a, b=b)
return res
jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct)
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
@ -273,33 +273,33 @@ class JaxExportTest(jtu.JaxTestCase):
def f1(x):
return jnp.sin(x)
a = np.arange(4, dtype=np.float32)
exp_f1 = jax_export.export(f1)(a)
exp_f1 = export.export(f1)(a)
def f2(x):
res1 = jax_export.call_exported(exp_f1)(x)
res2 = jax_export.call_exported(exp_f1)(res1)
res1 = export.call_exported(exp_f1)(x)
res2 = export.call_exported(exp_f1)(res1)
return jnp.cos(res2)
exp_f2 = jax_export.export(f2)(a)
exp_f2 = export.export(f2)(a)
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
jax_export.call_exported(exp_f2)(a))
export.call_exported(exp_f2)(a))
@jtu.parameterized_filterable(
#one_containing="",
kwargs=[
dict(v=v)
for v in range(jax_export.minimum_supported_serialization_version - 1,
jax_export.maximum_supported_serialization_version + 2)])
for v in range(export.minimum_supported_serialization_version - 1,
export.maximum_supported_serialization_version + 2)])
def test_shape_poly_basic_versions(self, v: int):
self.override_serialization_version(v)
with contextlib.ExitStack() as e:
if not (jax_export.minimum_supported_serialization_version <= v
<= jax_export.maximum_supported_serialization_version):
if not (export.minimum_supported_serialization_version <= v
<= export.maximum_supported_serialization_version):
e.enter_context(self.assertRaisesRegex(
ValueError,
f"The requested jax_serialization version {v} is outside the range of supported versions"))
exp = jax_export.export(jnp.sin)(
jax_export.poly_spec((3, 4), np.float32, "w, h"))
exp = export.export(jnp.sin)(
export.poly_spec((3, 4), np.float32, "w, h"))
# Peek at the module
module_str = exp.mlir_module()
self.assertEqual(config.jax_serialization_version >= 7,
@ -307,11 +307,11 @@ class JaxExportTest(jtu.JaxTestCase):
self.assertIn("jax.uses_shape_polymorphism = true",
module_str)
x = np.arange(30, dtype=np.float32).reshape((5, 6))
res = jax_export.call_exported(exp)(x)
res = export.call_exported(exp)(x)
self.assertAllClose(res, np.sin(x))
# A function is exported with f32[poly_spec] and is called with different arg
# shapes. We use jax_export.call_exported and we also run the shape check
# shapes. We use export.call_exported and we also run the shape check
# module.
@jtu.parameterized_filterable(
testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore
@ -348,8 +348,8 @@ class JaxExportTest(jtu.JaxTestCase):
return jnp.reshape(x, (-1, x.shape[1]))
disabled_checks = ()
exp_f = jax_export.export(f, disabled_checks=disabled_checks)(
jax_export.poly_spec((3, 4, 12), np.float32, poly_spec))
exp_f = export.export(f, disabled_checks=disabled_checks)(
export.poly_spec((3, 4, 12), np.float32, poly_spec))
self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12")
arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12]
@ -359,7 +359,7 @@ class JaxExportTest(jtu.JaxTestCase):
stack.push(self.assertRaisesRegex(Exception, expect_error))
assert core.is_constant_shape(arg.shape)
res = jax_export.call_exported(exp_f)(arg)
res = export.call_exported(exp_f)(arg)
if not expect_error:
self.assertAllClose(res, f(arg))
@ -450,23 +450,23 @@ class JaxExportTest(jtu.JaxTestCase):
arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12]
inner_exp = jax_export.export(inner)(
jax_export.poly_spec((3, 4, 12), np.float32, inner_poly_spec))
inner_exp = export.export(inner)(
export.poly_spec((3, 4, 12), np.float32, inner_poly_spec))
self.assertEqual(inner_exp.uses_shape_polymorphism,
(inner_poly_spec != "3,4,12"))
def outer(x): # x: outer_poly_spec
# Use an addition to test that the shapes are refined properly for the
# result of the call_exported.
return jax_export.call_exported(inner_exp)(x) + inner(x)
return export.call_exported(inner_exp)(x) + inner(x)
with contextlib.ExitStack() as stack:
if expect_error_outer_exp is not None:
stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))
# Call it after exporting again, with polymorphic shapes
outer_exp = jax_export.export(outer)(
jax_export.poly_spec(arg.shape, arg.dtype, outer_poly_spec))
outer_exp = export.export(outer)(
export.poly_spec(arg.shape, arg.dtype, outer_poly_spec))
if expect_error_outer_exp is not None:
return
@ -478,7 +478,7 @@ class JaxExportTest(jtu.JaxTestCase):
if expect_error_run is not None:
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
res = jax_export.call_exported(outer_exp)(arg)
res = export.call_exported(outer_exp)(arg)
if expect_error_run is not None:
return
@ -543,9 +543,9 @@ class JaxExportTest(jtu.JaxTestCase):
with contextlib.ExitStack() as stack:
if expect_error is not None:
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
exp = jax_export.export(f_jax)(
jax_export.poly_spec(x.shape, x.dtype, poly_spec))
jax_export.call_exported(exp)(x)
exp = export.export(f_jax)(
export.poly_spec(x.shape, x.dtype, poly_spec))
export.call_exported(exp)(x)
def test_multi_platform(self):
if jtu.device_under_test() == "gpu":
@ -553,10 +553,10 @@ class JaxExportTest(jtu.JaxTestCase):
raise unittest.SkipTest("Not intended for running on GPU")
x = np.arange(5, dtype=np.float32)
# TODO: use a function with different behavior for different platforms
exp = jax_export.export(jnp.sin,
exp = export.export(jnp.sin,
lowering_platforms=('cpu', 'tpu'))(x)
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu'))
res = jax_export.call_exported(exp)(x)
res = export.call_exported(exp)(x)
self.assertAllClose(res, np.sin(x))
def test_multi_platform_nested(self):
@ -565,7 +565,7 @@ class JaxExportTest(jtu.JaxTestCase):
raise unittest.SkipTest("Not intended for running on TPU")
x = np.arange(5, dtype=np.float32)
# TODO: use a function with different behavior for different platforms
exp = jax_export.export(jnp.sin,
exp = export.export(jnp.sin,
lowering_platforms=('cpu', 'tpu', 'cuda'))(x)
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu', 'cuda'))
@ -573,9 +573,9 @@ class JaxExportTest(jtu.JaxTestCase):
# lowering platforms, but included in the lowering platforms for the
# nested exported.
# TODO: improve this test once we implement true multi-platform lowering
exp2 = jax_export.export(jax_export.call_exported(exp),
exp2 = export.export(export.call_exported(exp),
lowering_platforms=('cpu', 'cuda'))(x)
res2 = jax_export.call_exported(exp2)(x)
res2 = export.call_exported(exp2)(x)
self.assertAllClose(res2, np.sin(x))
def test_multi_platform_and_poly(self):
@ -583,16 +583,16 @@ class JaxExportTest(jtu.JaxTestCase):
# The export is not applicable to GPU
raise unittest.SkipTest("Not intended for running on GPU")
# TODO: use a function with different behavior for different platforms
exp = jax_export.export(lambda x: jnp.reshape(jnp.sin(x), (-1,)),
exp = export.export(lambda x: jnp.reshape(jnp.sin(x), (-1,)),
lowering_platforms=('cpu', 'tpu'))(
jax_export.poly_spec((5, 6), np.float32, "b1, b2")
export.poly_spec((5, 6), np.float32, "b1, b2")
)
x = np.arange(12, dtype=np.float32).reshape((3, 4))
res = jax_export.call_exported(exp)(x)
res = export.call_exported(exp)(x)
self.assertAllClose(res, np.sin(x).reshape((-1,)))
# Now serialize the call to the exported
exp2 = jax_export.export(jax_export.call_exported(exp))(x)
res2 = jax_export.call_exported(exp2)(x)
exp2 = export.export(export.call_exported(exp))(x)
res2 = export.call_exported(exp2)(x)
self.assertAllClose(res2, np.sin(x).reshape((-1,)))