mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled". We change the relationship between enable_backends, disable_configs, enable_configs to be the following: * `enable_backends` selects a set of initial test configurations to enable, based off backend only. * `disable_configs` then prunes that set of test configurations, removing elements from the set. * `enable_configs` then adds additional configurations to the set. Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target. PiperOrigin-RevId: 679563155
This commit is contained in:
parent
5740ab3b02
commit
26632fd344
@ -28,25 +28,11 @@ package(
|
||||
|
||||
jax_generate_backend_suites()
|
||||
|
||||
DISABLED_BACKENDS = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
]
|
||||
|
||||
DISABLED_CONFIGS = [
|
||||
"gpu_v100",
|
||||
"gpu_a100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_x32",
|
||||
"gpu_pjrt_c_api",
|
||||
]
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "matmul_bench",
|
||||
srcs = ["matmul_bench.py"],
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
disable_configs = DISABLED_CONFIGS,
|
||||
enable_backends = [],
|
||||
enable_configs = ["gpu_h100"],
|
||||
tags = ["notap"],
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
|
@ -32,10 +32,7 @@ jax_multiplatform_test(
|
||||
name = "cuda_custom_call_test",
|
||||
srcs = ["cuda_custom_call_test.py"],
|
||||
data = [":foo"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["gpu"],
|
||||
tags = ["notap"],
|
||||
deps = [
|
||||
"//jax:extend",
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
load("//jaxlib:jax.bzl", "jax_py_test", "py_deps")
|
||||
load("//jaxlib:jax.bzl", "jax_multiplatform_test", "py_deps")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
@ -48,18 +48,17 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
jax_py_test(
|
||||
jax_multiplatform_test(
|
||||
name = "run_matmul",
|
||||
srcs = ["matmul.py"],
|
||||
enable_backends = [],
|
||||
enable_configs = ["gpu_h100"],
|
||||
main = "matmul.py",
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
"requires-gpu-sm90-only",
|
||||
],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:mosaic_gpu",
|
||||
"//learning/brain/research/jax:gpu_support",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
@ -231,15 +231,22 @@ def jax_multiplatform_test(
|
||||
shard_count = None,
|
||||
deps = [],
|
||||
data = [],
|
||||
disable_backends = None, # buildifier: disable=unused-variable
|
||||
enable_backends = None,
|
||||
backend_variant_args = {}, # buildifier: disable=unused-variable
|
||||
backend_tags = {}, # buildifier: disable=unused-variable
|
||||
disable_configs = None, # buildifier: disable=unused-variable
|
||||
enable_configs = None, # buildifier: disable=unused-variable
|
||||
enable_configs = [],
|
||||
config_tags_overrides = None, # buildifier: disable=unused-variable
|
||||
tags = [],
|
||||
main = None,
|
||||
pjrt_c_api_bypass = False): # buildifier: disable=unused-variable
|
||||
# enable_configs and disable_configs do not do anything in OSS, only in Google's CI.
|
||||
# The order in which `enable_backends`, `enable_configs`, and `disable_configs` are applied is
|
||||
# as follows:
|
||||
# 1. `enable_backends` is applied first, enabling all test configs for the given backends.
|
||||
# 2. `disable_configs` is applied second, disabling the named test configs.
|
||||
# 3. `enable_configs` is applied last, enabling the named test configs.
|
||||
|
||||
if main == None:
|
||||
if len(srcs) == 1:
|
||||
main = srcs[0]
|
||||
@ -256,7 +263,7 @@ def jax_multiplatform_test(
|
||||
"--jax_platform_name=" + backend,
|
||||
]
|
||||
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
|
||||
if disable_backends and backend in disable_backends:
|
||||
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
|
||||
test_tags += ["manual"]
|
||||
if backend == "gpu":
|
||||
test_tags += tf_cuda_tests_tags()
|
||||
|
87
tests/BUILD
87
tests/BUILD
@ -66,7 +66,10 @@ jax_py_test(
|
||||
jax_multiplatform_test(
|
||||
name = "array_interoperability_test",
|
||||
srcs = ["array_interoperability_test.py"],
|
||||
disable_backends = ["tpu"],
|
||||
enable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
tags = ["multiaccelerator"],
|
||||
deps = py_deps("tensorflow_core"),
|
||||
)
|
||||
@ -160,10 +163,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "gpu_memory_flags_test_no_preallocation",
|
||||
srcs = ["gpu_memory_flags_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["gpu"],
|
||||
env = {
|
||||
"XLA_PYTHON_CLIENT_PREALLOCATE": "0",
|
||||
},
|
||||
@ -173,10 +173,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "gpu_memory_flags_test",
|
||||
srcs = ["gpu_memory_flags_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["gpu"],
|
||||
env = {
|
||||
"XLA_PYTHON_CLIENT_PREALLOCATE": "1",
|
||||
},
|
||||
@ -273,10 +270,7 @@ jax_multiplatform_test(
|
||||
backend_tags = {
|
||||
"gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
||||
},
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["gpu"],
|
||||
env = {"XLA_FLAGS": "--xla_dump_to=sponge --xla_gpu_enable_latency_hiding_scheduler=true"},
|
||||
tags = [
|
||||
"config-cuda-only",
|
||||
@ -290,10 +284,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "mock_gpu_test",
|
||||
srcs = ["mock_gpu_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["gpu"],
|
||||
tags = [
|
||||
"config-cuda-only",
|
||||
],
|
||||
@ -556,11 +547,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "lax_metal_test",
|
||||
srcs = ["lax_metal_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["metal"],
|
||||
tags = ["notap"],
|
||||
deps = [
|
||||
"//jax:internal_test_util",
|
||||
@ -649,10 +636,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "metadata_test",
|
||||
srcs = ["metadata_test.py"],
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["cpu"],
|
||||
)
|
||||
|
||||
jax_py_test(
|
||||
@ -672,10 +656,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "multi_device_test",
|
||||
srcs = ["multi_device_test.py"],
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["cpu"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
@ -734,10 +715,7 @@ jax_multiplatform_test(
|
||||
name = "polynomial_test",
|
||||
srcs = ["polynomial_test.py"],
|
||||
# No implementation of nonsymmetric Eigendecomposition.
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["cpu"],
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
},
|
||||
@ -753,25 +731,18 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "heap_profiler_test",
|
||||
srcs = ["heap_profiler_test.py"],
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["cpu"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "profiler_test",
|
||||
srcs = ["profiler_test.py"],
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["cpu"],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "pytorch_interoperability_test",
|
||||
srcs = ["pytorch_interoperability_test.py"],
|
||||
disable_backends = ["tpu"],
|
||||
# The following cases are disabled because they time out in Google's CI, mostly because the
|
||||
# CUDA kernels in Torch take a very long time to compile.
|
||||
disable_configs = [
|
||||
@ -779,6 +750,10 @@ jax_multiplatform_test(
|
||||
"gpu_a100", # Pytorch A100 build times out in Google's CI.
|
||||
"gpu_h100", # Pytorch H100 build times out in Google's CI.
|
||||
],
|
||||
enable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
tags = [
|
||||
"not_build:arm",
|
||||
# TODO(b/355237462): Re-enable once MSAN issue is addressed.
|
||||
@ -1019,16 +994,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "sparse_nm_test",
|
||||
srcs = ["sparse_nm_test.py"],
|
||||
config_tags_overrides = {
|
||||
"gpu_a100": {
|
||||
"ondemand": False, # Include in presubmit.
|
||||
},
|
||||
},
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = [],
|
||||
enable_configs = [
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
@ -1386,13 +1352,10 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "experimental_rnn_test",
|
||||
srcs = ["experimental_rnn_test.py"],
|
||||
disable_backends = [
|
||||
"tpu",
|
||||
"cpu",
|
||||
],
|
||||
disable_configs = [
|
||||
"gpu_a100", # Numerical precision problems.
|
||||
],
|
||||
enable_backends = ["gpu"],
|
||||
shard_count = 15,
|
||||
deps = [
|
||||
"//jax:rnn",
|
||||
@ -1505,10 +1468,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "fused_attention_stablehlo_test",
|
||||
srcs = ["fused_attention_stablehlo_test.py"],
|
||||
disable_backends = [
|
||||
"tpu",
|
||||
"cpu",
|
||||
],
|
||||
enable_backends = ["gpu"],
|
||||
shard_count = {
|
||||
"gpu": 4,
|
||||
},
|
||||
@ -1542,10 +1502,7 @@ jax_py_test(
|
||||
jax_multiplatform_test(
|
||||
name = "cudnn_fusion_test",
|
||||
srcs = ["cudnn_fusion_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_backends = ["gpu"],
|
||||
enable_configs = [
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
|
@ -28,31 +28,16 @@ package(
|
||||
|
||||
jax_generate_backend_suites()
|
||||
|
||||
DISABLED_BACKENDS = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
]
|
||||
|
||||
DISABLED_CONFIGS = [
|
||||
"gpu_a100",
|
||||
"gpu_a100_x32",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_pjrt_c_api",
|
||||
"gpu_v100",
|
||||
"gpu_x32",
|
||||
]
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "gpu_test",
|
||||
srcs = ["gpu_test.py"],
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
disable_configs = DISABLED_CONFIGS,
|
||||
enable_backends = [],
|
||||
enable_configs = [
|
||||
"gpu_h100",
|
||||
"gpu_h100_2gpu",
|
||||
],
|
||||
shard_count = 4,
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
] + py_deps("absl/testing") + py_deps("numpy"),
|
||||
@ -61,8 +46,8 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "matmul_test",
|
||||
srcs = ["matmul_test.py"],
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
disable_configs = DISABLED_CONFIGS,
|
||||
enable_backends = [],
|
||||
enable_configs = ["gpu_h100"],
|
||||
shard_count = 5,
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
@ -73,8 +58,8 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "flash_attention",
|
||||
srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"],
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
disable_configs = DISABLED_CONFIGS,
|
||||
enable_backends = [],
|
||||
enable_configs = ["gpu_h100"],
|
||||
main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py",
|
||||
tags = ["notap"],
|
||||
deps = [
|
||||
@ -85,8 +70,8 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "flash_attention_test",
|
||||
srcs = ["flash_attention_test.py"],
|
||||
disable_backends = DISABLED_BACKENDS,
|
||||
disable_configs = DISABLED_CONFIGS,
|
||||
enable_backends = [],
|
||||
enable_configs = ["gpu_h100"],
|
||||
deps = [
|
||||
"//jax:mosaic_gpu",
|
||||
"//jax/experimental/mosaic/gpu/examples:flash_attention",
|
||||
|
@ -38,11 +38,9 @@ jax_multiplatform_test(
|
||||
"ondemand": False, # Include in presubmit.
|
||||
},
|
||||
},
|
||||
disable_configs = [
|
||||
"gpu_v100",
|
||||
"gpu_x32",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
enable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
enable_configs = [
|
||||
"gpu_a100_x32",
|
||||
@ -75,9 +73,6 @@ jax_multiplatform_test(
|
||||
"gpu_p100_x32",
|
||||
"gpu_h100",
|
||||
],
|
||||
shard_count = {
|
||||
"tpu": 1,
|
||||
},
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
"//jax:pallas_tpu",
|
||||
@ -130,8 +125,9 @@ jax_multiplatform_test(
|
||||
srcs = [
|
||||
"indexing_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
enable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
tags = [
|
||||
"noasan", # Times out.
|
||||
@ -154,14 +150,7 @@ jax_multiplatform_test(
|
||||
"ondemand": False, # Include in presubmit.
|
||||
},
|
||||
},
|
||||
disable_configs = [
|
||||
"gpu_v100",
|
||||
"gpu_x32",
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
],
|
||||
enable_backends = ["cpu"],
|
||||
enable_configs = [
|
||||
"gpu_a100_x32",
|
||||
"gpu_h100_x32",
|
||||
@ -186,19 +175,7 @@ jax_multiplatform_test(
|
||||
"ondemand": False, # Include in presubmit.
|
||||
},
|
||||
},
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
disable_configs = [
|
||||
"gpu_v100",
|
||||
"gpu_x32",
|
||||
"gpu_a100",
|
||||
"gpu_a100_x32",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_h100",
|
||||
],
|
||||
enable_backends = [],
|
||||
enable_configs = [
|
||||
"gpu_h100_x32",
|
||||
],
|
||||
@ -220,15 +197,7 @@ jax_multiplatform_test(
|
||||
"ondemand": False, # Include in presubmit.
|
||||
},
|
||||
},
|
||||
disable_configs = [
|
||||
"gpu_v100",
|
||||
"gpu_x32",
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_pjrt_c_api",
|
||||
],
|
||||
enable_backends = ["cpu"],
|
||||
enable_configs = [
|
||||
"gpu_a100_x32",
|
||||
"gpu_h100_x32",
|
||||
@ -251,15 +220,7 @@ jax_multiplatform_test(
|
||||
"ondemand": False, # Include in presubmit.
|
||||
},
|
||||
},
|
||||
disable_configs = [
|
||||
"gpu_v100",
|
||||
"gpu_x32",
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_pjrt_c_api",
|
||||
],
|
||||
enable_backends = ["cpu"],
|
||||
enable_configs = [
|
||||
"gpu_a100_x32",
|
||||
],
|
||||
@ -303,10 +264,7 @@ jax_multiplatform_test(
|
||||
srcs = [
|
||||
"pallas_error_handling_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
"//jax:pallas_tpu",
|
||||
@ -321,10 +279,7 @@ jax_multiplatform_test(
|
||||
srcs = [
|
||||
"tpu_all_gather_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
deps = [
|
||||
"//jax:pallas_tpu_ops",
|
||||
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
||||
@ -335,10 +290,7 @@ jax_multiplatform_test(
|
||||
srcs = [
|
||||
"tpu_gmm_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
shard_count = 50,
|
||||
tags = [
|
||||
"noasan", # Times out.
|
||||
@ -360,10 +312,7 @@ jax_multiplatform_test(
|
||||
srcs = ["tpu_pallas_test.py"],
|
||||
# The flag is necessary for ``pl.debug_print`` tests to work on TPU.
|
||||
args = ["--logtostderr"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
deps = [
|
||||
"//jax:extend",
|
||||
"//jax:pallas_tpu",
|
||||
@ -376,8 +325,9 @@ jax_multiplatform_test(
|
||||
srcs = [
|
||||
"tpu_ops_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
enable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
@ -390,10 +340,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_distributed_test",
|
||||
srcs = ["tpu_pallas_distributed_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
deps = [
|
||||
"//jax:extend",
|
||||
"//jax:pallas_tpu",
|
||||
@ -404,10 +351,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_pipeline_test",
|
||||
srcs = ["tpu_pallas_pipeline_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"noasan", # Times out.
|
||||
@ -424,10 +368,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_async_test",
|
||||
srcs = ["tpu_pallas_async_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
tags = [
|
||||
],
|
||||
deps = [
|
||||
@ -438,10 +379,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_pallas_mesh_test",
|
||||
srcs = ["tpu_pallas_mesh_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
tags = [
|
||||
"noasan",
|
||||
"nomsan",
|
||||
@ -458,10 +396,7 @@ jax_multiplatform_test(
|
||||
srcs = [
|
||||
"tpu_pallas_random_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
deps = [
|
||||
"//jax:pallas",
|
||||
"//jax:pallas_tpu",
|
||||
@ -474,10 +409,7 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "tpu_paged_attention_kernel_test",
|
||||
srcs = ["tpu_paged_attention_kernel_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
"noasan", # Times out.
|
||||
@ -494,10 +426,7 @@ jax_multiplatform_test(
|
||||
srcs = [
|
||||
"tpu_splash_attention_kernel_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
"cpu",
|
||||
],
|
||||
enable_backends = ["tpu"],
|
||||
shard_count = 24,
|
||||
tags = [
|
||||
"noasan", # Times out.
|
||||
@ -514,8 +443,9 @@ jax_multiplatform_test(
|
||||
srcs = [
|
||||
"tpu_splash_attention_mask_test.py",
|
||||
],
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
enable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
deps = [
|
||||
"//jax:pallas_tpu_ops",
|
||||
@ -532,17 +462,7 @@ jax_multiplatform_test(
|
||||
"ondemand": False, # Include in presubmit.
|
||||
},
|
||||
},
|
||||
disable_backends = [
|
||||
"tpu",
|
||||
],
|
||||
disable_configs = [
|
||||
"gpu_v100",
|
||||
"gpu_x32",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
],
|
||||
enable_backends = ["cpu"],
|
||||
enable_configs = [
|
||||
"gpu_a100_x32",
|
||||
"gpu_h100_x32",
|
||||
@ -565,17 +485,7 @@ jax_multiplatform_test(
|
||||
"ondemand": False, # Include in presubmit.
|
||||
},
|
||||
},
|
||||
disable_backends = [
|
||||
"tpu",
|
||||
],
|
||||
disable_configs = [
|
||||
"gpu_v100",
|
||||
"gpu_x32",
|
||||
"gpu_a100",
|
||||
"gpu_h100",
|
||||
"gpu_p100",
|
||||
"gpu_p100_x32",
|
||||
],
|
||||
enable_backends = ["cpu"],
|
||||
enable_configs = [
|
||||
"gpu_a100_x32",
|
||||
"gpu_h100_x32",
|
||||
|
Loading…
x
Reference in New Issue
Block a user