Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results.

This change is a part of the initiative to test the JAX wheels in the presubmit properly.

The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.

2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase.

3. The version suffix of the wheel in the build rule output depends on the environment variables.

   The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables.

4. Environment variables combinations for creating wheels with different versions:
  * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release`
  * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1`
  * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 723552265
This commit is contained in:
jax authors 2025-02-05 10:00:49 -08:00
parent 9f53dfae0b
commit d424f5b5b3
11 changed files with 265 additions and 60 deletions

View File

@ -62,6 +62,21 @@ xla_workspace0()
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
flatbuffers()
load("//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository")
jax_python_wheel_repository(
name = "jax_wheel",
version_key = "_version",
version_source = "//jax:version.py",
)
load(
"@tsl//third_party/py:python_wheel.bzl",
"python_wheel_version_suffix_repository",
)
python_wheel_version_suffix_repository(
name = "jax_wheel_version_suffix",
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",

View File

@ -16,7 +16,6 @@ load("@rules_python//python:defs.bzl", "py_library")
load(
"//jaxlib:jax.bzl",
"py_deps",
"pytype_strict_library",
)
licenses(["notice"])
@ -46,8 +45,3 @@ py_library(
"//jax/experimental/jax2tf",
] + py_deps("tensorflow_core"),
)
pytype_strict_library(
name = "build_utils",
srcs = ["build_utils.py"],
)

View File

@ -35,6 +35,8 @@ def _get_version_string() -> str:
# In this case we return it directly.
if _release_version is not None:
return _release_version
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
@ -71,16 +73,23 @@ def _get_version_for_build() -> str:
"""Determine the version at build time.
The returned version string depends on which environment variables are set:
- if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc"
Here the WHEEL_VERSION_SUFFIX value is ".dev20230906+ge58560fdc".
Please note that the WHEEL_VERSION_SUFFIX value is not the same as the
JAX_CUSTOM_VERSION_SUFFIX value, and WHEEL_VERSION_SUFFIX is set by Bazel
wheel build rule.
- if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16"
- if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906"
- if none are set: version looks like "0.4.16.dev20230906+ge58560fdc
"""
if _release_version is not None:
return _release_version
if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'):
return _version_from_todays_date(_version)
if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'):
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
if os.getenv("JAX_RELEASE") or os.getenv("JAXLIB_RELEASE"):
return _version
if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"):
return _version_from_todays_date(_version)
return _version_from_git_tree(_version) or _version_from_todays_date(_version)

View File

@ -14,7 +14,10 @@
"""Bazel macros used by the JAX build."""
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION")
load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "BUILD_TAG", "WHEEL_VERSION_SUFFIX")
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
@ -50,6 +53,15 @@ jax_internal_test_harnesses_visibility = []
jax_test_util_visibility = []
loops_visibility = []
PLATFORM_TAGS_DICT = {
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
("Darwin", "x86_64"): ("macosx_10_14", "x86_64"),
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
("Windows", "AMD64"): ("win", "amd64"),
}
# TODO(vam): remove this once zstandard builds against Python 3.13
def get_zstandard():
if HERMETIC_PYTHON_VERSION == "3.13":
@ -268,7 +280,7 @@ def jax_multiplatform_test(
]
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
test_tags += ["manual"]
test_tags.append("manual")
if backend == "gpu":
test_tags += tf_cuda_tests_tags()
native.py_test(
@ -309,15 +321,60 @@ def jax_generate_backend_suites(backends = []):
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
)
def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version):
if no_abi:
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
else:
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
python_version = HERMETIC_PYTHON_VERSION.replace(".", "")
return wheel_name_template.format(
package_name = package_name,
python_version = python_version,
major_python_version = python_version[0],
wheel_version = wheel_version,
wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]),
)
def _jax_wheel_impl(ctx):
include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
output_path = ctx.attr.output_path[BuildSettingInfo].value
git_hash = ctx.attr.git_hash[BuildSettingInfo].value
executable = ctx.executable.wheel_binary
output = ctx.actions.declare_directory(ctx.label.name)
if include_cuda_libs and not override_include_cuda_libs:
fail("JAX wheel shouldn't be built directly against the CUDA libraries." +
" Please provide `--config=cuda_libraries_from_stubs` for bazel build command." +
" If you absolutely need to build links directly against the CUDA libraries, provide" +
" `--@local_config_cuda//cuda:override_include_cuda_libs=true`.")
env = {}
args = ctx.actions.args()
args.add("--output_path", output.path) # required argument
args.add("--cpu", ctx.attr.platform_tag) # required argument
jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path
args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument
full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX)
env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX
if BUILD_TAG:
env["WHEEL_VERSION_SUFFIX"] = ".dev{}+selfbuilt".format(BUILD_TAG)
full_wheel_version += env["WHEEL_VERSION_SUFFIX"]
if not WHEEL_VERSION_SUFFIX and not BUILD_TAG:
env["JAX_RELEASE"] = "1"
cpu = ctx.attr.cpu
platform_name = ctx.attr.platform_name
wheel_name = _get_full_wheel_name(
package_name = ctx.attr.wheel_name,
no_abi = ctx.attr.no_abi,
platform_name = platform_name,
cpu_name = cpu,
wheel_version = full_wheel_version,
)
output_file = ctx.actions.declare_file(output_path +
"/" + wheel_name)
wheel_dir = output_file.path[:output_file.path.rfind("/")]
args.add("--output_path", wheel_dir) # required argument
args.add("--cpu", cpu) # required argument
args.add("--jaxlib_git_hash", git_hash) # required argument
if ctx.attr.enable_cuda:
args.add("--enable-cuda", "True")
@ -336,11 +393,13 @@ def _jax_wheel_impl(ctx):
args.use_param_file("@%s", use_always = False)
ctx.actions.run(
arguments = [args],
inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [],
outputs = [output],
inputs = [],
outputs = [output_file],
executable = executable,
env = env,
)
return [DefaultInfo(files = depset(direct = [output]))]
return [DefaultInfo(files = depset(direct = [output_file]))]
_jax_wheel = rule(
attrs = {
@ -350,19 +409,25 @@ _jax_wheel = rule(
# b/365588895 Investigate cfg = "exec" for multi platform builds
cfg = "target",
),
"platform_tag": attr.string(mandatory = True),
"git_hash": attr.label(allow_single_file = True),
"wheel_name": attr.string(mandatory = True),
"no_abi": attr.bool(default = False),
"cpu": attr.string(mandatory = True),
"platform_name": attr.string(mandatory = True),
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
"enable_cuda": attr.bool(default = False),
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
"platform_version": attr.string(mandatory = True, default = ""),
"skip_gpu_kernels": attr.bool(default = False),
"enable_rocm": attr.bool(default = False),
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
},
implementation = _jax_wheel_impl,
executable = False,
)
def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""):
"""Create jax artifact wheels.
Common artifact attributes are grouped within a single macro.
@ -370,6 +435,8 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
Args:
name: the name of the wheel
wheel_binary: the binary to use to build the wheel
wheel_name: the name of the wheel
no_abi: whether to build a wheel without ABI
enable_cuda: whether to build a cuda wheel
platform_version: the cuda version to use for the wheel
@ -379,18 +446,20 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
_jax_wheel(
name = name,
wheel_binary = wheel_binary,
wheel_name = wheel_name,
no_abi = no_abi,
enable_cuda = enable_cuda,
platform_version = platform_version,
# Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to
# pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to
# the git hash file needs to be created first.
git_hash = select({
"//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink",
"//conditions:default": None,
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
# flag in bazel command to pass the git hash for nightly or release builds.
platform_name = select({
"@platforms//os:osx": "Darwin",
"@platforms//os:macos": "Darwin",
"@platforms//os:windows": "Windows",
"@platforms//os:linux": "Linux",
}),
# Following the convention in jax/tools/build_utils.py.
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
platform_tag = select({
cpu = select({
"//jaxlib/tools:macos_arm64": "arm64",
"//jaxlib/tools:win_amd64": "AMD64",
"//jaxlib/tools:arm64": "aarch64",

View File

@ -0,0 +1,43 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Repository rule to generate a file with JAX wheel version. """
def _jax_python_wheel_repository_impl(repository_ctx):
version_source = repository_ctx.attr.version_source
version_key = repository_ctx.attr.version_key
version_file_content = repository_ctx.read(
repository_ctx.path(version_source),
)
version_start_index = version_file_content.find(version_key)
version_end_index = version_start_index + version_file_content[version_start_index:].find("\n")
wheel_version = version_file_content[version_start_index:version_end_index].replace(
version_key,
"WHEEL_VERSION",
)
repository_ctx.file(
"wheel.bzl",
wheel_version,
)
repository_ctx.file("BUILD", "")
jax_python_wheel_repository = repository_rule(
implementation = _jax_python_wheel_repository_impl,
attrs = {
"version_source": attr.label(mandatory = True, allow_single_file = True),
"version_key": attr.string(mandatory = True),
},
)

View File

@ -18,12 +18,38 @@ load("@bazel_skylib//lib:selects.bzl", "selects")
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel")
load(
"@tsl//third_party/py:py_manylinux_compliance_test.bzl",
"verify_manylinux_compliance_test",
)
load(
"//jaxlib:jax.bzl",
"PLATFORM_TAGS_DICT",
"if_windows",
"jax_py_test",
"jax_wheel",
"pytype_strict_library",
)
licenses(["notice"]) # Apache 2
package(default_visibility = ["//visibility:public"])
genrule(
name = "platform_tags_py",
srcs = [],
outs = ["platform_tags.py"],
cmd = "echo 'PLATFORM_TAGS_DICT = %s' > $@;" % PLATFORM_TAGS_DICT,
)
pytype_strict_library(
name = "build_utils",
srcs = [
"build_utils.py",
":platform_tags_py",
],
)
py_binary(
name = "build_wheel",
srcs = ["build_wheel.py"],
@ -41,7 +67,7 @@ py_binary(
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]),
deps = [
"//jax/tools:build_utils",
":build_utils",
"@bazel_tools//tools/python/runfiles",
"@pypi_build//:pkg",
"@pypi_setuptools//:pkg",
@ -99,7 +125,7 @@ py_binary(
"//jax_plugins/rocm:__init__.py",
]),
deps = [
"//jax/tools:build_utils",
":build_utils",
"@bazel_tools//tools/python/runfiles",
"@pypi_build//:pkg",
"@pypi_setuptools//:pkg",
@ -128,7 +154,7 @@ py_binary(
"//jax_plugins/rocm:plugin_setup.py",
]),
deps = [
"//jax/tools:build_utils",
":build_utils",
"@bazel_tools//tools/python/runfiles",
"@pypi_build//:pkg",
"@pypi_setuptools//:pkg",
@ -173,30 +199,73 @@ string_flag(
build_setting_default = "",
)
config_setting(
name = "jaxlib_git_hash_nightly_or_release",
flag_values = {
":jaxlib_git_hash": "nightly",
},
string_flag(
name = "output_path",
build_setting_default = "dist",
)
jax_wheel(
name = "jaxlib_wheel",
no_abi = False,
wheel_binary = ":build_wheel",
wheel_name = "jaxlib",
)
jax_wheel(
name = "jax_cuda_plugin_wheel",
enable_cuda = True,
no_abi = False,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
wheel_binary = ":build_gpu_kernels_wheel",
wheel_name = "jax_cuda12_plugin",
)
jax_wheel(
name = "jax_cuda_pjrt_wheel",
enable_cuda = True,
no_abi = True,
# TODO(b/371217563) May use hermetic cuda version here.
platform_version = "12",
wheel_binary = ":build_gpu_plugin_wheel",
wheel_name = "jax_cuda12_pjrt",
)
AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")])
PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])
X86_64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "x86_64")])
verify_manylinux_compliance_test(
name = "jaxlib_manylinux_compliance_test",
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
test_tags = [
"manual",
],
wheel = ":jaxlib_wheel",
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)
verify_manylinux_compliance_test(
name = "jax_cuda_plugin_manylinux_compliance_test",
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
test_tags = [
"manual",
],
wheel = ":jax_cuda_plugin_wheel",
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)
verify_manylinux_compliance_test(
name = "jax_cuda_pjrt_manylinux_compliance_test",
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
test_tags = [
"manual",
],
wheel = ":jax_cuda_pjrt_wheel",
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
)

View File

@ -24,7 +24,7 @@ import pathlib
import tempfile
from bazel_tools.tools.python.runfiles import runfiles
from jax.tools import build_utils
from jaxlib.tools import build_utils
parser = argparse.ArgumentParser()
parser.add_argument(
@ -174,12 +174,11 @@ try:
if args.editable:
build_utils.build_editable(sources_path, args.output_path, package_name)
else:
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
build_utils.build_wheel(
sources_path,
args.output_path,
package_name,
git_hash=git_hash,
git_hash=args.jaxlib_git_hash,
)
finally:
tmpdir.cleanup()

View File

@ -24,7 +24,7 @@ import pathlib
import tempfile
from bazel_tools.tools.python.runfiles import runfiles
from jax.tools import build_utils
from jaxlib.tools import build_utils
parser = argparse.ArgumentParser()
parser.add_argument(
@ -167,12 +167,11 @@ try:
if args.editable:
build_utils.build_editable(sources_path, args.output_path, package_name)
else:
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
build_utils.build_wheel(
sources_path,
args.output_path,
package_name,
git_hash=git_hash,
git_hash=args.jaxlib_git_hash,
)
finally:
if tmpdir:

View File

@ -24,6 +24,7 @@ import sys
import subprocess
import glob
from collections.abc import Sequence
from jaxlib.tools import platform_tags
def is_windows() -> bool:
@ -52,21 +53,11 @@ def copy_file(
def platform_tag(cpu: str) -> str:
platform_name, cpu_name = {
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
("Darwin", "x86_64"): ("macosx_10_14", "x86_64"),
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
("Windows", "AMD64"): ("win", "amd64"),
}[(platform.system(), cpu)]
platform_name, cpu_name = platform_tags.PLATFORM_TAGS_DICT[
(platform.system(), cpu)
]
return f"{platform_name}_{cpu_name}"
def get_githash(jaxlib_git_hash):
if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash):
with open(jaxlib_git_hash, "r") as f:
return f.readline().strip()
return jaxlib_git_hash
def build_wheel(
sources_path: str, output_path: str, package_name: str, git_hash: str = ""

View File

@ -27,7 +27,7 @@ import subprocess
import tempfile
from bazel_tools.tools.python.runfiles import runfiles
from jax.tools import build_utils
from jaxlib.tools import build_utils
parser = argparse.ArgumentParser()
parser.add_argument(
@ -387,8 +387,12 @@ try:
if args.editable:
build_utils.build_editable(sources_path, args.output_path, package_name)
else:
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=git_hash)
build_utils.build_wheel(
sources_path,
args.output_path,
package_name,
git_hash=args.jaxlib_git_hash,
)
finally:
if tmpdir:
tmpdir.cleanup()

View File

@ -104,6 +104,7 @@ class JaxVersionTest(unittest.TestCase):
self.assertEqual(version, "1.2.3.dev4567")
self.assertValidVersion(version)
@jtu.thread_unsafe_test() # Setting environment variables is not thread-safe.
@patch_jax_version("1.2.3", None)
def testBuildVersionFromEnvironment(self):
# This test covers build-time construction of version strings in the
@ -157,6 +158,18 @@ class JaxVersionTest(unittest.TestCase):
self.assertTrue(version.endswith("test"))
self.assertValidVersion(version)
with jtu.set_env(
JAX_RELEASE=None,
JAXLIB_RELEASE=None,
JAX_NIGHTLY=None,
JAXLIB_NIGHTLY="1",
WHEEL_VERSION_SUFFIX=".dev20250101+1c0f1076erc1",
):
with assert_no_subprocess_call():
version = jax.version._get_version_for_build()
self.assertEqual(version, f"{base_version}.dev20250101+1c0f1076erc1")
self.assertValidVersion(version)
def testVersions(self):
check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3",
minimum_jaxlib_version="1.2.3")