mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add build targets for jax-rocm-plugin
and jax-rocm-pjrt
wheels.
PiperOrigin-RevId: 732149495
This commit is contained in:
parent
bb96226dd8
commit
8f57b8167b
@ -479,6 +479,7 @@ def jax_wheel(
|
||||
build_wheel_only = True,
|
||||
editable = False,
|
||||
enable_cuda = False,
|
||||
enable_rocm = False,
|
||||
platform_version = "",
|
||||
source_files = []):
|
||||
"""Create jax artifact wheels.
|
||||
@ -494,6 +495,7 @@ def jax_wheel(
|
||||
editable: whether to build an editable wheel
|
||||
platform_independent: whether to build a wheel without platform tag
|
||||
enable_cuda: whether to build a cuda wheel
|
||||
enable_rocm: whether to build a rocm wheel
|
||||
platform_version: the cuda version to use for the wheel
|
||||
source_files: the source files to include in the wheel
|
||||
|
||||
@ -509,6 +511,7 @@ def jax_wheel(
|
||||
build_wheel_only = build_wheel_only,
|
||||
editable = editable,
|
||||
enable_cuda = enable_cuda,
|
||||
enable_rocm = enable_rocm,
|
||||
platform_version = platform_version,
|
||||
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
|
||||
# flag in bazel command to pass the git hash for nightly or release builds.
|
||||
|
@ -238,6 +238,24 @@ jax_wheel(
|
||||
wheel_name = "jax_cuda12_plugin",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_rocm_plugin_wheel",
|
||||
enable_rocm = True,
|
||||
no_abi = False,
|
||||
platform_version = "60",
|
||||
wheel_binary = ":build_gpu_kernels_wheel",
|
||||
wheel_name = "jax_rocm60_plugin",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_rocm_plugin_wheel_editable",
|
||||
editable = True,
|
||||
enable_rocm = True,
|
||||
platform_version = "60",
|
||||
wheel_binary = ":build_gpu_kernels_wheel",
|
||||
wheel_name = "jax_rocm60_plugin",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_cuda_pjrt_wheel",
|
||||
enable_cuda = True,
|
||||
@ -258,6 +276,24 @@ jax_wheel(
|
||||
wheel_name = "jax_cuda12_pjrt",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_rocm_pjrt_wheel",
|
||||
enable_rocm = True,
|
||||
no_abi = True,
|
||||
platform_version = "60",
|
||||
wheel_binary = ":build_gpu_plugin_wheel",
|
||||
wheel_name = "jax_rocm60_pjrt",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_rocm_pjrt_wheel_editable",
|
||||
editable = True,
|
||||
enable_rocm = True,
|
||||
platform_version = "60",
|
||||
wheel_binary = ":build_gpu_plugin_wheel",
|
||||
wheel_name = "jax_rocm60_pjrt",
|
||||
)
|
||||
|
||||
AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")])
|
||||
|
||||
PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])
|
||||
|
Loading…
x
Reference in New Issue
Block a user