2023-08-01 16:42:26 -07:00
|
|
|
# Copyright 2023 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
load(
|
|
|
|
"//jaxlib:jax.bzl",
|
2024-02-19 03:47:08 -08:00
|
|
|
"jax_generate_backend_suites",
|
2024-09-24 12:28:32 -07:00
|
|
|
"jax_multiplatform_test",
|
2023-08-01 16:42:26 -07:00
|
|
|
"py_deps",
|
|
|
|
)
|
|
|
|
|
|
|
|
licenses(["notice"])
|
|
|
|
|
|
|
|
package(
|
|
|
|
default_applicable_licenses = [],
|
|
|
|
default_visibility = ["//visibility:private"],
|
|
|
|
)
|
|
|
|
|
2024-02-19 03:47:08 -08:00
|
|
|
jax_generate_backend_suites()
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2023-08-01 16:42:26 -07:00
|
|
|
name = "pallas_test",
|
|
|
|
srcs = [
|
|
|
|
"pallas_test.py",
|
|
|
|
],
|
2023-09-08 08:07:30 -07:00
|
|
|
config_tags_overrides = {
|
2024-03-28 05:55:20 -07:00
|
|
|
"gpu_a100_x32": {
|
2023-09-08 08:07:30 -07:00
|
|
|
"ondemand": False, # Include in presubmit.
|
|
|
|
},
|
|
|
|
},
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"tpu",
|
2023-08-01 16:42:26 -07:00
|
|
|
],
|
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
2024-03-06 06:14:30 -08:00
|
|
|
"gpu_h100_x32",
|
2023-08-01 16:42:26 -07:00
|
|
|
],
|
2024-07-02 00:40:13 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 8,
|
|
|
|
"gpu": 4,
|
|
|
|
"tpu": 4,
|
|
|
|
},
|
2023-08-01 16:42:26 -07:00
|
|
|
deps = [
|
2024-05-13 03:06:31 -07:00
|
|
|
"//jax:pallas",
|
2024-02-20 02:41:38 -08:00
|
|
|
"//jax:pallas_gpu",
|
2024-06-05 01:34:07 -07:00
|
|
|
"//jax:pallas_gpu_ops",
|
2024-07-02 00:40:13 -07:00
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
2024-04-29 13:11:46 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-08-20 15:06:27 -07:00
|
|
|
name = "pallas_jumble_test",
|
|
|
|
srcs = [
|
|
|
|
"pallas_jumble_test.py",
|
|
|
|
],
|
|
|
|
disable_configs = [
|
2024-09-25 07:37:04 -07:00
|
|
|
"gpu_v100",
|
2024-08-20 15:06:27 -07:00
|
|
|
"gpu_x32",
|
|
|
|
"gpu_a100",
|
|
|
|
"gpu_p100",
|
|
|
|
"gpu_p100_x32",
|
|
|
|
"gpu_h100",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-03-25 08:59:59 -07:00
|
|
|
name = "ops_test",
|
|
|
|
srcs = [
|
|
|
|
"ops_test.py",
|
|
|
|
],
|
|
|
|
config_tags_overrides = {
|
2024-03-28 05:55:20 -07:00
|
|
|
"gpu_a100_x32": {
|
2024-03-25 08:59:59 -07:00
|
|
|
"ondemand": False, # Include in presubmit.
|
|
|
|
},
|
|
|
|
},
|
|
|
|
disable_configs = [
|
2024-09-25 07:37:04 -07:00
|
|
|
"gpu_v100",
|
2024-03-28 05:55:20 -07:00
|
|
|
"gpu_x32",
|
2024-03-25 08:59:59 -07:00
|
|
|
"gpu_p100",
|
2024-03-28 05:55:20 -07:00
|
|
|
"gpu_p100_x32",
|
2024-03-25 08:59:59 -07:00
|
|
|
],
|
|
|
|
enable_configs = [
|
2024-09-19 05:26:56 -07:00
|
|
|
"gpu_a100",
|
2024-03-25 08:59:59 -07:00
|
|
|
"gpu_a100_x32",
|
2024-09-19 05:26:56 -07:00
|
|
|
"gpu_h100",
|
2024-03-25 08:59:59 -07:00
|
|
|
"gpu_h100_x32",
|
|
|
|
],
|
2024-08-12 05:08:56 -07:00
|
|
|
shard_count = {
|
2024-08-13 01:41:59 -07:00
|
|
|
"cpu": 8,
|
|
|
|
"gpu": 8,
|
|
|
|
"tpu": 8,
|
2024-08-12 05:08:56 -07:00
|
|
|
},
|
2024-08-13 01:41:59 -07:00
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
2024-03-25 08:59:59 -07:00
|
|
|
deps = [
|
2024-04-02 14:36:00 -07:00
|
|
|
"//jax:pallas",
|
2024-05-13 03:06:31 -07:00
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
2024-07-07 22:16:24 -07:00
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
2024-08-12 05:08:56 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
2024-03-25 08:59:59 -07:00
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2023-08-01 16:42:26 -07:00
|
|
|
name = "indexing_test",
|
|
|
|
srcs = [
|
|
|
|
"indexing_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"tpu",
|
2024-01-31 09:33:09 -08:00
|
|
|
],
|
2024-08-13 01:41:59 -07:00
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
2023-08-01 16:42:26 -07:00
|
|
|
deps = [
|
2024-02-20 02:41:38 -08:00
|
|
|
"//jax:pallas",
|
2024-08-12 05:08:56 -07:00
|
|
|
"//jax:pallas_tpu",
|
2023-08-01 16:42:26 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
|
|
|
)
|
2023-12-05 00:09:34 -08:00
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-04 13:42:11 -07:00
|
|
|
name = "pallas_vmap_test",
|
|
|
|
srcs = [
|
|
|
|
"pallas_vmap_test.py",
|
|
|
|
],
|
|
|
|
config_tags_overrides = {
|
|
|
|
"gpu_a100_x32": {
|
|
|
|
"ondemand": False, # Include in presubmit.
|
|
|
|
},
|
|
|
|
},
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["cpu"],
|
2024-07-04 13:42:11 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
"gpu_h100_x32",
|
|
|
|
],
|
2024-08-02 10:40:47 -07:00
|
|
|
shard_count = 4,
|
2024-07-04 13:42:11 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu",
|
|
|
|
"//jax:pallas_gpu_ops",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-05-14 14:47:24 -07:00
|
|
|
name = "mosaic_gpu_test",
|
|
|
|
srcs = [
|
|
|
|
"mosaic_gpu_test.py",
|
|
|
|
],
|
|
|
|
config_tags_overrides = {
|
|
|
|
"gpu_h100_x32": {
|
2024-09-11 08:47:03 -07:00
|
|
|
"ondemand": False, # Include in presubmit.
|
2024-05-14 14:47:24 -07:00
|
|
|
},
|
|
|
|
},
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = [],
|
2024-05-14 14:47:24 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_h100_x32",
|
|
|
|
],
|
|
|
|
env = {
|
|
|
|
"JAX_PALLAS_USE_MOSAIC_GPU": "1",
|
|
|
|
},
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
2024-07-23 16:12:14 -07:00
|
|
|
"//jax/_src/pallas/mosaic_gpu",
|
2024-05-14 14:47:24 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-05-02 05:37:41 -07:00
|
|
|
name = "export_back_compat_pallas_test",
|
|
|
|
srcs = ["export_back_compat_pallas_test.py"],
|
|
|
|
config_tags_overrides = {
|
|
|
|
"gpu_a100_x32": {
|
|
|
|
"ondemand": False, # Include in presubmit.
|
|
|
|
},
|
|
|
|
},
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["cpu"],
|
2024-05-02 05:37:41 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
"gpu_h100_x32",
|
|
|
|
],
|
|
|
|
tags = [],
|
|
|
|
deps = [
|
|
|
|
"//jax:internal_export_back_compat_test_data",
|
|
|
|
"//jax:internal_export_back_compat_test_util",
|
2024-05-13 03:06:31 -07:00
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
2024-08-07 04:59:19 -07:00
|
|
|
"//jax:pallas_tpu_ops", # build_cleaner: keep
|
2024-05-02 05:37:41 -07:00
|
|
|
],
|
|
|
|
)
|
2024-06-10 03:12:39 -07:00
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-06-10 03:12:39 -07:00
|
|
|
name = "export_pallas_test",
|
|
|
|
srcs = ["export_pallas_test.py"],
|
|
|
|
config_tags_overrides = {
|
|
|
|
"gpu_a100_x32": {
|
|
|
|
"ondemand": False, # Include in presubmit.
|
|
|
|
},
|
|
|
|
},
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["cpu"],
|
2024-06-10 03:12:39 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
],
|
|
|
|
tags = [],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
|
|
|
"//jax:pallas_tpu", # build_cleaner: keep
|
|
|
|
],
|
|
|
|
)
|
2024-07-01 10:17:21 -07:00
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-01 10:17:21 -07:00
|
|
|
name = "pallas_shape_poly_test",
|
|
|
|
srcs = ["pallas_shape_poly_test.py"],
|
|
|
|
config_tags_overrides = {
|
|
|
|
"gpu_a100_x32": {
|
|
|
|
"ondemand": False, # Include in presubmit.
|
|
|
|
},
|
|
|
|
},
|
|
|
|
disable_configs = [
|
|
|
|
"gpu_x32",
|
|
|
|
"gpu_h100",
|
|
|
|
"gpu_p100",
|
|
|
|
"gpu_p100_x32",
|
|
|
|
"gpu_pjrt_c_api",
|
|
|
|
],
|
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
],
|
|
|
|
tags = [],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
|
|
|
"//jax:pallas_tpu", # build_cleaner: keep
|
|
|
|
"//jax/experimental/export",
|
|
|
|
],
|
|
|
|
)
|
2024-07-10 08:22:20 -07:00
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-08-12 14:41:58 -07:00
|
|
|
name = "pallas_error_handling_test",
|
|
|
|
srcs = [
|
|
|
|
"pallas_error_handling_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-08-12 14:41:58 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax/_src/pallas/mosaic:random",
|
|
|
|
"//third_party/py/absl/testing:absltest",
|
|
|
|
"//third_party/py/absl/testing:parameterized",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_all_gather_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_all_gather_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_gmm_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_gmm_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
shard_count = 50,
|
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps([
|
|
|
|
"absl/testing",
|
|
|
|
"absl/flags",
|
|
|
|
"numpy",
|
|
|
|
"hypothesis",
|
|
|
|
]),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_pallas_test",
|
|
|
|
srcs = ["tpu_pallas_test.py"],
|
|
|
|
# The flag is necessary for ``pl.debug_print`` tests to work on TPU.
|
|
|
|
args = ["--logtostderr"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:extend",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-08-12 05:08:56 -07:00
|
|
|
name = "tpu_ops_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_ops_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"tpu",
|
2024-08-12 05:08:56 -07:00
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-31 08:18:47 -07:00
|
|
|
name = "tpu_pallas_distributed_test",
|
|
|
|
srcs = ["tpu_pallas_distributed_test.py"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-07-31 08:18:47 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:extend",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_pallas_pipeline_test",
|
|
|
|
srcs = ["tpu_pallas_pipeline_test.py"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
shard_count = 5,
|
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:extend",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("hypothesis"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-09-18 20:38:54 -07:00
|
|
|
name = "tpu_pallas_async_test",
|
|
|
|
srcs = ["tpu_pallas_async_test.py"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-09-18 20:38:54 -07:00
|
|
|
tags = [
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
name = "tpu_pallas_mesh_test",
|
|
|
|
srcs = ["tpu_pallas_mesh_test.py"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
tags = [
|
|
|
|
"noasan",
|
|
|
|
"nomsan",
|
|
|
|
"notsan",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:extend",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_pallas_random_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_pallas_random_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax/_src/pallas/mosaic:random",
|
|
|
|
"//third_party/py/absl/testing:absltest",
|
|
|
|
"//third_party/py/absl/testing:parameterized",
|
|
|
|
] + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_paged_attention_kernel_test",
|
|
|
|
srcs = ["tpu_paged_attention_kernel_test.py"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
shard_count = 5,
|
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_splash_attention_kernel_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_splash_attention_kernel_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-07-22 22:04:40 -07:00
|
|
|
shard_count = 24,
|
2024-07-10 08:22:20 -07:00
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_splash_attention_mask_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_splash_attention_mask_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"tpu",
|
2024-07-10 08:22:20 -07:00
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "gpu_attention_test",
|
|
|
|
srcs = [
|
|
|
|
"gpu_attention_test.py",
|
|
|
|
],
|
|
|
|
config_tags_overrides = {
|
|
|
|
"gpu_a100_x32": {
|
|
|
|
"ondemand": False, # Include in presubmit.
|
|
|
|
},
|
|
|
|
},
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["cpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
"gpu_h100_x32",
|
|
|
|
],
|
|
|
|
shard_count = 1,
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
|
|
|
"//jax:pallas_gpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "gpu_ops_test",
|
|
|
|
srcs = [
|
|
|
|
"gpu_ops_test.py",
|
|
|
|
],
|
|
|
|
config_tags_overrides = {
|
|
|
|
"gpu_a100_x32": {
|
|
|
|
"ondemand": False, # Include in presubmit.
|
|
|
|
},
|
|
|
|
},
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["cpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
"gpu_h100_x32",
|
|
|
|
],
|
|
|
|
shard_count = 2,
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu",
|
|
|
|
"//jax:pallas_gpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|