rocm_jax/jaxlib/tools/BUILD.bazel
David Dunleavy 1a19d5594a Update all uses of @tsl//third_party to @xla//third_party
PiperOrigin-RevId: 733495240
2025-03-04 15:55:23 -08:00

335 lines
8.4 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# JAX is Autograd and XLA
load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load(
"@xla//third_party/py:py_manylinux_compliance_test.bzl",
"verify_manylinux_compliance_test",
)
load(
"//jaxlib:jax.bzl",
"PLATFORM_TAGS_DICT",
"if_windows",
"jax_py_test",
"jax_wheel",
"pytype_strict_library",
)
licenses(["notice"]) # Apache 2
package(default_visibility = ["//visibility:public"])
genrule(
name = "platform_tags_py",
srcs = [],
outs = ["platform_tags.py"],
cmd = "echo 'PLATFORM_TAGS_DICT = %s' > $@;" % PLATFORM_TAGS_DICT,
)
pytype_strict_library(
name = "build_utils",
srcs = [
"build_utils.py",
":platform_tags_py",
],
)
py_binary(
name = "build_wheel",
srcs = ["build_wheel.py"],
data = [
"LICENSE.txt",
"//jaxlib",
"//jaxlib:README.md",
"//jaxlib:setup.py",
"@xla//xla/ffi/api:api.h",
"@xla//xla/ffi/api:c_api.h",
"@xla//xla/ffi/api:ffi.h",
"@xla//xla/python:xla_client.py",
"@xla//xla/python:xla_extension",
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]),
deps = [
":build_utils",
"@bazel_tools//tools/python/runfiles",
"@pypi_build//:pkg",
"@pypi_setuptools//:pkg",
"@pypi_wheel//:pkg",
],
)
jax_py_test(
name = "build_wheel_test",
srcs = ["build_wheel_test.py"],
data = [":build_wheel"],
deps = [
"@bazel_tools//tools/python/runfiles",
],
)
cc_binary(
name = "pjrt_c_api_gpu_plugin.so",
linkopts = [
"-Wl,--version-script,$(location :gpu_version_script.lds)",
"-Wl,--no-undefined",
],
linkshared = True,
deps = [
":gpu_version_script.lds",
"@xla//xla/pjrt/c:pjrt_c_api_gpu",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds",
"@xla//xla/service:gpu_plugin",
] + if_cuda([
"//jaxlib/mosaic/gpu:custom_call",
"@xla//xla/stream_executor:cuda_platform",
]) + if_rocm([
"@xla//xla/stream_executor:rocm_platform",
]),
)
py_binary(
name = "build_gpu_plugin_wheel",
srcs = ["build_gpu_plugin_wheel.py"],
data = [
"LICENSE.txt",
":pjrt_c_api_gpu_plugin.so",
] + if_cuda([
"//jaxlib:version",
"//jaxlib/cuda:cuda_gpu_support",
"//jax_plugins/cuda:pyproject.toml",
"//jax_plugins/cuda:setup.py",
"//jax_plugins/cuda:__init__.py",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([
"//jaxlib:version",
"//jaxlib/rocm:rocm_gpu_support",
"//jax_plugins/rocm:pyproject.toml",
"//jax_plugins/rocm:setup.py",
"//jax_plugins/rocm:__init__.py",
]),
deps = [
":build_utils",
"@bazel_tools//tools/python/runfiles",
"@pypi_build//:pkg",
"@pypi_setuptools//:pkg",
"@pypi_wheel//:pkg",
],
)
py_binary(
name = "build_gpu_kernels_wheel",
srcs = ["build_gpu_kernels_wheel.py"],
data = [
"LICENSE.txt",
] + if_cuda([
"//jaxlib/mosaic/gpu:mosaic_gpu",
"//jaxlib:cuda_plugin_extension",
"//jaxlib:version",
"//jaxlib/cuda:cuda_gpu_support",
"//jax_plugins/cuda:plugin_pyproject.toml",
"//jax_plugins/cuda:plugin_setup.py",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([
"//jaxlib:rocm_plugin_extension",
"//jaxlib:version",
"//jaxlib/rocm:rocm_gpu_support",
"//jax_plugins/rocm:plugin_pyproject.toml",
"//jax_plugins/rocm:plugin_setup.py",
]),
deps = [
":build_utils",
"@bazel_tools//tools/python/runfiles",
"@pypi_build//:pkg",
"@pypi_setuptools//:pkg",
"@pypi_wheel//:pkg",
],
)
selects.config_setting_group(
name = "macos",
match_any = [
"@platforms//os:osx",
"@platforms//os:macos",
],
)
selects.config_setting_group(
name = "arm64",
match_any = [
"@platforms//cpu:aarch64",
"@platforms//cpu:arm64",
],
)
selects.config_setting_group(
name = "macos_arm64",
match_all = [
":arm64",
":macos",
],
)
selects.config_setting_group(
name = "win_amd64",
match_all = [
"@platforms//cpu:x86_64",
"@platforms//os:windows",
],
)
string_flag(
name = "jaxlib_git_hash",
build_setting_default = "",
)
string_flag(
name = "output_path",
build_setting_default = "dist",
)
jax_wheel(
name = "jaxlib_wheel",
no_abi = False,
wheel_binary = ":build_wheel",
wheel_name = "jaxlib",
)
jax_wheel(
name = "jaxlib_wheel_editable",
editable = True,
wheel_binary = ":build_wheel",
wheel_name = "jaxlib",
)
jax_wheel(
name = "jax_cuda_plugin_wheel",
enable_cuda = True,
no_abi = False,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
wheel_binary = ":build_gpu_kernels_wheel",
wheel_name = "jax_cuda12_plugin",
)
jax_wheel(
name = "jax_cuda_plugin_wheel_editable",
editable = True,
enable_cuda = True,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
wheel_binary = ":build_gpu_kernels_wheel",
wheel_name = "jax_cuda12_plugin",
)
jax_wheel(
name = "jax_rocm_plugin_wheel",
enable_rocm = True,
no_abi = False,
platform_version = "60",
wheel_binary = ":build_gpu_kernels_wheel",
wheel_name = "jax_rocm60_plugin",
)
jax_wheel(
name = "jax_rocm_plugin_wheel_editable",
editable = True,
enable_rocm = True,
platform_version = "60",
wheel_binary = ":build_gpu_kernels_wheel",
wheel_name = "jax_rocm60_plugin",
)
jax_wheel(
name = "jax_cuda_pjrt_wheel",
enable_cuda = True,
no_abi = True,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
wheel_binary = ":build_gpu_plugin_wheel",
wheel_name = "jax_cuda12_pjrt",
)
jax_wheel(
name = "jax_cuda_pjrt_wheel_editable",
editable = True,
enable_cuda = True,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
wheel_binary = ":build_gpu_plugin_wheel",
wheel_name = "jax_cuda12_pjrt",
)
jax_wheel(
name = "jax_rocm_pjrt_wheel",
enable_rocm = True,
no_abi = True,
platform_version = "60",
wheel_binary = ":build_gpu_plugin_wheel",
wheel_name = "jax_rocm60_pjrt",
)
jax_wheel(
name = "jax_rocm_pjrt_wheel_editable",
editable = True,
enable_rocm = True,
platform_version = "60",
wheel_binary = ":build_gpu_plugin_wheel",
wheel_name = "jax_rocm60_pjrt",
)
AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")])
PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])
X86_64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "x86_64")])
verify_manylinux_compliance_test(
name = "jaxlib_manylinux_compliance_test",
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
test_tags = [
"manual",
],
wheel = ":jaxlib_wheel",
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)
verify_manylinux_compliance_test(
name = "jax_cuda_plugin_manylinux_compliance_test",
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
test_tags = [
"manual",
],
wheel = ":jax_cuda_plugin_wheel",
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)
verify_manylinux_compliance_test(
name = "jax_cuda_pjrt_manylinux_compliance_test",
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
test_tags = [
"manual",
],
wheel = ":jax_cuda_pjrt_wheel",
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)