mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/ Cleanup only, no functional changes intended. PiperOrigin-RevId: 738183402
392 lines
9.6 KiB
Python
392 lines
9.6 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# JAX is Autograd and XLA
|
|
|
|
load("@bazel_skylib//lib:selects.bzl", "selects")
|
|
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
|
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
|
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
|
load(
|
|
"@xla//third_party/py:py_import.bzl",
|
|
"py_import",
|
|
)
|
|
load(
|
|
"@xla//third_party/py:py_manylinux_compliance_test.bzl",
|
|
"verify_manylinux_compliance_test",
|
|
)
|
|
load(
|
|
"//jaxlib:jax.bzl",
|
|
"PLATFORM_TAGS_DICT",
|
|
"if_windows",
|
|
"jax_py_test",
|
|
"jax_wheel",
|
|
"pytype_strict_library",
|
|
)
|
|
|
|
licenses(["notice"]) # Apache 2
|
|
|
|
package(default_visibility = ["//visibility:public"])
|
|
|
|
genrule(
|
|
name = "platform_tags_py",
|
|
srcs = [],
|
|
outs = ["platform_tags.py"],
|
|
cmd = "echo 'PLATFORM_TAGS_DICT = %s' > $@;" % PLATFORM_TAGS_DICT,
|
|
)
|
|
|
|
pytype_strict_library(
|
|
name = "build_utils",
|
|
srcs = [
|
|
"build_utils.py",
|
|
":platform_tags_py",
|
|
],
|
|
)
|
|
|
|
py_binary(
|
|
name = "build_wheel",
|
|
srcs = ["build_wheel.py"],
|
|
data = [
|
|
"LICENSE.txt",
|
|
"//jaxlib",
|
|
"//jaxlib:README.md",
|
|
"//jaxlib:setup.py",
|
|
"@xla//xla/ffi/api:api.h",
|
|
"@xla//xla/ffi/api:c_api.h",
|
|
"@xla//xla/ffi/api:ffi.h",
|
|
"@xla//xla/python:xla_client.py",
|
|
"@xla//xla/python:xla_extension",
|
|
] + if_windows([
|
|
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
|
|
]),
|
|
deps = [
|
|
":build_utils",
|
|
"@bazel_tools//tools/python/runfiles",
|
|
"@pypi_build//:pkg",
|
|
"@pypi_setuptools//:pkg",
|
|
"@pypi_wheel//:pkg",
|
|
],
|
|
)
|
|
|
|
jax_py_test(
|
|
name = "build_wheel_test",
|
|
srcs = ["build_wheel_test.py"],
|
|
data = [":build_wheel"],
|
|
deps = [
|
|
"@bazel_tools//tools/python/runfiles",
|
|
],
|
|
)
|
|
|
|
cc_binary(
|
|
name = "pjrt_c_api_gpu_plugin.so",
|
|
linkopts = [
|
|
"-Wl,--version-script,$(location :gpu_version_script.lds)",
|
|
"-Wl,--no-undefined",
|
|
],
|
|
linkshared = True,
|
|
deps = [
|
|
":gpu_version_script.lds",
|
|
"@xla//xla/pjrt/c:pjrt_c_api_gpu",
|
|
"@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds",
|
|
"@xla//xla/service:gpu_plugin",
|
|
] + if_cuda([
|
|
"//jaxlib/mosaic/gpu:custom_call",
|
|
"@xla//xla/stream_executor:cuda_platform",
|
|
]) + if_rocm([
|
|
"@xla//xla/stream_executor:rocm_platform",
|
|
]),
|
|
)
|
|
|
|
py_binary(
|
|
name = "build_gpu_plugin_wheel",
|
|
srcs = ["build_gpu_plugin_wheel.py"],
|
|
data = [
|
|
"LICENSE.txt",
|
|
":pjrt_c_api_gpu_plugin.so",
|
|
] + if_cuda([
|
|
"//jaxlib:version",
|
|
"//jaxlib/cuda:cuda_gpu_support",
|
|
"//jax_plugins/cuda:pyproject.toml",
|
|
"//jax_plugins/cuda:setup.py",
|
|
"//jax_plugins/cuda:__init__.py",
|
|
"@local_config_cuda//cuda:cuda-nvvm",
|
|
]) + if_rocm([
|
|
"//jaxlib:version",
|
|
"//jaxlib/rocm:rocm_gpu_support",
|
|
"//jax_plugins/rocm:pyproject.toml",
|
|
"//jax_plugins/rocm:setup.py",
|
|
"//jax_plugins/rocm:__init__.py",
|
|
]),
|
|
deps = [
|
|
":build_utils",
|
|
"@bazel_tools//tools/python/runfiles",
|
|
"@pypi_build//:pkg",
|
|
"@pypi_setuptools//:pkg",
|
|
"@pypi_wheel//:pkg",
|
|
],
|
|
)
|
|
|
|
py_binary(
|
|
name = "build_gpu_kernels_wheel",
|
|
srcs = ["build_gpu_kernels_wheel.py"],
|
|
data = [
|
|
"LICENSE.txt",
|
|
] + if_cuda([
|
|
"//jaxlib:version",
|
|
"//jaxlib/mosaic/gpu:mosaic_gpu",
|
|
"//jaxlib/cuda:cuda_plugin_extension",
|
|
"//jaxlib/cuda:cuda_gpu_support",
|
|
"//jax_plugins/cuda:plugin_pyproject.toml",
|
|
"//jax_plugins/cuda:plugin_setup.py",
|
|
"@local_config_cuda//cuda:cuda-nvvm",
|
|
]) + if_rocm([
|
|
"//jaxlib:version",
|
|
"//jaxlib/rocm:rocm_plugin_extension",
|
|
"//jaxlib/rocm:rocm_gpu_support",
|
|
"//jax_plugins/rocm:plugin_pyproject.toml",
|
|
"//jax_plugins/rocm:plugin_setup.py",
|
|
]),
|
|
deps = [
|
|
":build_utils",
|
|
"@bazel_tools//tools/python/runfiles",
|
|
"@pypi_build//:pkg",
|
|
"@pypi_setuptools//:pkg",
|
|
"@pypi_wheel//:pkg",
|
|
],
|
|
)
|
|
|
|
selects.config_setting_group(
|
|
name = "macos",
|
|
match_any = [
|
|
"@platforms//os:osx",
|
|
"@platforms//os:macos",
|
|
],
|
|
)
|
|
|
|
selects.config_setting_group(
|
|
name = "arm64",
|
|
match_any = [
|
|
"@platforms//cpu:aarch64",
|
|
"@platforms//cpu:arm64",
|
|
],
|
|
)
|
|
|
|
selects.config_setting_group(
|
|
name = "macos_arm64",
|
|
match_all = [
|
|
":arm64",
|
|
":macos",
|
|
],
|
|
)
|
|
|
|
selects.config_setting_group(
|
|
name = "macos_x86_64",
|
|
match_all = [
|
|
"@platforms//cpu:x86_64",
|
|
":macos",
|
|
],
|
|
)
|
|
|
|
selects.config_setting_group(
|
|
name = "win_amd64",
|
|
match_all = [
|
|
"@platforms//cpu:x86_64",
|
|
"@platforms//os:windows",
|
|
],
|
|
)
|
|
|
|
selects.config_setting_group(
|
|
name = "linux_x86_64",
|
|
match_all = [
|
|
"@platforms//cpu:x86_64",
|
|
"@platforms//os:linux",
|
|
],
|
|
)
|
|
|
|
selects.config_setting_group(
|
|
name = "linux_aarch64",
|
|
match_all = [
|
|
":arm64",
|
|
"@platforms//os:linux",
|
|
],
|
|
)
|
|
|
|
string_flag(
|
|
name = "jaxlib_git_hash",
|
|
build_setting_default = "",
|
|
)
|
|
|
|
string_flag(
|
|
name = "output_path",
|
|
build_setting_default = "dist",
|
|
)
|
|
|
|
NVIDIA_WHEELS_DEPS = [
|
|
"@pypi_nvidia_cublas_cu12//:whl",
|
|
"@pypi_nvidia_cuda_cupti_cu12//:whl",
|
|
"@pypi_nvidia_cuda_runtime_cu12//:whl",
|
|
"@pypi_nvidia_cudnn_cu12//:whl",
|
|
"@pypi_nvidia_cufft_cu12//:whl",
|
|
"@pypi_nvidia_cusolver_cu12//:whl",
|
|
"@pypi_nvidia_cusparse_cu12//:whl",
|
|
"@pypi_nvidia_nccl_cu12//:whl",
|
|
"@pypi_nvidia_nvjitlink_cu12//:whl",
|
|
]
|
|
|
|
jax_wheel(
|
|
name = "jaxlib_wheel",
|
|
no_abi = False,
|
|
wheel_binary = ":build_wheel",
|
|
wheel_name = "jaxlib",
|
|
)
|
|
|
|
py_import(
|
|
name = "jaxlib_py_import",
|
|
wheel = ":jaxlib_wheel",
|
|
)
|
|
|
|
jax_wheel(
|
|
name = "jaxlib_wheel_editable",
|
|
editable = True,
|
|
wheel_binary = ":build_wheel",
|
|
wheel_name = "jaxlib",
|
|
)
|
|
|
|
jax_wheel(
|
|
name = "jax_cuda_plugin_wheel",
|
|
enable_cuda = True,
|
|
no_abi = False,
|
|
# TODO(b/371217563) May use hermetic cuda version here.
|
|
platform_version = "12",
|
|
wheel_binary = ":build_gpu_kernels_wheel",
|
|
wheel_name = "jax_cuda12_plugin",
|
|
)
|
|
|
|
py_import(
|
|
name = "jax_cuda_plugin_py_import",
|
|
wheel = ":jax_cuda_plugin_wheel",
|
|
wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS),
|
|
)
|
|
|
|
jax_wheel(
|
|
name = "jax_cuda_plugin_wheel_editable",
|
|
editable = True,
|
|
enable_cuda = True,
|
|
# TODO(b/371217563) May use hermetic cuda version here.
|
|
platform_version = "12",
|
|
wheel_binary = ":build_gpu_kernels_wheel",
|
|
wheel_name = "jax_cuda12_plugin",
|
|
)
|
|
|
|
jax_wheel(
|
|
name = "jax_rocm_plugin_wheel",
|
|
enable_rocm = True,
|
|
no_abi = False,
|
|
platform_version = "60",
|
|
wheel_binary = ":build_gpu_kernels_wheel",
|
|
wheel_name = "jax_rocm60_plugin",
|
|
)
|
|
|
|
jax_wheel(
|
|
name = "jax_rocm_plugin_wheel_editable",
|
|
editable = True,
|
|
enable_rocm = True,
|
|
platform_version = "60",
|
|
wheel_binary = ":build_gpu_kernels_wheel",
|
|
wheel_name = "jax_rocm60_plugin",
|
|
)
|
|
|
|
jax_wheel(
|
|
name = "jax_cuda_pjrt_wheel",
|
|
enable_cuda = True,
|
|
no_abi = True,
|
|
# TODO(b/371217563) May use hermetic cuda version here.
|
|
platform_version = "12",
|
|
wheel_binary = ":build_gpu_plugin_wheel",
|
|
wheel_name = "jax_cuda12_pjrt",
|
|
)
|
|
|
|
py_import(
|
|
name = "jax_cuda_pjrt_py_import",
|
|
wheel = ":jax_cuda_pjrt_wheel",
|
|
wheel_deps = if_cuda(NVIDIA_WHEELS_DEPS),
|
|
)
|
|
|
|
jax_wheel(
|
|
name = "jax_cuda_pjrt_wheel_editable",
|
|
editable = True,
|
|
enable_cuda = True,
|
|
# TODO(b/371217563) May use hermetic cuda version here.
|
|
platform_version = "12",
|
|
wheel_binary = ":build_gpu_plugin_wheel",
|
|
wheel_name = "jax_cuda12_pjrt",
|
|
)
|
|
|
|
jax_wheel(
|
|
name = "jax_rocm_pjrt_wheel",
|
|
enable_rocm = True,
|
|
no_abi = True,
|
|
platform_version = "60",
|
|
wheel_binary = ":build_gpu_plugin_wheel",
|
|
wheel_name = "jax_rocm60_pjrt",
|
|
)
|
|
|
|
jax_wheel(
|
|
name = "jax_rocm_pjrt_wheel_editable",
|
|
editable = True,
|
|
enable_rocm = True,
|
|
platform_version = "60",
|
|
wheel_binary = ":build_gpu_plugin_wheel",
|
|
wheel_name = "jax_rocm60_pjrt",
|
|
)
|
|
|
|
AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")])
|
|
|
|
PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])
|
|
|
|
X86_64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "x86_64")])
|
|
|
|
verify_manylinux_compliance_test(
|
|
name = "jaxlib_manylinux_compliance_test",
|
|
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
|
|
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
|
|
test_tags = [
|
|
"manual",
|
|
],
|
|
wheel = ":jaxlib_wheel",
|
|
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
|
|
)
|
|
|
|
verify_manylinux_compliance_test(
|
|
name = "jax_cuda_plugin_manylinux_compliance_test",
|
|
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
|
|
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
|
|
test_tags = [
|
|
"manual",
|
|
],
|
|
wheel = ":jax_cuda_plugin_wheel",
|
|
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
|
|
)
|
|
|
|
verify_manylinux_compliance_test(
|
|
name = "jax_cuda_pjrt_manylinux_compliance_test",
|
|
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
|
|
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
|
|
test_tags = [
|
|
"manual",
|
|
],
|
|
wheel = ":jax_cuda_pjrt_wheel",
|
|
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
|
|
)
|