From 401d315091877806f77d2ef833919cbe0a5b482f Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Feb 2025 08:33:06 -0800 Subject: [PATCH] Add targets for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` editable wheels. PiperOrigin-RevId: 731737119 --- jaxlib/jax.bzl | 45 +++++++++++++++++++++++++--------------- jaxlib/tools/BUILD.bazel | 27 ++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index 14f92058d..ffeda65b3 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -374,27 +374,34 @@ def _jax_wheel_impl(ctx): no_abi = ctx.attr.no_abi platform_independent = ctx.attr.platform_independent build_wheel_only = ctx.attr.build_wheel_only + editable = ctx.attr.editable platform_name = ctx.attr.platform_name - wheel_name = _get_full_wheel_name( - package_name = ctx.attr.wheel_name, - no_abi = no_abi, - platform_independent = platform_independent, - platform_name = platform_name, - cpu_name = cpu, - wheel_version = full_wheel_version, - ) - wheel_file = ctx.actions.declare_file(output_path + - "/" + wheel_name) - wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")] - outputs = [wheel_file] - if not build_wheel_only: - source_distribution_name = _get_source_distribution_name( + if editable: + output_dir = ctx.actions.declare_directory(output_path + "/" + ctx.attr.wheel_name) + wheel_dir = output_dir.path + outputs = [output_dir] + args.add("--editable") + else: + wheel_name = _get_full_wheel_name( package_name = ctx.attr.wheel_name, + no_abi = no_abi, + platform_independent = platform_independent, + platform_name = platform_name, + cpu_name = cpu, wheel_version = full_wheel_version, ) - source_distribution_file = ctx.actions.declare_file(output_path + - "/" + source_distribution_name) - outputs.append(source_distribution_file) + wheel_file = ctx.actions.declare_file(output_path + + "/" + wheel_name) + wheel_dir = wheel_file.path[:wheel_file.path.rfind("/")] + outputs = [wheel_file] + if not build_wheel_only: + source_distribution_name = _get_source_distribution_name( + package_name = ctx.attr.wheel_name, + wheel_version = full_wheel_version, + ) + source_distribution_file = ctx.actions.declare_file(output_path + + "/" + source_distribution_name) + outputs.append(source_distribution_file) args.add("--output_path", wheel_dir) # required argument if not platform_independent: @@ -445,6 +452,7 @@ _jax_wheel = rule( "no_abi": attr.bool(default = False), "platform_independent": attr.bool(default = False), "build_wheel_only": attr.bool(default = True), + "editable": attr.bool(default = False), "cpu": attr.string(mandatory = True), "platform_name": attr.string(mandatory = True), "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), @@ -469,6 +477,7 @@ def jax_wheel( no_abi = False, platform_independent = False, build_wheel_only = True, + editable = False, enable_cuda = False, platform_version = "", source_files = []): @@ -482,6 +491,7 @@ def jax_wheel( wheel_name: the name of the wheel no_abi: whether to build a wheel without ABI build_wheel_only: whether to build a wheel without source distribution + editable: whether to build an editable wheel platform_independent: whether to build a wheel without platform tag enable_cuda: whether to build a cuda wheel platform_version: the cuda version to use for the wheel @@ -497,6 +507,7 @@ def jax_wheel( no_abi = no_abi, platform_independent = platform_independent, build_wheel_only = build_wheel_only, + editable = editable, enable_cuda = enable_cuda, platform_version = platform_version, # git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)` diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 318846381..de9f636ed 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -211,6 +211,13 @@ jax_wheel( wheel_name = "jaxlib", ) +jax_wheel( + name = "jaxlib_wheel_editable", + editable = True, + wheel_binary = ":build_wheel", + wheel_name = "jaxlib", +) + jax_wheel( name = "jax_cuda_plugin_wheel", enable_cuda = True, @@ -221,6 +228,16 @@ jax_wheel( wheel_name = "jax_cuda12_plugin", ) +jax_wheel( + name = "jax_cuda_plugin_wheel_editable", + editable = True, + enable_cuda = True, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "12", + wheel_binary = ":build_gpu_kernels_wheel", + wheel_name = "jax_cuda12_plugin", +) + jax_wheel( name = "jax_cuda_pjrt_wheel", enable_cuda = True, @@ -231,6 +248,16 @@ jax_wheel( wheel_name = "jax_cuda12_pjrt", ) +jax_wheel( + name = "jax_cuda_pjrt_wheel_editable", + editable = True, + enable_cuda = True, + # TODO(b/371217563) May use hermetic cuda version here. + platform_version = "12", + wheel_binary = ":build_gpu_plugin_wheel", + wheel_name = "jax_cuda12_pjrt", +) + AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")]) PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])