2023-08-01 16:42:26 -07:00
|
|
|
# Copyright 2023 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
load(
|
|
|
|
"//jaxlib:jax.bzl",
|
2024-02-19 03:47:08 -08:00
|
|
|
"jax_generate_backend_suites",
|
2025-02-03 13:21:01 -08:00
|
|
|
"jax_gpu_support_deps",
|
2024-09-24 12:28:32 -07:00
|
|
|
"jax_multiplatform_test",
|
2024-09-27 13:35:42 -07:00
|
|
|
"jax_py_test",
|
2023-08-01 16:42:26 -07:00
|
|
|
"py_deps",
|
|
|
|
)
|
|
|
|
|
|
|
|
licenses(["notice"])
|
|
|
|
|
|
|
|
package(
|
|
|
|
default_applicable_licenses = [],
|
|
|
|
default_visibility = ["//visibility:private"],
|
|
|
|
)
|
|
|
|
|
2024-02-19 03:47:08 -08:00
|
|
|
jax_generate_backend_suites()
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2023-08-01 16:42:26 -07:00
|
|
|
name = "pallas_test",
|
|
|
|
srcs = [
|
|
|
|
"pallas_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"tpu",
|
2023-08-01 16:42:26 -07:00
|
|
|
],
|
|
|
|
enable_configs = [
|
2025-01-02 06:22:56 -08:00
|
|
|
"gpu_a100",
|
|
|
|
"gpu_h100",
|
2023-08-01 16:42:26 -07:00
|
|
|
],
|
2024-07-02 00:40:13 -07:00
|
|
|
shard_count = {
|
|
|
|
"cpu": 8,
|
|
|
|
"gpu": 4,
|
|
|
|
"tpu": 4,
|
|
|
|
},
|
2023-08-01 16:42:26 -07:00
|
|
|
deps = [
|
2024-05-13 03:06:31 -07:00
|
|
|
"//jax:pallas",
|
2024-02-20 02:41:38 -08:00
|
|
|
"//jax:pallas_gpu",
|
2024-06-05 01:34:07 -07:00
|
|
|
"//jax:pallas_gpu_ops",
|
2024-11-11 11:13:16 -08:00
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_multiplatform_test(
|
|
|
|
name = "pallas_cost_estimate_test",
|
|
|
|
srcs = [
|
|
|
|
"pallas_cost_estimate_test.py",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu",
|
|
|
|
"//jax:pallas_gpu_ops",
|
2024-07-02 00:40:13 -07:00
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
2024-04-29 13:11:46 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
2023-08-01 16:42:26 -07:00
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-08-20 15:06:27 -07:00
|
|
|
name = "pallas_jumble_test",
|
|
|
|
srcs = [
|
|
|
|
"pallas_jumble_test.py",
|
|
|
|
],
|
|
|
|
disable_configs = [
|
2024-09-25 07:37:04 -07:00
|
|
|
"gpu_v100",
|
2025-01-14 11:55:02 -08:00
|
|
|
"gpu_v100_x32",
|
2024-08-20 15:06:27 -07:00
|
|
|
"gpu_a100",
|
|
|
|
"gpu_p100",
|
|
|
|
"gpu_p100_x32",
|
|
|
|
"gpu_h100",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-03-25 08:59:59 -07:00
|
|
|
name = "ops_test",
|
|
|
|
srcs = [
|
|
|
|
"ops_test.py",
|
|
|
|
],
|
|
|
|
disable_configs = [
|
2024-09-25 07:37:04 -07:00
|
|
|
"gpu_v100",
|
2025-01-14 11:55:02 -08:00
|
|
|
"gpu_v100_x32",
|
2024-03-25 08:59:59 -07:00
|
|
|
"gpu_p100",
|
2024-03-28 05:55:20 -07:00
|
|
|
"gpu_p100_x32",
|
2024-03-25 08:59:59 -07:00
|
|
|
],
|
|
|
|
enable_configs = [
|
2024-09-19 05:26:56 -07:00
|
|
|
"gpu_a100",
|
2024-03-25 08:59:59 -07:00
|
|
|
"gpu_a100_x32",
|
2024-09-19 05:26:56 -07:00
|
|
|
"gpu_h100",
|
2024-03-25 08:59:59 -07:00
|
|
|
"gpu_h100_x32",
|
2025-01-27 09:14:38 -08:00
|
|
|
"tpu_v6e_1x1",
|
2024-03-25 08:59:59 -07:00
|
|
|
],
|
2024-08-12 05:08:56 -07:00
|
|
|
shard_count = {
|
2025-01-10 04:34:32 -08:00
|
|
|
"cpu": 16,
|
|
|
|
"gpu": 16,
|
|
|
|
"tpu": 16,
|
2024-08-12 05:08:56 -07:00
|
|
|
},
|
2024-08-13 01:41:59 -07:00
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
2024-03-25 08:59:59 -07:00
|
|
|
deps = [
|
2024-04-02 14:36:00 -07:00
|
|
|
"//jax:pallas",
|
2024-05-13 03:06:31 -07:00
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
2024-07-07 22:16:24 -07:00
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
2024-08-12 05:08:56 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
2024-03-25 08:59:59 -07:00
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2023-08-01 16:42:26 -07:00
|
|
|
name = "indexing_test",
|
|
|
|
srcs = [
|
|
|
|
"indexing_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"tpu",
|
2024-01-31 09:33:09 -08:00
|
|
|
],
|
2024-08-13 01:41:59 -07:00
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
2023-08-01 16:42:26 -07:00
|
|
|
deps = [
|
2024-02-20 02:41:38 -08:00
|
|
|
"//jax:pallas",
|
2024-08-12 05:08:56 -07:00
|
|
|
"//jax:pallas_tpu",
|
2023-08-01 16:42:26 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
|
|
|
)
|
2023-12-05 00:09:34 -08:00
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-04 13:42:11 -07:00
|
|
|
name = "pallas_vmap_test",
|
|
|
|
srcs = [
|
|
|
|
"pallas_vmap_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["cpu"],
|
2024-07-04 13:42:11 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
"gpu_h100_x32",
|
|
|
|
],
|
2024-08-02 10:40:47 -07:00
|
|
|
shard_count = 4,
|
2024-07-04 13:42:11 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu",
|
|
|
|
"//jax:pallas_gpu_ops",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-05-14 14:47:24 -07:00
|
|
|
name = "mosaic_gpu_test",
|
|
|
|
srcs = [
|
|
|
|
"mosaic_gpu_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = [],
|
2024-05-14 14:47:24 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_h100_x32",
|
2025-01-14 13:02:17 -08:00
|
|
|
"gpu_h100",
|
2024-05-14 14:47:24 -07:00
|
|
|
],
|
|
|
|
env = {
|
|
|
|
"JAX_PALLAS_USE_MOSAIC_GPU": "1",
|
2024-10-30 10:36:51 -07:00
|
|
|
"JAX_PALLAS_VERBOSE_ERRORS": "0",
|
2024-05-14 14:47:24 -07:00
|
|
|
},
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
2024-10-07 04:04:16 -07:00
|
|
|
"//jax:pallas_mosaic_gpu", # build_cleaner: keep
|
2024-07-23 16:12:14 -07:00
|
|
|
"//jax/_src/pallas/mosaic_gpu",
|
2024-05-14 14:47:24 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-05-02 05:37:41 -07:00
|
|
|
name = "export_back_compat_pallas_test",
|
|
|
|
srcs = ["export_back_compat_pallas_test.py"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["cpu"],
|
2024-05-02 05:37:41 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
"gpu_h100_x32",
|
|
|
|
],
|
|
|
|
tags = [],
|
|
|
|
deps = [
|
|
|
|
"//jax:internal_export_back_compat_test_data",
|
|
|
|
"//jax:internal_export_back_compat_test_util",
|
2024-05-13 03:06:31 -07:00
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
2024-08-07 04:59:19 -07:00
|
|
|
"//jax:pallas_tpu_ops", # build_cleaner: keep
|
2024-05-02 05:37:41 -07:00
|
|
|
],
|
|
|
|
)
|
2024-06-10 03:12:39 -07:00
|
|
|
|
2025-02-03 13:21:01 -08:00
|
|
|
jax_py_test(
|
|
|
|
name = "export_pallas_test_cpu_only",
|
|
|
|
srcs = ["export_pallas_test.py"],
|
|
|
|
args = ["--jax_test_dut=cpu"],
|
|
|
|
main = "export_pallas_test.py",
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
|
|
|
"//jax:pallas_tpu", # build_cleaner: keep
|
|
|
|
"//jax:test_util",
|
|
|
|
] + jax_gpu_support_deps,
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-06-10 03:12:39 -07:00
|
|
|
name = "export_pallas_test",
|
|
|
|
srcs = ["export_pallas_test.py"],
|
2025-02-03 13:21:01 -08:00
|
|
|
# Cross-compilation on CPU is tested separately.
|
|
|
|
disable_configs = [
|
|
|
|
"cpu",
|
|
|
|
"cpu_shardy",
|
|
|
|
"cpu_x32",
|
|
|
|
],
|
2024-06-10 03:12:39 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
],
|
|
|
|
tags = [],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
|
|
|
"//jax:pallas_tpu", # build_cleaner: keep
|
|
|
|
],
|
|
|
|
)
|
2024-07-01 10:17:21 -07:00
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-01 10:17:21 -07:00
|
|
|
name = "pallas_shape_poly_test",
|
|
|
|
srcs = ["pallas_shape_poly_test.py"],
|
|
|
|
disable_configs = [
|
|
|
|
"gpu_h100",
|
|
|
|
"gpu_p100",
|
|
|
|
"gpu_p100_x32",
|
2025-01-14 11:55:02 -08:00
|
|
|
"gpu_v100_x32",
|
|
|
|
"gpu_p100_pjrt_c_api",
|
2024-07-01 10:17:21 -07:00
|
|
|
],
|
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
],
|
|
|
|
tags = [],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
|
|
|
"//jax:pallas_tpu", # build_cleaner: keep
|
|
|
|
],
|
|
|
|
)
|
2024-07-10 08:22:20 -07:00
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-08-12 14:41:58 -07:00
|
|
|
name = "pallas_error_handling_test",
|
|
|
|
srcs = [
|
|
|
|
"pallas_error_handling_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-08-12 14:41:58 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax/_src/pallas/mosaic:random",
|
2024-10-09 15:24:40 -04:00
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
2024-08-12 14:41:58 -07:00
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_all_gather_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_all_gather_test.py",
|
|
|
|
],
|
2024-12-10 15:42:02 -08:00
|
|
|
enable_backends = [],
|
2024-10-10 08:40:51 -07:00
|
|
|
enable_configs = [
|
|
|
|
"tpu_v5e_4x2",
|
|
|
|
],
|
2024-07-10 08:22:20 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_gmm_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_gmm_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
shard_count = 50,
|
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps([
|
|
|
|
"absl/testing",
|
|
|
|
"absl/flags",
|
|
|
|
"numpy",
|
|
|
|
"hypothesis",
|
|
|
|
]),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_pallas_test",
|
|
|
|
srcs = ["tpu_pallas_test.py"],
|
|
|
|
# The flag is necessary for ``pl.debug_print`` tests to work on TPU.
|
|
|
|
args = ["--logtostderr"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-10-10 08:40:51 -07:00
|
|
|
enable_configs = [
|
|
|
|
"tpu_v5e",
|
|
|
|
"tpu_v5p_1x1",
|
|
|
|
],
|
2024-07-10 08:22:20 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:extend",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-08-12 05:08:56 -07:00
|
|
|
name = "tpu_ops_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_ops_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = [
|
|
|
|
"cpu",
|
|
|
|
"tpu",
|
2024-08-12 05:08:56 -07:00
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-31 08:18:47 -07:00
|
|
|
name = "tpu_pallas_distributed_test",
|
|
|
|
srcs = ["tpu_pallas_distributed_test.py"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-10-10 08:40:51 -07:00
|
|
|
enable_configs = [
|
|
|
|
"tpu_v5e_4x2",
|
|
|
|
"tpu_v5p_2x2",
|
|
|
|
"tpu_v4_2x2",
|
|
|
|
"tpu_v3_2x2",
|
|
|
|
],
|
2024-07-31 08:18:47 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:extend",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_pallas_pipeline_test",
|
|
|
|
srcs = ["tpu_pallas_pipeline_test.py"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-10-10 08:40:51 -07:00
|
|
|
enable_configs = [
|
|
|
|
"tpu_v5e_4x2",
|
|
|
|
"tpu_v5p_1x1",
|
|
|
|
],
|
2024-07-10 08:22:20 -07:00
|
|
|
shard_count = 5,
|
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:extend",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("hypothesis"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-09-18 20:38:54 -07:00
|
|
|
name = "tpu_pallas_async_test",
|
|
|
|
srcs = ["tpu_pallas_async_test.py"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-10-10 08:40:51 -07:00
|
|
|
enable_configs = [
|
|
|
|
"tpu_v5e_4x2",
|
|
|
|
"tpu_v5p_1x1",
|
2024-09-18 20:38:54 -07:00
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-10-01 16:29:59 -07:00
|
|
|
name = "tpu_pallas_state_test",
|
|
|
|
srcs = ["tpu_pallas_state_test.py"],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
mesh = pltpu.create_tensorcore_mesh('core')
y = jnp.zeros_like(x)
@state_discharge.run_state
def inner(refs):
x_ref, y_ref = refs
def kernel():
def alloc(sem):
pltpu.async_copy(x_ref, y_ref, sem).wait()
pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
check_rep=False)()
_, y = inner((x, y))
return y
```
Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch
This change allows you to express pallas_call *compositionally* using existing APIs.
1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA
The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.
PiperOrigin-RevId: 655320587
2024-07-23 15:15:11 -07:00
|
|
|
tags = [
|
|
|
|
"noasan",
|
|
|
|
"nomsan",
|
|
|
|
"notsan",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:extend",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_pallas_random_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_pallas_random_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-11-19 11:25:51 -08:00
|
|
|
enable_configs = [
|
|
|
|
"tpu_v5p_2x2",
|
|
|
|
],
|
2024-07-10 08:22:20 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_tpu",
|
2024-11-19 11:25:51 -08:00
|
|
|
"//jax:pallas_tpu_ops",
|
2024-07-10 08:22:20 -07:00
|
|
|
"//jax/_src/pallas/mosaic:random",
|
2024-10-09 15:24:40 -04:00
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
2024-07-10 08:22:20 -07:00
|
|
|
)
|
|
|
|
|
Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
2024-11-22 10:49:17 -08:00
|
|
|
jax_multiplatform_test(
|
|
|
|
name = "tpu_pallas_interpret_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_pallas_interpret_test.py",
|
|
|
|
],
|
|
|
|
disable_configs = ["cpu_shardy"],
|
|
|
|
enable_backends = ["cpu"],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_tpu",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_multiplatform_test(
|
|
|
|
name = "tpu_pallas_interpret_distributed_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_pallas_interpret_distributed_test.py",
|
|
|
|
],
|
|
|
|
disable_configs = ["cpu_shardy"],
|
|
|
|
enable_backends = ["cpu"],
|
|
|
|
deps = [
|
|
|
|
"//third_party/py/jax:pallas",
|
|
|
|
"//third_party/py/jax:pallas_tpu",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_paged_attention_kernel_test",
|
|
|
|
srcs = ["tpu_paged_attention_kernel_test.py"],
|
2024-11-08 11:19:46 -08:00
|
|
|
disable_configs = [
|
|
|
|
"tpu_v5p_1x1",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
shard_count = 5,
|
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
2025-01-28 14:46:53 -08:00
|
|
|
"notap", # this code has data race issues that XLA improvements unhide. b/392946030
|
2024-07-10 08:22:20 -07:00
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_splash_attention_kernel_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_splash_attention_kernel_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["tpu"],
|
2024-07-22 22:04:40 -07:00
|
|
|
shard_count = 24,
|
2024-07-10 08:22:20 -07:00
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas_tpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
|
|
|
)
|
|
|
|
|
2024-09-27 13:35:42 -07:00
|
|
|
# This test doesn't need a TPU; it only tests numpy-using helpers.
|
|
|
|
jax_py_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "tpu_splash_attention_mask_test",
|
|
|
|
srcs = [
|
|
|
|
"tpu_splash_attention_mask_test.py",
|
|
|
|
],
|
|
|
|
deps = [
|
2024-09-27 13:35:42 -07:00
|
|
|
"//jax",
|
2024-07-10 08:22:20 -07:00
|
|
|
"//jax:pallas_tpu_ops",
|
2024-09-27 13:35:42 -07:00
|
|
|
"//jax:test_util",
|
2024-07-10 08:22:20 -07:00
|
|
|
] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "gpu_attention_test",
|
|
|
|
srcs = [
|
|
|
|
"gpu_attention_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["cpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
"gpu_h100_x32",
|
|
|
|
],
|
|
|
|
shard_count = 1,
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu", # build_cleaner: keep
|
|
|
|
"//jax:pallas_gpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-09-24 12:28:32 -07:00
|
|
|
jax_multiplatform_test(
|
2024-07-10 08:22:20 -07:00
|
|
|
name = "gpu_ops_test",
|
|
|
|
srcs = [
|
|
|
|
"gpu_ops_test.py",
|
|
|
|
],
|
2024-09-27 06:14:50 -07:00
|
|
|
enable_backends = ["cpu"],
|
2024-07-10 08:22:20 -07:00
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
"gpu_h100_x32",
|
|
|
|
],
|
2025-01-29 10:58:20 -08:00
|
|
|
shard_count = 20,
|
|
|
|
tags = [
|
|
|
|
"noasan", # Times out.
|
|
|
|
"nomsan", # Times out.
|
|
|
|
"notsan", # Times out.
|
|
|
|
],
|
2024-07-10 08:22:20 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu",
|
|
|
|
"//jax:pallas_gpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
2024-10-29 05:20:02 -07:00
|
|
|
|
2025-01-05 19:32:11 +00:00
|
|
|
jax_multiplatform_test(
|
|
|
|
name = "gpu_paged_attention_test",
|
|
|
|
srcs = [
|
|
|
|
"gpu_paged_attention_test.py",
|
|
|
|
],
|
|
|
|
enable_backends = ["cpu"],
|
|
|
|
enable_configs = [
|
|
|
|
"gpu_a100_x32",
|
|
|
|
"gpu_h100_x32",
|
|
|
|
],
|
|
|
|
shard_count = 6,
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_gpu",
|
|
|
|
"//jax:pallas_gpu_ops",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
2024-10-29 05:20:02 -07:00
|
|
|
jax_multiplatform_test(
|
|
|
|
name = "mgpu_attention_run",
|
|
|
|
srcs = ["//jax/experimental/pallas/ops/gpu:attention_mgpu.py"],
|
|
|
|
enable_backends = [],
|
|
|
|
enable_configs = ["gpu_h100_x32"],
|
2024-12-06 09:18:28 -08:00
|
|
|
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
|
2024-10-29 05:20:02 -07:00
|
|
|
tags = [
|
|
|
|
"manual",
|
|
|
|
"notap",
|
|
|
|
],
|
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_mosaic_gpu",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|
|
|
|
|
|
|
|
jax_multiplatform_test(
|
|
|
|
name = "mgpu_attention_test",
|
|
|
|
srcs = ["mgpu_attention_test.py"],
|
|
|
|
enable_backends = [],
|
|
|
|
enable_configs = ["gpu_h100_x32"],
|
2024-12-06 09:18:28 -08:00
|
|
|
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
|
2024-10-29 05:20:02 -07:00
|
|
|
deps = [
|
|
|
|
"//jax:pallas",
|
|
|
|
"//jax:pallas_experimental_gpu_ops",
|
|
|
|
"//jax:pallas_mosaic_gpu",
|
|
|
|
] + py_deps("absl/testing") + py_deps("numpy"),
|
|
|
|
)
|