Add support for running JAX tests under Bazel.

This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.

To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```

Issue #7323

PiperOrigin-RevId: 458551208
This commit is contained in:
Peter Hawkins 2022-07-01 15:06:54 -07:00 committed by jax authors
parent 270f73e346
commit 1fc9afd03a
6 changed files with 1227 additions and 6 deletions

239
jax/BUILD Normal file
View File

@ -0,0 +1,239 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# JAX is Autograd and XLA
load(
"//jaxlib:jax.bzl",
"absl_logging_py_deps",
"absl_testing_py_deps",
"jax_extra_deps",
"jax_internal_packages",
"jax_test_util_visibility",
"loops_visibility",
"numpy_py_deps",
"py_library_providing_imports_info",
"pytype_library",
"scipy_py_deps",
"sharded_jit_visibility",
)
licenses(["notice"])
package(default_visibility = [":internal"])
exports_files([
"LICENSE",
"version.py",
])
# Packages that have access to JAX-internal implementation details.
package_group(
name = "internal",
packages = [
"//...",
] + jax_internal_packages,
)
# JAX-private test utilities.
py_library(
# This build target is required in order to use private test utilities in jax._src.test_util,
# and its visibility is intentionally restricted to discourage its use outside JAX itself.
# JAX does provide some public test utilities (see jax/test_util.py);
# these are available in jax.test_util via the standard :jax target.
name = "test_util",
testonly = 1,
srcs = [
"_src/test_util.py",
],
visibility = [
":internal",
] + jax_test_util_visibility,
deps = [
":jax",
] + absl_testing_py_deps + numpy_py_deps,
)
py_library_providing_imports_info(
name = "jax",
srcs = glob(
[
"*.py",
"_src/**/*.py",
"image/**/*.py",
"interpreters/**/*.py",
"lax/**/*.py",
"lib/**/*.py",
"nn/**/*.py",
"numpy/**/*.py",
"ops/**/*.py",
"scipy/**/*.py",
"third_party/**/*.py",
],
exclude = [
"_src/test_util.py",
"*_test.py",
"**/*_test.py",
"interpreters/sharded_jit.py",
],
) + [
# until new parallelism APIs are moved out of experimental
"experimental/maps.py",
"experimental/pjit.py",
"experimental/global_device_array.py",
"experimental/array.py",
"experimental/sharding.py",
"experimental/multihost_utils.py",
# until checkify is moved out of experimental
"experimental/checkify.py",
# to avoid circular dependencies
"experimental/compilation_cache/compilation_cache.py",
"experimental/compilation_cache/gfile_cache.py",
"experimental/compilation_cache/cache_interface.py",
],
lib_rule = pytype_library,
visibility = ["//visibility:public"],
deps = [
"//jaxlib",
] + numpy_py_deps + scipy_py_deps + jax_extra_deps,
)
py_library_providing_imports_info(
name = "experimental",
srcs = glob([
"experimental/*.py",
"example_libraries/*.py",
]),
visibility = ["//visibility:public"],
deps = [
":jax",
":sharded_jit",
] + absl_logging_py_deps + numpy_py_deps,
)
pytype_library(
name = "stax",
srcs = [
"example_libraries/stax.py",
"experimental/stax.py",
],
visibility = ["//visibility:public"],
deps = [":jax"],
)
pytype_library(
name = "experimental_sparse",
srcs = glob([
"experimental/sparse/*.py",
]),
visibility = ["//visibility:public"],
deps = [":jax"],
)
# sharded_jit is deprecated. Please do not add any more projects to the visibility.
pytype_library(
name = "sharded_jit",
srcs = ["interpreters/sharded_jit.py"],
visibility = [
":internal",
] + sharded_jit_visibility,
deps = [":jax"],
)
pytype_library(
name = "optimizers",
srcs = [
"example_libraries/optimizers.py",
"experimental/optimizers.py",
],
visibility = ["//visibility:public"],
deps = [":jax"],
)
pytype_library(
name = "ode",
srcs = ["experimental/ode.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)
# loops is deprecated. Please do not add any more projects to the visibility.
pytype_library(
name = "loops",
srcs = ["experimental/loops.py"],
visibility = [":internal"] + loops_visibility,
deps = [":jax"],
)
pytype_library(
name = "callback",
srcs = ["experimental/callback.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)
# TODO(apaszke): Remove this target
pytype_library(
name = "maps",
srcs = ["experimental/maps.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)
# TODO(apaszke): Remove this target
pytype_library(
name = "pjit",
srcs = ["experimental/pjit.py"],
visibility = ["//visibility:public"],
deps = [
":experimental",
":jax",
],
)
pytype_library(
name = "jet",
srcs = ["experimental/jet.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)
pytype_library(
name = "experimental_host_callback",
srcs = ["experimental/host_callback.py"],
visibility = ["//visibility:public"],
deps = [
":jax",
],
)
pytype_library(
name = "compilation_cache",
srcs = [
"experimental/compilation_cache/compilation_cache.py",
"experimental/compilation_cache/gfile_cache.py",
],
visibility = ["//visibility:public"],
deps = [":jax"],
)
pytype_library(
name = "mesh_utils",
srcs = ["experimental/mesh_utils.py"],
visibility = ["//visibility:public"],
deps = [
":experimental",
":jax",
],
)

View File

@ -1 +0,0 @@
exports_files(["version.py"])

View File

@ -0,0 +1,48 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"jax2tf_deps",
"numpy_py_deps",
"tensorflow_py_deps",
)
licenses(["notice"]) # Apache 2
package(
default_visibility = ["//visibility:private"],
)
py_library(
name = "jax2tf",
srcs = ["__init__.py"],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [":jax2tf_internal"],
)
py_library(
name = "jax2tf_internal",
srcs = [
"call_tf.py",
"impl_no_xla.py",
"jax2tf.py",
"shape_poly.py",
],
srcs_version = "PY3",
deps = [
"//jax",
] + numpy_py_deps + tensorflow_py_deps + jax2tf_deps,
)

View File

@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"tensorflow_py_deps",
)
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
@ -24,7 +29,7 @@ py_library(
"ignore_for_dep=third_party.py.tensorflow",
],
deps = [
"//third_party/py/jax",
"//jax",
],
)
@ -32,8 +37,7 @@ py_library(
name = "jax_to_ir_with_tensorflow",
srcs = ["jax_to_ir.py"],
deps = [
"//third_party/py/jax",
"//third_party/py/jax/experimental/jax2tf",
"//third_party/py/tensorflow",
],
"//jax",
"//jax/experimental/jax2tf",
] + tensorflow_py_deps,
)

View File

@ -32,6 +32,26 @@ if_rocm_is_configured = _if_rocm_is_configured
if_windows = _if_windows
flatbuffer_cc_library = _flatbuffer_cc_library
jax_internal_packages = []
jax_test_util_visibility = []
loops_visibility = []
sharded_jit_visibility = []
absl_logging_py_deps = []
absl_testing_py_deps = []
cloudpickle_py_deps = []
numpy_py_deps = []
pil_py_deps = []
portpicker_py_deps = []
scipy_py_deps = []
tensorflow_py_deps = []
jax_extra_deps = []
jax2tf_deps = []
def py_library_providing_imports_info(*, name, lib_rule = native.py_library, **kwargs):
lib_rule(name = name, **kwargs)
def py_extension(name, srcs, copts, deps):
pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name)
@ -100,3 +120,37 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []):
interface_library = interface_library_file,
shared_library = out,
)
def jax_test(
name,
srcs,
args = [],
shard_count = None,
deps = [],
disable_backends = None, # buildifier: disable=unused-variable
backend_tags = {}, # buildifier: disable=unused-variable
disable_configs = None, # buildifier: disable=unused-variable
enable_configs = None, # buildifier: disable=unused-variable
tags = [],
main = None):
if shard_count == None or type(shard_count) == type(0):
shards = shard_count
else:
shards = shard_count.get("cpu", None)
native.py_test(
name = name,
srcs = srcs,
args = args,
deps = [
"//jax",
"//jax:test_util",
] + deps,
shard_count = shards,
tags = tags,
main = main,
)
def jax_generate_backend_suites():
pass
jax_test_file_visibility = []

877
tests/BUILD Normal file
View File

@ -0,0 +1,877 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"absl_logging_py_deps",
"cloudpickle_py_deps",
"jax_generate_backend_suites",
"jax_test",
"jax_test_file_visibility",
"pil_py_deps",
"portpicker_py_deps",
"pytype_library",
"scipy_py_deps",
"tensorflow_py_deps",
)
licenses(["notice"]) # Apache 2
package(default_visibility = ["//visibility:private"])
jax_generate_backend_suites()
jax_test(
name = "api_test",
srcs = ["api_test.py"],
shard_count = 5,
)
jax_test(
name = "api_util_test",
srcs = ["api_util_test.py"],
)
jax_test(
name = "array_interoperability_test",
srcs = ["array_interoperability_test.py"],
disable_backends = ["tpu"],
deps = tensorflow_py_deps,
)
jax_test(
name = "batching_test",
srcs = ["batching_test.py"],
)
jax_test(
name = "callback_test",
srcs = ["callback_test.py"],
deps = ["//jax:callback"],
)
jax_test(
name = "core_test",
srcs = ["core_test.py"],
shard_count = {
"cpu": 5,
},
)
jax_test(
name = "custom_object_test",
srcs = ["custom_object_test.py"],
)
py_test(
name = "debug_nans_test",
srcs = ["debug_nans_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)
py_test(
name = "distributed_test",
srcs = ["distributed_test.py"],
args = [
"--exclude_test_targets=MultiProcessGpuTest",
],
deps = [
"//jax",
"//jax:test_util",
] + portpicker_py_deps,
)
jax_test(
name = "dtypes_test",
srcs = ["dtypes_test.py"],
)
py_test(
name = "errors_test",
srcs = ["errors_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)
jax_test(
name = "fft_test",
srcs = ["fft_test.py"],
backend_tags = {
"tpu": ["notsan"], # Times out on TPU with tsan.
},
shard_count = {
"tpu": 20,
},
)
jax_test(
name = "generated_fun_test",
srcs = ["generated_fun_test.py"],
)
jax_test(
name = "lobpcg_test",
srcs = ["lobpcg_test.py"],
deps = [
"//jax:experimental_sparse",
],
)
jax_test(
name = "svd_test",
srcs = ["svd_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
},
)
py_test(
name = "xla_interpreter_test",
srcs = ["xla_interpreter_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)
jax_test(
name = "xmap_test",
srcs = ["xmap_test.py"],
shard_count = {
"cpu": 10,
"tpu": 4,
},
deps = [
"//jax:maps",
],
)
jax_test(
name = "pjit_test",
srcs = ["pjit_test.py"],
deps = [
"//jax:experimental",
],
)
jax_test(
name = "global_device_array_test",
srcs = ["global_device_array_test.py"],
deps = [
"//jax:experimental",
],
)
jax_test(
name = "array_test",
srcs = ["array_test.py"],
deps = [
"//jax:experimental",
],
)
jax_test(
name = "remote_transfer_test",
srcs = ["remote_transfer_test.py"],
disable_backends = [
"gpu",
"cpu",
],
deps = [
"//jax:experimental",
],
)
jax_test(
name = "image_test",
srcs = ["image_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
"iree": 10,
},
deps = pil_py_deps + tensorflow_py_deps,
)
jax_test(
name = "infeed_test",
srcs = ["infeed_test.py"],
deps = [
"//jax:experimental_host_callback",
],
)
py_test(
name = "jax_jit_test_x32",
srcs = ["jax_jit_test.py"],
main = "jax_jit_test.py",
visibility = ["//visibility:private"],
deps = [
"//jax",
"//jax:test_util",
],
)
py_test(
name = "jax_jit_test_x64",
srcs = ["jax_jit_test.py"],
args = ["--jax_enable_x64=true"],
main = "jax_jit_test.py",
visibility = ["//visibility:private"],
deps = [
"//jax",
"//jax:test_util",
],
)
test_suite(
name = "jax_jit_test",
tests = [
"jax_jit_test_x32",
"jax_jit_test_x64",
],
)
py_test(
name = "jax_to_ir_test",
srcs = ["jax_to_ir_test.py"],
deps = [
"//jax:test_util",
"//jax/experimental/jax2tf",
"//jax/tools:jax_to_ir",
] + tensorflow_py_deps,
)
jax_test(
name = "jaxpr_util_test",
srcs = ["jaxpr_util_test.py"],
)
jax_test(
name = "jet_test",
srcs = ["jet_test.py"],
shard_count = {
"cpu": 10,
},
deps = [
"//jax:jet",
"//jax:stax",
],
)
jax_test(
name = "lax_control_flow_test",
srcs = ["lax_control_flow_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
"iree": 10,
},
)
jax_test(
name = "custom_root_test",
srcs = ["custom_root_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
"iree": 10,
},
)
jax_test(
name = "custom_linear_solve_test",
srcs = ["custom_linear_solve_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
"iree": 10,
},
)
jax_test(
name = "lax_numpy_test",
srcs = ["lax_numpy_test.py"],
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 20,
},
)
jax_test(
name = "lax_numpy_indexing_test",
srcs = ["lax_numpy_indexing_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
"iree": 10,
},
)
jax_test(
name = "lax_numpy_einsum_test",
srcs = ["lax_numpy_einsum_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
"iree": 10,
},
)
jax_test(
name = "lax_numpy_vectorize_test",
srcs = ["lax_numpy_vectorize_test.py"],
)
jax_test(
name = "lax_scipy_test",
srcs = ["lax_scipy_test.py"],
backend_tags = {
"tpu": ["noasan"], # Test times out.
},
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
"iree": 10,
},
)
jax_test(
name = "lax_scipy_sparse_test",
srcs = ["lax_scipy_sparse_test.py"],
backend_tags = {
"cpu": ["nomsan"], # Test fails under msan because of fortran code inside scipy.
},
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
"iree": 10,
},
)
jax_test(
name = "lax_test",
srcs = ["lax_test.py"],
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 20,
"iree": 40,
},
)
pytype_library(
name = "lax_test_lib",
srcs = ["lax_test.py"],
srcs_version = "PY3",
deps = ["//jax"],
)
pytype_library(
name = "lax_vmap_test_lib",
testonly = 1,
srcs = ["lax_vmap_test.py"],
srcs_version = "PY3",
deps = [
":lax_test_lib",
"//jax",
"//jax:test_util",
],
)
jax_test(
name = "lax_autodiff_test",
srcs = ["lax_autodiff_test.py"],
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 20,
"iree": 40,
},
deps = [":lax_test_lib"],
)
jax_test(
name = "lax_vmap_test",
srcs = ["lax_vmap_test.py"],
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 20,
"iree": 40,
},
deps = [":lax_test_lib"],
)
jax_test(
name = "linalg_test",
srcs = ["linalg_test.py"],
backend_tags = {
"tpu": [
"cpu:8",
"noasan", # Times out.
],
},
shard_count = {
"cpu": 20,
"gpu": 20,
"tpu": 10,
"iree": 20,
},
)
jax_test(
name = "masking_test",
srcs = ["masking_test.py"],
)
jax_test(
name = "metadata_test",
srcs = ["metadata_test.py"],
disable_backends = [
"gpu",
"tpu",
],
)
jax_test(
name = "multibackend_test",
srcs = ["multibackend_test.py"],
)
jax_test(
name = "multi_device_test",
srcs = ["multi_device_test.py"],
disable_backends = [
"gpu",
"tpu",
],
)
jax_test(
name = "nn_test",
srcs = ["nn_test.py"],
)
jax_test(
name = "optimizers_test",
srcs = ["optimizers_test.py"],
deps = ["//jax:optimizers"],
)
jax_test(
name = "pickle_test",
srcs = ["pickle_test.py"],
deps = [
"//jax:experimental",
] + cloudpickle_py_deps,
)
jax_test(
name = "pmap_test",
srcs = ["pmap_test.py"],
shard_count = {
"cpu": 5,
"gpu": 5,
"tpu": 5,
},
deps = [
":lax_test_lib",
":lax_vmap_test_lib",
],
)
jax_test(
name = "polynomial_test",
srcs = ["polynomial_test.py"],
# No implementation of nonsymmetric Eigendecomposition.
disable_backends = [
"gpu",
"tpu",
],
shard_count = {
"cpu": 10,
},
# This test ends up calling Fortran code that initializes some memory and
# passes it to C code. MSan is not able to detect that the memory was
# initialized by Fortran, and it makes the test fail. This can usually be
# fixed by annotating the memory with `ANNOTATE_MEMORY_IS_INITIALIZED`, but
# in this case there's not a good place to do it, see b/197635968#comment19
# for details.
tags = ["nomsan"],
)
jax_test(
name = "heap_profiler_test",
srcs = ["heap_profiler_test.py"],
disable_backends = [
"gpu",
"tpu",
],
)
jax_test(
name = "profiler_test",
srcs = ["profiler_test.py"],
disable_backends = [
"gpu",
"tpu",
],
)
jax_test(
name = "qdwh_test",
srcs = ["qdwh_test.py"],
)
jax_test(
name = "random_test",
srcs = ["random_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
"iree": 10,
},
)
# TODO(b/199564969): remove once we always enable_custom_prng
jax_test(
name = "random_test_with_custom_prng",
srcs = ["random_test.py"],
args = ["--jax_enable_custom_prng=true"],
backend_tags = {
"cpu": ["noasan"], # Times out under asan.
},
main = "random_test.py",
shard_count = {
"cpu": 30,
"gpu": 20,
"tpu": 20,
"iree": 20,
},
)
jax_test(
name = "scipy_fft_test",
srcs = ["scipy_fft_test.py"],
)
jax_test(
name = "scipy_interpolate_test",
srcs = ["scipy_interpolate_test.py"],
)
jax_test(
name = "scipy_ndimage_test",
srcs = ["scipy_ndimage_test.py"],
)
jax_test(
name = "scipy_optimize_test",
srcs = ["scipy_optimize_test.py"],
)
jax_test(
name = "scipy_signal_test",
srcs = ["scipy_signal_test.py"],
backend_tags = {
"cpu": [
"noasan", # Test times out under asan.
],
"tpu": [
"noasan",
"notsan",
], # Test times out under asan/tsan.
},
shard_count = {
"tpu": 5,
},
)
jax_test(
name = "scipy_stats_test",
srcs = ["scipy_stats_test.py"],
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
"iree": 10,
},
)
jax_test(
name = "sharded_jit_test",
srcs = ["sharded_jit_test.py"],
disable_backends = [
"tpu",
"iree",
],
deps = ["//jax:experimental"],
)
jax_test(
name = "sparse_test",
srcs = ["sparse_test.py"],
args = ["--jax_bcoo_cusparse_lowering=true"],
backend_tags = {
"cpu": [
"noasan", # Test times out under asan.
],
},
shard_count = {
"cpu": 10,
"gpu": 10,
"tpu": 10,
"iree": 10,
},
deps = [
"//jax:experimental_sparse",
] + scipy_py_deps,
)
jax_test(
name = "sparsify_test",
srcs = ["sparsify_test.py"],
args = ["--jax_bcoo_cusparse_lowering=true"],
deps = [
"//jax:experimental_sparse",
],
)
jax_test(
name = "stack_test",
srcs = ["stack_test.py"],
)
jax_test(
name = "checkify_test",
srcs = ["checkify_test.py"],
)
jax_test(
name = "stax_test",
srcs = ["stax_test.py"],
shard_count = {
"cpu": 5,
"gpu": 5,
"iree": 5,
},
deps = ["//jax:stax"],
)
jax_test(
name = "linear_search_test",
srcs = ["third_party/scipy/line_search_test.py"],
main = "third_party/scipy/line_search_test.py",
)
py_test(
name = "tree_util_test",
srcs = ["tree_util_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)
py_test(
name = "util_test",
srcs = ["util_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)
py_test(
name = "version_test",
srcs = ["version_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)
py_test(
name = "xla_bridge_test",
srcs = ["xla_bridge_test.py"],
deps = [
"//jax",
"//jax:test_util",
] + absl_logging_py_deps,
)
py_test(
name = "gfile_cache_test",
srcs = ["gfile_cache_test.py"],
deps = [
"//jax",
"//jax:compilation_cache",
"//jax:test_util",
],
)
jax_test(
name = "compilation_cache_test",
srcs = ["compilation_cache_test.py"],
backend_tags = {
"tpu": ["nomsan"], # TODO(b/213388298): this test fails msan.
},
deps = [
"//jax:compilation_cache",
"//jax:experimental",
],
)
jax_test(
name = "ode_test",
srcs = ["ode_test.py"],
shard_count = {
"cpu": 10,
},
deps = ["//jax:ode"],
)
jax_test(
name = "host_callback_test",
srcs = ["host_callback_test.py"],
args = ["--jax_host_callback_outfeed=true"],
deps = [
"//jax:experimental",
"//jax:experimental_host_callback",
"//jax:ode",
],
)
jax_test(
name = "host_callback_custom_call_test",
srcs = ["host_callback_test.py"],
args = ["--jax_host_callback_outfeed=false"],
disable_backends = [
"gpu",
"tpu", # On TPU we always use outfeed
],
main = "host_callback_test.py",
deps = [
"//jax:experimental",
"//jax:experimental_host_callback",
"//jax:ode",
],
)
jax_test(
name = "host_callback_to_tf_test",
srcs = ["host_callback_to_tf_test.py"],
deps = [
"//jax:experimental_host_callback",
"//jax:ode",
] + tensorflow_py_deps,
)
jax_test(
name = "x64_context_test",
srcs = ["x64_context_test.py"],
deps = [
"//jax:experimental",
],
)
jax_test(
name = "ann_test",
srcs = ["ann_test.py"],
shard_count = 2,
deps = [
":lax_test_lib",
],
)
py_test(
name = "mesh_utils_test",
srcs = ["mesh_utils_test.py"],
deps = [
"//jax",
"//jax:mesh_utils",
"//jax:test_util",
],
)
jax_test(
name = "transfer_guard_test",
srcs = ["transfer_guard_test.py"],
)
jax_test(
name = "name_stack_test",
srcs = ["name_stack_test.py"],
)
jax_test(
name = "jaxpr_effects_test",
srcs = ["jaxpr_effects_test.py"],
)
jax_test(
name = "debugging_primitives_test",
srcs = ["debugging_primitives_test.py"],
)
jax_test(
name = "debugger_test",
srcs = ["debugger_test.py"],
)
exports_files(
[
"api_test.py",
"pmap_test.py",
"sharded_jit_test.py",
"pjit_test.py",
"transfer_guard_test.py",
],
visibility = jax_test_file_visibility,
)
# This filegroup specifies the set of tests known to Bazel, used for a test that
# verifies every test has a Bazel test rule.
# If a test isn't meant to be tested with Bazel, add it to the exclude list.
filegroup(
name = "all_tests",
srcs = glob(
include = [
"*_test.py",
"third_party/*/*_test.py",
],
exclude = [],
) + ["BUILD"],
visibility = [
"//third_party/py/jax:__subpackages__",
],
)