mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 19:06:07 +00:00
701 lines
22 KiB
Python
Executable File
701 lines
22 KiB
Python
Executable File
#!/usr/bin/python
|
|
#
|
|
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# CLI for building JAX wheel packages from source and for updating the
|
|
# requirements_lock.txt files
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import platform
|
|
import sys
|
|
import copy
|
|
|
|
from tools import command, utils
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BANNER = r"""
|
|
_ _ __ __
|
|
| | / \ \ \/ /
|
|
_ | |/ _ \ \ /
|
|
| |_| / ___ \/ \
|
|
\___/_/ \/_/\_\
|
|
|
|
"""
|
|
|
|
EPILOG = """
|
|
From the root directory of the JAX repository, run
|
|
`python build/build.py build --wheels=<list of JAX wheels>` to build JAX
|
|
artifacts.
|
|
|
|
Multiple wheels can be built with a single invocation of the CLI.
|
|
E.g. python build/build.py build --wheels=jaxlib,jax-cuda-plugin
|
|
|
|
To update the requirements_lock.txt files, run
|
|
`python build/build.py requirements_update`
|
|
"""
|
|
|
|
# Define the build target for each wheel.
|
|
WHEEL_BUILD_TARGET_DICT = {
|
|
"jaxlib": "//jaxlib/tools:build_wheel",
|
|
"jax-cuda-plugin": "//jaxlib/tools:build_gpu_kernels_wheel",
|
|
"jax-cuda-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
|
|
"jax-rocm-plugin": "//jaxlib/tools:build_gpu_kernels_wheel",
|
|
"jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
|
|
}
|
|
|
|
# Dictionary with the new wheel build rule. Note that when JAX migrates to the
|
|
# new wheel build rule fully, the build CLI will switch to the new wheel build
|
|
# rule as the default.
|
|
WHEEL_BUILD_TARGET_DICT_NEW = {
|
|
"jax": "//:jax_wheel",
|
|
"jaxlib": "//jaxlib/tools:jaxlib_wheel",
|
|
"jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel",
|
|
"jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel",
|
|
"jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel",
|
|
"jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel",
|
|
}
|
|
|
|
def add_global_arguments(parser: argparse.ArgumentParser):
|
|
"""Adds all the global arguments that applies to all the CLI subcommands."""
|
|
parser.add_argument(
|
|
"--python_version",
|
|
type=str,
|
|
default=f"{sys.version_info.major}.{sys.version_info.minor}",
|
|
help=
|
|
"""
|
|
Hermetic Python version to use. Default is to use the version of the
|
|
Python binary that executed the CLI.
|
|
""",
|
|
)
|
|
|
|
bazel_group = parser.add_argument_group('Bazel Options')
|
|
bazel_group.add_argument(
|
|
"--bazel_path",
|
|
type=str,
|
|
default="",
|
|
help="""
|
|
Path to the Bazel binary to use. The default is to find bazel via the
|
|
PATH; if none is found, downloads a fresh copy of Bazel from GitHub.
|
|
""",
|
|
)
|
|
|
|
bazel_group.add_argument(
|
|
"--bazel_startup_options",
|
|
action="append",
|
|
default=[],
|
|
help="""
|
|
Additional startup options to pass to Bazel, can be specified multiple
|
|
times to pass multiple options.
|
|
E.g. --bazel_startup_options='--nobatch'
|
|
""",
|
|
)
|
|
|
|
bazel_group.add_argument(
|
|
"--bazel_options",
|
|
action="append",
|
|
default=[],
|
|
help="""
|
|
Additional build options to pass to Bazel, can be specified multiple
|
|
times to pass multiple options.
|
|
E.g. --bazel_options='--local_resources=HOST_CPUS'
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dry_run",
|
|
action="store_true",
|
|
help="Prints the Bazel command that is going to be executed.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--verbose",
|
|
action="store_true",
|
|
help="Produce verbose output for debugging.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--detailed_timestamped_log",
|
|
action="store_true",
|
|
help="""
|
|
Enable detailed logging of the Bazel command with timestamps. The logs
|
|
will be stored and can be accessed as artifacts.
|
|
""",
|
|
)
|
|
|
|
|
|
def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
|
|
"""Adds all the arguments that applies to the artifact subcommands."""
|
|
parser.add_argument(
|
|
"--wheels",
|
|
type=str,
|
|
default="jaxlib",
|
|
help=
|
|
"""
|
|
A comma separated list of JAX wheels to build. E.g: --wheels="jaxlib",
|
|
--wheels="jaxlib,jax-cuda-plugin", etc.
|
|
Valid options are: jaxlib, jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,
|
|
jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--use_new_wheel_build_rule",
|
|
action="store_true",
|
|
help=
|
|
"""
|
|
Whether to use the new wheel build rule. Temporary flag and will be
|
|
removed once JAX migrates to the new wheel build rule fully.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--editable",
|
|
action="store_true",
|
|
help="Create an 'editable' build instead of a wheel.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--output_path",
|
|
type=str,
|
|
default=os.path.join(os.getcwd(), "dist"),
|
|
help="Directory to which the JAX wheel packages should be written.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--configure_only",
|
|
action="store_true",
|
|
help="""
|
|
If true, writes the Bazel options to the .jax_configure.bazelrc file but
|
|
does not build the artifacts.
|
|
""",
|
|
)
|
|
|
|
# CUDA Options
|
|
cuda_group = parser.add_argument_group('CUDA Options')
|
|
cuda_group.add_argument(
|
|
"--cuda_version",
|
|
type=str,
|
|
help=
|
|
"""
|
|
Hermetic CUDA version to use. Default is to use the version specified
|
|
in the .bazelrc.
|
|
""",
|
|
)
|
|
|
|
cuda_group.add_argument(
|
|
"--cuda_major_version",
|
|
type=str,
|
|
default="12",
|
|
help=
|
|
"""
|
|
Which CUDA major version should the wheel be tagged as? Auto-detected if
|
|
--cuda_version is set. When --cuda_version is not set, the default is to
|
|
set the major version to 12 to match the default in .bazelrc.
|
|
""",
|
|
)
|
|
|
|
cuda_group.add_argument(
|
|
"--cudnn_version",
|
|
type=str,
|
|
help=
|
|
"""
|
|
Hermetic cuDNN version to use. Default is to use the version specified
|
|
in the .bazelrc.
|
|
""",
|
|
)
|
|
|
|
cuda_group.add_argument(
|
|
"--disable_nccl",
|
|
action="store_true",
|
|
help="Should NCCL be disabled?",
|
|
)
|
|
|
|
cuda_group.add_argument(
|
|
"--cuda_compute_capabilities",
|
|
type=str,
|
|
default=None,
|
|
help=
|
|
"""
|
|
A comma-separated list of CUDA compute capabilities to support. Default
|
|
is to use the values specified in the .bazelrc.
|
|
""",
|
|
)
|
|
|
|
cuda_group.add_argument(
|
|
"--build_cuda_with_clang",
|
|
action="store_true",
|
|
help="""
|
|
Should CUDA code be compiled using Clang? The default behavior is to
|
|
compile CUDA with NVCC.
|
|
""",
|
|
)
|
|
|
|
# ROCm Options
|
|
rocm_group = parser.add_argument_group('ROCm Options')
|
|
rocm_group.add_argument(
|
|
"--rocm_version",
|
|
type=str,
|
|
default="60",
|
|
help="ROCm version to use",
|
|
)
|
|
|
|
rocm_group.add_argument(
|
|
"--rocm_amdgpu_targets",
|
|
type=str,
|
|
default="gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201",
|
|
help="A comma-separated list of ROCm amdgpu targets to support.",
|
|
)
|
|
|
|
rocm_group.add_argument(
|
|
"--rocm_path",
|
|
type=str,
|
|
default="",
|
|
help="Path to the ROCm toolkit.",
|
|
)
|
|
|
|
# Compile Options
|
|
compile_group = parser.add_argument_group('Compile Options')
|
|
|
|
compile_group.add_argument(
|
|
"--use_clang",
|
|
type=utils._parse_string_as_bool,
|
|
default="true",
|
|
const=True,
|
|
nargs="?",
|
|
help="""
|
|
Whether to use Clang as the compiler. Not recommended to set this to
|
|
False as JAX uses Clang as the default compiler.
|
|
""",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--clang_path",
|
|
type=str,
|
|
default="",
|
|
help="""
|
|
Path to the Clang binary to use.
|
|
""",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--gcc_path",
|
|
type=str,
|
|
default="",
|
|
help="""
|
|
Path to the GCC binary to use.
|
|
""",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--disable_mkl_dnn",
|
|
action="store_true",
|
|
help="""
|
|
Disables MKL-DNN.
|
|
""",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--target_cpu_features",
|
|
choices=["release", "native", "default"],
|
|
default="release",
|
|
help="""
|
|
What CPU features should we target? Release enables CPU features that
|
|
should be enabled for a release build, which on x86-64 architectures
|
|
enables AVX. Native enables -march=native, which generates code targeted
|
|
to use all features of the current machine. Default means don't opt-in
|
|
to any architectural features and use whatever the C compiler generates
|
|
by default.
|
|
""",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--target_cpu",
|
|
default=None,
|
|
help="CPU platform to target. Default is the same as the host machine.",
|
|
)
|
|
|
|
compile_group.add_argument(
|
|
"--local_xla_path",
|
|
type=str,
|
|
default=os.environ.get("JAXCI_XLA_GIT_DIR", ""),
|
|
help="""
|
|
Path to local XLA repository to use. If not set, Bazel uses the XLA at
|
|
the pinned version in workspace.bzl.
|
|
""",
|
|
)
|
|
|
|
async def main():
|
|
parser = argparse.ArgumentParser(
|
|
description=r"""
|
|
CLI for building JAX wheel packages from source and for updating the
|
|
requirements_lock.txt files
|
|
""",
|
|
epilog=EPILOG,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter
|
|
)
|
|
|
|
# Create subparsers for build and requirements_update
|
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
|
|
# requirements_update subcommand
|
|
requirements_update_parser = subparsers.add_parser(
|
|
"requirements_update", help="Updates the requirements_lock.txt files"
|
|
)
|
|
requirements_update_parser.add_argument(
|
|
"--nightly_update",
|
|
action="store_true",
|
|
help="""
|
|
If true, updates requirements_lock.txt for a corresponding version of
|
|
Python and will consider dev, nightly and pre-release versions of
|
|
packages.
|
|
""",
|
|
)
|
|
add_global_arguments(requirements_update_parser)
|
|
|
|
# Artifact build subcommand
|
|
build_artifact_parser = subparsers.add_parser(
|
|
"build", help="Builds the jaxlib, plugin, and pjrt artifact"
|
|
)
|
|
add_artifact_subcommand_arguments(build_artifact_parser)
|
|
add_global_arguments(build_artifact_parser)
|
|
|
|
arch = platform.machine()
|
|
os_name = platform.system().lower()
|
|
|
|
args = parser.parse_args()
|
|
|
|
logger.info("%s", BANNER)
|
|
|
|
if args.verbose:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
logger.info("Verbose logging enabled")
|
|
|
|
bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path)
|
|
|
|
logging.debug("Bazel path: %s", bazel_path)
|
|
logging.debug("Bazel version: %s", bazel_version)
|
|
|
|
executor = command.SubprocessExecutor()
|
|
|
|
# Start constructing the Bazel command
|
|
bazel_command_base = command.CommandBuilder(bazel_path)
|
|
|
|
if args.bazel_startup_options:
|
|
logging.debug(
|
|
"Additional Bazel startup options: %s", args.bazel_startup_options
|
|
)
|
|
for option in args.bazel_startup_options:
|
|
bazel_command_base.append(option)
|
|
|
|
if not args.use_new_wheel_build_rule or args.command == "requirements_update":
|
|
bazel_command_base.append("run")
|
|
else:
|
|
bazel_command_base.append("build")
|
|
|
|
if args.python_version:
|
|
# Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version
|
|
# if bazel_options override it
|
|
python_version_opt = "--repo_env=HERMETIC_PYTHON_VERSION="
|
|
if any([python_version_opt in opt for opt in args.bazel_options]):
|
|
raise RuntimeError(
|
|
"Please use python_version to set hermetic python version instead of "
|
|
"setting --repo_env=HERMETIC_PYTHON_VERSION=<python version> bazel option"
|
|
)
|
|
logging.debug("Hermetic Python version: %s", args.python_version)
|
|
bazel_command_base.append(
|
|
f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}"
|
|
)
|
|
# Let's interpret X.YY-ft version as free-threading python and set rules_python config flag:
|
|
if args.python_version.endswith("-ft"):
|
|
bazel_command_base.append(
|
|
"--@rules_python//python/config_settings:py_freethreaded='yes'"
|
|
)
|
|
|
|
# Enable verbose failures.
|
|
bazel_command_base.append("--verbose_failures=true")
|
|
|
|
# Requirements update subcommand execution
|
|
if args.command == "requirements_update":
|
|
requirements_command = copy.deepcopy(bazel_command_base)
|
|
if args.bazel_options:
|
|
logging.debug(
|
|
"Using additional build options: %s", args.bazel_options
|
|
)
|
|
for option in args.bazel_options:
|
|
requirements_command.append(option)
|
|
|
|
if args.nightly_update:
|
|
logging.info(
|
|
"--nightly_update is set. Bazel will run"
|
|
" //build:requirements_nightly.update"
|
|
)
|
|
requirements_command.append("//build:requirements_nightly.update")
|
|
else:
|
|
requirements_command.append("//build:requirements.update")
|
|
|
|
result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
|
|
if result.return_code != 0:
|
|
raise RuntimeError(f"Command failed with return code {result.return_code}")
|
|
else:
|
|
sys.exit(0)
|
|
|
|
wheel_build_command_base = copy.deepcopy(bazel_command_base)
|
|
|
|
wheel_cpus = {
|
|
"darwin_arm64": "arm64",
|
|
"darwin_x86_64": "x86_64",
|
|
"ppc": "ppc64le",
|
|
"aarch64": "aarch64",
|
|
}
|
|
target_cpu = (
|
|
wheel_cpus[args.target_cpu] if args.target_cpu is not None else arch
|
|
)
|
|
|
|
if args.local_xla_path:
|
|
logging.debug("Local XLA path: %s", args.local_xla_path)
|
|
wheel_build_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"")
|
|
|
|
if args.target_cpu:
|
|
logging.debug("Target CPU: %s", args.target_cpu)
|
|
wheel_build_command_base.append(f"--cpu={args.target_cpu}")
|
|
|
|
if args.disable_nccl:
|
|
logging.debug("Disabling NCCL")
|
|
wheel_build_command_base.append("--config=nonccl")
|
|
|
|
git_hash = utils.get_githash()
|
|
|
|
clang_path = ""
|
|
if args.use_clang:
|
|
clang_path = args.clang_path or utils.get_clang_path_or_exit()
|
|
clang_major_version = utils.get_clang_major_version(clang_path)
|
|
logging.debug(
|
|
"Using Clang as the compiler, clang path: %s, clang version: %s",
|
|
clang_path,
|
|
clang_major_version,
|
|
)
|
|
|
|
# Use double quotes around clang path to avoid path issues on Windows.
|
|
wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
|
|
wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"")
|
|
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"")
|
|
|
|
if clang_major_version >= 16:
|
|
# Enable clang settings that are needed for the build to work with newer
|
|
# versions of Clang.
|
|
wheel_build_command_base.append("--config=clang")
|
|
if clang_major_version < 19:
|
|
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
|
|
|
|
else:
|
|
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
|
|
logging.debug(
|
|
"Using GCC as the compiler, gcc path: %s",
|
|
gcc_path,
|
|
)
|
|
wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"")
|
|
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")
|
|
|
|
gcc_major_version = utils.get_gcc_major_version(gcc_path)
|
|
if gcc_major_version < 13:
|
|
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
|
|
|
|
if not args.disable_mkl_dnn:
|
|
logging.debug("Enabling MKL DNN")
|
|
if target_cpu == "aarch64":
|
|
wheel_build_command_base.append("--config=mkl_aarch64_threadpool")
|
|
else:
|
|
wheel_build_command_base.append("--config=mkl_open_source_only")
|
|
|
|
if args.target_cpu_features == "release":
|
|
if arch in ["x86_64", "AMD64"]:
|
|
logging.debug(
|
|
"Using release cpu features: --config=avx_%s",
|
|
"windows" if os_name == "windows" else "posix",
|
|
)
|
|
wheel_build_command_base.append(
|
|
"--config=avx_windows"
|
|
if os_name == "windows"
|
|
else "--config=avx_posix"
|
|
)
|
|
elif args.target_cpu_features == "native":
|
|
if os_name == "windows":
|
|
logger.warning(
|
|
"--target_cpu_features=native is not supported on Windows;"
|
|
" ignoring."
|
|
)
|
|
else:
|
|
logging.debug("Using native cpu features: --config=native_arch_posix")
|
|
wheel_build_command_base.append("--config=native_arch_posix")
|
|
else:
|
|
logging.debug("Using default cpu features")
|
|
|
|
if "cuda" in args.wheels and "rocm" in args.wheels:
|
|
logging.error("CUDA and ROCm cannot be enabled at the same time.")
|
|
sys.exit(1)
|
|
|
|
if "cuda" in args.wheels:
|
|
wheel_build_command_base.append("--config=cuda")
|
|
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
|
|
if args.use_clang:
|
|
wheel_build_command_base.append(
|
|
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
|
|
)
|
|
if args.build_cuda_with_clang:
|
|
logging.debug("Building CUDA with Clang")
|
|
wheel_build_command_base.append("--config=build_cuda_with_clang")
|
|
else:
|
|
logging.debug("Building CUDA with NVCC")
|
|
wheel_build_command_base.append("--config=build_cuda_with_nvcc")
|
|
else:
|
|
logging.debug("Building CUDA with NVCC")
|
|
wheel_build_command_base.append("--config=build_cuda_with_nvcc")
|
|
|
|
if args.cuda_version:
|
|
logging.debug("Hermetic CUDA version: %s", args.cuda_version)
|
|
wheel_build_command_base.append(
|
|
f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}"
|
|
)
|
|
if args.cudnn_version:
|
|
logging.debug("Hermetic cuDNN version: %s", args.cudnn_version)
|
|
wheel_build_command_base.append(
|
|
f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}"
|
|
)
|
|
if args.cuda_compute_capabilities:
|
|
logging.debug(
|
|
"Hermetic CUDA compute capabilities: %s",
|
|
args.cuda_compute_capabilities,
|
|
)
|
|
wheel_build_command_base.append(
|
|
f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}"
|
|
)
|
|
|
|
if "rocm" in args.wheels:
|
|
wheel_build_command_base.append("--config=rocm_base")
|
|
if args.use_clang:
|
|
wheel_build_command_base.append("--config=rocm")
|
|
wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
|
|
if args.rocm_path:
|
|
logging.debug("ROCm tookit path: %s", args.rocm_path)
|
|
wheel_build_command_base.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"")
|
|
if args.rocm_amdgpu_targets:
|
|
logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets)
|
|
wheel_build_command_base.append(
|
|
f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}"
|
|
)
|
|
|
|
# Append additional build options at the end to override any options set in
|
|
# .bazelrc or above.
|
|
if args.bazel_options:
|
|
logging.debug(
|
|
"Additional Bazel build options: %s", args.bazel_options
|
|
)
|
|
for option in args.bazel_options:
|
|
wheel_build_command_base.append(option)
|
|
if "cuda" in args.wheels:
|
|
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
|
|
|
|
with open(".jax_configure.bazelrc", "w") as f:
|
|
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule)
|
|
if not jax_configure_options:
|
|
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
|
|
sys.exit(1)
|
|
f.write(jax_configure_options)
|
|
logging.info("Bazel options written to .jax_configure.bazelrc")
|
|
|
|
if args.use_new_wheel_build_rule:
|
|
logging.info("Using new wheel build rule")
|
|
wheel_build_targets = WHEEL_BUILD_TARGET_DICT_NEW
|
|
else:
|
|
wheel_build_targets = WHEEL_BUILD_TARGET_DICT
|
|
|
|
if args.configure_only:
|
|
logging.info("--configure_only is set so not running any Bazel commands.")
|
|
else:
|
|
# Wheel build command execution
|
|
for wheel in args.wheels.split(","):
|
|
output_path = args.output_path
|
|
logger.debug("Artifacts output directory: %s", output_path)
|
|
|
|
# Allow CUDA/ROCm wheels without the "jax-" prefix.
|
|
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
|
|
wheel = "jax-" + wheel
|
|
|
|
if wheel not in wheel_build_targets.keys():
|
|
logging.error(
|
|
"Incorrect wheel name provided, valid choices are jaxlib,"
|
|
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
|
|
" jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt"
|
|
)
|
|
sys.exit(1)
|
|
|
|
wheel_build_command = copy.deepcopy(wheel_build_command_base)
|
|
print("\n")
|
|
logger.info(
|
|
"Building %s for %s %s...",
|
|
wheel,
|
|
os_name,
|
|
arch,
|
|
)
|
|
|
|
# Append the build target to the Bazel command.
|
|
build_target = wheel_build_targets[wheel]
|
|
wheel_build_command.append(build_target)
|
|
|
|
if not args.use_new_wheel_build_rule:
|
|
wheel_build_command.append("--")
|
|
|
|
if args.editable:
|
|
logger.info("Building an editable build")
|
|
output_path = os.path.join(output_path, wheel)
|
|
wheel_build_command.append("--editable")
|
|
|
|
wheel_build_command.append(f'--output_path="{output_path}"')
|
|
wheel_build_command.append(f"--cpu={target_cpu}")
|
|
|
|
if "cuda" in wheel:
|
|
wheel_build_command.append("--enable-cuda=True")
|
|
if args.cuda_version:
|
|
cuda_major_version = args.cuda_version.split(".")[0]
|
|
else:
|
|
cuda_major_version = args.cuda_major_version
|
|
wheel_build_command.append(f"--platform_version={cuda_major_version}")
|
|
|
|
if "rocm" in wheel:
|
|
wheel_build_command.append("--enable-rocm=True")
|
|
wheel_build_command.append(f"--platform_version={args.rocm_version}")
|
|
|
|
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
|
|
|
|
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
|
|
# Exit with error if any wheel build fails.
|
|
if result.return_code != 0:
|
|
raise RuntimeError(f"Command failed with return code {result.return_code}")
|
|
|
|
# Exit with success if all wheels in the list were built successfully.
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|