Move export backwards compatibility tests out of jax2tf. Step 3.

The last part of moving the tests: move jax2tf/tests/back_compat_test.py to tests/export_back_compat_test.py.

PiperOrigin-RevId: 596555577
This commit is contained in:
George Necula 2024-01-08 04:47:36 -08:00 committed by jax authors
parent 3d608b1850
commit ed2a839884
2 changed files with 28 additions and 6 deletions

View File

@ -1392,6 +1392,16 @@ jax_test(
],
)
jax_test(
name = "export_back_compat_test",
srcs = ["export_back_compat_test.py"],
tags = [],
deps = [
"//jax:internal_export_back_compat_test_data",
"//jax:internal_export_back_compat_test_util",
],
)
exports_files(
[
"api_test.py",

View File

@ -42,7 +42,6 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_l
from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd
from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm
from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32
from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu
from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK
@ -61,6 +60,7 @@ from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import version as jaxlib_version
config.parse_flags_with_absl()
@ -111,7 +111,6 @@ class CompatTest(bctu.CompatTestBase):
cpu_schur_lapack_gees.data_2023_07_16,
cpu_svd_lapack_gesdd.data_2023_06_19,
cpu_triangular_solve_blas_trsm.data_2023_07_16,
tf_call_tf_function.data_2023_07_29, # This is tested in back_compat_tf_test.py
tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17,
tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17,
tpu_ApproxTopK.data_2023_05_16,
@ -130,6 +129,7 @@ class CompatTest(bctu.CompatTestBase):
covered_targets = covered_targets.union(data.custom_call_targets)
covered_targets = covered_targets.union({
"tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py
"tpu_custom_call", # tested separately
})
not_covered = targets_to_cover.difference(covered_targets)
@ -144,14 +144,21 @@ class CompatTest(bctu.CompatTestBase):
# An old lowering, with ducc_fft. We keep it for 6 months.
data = self.load_testdata(cpu_ducc_fft.data_2023_03_17)
# We have changed the lowering for fft since we saved this data.
# FFT no longer lowers to a custom call.
self.run_one_test(func, data, expect_current_custom_calls=[])
if jaxlib_version <= (0, 4, 20):
expect_current_custom_calls = ["dynamic_ducc_fft"]
else:
# We have changed the lowering for fft since we saved this data.
# FFT no longer lowers to a custom call.
expect_current_custom_calls = []
self.run_one_test(func, data,
expect_current_custom_calls=expect_current_custom_calls)
# A newer lowering, with dynamic_ducc_fft.
data = self.load_testdata(cpu_ducc_fft.data_2023_06_14)
# FFT no longer lowers to a custom call.
self.run_one_test(func, data, expect_current_custom_calls=[])
self.run_one_test(func, data,
expect_current_custom_calls=expect_current_custom_calls)
def cholesky_input(self, shape, dtype):
a = jtu.rand_default(self.rng())(shape, dtype)
@ -294,6 +301,8 @@ class CompatTest(bctu.CompatTestBase):
# We use different custom calls for sizes <= 32
for variant in ["syevj", "syevd"])
def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"):
if not config.enable_x64.value and dtype_name == "f64":
self.skipTest("Test disabled for x32 mode")
# For lax.linalg.eigh
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
size = dict(syevj=8, syevd=36)[variant]
@ -646,6 +655,9 @@ class CompatTest(bctu.CompatTestBase):
# Recent serializations also include shape_assertion, tested with dynamic_top_k
expect_current_custom_calls=["stablehlo.dynamic_reduce_window", "shape_assertion"])
@jtu.ignore_warning(
category=FutureWarning,
message="Raw arrays as random keys to jax.random functions are deprecated")
def test_stablehlo_dynamic_rbg_bit_generator(self):
# stablehlo.dynamic_rbg_bit_generator is used temporarily for a
# rbg_bit_generator with dynamic shapes.