Add build targets for jax-rocm-plugin and jax-rocm-pjrt wheels.

PiperOrigin-RevId: 732149495
This commit is contained in:
jax authors 2025-02-28 08:36:08 -08:00
parent bb96226dd8
commit 8f57b8167b
2 changed files with 39 additions and 0 deletions

View File

@ -479,6 +479,7 @@ def jax_wheel(
build_wheel_only = True,
editable = False,
enable_cuda = False,
enable_rocm = False,
platform_version = "",
source_files = []):
"""Create jax artifact wheels.
@ -494,6 +495,7 @@ def jax_wheel(
editable: whether to build an editable wheel
platform_independent: whether to build a wheel without platform tag
enable_cuda: whether to build a cuda wheel
enable_rocm: whether to build a rocm wheel
platform_version: the cuda version to use for the wheel
source_files: the source files to include in the wheel
@ -509,6 +511,7 @@ def jax_wheel(
build_wheel_only = build_wheel_only,
editable = editable,
enable_cuda = enable_cuda,
enable_rocm = enable_rocm,
platform_version = platform_version,
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
# flag in bazel command to pass the git hash for nightly or release builds.

View File

@ -238,6 +238,24 @@ jax_wheel(
wheel_name = "jax_cuda12_plugin",
)
jax_wheel(
name = "jax_rocm_plugin_wheel",
enable_rocm = True,
no_abi = False,
platform_version = "60",
wheel_binary = ":build_gpu_kernels_wheel",
wheel_name = "jax_rocm60_plugin",
)
jax_wheel(
name = "jax_rocm_plugin_wheel_editable",
editable = True,
enable_rocm = True,
platform_version = "60",
wheel_binary = ":build_gpu_kernels_wheel",
wheel_name = "jax_rocm60_plugin",
)
jax_wheel(
name = "jax_cuda_pjrt_wheel",
enable_cuda = True,
@ -258,6 +276,24 @@ jax_wheel(
wheel_name = "jax_cuda12_pjrt",
)
jax_wheel(
name = "jax_rocm_pjrt_wheel",
enable_rocm = True,
no_abi = True,
platform_version = "60",
wheel_binary = ":build_gpu_plugin_wheel",
wheel_name = "jax_rocm60_pjrt",
)
jax_wheel(
name = "jax_rocm_pjrt_wheel_editable",
editable = True,
enable_rocm = True,
platform_version = "60",
wheel_binary = ":build_gpu_plugin_wheel",
wheel_name = "jax_rocm60_pjrt",
)
AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")])
PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])