diff --git a/tests/BUILD b/tests/BUILD index 8038bc57c..de924bdca 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1392,6 +1392,16 @@ jax_test( ], ) +jax_test( + name = "export_back_compat_test", + srcs = ["export_back_compat_test.py"], + tags = [], + deps = [ + "//jax:internal_export_back_compat_test_data", + "//jax:internal_export_back_compat_test_util", + ], +) + exports_files( [ "api_test.py", diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/tests/export_back_compat_test.py similarity index 97% rename from jax/experimental/jax2tf/tests/back_compat_test.py rename to tests/export_back_compat_test.py index abf11fddd..3be63dbd6 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -42,7 +42,6 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_l from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32 -from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK @@ -61,6 +60,7 @@ from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import test_util as jtu +from jax._src.lib import version as jaxlib_version config.parse_flags_with_absl() @@ -111,7 +111,6 @@ class CompatTest(bctu.CompatTestBase): cpu_schur_lapack_gees.data_2023_07_16, cpu_svd_lapack_gesdd.data_2023_06_19, cpu_triangular_solve_blas_trsm.data_2023_07_16, - tf_call_tf_function.data_2023_07_29, # This is tested in back_compat_tf_test.py tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17, tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17, tpu_ApproxTopK.data_2023_05_16, @@ -130,6 +129,7 @@ class CompatTest(bctu.CompatTestBase): covered_targets = covered_targets.union(data.custom_call_targets) covered_targets = covered_targets.union({ + "tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py "tpu_custom_call", # tested separately }) not_covered = targets_to_cover.difference(covered_targets) @@ -144,14 +144,21 @@ class CompatTest(bctu.CompatTestBase): # An old lowering, with ducc_fft. We keep it for 6 months. data = self.load_testdata(cpu_ducc_fft.data_2023_03_17) - # We have changed the lowering for fft since we saved this data. - # FFT no longer lowers to a custom call. - self.run_one_test(func, data, expect_current_custom_calls=[]) + if jaxlib_version <= (0, 4, 20): + expect_current_custom_calls = ["dynamic_ducc_fft"] + else: + # We have changed the lowering for fft since we saved this data. + # FFT no longer lowers to a custom call. + expect_current_custom_calls = [] + + self.run_one_test(func, data, + expect_current_custom_calls=expect_current_custom_calls) # A newer lowering, with dynamic_ducc_fft. data = self.load_testdata(cpu_ducc_fft.data_2023_06_14) # FFT no longer lowers to a custom call. - self.run_one_test(func, data, expect_current_custom_calls=[]) + self.run_one_test(func, data, + expect_current_custom_calls=expect_current_custom_calls) def cholesky_input(self, shape, dtype): a = jtu.rand_default(self.rng())(shape, dtype) @@ -294,6 +301,8 @@ class CompatTest(bctu.CompatTestBase): # We use different custom calls for sizes <= 32 for variant in ["syevj", "syevd"]) def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"): + if not config.enable_x64.value and dtype_name == "f64": + self.skipTest("Test disabled for x32 mode") # For lax.linalg.eigh dtype = dict(f32=np.float32, f64=np.float64)[dtype_name] size = dict(syevj=8, syevd=36)[variant] @@ -646,6 +655,9 @@ class CompatTest(bctu.CompatTestBase): # Recent serializations also include shape_assertion, tested with dynamic_top_k expect_current_custom_calls=["stablehlo.dynamic_reduce_window", "shape_assertion"]) + @jtu.ignore_warning( + category=FutureWarning, + message="Raw arrays as random keys to jax.random functions are deprecated") def test_stablehlo_dynamic_rbg_bit_generator(self): # stablehlo.dynamic_rbg_bit_generator is used temporarily for a # rbg_bit_generator with dynamic shapes.