mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00

List of changes: 1. Allow us to build a RC wheel when building release artifacts. This is done by modifying the build CLI to use the new JAX build rule and passing in the build options that control the wheel tag. A new build argument `use_new_wheel_build_rule` is introduced to the build CLI to avoid breaking anyone that uses the CLI and the old build rule. Note that this option will go way in the future when the build CLI migrates fully to the new build rule. 2. Change the upload script to upload both rc and release tagged wheels (changes internal) PiperOrigin-RevId: 733464219
701 lines
22 KiB
Python
Executable File
701 lines
22 KiB
Python
Executable File
#!/usr/bin/python
|
|
#
|
|
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# CLI for building JAX wheel packages from source and for updating the
|
|
# requirements_lock.txt files
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import platform
|
|
import sys
|
|
import copy
|
|
|
|
from tools import command, utils
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BANNER = r"""
|
|
_ _ __ __
|
|
| | / \ \ \/ /
|
|
_ | |/ _ \ \ /
|
|
| |_| / ___ \/ \
|
|
\___/_/ \/_/\_\
|
|
|
|
"""
|
|
|
|
EPILOG = """
|
|
From the root directory of the JAX repository, run
|
|
`python build/build.py build --wheels=<list of JAX wheels>` to build JAX
|
|
artifacts.
|
|
|
|
Multiple wheels can be built with a single invocation of the CLI.
|
|
E.g. python build/build.py build --wheels=jaxlib,jax-cuda-plugin
|
|
|
|
To update the requirements_lock.txt files, run
|
|
`python build/build.py requirements_update`
|
|
"""
|
|
|
|
# Define the build target for each wheel.
|
|
WHEEL_BUILD_TARGET_DICT = {
|
|
"jaxlib": "//jaxlib/tools:build_wheel",
|
|
"jax-cuda-plugin": "//jaxlib/tools:build_gpu_kernels_wheel",
|
|
"jax-cuda-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
|
|
"jax-rocm-plugin": "//jaxlib/tools:build_gpu_kernels_wheel",
|
|
"jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
|
|
}
|
|
|
|
# Dictionary with the new wheel build rule. Note that when JAX migrates to the
|
|
# new wheel build rule fully, the build CLI will switch to the new wheel build
|
|
# rule as the default.
|
|
WHEEL_BUILD_TARGET_DICT_NEW = {
|
|
"jax": "//:jax_wheel",
|
|
"jaxlib": "//jaxlib/tools:jaxlib_wheel",
|
|
"jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel",
|
|
"jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel",
|
|
"jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel",
|
|
"jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel",
|
|
}
|
|
|
|
def add_global_arguments(parser: argparse.ArgumentParser):
|
|
"""Adds all the global arguments that applies to all the CLI subcommands."""
|
|
parser.add_argument(
|
|
"--python_version",
|
|
type=str,
|
|
default=f"{sys.version_info.major}.{sys.version_info.minor}",
|
|
help=
|
|
"""
|
|
Hermetic Python version to use. Default is to use the version of the
|
|
Python binary that executed the CLI.
|
|
""",
|
|
)
|
|
|
|
bazel_group = parser.add_argument_group('Bazel Options')
|
|
bazel_group.add_argument(
|
|
"--bazel_path",
|
|
type=str,
|
|
default="",
|
|
help="""
|
|
Path to the Bazel binary to use. The default is to find bazel via the
|
|
PATH; if none is found, downloads a fresh copy of Bazel from GitHub.
|
|
""",
|
|
)
|
|
|
|
bazel_group.add_argument(
|
|
"--bazel_startup_options",
|
|
action="append",
|
|
default=[],
|
|
help="""
|
|
Additional startup options to pass to Bazel, can be specified multiple
|
|
times to pass multiple options.
|
|
E.g. --bazel_startup_options='--nobatch'
|
|
""",
|
|
)
|
|
|
|
bazel_group.add_argument(
|
|
"--bazel_options",
|
|
action="append",
|
|
default=[],
|
|
help="""
|
|
Additional build options to pass to Bazel, can be specified multiple
|
|
times to pass multiple options.
|
|
E.g. --bazel_options='--local_resources=HOST_CPUS'
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dry_run",
|
|
action="store_true",
|
|
help="Prints the Bazel command that is going to be executed.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--verbose",
|
|
action="store_true",
|
|
help="Produce verbose output for debugging.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--detailed_timestamped_log",
|
|
action="store_true",
|
|
help="""
|
|
Enable detailed logging of the Bazel command with timestamps. The logs
|
|
will be stored and can be accessed as artifacts.
|
|
""",
|
|
)
|
|
|
|
|
|
def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
|
|
"""Adds all the arguments that applies to the artifact subcommands."""
|
|
parser.add_argument(
|
|
"--wheels",
|
|
type=str,
|
|
default="jaxlib",
|
|
help=
|
|
"""
|
|
A comma separated list of JAX wheels to build. E.g: --wheels="jaxlib",
|
|
--wheels="jaxlib,jax-cuda-plugin", etc.
|
|
Valid options are: jaxlib, jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,
|
|
jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--use_new_wheel_build_rule",
|
|
action="store_true",
|
|
help=
|
|
"""
|
|
Whether to use the new wheel build rule. Temporary flag and will be
|
|
removed once JAX migrates to the new wheel build rule fully.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--editable",
|
|
action="store_true",
|
|
help="Create an 'editable' build instead of a wheel.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output_path",
|
|
type=str,
|
|
default=os.path.join(os.getcwd(), "dist"),
|
|
help="Directory to which the JAX wheel packages should be written.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--configure_only",
|
|
action="store_true",
|
|
help="""
|
|
If true, writes the Bazel options to the .jax_configure.bazelrc file but
|
|
does not build the artifacts.
|
|
""",
|
|
)
|
|
|
|
# CUDA Options
|
|
cuda_group = parser.add_argument_group('CUDA Options')
|
|
cuda_group.add_argument(
|
|
"--cuda_version",
|
|
type=str,
|
|
help=
|
|
"""
|
|
Hermetic CUDA version to use. Default is to use the version specified
|
|
in the .bazelrc.
|
|
""",
|
|
)
|
|
|
|
cuda_group.add_argument(
|
|
"--cuda_major_version",
|
|
type=str,
|
|
default="12",
|
|
help=
|
|
"""
|
|
Which CUDA major version should the wheel be tagged as? Auto-detected if
|
|
--cuda_version is set. When --cuda_version is not set, the default is to
|
|
set the major version to 12 to match the default in .bazelrc.
|
|
""",
|
|
)
|
|
|
|
cuda_group.add_argument(
|
|
"--cudnn_version",
|
|
type=str,
|
|
help=
|
|
"""
|
|
Hermetic cuDNN version to use. Default is to use the version specified
|
|
in the .bazelrc.
|
|
""",
|
|
)
|
|
|
|
cuda_group.add_argument(
|
|
"--disable_nccl",
|
|
action="store_true",
|
|
help="Should NCCL be disabled?",
|
|
)
|
|
|
|
cuda_group.add_argument(
|
|
"--cuda_compute_capabilities",
|
|
type=str,
|
|
default=None,
|
|
help=
|
|
"""
|
|
A comma-separated list of CUDA compute capabilities to support. Default
|
|
is to use the values specified in the .bazelrc.
|
|
""",
|
|
)
|
|
|
|
cuda_group.add_argument(
|
|
"--build_cuda_with_clang",
|
|
action="store_true",
|
|
help="""
|
|
Should CUDA code be compiled using Clang? The default behavior is to
|
|
compile CUDA with NVCC.
|
|
""",
|
|
)
|
|
|
|
# ROCm Options
|
|
rocm_group = parser.add_argument_group('ROCm Options')
|
|
rocm_group.add_argument(
|
|
"--rocm_version",
|
|
type=str,
|
|
default="60",
|
|
help="ROCm version to use",
|
|
)
|
|
|
|
rocm_group.add_argument(
|
|
"--rocm_amdgpu_targets",
|
|
type=str,
|
|
default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100,gfx1200,gfx1201",
|
|
help="A comma-separated list of ROCm amdgpu targets to support.",
|
|
)
|
|
|
|
rocm_group.add_argument(
|
|
"--rocm_path",
|
|
type=str,
|
|
default="",
|
|
help="Path to the ROCm toolkit.",
|
|
)
|
|
|
|
# Compile Options
|
|
compile_group = parser.add_argument_group('Compile Options')
|
|
|
|
compile_group.add_argument(
|
|
"--use_clang",
|
|
type=utils._parse_string_as_bool,
|
|
default="true",
|
|
const=True,
|
|
nargs="?",
|
|
help="""
|
|
Whether to use Clang as the compiler. Not recommended to set this to
|
|
False as JAX uses Clang as the default compiler.
|
|
""",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--clang_path",
|
|
type=str,
|
|
default="",
|
|
help="""
|
|
Path to the Clang binary to use.
|
|
""",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--gcc_path",
|
|
type=str,
|
|
default="",
|
|
help="""
|
|
Path to the GCC binary to use.
|
|
""",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--disable_mkl_dnn",
|
|
action="store_true",
|
|
help="""
|
|
Disables MKL-DNN.
|
|
""",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--target_cpu_features",
|
|
choices=["release", "native", "default"],
|
|
default="release",
|
|
help="""
|
|
What CPU features should we target? Release enables CPU features that
|
|
should be enabled for a release build, which on x86-64 architectures
|
|
enables AVX. Native enables -march=native, which generates code targeted
|
|
to use all features of the current machine. Default means don't opt-in
|
|
to any architectural features and use whatever the C compiler generates
|
|
by default.
|
|
""",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--target_cpu",
|
|
default=None,
|
|
help="CPU platform to target. Default is the same as the host machine.",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--local_xla_path",
|
|
type=str,
|
|
default=os.environ.get("JAXCI_XLA_GIT_DIR", ""),
|
|
help="""
|
|
Path to local XLA repository to use. If not set, Bazel uses the XLA at
|
|
the pinned version in workspace.bzl.
|
|
""",
|
|
)
|
|
|
|
async def main():
|
|
parser = argparse.ArgumentParser(
|
|
description=r"""
|
|
CLI for building JAX wheel packages from source and for updating the
|
|
requirements_lock.txt files
|
|
""",
|
|
epilog=EPILOG,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter
|
|
)
|
|
|
|
# Create subparsers for build and requirements_update
|
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
|
|
# requirements_update subcommand
|
|
requirements_update_parser = subparsers.add_parser(
|
|
"requirements_update", help="Updates the requirements_lock.txt files"
|
|
)
|
|
requirements_update_parser.add_argument(
|
|
"--nightly_update",
|
|
action="store_true",
|
|
help="""
|
|
If true, updates requirements_lock.txt for a corresponding version of
|
|
Python and will consider dev, nightly and pre-release versions of
|
|
packages.
|
|
""",
|
|
)
|
|
add_global_arguments(requirements_update_parser)
|
|
|
|
# Artifact build subcommand
|
|
build_artifact_parser = subparsers.add_parser(
|
|
"build", help="Builds the jaxlib, plugin, and pjrt artifact"
|
|
)
|
|
add_artifact_subcommand_arguments(build_artifact_parser)
|
|
add_global_arguments(build_artifact_parser)
|
|
|
|
arch = platform.machine()
|
|
os_name = platform.system().lower()
|
|
|
|
args = parser.parse_args()
|
|
|
|
logger.info("%s", BANNER)
|
|
|
|
if args.verbose:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
logger.info("Verbose logging enabled")
|
|
|
|
bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path)
|
|
|
|
logging.debug("Bazel path: %s", bazel_path)
|
|
logging.debug("Bazel version: %s", bazel_version)
|
|
|
|
executor = command.SubprocessExecutor()
|
|
|
|
# Start constructing the Bazel command
|
|
bazel_command_base = command.CommandBuilder(bazel_path)
|
|
|
|
if args.bazel_startup_options:
|
|
logging.debug(
|
|
"Additional Bazel startup options: %s", args.bazel_startup_options
|
|
)
|
|
for option in args.bazel_startup_options:
|
|
bazel_command_base.append(option)
|
|
|
|
if not args.use_new_wheel_build_rule or args.command == "requirements_update":
|
|
bazel_command_base.append("run")
|
|
else:
|
|
bazel_command_base.append("build")
|
|
|
|
if args.python_version:
|
|
# Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version
|
|
# if bazel_options override it
|
|
python_version_opt = "--repo_env=HERMETIC_PYTHON_VERSION="
|
|
if any([python_version_opt in opt for opt in args.bazel_options]):
|
|
raise RuntimeError(
|
|
"Please use python_version to set hermetic python version instead of "
|
|
"setting --repo_env=HERMETIC_PYTHON_VERSION=<python version> bazel option"
|
|
)
|
|
logging.debug("Hermetic Python version: %s", args.python_version)
|
|
bazel_command_base.append(
|
|
f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}"
|
|
)
|
|
# Let's interpret X.YY-ft version as free-threading python and set rules_python config flag:
|
|
if args.python_version.endswith("-ft"):
|
|
bazel_command_base.append(
|
|
"--@rules_python//python/config_settings:py_freethreaded='yes'"
|
|
)
|
|
|
|
# Enable verbose failures.
|
|
bazel_command_base.append("--verbose_failures=true")
|
|
|
|
# Requirements update subcommand execution
|
|
if args.command == "requirements_update":
|
|
requirements_command = copy.deepcopy(bazel_command_base)
|
|
if args.bazel_options:
|
|
logging.debug(
|
|
"Using additional build options: %s", args.bazel_options
|
|
)
|
|
for option in args.bazel_options:
|
|
requirements_command.append(option)
|
|
|
|
if args.nightly_update:
|
|
logging.info(
|
|
"--nightly_update is set. Bazel will run"
|
|
" //build:requirements_nightly.update"
|
|
)
|
|
requirements_command.append("//build:requirements_nightly.update")
|
|
else:
|
|
requirements_command.append("//build:requirements.update")
|
|
|
|
result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
|
|
if result.return_code != 0:
|
|
raise RuntimeError(f"Command failed with return code {result.return_code}")
|
|
else:
|
|
sys.exit(0)
|
|
|
|
wheel_build_command_base = copy.deepcopy(bazel_command_base)
|
|
|
|
wheel_cpus = {
|
|
"darwin_arm64": "arm64",
|
|
"darwin_x86_64": "x86_64",
|
|
"ppc": "ppc64le",
|
|
"aarch64": "aarch64",
|
|
}
|
|
target_cpu = (
|
|
wheel_cpus[args.target_cpu] if args.target_cpu is not None else arch
|
|
)
|
|
|
|
if args.local_xla_path:
|
|
logging.debug("Local XLA path: %s", args.local_xla_path)
|
|
wheel_build_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"")
|
|
|
|
if args.target_cpu:
|
|
logging.debug("Target CPU: %s", args.target_cpu)
|
|
wheel_build_command_base.append(f"--cpu={args.target_cpu}")
|
|
|
|
if args.disable_nccl:
|
|
logging.debug("Disabling NCCL")
|
|
wheel_build_command_base.append("--config=nonccl")
|
|
|
|
git_hash = utils.get_githash()
|
|
|
|
clang_path = ""
|
|
if args.use_clang:
|
|
clang_path = args.clang_path or utils.get_clang_path_or_exit()
|
|
clang_major_version = utils.get_clang_major_version(clang_path)
|
|
logging.debug(
|
|
"Using Clang as the compiler, clang path: %s, clang version: %s",
|
|
clang_path,
|
|
clang_major_version,
|
|
)
|
|
|
|
# Use double quotes around clang path to avoid path issues on Windows.
|
|
wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
|
|
wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"")
|
|
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"")
|
|
|
|
if clang_major_version >= 16:
|
|
# Enable clang settings that are needed for the build to work with newer
|
|
# versions of Clang.
|
|
wheel_build_command_base.append("--config=clang")
|
|
if clang_major_version < 19:
|
|
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
|
|
|
|
else:
|
|
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
|
|
logging.debug(
|
|
"Using GCC as the compiler, gcc path: %s",
|
|
gcc_path,
|
|
)
|
|
wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"")
|
|
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")
|
|
|
|
gcc_major_version = utils.get_gcc_major_version(gcc_path)
|
|
if gcc_major_version < 13:
|
|
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
|
|
|
|
if not args.disable_mkl_dnn:
|
|
logging.debug("Enabling MKL DNN")
|
|
if target_cpu == "aarch64":
|
|
wheel_build_command_base.append("--config=mkl_aarch64_threadpool")
|
|
else:
|
|
wheel_build_command_base.append("--config=mkl_open_source_only")
|
|
|
|
if args.target_cpu_features == "release":
|
|
if arch in ["x86_64", "AMD64"]:
|
|
logging.debug(
|
|
"Using release cpu features: --config=avx_%s",
|
|
"windows" if os_name == "windows" else "posix",
|
|
)
|
|
wheel_build_command_base.append(
|
|
"--config=avx_windows"
|
|
if os_name == "windows"
|
|
else "--config=avx_posix"
|
|
)
|
|
elif args.target_cpu_features == "native":
|
|
if os_name == "windows":
|
|
logger.warning(
|
|
"--target_cpu_features=native is not supported on Windows;"
|
|
" ignoring."
|
|
)
|
|
else:
|
|
logging.debug("Using native cpu features: --config=native_arch_posix")
|
|
wheel_build_command_base.append("--config=native_arch_posix")
|
|
else:
|
|
logging.debug("Using default cpu features")
|
|
|
|
if "cuda" in args.wheels and "rocm" in args.wheels:
|
|
logging.error("CUDA and ROCm cannot be enabled at the same time.")
|
|
sys.exit(1)
|
|
|
|
if "cuda" in args.wheels:
|
|
wheel_build_command_base.append("--config=cuda")
|
|
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
|
|
if args.use_clang:
|
|
wheel_build_command_base.append(
|
|
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
|
|
)
|
|
if args.build_cuda_with_clang:
|
|
logging.debug("Building CUDA with Clang")
|
|
wheel_build_command_base.append("--config=build_cuda_with_clang")
|
|
else:
|
|
logging.debug("Building CUDA with NVCC")
|
|
wheel_build_command_base.append("--config=build_cuda_with_nvcc")
|
|
else:
|
|
logging.debug("Building CUDA with NVCC")
|
|
wheel_build_command_base.append("--config=build_cuda_with_nvcc")
|
|
|
|
if args.cuda_version:
|
|
logging.debug("Hermetic CUDA version: %s", args.cuda_version)
|
|
wheel_build_command_base.append(
|
|
f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}"
|
|
)
|
|
if args.cudnn_version:
|
|
logging.debug("Hermetic cuDNN version: %s", args.cudnn_version)
|
|
wheel_build_command_base.append(
|
|
f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}"
|
|
)
|
|
if args.cuda_compute_capabilities:
|
|
logging.debug(
|
|
"Hermetic CUDA compute capabilities: %s",
|
|
args.cuda_compute_capabilities,
|
|
)
|
|
wheel_build_command_base.append(
|
|
f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}"
|
|
)
|
|
|
|
if "rocm" in args.wheels:
|
|
wheel_build_command_base.append("--config=rocm_base")
|
|
if args.use_clang:
|
|
wheel_build_command_base.append("--config=rocm")
|
|
wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
|
|
if args.rocm_path:
|
|
logging.debug("ROCm tookit path: %s", args.rocm_path)
|
|
wheel_build_command_base.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"")
|
|
if args.rocm_amdgpu_targets:
|
|
logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets)
|
|
wheel_build_command_base.append(
|
|
f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}"
|
|
)
|
|
|
|
# Append additional build options at the end to override any options set in
|
|
# .bazelrc or above.
|
|
if args.bazel_options:
|
|
logging.debug(
|
|
"Additional Bazel build options: %s", args.bazel_options
|
|
)
|
|
for option in args.bazel_options:
|
|
wheel_build_command_base.append(option)
|
|
if "cuda" in args.wheels:
|
|
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
|
|
|
|
with open(".jax_configure.bazelrc", "w") as f:
|
|
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule)
|
|
if not jax_configure_options:
|
|
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
|
|
sys.exit(1)
|
|
f.write(jax_configure_options)
|
|
logging.info("Bazel options written to .jax_configure.bazelrc")
|
|
|
|
if args.use_new_wheel_build_rule:
|
|
logging.info("Using new wheel build rule")
|
|
wheel_build_targets = WHEEL_BUILD_TARGET_DICT_NEW
|
|
else:
|
|
wheel_build_targets = WHEEL_BUILD_TARGET_DICT
|
|
|
|
if args.configure_only:
|
|
logging.info("--configure_only is set so not running any Bazel commands.")
|
|
else:
|
|
# Wheel build command execution
|
|
for wheel in args.wheels.split(","):
|
|
output_path = args.output_path
|
|
logger.debug("Artifacts output directory: %s", output_path)
|
|
|
|
# Allow CUDA/ROCm wheels without the "jax-" prefix.
|
|
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
|
|
wheel = "jax-" + wheel
|
|
|
|
if wheel not in wheel_build_targets.keys():
|
|
logging.error(
|
|
"Incorrect wheel name provided, valid choices are jaxlib,"
|
|
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
|
|
" jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt"
|
|
)
|
|
sys.exit(1)
|
|
|
|
wheel_build_command = copy.deepcopy(wheel_build_command_base)
|
|
print("\n")
|
|
logger.info(
|
|
"Building %s for %s %s...",
|
|
wheel,
|
|
os_name,
|
|
arch,
|
|
)
|
|
|
|
# Append the build target to the Bazel command.
|
|
build_target = wheel_build_targets[wheel]
|
|
wheel_build_command.append(build_target)
|
|
|
|
if not args.use_new_wheel_build_rule:
|
|
wheel_build_command.append("--")
|
|
|
|
if args.editable:
|
|
logger.info("Building an editable build")
|
|
output_path = os.path.join(output_path, wheel)
|
|
wheel_build_command.append("--editable")
|
|
|
|
wheel_build_command.append(f'--output_path="{output_path}"')
|
|
wheel_build_command.append(f"--cpu={target_cpu}")
|
|
|
|
if "cuda" in wheel:
|
|
wheel_build_command.append("--enable-cuda=True")
|
|
if args.cuda_version:
|
|
cuda_major_version = args.cuda_version.split(".")[0]
|
|
else:
|
|
cuda_major_version = args.cuda_major_version
|
|
wheel_build_command.append(f"--platform_version={cuda_major_version}")
|
|
|
|
if "rocm" in wheel:
|
|
wheel_build_command.append("--enable-rocm=True")
|
|
wheel_build_command.append(f"--platform_version={args.rocm_version}")
|
|
|
|
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
|
|
|
|
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
|
|
# Exit with error if any wheel build fails.
|
|
if result.return_code != 0:
|
|
raise RuntimeError(f"Command failed with return code {result.return_code}")
|
|
|
|
# Exit with success if all wheels in the list were built successfully.
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|