2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2022-07-01 15:06:54 -07:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2024-03-27 10:26:38 -07:00
|
|
|
load("@rules_python//python:defs.bzl", "py_test")
|
2022-07-01 15:06:54 -07:00
|
|
|
load(
|
|
|
|
"//jaxlib:jax.bzl",
|
|
|
|
"jax_generate_backend_suites",
|
|
|
|
"jax_test",
|
|
|
|
"jax_test_file_visibility",
|
2022-08-05 07:48:40 -07:00
|
|
|
"py_deps",
|
2022-09-14 10:38:54 -07:00
|
|
|
"pytype_test",
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2022-08-24 08:16:16 -07:00
|
|
|
licenses(["notice"])
|
2022-07-01 15:06:54 -07:00
|
|
|
|
2023-04-19 13:26:24 -07:00
|
|
|
package(
|
|
|
|
default_applicable_licenses = [],
|
|
|
|
default_visibility = ["//visibility:private"],
|
|
|
|
)
|
2022-07-01 15:06:54 -07:00
|
|
|
|
|
|
|
jax_generate_backend_suites()
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "api_test",
|
|
|
|
srcs = ["api_test.py"],
|
2022-07-11 13:30:44 +00:00
|
|
|
shard_count = 10,
|
2024-06-17 23:57:16 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2024-07-11 06:53:32 -07:00
|
|
|
jax_test(
|
|
|
|
name = "device_test",
|
|
|
|
srcs = ["device_test.py"],
|
|
|
|
)
|
|
|
|
|
2022-07-14 20:18:14 -07:00
|
|
|
jax_test(
|
|
|
|
name = "dynamic_api_test",
|
|
|
|
srcs = ["dynamic_api_test.py"],
|
|
|
|
shard_count = 2,
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "api_util_test",
|
|
|
|
srcs = ["api_util_test.py"],
|
|
|
|
)
|
|
|
|
|
2023-11-27 15:20:41 -08:00
|
|
|
py_test(
|
|
|
|
name = "array_api_test",
|
|
|
|
srcs = ["array_api_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:experimental_array_api",
|
2024-04-11 14:20:44 +00:00
|
|
|
"//jax:test_util",
|
2023-11-27 15:20:41 -08:00
|
|
|
] + py_deps("absl/testing"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "array_interoperability_test",
|
|
|
|
srcs = ["array_interoperability_test.py"],
|
|
|
|
disable_backends = ["tpu"],
|
2024-01-22 14:24:45 -08:00
|
|
|
tags = ["multiaccelerator"],
|
2023-02-08 12:11:20 -08:00
|
|
|
deps = py_deps("tensorflow_core"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "batching_test",
|
|
|
|
srcs = ["batching_test.py"],
|
2022-07-11 13:30:44 +00:00
|
|
|
shard_count = {
|
|
|
|
"gpu": 5,
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2024-06-25 09:02:32 -07:00
|
|
|
jax_test(
|
|
|
|
name = "config_test",
|
|
|
|
srcs = ["config_test.py"],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "core_test",
|
|
|
|
srcs = ["core_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 5,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 10,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "custom_object_test",
|
|
|
|
srcs = ["custom_object_test.py"],
|
|
|
|
)
|
|
|
|
|
2022-09-30 14:20:57 -07:00
|
|
|
jax_test(
|
2022-07-01 15:06:54 -07:00
|
|
|
name = "debug_nans_test",
|
|
|
|
srcs = ["debug_nans_test.py"],
|
|
|
|
)
|
|
|
|
|
2022-08-04 12:52:16 -07:00
|
|
|
py_test(
|
2022-08-25 15:27:07 -07:00
|
|
|
name = "multiprocess_gpu_test",
|
|
|
|
srcs = ["multiprocess_gpu_test.py"],
|
2022-07-01 15:06:54 -07:00
|
|
|
args = [
|
|
|
|
"--exclude_test_targets=MultiProcessGpuTest",
|
|
|
|
],
|
2022-09-16 07:07:29 -07:00
|
|
|
tags = ["manual"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("portpicker"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "dtypes_test",
|
|
|
|
srcs = ["dtypes_test.py"],
|
|
|
|
)
|
|
|
|
|
2023-01-17 18:42:21 -08:00
|
|
|
jax_test(
|
2022-07-01 15:06:54 -07:00
|
|
|
name = "errors_test",
|
|
|
|
srcs = ["errors_test.py"],
|
2023-01-17 18:42:21 -08:00
|
|
|
# No need to test all other configs.
|
|
|
|
enable_configs = [
|
|
|
|
"cpu",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-08-24 14:40:10 -07:00
|
|
|
jax_test(
|
|
|
|
name = "extend_test",
|
|
|
|
srcs = ["extend_test.py"],
|
|
|
|
deps = ["//jax:extend"],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "fft_test",
|
|
|
|
srcs = ["fft_test.py"],
|
|
|
|
backend_tags = {
|
2022-11-16 06:00:07 -08:00
|
|
|
"tpu": [
|
|
|
|
"noasan",
|
|
|
|
"notsan",
|
|
|
|
], # Times out on TPU with asan/tsan.
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
shard_count = {
|
|
|
|
"tpu": 20,
|
2022-11-23 06:32:41 -08:00
|
|
|
"cpu": 20,
|
2024-07-01 02:57:39 -07:00
|
|
|
"gpu": 10,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "generated_fun_test",
|
|
|
|
srcs = ["generated_fun_test.py"],
|
|
|
|
)
|
|
|
|
|
2023-12-19 13:05:26 -08:00
|
|
|
jax_test(
|
|
|
|
name = "gpu_memory_flags_test_no_preallocation",
|
|
|
|
srcs = ["gpu_memory_flags_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
env = {
|
|
|
|
"XLA_PYTHON_CLIENT_PREALLOCATE": "0",
|
|
|
|
},
|
|
|
|
main = "gpu_memory_flags_test.py",
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "gpu_memory_flags_test",
|
|
|
|
srcs = ["gpu_memory_flags_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
env = {
|
|
|
|
"XLA_PYTHON_CLIENT_PREALLOCATE": "1",
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "lobpcg_test",
|
|
|
|
srcs = ["lobpcg_test.py"],
|
2022-08-04 20:05:18 -07:00
|
|
|
env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"},
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 48,
|
|
|
|
"gpu": 48,
|
|
|
|
"tpu": 48,
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental_sparse",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("matplotlib"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "svd_test",
|
|
|
|
srcs = ["svd_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
"gpu": 10,
|
2023-11-29 05:47:13 -08:00
|
|
|
"tpu": 40,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "xla_interpreter_test",
|
|
|
|
srcs = ["xla_interpreter_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-09-11 17:41:18 -07:00
|
|
|
jax_test(
|
|
|
|
name = "memories_test",
|
|
|
|
srcs = ["memories_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"tpu": 5,
|
|
|
|
},
|
2024-05-17 15:58:25 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
2023-09-11 17:41:18 -07:00
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "pjit_test",
|
|
|
|
srcs = ["pjit_test.py"],
|
2022-09-21 08:50:17 -07:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": ["notsan"], # Times out under tsan.
|
2024-01-12 08:20:06 -08:00
|
|
|
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
2022-09-21 08:50:17 -07:00
|
|
|
},
|
2022-08-10 20:11:06 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 5,
|
|
|
|
"gpu": 5,
|
|
|
|
"tpu": 5,
|
|
|
|
},
|
2022-07-06 12:51:07 -07:00
|
|
|
tags = ["multiaccelerator"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
2023-07-18 14:17:56 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-11-15 08:48:17 -08:00
|
|
|
jax_test(
|
|
|
|
name = "layout_test",
|
|
|
|
srcs = ["layout_test.py"],
|
2024-05-30 05:31:46 -07:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit.
|
|
|
|
},
|
2023-11-15 08:48:17 -08:00
|
|
|
tags = ["multiaccelerator"],
|
|
|
|
)
|
|
|
|
|
2023-12-19 16:30:48 -08:00
|
|
|
jax_test(
|
|
|
|
name = "shard_alike_test",
|
|
|
|
srcs = ["shard_alike_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-07-18 14:17:56 -07:00
|
|
|
jax_test(
|
|
|
|
name = "pgle_test",
|
|
|
|
srcs = ["pgle_test.py"],
|
2024-01-12 08:20:06 -08:00
|
|
|
backend_tags = {
|
|
|
|
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
|
|
|
},
|
2023-07-18 14:17:56 -07:00
|
|
|
disable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
env = {"XLA_FLAGS": "--xla_dump_to=sponge --xla_gpu_enable_latency_hiding_scheduler=true"},
|
|
|
|
tags = [
|
|
|
|
"config-cuda-only",
|
|
|
|
"multiaccelerator",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-10-02 13:04:03 -07:00
|
|
|
jax_test(
|
|
|
|
name = "mock_gpu_test",
|
|
|
|
srcs = ["mock_gpu_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
tags = [
|
|
|
|
"config-cuda-only",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "array_test",
|
|
|
|
srcs = ["array_test.py"],
|
2024-05-30 05:31:46 -07:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": ["requires-mem:16g"], # Under tsan on 2x2 this test exceeds the default 12G memory limit.
|
|
|
|
},
|
2022-07-06 12:51:07 -07:00
|
|
|
tags = ["multiaccelerator"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
2023-10-24 14:37:28 -07:00
|
|
|
"//jax:internal_test_util",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-04-03 14:47:14 -07:00
|
|
|
jax_test(
|
|
|
|
name = "aot_test",
|
|
|
|
srcs = ["aot_test.py"],
|
|
|
|
tags = ["multiaccelerator"],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
2023-05-04 09:52:50 -07:00
|
|
|
] + py_deps("numpy"),
|
2023-04-03 14:47:14 -07:00
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "image_test",
|
|
|
|
srcs = ["image_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 20,
|
2022-07-01 15:06:54 -07:00
|
|
|
"tpu": 10,
|
|
|
|
},
|
2024-01-15 01:40:05 -08:00
|
|
|
tags = ["noasan"], # Linking TF causes a linker OOM.
|
2023-02-08 12:11:20 -08:00
|
|
|
deps = py_deps("pil") + py_deps("tensorflow_core"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "infeed_test",
|
|
|
|
srcs = ["infeed_test.py"],
|
2024-06-03 22:42:52 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental_host_callback",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-09-26 14:38:06 -07:00
|
|
|
jax_test(
|
|
|
|
name = "jax_jit_test",
|
2022-07-01 15:06:54 -07:00
|
|
|
srcs = ["jax_jit_test.py"],
|
|
|
|
main = "jax_jit_test.py",
|
2024-06-03 22:25:41 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "jax_to_ir_test",
|
|
|
|
srcs = ["jax_to_ir_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax:test_util",
|
|
|
|
"//jax/experimental/jax2tf",
|
|
|
|
"//jax/tools:jax_to_ir",
|
2023-02-08 12:11:20 -08:00
|
|
|
] + py_deps("tensorflow_core"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-08-08 10:08:19 -07:00
|
|
|
py_test(
|
2022-07-01 15:06:54 -07:00
|
|
|
name = "jaxpr_util_test",
|
|
|
|
srcs = ["jaxpr_util_test.py"],
|
2023-08-08 10:08:19 -07:00
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:jaxpr_util",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "jet_test",
|
|
|
|
srcs = ["jet_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 10,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
deps = [
|
|
|
|
"//jax:jet",
|
|
|
|
"//jax:stax",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_control_flow_test",
|
|
|
|
srcs = ["lax_control_flow_test.py"],
|
|
|
|
shard_count = {
|
2022-09-06 13:26:41 -07:00
|
|
|
"cpu": 30,
|
2022-09-11 14:30:18 -07:00
|
|
|
"gpu": 40,
|
2022-09-06 13:26:41 -07:00
|
|
|
"tpu": 30,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "custom_root_test",
|
|
|
|
srcs = ["custom_root_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "custom_linear_solve_test",
|
|
|
|
srcs = ["custom_linear_solve_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_test",
|
|
|
|
srcs = ["lax_numpy_test.py"],
|
2022-11-14 07:11:26 -08:00
|
|
|
backend_tags = {
|
2024-01-11 10:42:51 -08:00
|
|
|
"cpu": ["notsan"], # Test times out.
|
2022-11-14 07:11:26 -08:00
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
2022-11-09 18:57:28 -08:00
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
2024-07-22 12:07:24 -07:00
|
|
|
"tpu": 50,
|
2022-09-30 19:37:42 -07:00
|
|
|
},
|
2024-07-23 11:53:48 -07:00
|
|
|
tags = [
|
|
|
|
"noasan", # Test times out on all backends
|
|
|
|
"test_cpu_thunks",
|
|
|
|
],
|
2022-09-30 19:37:42 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_operators_test",
|
|
|
|
srcs = ["lax_numpy_operators_test.py"],
|
|
|
|
shard_count = {
|
2022-11-09 18:57:28 -08:00
|
|
|
"cpu": 30,
|
|
|
|
"gpu": 30,
|
2023-11-29 05:47:13 -08:00
|
|
|
"tpu": 40,
|
2022-09-30 19:37:42 -07:00
|
|
|
},
|
2024-07-23 11:53:48 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2022-09-30 19:37:42 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_reducers_test",
|
|
|
|
srcs = ["lax_numpy_reducers_test.py"],
|
|
|
|
shard_count = {
|
2022-11-09 18:57:28 -08:00
|
|
|
"cpu": 20,
|
|
|
|
"gpu": 20,
|
|
|
|
"tpu": 20,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2024-07-23 11:53:48 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_indexing_test",
|
|
|
|
srcs = ["lax_numpy_indexing_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_einsum_test",
|
|
|
|
srcs = ["lax_numpy_einsum_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
|
|
|
},
|
2024-07-23 11:53:48 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-08-10 14:58:18 -07:00
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_ufuncs_test",
|
|
|
|
srcs = ["lax_numpy_ufuncs_test.py"],
|
2024-07-25 07:25:28 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
"gpu": 10,
|
|
|
|
"tpu": 10,
|
|
|
|
},
|
2024-07-23 11:53:48 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2023-08-10 14:58:18 -07:00
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "lax_numpy_vectorize_test",
|
|
|
|
srcs = ["lax_numpy_vectorize_test.py"],
|
2024-07-23 11:53:48 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_scipy_test",
|
|
|
|
srcs = ["lax_scipy_test.py"],
|
|
|
|
shard_count = {
|
2023-03-06 09:49:52 -08:00
|
|
|
"cpu": 20,
|
|
|
|
"gpu": 20,
|
|
|
|
"tpu": 20,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2023-03-06 09:49:52 -08:00
|
|
|
deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_scipy_sparse_test",
|
|
|
|
srcs = ["lax_scipy_sparse_test.py"],
|
|
|
|
backend_tags = {
|
|
|
|
"cpu": ["nomsan"], # Test fails under msan because of fortran code inside scipy.
|
|
|
|
},
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
2022-07-08 08:29:56 -07:00
|
|
|
"gpu": 10,
|
2022-07-01 15:06:54 -07:00
|
|
|
"tpu": 10,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2023-03-06 09:49:52 -08:00
|
|
|
jax_test(
|
|
|
|
name = "lax_scipy_special_functions_test",
|
|
|
|
srcs = ["lax_scipy_special_functions_test.py"],
|
2024-01-22 12:02:07 -08:00
|
|
|
backend_tags = {
|
|
|
|
"gpu": ["noasan"], # Times out.
|
|
|
|
},
|
2023-03-06 09:49:52 -08:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 20,
|
|
|
|
"gpu": 20,
|
|
|
|
"tpu": 20,
|
|
|
|
},
|
|
|
|
deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_scipy_spectral_dac_test",
|
|
|
|
srcs = ["lax_scipy_spectral_dac_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
|
|
|
"tpu": 40,
|
|
|
|
},
|
|
|
|
deps = [
|
|
|
|
"//jax:internal_test_util",
|
|
|
|
] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "lax_test",
|
|
|
|
srcs = ["lax_test.py"],
|
2023-05-17 08:15:32 -07:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": ["noasan"], # Times out.
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
2023-05-01 13:30:49 -07:00
|
|
|
"tpu": 40,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2024-07-23 11:53:48 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2023-03-30 06:12:10 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:internal_test_util",
|
|
|
|
"//jax:lax_reference",
|
|
|
|
] + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2024-03-12 17:17:20 -07:00
|
|
|
jax_test(
|
|
|
|
name = "lax_metal_test",
|
|
|
|
srcs = ["lax_metal_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
2024-03-12 17:45:19 -07:00
|
|
|
tags = ["notap"],
|
2024-03-12 17:17:20 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:internal_test_util",
|
|
|
|
"//jax:lax_reference",
|
|
|
|
] + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_autodiff_test",
|
|
|
|
srcs = ["lax_autodiff_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
|
|
|
"tpu": 20,
|
|
|
|
},
|
2024-07-23 11:53:48 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_vmap_test",
|
|
|
|
srcs = ["lax_vmap_test.py"],
|
2023-03-07 08:49:05 -08:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
|
|
|
"tpu": 40,
|
2022-11-14 07:11:26 -08:00
|
|
|
},
|
2024-07-23 11:53:48 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2023-03-07 08:49:05 -08:00
|
|
|
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "lax_vmap_op_test",
|
|
|
|
srcs = ["lax_vmap_op_test.py"],
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
2022-11-09 20:31:32 -08:00
|
|
|
"tpu": 40,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2024-07-23 11:53:48 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2023-03-07 08:49:05 -08:00
|
|
|
deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2022-10-06 01:43:54 +00:00
|
|
|
py_test(
|
|
|
|
name = "lazy_loader_test",
|
|
|
|
srcs = [
|
|
|
|
"lazy_loader_test.py",
|
|
|
|
],
|
|
|
|
deps = [
|
2023-03-10 10:28:55 -08:00
|
|
|
"//jax:internal_test_util",
|
2022-10-06 01:43:54 +00:00
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-02-09 05:47:59 -08:00
|
|
|
py_test(
|
|
|
|
name = "deprecation_test",
|
|
|
|
srcs = [
|
|
|
|
"deprecation_test.py",
|
|
|
|
],
|
|
|
|
deps = [
|
2023-02-17 10:55:04 -08:00
|
|
|
"//jax:internal_test_util",
|
2023-02-09 05:47:59 -08:00
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "linalg_test",
|
|
|
|
srcs = ["linalg_test.py"],
|
|
|
|
backend_tags = {
|
|
|
|
"tpu": [
|
|
|
|
"cpu:8",
|
|
|
|
"noasan", # Times out.
|
2022-11-17 08:13:29 -08:00
|
|
|
"nomsan", # Times out.
|
2022-09-22 11:44:22 -07:00
|
|
|
"nodebug", # Times out.
|
2022-08-05 08:33:11 -07:00
|
|
|
"notsan", # Times out.
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
},
|
|
|
|
shard_count = {
|
2022-08-12 13:21:40 +00:00
|
|
|
"cpu": 40,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 40,
|
2022-11-09 18:57:28 -08:00
|
|
|
"tpu": 40,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2024-07-23 11:53:48 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2024-05-14 15:20:37 -07:00
|
|
|
jax_test(
|
|
|
|
name = "cholesky_update_test",
|
|
|
|
srcs = ["cholesky_update_test.py"],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "metadata_test",
|
|
|
|
srcs = ["metadata_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-11-15 12:41:08 -08:00
|
|
|
py_test(
|
|
|
|
name = "monitoring_test",
|
|
|
|
srcs = ["monitoring_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "multibackend_test",
|
|
|
|
srcs = ["multibackend_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "multi_device_test",
|
|
|
|
srcs = ["multi_device_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "nn_test",
|
|
|
|
srcs = ["nn_test.py"],
|
2023-04-26 15:10:47 -07:00
|
|
|
shard_count = {
|
2024-04-18 06:03:40 -07:00
|
|
|
"cpu": 10,
|
2023-04-26 15:10:47 -07:00
|
|
|
"tpu": 10,
|
2024-02-12 05:48:47 -08:00
|
|
|
"gpu": 10,
|
2023-04-26 15:10:47 -07:00
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "optimizers_test",
|
|
|
|
srcs = ["optimizers_test.py"],
|
|
|
|
deps = ["//jax:optimizers"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "pickle_test",
|
|
|
|
srcs = ["pickle_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
2022-11-22 08:30:38 -08:00
|
|
|
] + py_deps("cloudpickle") + py_deps("numpy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "pmap_test",
|
|
|
|
srcs = ["pmap_test.py"],
|
2023-02-15 07:34:54 -08:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": [
|
|
|
|
"noasan", # Times out under asan.
|
2024-05-24 07:24:36 -07:00
|
|
|
"requires-mem:16g", # Under tsan on 2x2 this test exceeds the default 12G memory limit.
|
2023-02-15 07:34:54 -08:00
|
|
|
],
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
2022-08-22 06:51:21 -07:00
|
|
|
"cpu": 30,
|
2022-08-17 10:45:47 -07:00
|
|
|
"gpu": 30,
|
2022-08-22 06:51:21 -07:00
|
|
|
"tpu": 30,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2022-07-06 12:51:07 -07:00
|
|
|
tags = ["multiaccelerator"],
|
2023-10-16 16:01:34 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:internal_test_util",
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "polynomial_test",
|
|
|
|
srcs = ["polynomial_test.py"],
|
|
|
|
# No implementation of nonsymmetric Eigendecomposition.
|
|
|
|
disable_backends = [
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
},
|
|
|
|
# This test ends up calling Fortran code that initializes some memory and
|
|
|
|
# passes it to C code. MSan is not able to detect that the memory was
|
|
|
|
# initialized by Fortran, and it makes the test fail. This can usually be
|
|
|
|
# fixed by annotating the memory with `ANNOTATE_MEMORY_IS_INITIALIZED`, but
|
|
|
|
# in this case there's not a good place to do it, see b/197635968#comment19
|
|
|
|
# for details.
|
|
|
|
tags = ["nomsan"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "heap_profiler_test",
|
|
|
|
srcs = ["heap_profiler_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "profiler_test",
|
|
|
|
srcs = ["profiler_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-02-10 12:10:19 -08:00
|
|
|
jax_test(
|
|
|
|
name = "pytorch_interoperability_test",
|
|
|
|
srcs = ["pytorch_interoperability_test.py"],
|
|
|
|
disable_backends = ["tpu"],
|
2024-03-13 08:30:37 -07:00
|
|
|
# The following cases are disabled because they time out in Google's CI, mostly because the
|
|
|
|
# CUDA kernels in Torch take a very long time to compile.
|
2024-03-04 16:59:05 -08:00
|
|
|
disable_configs = [
|
2024-03-13 08:30:37 -07:00
|
|
|
"gpu_p100", # Pytorch P100 build times out in Google's CI.
|
|
|
|
"gpu_a100", # Pytorch A100 build times out in Google's CI.
|
2024-03-04 16:59:05 -08:00
|
|
|
"gpu_h100", # Pytorch H100 build times out in Google's CI.
|
|
|
|
],
|
2024-02-22 06:28:32 -08:00
|
|
|
tags = [
|
|
|
|
# PyTorch leaks dlpack metadata https://github.com/pytorch/pytorch/issues/117058, and
|
|
|
|
# compilation times out on CPU.
|
|
|
|
"noasan",
|
|
|
|
"not_build:arm",
|
|
|
|
],
|
2023-02-10 12:10:19 -08:00
|
|
|
deps = py_deps("torch"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "qdwh_test",
|
|
|
|
srcs = ["qdwh_test.py"],
|
2024-06-12 12:16:06 -07:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": [
|
|
|
|
"noasan", # Times out
|
|
|
|
"nomsan", # Times out
|
|
|
|
"notsan", # Times out
|
|
|
|
],
|
|
|
|
},
|
2024-06-11 16:03:55 -07:00
|
|
|
shard_count = 10,
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "random_test",
|
|
|
|
srcs = ["random_test.py"],
|
2022-11-14 10:33:53 -08:00
|
|
|
backend_tags = {
|
2023-05-17 08:15:32 -07:00
|
|
|
"cpu": [
|
|
|
|
"notsan", # Times out
|
|
|
|
"nomsan", # Times out
|
|
|
|
],
|
2023-05-05 08:32:56 -07:00
|
|
|
"tpu": [
|
|
|
|
"optonly",
|
|
|
|
"nomsan", # Times out
|
2023-07-12 05:13:03 -07:00
|
|
|
"notsan", # Times out
|
2023-05-05 08:32:56 -07:00
|
|
|
],
|
2022-11-14 10:33:53 -08:00
|
|
|
},
|
2023-10-05 15:28:24 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 30,
|
|
|
|
"gpu": 30,
|
|
|
|
"tpu": 40,
|
|
|
|
},
|
|
|
|
tags = ["noasan"], # Times out
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "random_lax_test",
|
|
|
|
srcs = ["random_lax_test.py"],
|
|
|
|
backend_tags = {
|
|
|
|
"cpu": [
|
|
|
|
"notsan", # Times out
|
|
|
|
"nomsan", # Times out
|
|
|
|
],
|
|
|
|
"tpu": [
|
|
|
|
"optonly",
|
|
|
|
"nomsan", # Times out
|
|
|
|
"notsan", # Times out
|
|
|
|
],
|
2023-10-04 13:04:05 -07:00
|
|
|
},
|
2024-03-05 20:09:14 -08:00
|
|
|
backend_variant_args = {
|
|
|
|
"gpu": ["--jax_num_generated_cases=40"],
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
2023-10-17 12:58:39 -07:00
|
|
|
"cpu": 40,
|
2024-04-30 12:34:20 -07:00
|
|
|
"gpu": 40,
|
2023-05-04 07:41:21 -07:00
|
|
|
"tpu": 40,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2022-11-14 07:11:26 -08:00
|
|
|
tags = ["noasan"], # Times out
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
# TODO(b/199564969): remove once we always enable_custom_prng
|
|
|
|
jax_test(
|
|
|
|
name = "random_test_with_custom_prng",
|
|
|
|
srcs = ["random_test.py"],
|
|
|
|
args = ["--jax_enable_custom_prng=true"],
|
|
|
|
backend_tags = {
|
2022-11-14 07:11:26 -08:00
|
|
|
"cpu": [
|
2023-05-25 10:05:37 -07:00
|
|
|
"noasan", # Times out under asan/msan/tsan.
|
|
|
|
"nomsan",
|
2022-11-14 07:11:26 -08:00
|
|
|
"notsan",
|
2022-11-17 08:13:29 -08:00
|
|
|
],
|
|
|
|
"tpu": [
|
2023-05-04 07:41:21 -07:00
|
|
|
"noasan", # Times out under asan/msan/tsan.
|
|
|
|
"nomsan",
|
2023-04-28 15:26:14 -07:00
|
|
|
"notsan",
|
2022-11-17 08:13:29 -08:00
|
|
|
"optonly",
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
main = "random_test.py",
|
|
|
|
shard_count = {
|
2022-08-12 13:21:40 +00:00
|
|
|
"cpu": 40,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 40,
|
2022-07-28 07:14:07 -07:00
|
|
|
"tpu": 40,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "scipy_fft_test",
|
|
|
|
srcs = ["scipy_fft_test.py"],
|
2022-11-16 06:00:07 -08:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": [
|
|
|
|
"noasan",
|
|
|
|
"notsan",
|
2023-05-01 13:41:42 -07:00
|
|
|
"nomsan",
|
|
|
|
], # Times out on TPU with asan/tsan/msan.
|
2022-11-16 06:00:07 -08:00
|
|
|
},
|
2022-11-09 18:57:28 -08:00
|
|
|
shard_count = 4,
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "scipy_interpolate_test",
|
|
|
|
srcs = ["scipy_interpolate_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "scipy_ndimage_test",
|
|
|
|
srcs = ["scipy_ndimage_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "scipy_optimize_test",
|
|
|
|
srcs = ["scipy_optimize_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "scipy_signal_test",
|
|
|
|
srcs = ["scipy_signal_test.py"],
|
|
|
|
backend_tags = {
|
|
|
|
"cpu": [
|
|
|
|
"noasan", # Test times out under asan.
|
|
|
|
],
|
2022-11-29 08:58:19 -08:00
|
|
|
# TPU test times out under asan/msan/tsan (b/260710050)
|
2022-07-01 15:06:54 -07:00
|
|
|
"tpu": [
|
|
|
|
"noasan",
|
2022-11-29 08:58:19 -08:00
|
|
|
"nomsan",
|
2022-07-01 15:06:54 -07:00
|
|
|
"notsan",
|
2022-11-10 07:38:21 -08:00
|
|
|
"optonly",
|
2022-11-29 08:58:19 -08:00
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2024-03-04 16:59:05 -08:00
|
|
|
disable_configs = [
|
|
|
|
"gpu_h100", # TODO(phawkins): numerical failure on h100
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
2022-11-09 18:57:28 -08:00
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 40,
|
2022-11-09 20:31:32 -08:00
|
|
|
"tpu": 50,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2023-06-02 22:38:15 -05:00
|
|
|
jax_test(
|
|
|
|
name = "scipy_spatial_test",
|
|
|
|
srcs = ["scipy_spatial_test.py"],
|
|
|
|
deps = py_deps("scipy"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "scipy_stats_test",
|
|
|
|
srcs = ["scipy_stats_test.py"],
|
2023-05-04 07:41:21 -07:00
|
|
|
backend_tags = {
|
|
|
|
"tpu": ["nomsan"], # Times out
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
2022-11-09 18:57:28 -08:00
|
|
|
"cpu": 40,
|
2022-11-02 12:06:46 -07:00
|
|
|
"gpu": 30,
|
2022-11-09 18:57:28 -08:00
|
|
|
"tpu": 40,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2022-11-14 07:11:26 -08:00
|
|
|
tags = [
|
|
|
|
"noasan",
|
|
|
|
"notsan",
|
|
|
|
], # Times out
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "sparse_test",
|
|
|
|
srcs = ["sparse_test.py"],
|
|
|
|
args = ["--jax_bcoo_cusparse_lowering=true"],
|
|
|
|
backend_tags = {
|
2022-11-22 12:48:03 -08:00
|
|
|
"cpu": [
|
|
|
|
"nomsan", # Times out
|
|
|
|
"notsan", # Times out
|
|
|
|
],
|
2022-11-10 07:38:21 -08:00
|
|
|
"tpu": ["optonly"],
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2023-02-15 14:57:54 -08:00
|
|
|
# Use fewer cases to prevent timeouts.
|
|
|
|
backend_variant_args = {
|
|
|
|
"cpu": ["--jax_num_generated_cases=40"],
|
|
|
|
"cpu_x32": ["--jax_num_generated_cases=40"],
|
2023-10-04 19:56:04 -07:00
|
|
|
"gpu": ["--jax_num_generated_cases=40"],
|
|
|
|
},
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 50,
|
|
|
|
"gpu": 50,
|
|
|
|
"tpu": 50,
|
|
|
|
},
|
|
|
|
tags = [
|
|
|
|
"noasan",
|
|
|
|
"nomsan",
|
|
|
|
"notsan",
|
|
|
|
], # Test times out under asan/msan/tsan.
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental_sparse",
|
|
|
|
"//jax:sparse_test_util",
|
|
|
|
] + py_deps("scipy"),
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "sparse_bcoo_bcsr_test",
|
|
|
|
srcs = ["sparse_bcoo_bcsr_test.py"],
|
|
|
|
args = ["--jax_bcoo_cusparse_lowering=true"],
|
|
|
|
backend_tags = {
|
|
|
|
"cpu": [
|
|
|
|
"nomsan", # Times out
|
|
|
|
"notsan", # Times out
|
|
|
|
],
|
|
|
|
"tpu": ["optonly"],
|
|
|
|
},
|
|
|
|
# Use fewer cases to prevent timeouts.
|
|
|
|
backend_variant_args = {
|
|
|
|
"cpu": ["--jax_num_generated_cases=40"],
|
|
|
|
"cpu_x32": ["--jax_num_generated_cases=40"],
|
2023-02-15 14:57:54 -08:00
|
|
|
"gpu": ["--jax_num_generated_cases=40"],
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
shard_count = {
|
2023-01-05 04:12:49 -08:00
|
|
|
"cpu": 50,
|
|
|
|
"gpu": 50,
|
2022-11-09 20:31:32 -08:00
|
|
|
"tpu": 50,
|
2022-07-01 15:06:54 -07:00
|
|
|
},
|
2022-11-14 10:33:53 -08:00
|
|
|
tags = [
|
|
|
|
"noasan",
|
2023-01-20 10:40:34 -08:00
|
|
|
"nomsan",
|
2022-11-14 10:33:53 -08:00
|
|
|
"notsan",
|
2023-01-20 10:40:34 -08:00
|
|
|
], # Test times out under asan/msan/tsan.
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental_sparse",
|
2022-11-16 09:58:06 -08:00
|
|
|
"//jax:sparse_test_util",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("scipy"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2024-04-24 01:05:45 -07:00
|
|
|
jax_test(
|
|
|
|
name = "sparse_nm_test",
|
|
|
|
srcs = ["sparse_nm_test.py"],
|
2024-04-25 10:37:12 -07:00
|
|
|
config_tags_overrides = {
|
|
|
|
"gpu_a100": {
|
|
|
|
"ondemand": False, # Include in presubmit.
|
|
|
|
},
|
|
|
|
},
|
2024-04-24 01:05:45 -07:00
|
|
|
disable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"gpu",
|
|
|
|
"tpu",
|
|
|
|
],
|
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100",
|
2024-05-14 03:07:12 -07:00
|
|
|
"gpu_h100",
|
2024-04-24 01:05:45 -07:00
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental_sparse",
|
2024-04-25 10:37:12 -07:00
|
|
|
"//jax:pallas_gpu",
|
2024-04-24 01:05:45 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "sparsify_test",
|
|
|
|
srcs = ["sparsify_test.py"],
|
|
|
|
args = ["--jax_bcoo_cusparse_lowering=true"],
|
2023-02-15 07:34:54 -08:00
|
|
|
backend_tags = {
|
2023-05-11 08:32:55 -07:00
|
|
|
"cpu": [
|
|
|
|
"noasan", # Times out under asan
|
|
|
|
"notsan", # Times out under asan
|
|
|
|
],
|
2023-02-15 07:34:54 -08:00
|
|
|
"tpu": [
|
|
|
|
"noasan", # Times out under asan.
|
|
|
|
],
|
|
|
|
},
|
2022-07-11 13:30:44 +00:00
|
|
|
shard_count = {
|
2022-10-27 13:11:58 -07:00
|
|
|
"cpu": 5,
|
2022-07-11 13:30:44 +00:00
|
|
|
"gpu": 20,
|
2022-07-20 04:30:50 -07:00
|
|
|
"tpu": 10,
|
2022-07-11 13:30:44 +00:00
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental_sparse",
|
2023-02-01 16:16:14 -08:00
|
|
|
"//jax:sparse_test_util",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "stack_test",
|
|
|
|
srcs = ["stack_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "checkify_test",
|
|
|
|
srcs = ["checkify_test.py"],
|
2023-01-11 18:25:30 -08:00
|
|
|
shard_count = {
|
|
|
|
"gpu": 2,
|
2023-01-20 10:26:01 -08:00
|
|
|
"tpu": 2,
|
2023-01-11 18:25:30 -08:00
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "stax_test",
|
|
|
|
srcs = ["stax_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 5,
|
|
|
|
"gpu": 5,
|
|
|
|
},
|
|
|
|
deps = ["//jax:stax"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "linear_search_test",
|
|
|
|
srcs = ["third_party/scipy/line_search_test.py"],
|
|
|
|
main = "third_party/scipy/line_search_test.py",
|
|
|
|
)
|
|
|
|
|
2024-06-24 11:19:59 -07:00
|
|
|
jax_test(
|
|
|
|
name = "blocked_sampler_test",
|
|
|
|
srcs = ["blocked_sampler_test.py"],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
py_test(
|
|
|
|
name = "tree_util_test",
|
|
|
|
srcs = ["tree_util_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-09-14 10:38:54 -07:00
|
|
|
pytype_test(
|
2022-09-13 12:43:51 -07:00
|
|
|
name = "typing_test",
|
|
|
|
srcs = ["typing_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
py_test(
|
|
|
|
name = "util_test",
|
|
|
|
srcs = ["util_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "version_test",
|
|
|
|
srcs = ["version_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "xla_bridge_test",
|
|
|
|
srcs = ["xla_bridge_test.py"],
|
2023-03-21 16:52:49 -07:00
|
|
|
data = ["testdata/example_pjrt_plugin_config.json"],
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax",
|
2023-08-15 06:38:56 -07:00
|
|
|
"//jax:compiler",
|
2022-07-01 15:06:54 -07:00
|
|
|
"//jax:test_util",
|
2022-08-05 07:48:40 -07:00
|
|
|
] + py_deps("absl/logging"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2024-05-30 17:59:05 +04:00
|
|
|
py_test(
|
|
|
|
name = "lru_cache_test",
|
|
|
|
srcs = ["lru_cache_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:lru_cache",
|
|
|
|
"//jax:test_util",
|
|
|
|
] + py_deps("filelock"),
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "compilation_cache_test",
|
|
|
|
srcs = ["compilation_cache_test.py"],
|
2024-07-17 15:28:34 -07:00
|
|
|
tags = ["test_cpu_thunks"],
|
2023-08-15 06:38:56 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:compilation_cache_internal",
|
|
|
|
"//jax:compiler",
|
|
|
|
],
|
2023-07-27 23:00:26 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "cache_key_test",
|
|
|
|
srcs = ["cache_key_test.py"],
|
2023-08-15 06:38:56 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:cache_key",
|
|
|
|
"//jax:compiler",
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "ode_test",
|
|
|
|
srcs = ["ode_test.py"],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 10,
|
|
|
|
},
|
|
|
|
deps = ["//jax:ode"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
2023-08-31 22:07:42 -07:00
|
|
|
name = "host_callback_outfeed_test",
|
2022-07-01 15:06:54 -07:00
|
|
|
srcs = ["host_callback_test.py"],
|
|
|
|
args = ["--jax_host_callback_outfeed=true"],
|
2023-05-19 07:44:18 -07:00
|
|
|
shard_count = {
|
|
|
|
"tpu": 5,
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
"//jax:experimental_host_callback",
|
|
|
|
"//jax:ode",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
2023-08-31 22:07:42 -07:00
|
|
|
name = "host_callback_test",
|
2022-07-01 15:06:54 -07:00
|
|
|
srcs = ["host_callback_test.py"],
|
|
|
|
args = ["--jax_host_callback_outfeed=false"],
|
|
|
|
main = "host_callback_test.py",
|
2022-07-11 13:30:44 +00:00
|
|
|
shard_count = {
|
|
|
|
"gpu": 5,
|
|
|
|
},
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
"//jax:experimental_host_callback",
|
|
|
|
"//jax:ode",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "host_callback_to_tf_test",
|
|
|
|
srcs = ["host_callback_to_tf_test.py"],
|
2024-01-15 01:40:05 -08:00
|
|
|
tags = ["noasan"], # Linking TF causes a linker OOM.
|
2022-07-01 15:06:54 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental_host_callback",
|
|
|
|
"//jax:ode",
|
2023-02-08 12:11:20 -08:00
|
|
|
] + py_deps("tensorflow_core"),
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2023-12-11 12:03:48 -08:00
|
|
|
jax_test(
|
|
|
|
name = "key_reuse_test",
|
|
|
|
srcs = ["key_reuse_test.py"],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "x64_context_test",
|
|
|
|
srcs = ["x64_context_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "ann_test",
|
|
|
|
srcs = ["ann_test.py"],
|
2022-11-09 18:57:28 -08:00
|
|
|
shard_count = 10,
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
py_test(
|
|
|
|
name = "mesh_utils_test",
|
|
|
|
srcs = ["mesh_utils_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:mesh_utils",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "transfer_guard_test",
|
|
|
|
srcs = ["transfer_guard_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "name_stack_test",
|
|
|
|
srcs = ["name_stack_test.py"],
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "jaxpr_effects_test",
|
|
|
|
srcs = ["jaxpr_effects_test.py"],
|
2024-01-12 08:20:06 -08:00
|
|
|
backend_tags = {
|
|
|
|
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
|
|
|
},
|
2022-07-06 20:52:08 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu",
|
|
|
|
"cpu",
|
|
|
|
],
|
2023-09-14 07:52:07 -07:00
|
|
|
tags = ["multiaccelerator"],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
jax_test(
|
|
|
|
name = "debugging_primitives_test",
|
|
|
|
srcs = ["debugging_primitives_test.py"],
|
2022-07-06 20:52:08 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu",
|
|
|
|
"cpu",
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2022-08-08 11:41:46 -07:00
|
|
|
jax_test(
|
|
|
|
name = "python_callback_test",
|
|
|
|
srcs = ["python_callback_test.py"],
|
2024-01-12 08:20:06 -08:00
|
|
|
backend_tags = {
|
|
|
|
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
|
|
|
},
|
2023-09-14 07:52:07 -07:00
|
|
|
tags = ["multiaccelerator"],
|
2022-11-10 12:00:21 -08:00
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
2022-08-08 11:41:46 -07:00
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
jax_test(
|
|
|
|
name = "debugger_test",
|
|
|
|
srcs = ["debugger_test.py"],
|
2022-07-06 20:52:08 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu",
|
|
|
|
"cpu",
|
|
|
|
],
|
2022-07-01 15:06:54 -07:00
|
|
|
)
|
|
|
|
|
2022-08-01 17:05:43 -07:00
|
|
|
jax_test(
|
|
|
|
name = "state_test",
|
|
|
|
srcs = ["state_test.py"],
|
2023-04-04 10:56:25 -07:00
|
|
|
# Use fewer cases to prevent timeouts.
|
|
|
|
args = [
|
|
|
|
"--jax_num_generated_cases=5",
|
|
|
|
],
|
|
|
|
backend_variant_args = {
|
|
|
|
"tpu_pjrt_c_api": ["--jax_num_generated_cases=1"],
|
|
|
|
},
|
2022-08-01 17:05:43 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu",
|
|
|
|
"cpu",
|
|
|
|
],
|
2023-04-04 10:56:25 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 2,
|
|
|
|
"gpu": 2,
|
|
|
|
"tpu": 2,
|
|
|
|
},
|
|
|
|
deps = py_deps("hypothesis"),
|
2022-08-01 17:05:43 -07:00
|
|
|
)
|
|
|
|
|
2024-04-28 21:16:13 -04:00
|
|
|
jax_test(
|
|
|
|
name = "mutable_array_test",
|
|
|
|
srcs = ["mutable_array_test.py"],
|
|
|
|
)
|
|
|
|
|
2022-09-12 14:40:11 -07:00
|
|
|
jax_test(
|
|
|
|
name = "for_loop_test",
|
|
|
|
srcs = ["for_loop_test.py"],
|
|
|
|
shard_count = {
|
2022-10-27 11:41:34 -07:00
|
|
|
"cpu": 20,
|
2022-09-13 12:33:40 -07:00
|
|
|
"gpu": 10,
|
2023-05-04 07:41:21 -07:00
|
|
|
"tpu": 20,
|
2022-09-12 14:40:11 -07:00
|
|
|
},
|
|
|
|
)
|
|
|
|
|
2022-11-04 15:29:10 -07:00
|
|
|
jax_test(
|
|
|
|
name = "shard_map_test",
|
|
|
|
srcs = ["shard_map_test.py"],
|
2023-03-08 12:56:54 -08:00
|
|
|
shard_count = {
|
2023-10-05 13:12:40 -07:00
|
|
|
"cpu": 50,
|
2023-03-08 12:56:54 -08:00
|
|
|
"gpu": 10,
|
2023-10-05 13:12:40 -07:00
|
|
|
"tpu": 50,
|
2023-03-08 12:56:54 -08:00
|
|
|
},
|
2023-11-29 05:47:13 -08:00
|
|
|
tags = [
|
|
|
|
"multiaccelerator",
|
|
|
|
"noasan",
|
|
|
|
"nomsan",
|
|
|
|
"notsan",
|
|
|
|
], # Times out under *SAN.
|
2023-03-10 14:51:08 -08:00
|
|
|
deps = [
|
2023-12-18 14:32:04 -08:00
|
|
|
"//jax:experimental",
|
2023-03-10 14:51:08 -08:00
|
|
|
"//jax:tree_util",
|
|
|
|
],
|
2022-11-04 15:29:10 -07:00
|
|
|
)
|
|
|
|
|
2022-07-20 15:09:47 -07:00
|
|
|
jax_test(
|
|
|
|
name = "clear_backends_test",
|
|
|
|
srcs = ["clear_backends_test.py"],
|
|
|
|
)
|
|
|
|
|
2024-01-25 22:20:36 -08:00
|
|
|
jax_test(
|
|
|
|
name = "attrs_test",
|
|
|
|
srcs = ["attrs_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax:experimental",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-11-28 14:31:10 -08:00
|
|
|
jax_test(
|
|
|
|
name = "experimental_rnn_test",
|
|
|
|
srcs = ["experimental_rnn_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"tpu",
|
|
|
|
"cpu",
|
|
|
|
],
|
2023-02-14 12:01:35 -08:00
|
|
|
disable_configs = [
|
|
|
|
"gpu_a100", # Numerical precision problems.
|
|
|
|
],
|
2024-05-21 10:24:29 -07:00
|
|
|
shard_count = 15,
|
2022-11-28 14:31:10 -08:00
|
|
|
deps = [
|
|
|
|
"//jax:rnn",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-07-20 18:28:18 -07:00
|
|
|
py_test(
|
|
|
|
name = "mosaic_test",
|
|
|
|
srcs = ["mosaic_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:mosaic",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-04-19 13:46:33 -04:00
|
|
|
py_test(
|
|
|
|
name = "source_info_test",
|
|
|
|
srcs = ["source_info_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-04-21 13:20:16 -07:00
|
|
|
py_test(
|
|
|
|
name = "package_structure_test",
|
|
|
|
srcs = ["package_structure_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-05-08 12:43:52 +00:00
|
|
|
jax_test(
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
name = "logging_test",
|
|
|
|
srcs = ["logging_test.py"],
|
|
|
|
)
|
|
|
|
|
2023-09-05 22:15:22 -07:00
|
|
|
jax_test(
|
|
|
|
name = "export_test",
|
|
|
|
srcs = ["export_test.py"],
|
|
|
|
enable_configs = [
|
|
|
|
"tpu_df_2x2",
|
|
|
|
],
|
|
|
|
tags = [],
|
|
|
|
deps = [
|
|
|
|
"//jax/experimental/export",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-11-19 08:59:23 -08:00
|
|
|
jax_test(
|
|
|
|
name = "shape_poly_test",
|
|
|
|
srcs = ["shape_poly_test.py"],
|
|
|
|
disable_configs = [
|
|
|
|
"gpu_a100", # TODO(b/269593297): matmul precision issues
|
|
|
|
],
|
|
|
|
enable_configs = [
|
|
|
|
"cpu",
|
|
|
|
"cpu_x32",
|
|
|
|
],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 4,
|
|
|
|
"gpu": 4,
|
|
|
|
"tpu": 4,
|
|
|
|
},
|
2023-12-04 10:28:28 -08:00
|
|
|
tags = [
|
|
|
|
"noasan", # Times out
|
|
|
|
"nomsan", # Times out
|
|
|
|
"notsan", # Times out
|
|
|
|
],
|
2023-11-19 08:59:23 -08:00
|
|
|
deps = [
|
|
|
|
"//jax:internal_test_harnesses",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2023-11-18 02:52:06 -08:00
|
|
|
jax_test(
|
|
|
|
name = "export_harnesses_multi_platform_test",
|
|
|
|
srcs = ["export_harnesses_multi_platform_test.py"],
|
|
|
|
disable_configs = [
|
|
|
|
"gpu_a100", # TODO(b/269593297): matmul precision issues
|
2024-06-28 06:24:01 -07:00
|
|
|
"gpu_h100", # Scarce resources.
|
2023-11-18 02:52:06 -08:00
|
|
|
],
|
|
|
|
shard_count = {
|
|
|
|
"cpu": 40,
|
|
|
|
"gpu": 20,
|
|
|
|
"tpu": 20,
|
|
|
|
},
|
|
|
|
tags = [
|
2023-12-04 07:53:54 -08:00
|
|
|
"noasan", # Times out, TODO(b/314760446): test failures on Sapphire Rapids.
|
2023-11-18 02:52:06 -08:00
|
|
|
"nodebug", # Times out.
|
2023-12-04 07:53:54 -08:00
|
|
|
"nomsan", # TODO(b/314760446): test failures on Sapphire Rapids.
|
|
|
|
"notsan", # TODO(b/314760446): test failures on Sapphire Rapids.
|
2023-11-18 02:52:06 -08:00
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:internal_test_harnesses",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-01-08 04:47:36 -08:00
|
|
|
jax_test(
|
|
|
|
name = "export_back_compat_test",
|
|
|
|
srcs = ["export_back_compat_test.py"],
|
|
|
|
tags = [],
|
|
|
|
deps = [
|
|
|
|
"//jax:internal_export_back_compat_test_data",
|
|
|
|
"//jax:internal_export_back_compat_test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-01-17 16:09:09 -08:00
|
|
|
jax_test(
|
|
|
|
name = "fused_attention_stablehlo_test",
|
|
|
|
srcs = ["fused_attention_stablehlo_test.py"],
|
|
|
|
disable_backends = [
|
|
|
|
"tpu",
|
|
|
|
"cpu",
|
|
|
|
],
|
2024-03-15 13:22:45 -07:00
|
|
|
shard_count = {
|
|
|
|
"gpu": 4,
|
|
|
|
},
|
|
|
|
tags = ["multiaccelerator"],
|
2024-01-17 16:09:09 -08:00
|
|
|
)
|
|
|
|
|
2024-05-06 09:59:18 -04:00
|
|
|
py_test(
|
|
|
|
name = "pretty_printer_test",
|
|
|
|
srcs = ["pretty_printer_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-05-08 05:44:08 -07:00
|
|
|
py_test(
|
|
|
|
name = "sourcemap_test",
|
|
|
|
srcs = ["sourcemap_test.py"],
|
|
|
|
deps = [
|
|
|
|
"//jax",
|
|
|
|
"//jax:test_util",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2022-07-01 15:06:54 -07:00
|
|
|
exports_files(
|
|
|
|
[
|
|
|
|
"api_test.py",
|
2023-07-11 08:04:40 -07:00
|
|
|
"array_test.py",
|
2023-07-27 23:00:26 -07:00
|
|
|
"cache_key_test.py",
|
2022-07-12 05:20:54 -07:00
|
|
|
"compilation_cache_test.py",
|
2023-09-21 11:54:32 -07:00
|
|
|
"memories_test.py",
|
2023-07-27 23:00:26 -07:00
|
|
|
"pmap_test.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
"pjit_test.py",
|
2022-08-16 17:42:31 -07:00
|
|
|
"python_callback_test.py",
|
2023-09-21 22:52:00 -07:00
|
|
|
"shard_map_test.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
"transfer_guard_test.py",
|
2023-12-07 11:30:08 -08:00
|
|
|
"layout_test.py",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
visibility = jax_test_file_visibility,
|
|
|
|
)
|
|
|
|
|
|
|
|
# This filegroup specifies the set of tests known to Bazel, used for a test that
|
|
|
|
# verifies every test has a Bazel test rule.
|
|
|
|
# If a test isn't meant to be tested with Bazel, add it to the exclude list.
|
|
|
|
filegroup(
|
|
|
|
name = "all_tests",
|
|
|
|
srcs = glob(
|
|
|
|
include = [
|
|
|
|
"*_test.py",
|
|
|
|
"third_party/*/*_test.py",
|
|
|
|
],
|
|
|
|
exclude = [],
|
|
|
|
) + ["BUILD"],
|
|
|
|
visibility = [
|
2024-02-20 02:41:38 -08:00
|
|
|
"//:__subpackages__",
|
2022-07-01 15:06:54 -07:00
|
|
|
],
|
|
|
|
)
|