mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Move export backwards compatibility tests out of jax2tf. Step 3.
The last part of moving the tests: move jax2tf/tests/back_compat_test.py to tests/export_back_compat_test.py. PiperOrigin-RevId: 596555577
This commit is contained in:
parent
3d608b1850
commit
ed2a839884
10
tests/BUILD
10
tests/BUILD
@ -1392,6 +1392,16 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "export_back_compat_test",
|
||||
srcs = ["export_back_compat_test.py"],
|
||||
tags = [],
|
||||
deps = [
|
||||
"//jax:internal_export_back_compat_test_data",
|
||||
"//jax:internal_export_back_compat_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"api_test.py",
|
||||
|
@ -42,7 +42,6 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_schur_l
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32
|
||||
from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu
|
||||
from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK
|
||||
@ -61,6 +60,7 @@ from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -111,7 +111,6 @@ class CompatTest(bctu.CompatTestBase):
|
||||
cpu_schur_lapack_gees.data_2023_07_16,
|
||||
cpu_svd_lapack_gesdd.data_2023_06_19,
|
||||
cpu_triangular_solve_blas_trsm.data_2023_07_16,
|
||||
tf_call_tf_function.data_2023_07_29, # This is tested in back_compat_tf_test.py
|
||||
tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17,
|
||||
tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17,
|
||||
tpu_ApproxTopK.data_2023_05_16,
|
||||
@ -130,6 +129,7 @@ class CompatTest(bctu.CompatTestBase):
|
||||
covered_targets = covered_targets.union(data.custom_call_targets)
|
||||
|
||||
covered_targets = covered_targets.union({
|
||||
"tf.call_tf_function", # tested in jax2tf/tests/back_compat_tf_test.py
|
||||
"tpu_custom_call", # tested separately
|
||||
})
|
||||
not_covered = targets_to_cover.difference(covered_targets)
|
||||
@ -144,14 +144,21 @@ class CompatTest(bctu.CompatTestBase):
|
||||
|
||||
# An old lowering, with ducc_fft. We keep it for 6 months.
|
||||
data = self.load_testdata(cpu_ducc_fft.data_2023_03_17)
|
||||
# We have changed the lowering for fft since we saved this data.
|
||||
# FFT no longer lowers to a custom call.
|
||||
self.run_one_test(func, data, expect_current_custom_calls=[])
|
||||
if jaxlib_version <= (0, 4, 20):
|
||||
expect_current_custom_calls = ["dynamic_ducc_fft"]
|
||||
else:
|
||||
# We have changed the lowering for fft since we saved this data.
|
||||
# FFT no longer lowers to a custom call.
|
||||
expect_current_custom_calls = []
|
||||
|
||||
self.run_one_test(func, data,
|
||||
expect_current_custom_calls=expect_current_custom_calls)
|
||||
|
||||
# A newer lowering, with dynamic_ducc_fft.
|
||||
data = self.load_testdata(cpu_ducc_fft.data_2023_06_14)
|
||||
# FFT no longer lowers to a custom call.
|
||||
self.run_one_test(func, data, expect_current_custom_calls=[])
|
||||
self.run_one_test(func, data,
|
||||
expect_current_custom_calls=expect_current_custom_calls)
|
||||
|
||||
def cholesky_input(self, shape, dtype):
|
||||
a = jtu.rand_default(self.rng())(shape, dtype)
|
||||
@ -294,6 +301,8 @@ class CompatTest(bctu.CompatTestBase):
|
||||
# We use different custom calls for sizes <= 32
|
||||
for variant in ["syevj", "syevd"])
|
||||
def test_cuda_eigh_cusolver_syev(self, dtype_name="f32", variant="syevj"):
|
||||
if not config.enable_x64.value and dtype_name == "f64":
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
# For lax.linalg.eigh
|
||||
dtype = dict(f32=np.float32, f64=np.float64)[dtype_name]
|
||||
size = dict(syevj=8, syevd=36)[variant]
|
||||
@ -646,6 +655,9 @@ class CompatTest(bctu.CompatTestBase):
|
||||
# Recent serializations also include shape_assertion, tested with dynamic_top_k
|
||||
expect_current_custom_calls=["stablehlo.dynamic_reduce_window", "shape_assertion"])
|
||||
|
||||
@jtu.ignore_warning(
|
||||
category=FutureWarning,
|
||||
message="Raw arrays as random keys to jax.random functions are deprecated")
|
||||
def test_stablehlo_dynamic_rbg_bit_generator(self):
|
||||
# stablehlo.dynamic_rbg_bit_generator is used temporarily for a
|
||||
# rbg_bit_generator with dynamic shapes.
|
Loading…
x
Reference in New Issue
Block a user