mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[export] Move jax_export and shape_poly out of jax2tf.
Those modules have been developed initially for jax2tf but they do not depend on TF anymore. They are used for JAX native serialization. We move them under jax.experimental.export (also renaming jax_export.py to export.py) so that we can use them without depending on TF. We are leaving behind stub modules jax2tf.jax_export and jax2tf.shape_poly that just redirect some of the public APIs. To be cleaned later. PiperOrigin-RevId: 562988740
This commit is contained in:
parent
b2e5a1cf6a
commit
660a015652
44
jax/experimental/export/BUILD
Normal file
44
jax/experimental/export/BUILD
Normal file
@ -0,0 +1,44 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# JAX-export provides APIs for exporting StableHLO for serialization purposes.
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
)
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# Please add new users to :australis_users.
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "export",
|
||||
srcs = [
|
||||
"export.py",
|
||||
"shape_poly.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
# TODO: b/255503696: enable pytype
|
||||
tags = ["pytype_unchecked_annotations"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//jax",
|
||||
] + py_deps("numpy"),
|
||||
)
|
14
jax/experimental/export/__init__.py
Normal file
14
jax/experimental/export/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
# Copyright 2023 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
1046
jax/experimental/export/export.py
Normal file
1046
jax/experimental/export/export.py
Normal file
File diff suppressed because it is too large
Load Diff
1593
jax/experimental/export/shape_poly.py
Normal file
1593
jax/experimental/export/shape_poly.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -35,34 +35,21 @@ py_library(
|
||||
deps = [":jax2tf_internal"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "jax_export",
|
||||
srcs = [
|
||||
"jax_export.py",
|
||||
"shape_poly.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
# TODO: b/255503696: enable pytype
|
||||
tags = ["pytype_unchecked_annotations"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//jax",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "jax2tf_internal",
|
||||
srcs = [
|
||||
"call_tf.py",
|
||||
"impl_no_xla.py",
|
||||
"jax2tf.py",
|
||||
"jax_export.py", # TODO(necula): remove stub
|
||||
"shape_poly.py", # TODO(necula): remove stub
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
# TODO: b/255503696: enable pytype
|
||||
tags = ["pytype_unchecked_annotations"],
|
||||
visibility = jax_visibility("jax2tf_internal"),
|
||||
deps = [
|
||||
":jax_export",
|
||||
"//jax",
|
||||
"//jax/experimental/export",
|
||||
] + py_deps("numpy") + py_deps("tensorflow_core") + jax2tf_deps,
|
||||
)
|
||||
|
@ -21,3 +21,7 @@ from jax.experimental.jax2tf.jax2tf import (
|
||||
PolyShape as PolyShape
|
||||
)
|
||||
from jax.experimental.jax2tf.call_tf import call_tf as call_tf
|
||||
# TODO(necula): remove stub. Needed by SAX
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
# Needed by maths.qec.
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
|
@ -36,9 +36,9 @@ from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax import sharding
|
||||
from jax.experimental import maps
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
from jax.experimental.export import shape_poly
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.jax2tf import impl_no_xla
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import ad_checkpoint
|
||||
@ -86,7 +86,7 @@ NameStack = source_info_util.NameStack
|
||||
PolyShape = shape_poly.PolyShape
|
||||
DType = Any
|
||||
|
||||
DisabledSafetyCheck = jax_export.DisabledSafetyCheck
|
||||
DisabledSafetyCheck = export.DisabledSafetyCheck
|
||||
|
||||
# A temporary internal flag, to enable the wrapping of jax.jit functions
|
||||
# with tf.function(jit_compile=True). See #7389. This change has triggered a
|
||||
@ -370,14 +370,14 @@ def convert(fun_jax: Callable,
|
||||
_, a_jax_dtype = _tfval_to_tensor_jax_dtype(a)
|
||||
return tf_arg_shape, a_jax_dtype
|
||||
|
||||
args_specs = jax_export.poly_specs(args_tf,
|
||||
polymorphic_shapes=polymorphic_shapes,
|
||||
get_shape_and_dtype=shape_and_dtype_tf)
|
||||
args_specs = export.poly_specs(args_tf,
|
||||
polymorphic_shapes=polymorphic_shapes,
|
||||
get_shape_and_dtype=shape_and_dtype_tf)
|
||||
# The polymorphic_shapes argument refers to positional arguments only.
|
||||
# We assume None for the kwargs.
|
||||
kwargs_specs = jax_export.poly_specs(kwargs_tf,
|
||||
polymorphic_shapes=None,
|
||||
get_shape_and_dtype=shape_and_dtype_tf)
|
||||
kwargs_specs = export.poly_specs(kwargs_tf,
|
||||
polymorphic_shapes=None,
|
||||
get_shape_and_dtype=shape_and_dtype_tf)
|
||||
combined_args_tf = (args_tf, kwargs_tf)
|
||||
args_flat_tf: Sequence[TfVal]
|
||||
args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf)
|
||||
@ -503,7 +503,7 @@ class NativeSerializationImpl(SerializationImpl):
|
||||
_thread_local_state.call_tf_concrete_function_list = _prev_func_list
|
||||
|
||||
self._restore_context = _restore_context
|
||||
self.exported = jax_export.export(
|
||||
self.exported = export.export(
|
||||
self.fun_jax,
|
||||
lowering_platform=self.lowering_platform,
|
||||
disabled_checks=self.native_serialization_disabled_checks
|
||||
@ -669,7 +669,7 @@ def eval_polymorphic_shape(fun_jax: Callable,
|
||||
(c, a)
|
||||
"""
|
||||
def do_eval_polymorphic_shape(*args_specs) -> Any:
|
||||
args_poly_specs = jax_export.poly_specs(
|
||||
args_poly_specs = export.poly_specs(
|
||||
args_specs, polymorphic_shapes=polymorphic_shapes)
|
||||
res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs)
|
||||
# TODO(necula): For now we export the polymorphic shapes using `str`.
|
||||
@ -803,7 +803,7 @@ def _interpret_fun_jax(
|
||||
|
||||
|
||||
def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
|
||||
exported: jax_export.Exported,
|
||||
exported: export.Exported,
|
||||
) -> Sequence[TfVal]:
|
||||
"""Runs the `exported` as an XlaCallModule TF op.
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -28,7 +28,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import lax
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.jax2tf.tests import back_compat_test_util as bctu
|
||||
|
||||
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
|
||||
@ -96,7 +96,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
|
||||
def test_custom_call_coverage(self):
|
||||
"""Tests that the back compat tests cover all the targets declared stable."""
|
||||
targets_to_cover = set(jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
|
||||
targets_to_cover = set(export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE)
|
||||
# Add here all the testdatas that should cover the targets guaranteed
|
||||
# stable
|
||||
covering_testdatas = [
|
||||
|
@ -21,7 +21,7 @@ The tests in this file refer to the test data in ./back_compat_testdata.
|
||||
There is one test for each version of a custom call target, e.g.,
|
||||
`test_ducc_fft` tests the FFT custom calls on CPU.
|
||||
Only custom call targets tested here should be listed in
|
||||
jax_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE. All other custom
|
||||
export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE. All other custom
|
||||
call targets will result in an error when encountered during serialization.
|
||||
|
||||
Once we stop using a custom call target in JAX, you can remove it from the
|
||||
@ -78,7 +78,7 @@ from numpy import array, float32
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
from jax.experimental.export import export
|
||||
|
||||
from jax.experimental import pjit
|
||||
|
||||
@ -281,12 +281,12 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
a string (for debugging), and (c) the module serialization version.
|
||||
"""
|
||||
# Use the native exporter, to make sure we get the proper serialization.
|
||||
args_specs = jax_export.poly_specs(data.inputs, polymorphic_shapes)
|
||||
exported = jax_export.export(
|
||||
args_specs = export.poly_specs(data.inputs, polymorphic_shapes)
|
||||
exported = export.export(
|
||||
jax.jit(func),
|
||||
lowering_platform=self.default_jax_backend(),
|
||||
disabled_checks=tuple(
|
||||
jax_export.DisabledSafetyCheck.custom_call(target)
|
||||
export.DisabledSafetyCheck.custom_call(target)
|
||||
for target in allow_unstable_custom_call_targets)
|
||||
)(*args_specs)
|
||||
|
||||
@ -297,13 +297,13 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
|
||||
def run_serialized(self, data: CompatTestData,
|
||||
polymorphic_shapes: Optional[Sequence[str]] = None):
|
||||
args_specs = jax_export.poly_specs(data.inputs, polymorphic_shapes)
|
||||
args_specs = export.poly_specs(data.inputs, polymorphic_shapes)
|
||||
def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray:
|
||||
return core.ShapedArray(a.shape, a.dtype)
|
||||
in_avals_tree = tree_util.tree_map(ndarray_to_aval, args_specs)
|
||||
# TODO: we ought to ensure that out_avals are polymorphic if need be. We
|
||||
# could either save the in/out_avals (but we need to first implement that
|
||||
# support in jax_export), or we can just re-use them from the current
|
||||
# support in export), or we can just re-use them from the current
|
||||
# exported.
|
||||
out_avals_tree = tree_util.tree_map(ndarray_to_aval, data.expected_outputs)
|
||||
# in_tree must be for (args, kwargs)
|
||||
@ -312,7 +312,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
def _get_vjp(_):
|
||||
assert False # We do not have and do not need VJP
|
||||
|
||||
exported = jax_export.Exported(
|
||||
exported = export.Exported(
|
||||
fun_name="run_serialized",
|
||||
in_tree=in_tree,
|
||||
in_avals=tuple(in_avals),
|
||||
@ -331,4 +331,4 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
_get_vjp=_get_vjp)
|
||||
|
||||
# We use pjit in case there are shardings in the exported module.
|
||||
return pjit.pjit(jax_export.call_exported(exported))(*data.inputs)
|
||||
return pjit.pjit(export.call_exported(exported))(*data.inputs)
|
||||
|
@ -39,7 +39,7 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
import jax._src.xla_bridge
|
||||
from jax import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.jax2tf.tests import tf_test_util
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.shard_map import shard_map
|
||||
@ -1522,7 +1522,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
stack.enter_context(mesh)
|
||||
# Run the JAX native version, to check it works, and to fill caches.
|
||||
_ = func_to_convert(*args)
|
||||
exported = jax_export.export(
|
||||
exported = export.export(
|
||||
func_to_convert,
|
||||
lowering_platform='tpu'
|
||||
)(*(core.ShapedArray(a.shape, a.dtype) for a in args))
|
||||
|
@ -31,8 +31,8 @@ import re
|
||||
|
||||
import jax
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.export import shape_poly
|
||||
from jax.experimental import pjit
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
@ -72,7 +72,6 @@ expect_error_associative_scan = (
|
||||
"associative scan over axis of non-constant size"))
|
||||
|
||||
|
||||
|
||||
class DimExprTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
def sampled_assert_equal(self,
|
||||
@ -585,7 +584,7 @@ class PolyHarness(Harness):
|
||||
len(polymorphic_shapes), len(args),
|
||||
f"polymorphic_shapes {polymorphic_shapes} of length "
|
||||
f"{len(polymorphic_shapes)} must match number of arguments {len(args)}")
|
||||
args_specs = jax_export.poly_specs(args, polymorphic_shapes)
|
||||
args_specs = export.poly_specs(args, polymorphic_shapes)
|
||||
input_signature = [
|
||||
tf.TensorSpec(
|
||||
[d if isinstance(d, int) else None for d in a.shape],
|
||||
|
@ -31,7 +31,7 @@ from jax import tree_util
|
||||
|
||||
from jax import config
|
||||
from jax.experimental import jax2tf
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
from jax.experimental.export import export
|
||||
from jax._src import xla_bridge
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
@ -158,7 +158,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
|
||||
jax_legacy_prng_key="allow")
|
||||
class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
# We want most tests to use the maximum available version, from the locally
|
||||
# installed tfxla module and jax_export.
|
||||
# installed tfxla module and export.
|
||||
use_max_serialization_version = True
|
||||
|
||||
def setUp(self):
|
||||
@ -181,16 +181,16 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
self.addCleanup(functools.partial(config.update,
|
||||
"jax_serialization_version", version))
|
||||
if self.use_max_serialization_version:
|
||||
# Use the largest supported by both jax_export and tfxla.call_module
|
||||
version = min(jax_export.maximum_supported_serialization_version,
|
||||
# Use the largest supported by both export and tfxla.call_module
|
||||
version = min(export.maximum_supported_serialization_version,
|
||||
tfxla.call_module_maximum_supported_version())
|
||||
self.assertGreaterEqual(version,
|
||||
jax_export.minimum_supported_serialization_version)
|
||||
export.minimum_supported_serialization_version)
|
||||
config.update("jax_serialization_version", version)
|
||||
logging.info(
|
||||
"Using JAX serialization version %s (jax_export.max_version %s, tf.XlaCallModule max version %s)",
|
||||
"Using JAX serialization version %s (export.max_version %s, tf.XlaCallModule max version %s)",
|
||||
version,
|
||||
jax_export.maximum_supported_serialization_version,
|
||||
export.maximum_supported_serialization_version,
|
||||
tfxla.call_module_maximum_supported_version())
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
|
12
tests/BUILD
12
tests/BUILD
@ -1173,6 +1173,18 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "export_test",
|
||||
srcs = ["export_test.py"],
|
||||
enable_configs = [
|
||||
"tpu_df_2x2",
|
||||
],
|
||||
tags = [],
|
||||
deps = [
|
||||
"//jax/experimental/export",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"api_test.py",
|
||||
|
@ -12,9 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import math
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from typing import Optional
|
||||
import unittest
|
||||
@ -24,7 +24,7 @@ import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.config import config
|
||||
from jax.experimental.jax2tf import jax_export
|
||||
from jax.experimental.export import export
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
@ -56,14 +56,14 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
super().setUp()
|
||||
# Run tests with the maximum supported version by default
|
||||
self.override_serialization_version(
|
||||
jax_export.maximum_supported_serialization_version)
|
||||
export.maximum_supported_serialization_version)
|
||||
|
||||
def test_basic_export_only(self):
|
||||
def my_fun(x):
|
||||
return jnp.sin(x)
|
||||
exp = jax_export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
|
||||
exp = export.export(my_fun)(jax.ShapeDtypeStruct((4,), dtype=np.float32))
|
||||
self.assertEqual("my_fun", exp.fun_name)
|
||||
self.assertEqual(jax_export.default_lowering_platform(), exp.lowering_platform)
|
||||
self.assertEqual(export.default_lowering_platform(), exp.lowering_platform)
|
||||
self.assertEqual(tree_util.tree_flatten(((1,), {}))[1], exp.in_tree)
|
||||
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
|
||||
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
|
||||
@ -74,7 +74,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f(a_b_pair, *, a, b):
|
||||
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
|
||||
|
||||
exp = jax_export.export(f, lowering_platform="cpu")((a, b), a=a, b=b)
|
||||
exp = export.export(f, lowering_platform="cpu")((a, b), a=a, b=b)
|
||||
a_aval = core.ShapedArray(a.shape, a.dtype)
|
||||
b_aval = core.ShapedArray(b.shape, b.dtype)
|
||||
self.assertEqual(exp.lowering_platform, "cpu")
|
||||
@ -90,9 +90,9 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f(a, b): # a: f32[2w,h] b: f32[w,h]
|
||||
return jnp.concatenate([a, b], axis=0)
|
||||
|
||||
exp = jax_export.export(f)(
|
||||
jax_export.poly_spec(a.shape, a.dtype, "(2*w, h)"),
|
||||
jax_export.poly_spec(a.shape, a.dtype, "(w, h)"))
|
||||
exp = export.export(f)(
|
||||
export.poly_spec(a.shape, a.dtype, "(2*w, h)"),
|
||||
export.poly_spec(a.shape, a.dtype, "(w, h)"))
|
||||
self.assertEqual("(2*w, h)", str(exp.in_avals[0].shape))
|
||||
self.assertEqual("(w, h)", str(exp.in_avals[1].shape))
|
||||
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
|
||||
@ -102,25 +102,25 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f(a0, a1, *, ak):
|
||||
return jnp.concatenate([a0, a1, ak], axis=0)
|
||||
|
||||
a_poly_spec = jax_export.poly_spec(a.shape, a.dtype, "(w, h)")
|
||||
exp = jax_export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec)
|
||||
a_poly_spec = export.poly_spec(a.shape, a.dtype, "(w, h)")
|
||||
exp = export.export(f)(a_poly_spec, a_poly_spec, ak=a_poly_spec)
|
||||
self.assertEqual("(w, h)", str(exp.in_avals[0].shape))
|
||||
self.assertEqual("(3*w, h)", str(exp.out_avals[0].shape))
|
||||
|
||||
def test_basic(self):
|
||||
f = jnp.sin
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = jax_export.export(f)(x)
|
||||
exp_f = export.export(f)(x)
|
||||
|
||||
f1 = jax_export.call_exported(exp_f)
|
||||
f1 = export.call_exported(exp_f)
|
||||
self.assertAllClose(f(x), f1(x))
|
||||
|
||||
def test_call_exported_lambda(self):
|
||||
# When we export a lambda, the exported.fun_name is not a valid MLIR function name
|
||||
f = lambda x: jnp.sin(x)
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = jax_export.export(f)(x)
|
||||
f1 = jax_export.call_exported(exp_f)
|
||||
exp_f = export.export(f)(x)
|
||||
f1 = export.call_exported(exp_f)
|
||||
self.assertAllClose(f(x), f1(x))
|
||||
|
||||
def test_call_twice_exported(self):
|
||||
@ -129,8 +129,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f1(x):
|
||||
exp_f = jax_export.export(f)(x)
|
||||
return jax_export.call_exported(exp_f)(x) + jax_export.call_exported(exp_f)(x)
|
||||
exp_f = export.export(f)(x)
|
||||
return export.call_exported(exp_f)(x) + export.call_exported(exp_f)(x)
|
||||
|
||||
self.assertAllClose(2. * f(x), f1(x))
|
||||
|
||||
@ -138,9 +138,9 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
f = lambda x, y: jnp.sin(x)
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
y = np.arange(6, dtype=np.float32)
|
||||
exp_f = jax_export.export(f)(x, y)
|
||||
exp_f = export.export(f)(x, y)
|
||||
|
||||
f1 = jax_export.call_exported(exp_f)
|
||||
f1 = export.call_exported(exp_f)
|
||||
self.assertAllClose(f(x, y), f1(x, y))
|
||||
|
||||
def test_pytree(self):
|
||||
@ -149,8 +149,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f(a_b_pair, a, b):
|
||||
return (dict(res=a_b_pair, a=a, b=b), jnp.sin(a), jnp.cos(b))
|
||||
|
||||
exp_f = jax_export.export(f)((a, b), a=a, b=b)
|
||||
f1 = jax_export.call_exported(exp_f)
|
||||
exp_f = export.export(f)((a, b), a=a, b=b)
|
||||
f1 = export.call_exported(exp_f)
|
||||
self.assertAllClose(f((a, b), a=a, b=b),
|
||||
f1((a, b), a=a, b=b))
|
||||
|
||||
@ -158,34 +158,34 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f(a_b_pair, *, c):
|
||||
return jnp.sin(a_b_pair[0]) + jnp.cos(a_b_pair[1]) + c
|
||||
a = b = c = np.arange(4, dtype=np.float32)
|
||||
exp_f = jax_export.export(f)((a, b), c=c)
|
||||
exp_f = export.export(f)((a, b), c=c)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The invocation args and kwargs must have the same pytree structure"):
|
||||
jax_export.call_exported(exp_f)(a, b, c=(a, b))
|
||||
export.call_exported(exp_f)(a, b, c=(a, b))
|
||||
|
||||
def test_error_wrong_avals(self):
|
||||
def f(a, *, b): # a: f32[4] and b: f32[4]
|
||||
return jnp.sin(a) + jnp.cos(b)
|
||||
f32_4 = np.arange(4, dtype=np.float32)
|
||||
exp_f = jax_export.export(f)(f32_4, b=f32_4)
|
||||
exp_f = export.export(f)(f32_4, b=f32_4)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Shape mismatch for args\[0\].shape\[0\]"):
|
||||
jax_export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
|
||||
export.call_exported(exp_f)(np.arange(6, dtype=np.float32), b=f32_4)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Shape mismatch for kwargs\['b'\].shape\[0\]"):
|
||||
jax_export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
|
||||
export.call_exported(exp_f)(f32_4, b=np.arange(6, dtype=np.float32))
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Rank mismatch for args\[0\]"):
|
||||
jax_export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4)
|
||||
export.call_exported(exp_f)(f32_4.reshape((1, 4)), b=f32_4)
|
||||
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"Dtype mismatch for args\[0\]"):
|
||||
jax_export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4)
|
||||
export.call_exported(exp_f)(f32_4.astype(np.float16), b=f32_4)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
testcase_name=lambda kw: kw["platform"],
|
||||
@ -194,19 +194,19 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def test_error_wrong_platform(self, platform):
|
||||
a = np.arange(4, dtype=np.float32)
|
||||
|
||||
exp_f = jax_export.export(jnp.sin, lowering_platform=platform)(a)
|
||||
exp_f = export.export(jnp.sin, lowering_platform=platform)(a)
|
||||
if xb.canonicalize_platform(jtu.device_under_test()) == platform:
|
||||
raise unittest.SkipTest("Uninteresting scenario")
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "The exported function .* was lowered for platform"):
|
||||
jax_export.call_exported(exp_f)(a)
|
||||
export.call_exported(exp_f)(a)
|
||||
|
||||
# Now try with the platform check disabled
|
||||
exp_f_no_platform_check = jax_export.export(
|
||||
exp_f_no_platform_check = export.export(
|
||||
jnp.sin, lowering_platform=platform,
|
||||
disabled_checks=[jax_export.DisabledSafetyCheck.platform()])(a)
|
||||
res = jax_export.call_exported(exp_f_no_platform_check)(a)
|
||||
disabled_checks=[export.DisabledSafetyCheck.platform()])(a)
|
||||
res = export.call_exported(exp_f_no_platform_check)(a)
|
||||
self.assertAllClose(res, jnp.sin(a))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
@ -230,23 +230,23 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
a = np.arange(3, dtype=np.float32)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Cannot serialize code with custom calls whose targets .*"):
|
||||
jax_export.export(
|
||||
export.export(
|
||||
lambda a: a + test_primitive.bind(a)
|
||||
)(a)
|
||||
|
||||
# Now try again with the safety check disabled
|
||||
exp = jax_export.export(
|
||||
exp = export.export(
|
||||
lambda a: a + test_primitive.bind(a),
|
||||
disabled_checks=[jax_export.DisabledSafetyCheck.custom_call("disallowed_call_target")]
|
||||
disabled_checks=[export.DisabledSafetyCheck.custom_call("disallowed_call_target")]
|
||||
)(a)
|
||||
self.assertIn("disallowed_call_target", exp.mlir_module())
|
||||
|
||||
def test_grad(self):
|
||||
f = lambda x: jnp.sum(jnp.sin(x))
|
||||
x = np.arange(4, dtype=np.float32)
|
||||
exp_f = jax_export.export(f)(x)
|
||||
exp_f = export.export(f)(x)
|
||||
|
||||
f1 = jax_export.call_exported(exp_f)
|
||||
f1 = export.call_exported(exp_f)
|
||||
self.assertAllClose(jax.grad(f)(x), jax.grad(f1)(x))
|
||||
|
||||
def test_pytree_vjp(self):
|
||||
@ -256,14 +256,14 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
a = np.arange(4, dtype=np.float32)
|
||||
b = np.arange(6, dtype=np.float32)
|
||||
exp_f = jax_export.export(f)((a, b), a=a, b=b)
|
||||
exp_f = export.export(f)((a, b), a=a, b=b)
|
||||
|
||||
out_ct = f((a, b), a=a, b=b) # The output has the right structure as the cotangent
|
||||
def f1_jax(a, b): # For VJP, make a function without kwargs
|
||||
res = f((a, b), a=a, b=b)
|
||||
return res
|
||||
def f1_exp(a, b): # For VJP, make a function without kwargs
|
||||
res = jax_export.call_exported(exp_f)((a, b), a=a, b=b)
|
||||
res = export.call_exported(exp_f)((a, b), a=a, b=b)
|
||||
return res
|
||||
jax_vjp = jax.vjp(f1_jax, a, b)[1](out_ct)
|
||||
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
|
||||
@ -273,33 +273,33 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
def f1(x):
|
||||
return jnp.sin(x)
|
||||
a = np.arange(4, dtype=np.float32)
|
||||
exp_f1 = jax_export.export(f1)(a)
|
||||
exp_f1 = export.export(f1)(a)
|
||||
def f2(x):
|
||||
res1 = jax_export.call_exported(exp_f1)(x)
|
||||
res2 = jax_export.call_exported(exp_f1)(res1)
|
||||
res1 = export.call_exported(exp_f1)(x)
|
||||
res2 = export.call_exported(exp_f1)(res1)
|
||||
return jnp.cos(res2)
|
||||
exp_f2 = jax_export.export(f2)(a)
|
||||
exp_f2 = export.export(f2)(a)
|
||||
|
||||
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
|
||||
jax_export.call_exported(exp_f2)(a))
|
||||
export.call_exported(exp_f2)(a))
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
#one_containing="",
|
||||
kwargs=[
|
||||
dict(v=v)
|
||||
for v in range(jax_export.minimum_supported_serialization_version - 1,
|
||||
jax_export.maximum_supported_serialization_version + 2)])
|
||||
for v in range(export.minimum_supported_serialization_version - 1,
|
||||
export.maximum_supported_serialization_version + 2)])
|
||||
def test_shape_poly_basic_versions(self, v: int):
|
||||
self.override_serialization_version(v)
|
||||
with contextlib.ExitStack() as e:
|
||||
if not (jax_export.minimum_supported_serialization_version <= v
|
||||
<= jax_export.maximum_supported_serialization_version):
|
||||
if not (export.minimum_supported_serialization_version <= v
|
||||
<= export.maximum_supported_serialization_version):
|
||||
e.enter_context(self.assertRaisesRegex(
|
||||
ValueError,
|
||||
f"The requested jax_serialization version {v} is outside the range of supported versions"))
|
||||
|
||||
exp = jax_export.export(jnp.sin)(
|
||||
jax_export.poly_spec((3, 4), np.float32, "w, h"))
|
||||
exp = export.export(jnp.sin)(
|
||||
export.poly_spec((3, 4), np.float32, "w, h"))
|
||||
# Peek at the module
|
||||
module_str = exp.mlir_module()
|
||||
self.assertEqual(config.jax_serialization_version >= 7,
|
||||
@ -307,11 +307,11 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
self.assertIn("jax.uses_shape_polymorphism = true",
|
||||
module_str)
|
||||
x = np.arange(30, dtype=np.float32).reshape((5, 6))
|
||||
res = jax_export.call_exported(exp)(x)
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(res, np.sin(x))
|
||||
|
||||
# A function is exported with f32[poly_spec] and is called with different arg
|
||||
# shapes. We use jax_export.call_exported and we also run the shape check
|
||||
# shapes. We use export.call_exported and we also run the shape check
|
||||
# module.
|
||||
@jtu.parameterized_filterable(
|
||||
testcase_name=lambda kw:f"poly_spec={kw['poly_spec']}_arg_shape={kw['arg_shape']}", # type: ignore
|
||||
@ -348,8 +348,8 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
return jnp.reshape(x, (-1, x.shape[1]))
|
||||
|
||||
disabled_checks = ()
|
||||
exp_f = jax_export.export(f, disabled_checks=disabled_checks)(
|
||||
jax_export.poly_spec((3, 4, 12), np.float32, poly_spec))
|
||||
exp_f = export.export(f, disabled_checks=disabled_checks)(
|
||||
export.poly_spec((3, 4, 12), np.float32, poly_spec))
|
||||
self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12")
|
||||
arg = np.arange(np.prod(arg_shape),
|
||||
dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12]
|
||||
@ -359,7 +359,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
stack.push(self.assertRaisesRegex(Exception, expect_error))
|
||||
|
||||
assert core.is_constant_shape(arg.shape)
|
||||
res = jax_export.call_exported(exp_f)(arg)
|
||||
res = export.call_exported(exp_f)(arg)
|
||||
|
||||
if not expect_error:
|
||||
self.assertAllClose(res, f(arg))
|
||||
@ -450,23 +450,23 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
arg = np.arange(np.prod(arg_shape),
|
||||
dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12]
|
||||
inner_exp = jax_export.export(inner)(
|
||||
jax_export.poly_spec((3, 4, 12), np.float32, inner_poly_spec))
|
||||
inner_exp = export.export(inner)(
|
||||
export.poly_spec((3, 4, 12), np.float32, inner_poly_spec))
|
||||
|
||||
self.assertEqual(inner_exp.uses_shape_polymorphism,
|
||||
(inner_poly_spec != "3,4,12"))
|
||||
def outer(x): # x: outer_poly_spec
|
||||
# Use an addition to test that the shapes are refined properly for the
|
||||
# result of the call_exported.
|
||||
return jax_export.call_exported(inner_exp)(x) + inner(x)
|
||||
return export.call_exported(inner_exp)(x) + inner(x)
|
||||
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error_outer_exp is not None:
|
||||
stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))
|
||||
|
||||
# Call it after exporting again, with polymorphic shapes
|
||||
outer_exp = jax_export.export(outer)(
|
||||
jax_export.poly_spec(arg.shape, arg.dtype, outer_poly_spec))
|
||||
outer_exp = export.export(outer)(
|
||||
export.poly_spec(arg.shape, arg.dtype, outer_poly_spec))
|
||||
|
||||
if expect_error_outer_exp is not None:
|
||||
return
|
||||
@ -478,7 +478,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
if expect_error_run is not None:
|
||||
stack.push(self.assertRaisesRegex(Exception, expect_error_run))
|
||||
|
||||
res = jax_export.call_exported(outer_exp)(arg)
|
||||
res = export.call_exported(outer_exp)(arg)
|
||||
|
||||
if expect_error_run is not None:
|
||||
return
|
||||
@ -543,9 +543,9 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
with contextlib.ExitStack() as stack:
|
||||
if expect_error is not None:
|
||||
stack.push(self.assertRaisesRegex(Exception, re.escape(expect_error)))
|
||||
exp = jax_export.export(f_jax)(
|
||||
jax_export.poly_spec(x.shape, x.dtype, poly_spec))
|
||||
jax_export.call_exported(exp)(x)
|
||||
exp = export.export(f_jax)(
|
||||
export.poly_spec(x.shape, x.dtype, poly_spec))
|
||||
export.call_exported(exp)(x)
|
||||
|
||||
def test_multi_platform(self):
|
||||
if jtu.device_under_test() == "gpu":
|
||||
@ -553,10 +553,10 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest("Not intended for running on GPU")
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
# TODO: use a function with different behavior for different platforms
|
||||
exp = jax_export.export(jnp.sin,
|
||||
exp = export.export(jnp.sin,
|
||||
lowering_platforms=('cpu', 'tpu'))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu'))
|
||||
res = jax_export.call_exported(exp)(x)
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(res, np.sin(x))
|
||||
|
||||
def test_multi_platform_nested(self):
|
||||
@ -565,7 +565,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest("Not intended for running on TPU")
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
# TODO: use a function with different behavior for different platforms
|
||||
exp = jax_export.export(jnp.sin,
|
||||
exp = export.export(jnp.sin,
|
||||
lowering_platforms=('cpu', 'tpu', 'cuda'))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu', 'cuda'))
|
||||
|
||||
@ -573,9 +573,9 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# lowering platforms, but included in the lowering platforms for the
|
||||
# nested exported.
|
||||
# TODO: improve this test once we implement true multi-platform lowering
|
||||
exp2 = jax_export.export(jax_export.call_exported(exp),
|
||||
exp2 = export.export(export.call_exported(exp),
|
||||
lowering_platforms=('cpu', 'cuda'))(x)
|
||||
res2 = jax_export.call_exported(exp2)(x)
|
||||
res2 = export.call_exported(exp2)(x)
|
||||
self.assertAllClose(res2, np.sin(x))
|
||||
|
||||
def test_multi_platform_and_poly(self):
|
||||
@ -583,16 +583,16 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# The export is not applicable to GPU
|
||||
raise unittest.SkipTest("Not intended for running on GPU")
|
||||
# TODO: use a function with different behavior for different platforms
|
||||
exp = jax_export.export(lambda x: jnp.reshape(jnp.sin(x), (-1,)),
|
||||
exp = export.export(lambda x: jnp.reshape(jnp.sin(x), (-1,)),
|
||||
lowering_platforms=('cpu', 'tpu'))(
|
||||
jax_export.poly_spec((5, 6), np.float32, "b1, b2")
|
||||
export.poly_spec((5, 6), np.float32, "b1, b2")
|
||||
)
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
res = jax_export.call_exported(exp)(x)
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(res, np.sin(x).reshape((-1,)))
|
||||
# Now serialize the call to the exported
|
||||
exp2 = jax_export.export(jax_export.call_exported(exp))(x)
|
||||
res2 = jax_export.call_exported(exp2)(x)
|
||||
exp2 = export.export(export.call_exported(exp))(x)
|
||||
res2 = export.call_exported(exp2)(x)
|
||||
self.assertAllClose(res2, np.sin(x).reshape((-1,)))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user