2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2022-07-01 15:06:54 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
load(
|
|
|
|
"//jaxlib:jax.bzl",
|
|
|
|
"jax_generate_backend_suites",
|
|
|
|
"jax_test",
|
|
|
|
"jax_test_file_visibility",
|
2022-08-05 07:48:40 -07:00
|
|
|
"py_deps",
|
2022-09-14 10:38:54 -07:00
|
|
|
"pytype_test",
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2022-08-24 08:16:16 -07:00
|
|
|
licenses(["notice"])
|
2022-07-01 15:06:54 -07:00
|
|
|
|
|
|
|
package(default_visibility = ["//visibility:private"])
|
|
|
|
|
|
|
|
jax_generate_backend_suites()
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "api_test",
|
|
|
|
srcs = ["api_test.py"],
|
2022-07-11 13:30:44 +00:00
|
|
|
shard_count = 10,
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2022-07-14 20:18:14 -07:00
|
|
|
jax_test(
|
|
|
|
name = "dynamic_api_test",
|
|
|
|
srcs = ["dynamic_api_test.py"],
|
|
|
|
shard_count = 2,
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "api_util_test",
|
|
|
|
srcs = ["api_util_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "array_interoperability_test",
|
|
|
|
srcs = ["array_interoperability_test.py"],
|
|
|
|
disable_backends = ["tpu"],
|
2023-02-08 12:11:20 -08:00
|
|
|
deps = py_deps("tensorflow_core"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "batching_test",
|
|
|
|
srcs = ["batching_test.py"],
|
2022-07-11 13:30:44 +00:00
|
|
|
shard_count = {
|
|
|
|
"gpu": 5,
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "core_test",
|
|
|
|
srcs = ["core_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 5,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 10,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "custom_object_test",
|
|
|
|
srcs = ["custom_object_test.py"],
|
|
|
|
)
|
|
|
|
|
2022-09-30 14:20:57 -07:00
|
|
|
jax_test(
|
2022-07-01 15:06:54 -07:00
|
|
|
name = "debug_nans_test",
|
|
|
|
srcs = ["debug_nans_test.py"],
|
|
|
|
)
|
|
|
|
|
2022-08-04 12:52:16 -07:00
|
|
|
py_test(
|
2022-08-25 15:27:07 -07:00
|
|
|
name = "multiprocess_gpu_test",
|
|
|
|
srcs = ["multiprocess_gpu_test.py"],
|
2022-07-01 15:06:54 -07:00
|
|
|
args = [
|
|
|
|
"--exclude_test_targets=MultiProcessGpuTest",
|
|
|
|
],
|
2022-09-16 07:07:29 -07:00
|
|
|
tags = ["manual"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("portpicker"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "dtypes_test",
|
|
|
|
srcs = ["dtypes_test.py"],
|
|
|
|
)
|
|
|
|
|
2023-01-17 18:42:21 -08:00
|
|
|
jax_test(
|
2022-07-01 15:06:54 -07:00
|
|
|
name = "errors_test",
|
|
|
|
srcs = ["errors_test.py"],
|
2023-01-17 18:42:21 -08:00
|
|
|
# No need to test all other configs.
|
|
|
|
enable_configs = [
|
|
|
|
"cpu",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "fft_test",
|
|
|
|
srcs = ["fft_test.py"],
|
|
|
|
backend_tags = {
|
2022-11-16 06:00:07 -08:00
|
|
|
"tpu": [
|
|
|
|
"noasan",
|
|
|
|
"notsan",
|
|
|
|
], # Times out on TPU with asan/tsan.
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
shard_count = {
|
|
|
|
"tpu": 20,
|
2022-11-23 06:32:41 -08:00
|
|
|
"cpu": 20,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "generated_fun_test",
|
|
|
|
srcs = ["generated_fun_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lobpcg_test",
|
|
|
|
srcs = ["lobpcg_test.py"],
|
2022-08-04 20:05:18 -07:00
|
|
|
env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"},
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 48,
|
|
|
|
"gpu": 48,
|
|
|
|
"tpu": 48,
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental_sparse",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("matplotlib"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "svd_test",
|
|
|
|
srcs = ["svd_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "xla_interpreter_test",
|
|
|
|
srcs = ["xla_interpreter_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "xmap_test",
|
|
|
|
srcs = ["xmap_test.py"],
|
2022-12-05 06:51:28 -08:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": ["noasan"], # Times out.
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 4,
|
2022-07-01 15:06:54 -07:00
|
|
|
"tpu": 4,
|
|
|
|
},
|
2022-07-06 12:51:07 -07:00
|
|
|
tags = ["multiaccelerator"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:maps",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "pjit_test",
|
|
|
|
srcs = ["pjit_test.py"],
|
2022-09-21 08:50:17 -07:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": ["notsan"], # Times out under tsan.
|
|
|
|
},
|
2022-08-10 20:11:06 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 5,
|
|
|
|
"gpu": 5,
|
|
|
|
"tpu": 5,
|
|
|
|
},
|
2022-07-06 12:51:07 -07:00
|
|
|
tags = ["multiaccelerator"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "array_test",
|
|
|
|
srcs = ["array_test.py"],
|
2022-07-06 12:51:07 -07:00
|
|
|
tags = ["multiaccelerator"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-04-03 14:47:14 -07:00
|
|
|
jax_test(
|
|
|
|
name = "aot_test",
|
|
|
|
srcs = ["aot_test.py"],
|
|
|
|
tags = ["multiaccelerator"],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "image_test",
|
|
|
|
srcs = ["image_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 20,
|
2022-07-01 15:06:54 -07:00
|
|
|
"tpu": 10,
|
|
|
|
"iree": 10,
|
|
|
|
},
|
2023-02-08 12:11:20 -08:00
|
|
|
deps = py_deps("pil") + py_deps("tensorflow_core"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "infeed_test",
|
|
|
|
srcs = ["infeed_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental_host_callback",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-09-26 14:38:06 -07:00
|
|
|
jax_test(
|
|
|
|
name = "jax_jit_test",
|
2022-07-01 15:06:54 -07:00
|
|
|
srcs = ["jax_jit_test.py"],
|
|
|
|
main = "jax_jit_test.py",
|
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "jax_to_ir_test",
|
|
|
|
srcs = ["jax_to_ir_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax:test_util",
|
|
|
|
"//jax/experimental/jax2tf",
|
|
|
|
"//jax/tools:jax_to_ir",
|
2023-02-08 12:11:20 -08:00
|
|
|
] + py_deps("tensorflow_core"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "jaxpr_util_test",
|
|
|
|
srcs = ["jaxpr_util_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "jet_test",
|
|
|
|
srcs = ["jet_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 10,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
deps = [
|
|
|
|
"//jax:jet",
|
|
|
|
"//jax:stax",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_control_flow_test",
|
|
|
|
srcs = ["lax_control_flow_test.py"],
|
|
|
|
shard_count = {
|
2022-09-06 13:26:41 -07:00
|
|
|
"cpu": 30,
|
2022-09-11 14:30:18 -07:00
|
|
|
"gpu": 40,
|
2022-09-06 13:26:41 -07:00
|
|
|
"tpu": 30,
|
2022-07-01 15:06:54 -07:00
|
|
|
"iree": 10,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "custom_root_test",
|
|
|
|
srcs = ["custom_root_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
|
|
|
"iree": 10,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "custom_linear_solve_test",
|
|
|
|
srcs = ["custom_linear_solve_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
|
|
|
"iree": 10,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_test",
|
|
|
|
srcs = ["lax_numpy_test.py"],
|
2022-11-14 07:11:26 -08:00
|
|
|
backend_tags = {
|
2022-11-17 08:13:29 -08:00
|
|
|
"cpu": ["noasan"], # Test times out.
|
2022-11-14 07:11:26 -08:00
|
|
|
"tpu": ["noasan"], # Test times out.
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
2022-11-09 18:57:28 -08:00
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
|
|
|
"tpu": 40,
|
2022-09-30 19:37:42 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_operators_test",
|
|
|
|
srcs = ["lax_numpy_operators_test.py"],
|
|
|
|
shard_count = {
|
2022-11-09 18:57:28 -08:00
|
|
|
"cpu": 30,
|
|
|
|
"gpu": 30,
|
|
|
|
"tpu": 20,
|
2022-09-30 19:37:42 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_reducers_test",
|
|
|
|
srcs = ["lax_numpy_reducers_test.py"],
|
|
|
|
shard_count = {
|
2022-11-09 18:57:28 -08:00
|
|
|
"cpu": 20,
|
|
|
|
"gpu": 20,
|
|
|
|
"tpu": 20,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_indexing_test",
|
|
|
|
srcs = ["lax_numpy_indexing_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
|
|
|
"iree": 10,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_einsum_test",
|
|
|
|
srcs = ["lax_numpy_einsum_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
|
|
|
"iree": 10,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_vectorize_test",
|
|
|
|
srcs = ["lax_numpy_vectorize_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_scipy_test",
|
|
|
|
srcs = ["lax_scipy_test.py"],
|
|
|
|
shard_count = {
|
2023-03-06 09:49:52 -08:00
|
|
|
"cpu": 20,
|
|
|
|
"gpu": 20,
|
|
|
|
"tpu": 20,
|
2022-07-01 15:06:54 -07:00
|
|
|
"iree": 10,
|
|
|
|
},
|
2023-03-06 09:49:52 -08:00
|
|
|
deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_scipy_sparse_test",
|
|
|
|
srcs = ["lax_scipy_sparse_test.py"],
|
|
|
|
backend_tags = {
|
|
|
|
"cpu": ["nomsan"], # Test fails under msan because of fortran code inside scipy.
|
|
|
|
},
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
2022-07-08 08:29:56 -07:00
|
|
|
"gpu": 10,
|
2022-07-01 15:06:54 -07:00
|
|
|
"tpu": 10,
|
|
|
|
"iree": 10,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2023-03-06 09:49:52 -08:00
|
|
|
jax_test(
|
|
|
|
name = "lax_scipy_special_functions_test",
|
|
|
|
srcs = ["lax_scipy_special_functions_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 20,
|
|
|
|
"gpu": 20,
|
|
|
|
"tpu": 20,
|
|
|
|
"iree": 10,
|
|
|
|
},
|
|
|
|
deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_scipy_spectral_dac_test",
|
|
|
|
srcs = ["lax_scipy_spectral_dac_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
|
|
|
"tpu": 40,
|
|
|
|
"iree": 40,
|
|
|
|
},
|
|
|
|
deps = [
|
|
|
|
"//jax:internal_test_util",
|
|
|
|
] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "lax_test",
|
|
|
|
srcs = ["lax_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
2022-11-09 18:57:28 -08:00
|
|
|
"tpu": 30,
|
2022-07-01 15:06:54 -07:00
|
|
|
"iree": 40,
|
|
|
|
},
|
2023-03-30 06:12:10 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:internal_test_util",
|
|
|
|
"//jax:lax_reference",
|
|
|
|
] + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_autodiff_test",
|
|
|
|
srcs = ["lax_autodiff_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
|
|
|
"tpu": 20,
|
|
|
|
"iree": 40,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_vmap_test",
|
|
|
|
srcs = ["lax_vmap_test.py"],
|
2023-03-07 08:49:05 -08:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
|
|
|
"tpu": 40,
|
|
|
|
"iree": 40,
|
2022-11-14 07:11:26 -08:00
|
|
|
},
|
2023-03-07 08:49:05 -08:00
|
|
|
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_vmap_op_test",
|
|
|
|
srcs = ["lax_vmap_op_test.py"],
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
2022-11-09 20:31:32 -08:00
|
|
|
"tpu": 40,
|
2022-07-01 15:06:54 -07:00
|
|
|
"iree": 40,
|
|
|
|
},
|
2023-03-07 08:49:05 -08:00
|
|
|
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2022-10-06 01:43:54 +00:00
|
|
|
py_test(
|
|
|
|
name = "lazy_loader_test",
|
|
|
|
srcs = [
|
|
|
|
"lazy_loader_test.py",
|
|
|
|
],
|
|
|
|
deps = [
|
2023-03-10 10:28:55 -08:00
|
|
|
"//jax:internal_test_util",
|
2022-10-06 01:43:54 +00:00
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-02-09 05:47:59 -08:00
|
|
|
py_test(
|
|
|
|
name = "deprecation_test",
|
|
|
|
srcs = [
|
|
|
|
"deprecation_test.py",
|
|
|
|
],
|
|
|
|
deps = [
|
2023-02-17 10:55:04 -08:00
|
|
|
"//jax:internal_test_util",
|
2023-02-09 05:47:59 -08:00
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "linalg_test",
|
|
|
|
srcs = ["linalg_test.py"],
|
|
|
|
backend_tags = {
|
|
|
|
"tpu": [
|
|
|
|
"cpu:8",
|
|
|
|
"noasan", # Times out.
|
2022-11-17 08:13:29 -08:00
|
|
|
"nomsan", # Times out.
|
2022-09-22 11:44:22 -07:00
|
|
|
"nodebug", # Times out.
|
2022-08-05 08:33:11 -07:00
|
|
|
"notsan", # Times out.
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
},
|
|
|
|
shard_count = {
|
2022-08-12 13:21:40 +00:00
|
|
|
"cpu": 40,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 40,
|
2022-11-09 18:57:28 -08:00
|
|
|
"tpu": 40,
|
2022-07-01 15:06:54 -07:00
|
|
|
"iree": 20,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "metadata_test",
|
|
|
|
srcs = ["metadata_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-11-15 12:41:08 -08:00
|
|
|
py_test(
|
|
|
|
name = "monitoring_test",
|
|
|
|
srcs = ["monitoring_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "multibackend_test",
|
|
|
|
srcs = ["multibackend_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "multi_device_test",
|
|
|
|
srcs = ["multi_device_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "nn_test",
|
|
|
|
srcs = ["nn_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "optimizers_test",
|
|
|
|
srcs = ["optimizers_test.py"],
|
|
|
|
deps = ["//jax:optimizers"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "pickle_test",
|
|
|
|
srcs = ["pickle_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
2022-11-22 08:30:38 -08:00
|
|
|
] + py_deps("cloudpickle") + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "pmap_test",
|
|
|
|
srcs = ["pmap_test.py"],
|
2023-02-15 07:34:54 -08:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": [
|
|
|
|
"noasan", # Times out under asan.
|
|
|
|
],
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
2022-08-22 06:51:21 -07:00
|
|
|
"cpu": 30,
|
2022-08-17 10:45:47 -07:00
|
|
|
"gpu": 30,
|
2022-08-22 06:51:21 -07:00
|
|
|
"tpu": 30,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2022-07-06 12:51:07 -07:00
|
|
|
tags = ["multiaccelerator"],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "polynomial_test",
|
|
|
|
srcs = ["polynomial_test.py"],
|
|
|
|
# No implementation of nonsymmetric Eigendecomposition.
|
|
|
|
disable_backends = [
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
},
|
|
|
|
# This test ends up calling Fortran code that initializes some memory and
|
|
|
|
# passes it to C code. MSan is not able to detect that the memory was
|
|
|
|
# initialized by Fortran, and it makes the test fail. This can usually be
|
|
|
|
# fixed by annotating the memory with `ANNOTATE_MEMORY_IS_INITIALIZED`, but
|
|
|
|
# in this case there's not a good place to do it, see b/197635968#comment19
|
|
|
|
# for details.
|
|
|
|
tags = ["nomsan"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "heap_profiler_test",
|
|
|
|
srcs = ["heap_profiler_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "profiler_test",
|
|
|
|
srcs = ["profiler_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-02-10 12:10:19 -08:00
|
|
|
jax_test(
|
|
|
|
name = "pytorch_interoperability_test",
|
|
|
|
srcs = ["pytorch_interoperability_test.py"],
|
|
|
|
disable_backends = ["tpu"],
|
|
|
|
deps = py_deps("torch"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "qdwh_test",
|
|
|
|
srcs = ["qdwh_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "random_test",
|
|
|
|
srcs = ["random_test.py"],
|
2022-11-14 10:33:53 -08:00
|
|
|
backend_tags = {
|
|
|
|
"cpu": ["notsan"], # Times out
|
2022-11-16 08:46:57 -08:00
|
|
|
"tpu": ["optonly"],
|
2022-11-14 10:33:53 -08:00
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
2022-11-09 18:57:28 -08:00
|
|
|
"cpu": 30,
|
|
|
|
"gpu": 30,
|
|
|
|
"tpu": 30,
|
|
|
|
"iree": 30,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2022-11-14 07:11:26 -08:00
|
|
|
tags = ["noasan"], # Times out
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
# TODO(b/199564969): remove once we always enable_custom_prng
|
|
|
|
jax_test(
|
|
|
|
name = "random_test_with_custom_prng",
|
|
|
|
srcs = ["random_test.py"],
|
|
|
|
args = ["--jax_enable_custom_prng=true"],
|
|
|
|
backend_tags = {
|
2022-11-14 07:11:26 -08:00
|
|
|
"cpu": [
|
2022-11-17 08:13:29 -08:00
|
|
|
"noasan", # Times out under asan/tsan.
|
2022-11-14 07:11:26 -08:00
|
|
|
"notsan",
|
2022-11-17 08:13:29 -08:00
|
|
|
],
|
|
|
|
"tpu": [
|
|
|
|
"noasan", # Times out under asan/tsan.
|
|
|
|
"optonly",
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
main = "random_test.py",
|
|
|
|
shard_count = {
|
2022-08-12 13:21:40 +00:00
|
|
|
"cpu": 40,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 40,
|
2022-07-28 07:14:07 -07:00
|
|
|
"tpu": 40,
|
2022-07-01 15:06:54 -07:00
|
|
|
"iree": 20,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "scipy_fft_test",
|
|
|
|
srcs = ["scipy_fft_test.py"],
|
2022-11-16 06:00:07 -08:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": [
|
|
|
|
"noasan",
|
|
|
|
"notsan",
|
|
|
|
], # Times out on TPU with asan/tsan.
|
|
|
|
},
|
2022-11-09 18:57:28 -08:00
|
|
|
shard_count = 4,
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "scipy_interpolate_test",
|
|
|
|
srcs = ["scipy_interpolate_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "scipy_ndimage_test",
|
|
|
|
srcs = ["scipy_ndimage_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "scipy_optimize_test",
|
|
|
|
srcs = ["scipy_optimize_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "scipy_signal_test",
|
|
|
|
srcs = ["scipy_signal_test.py"],
|
|
|
|
backend_tags = {
|
|
|
|
"cpu": [
|
|
|
|
"noasan", # Test times out under asan.
|
|
|
|
],
|
2022-11-29 08:58:19 -08:00
|
|
|
# TPU test times out under asan/msan/tsan (b/260710050)
|
2022-07-01 15:06:54 -07:00
|
|
|
"tpu": [
|
|
|
|
"noasan",
|
2022-11-29 08:58:19 -08:00
|
|
|
"nomsan",
|
2022-07-01 15:06:54 -07:00
|
|
|
"notsan",
|
2022-11-10 07:38:21 -08:00
|
|
|
"optonly",
|
2022-11-29 08:58:19 -08:00
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
shard_count = {
|
2022-11-09 18:57:28 -08:00
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
2022-11-09 20:31:32 -08:00
|
|
|
"tpu": 50,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "scipy_stats_test",
|
|
|
|
srcs = ["scipy_stats_test.py"],
|
|
|
|
shard_count = {
|
2022-11-09 18:57:28 -08:00
|
|
|
"cpu": 40,
|
2022-11-02 12:06:46 -07:00
|
|
|
"gpu": 30,
|
2022-11-09 18:57:28 -08:00
|
|
|
"tpu": 40,
|
2022-07-01 15:06:54 -07:00
|
|
|
"iree": 10,
|
|
|
|
},
|
2022-11-14 07:11:26 -08:00
|
|
|
tags = [
|
|
|
|
"noasan",
|
2023-03-06 13:50:03 -08:00
|
|
|
"notap", # TODO(b/271883906): Failing on xla.opt.debug.tpu_hw tap
|
2022-11-14 07:11:26 -08:00
|
|
|
"notsan",
|
|
|
|
], # Times out
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "sparse_test",
|
|
|
|
srcs = ["sparse_test.py"],
|
|
|
|
args = ["--jax_bcoo_cusparse_lowering=true"],
|
|
|
|
backend_tags = {
|
2022-11-22 12:48:03 -08:00
|
|
|
"cpu": [
|
|
|
|
"nomsan", # Times out
|
|
|
|
"notsan", # Times out
|
|
|
|
],
|
2022-11-10 07:38:21 -08:00
|
|
|
"tpu": ["optonly"],
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2023-02-15 14:57:54 -08:00
|
|
|
# Use fewer cases to prevent timeouts.
|
|
|
|
backend_variant_args = {
|
|
|
|
"cpu": ["--jax_num_generated_cases=40"],
|
|
|
|
"cpu_x32": ["--jax_num_generated_cases=40"],
|
|
|
|
"cpu_no_jax_array": ["--jax_num_generated_cases=40"],
|
|
|
|
"gpu": ["--jax_num_generated_cases=40"],
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
2023-01-05 04:12:49 -08:00
|
|
|
"cpu": 50,
|
|
|
|
"gpu": 50,
|
2022-11-09 20:31:32 -08:00
|
|
|
"tpu": 50,
|
2022-07-01 15:06:54 -07:00
|
|
|
"iree": 10,
|
|
|
|
},
|
2022-11-14 10:33:53 -08:00
|
|
|
tags = [
|
|
|
|
"noasan",
|
2023-01-20 10:40:34 -08:00
|
|
|
"nomsan",
|
2022-11-14 10:33:53 -08:00
|
|
|
"notsan",
|
2023-01-20 10:40:34 -08:00
|
|
|
], # Test times out under asan/msan/tsan.
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental_sparse",
|
2022-11-16 09:58:06 -08:00
|
|
|
"//jax:sparse_test_util",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("scipy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "sparsify_test",
|
|
|
|
srcs = ["sparsify_test.py"],
|
|
|
|
args = ["--jax_bcoo_cusparse_lowering=true"],
|
2023-02-15 07:34:54 -08:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": [
|
|
|
|
"noasan", # Times out under asan.
|
|
|
|
],
|
|
|
|
},
|
2022-07-11 13:30:44 +00:00
|
|
|
shard_count = {
|
2022-10-27 13:11:58 -07:00
|
|
|
"cpu": 5,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 20,
|
2022-07-20 04:30:50 -07:00
|
|
|
"tpu": 10,
|
2022-07-11 13:30:44 +00:00
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental_sparse",
|
2023-02-01 16:16:14 -08:00
|
|
|
"//jax:sparse_test_util",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "stack_test",
|
|
|
|
srcs = ["stack_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "checkify_test",
|
|
|
|
srcs = ["checkify_test.py"],
|
2023-01-11 18:25:30 -08:00
|
|
|
shard_count = {
|
|
|
|
"gpu": 2,
|
2023-01-20 10:26:01 -08:00
|
|
|
"tpu": 2,
|
2023-01-11 18:25:30 -08:00
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "stax_test",
|
|
|
|
srcs = ["stax_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 5,
|
|
|
|
"gpu": 5,
|
|
|
|
"iree": 5,
|
|
|
|
},
|
|
|
|
deps = ["//jax:stax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "linear_search_test",
|
|
|
|
srcs = ["third_party/scipy/line_search_test.py"],
|
|
|
|
main = "third_party/scipy/line_search_test.py",
|
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "tree_util_test",
|
|
|
|
srcs = ["tree_util_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-09-14 10:38:54 -07:00
|
|
|
pytype_test(
|
2022-09-13 12:43:51 -07:00
|
|
|
name = "typing_test",
|
|
|
|
srcs = ["typing_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
py_test(
|
|
|
|
name = "util_test",
|
|
|
|
srcs = ["util_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "version_test",
|
|
|
|
srcs = ["version_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "xla_bridge_test",
|
|
|
|
srcs = ["xla_bridge_test.py"],
|
2023-03-21 16:52:49 -07:00
|
|
|
data = ["testdata/example_pjrt_plugin_config.json"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("absl/logging"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "gfile_cache_test",
|
|
|
|
srcs = ["gfile_cache_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:compilation_cache",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "compilation_cache_test",
|
|
|
|
srcs = ["compilation_cache_test.py"],
|
|
|
|
backend_tags = {
|
|
|
|
"tpu": ["nomsan"], # TODO(b/213388298): this test fails msan.
|
|
|
|
},
|
|
|
|
deps = [
|
|
|
|
"//jax:compilation_cache",
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "ode_test",
|
|
|
|
srcs = ["ode_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
},
|
|
|
|
deps = ["//jax:ode"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "host_callback_test",
|
|
|
|
srcs = ["host_callback_test.py"],
|
|
|
|
args = ["--jax_host_callback_outfeed=true"],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
"//jax:experimental_host_callback",
|
|
|
|
"//jax:ode",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "host_callback_custom_call_test",
|
|
|
|
srcs = ["host_callback_test.py"],
|
|
|
|
args = ["--jax_host_callback_outfeed=false"],
|
|
|
|
disable_backends = [
|
2022-07-10 13:04:44 -07:00
|
|
|
"gpu",
|
2022-07-01 15:06:54 -07:00
|
|
|
"tpu", # On TPU we always use outfeed
|
|
|
|
],
|
|
|
|
main = "host_callback_test.py",
|
2022-07-11 13:30:44 +00:00
|
|
|
shard_count = {
|
|
|
|
"gpu": 5,
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
"//jax:experimental_host_callback",
|
|
|
|
"//jax:ode",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "host_callback_to_tf_test",
|
|
|
|
srcs = ["host_callback_to_tf_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental_host_callback",
|
|
|
|
"//jax:ode",
|
2023-02-08 12:11:20 -08:00
|
|
|
] + py_deps("tensorflow_core"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "x64_context_test",
|
|
|
|
srcs = ["x64_context_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "ann_test",
|
|
|
|
srcs = ["ann_test.py"],
|
2022-11-09 18:57:28 -08:00
|
|
|
shard_count = 10,
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "mesh_utils_test",
|
|
|
|
srcs = ["mesh_utils_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:mesh_utils",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "transfer_guard_test",
|
|
|
|
srcs = ["transfer_guard_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "name_stack_test",
|
|
|
|
srcs = ["name_stack_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "jaxpr_effects_test",
|
|
|
|
srcs = ["jaxpr_effects_test.py"],
|
2022-07-06 20:52:08 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu",
|
|
|
|
"cpu",
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "debugging_primitives_test",
|
|
|
|
srcs = ["debugging_primitives_test.py"],
|
2022-07-06 20:52:08 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu",
|
|
|
|
"cpu",
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2022-08-08 11:41:46 -07:00
|
|
|
jax_test(
|
|
|
|
name = "python_callback_test",
|
|
|
|
srcs = ["python_callback_test.py"],
|
2022-11-10 12:00:21 -08:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
2022-08-08 11:41:46 -07:00
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "debugger_test",
|
|
|
|
srcs = ["debugger_test.py"],
|
2022-07-06 20:52:08 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu",
|
|
|
|
"cpu",
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2022-08-01 17:05:43 -07:00
|
|
|
jax_test(
|
|
|
|
name = "state_test",
|
|
|
|
srcs = ["state_test.py"],
|
2023-04-04 10:56:25 -07:00
|
|
|
# Use fewer cases to prevent timeouts.
|
|
|
|
args = [
|
|
|
|
"--jax_num_generated_cases=5",
|
|
|
|
],
|
|
|
|
backend_variant_args = {
|
|
|
|
"tpu_pjrt_c_api": ["--jax_num_generated_cases=1"],
|
|
|
|
},
|
2022-08-01 17:05:43 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu",
|
|
|
|
"cpu",
|
|
|
|
],
|
2023-04-04 10:56:25 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 2,
|
|
|
|
"gpu": 2,
|
|
|
|
"tpu": 2,
|
|
|
|
},
|
|
|
|
deps = py_deps("hypothesis"),
|
2022-08-01 17:05:43 -07:00
|
|
|
)
|
|
|
|
|
2022-09-12 14:40:11 -07:00
|
|
|
jax_test(
|
|
|
|
name = "for_loop_test",
|
|
|
|
srcs = ["for_loop_test.py"],
|
|
|
|
shard_count = {
|
2022-10-27 11:41:34 -07:00
|
|
|
"cpu": 20,
|
2022-09-13 12:33:40 -07:00
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
2022-09-12 14:40:11 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2022-11-04 15:29:10 -07:00
|
|
|
jax_test(
|
|
|
|
name = "shard_map_test",
|
|
|
|
srcs = ["shard_map_test.py"],
|
2023-03-08 12:56:54 -08:00
|
|
|
shard_count = {
|
2023-03-09 10:12:44 -08:00
|
|
|
"cpu": 20,
|
2023-03-08 12:56:54 -08:00
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
|
|
|
},
|
2023-03-10 14:51:08 -08:00
|
|
|
deps = [
|
|
|
|
"//jax:tree_util",
|
|
|
|
],
|
2022-11-04 15:29:10 -07:00
|
|
|
)
|
|
|
|
|
2022-07-20 15:09:47 -07:00
|
|
|
jax_test(
|
|
|
|
name = "clear_backends_test",
|
|
|
|
srcs = ["clear_backends_test.py"],
|
|
|
|
)
|
|
|
|
|
2022-11-28 14:31:10 -08:00
|
|
|
jax_test(
|
|
|
|
name = "experimental_rnn_test",
|
|
|
|
srcs = ["experimental_rnn_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"tpu",
|
|
|
|
"cpu",
|
|
|
|
],
|
2023-02-14 12:01:35 -08:00
|
|
|
disable_configs = [
|
|
|
|
"gpu_a100", # Numerical precision problems.
|
|
|
|
],
|
2023-04-05 15:18:43 -07:00
|
|
|
shard_count = 8,
|
2022-11-28 14:31:10 -08:00
|
|
|
deps = [
|
|
|
|
"//jax:rnn",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
exports_files(
|
|
|
|
[
|
|
|
|
"api_test.py",
|
|
|
|
"pmap_test.py",
|
2022-07-12 05:20:54 -07:00
|
|
|
"compilation_cache_test.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
"pjit_test.py",
|
2022-08-16 17:42:31 -07:00
|
|
|
"python_callback_test.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
"transfer_guard_test.py",
|
|
|
|
],
|
|
|
|
visibility = jax_test_file_visibility,
|
|
|
|
)
|
|
|
|
|
|
|
|
# This filegroup specifies the set of tests known to Bazel, used for a test that
|
|
|
|
# verifies every test has a Bazel test rule.
|
|
|
|
# If a test isn't meant to be tested with Bazel, add it to the exclude list.
|
|
|
|
filegroup(
|
|
|
|
name = "all_tests",
|
|
|
|
srcs = glob(
|
|
|
|
include = [
|
|
|
|
"*_test.py",
|
|
|
|
"third_party/*/*_test.py",
|
|
|
|
],
|
|
|
|
exclude = [],
|
|
|
|
) + ["BUILD"],
|
|
|
|
visibility = [
|
|
|
|
"//third_party/py/jax:__subpackages__",
|
|
|
|
],
|
|
|
|
)
|