# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # JAX is Autograd and XLA load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load( "@xla//third_party/py:py_manylinux_compliance_test.bzl", "verify_manylinux_compliance_test", ) load( "//jaxlib:jax.bzl", "PLATFORM_TAGS_DICT", "if_windows", "jax_py_test", "jax_wheel", "pytype_strict_library", ) licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) genrule( name = "platform_tags_py", srcs = [], outs = ["platform_tags.py"], cmd = "echo 'PLATFORM_TAGS_DICT = %s' > $@;" % PLATFORM_TAGS_DICT, ) pytype_strict_library( name = "build_utils", srcs = [ "build_utils.py", ":platform_tags_py", ], ) py_binary( name = "build_wheel", srcs = ["build_wheel.py"], data = [ "LICENSE.txt", "//jaxlib", "//jaxlib:README.md", "//jaxlib:setup.py", "@xla//xla/ffi/api:api.h", "@xla//xla/ffi/api:c_api.h", "@xla//xla/ffi/api:ffi.h", "@xla//xla/python:xla_client.py", "@xla//xla/python:xla_extension", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]), deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", "@pypi_build//:pkg", "@pypi_setuptools//:pkg", "@pypi_wheel//:pkg", ], ) jax_py_test( name = "build_wheel_test", srcs = ["build_wheel_test.py"], data = [":build_wheel"], deps = [ "@bazel_tools//tools/python/runfiles", ], ) cc_binary( name = "pjrt_c_api_gpu_plugin.so", linkopts = [ "-Wl,--version-script,$(location :gpu_version_script.lds)", "-Wl,--no-undefined", ], linkshared = True, deps = [ ":gpu_version_script.lds", "@xla//xla/pjrt/c:pjrt_c_api_gpu", "@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds", "@xla//xla/service:gpu_plugin", ] + if_cuda([ "//jaxlib/mosaic/gpu:custom_call", "@xla//xla/stream_executor:cuda_platform", ]) + if_rocm([ "@xla//xla/stream_executor:rocm_platform", ]), ) py_binary( name = "build_gpu_plugin_wheel", srcs = ["build_gpu_plugin_wheel.py"], data = [ "LICENSE.txt", ":pjrt_c_api_gpu_plugin.so", ] + if_cuda([ "//jaxlib:version", "//jaxlib/cuda:cuda_gpu_support", "//jax_plugins/cuda:pyproject.toml", "//jax_plugins/cuda:setup.py", "//jax_plugins/cuda:__init__.py", "@local_config_cuda//cuda:cuda-nvvm", ]) + if_rocm([ "//jaxlib:version", "//jaxlib/rocm:rocm_gpu_support", "//jax_plugins/rocm:pyproject.toml", "//jax_plugins/rocm:setup.py", "//jax_plugins/rocm:__init__.py", ]), deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", "@pypi_build//:pkg", "@pypi_setuptools//:pkg", "@pypi_wheel//:pkg", ], ) py_binary( name = "build_gpu_kernels_wheel", srcs = ["build_gpu_kernels_wheel.py"], data = [ "LICENSE.txt", ] + if_cuda([ "//jaxlib/mosaic/gpu:mosaic_gpu", "//jaxlib:cuda_plugin_extension", "//jaxlib:version", "//jaxlib/cuda:cuda_gpu_support", "//jax_plugins/cuda:plugin_pyproject.toml", "//jax_plugins/cuda:plugin_setup.py", "@local_config_cuda//cuda:cuda-nvvm", ]) + if_rocm([ "//jaxlib:rocm_plugin_extension", "//jaxlib:version", "//jaxlib/rocm:rocm_gpu_support", "//jax_plugins/rocm:plugin_pyproject.toml", "//jax_plugins/rocm:plugin_setup.py", ]), deps = [ ":build_utils", "@bazel_tools//tools/python/runfiles", "@pypi_build//:pkg", "@pypi_setuptools//:pkg", "@pypi_wheel//:pkg", ], ) selects.config_setting_group( name = "macos", match_any = [ "@platforms//os:osx", "@platforms//os:macos", ], ) selects.config_setting_group( name = "arm64", match_any = [ "@platforms//cpu:aarch64", "@platforms//cpu:arm64", ], ) selects.config_setting_group( name = "macos_arm64", match_all = [ ":arm64", ":macos", ], ) selects.config_setting_group( name = "win_amd64", match_all = [ "@platforms//cpu:x86_64", "@platforms//os:windows", ], ) string_flag( name = "jaxlib_git_hash", build_setting_default = "", ) string_flag( name = "output_path", build_setting_default = "dist", ) jax_wheel( name = "jaxlib_wheel", no_abi = False, wheel_binary = ":build_wheel", wheel_name = "jaxlib", ) jax_wheel( name = "jaxlib_wheel_editable", editable = True, wheel_binary = ":build_wheel", wheel_name = "jaxlib", ) jax_wheel( name = "jax_cuda_plugin_wheel", enable_cuda = True, no_abi = False, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", wheel_binary = ":build_gpu_kernels_wheel", wheel_name = "jax_cuda12_plugin", ) jax_wheel( name = "jax_cuda_plugin_wheel_editable", editable = True, enable_cuda = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", wheel_binary = ":build_gpu_kernels_wheel", wheel_name = "jax_cuda12_plugin", ) jax_wheel( name = "jax_rocm_plugin_wheel", enable_rocm = True, no_abi = False, platform_version = "60", wheel_binary = ":build_gpu_kernels_wheel", wheel_name = "jax_rocm60_plugin", ) jax_wheel( name = "jax_rocm_plugin_wheel_editable", editable = True, enable_rocm = True, platform_version = "60", wheel_binary = ":build_gpu_kernels_wheel", wheel_name = "jax_rocm60_plugin", ) jax_wheel( name = "jax_cuda_pjrt_wheel", enable_cuda = True, no_abi = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", wheel_binary = ":build_gpu_plugin_wheel", wheel_name = "jax_cuda12_pjrt", ) jax_wheel( name = "jax_cuda_pjrt_wheel_editable", editable = True, enable_cuda = True, # TODO(b/371217563) May use hermetic cuda version here. platform_version = "12", wheel_binary = ":build_gpu_plugin_wheel", wheel_name = "jax_cuda12_pjrt", ) jax_wheel( name = "jax_rocm_pjrt_wheel", enable_rocm = True, no_abi = True, platform_version = "60", wheel_binary = ":build_gpu_plugin_wheel", wheel_name = "jax_rocm60_pjrt", ) jax_wheel( name = "jax_rocm_pjrt_wheel_editable", editable = True, enable_rocm = True, platform_version = "60", wheel_binary = ":build_gpu_plugin_wheel", wheel_name = "jax_rocm60_pjrt", ) AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")]) PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")]) X86_64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "x86_64")]) verify_manylinux_compliance_test( name = "jaxlib_manylinux_compliance_test", aarch64_compliance_tag = AARCH64_MANYLINUX_TAG, ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG, test_tags = [ "manual", ], wheel = ":jaxlib_wheel", x86_64_compliance_tag = X86_64_MANYLINUX_TAG, ) verify_manylinux_compliance_test( name = "jax_cuda_plugin_manylinux_compliance_test", aarch64_compliance_tag = AARCH64_MANYLINUX_TAG, ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG, test_tags = [ "manual", ], wheel = ":jax_cuda_plugin_wheel", x86_64_compliance_tag = X86_64_MANYLINUX_TAG, ) verify_manylinux_compliance_test( name = "jax_cuda_pjrt_manylinux_compliance_test", aarch64_compliance_tag = AARCH64_MANYLINUX_TAG, ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG, test_tags = [ "manual", ], wheel = ":jax_cuda_pjrt_wheel", x86_64_compliance_tag = X86_64_MANYLINUX_TAG, )