mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results.
This change is a part of the initiative to test the JAX wheels in the presubmit properly. The list of the changes: 1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`. 2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase. 3. The version suffix of the wheel in the build rule output depends on the environment variables. The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables. 4. Environment variables combinations for creating wheels with different versions: * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot` * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release` * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1` * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)` PiperOrigin-RevId: 723552265
This commit is contained in:
parent
9f53dfae0b
commit
d424f5b5b3
15
WORKSPACE
15
WORKSPACE
@ -62,6 +62,21 @@ xla_workspace0()
|
||||
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
|
||||
flatbuffers()
|
||||
|
||||
load("//jaxlib:jax_python_wheel.bzl", "jax_python_wheel_repository")
|
||||
jax_python_wheel_repository(
|
||||
name = "jax_wheel",
|
||||
version_key = "_version",
|
||||
version_source = "//jax:version.py",
|
||||
)
|
||||
|
||||
load(
|
||||
"@tsl//third_party/py:python_wheel.bzl",
|
||||
"python_wheel_version_suffix_repository",
|
||||
)
|
||||
python_wheel_version_suffix_repository(
|
||||
name = "jax_wheel_version_suffix",
|
||||
)
|
||||
|
||||
load(
|
||||
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
|
||||
"cuda_json_init_repository",
|
||||
|
@ -16,7 +16,6 @@ load("@rules_python//python:defs.bzl", "py_library")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
@ -46,8 +45,3 @@ py_library(
|
||||
"//jax/experimental/jax2tf",
|
||||
] + py_deps("tensorflow_core"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "build_utils",
|
||||
srcs = ["build_utils.py"],
|
||||
)
|
||||
|
@ -35,6 +35,8 @@ def _get_version_string() -> str:
|
||||
# In this case we return it directly.
|
||||
if _release_version is not None:
|
||||
return _release_version
|
||||
if os.getenv("WHEEL_VERSION_SUFFIX"):
|
||||
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
|
||||
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
|
||||
|
||||
|
||||
@ -71,16 +73,23 @@ def _get_version_for_build() -> str:
|
||||
"""Determine the version at build time.
|
||||
|
||||
The returned version string depends on which environment variables are set:
|
||||
- if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc"
|
||||
Here the WHEEL_VERSION_SUFFIX value is ".dev20230906+ge58560fdc".
|
||||
Please note that the WHEEL_VERSION_SUFFIX value is not the same as the
|
||||
JAX_CUSTOM_VERSION_SUFFIX value, and WHEEL_VERSION_SUFFIX is set by Bazel
|
||||
wheel build rule.
|
||||
- if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16"
|
||||
- if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906"
|
||||
- if none are set: version looks like "0.4.16.dev20230906+ge58560fdc
|
||||
"""
|
||||
if _release_version is not None:
|
||||
return _release_version
|
||||
if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'):
|
||||
return _version_from_todays_date(_version)
|
||||
if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'):
|
||||
if os.getenv("WHEEL_VERSION_SUFFIX"):
|
||||
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
|
||||
if os.getenv("JAX_RELEASE") or os.getenv("JAXLIB_RELEASE"):
|
||||
return _version
|
||||
if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"):
|
||||
return _version_from_todays_date(_version)
|
||||
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
|
||||
|
||||
|
||||
|
109
jaxlib/jax.bzl
109
jaxlib/jax.bzl
@ -14,7 +14,10 @@
|
||||
|
||||
"""Bazel macros used by the JAX build."""
|
||||
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
|
||||
load("@com_github_google_flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
|
||||
load("@jax_wheel//:wheel.bzl", "WHEEL_VERSION")
|
||||
load("@jax_wheel_version_suffix//:wheel_version_suffix.bzl", "BUILD_TAG", "WHEEL_VERSION_SUFFIX")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
|
||||
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")
|
||||
@ -50,6 +53,15 @@ jax_internal_test_harnesses_visibility = []
|
||||
jax_test_util_visibility = []
|
||||
loops_visibility = []
|
||||
|
||||
PLATFORM_TAGS_DICT = {
|
||||
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
|
||||
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
|
||||
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
|
||||
("Darwin", "x86_64"): ("macosx_10_14", "x86_64"),
|
||||
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
|
||||
("Windows", "AMD64"): ("win", "amd64"),
|
||||
}
|
||||
|
||||
# TODO(vam): remove this once zstandard builds against Python 3.13
|
||||
def get_zstandard():
|
||||
if HERMETIC_PYTHON_VERSION == "3.13":
|
||||
@ -268,7 +280,7 @@ def jax_multiplatform_test(
|
||||
]
|
||||
test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, [])
|
||||
if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]):
|
||||
test_tags += ["manual"]
|
||||
test_tags.append("manual")
|
||||
if backend == "gpu":
|
||||
test_tags += tf_cuda_tests_tags()
|
||||
native.py_test(
|
||||
@ -309,15 +321,60 @@ def jax_generate_backend_suites(backends = []):
|
||||
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
|
||||
)
|
||||
|
||||
def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version):
|
||||
if no_abi:
|
||||
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
|
||||
else:
|
||||
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
|
||||
python_version = HERMETIC_PYTHON_VERSION.replace(".", "")
|
||||
return wheel_name_template.format(
|
||||
package_name = package_name,
|
||||
python_version = python_version,
|
||||
major_python_version = python_version[0],
|
||||
wheel_version = wheel_version,
|
||||
wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]),
|
||||
)
|
||||
|
||||
def _jax_wheel_impl(ctx):
|
||||
include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value
|
||||
override_include_cuda_libs = ctx.attr.override_include_cuda_libs[BuildSettingInfo].value
|
||||
output_path = ctx.attr.output_path[BuildSettingInfo].value
|
||||
git_hash = ctx.attr.git_hash[BuildSettingInfo].value
|
||||
executable = ctx.executable.wheel_binary
|
||||
|
||||
output = ctx.actions.declare_directory(ctx.label.name)
|
||||
if include_cuda_libs and not override_include_cuda_libs:
|
||||
fail("JAX wheel shouldn't be built directly against the CUDA libraries." +
|
||||
" Please provide `--config=cuda_libraries_from_stubs` for bazel build command." +
|
||||
" If you absolutely need to build links directly against the CUDA libraries, provide" +
|
||||
" `--@local_config_cuda//cuda:override_include_cuda_libs=true`.")
|
||||
|
||||
env = {}
|
||||
args = ctx.actions.args()
|
||||
args.add("--output_path", output.path) # required argument
|
||||
args.add("--cpu", ctx.attr.platform_tag) # required argument
|
||||
jaxlib_git_hash = "" if ctx.file.git_hash == None else ctx.file.git_hash.path
|
||||
args.add("--jaxlib_git_hash", jaxlib_git_hash) # required argument
|
||||
|
||||
full_wheel_version = (WHEEL_VERSION + WHEEL_VERSION_SUFFIX)
|
||||
env["WHEEL_VERSION_SUFFIX"] = WHEEL_VERSION_SUFFIX
|
||||
if BUILD_TAG:
|
||||
env["WHEEL_VERSION_SUFFIX"] = ".dev{}+selfbuilt".format(BUILD_TAG)
|
||||
full_wheel_version += env["WHEEL_VERSION_SUFFIX"]
|
||||
if not WHEEL_VERSION_SUFFIX and not BUILD_TAG:
|
||||
env["JAX_RELEASE"] = "1"
|
||||
|
||||
cpu = ctx.attr.cpu
|
||||
platform_name = ctx.attr.platform_name
|
||||
wheel_name = _get_full_wheel_name(
|
||||
package_name = ctx.attr.wheel_name,
|
||||
no_abi = ctx.attr.no_abi,
|
||||
platform_name = platform_name,
|
||||
cpu_name = cpu,
|
||||
wheel_version = full_wheel_version,
|
||||
)
|
||||
output_file = ctx.actions.declare_file(output_path +
|
||||
"/" + wheel_name)
|
||||
wheel_dir = output_file.path[:output_file.path.rfind("/")]
|
||||
|
||||
args.add("--output_path", wheel_dir) # required argument
|
||||
args.add("--cpu", cpu) # required argument
|
||||
args.add("--jaxlib_git_hash", git_hash) # required argument
|
||||
|
||||
if ctx.attr.enable_cuda:
|
||||
args.add("--enable-cuda", "True")
|
||||
@ -336,11 +393,13 @@ def _jax_wheel_impl(ctx):
|
||||
args.use_param_file("@%s", use_always = False)
|
||||
ctx.actions.run(
|
||||
arguments = [args],
|
||||
inputs = [ctx.file.git_hash] if ctx.file.git_hash != None else [],
|
||||
outputs = [output],
|
||||
inputs = [],
|
||||
outputs = [output_file],
|
||||
executable = executable,
|
||||
env = env,
|
||||
)
|
||||
return [DefaultInfo(files = depset(direct = [output]))]
|
||||
|
||||
return [DefaultInfo(files = depset(direct = [output_file]))]
|
||||
|
||||
_jax_wheel = rule(
|
||||
attrs = {
|
||||
@ -350,19 +409,25 @@ _jax_wheel = rule(
|
||||
# b/365588895 Investigate cfg = "exec" for multi platform builds
|
||||
cfg = "target",
|
||||
),
|
||||
"platform_tag": attr.string(mandatory = True),
|
||||
"git_hash": attr.label(allow_single_file = True),
|
||||
"wheel_name": attr.string(mandatory = True),
|
||||
"no_abi": attr.bool(default = False),
|
||||
"cpu": attr.string(mandatory = True),
|
||||
"platform_name": attr.string(mandatory = True),
|
||||
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
|
||||
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
|
||||
"enable_cuda": attr.bool(default = False),
|
||||
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
|
||||
"platform_version": attr.string(mandatory = True, default = ""),
|
||||
"skip_gpu_kernels": attr.bool(default = False),
|
||||
"enable_rocm": attr.bool(default = False),
|
||||
"include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")),
|
||||
"override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")),
|
||||
},
|
||||
implementation = _jax_wheel_impl,
|
||||
executable = False,
|
||||
)
|
||||
|
||||
def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
|
||||
def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""):
|
||||
"""Create jax artifact wheels.
|
||||
|
||||
Common artifact attributes are grouped within a single macro.
|
||||
@ -370,6 +435,8 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
|
||||
Args:
|
||||
name: the name of the wheel
|
||||
wheel_binary: the binary to use to build the wheel
|
||||
wheel_name: the name of the wheel
|
||||
no_abi: whether to build a wheel without ABI
|
||||
enable_cuda: whether to build a cuda wheel
|
||||
platform_version: the cuda version to use for the wheel
|
||||
|
||||
@ -379,18 +446,20 @@ def jax_wheel(name, wheel_binary, enable_cuda = False, platform_version = ""):
|
||||
_jax_wheel(
|
||||
name = name,
|
||||
wheel_binary = wheel_binary,
|
||||
wheel_name = wheel_name,
|
||||
no_abi = no_abi,
|
||||
enable_cuda = enable_cuda,
|
||||
platform_version = platform_version,
|
||||
# Empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=nightly` flag in bazel command to
|
||||
# pass the git hash for nightly or release builds. Note that the symlink git_hash_symlink to
|
||||
# the git hash file needs to be created first.
|
||||
git_hash = select({
|
||||
"//jaxlib/tools:jaxlib_git_hash_nightly_or_release": "git_hash_symlink",
|
||||
"//conditions:default": None,
|
||||
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
|
||||
# flag in bazel command to pass the git hash for nightly or release builds.
|
||||
platform_name = select({
|
||||
"@platforms//os:osx": "Darwin",
|
||||
"@platforms//os:macos": "Darwin",
|
||||
"@platforms//os:windows": "Windows",
|
||||
"@platforms//os:linux": "Linux",
|
||||
}),
|
||||
# Following the convention in jax/tools/build_utils.py.
|
||||
# TODO(kanglan) Add @platforms//cpu:ppc64le once JAX Bazel is upgraded > 6.5.0.
|
||||
platform_tag = select({
|
||||
cpu = select({
|
||||
"//jaxlib/tools:macos_arm64": "arm64",
|
||||
"//jaxlib/tools:win_amd64": "AMD64",
|
||||
"//jaxlib/tools:arm64": "aarch64",
|
||||
|
43
jaxlib/jax_python_wheel.bzl
Normal file
43
jaxlib/jax_python_wheel.bzl
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" Repository rule to generate a file with JAX wheel version. """
|
||||
|
||||
def _jax_python_wheel_repository_impl(repository_ctx):
|
||||
version_source = repository_ctx.attr.version_source
|
||||
version_key = repository_ctx.attr.version_key
|
||||
|
||||
version_file_content = repository_ctx.read(
|
||||
repository_ctx.path(version_source),
|
||||
)
|
||||
version_start_index = version_file_content.find(version_key)
|
||||
version_end_index = version_start_index + version_file_content[version_start_index:].find("\n")
|
||||
|
||||
wheel_version = version_file_content[version_start_index:version_end_index].replace(
|
||||
version_key,
|
||||
"WHEEL_VERSION",
|
||||
)
|
||||
repository_ctx.file(
|
||||
"wheel.bzl",
|
||||
wheel_version,
|
||||
)
|
||||
repository_ctx.file("BUILD", "")
|
||||
|
||||
jax_python_wheel_repository = repository_rule(
|
||||
implementation = _jax_python_wheel_repository_impl,
|
||||
attrs = {
|
||||
"version_source": attr.label(mandatory = True, allow_single_file = True),
|
||||
"version_key": attr.string(mandatory = True),
|
||||
},
|
||||
)
|
@ -18,12 +18,38 @@ load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
load("//jaxlib:jax.bzl", "if_windows", "jax_py_test", "jax_wheel")
|
||||
load(
|
||||
"@tsl//third_party/py:py_manylinux_compliance_test.bzl",
|
||||
"verify_manylinux_compliance_test",
|
||||
)
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"PLATFORM_TAGS_DICT",
|
||||
"if_windows",
|
||||
"jax_py_test",
|
||||
"jax_wheel",
|
||||
"pytype_strict_library",
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
genrule(
|
||||
name = "platform_tags_py",
|
||||
srcs = [],
|
||||
outs = ["platform_tags.py"],
|
||||
cmd = "echo 'PLATFORM_TAGS_DICT = %s' > $@;" % PLATFORM_TAGS_DICT,
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "build_utils",
|
||||
srcs = [
|
||||
"build_utils.py",
|
||||
":platform_tags_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "build_wheel",
|
||||
srcs = ["build_wheel.py"],
|
||||
@ -41,7 +67,7 @@ py_binary(
|
||||
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
|
||||
]),
|
||||
deps = [
|
||||
"//jax/tools:build_utils",
|
||||
":build_utils",
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
"@pypi_build//:pkg",
|
||||
"@pypi_setuptools//:pkg",
|
||||
@ -99,7 +125,7 @@ py_binary(
|
||||
"//jax_plugins/rocm:__init__.py",
|
||||
]),
|
||||
deps = [
|
||||
"//jax/tools:build_utils",
|
||||
":build_utils",
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
"@pypi_build//:pkg",
|
||||
"@pypi_setuptools//:pkg",
|
||||
@ -128,7 +154,7 @@ py_binary(
|
||||
"//jax_plugins/rocm:plugin_setup.py",
|
||||
]),
|
||||
deps = [
|
||||
"//jax/tools:build_utils",
|
||||
":build_utils",
|
||||
"@bazel_tools//tools/python/runfiles",
|
||||
"@pypi_build//:pkg",
|
||||
"@pypi_setuptools//:pkg",
|
||||
@ -173,30 +199,73 @@ string_flag(
|
||||
build_setting_default = "",
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "jaxlib_git_hash_nightly_or_release",
|
||||
flag_values = {
|
||||
":jaxlib_git_hash": "nightly",
|
||||
},
|
||||
string_flag(
|
||||
name = "output_path",
|
||||
build_setting_default = "dist",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jaxlib_wheel",
|
||||
no_abi = False,
|
||||
wheel_binary = ":build_wheel",
|
||||
wheel_name = "jaxlib",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_cuda_plugin_wheel",
|
||||
enable_cuda = True,
|
||||
no_abi = False,
|
||||
# TODO(b/371217563) May use hermetic cuda version here.
|
||||
platform_version = "12",
|
||||
wheel_binary = ":build_gpu_kernels_wheel",
|
||||
wheel_name = "jax_cuda12_plugin",
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_cuda_pjrt_wheel",
|
||||
enable_cuda = True,
|
||||
no_abi = True,
|
||||
# TODO(b/371217563) May use hermetic cuda version here.
|
||||
platform_version = "12",
|
||||
wheel_binary = ":build_gpu_plugin_wheel",
|
||||
wheel_name = "jax_cuda12_pjrt",
|
||||
)
|
||||
|
||||
AARCH64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "aarch64")])
|
||||
|
||||
PPC64LE_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "ppc64le")])
|
||||
|
||||
X86_64_MANYLINUX_TAG = "_".join(PLATFORM_TAGS_DICT[("Linux", "x86_64")])
|
||||
|
||||
verify_manylinux_compliance_test(
|
||||
name = "jaxlib_manylinux_compliance_test",
|
||||
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
|
||||
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
|
||||
test_tags = [
|
||||
"manual",
|
||||
],
|
||||
wheel = ":jaxlib_wheel",
|
||||
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
|
||||
)
|
||||
|
||||
verify_manylinux_compliance_test(
|
||||
name = "jax_cuda_plugin_manylinux_compliance_test",
|
||||
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
|
||||
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
|
||||
test_tags = [
|
||||
"manual",
|
||||
],
|
||||
wheel = ":jax_cuda_plugin_wheel",
|
||||
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
|
||||
)
|
||||
|
||||
verify_manylinux_compliance_test(
|
||||
name = "jax_cuda_pjrt_manylinux_compliance_test",
|
||||
aarch64_compliance_tag = AARCH64_MANYLINUX_TAG,
|
||||
ppc64le_compliance_tag = PPC64LE_MANYLINUX_TAG,
|
||||
test_tags = [
|
||||
"manual",
|
||||
],
|
||||
wheel = ":jax_cuda_pjrt_wheel",
|
||||
x86_64_compliance_tag = X86_64_MANYLINUX_TAG,
|
||||
)
|
||||
|
@ -24,7 +24,7 @@ import pathlib
|
||||
import tempfile
|
||||
|
||||
from bazel_tools.tools.python.runfiles import runfiles
|
||||
from jax.tools import build_utils
|
||||
from jaxlib.tools import build_utils
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -174,12 +174,11 @@ try:
|
||||
if args.editable:
|
||||
build_utils.build_editable(sources_path, args.output_path, package_name)
|
||||
else:
|
||||
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
|
||||
build_utils.build_wheel(
|
||||
sources_path,
|
||||
args.output_path,
|
||||
package_name,
|
||||
git_hash=git_hash,
|
||||
git_hash=args.jaxlib_git_hash,
|
||||
)
|
||||
finally:
|
||||
tmpdir.cleanup()
|
||||
|
@ -24,7 +24,7 @@ import pathlib
|
||||
import tempfile
|
||||
|
||||
from bazel_tools.tools.python.runfiles import runfiles
|
||||
from jax.tools import build_utils
|
||||
from jaxlib.tools import build_utils
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -167,12 +167,11 @@ try:
|
||||
if args.editable:
|
||||
build_utils.build_editable(sources_path, args.output_path, package_name)
|
||||
else:
|
||||
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
|
||||
build_utils.build_wheel(
|
||||
sources_path,
|
||||
args.output_path,
|
||||
package_name,
|
||||
git_hash=git_hash,
|
||||
git_hash=args.jaxlib_git_hash,
|
||||
)
|
||||
finally:
|
||||
if tmpdir:
|
||||
|
@ -24,6 +24,7 @@ import sys
|
||||
import subprocess
|
||||
import glob
|
||||
from collections.abc import Sequence
|
||||
from jaxlib.tools import platform_tags
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
@ -52,21 +53,11 @@ def copy_file(
|
||||
|
||||
|
||||
def platform_tag(cpu: str) -> str:
|
||||
platform_name, cpu_name = {
|
||||
("Linux", "x86_64"): ("manylinux2014", "x86_64"),
|
||||
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
|
||||
("Linux", "ppc64le"): ("manylinux2014", "ppc64le"),
|
||||
("Darwin", "x86_64"): ("macosx_10_14", "x86_64"),
|
||||
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
|
||||
("Windows", "AMD64"): ("win", "amd64"),
|
||||
}[(platform.system(), cpu)]
|
||||
platform_name, cpu_name = platform_tags.PLATFORM_TAGS_DICT[
|
||||
(platform.system(), cpu)
|
||||
]
|
||||
return f"{platform_name}_{cpu_name}"
|
||||
|
||||
def get_githash(jaxlib_git_hash):
|
||||
if jaxlib_git_hash != "" and os.path.isfile(jaxlib_git_hash):
|
||||
with open(jaxlib_git_hash, "r") as f:
|
||||
return f.readline().strip()
|
||||
return jaxlib_git_hash
|
||||
|
||||
def build_wheel(
|
||||
sources_path: str, output_path: str, package_name: str, git_hash: str = ""
|
@ -27,7 +27,7 @@ import subprocess
|
||||
import tempfile
|
||||
|
||||
from bazel_tools.tools.python.runfiles import runfiles
|
||||
from jax.tools import build_utils
|
||||
from jaxlib.tools import build_utils
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -387,8 +387,12 @@ try:
|
||||
if args.editable:
|
||||
build_utils.build_editable(sources_path, args.output_path, package_name)
|
||||
else:
|
||||
git_hash = build_utils.get_githash(args.jaxlib_git_hash)
|
||||
build_utils.build_wheel(sources_path, args.output_path, package_name, git_hash=git_hash)
|
||||
build_utils.build_wheel(
|
||||
sources_path,
|
||||
args.output_path,
|
||||
package_name,
|
||||
git_hash=args.jaxlib_git_hash,
|
||||
)
|
||||
finally:
|
||||
if tmpdir:
|
||||
tmpdir.cleanup()
|
||||
|
@ -104,6 +104,7 @@ class JaxVersionTest(unittest.TestCase):
|
||||
self.assertEqual(version, "1.2.3.dev4567")
|
||||
self.assertValidVersion(version)
|
||||
|
||||
@jtu.thread_unsafe_test() # Setting environment variables is not thread-safe.
|
||||
@patch_jax_version("1.2.3", None)
|
||||
def testBuildVersionFromEnvironment(self):
|
||||
# This test covers build-time construction of version strings in the
|
||||
@ -157,6 +158,18 @@ class JaxVersionTest(unittest.TestCase):
|
||||
self.assertTrue(version.endswith("test"))
|
||||
self.assertValidVersion(version)
|
||||
|
||||
with jtu.set_env(
|
||||
JAX_RELEASE=None,
|
||||
JAXLIB_RELEASE=None,
|
||||
JAX_NIGHTLY=None,
|
||||
JAXLIB_NIGHTLY="1",
|
||||
WHEEL_VERSION_SUFFIX=".dev20250101+1c0f1076erc1",
|
||||
):
|
||||
with assert_no_subprocess_call():
|
||||
version = jax.version._get_version_for_build()
|
||||
self.assertEqual(version, f"{base_version}.dev20250101+1c0f1076erc1")
|
||||
self.assertValidVersion(version)
|
||||
|
||||
def testVersions(self):
|
||||
check_jaxlib_version(jax_version="1.2.3", jaxlib_version="1.2.3",
|
||||
minimum_jaxlib_version="1.2.3")
|
||||
|
Loading…
x
Reference in New Issue
Block a user