From 1fc9afd03a9d7f7c1e84d25929659174ce161c90 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 1 Jul 2022 15:06:54 -0700 Subject: [PATCH] Add support for running JAX tests under Bazel. This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported. To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running: ``` bazel test -c opt //tests/... ``` Issue #7323 PiperOrigin-RevId: 458551208 --- jax/BUILD | 239 +++++++++ jax/BUILD.bazel | 1 - jax/experimental/jax2tf/BUILD | 48 ++ jax/tools/BUILD | 14 +- jaxlib/jax.bzl | 54 +++ tests/BUILD | 877 ++++++++++++++++++++++++++++++++++ 6 files changed, 1227 insertions(+), 6 deletions(-) create mode 100644 jax/BUILD delete mode 100644 jax/BUILD.bazel create mode 100644 jax/experimental/jax2tf/BUILD create mode 100644 tests/BUILD diff --git a/jax/BUILD b/jax/BUILD new file mode 100644 index 000000000..6cc20d8e8 --- /dev/null +++ b/jax/BUILD @@ -0,0 +1,239 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# JAX is Autograd and XLA + +load( + "//jaxlib:jax.bzl", + "absl_logging_py_deps", + "absl_testing_py_deps", + "jax_extra_deps", + "jax_internal_packages", + "jax_test_util_visibility", + "loops_visibility", + "numpy_py_deps", + "py_library_providing_imports_info", + "pytype_library", + "scipy_py_deps", + "sharded_jit_visibility", +) + +licenses(["notice"]) + +package(default_visibility = [":internal"]) + +exports_files([ + "LICENSE", + "version.py", +]) + +# Packages that have access to JAX-internal implementation details. +package_group( + name = "internal", + packages = [ + "//...", + ] + jax_internal_packages, +) + +# JAX-private test utilities. +py_library( + # This build target is required in order to use private test utilities in jax._src.test_util, + # and its visibility is intentionally restricted to discourage its use outside JAX itself. + # JAX does provide some public test utilities (see jax/test_util.py); + # these are available in jax.test_util via the standard :jax target. + name = "test_util", + testonly = 1, + srcs = [ + "_src/test_util.py", + ], + visibility = [ + ":internal", + ] + jax_test_util_visibility, + deps = [ + ":jax", + ] + absl_testing_py_deps + numpy_py_deps, +) + +py_library_providing_imports_info( + name = "jax", + srcs = glob( + [ + "*.py", + "_src/**/*.py", + "image/**/*.py", + "interpreters/**/*.py", + "lax/**/*.py", + "lib/**/*.py", + "nn/**/*.py", + "numpy/**/*.py", + "ops/**/*.py", + "scipy/**/*.py", + "third_party/**/*.py", + ], + exclude = [ + "_src/test_util.py", + "*_test.py", + "**/*_test.py", + "interpreters/sharded_jit.py", + ], + ) + [ + # until new parallelism APIs are moved out of experimental + "experimental/maps.py", + "experimental/pjit.py", + "experimental/global_device_array.py", + "experimental/array.py", + "experimental/sharding.py", + "experimental/multihost_utils.py", + # until checkify is moved out of experimental + "experimental/checkify.py", + # to avoid circular dependencies + "experimental/compilation_cache/compilation_cache.py", + "experimental/compilation_cache/gfile_cache.py", + "experimental/compilation_cache/cache_interface.py", + ], + lib_rule = pytype_library, + visibility = ["//visibility:public"], + deps = [ + "//jaxlib", + ] + numpy_py_deps + scipy_py_deps + jax_extra_deps, +) + +py_library_providing_imports_info( + name = "experimental", + srcs = glob([ + "experimental/*.py", + "example_libraries/*.py", + ]), + visibility = ["//visibility:public"], + deps = [ + ":jax", + ":sharded_jit", + ] + absl_logging_py_deps + numpy_py_deps, +) + +pytype_library( + name = "stax", + srcs = [ + "example_libraries/stax.py", + "experimental/stax.py", + ], + visibility = ["//visibility:public"], + deps = [":jax"], +) + +pytype_library( + name = "experimental_sparse", + srcs = glob([ + "experimental/sparse/*.py", + ]), + visibility = ["//visibility:public"], + deps = [":jax"], +) + +# sharded_jit is deprecated. Please do not add any more projects to the visibility. +pytype_library( + name = "sharded_jit", + srcs = ["interpreters/sharded_jit.py"], + visibility = [ + ":internal", + ] + sharded_jit_visibility, + deps = [":jax"], +) + +pytype_library( + name = "optimizers", + srcs = [ + "example_libraries/optimizers.py", + "experimental/optimizers.py", + ], + visibility = ["//visibility:public"], + deps = [":jax"], +) + +pytype_library( + name = "ode", + srcs = ["experimental/ode.py"], + visibility = ["//visibility:public"], + deps = [":jax"], +) + +# loops is deprecated. Please do not add any more projects to the visibility. +pytype_library( + name = "loops", + srcs = ["experimental/loops.py"], + visibility = [":internal"] + loops_visibility, + deps = [":jax"], +) + +pytype_library( + name = "callback", + srcs = ["experimental/callback.py"], + visibility = ["//visibility:public"], + deps = [":jax"], +) + +# TODO(apaszke): Remove this target +pytype_library( + name = "maps", + srcs = ["experimental/maps.py"], + visibility = ["//visibility:public"], + deps = [":jax"], +) + +# TODO(apaszke): Remove this target +pytype_library( + name = "pjit", + srcs = ["experimental/pjit.py"], + visibility = ["//visibility:public"], + deps = [ + ":experimental", + ":jax", + ], +) + +pytype_library( + name = "jet", + srcs = ["experimental/jet.py"], + visibility = ["//visibility:public"], + deps = [":jax"], +) + +pytype_library( + name = "experimental_host_callback", + srcs = ["experimental/host_callback.py"], + visibility = ["//visibility:public"], + deps = [ + ":jax", + ], +) + +pytype_library( + name = "compilation_cache", + srcs = [ + "experimental/compilation_cache/compilation_cache.py", + "experimental/compilation_cache/gfile_cache.py", + ], + visibility = ["//visibility:public"], + deps = [":jax"], +) + +pytype_library( + name = "mesh_utils", + srcs = ["experimental/mesh_utils.py"], + visibility = ["//visibility:public"], + deps = [ + ":experimental", + ":jax", + ], +) diff --git a/jax/BUILD.bazel b/jax/BUILD.bazel deleted file mode 100644 index 9ea5c9ef3..000000000 --- a/jax/BUILD.bazel +++ /dev/null @@ -1 +0,0 @@ -exports_files(["version.py"]) diff --git a/jax/experimental/jax2tf/BUILD b/jax/experimental/jax2tf/BUILD new file mode 100644 index 000000000..60325c8b1 --- /dev/null +++ b/jax/experimental/jax2tf/BUILD @@ -0,0 +1,48 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "jax2tf_deps", + "numpy_py_deps", + "tensorflow_py_deps", +) + +licenses(["notice"]) # Apache 2 + +package( + default_visibility = ["//visibility:private"], +) + +py_library( + name = "jax2tf", + srcs = ["__init__.py"], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [":jax2tf_internal"], +) + +py_library( + name = "jax2tf_internal", + srcs = [ + "call_tf.py", + "impl_no_xla.py", + "jax2tf.py", + "shape_poly.py", + ], + srcs_version = "PY3", + deps = [ + "//jax", + ] + numpy_py_deps + tensorflow_py_deps + jax2tf_deps, +) diff --git a/jax/tools/BUILD b/jax/tools/BUILD index 38b7a52a9..34e5c2d8e 100644 --- a/jax/tools/BUILD +++ b/jax/tools/BUILD @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +load( + "//jaxlib:jax.bzl", + "tensorflow_py_deps", +) + licenses(["notice"]) package(default_visibility = ["//visibility:public"]) @@ -24,7 +29,7 @@ py_library( "ignore_for_dep=third_party.py.tensorflow", ], deps = [ - "//third_party/py/jax", + "//jax", ], ) @@ -32,8 +37,7 @@ py_library( name = "jax_to_ir_with_tensorflow", srcs = ["jax_to_ir.py"], deps = [ - "//third_party/py/jax", - "//third_party/py/jax/experimental/jax2tf", - "//third_party/py/tensorflow", - ], + "//jax", + "//jax/experimental/jax2tf", + ] + tensorflow_py_deps, ) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 8f61fb929..36b568929 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -32,6 +32,26 @@ if_rocm_is_configured = _if_rocm_is_configured if_windows = _if_windows flatbuffer_cc_library = _flatbuffer_cc_library +jax_internal_packages = [] +jax_test_util_visibility = [] +loops_visibility = [] +sharded_jit_visibility = [] + +absl_logging_py_deps = [] +absl_testing_py_deps = [] +cloudpickle_py_deps = [] +numpy_py_deps = [] +pil_py_deps = [] +portpicker_py_deps = [] +scipy_py_deps = [] +tensorflow_py_deps = [] + +jax_extra_deps = [] +jax2tf_deps = [] + +def py_library_providing_imports_info(*, name, lib_rule = native.py_library, **kwargs): + lib_rule(name = name, **kwargs) + def py_extension(name, srcs, copts, deps): pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name) @@ -100,3 +120,37 @@ def windows_cc_shared_mlir_library(name, out, deps = [], srcs = []): interface_library = interface_library_file, shared_library = out, ) + +def jax_test( + name, + srcs, + args = [], + shard_count = None, + deps = [], + disable_backends = None, # buildifier: disable=unused-variable + backend_tags = {}, # buildifier: disable=unused-variable + disable_configs = None, # buildifier: disable=unused-variable + enable_configs = None, # buildifier: disable=unused-variable + tags = [], + main = None): + if shard_count == None or type(shard_count) == type(0): + shards = shard_count + else: + shards = shard_count.get("cpu", None) + native.py_test( + name = name, + srcs = srcs, + args = args, + deps = [ + "//jax", + "//jax:test_util", + ] + deps, + shard_count = shards, + tags = tags, + main = main, + ) + +def jax_generate_backend_suites(): + pass + +jax_test_file_visibility = [] diff --git a/tests/BUILD b/tests/BUILD new file mode 100644 index 000000000..2c0431811 --- /dev/null +++ b/tests/BUILD @@ -0,0 +1,877 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "absl_logging_py_deps", + "cloudpickle_py_deps", + "jax_generate_backend_suites", + "jax_test", + "jax_test_file_visibility", + "pil_py_deps", + "portpicker_py_deps", + "pytype_library", + "scipy_py_deps", + "tensorflow_py_deps", +) + +licenses(["notice"]) # Apache 2 + +package(default_visibility = ["//visibility:private"]) + +jax_generate_backend_suites() + +jax_test( + name = "api_test", + srcs = ["api_test.py"], + shard_count = 5, +) + +jax_test( + name = "api_util_test", + srcs = ["api_util_test.py"], +) + +jax_test( + name = "array_interoperability_test", + srcs = ["array_interoperability_test.py"], + disable_backends = ["tpu"], + deps = tensorflow_py_deps, +) + +jax_test( + name = "batching_test", + srcs = ["batching_test.py"], +) + +jax_test( + name = "callback_test", + srcs = ["callback_test.py"], + deps = ["//jax:callback"], +) + +jax_test( + name = "core_test", + srcs = ["core_test.py"], + shard_count = { + "cpu": 5, + }, +) + +jax_test( + name = "custom_object_test", + srcs = ["custom_object_test.py"], +) + +py_test( + name = "debug_nans_test", + srcs = ["debug_nans_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + +py_test( + name = "distributed_test", + srcs = ["distributed_test.py"], + args = [ + "--exclude_test_targets=MultiProcessGpuTest", + ], + deps = [ + "//jax", + "//jax:test_util", + ] + portpicker_py_deps, +) + +jax_test( + name = "dtypes_test", + srcs = ["dtypes_test.py"], +) + +py_test( + name = "errors_test", + srcs = ["errors_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + +jax_test( + name = "fft_test", + srcs = ["fft_test.py"], + backend_tags = { + "tpu": ["notsan"], # Times out on TPU with tsan. + }, + shard_count = { + "tpu": 20, + }, +) + +jax_test( + name = "generated_fun_test", + srcs = ["generated_fun_test.py"], +) + +jax_test( + name = "lobpcg_test", + srcs = ["lobpcg_test.py"], + deps = [ + "//jax:experimental_sparse", + ], +) + +jax_test( + name = "svd_test", + srcs = ["svd_test.py"], + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + }, +) + +py_test( + name = "xla_interpreter_test", + srcs = ["xla_interpreter_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + +jax_test( + name = "xmap_test", + srcs = ["xmap_test.py"], + shard_count = { + "cpu": 10, + "tpu": 4, + }, + deps = [ + "//jax:maps", + ], +) + +jax_test( + name = "pjit_test", + srcs = ["pjit_test.py"], + deps = [ + "//jax:experimental", + ], +) + +jax_test( + name = "global_device_array_test", + srcs = ["global_device_array_test.py"], + deps = [ + "//jax:experimental", + ], +) + +jax_test( + name = "array_test", + srcs = ["array_test.py"], + deps = [ + "//jax:experimental", + ], +) + +jax_test( + name = "remote_transfer_test", + srcs = ["remote_transfer_test.py"], + disable_backends = [ + "gpu", + "cpu", + ], + deps = [ + "//jax:experimental", + ], +) + +jax_test( + name = "image_test", + srcs = ["image_test.py"], + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + "iree": 10, + }, + deps = pil_py_deps + tensorflow_py_deps, +) + +jax_test( + name = "infeed_test", + srcs = ["infeed_test.py"], + deps = [ + "//jax:experimental_host_callback", + ], +) + +py_test( + name = "jax_jit_test_x32", + srcs = ["jax_jit_test.py"], + main = "jax_jit_test.py", + visibility = ["//visibility:private"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + +py_test( + name = "jax_jit_test_x64", + srcs = ["jax_jit_test.py"], + args = ["--jax_enable_x64=true"], + main = "jax_jit_test.py", + visibility = ["//visibility:private"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + +test_suite( + name = "jax_jit_test", + tests = [ + "jax_jit_test_x32", + "jax_jit_test_x64", + ], +) + +py_test( + name = "jax_to_ir_test", + srcs = ["jax_to_ir_test.py"], + deps = [ + "//jax:test_util", + "//jax/experimental/jax2tf", + "//jax/tools:jax_to_ir", + ] + tensorflow_py_deps, +) + +jax_test( + name = "jaxpr_util_test", + srcs = ["jaxpr_util_test.py"], +) + +jax_test( + name = "jet_test", + srcs = ["jet_test.py"], + shard_count = { + "cpu": 10, + }, + deps = [ + "//jax:jet", + "//jax:stax", + ], +) + +jax_test( + name = "lax_control_flow_test", + srcs = ["lax_control_flow_test.py"], + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + "iree": 10, + }, +) + +jax_test( + name = "custom_root_test", + srcs = ["custom_root_test.py"], + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + "iree": 10, + }, +) + +jax_test( + name = "custom_linear_solve_test", + srcs = ["custom_linear_solve_test.py"], + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + "iree": 10, + }, +) + +jax_test( + name = "lax_numpy_test", + srcs = ["lax_numpy_test.py"], + shard_count = { + "cpu": 40, + "gpu": 40, + "tpu": 20, + }, +) + +jax_test( + name = "lax_numpy_indexing_test", + srcs = ["lax_numpy_indexing_test.py"], + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + "iree": 10, + }, +) + +jax_test( + name = "lax_numpy_einsum_test", + srcs = ["lax_numpy_einsum_test.py"], + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + "iree": 10, + }, +) + +jax_test( + name = "lax_numpy_vectorize_test", + srcs = ["lax_numpy_vectorize_test.py"], +) + +jax_test( + name = "lax_scipy_test", + srcs = ["lax_scipy_test.py"], + backend_tags = { + "tpu": ["noasan"], # Test times out. + }, + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + "iree": 10, + }, +) + +jax_test( + name = "lax_scipy_sparse_test", + srcs = ["lax_scipy_sparse_test.py"], + backend_tags = { + "cpu": ["nomsan"], # Test fails under msan because of fortran code inside scipy. + }, + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + "iree": 10, + }, +) + +jax_test( + name = "lax_test", + srcs = ["lax_test.py"], + shard_count = { + "cpu": 40, + "gpu": 40, + "tpu": 20, + "iree": 40, + }, +) + +pytype_library( + name = "lax_test_lib", + srcs = ["lax_test.py"], + srcs_version = "PY3", + deps = ["//jax"], +) + +pytype_library( + name = "lax_vmap_test_lib", + testonly = 1, + srcs = ["lax_vmap_test.py"], + srcs_version = "PY3", + deps = [ + ":lax_test_lib", + "//jax", + "//jax:test_util", + ], +) + +jax_test( + name = "lax_autodiff_test", + srcs = ["lax_autodiff_test.py"], + shard_count = { + "cpu": 40, + "gpu": 40, + "tpu": 20, + "iree": 40, + }, + deps = [":lax_test_lib"], +) + +jax_test( + name = "lax_vmap_test", + srcs = ["lax_vmap_test.py"], + shard_count = { + "cpu": 40, + "gpu": 40, + "tpu": 20, + "iree": 40, + }, + deps = [":lax_test_lib"], +) + +jax_test( + name = "linalg_test", + srcs = ["linalg_test.py"], + backend_tags = { + "tpu": [ + "cpu:8", + "noasan", # Times out. + ], + }, + shard_count = { + "cpu": 20, + "gpu": 20, + "tpu": 10, + "iree": 20, + }, +) + +jax_test( + name = "masking_test", + srcs = ["masking_test.py"], +) + +jax_test( + name = "metadata_test", + srcs = ["metadata_test.py"], + disable_backends = [ + "gpu", + "tpu", + ], +) + +jax_test( + name = "multibackend_test", + srcs = ["multibackend_test.py"], +) + +jax_test( + name = "multi_device_test", + srcs = ["multi_device_test.py"], + disable_backends = [ + "gpu", + "tpu", + ], +) + +jax_test( + name = "nn_test", + srcs = ["nn_test.py"], +) + +jax_test( + name = "optimizers_test", + srcs = ["optimizers_test.py"], + deps = ["//jax:optimizers"], +) + +jax_test( + name = "pickle_test", + srcs = ["pickle_test.py"], + deps = [ + "//jax:experimental", + ] + cloudpickle_py_deps, +) + +jax_test( + name = "pmap_test", + srcs = ["pmap_test.py"], + shard_count = { + "cpu": 5, + "gpu": 5, + "tpu": 5, + }, + deps = [ + ":lax_test_lib", + ":lax_vmap_test_lib", + ], +) + +jax_test( + name = "polynomial_test", + srcs = ["polynomial_test.py"], + # No implementation of nonsymmetric Eigendecomposition. + disable_backends = [ + "gpu", + "tpu", + ], + shard_count = { + "cpu": 10, + }, + # This test ends up calling Fortran code that initializes some memory and + # passes it to C code. MSan is not able to detect that the memory was + # initialized by Fortran, and it makes the test fail. This can usually be + # fixed by annotating the memory with `ANNOTATE_MEMORY_IS_INITIALIZED`, but + # in this case there's not a good place to do it, see b/197635968#comment19 + # for details. + tags = ["nomsan"], +) + +jax_test( + name = "heap_profiler_test", + srcs = ["heap_profiler_test.py"], + disable_backends = [ + "gpu", + "tpu", + ], +) + +jax_test( + name = "profiler_test", + srcs = ["profiler_test.py"], + disable_backends = [ + "gpu", + "tpu", + ], +) + +jax_test( + name = "qdwh_test", + srcs = ["qdwh_test.py"], +) + +jax_test( + name = "random_test", + srcs = ["random_test.py"], + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + "iree": 10, + }, +) + +# TODO(b/199564969): remove once we always enable_custom_prng +jax_test( + name = "random_test_with_custom_prng", + srcs = ["random_test.py"], + args = ["--jax_enable_custom_prng=true"], + backend_tags = { + "cpu": ["noasan"], # Times out under asan. + }, + main = "random_test.py", + shard_count = { + "cpu": 30, + "gpu": 20, + "tpu": 20, + "iree": 20, + }, +) + +jax_test( + name = "scipy_fft_test", + srcs = ["scipy_fft_test.py"], +) + +jax_test( + name = "scipy_interpolate_test", + srcs = ["scipy_interpolate_test.py"], +) + +jax_test( + name = "scipy_ndimage_test", + srcs = ["scipy_ndimage_test.py"], +) + +jax_test( + name = "scipy_optimize_test", + srcs = ["scipy_optimize_test.py"], +) + +jax_test( + name = "scipy_signal_test", + srcs = ["scipy_signal_test.py"], + backend_tags = { + "cpu": [ + "noasan", # Test times out under asan. + ], + "tpu": [ + "noasan", + "notsan", + ], # Test times out under asan/tsan. + }, + shard_count = { + "tpu": 5, + }, +) + +jax_test( + name = "scipy_stats_test", + srcs = ["scipy_stats_test.py"], + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + "iree": 10, + }, +) + +jax_test( + name = "sharded_jit_test", + srcs = ["sharded_jit_test.py"], + disable_backends = [ + "tpu", + "iree", + ], + deps = ["//jax:experimental"], +) + +jax_test( + name = "sparse_test", + srcs = ["sparse_test.py"], + args = ["--jax_bcoo_cusparse_lowering=true"], + backend_tags = { + "cpu": [ + "noasan", # Test times out under asan. + ], + }, + shard_count = { + "cpu": 10, + "gpu": 10, + "tpu": 10, + "iree": 10, + }, + deps = [ + "//jax:experimental_sparse", + ] + scipy_py_deps, +) + +jax_test( + name = "sparsify_test", + srcs = ["sparsify_test.py"], + args = ["--jax_bcoo_cusparse_lowering=true"], + deps = [ + "//jax:experimental_sparse", + ], +) + +jax_test( + name = "stack_test", + srcs = ["stack_test.py"], +) + +jax_test( + name = "checkify_test", + srcs = ["checkify_test.py"], +) + +jax_test( + name = "stax_test", + srcs = ["stax_test.py"], + shard_count = { + "cpu": 5, + "gpu": 5, + "iree": 5, + }, + deps = ["//jax:stax"], +) + +jax_test( + name = "linear_search_test", + srcs = ["third_party/scipy/line_search_test.py"], + main = "third_party/scipy/line_search_test.py", +) + +py_test( + name = "tree_util_test", + srcs = ["tree_util_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + +py_test( + name = "util_test", + srcs = ["util_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + +py_test( + name = "version_test", + srcs = ["version_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + +py_test( + name = "xla_bridge_test", + srcs = ["xla_bridge_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + absl_logging_py_deps, +) + +py_test( + name = "gfile_cache_test", + srcs = ["gfile_cache_test.py"], + deps = [ + "//jax", + "//jax:compilation_cache", + "//jax:test_util", + ], +) + +jax_test( + name = "compilation_cache_test", + srcs = ["compilation_cache_test.py"], + backend_tags = { + "tpu": ["nomsan"], # TODO(b/213388298): this test fails msan. + }, + deps = [ + "//jax:compilation_cache", + "//jax:experimental", + ], +) + +jax_test( + name = "ode_test", + srcs = ["ode_test.py"], + shard_count = { + "cpu": 10, + }, + deps = ["//jax:ode"], +) + +jax_test( + name = "host_callback_test", + srcs = ["host_callback_test.py"], + args = ["--jax_host_callback_outfeed=true"], + deps = [ + "//jax:experimental", + "//jax:experimental_host_callback", + "//jax:ode", + ], +) + +jax_test( + name = "host_callback_custom_call_test", + srcs = ["host_callback_test.py"], + args = ["--jax_host_callback_outfeed=false"], + disable_backends = [ + "gpu", + "tpu", # On TPU we always use outfeed + ], + main = "host_callback_test.py", + deps = [ + "//jax:experimental", + "//jax:experimental_host_callback", + "//jax:ode", + ], +) + +jax_test( + name = "host_callback_to_tf_test", + srcs = ["host_callback_to_tf_test.py"], + deps = [ + "//jax:experimental_host_callback", + "//jax:ode", + ] + tensorflow_py_deps, +) + +jax_test( + name = "x64_context_test", + srcs = ["x64_context_test.py"], + deps = [ + "//jax:experimental", + ], +) + +jax_test( + name = "ann_test", + srcs = ["ann_test.py"], + shard_count = 2, + deps = [ + ":lax_test_lib", + ], +) + +py_test( + name = "mesh_utils_test", + srcs = ["mesh_utils_test.py"], + deps = [ + "//jax", + "//jax:mesh_utils", + "//jax:test_util", + ], +) + +jax_test( + name = "transfer_guard_test", + srcs = ["transfer_guard_test.py"], +) + +jax_test( + name = "name_stack_test", + srcs = ["name_stack_test.py"], +) + +jax_test( + name = "jaxpr_effects_test", + srcs = ["jaxpr_effects_test.py"], +) + +jax_test( + name = "debugging_primitives_test", + srcs = ["debugging_primitives_test.py"], +) + +jax_test( + name = "debugger_test", + srcs = ["debugger_test.py"], +) + +exports_files( + [ + "api_test.py", + "pmap_test.py", + "sharded_jit_test.py", + "pjit_test.py", + "transfer_guard_test.py", + ], + visibility = jax_test_file_visibility, +) + +# This filegroup specifies the set of tests known to Bazel, used for a test that +# verifies every test has a Bazel test rule. +# If a test isn't meant to be tested with Bazel, add it to the exclude list. +filegroup( + name = "all_tests", + srcs = glob( + include = [ + "*_test.py", + "third_party/*/*_test.py", + ], + exclude = [], + ) + ["BUILD"], + visibility = [ + "//third_party/py/jax:__subpackages__", + ], +)