rocm_jax/build/build.py
github-actions[bot] 6ee76a8a6d
Merge pull request #271 from ROCm/ci-upstream-sync-142_1
CI: 03/11/25 upstream sync
2025-03-11 14:10:03 -05:00

701 lines
22 KiB
Python
Executable File

#!/usr/bin/python
#
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# CLI for building JAX wheel packages from source and for updating the
# requirements_lock.txt files
import argparse
import asyncio
import logging
import os
import platform
import sys
import copy
from tools import command, utils
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
BANNER = r"""
_ _ __ __
| | / \ \ \/ /
_ | |/ _ \ \ /
| |_| / ___ \/ \
\___/_/ \/_/\_\
"""
EPILOG = """
From the root directory of the JAX repository, run
`python build/build.py build --wheels=<list of JAX wheels>` to build JAX
artifacts.
Multiple wheels can be built with a single invocation of the CLI.
E.g. python build/build.py build --wheels=jaxlib,jax-cuda-plugin
To update the requirements_lock.txt files, run
`python build/build.py requirements_update`
"""
# Define the build target for each wheel.
WHEEL_BUILD_TARGET_DICT = {
"jaxlib": "//jaxlib/tools:build_wheel",
"jax-cuda-plugin": "//jaxlib/tools:build_gpu_kernels_wheel",
"jax-cuda-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
"jax-rocm-plugin": "//jaxlib/tools:build_gpu_kernels_wheel",
"jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
}
# Dictionary with the new wheel build rule. Note that when JAX migrates to the
# new wheel build rule fully, the build CLI will switch to the new wheel build
# rule as the default.
WHEEL_BUILD_TARGET_DICT_NEW = {
"jax": "//:jax_wheel",
"jaxlib": "//jaxlib/tools:jaxlib_wheel",
"jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel",
"jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel",
"jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel",
"jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel",
}
def add_global_arguments(parser: argparse.ArgumentParser):
"""Adds all the global arguments that applies to all the CLI subcommands."""
parser.add_argument(
"--python_version",
type=str,
default=f"{sys.version_info.major}.{sys.version_info.minor}",
help=
"""
Hermetic Python version to use. Default is to use the version of the
Python binary that executed the CLI.
""",
)
bazel_group = parser.add_argument_group('Bazel Options')
bazel_group.add_argument(
"--bazel_path",
type=str,
default="",
help="""
Path to the Bazel binary to use. The default is to find bazel via the
PATH; if none is found, downloads a fresh copy of Bazel from GitHub.
""",
)
bazel_group.add_argument(
"--bazel_startup_options",
action="append",
default=[],
help="""
Additional startup options to pass to Bazel, can be specified multiple
times to pass multiple options.
E.g. --bazel_startup_options='--nobatch'
""",
)
bazel_group.add_argument(
"--bazel_options",
action="append",
default=[],
help="""
Additional build options to pass to Bazel, can be specified multiple
times to pass multiple options.
E.g. --bazel_options='--local_resources=HOST_CPUS'
""",
)
parser.add_argument(
"--dry_run",
action="store_true",
help="Prints the Bazel command that is going to be executed.",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Produce verbose output for debugging.",
)
parser.add_argument(
"--detailed_timestamped_log",
action="store_true",
help="""
Enable detailed logging of the Bazel command with timestamps. The logs
will be stored and can be accessed as artifacts.
""",
)
def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
"""Adds all the arguments that applies to the artifact subcommands."""
parser.add_argument(
"--wheels",
type=str,
default="jaxlib",
help=
"""
A comma separated list of JAX wheels to build. E.g: --wheels="jaxlib",
--wheels="jaxlib,jax-cuda-plugin", etc.
Valid options are: jaxlib, jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,
jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt
""",
)
parser.add_argument(
"--use_new_wheel_build_rule",
action="store_true",
help=
"""
Whether to use the new wheel build rule. Temporary flag and will be
removed once JAX migrates to the new wheel build rule fully.
""",
)
parser.add_argument(
"--editable",
action="store_true",
help="Create an 'editable' build instead of a wheel.",
)
parser.add_argument(
"--output_path",
type=str,
default=os.path.join(os.getcwd(), "dist"),
help="Directory to which the JAX wheel packages should be written.",
)
parser.add_argument(
"--configure_only",
action="store_true",
help="""
If true, writes the Bazel options to the .jax_configure.bazelrc file but
does not build the artifacts.
""",
)
# CUDA Options
cuda_group = parser.add_argument_group('CUDA Options')
cuda_group.add_argument(
"--cuda_version",
type=str,
help=
"""
Hermetic CUDA version to use. Default is to use the version specified
in the .bazelrc.
""",
)
cuda_group.add_argument(
"--cuda_major_version",
type=str,
default="12",
help=
"""
Which CUDA major version should the wheel be tagged as? Auto-detected if
--cuda_version is set. When --cuda_version is not set, the default is to
set the major version to 12 to match the default in .bazelrc.
""",
)
cuda_group.add_argument(
"--cudnn_version",
type=str,
help=
"""
Hermetic cuDNN version to use. Default is to use the version specified
in the .bazelrc.
""",
)
cuda_group.add_argument(
"--disable_nccl",
action="store_true",
help="Should NCCL be disabled?",
)
cuda_group.add_argument(
"--cuda_compute_capabilities",
type=str,
default=None,
help=
"""
A comma-separated list of CUDA compute capabilities to support. Default
is to use the values specified in the .bazelrc.
""",
)
cuda_group.add_argument(
"--build_cuda_with_clang",
action="store_true",
help="""
Should CUDA code be compiled using Clang? The default behavior is to
compile CUDA with NVCC.
""",
)
# ROCm Options
rocm_group = parser.add_argument_group('ROCm Options')
rocm_group.add_argument(
"--rocm_version",
type=str,
default="60",
help="ROCm version to use",
)
rocm_group.add_argument(
"--rocm_amdgpu_targets",
type=str,
default="gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1030,gfx1100,gfx1101,gfx1200,gfx1201",
help="A comma-separated list of ROCm amdgpu targets to support.",
)
rocm_group.add_argument(
"--rocm_path",
type=str,
default="",
help="Path to the ROCm toolkit.",
)
# Compile Options
compile_group = parser.add_argument_group('Compile Options')
compile_group.add_argument(
"--use_clang",
type=utils._parse_string_as_bool,
default="true",
const=True,
nargs="?",
help="""
Whether to use Clang as the compiler. Not recommended to set this to
False as JAX uses Clang as the default compiler.
""",
)
compile_group.add_argument(
"--clang_path",
type=str,
default="",
help="""
Path to the Clang binary to use.
""",
)
compile_group.add_argument(
"--gcc_path",
type=str,
default="",
help="""
Path to the GCC binary to use.
""",
)
compile_group.add_argument(
"--disable_mkl_dnn",
action="store_true",
help="""
Disables MKL-DNN.
""",
)
compile_group.add_argument(
"--target_cpu_features",
choices=["release", "native", "default"],
default="release",
help="""
What CPU features should we target? Release enables CPU features that
should be enabled for a release build, which on x86-64 architectures
enables AVX. Native enables -march=native, which generates code targeted
to use all features of the current machine. Default means don't opt-in
to any architectural features and use whatever the C compiler generates
by default.
""",
)
compile_group.add_argument(
"--target_cpu",
default=None,
help="CPU platform to target. Default is the same as the host machine.",
)
compile_group.add_argument(
"--local_xla_path",
type=str,
default=os.environ.get("JAXCI_XLA_GIT_DIR", ""),
help="""
Path to local XLA repository to use. If not set, Bazel uses the XLA at
the pinned version in workspace.bzl.
""",
)
async def main():
parser = argparse.ArgumentParser(
description=r"""
CLI for building JAX wheel packages from source and for updating the
requirements_lock.txt files
""",
epilog=EPILOG,
formatter_class=argparse.RawDescriptionHelpFormatter
)
# Create subparsers for build and requirements_update
subparsers = parser.add_subparsers(dest="command", required=True)
# requirements_update subcommand
requirements_update_parser = subparsers.add_parser(
"requirements_update", help="Updates the requirements_lock.txt files"
)
requirements_update_parser.add_argument(
"--nightly_update",
action="store_true",
help="""
If true, updates requirements_lock.txt for a corresponding version of
Python and will consider dev, nightly and pre-release versions of
packages.
""",
)
add_global_arguments(requirements_update_parser)
# Artifact build subcommand
build_artifact_parser = subparsers.add_parser(
"build", help="Builds the jaxlib, plugin, and pjrt artifact"
)
add_artifact_subcommand_arguments(build_artifact_parser)
add_global_arguments(build_artifact_parser)
arch = platform.machine()
os_name = platform.system().lower()
args = parser.parse_args()
logger.info("%s", BANNER)
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
logger.info("Verbose logging enabled")
bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path)
logging.debug("Bazel path: %s", bazel_path)
logging.debug("Bazel version: %s", bazel_version)
executor = command.SubprocessExecutor()
# Start constructing the Bazel command
bazel_command_base = command.CommandBuilder(bazel_path)
if args.bazel_startup_options:
logging.debug(
"Additional Bazel startup options: %s", args.bazel_startup_options
)
for option in args.bazel_startup_options:
bazel_command_base.append(option)
if not args.use_new_wheel_build_rule or args.command == "requirements_update":
bazel_command_base.append("run")
else:
bazel_command_base.append("build")
if args.python_version:
# Do not add --repo_env=HERMETIC_PYTHON_VERSION with default args.python_version
# if bazel_options override it
python_version_opt = "--repo_env=HERMETIC_PYTHON_VERSION="
if any([python_version_opt in opt for opt in args.bazel_options]):
raise RuntimeError(
"Please use python_version to set hermetic python version instead of "
"setting --repo_env=HERMETIC_PYTHON_VERSION=<python version> bazel option"
)
logging.debug("Hermetic Python version: %s", args.python_version)
bazel_command_base.append(
f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}"
)
# Let's interpret X.YY-ft version as free-threading python and set rules_python config flag:
if args.python_version.endswith("-ft"):
bazel_command_base.append(
"--@rules_python//python/config_settings:py_freethreaded='yes'"
)
# Enable verbose failures.
bazel_command_base.append("--verbose_failures=true")
# Requirements update subcommand execution
if args.command == "requirements_update":
requirements_command = copy.deepcopy(bazel_command_base)
if args.bazel_options:
logging.debug(
"Using additional build options: %s", args.bazel_options
)
for option in args.bazel_options:
requirements_command.append(option)
if args.nightly_update:
logging.info(
"--nightly_update is set. Bazel will run"
" //build:requirements_nightly.update"
)
requirements_command.append("//build:requirements_nightly.update")
else:
requirements_command.append("//build:requirements.update")
result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
if result.return_code != 0:
raise RuntimeError(f"Command failed with return code {result.return_code}")
else:
sys.exit(0)
wheel_build_command_base = copy.deepcopy(bazel_command_base)
wheel_cpus = {
"darwin_arm64": "arm64",
"darwin_x86_64": "x86_64",
"ppc": "ppc64le",
"aarch64": "aarch64",
}
target_cpu = (
wheel_cpus[args.target_cpu] if args.target_cpu is not None else arch
)
if args.local_xla_path:
logging.debug("Local XLA path: %s", args.local_xla_path)
wheel_build_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"")
if args.target_cpu:
logging.debug("Target CPU: %s", args.target_cpu)
wheel_build_command_base.append(f"--cpu={args.target_cpu}")
if args.disable_nccl:
logging.debug("Disabling NCCL")
wheel_build_command_base.append("--config=nonccl")
git_hash = utils.get_githash()
clang_path = ""
if args.use_clang:
clang_path = args.clang_path or utils.get_clang_path_or_exit()
clang_major_version = utils.get_clang_major_version(clang_path)
logging.debug(
"Using Clang as the compiler, clang path: %s, clang version: %s",
clang_path,
clang_major_version,
)
# Use double quotes around clang path to avoid path issues on Windows.
wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"")
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"")
if clang_major_version >= 16:
# Enable clang settings that are needed for the build to work with newer
# versions of Clang.
wheel_build_command_base.append("--config=clang")
if clang_major_version < 19:
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
else:
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
logging.debug(
"Using GCC as the compiler, gcc path: %s",
gcc_path,
)
wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"")
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")
gcc_major_version = utils.get_gcc_major_version(gcc_path)
if gcc_major_version < 13:
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")
if not args.disable_mkl_dnn:
logging.debug("Enabling MKL DNN")
if target_cpu == "aarch64":
wheel_build_command_base.append("--config=mkl_aarch64_threadpool")
else:
wheel_build_command_base.append("--config=mkl_open_source_only")
if args.target_cpu_features == "release":
if arch in ["x86_64", "AMD64"]:
logging.debug(
"Using release cpu features: --config=avx_%s",
"windows" if os_name == "windows" else "posix",
)
wheel_build_command_base.append(
"--config=avx_windows"
if os_name == "windows"
else "--config=avx_posix"
)
elif args.target_cpu_features == "native":
if os_name == "windows":
logger.warning(
"--target_cpu_features=native is not supported on Windows;"
" ignoring."
)
else:
logging.debug("Using native cpu features: --config=native_arch_posix")
wheel_build_command_base.append("--config=native_arch_posix")
else:
logging.debug("Using default cpu features")
if "cuda" in args.wheels and "rocm" in args.wheels:
logging.error("CUDA and ROCm cannot be enabled at the same time.")
sys.exit(1)
if "cuda" in args.wheels:
wheel_build_command_base.append("--config=cuda")
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
if args.use_clang:
wheel_build_command_base.append(
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
)
if args.build_cuda_with_clang:
logging.debug("Building CUDA with Clang")
wheel_build_command_base.append("--config=build_cuda_with_clang")
else:
logging.debug("Building CUDA with NVCC")
wheel_build_command_base.append("--config=build_cuda_with_nvcc")
else:
logging.debug("Building CUDA with NVCC")
wheel_build_command_base.append("--config=build_cuda_with_nvcc")
if args.cuda_version:
logging.debug("Hermetic CUDA version: %s", args.cuda_version)
wheel_build_command_base.append(
f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}"
)
if args.cudnn_version:
logging.debug("Hermetic cuDNN version: %s", args.cudnn_version)
wheel_build_command_base.append(
f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}"
)
if args.cuda_compute_capabilities:
logging.debug(
"Hermetic CUDA compute capabilities: %s",
args.cuda_compute_capabilities,
)
wheel_build_command_base.append(
f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}"
)
if "rocm" in args.wheels:
wheel_build_command_base.append("--config=rocm_base")
if args.use_clang:
wheel_build_command_base.append("--config=rocm")
wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
if args.rocm_path:
logging.debug("ROCm tookit path: %s", args.rocm_path)
wheel_build_command_base.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"")
if args.rocm_amdgpu_targets:
logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets)
wheel_build_command_base.append(
f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}"
)
# Append additional build options at the end to override any options set in
# .bazelrc or above.
if args.bazel_options:
logging.debug(
"Additional Bazel build options: %s", args.bazel_options
)
for option in args.bazel_options:
wheel_build_command_base.append(option)
if "cuda" in args.wheels:
wheel_build_command_base.append("--config=cuda_libraries_from_stubs")
with open(".jax_configure.bazelrc", "w") as f:
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list(), args.use_new_wheel_build_rule)
if not jax_configure_options:
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
sys.exit(1)
f.write(jax_configure_options)
logging.info("Bazel options written to .jax_configure.bazelrc")
if args.use_new_wheel_build_rule:
logging.info("Using new wheel build rule")
wheel_build_targets = WHEEL_BUILD_TARGET_DICT_NEW
else:
wheel_build_targets = WHEEL_BUILD_TARGET_DICT
if args.configure_only:
logging.info("--configure_only is set so not running any Bazel commands.")
else:
# Wheel build command execution
for wheel in args.wheels.split(","):
output_path = args.output_path
logger.debug("Artifacts output directory: %s", output_path)
# Allow CUDA/ROCm wheels without the "jax-" prefix.
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
wheel = "jax-" + wheel
if wheel not in wheel_build_targets.keys():
logging.error(
"Incorrect wheel name provided, valid choices are jaxlib,"
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
" jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt"
)
sys.exit(1)
wheel_build_command = copy.deepcopy(wheel_build_command_base)
print("\n")
logger.info(
"Building %s for %s %s...",
wheel,
os_name,
arch,
)
# Append the build target to the Bazel command.
build_target = wheel_build_targets[wheel]
wheel_build_command.append(build_target)
if not args.use_new_wheel_build_rule:
wheel_build_command.append("--")
if args.editable:
logger.info("Building an editable build")
output_path = os.path.join(output_path, wheel)
wheel_build_command.append("--editable")
wheel_build_command.append(f'--output_path="{output_path}"')
wheel_build_command.append(f"--cpu={target_cpu}")
if "cuda" in wheel:
wheel_build_command.append("--enable-cuda=True")
if args.cuda_version:
cuda_major_version = args.cuda_version.split(".")[0]
else:
cuda_major_version = args.cuda_major_version
wheel_build_command.append(f"--platform_version={cuda_major_version}")
if "rocm" in wheel:
wheel_build_command.append("--enable-rocm=True")
wheel_build_command.append(f"--platform_version={args.rocm_version}")
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
# Exit with error if any wheel build fails.
if result.return_code != 0:
raise RuntimeError(f"Command failed with return code {result.return_code}")
# Exit with success if all wheels in the list were built successfully.
sys.exit(0)
if __name__ == "__main__":
asyncio.run(main())