Move export backwards compatibility tests out of jax2tf. Step 1.

These tests are independent of TensorFlow, yet by being in the jax2tf package they end up pulling in TensorFlow as a dependency.

This is part of a larger cl/562671314 that ran into OSS build problems.
I am attempting this smaller change first, and afterwards I will move more of the test data files, and then the actual test.

PiperOrigin-RevId: 591927484
This commit is contained in:
George Necula 2023-12-18 09:49:11 -08:00 committed by jax authors
parent 36cf5afa67
commit eed61f68aa
3 changed files with 17 additions and 6 deletions

View File

@ -137,12 +137,13 @@ py_library(
py_library(
name = "internal_test_util",
testonly = 1,
srcs = glob(
["_src/internal_test_util/**/*.py"], # include
srcs = [
"_src/internal_test_util/deprecation_module.py",
"_src/internal_test_util/lax_test_util.py",
] + glob(
[
"_src/internal_test_util/test_harnesses.py",
"_src/internal_test_util/export_back_compat_test_util.py",
], # exclude
"_src/internal_test_util/lazy_loader_module/*.py",
],
),
visibility = [":internal"],
deps = [
@ -173,6 +174,16 @@ py_library(
] + py_deps("numpy"),
)
py_library(
name = "internal_export_back_compat_test_data",
testonly = 1,
srcs = glob(["_src/internal_test_util/export_back_compat_test_data/*.py"]),
visibility = [
":internal",
],
deps = py_deps("numpy"),
)
py_library_providing_imports_info(
name = "jax",
srcs = [

View File

@ -30,7 +30,7 @@ from jax import lax
from jax.experimental.export import export
from jax._src.internal_test_util import export_back_compat_test_util as bctu
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
from jax._src.internal_test_util.export_back_compat_test_data import cpu_ducc_fft
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_cholesky_lapack_potrf
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_eig_lapack_geev
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_eigh_cusolver_syev