rocm_jax/build/build.py
Nitin Srinivasan 6761512658 Re-factor build CLI to a subcommand based approach
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.

Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.

There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time.

Usage:
* Building `jaxlib`:
```
python build/build.py build --wheels=jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building multiple packages:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```

For more details on each argument and to see available options, run:
```
python build/build.py build --help
```
or
```
python build/build.py requirements_update --help
```

PiperOrigin-RevId: 700075411
2024-11-25 13:03:04 -08:00

602 lines
18 KiB
Python
Executable File

#!/usr/bin/python
#
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# CLI for building JAX wheel packages from source and for updating the
# requirements_lock.txt files
import argparse
import asyncio
import logging
import os
import platform
import sys
import copy
from tools import command, utils
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
BANNER = r"""
_ _ __ __
| | / \ \ \/ /
_ | |/ _ \ \ /
| |_| / ___ \/ \
\___/_/ \/_/\_\
"""
EPILOG = """
From the root directory of the JAX repository, run
`python build/build.py build --wheels=<list of JAX wheels>` to build JAX
artifacts.
Multiple wheels can be built with a single invocation of the CLI.
E.g. python build/build.py build --wheels=jaxlib,jax-cuda-plugin
To update the requirements_lock.txt files, run
`python build/build.py requirements_update`
"""
# Define the build target for each wheel.
WHEEL_BUILD_TARGET_DICT = {
"jaxlib": "//jaxlib/tools:build_wheel",
"jax-cuda-plugin": "//jaxlib/tools:build_gpu_kernels_wheel",
"jax-cuda-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
"jax-rocm-plugin": "//jaxlib/tools:build_gpu_kernels_wheel",
"jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
}
def add_global_arguments(parser: argparse.ArgumentParser):
"""Adds all the global arguments that applies to all the CLI subcommands."""
parser.add_argument(
"--python_version",
type=str,
choices=["3.10", "3.11", "3.12", "3.13"],
default=f"{sys.version_info.major}.{sys.version_info.minor}",
help=
"""
Hermetic Python version to use. Default is to use the version of the
Python binary that executed the CLI.
""",
)
bazel_group = parser.add_argument_group('Bazel Options')
bazel_group.add_argument(
"--bazel_path",
type=str,
default="",
help="""
Path to the Bazel binary to use. The default is to find bazel via the
PATH; if none is found, downloads a fresh copy of Bazel from GitHub.
""",
)
bazel_group.add_argument(
"--bazel_startup_options",
action="append",
default=[],
help="""
Additional startup options to pass to Bazel, can be specified multiple
times to pass multiple options.
E.g. --bazel_startup_options='--nobatch'
""",
)
bazel_group.add_argument(
"--bazel_options",
action="append",
default=[],
help="""
Additional build options to pass to Bazel, can be specified multiple
times to pass multiple options.
E.g. --bazel_options='--local_resources=HOST_CPUS'
""",
)
parser.add_argument(
"--dry_run",
action="store_true",
help="Prints the Bazel command that is going to be executed.",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Produce verbose output for debugging.",
)
def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
"""Adds all the arguments that applies to the artifact subcommands."""
parser.add_argument(
"--wheels",
type=str,
default="jaxlib",
help=
"""
A comma separated list of JAX wheels to build. E.g: --wheels="jaxlib",
--wheels="jaxlib,jax-cuda-plugin", etc.
Valid options are: jaxlib, jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,
jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt
""",
)
parser.add_argument(
"--editable",
action="store_true",
help="Create an 'editable' build instead of a wheel.",
)
parser.add_argument(
"--output_path",
type=str,
default=os.path.join(os.getcwd(), "dist"),
help="Directory to which the JAX wheel packages should be written.",
)
parser.add_argument(
"--configure_only",
action="store_true",
help="""
If true, writes the Bazel options to the .jax_configure.bazelrc file but
does not build the artifacts.
""",
)
# CUDA Options
cuda_group = parser.add_argument_group('CUDA Options')
cuda_group.add_argument(
"--cuda_version",
type=str,
help=
"""
Hermetic CUDA version to use. Default is to use the version specified
in the .bazelrc.
""",
)
cuda_group.add_argument(
"--cuda_major_version",
type=str,
default="12",
help=
"""
Which CUDA major version should the wheel be tagged as? Auto-detected if
--cuda_version is set. When --cuda_version is not set, the default is to
set the major version to 12 to match the default in .bazelrc.
""",
)
cuda_group.add_argument(
"--cudnn_version",
type=str,
help=
"""
Hermetic cuDNN version to use. Default is to use the version specified
in the .bazelrc.
""",
)
cuda_group.add_argument(
"--disable_nccl",
action="store_true",
help="Should NCCL be disabled?",
)
cuda_group.add_argument(
"--cuda_compute_capabilities",
type=str,
default=None,
help=
"""
A comma-separated list of CUDA compute capabilities to support. Default
is to use the values specified in the .bazelrc.
""",
)
cuda_group.add_argument(
"--build_cuda_with_clang",
action="store_true",
help="""
Should CUDA code be compiled using Clang? The default behavior is to
compile CUDA with NVCC.
""",
)
# ROCm Options
rocm_group = parser.add_argument_group('ROCm Options')
rocm_group.add_argument(
"--rocm_version",
type=str,
default="60",
help="ROCm version to use",
)
rocm_group.add_argument(
"--rocm_amdgpu_targets",
type=str,
default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100",
help="A comma-separated list of ROCm amdgpu targets to support.",
)
rocm_group.add_argument(
"--rocm_path",
type=str,
default="",
help="Path to the ROCm toolkit.",
)
# Compile Options
compile_group = parser.add_argument_group('Compile Options')
compile_group.add_argument(
"--use_clang",
type=utils._parse_string_as_bool,
default="true",
const=True,
nargs="?",
help="""
Whether to use Clang as the compiler. Not recommended to set this to
False as JAX uses Clang as the default compiler.
""",
)
compile_group.add_argument(
"--clang_path",
type=str,
default="",
help="""
Path to the Clang binary to use.
""",
)
compile_group.add_argument(
"--disable_mkl_dnn",
action="store_true",
help="""
Disables MKL-DNN.
""",
)
compile_group.add_argument(
"--target_cpu_features",
choices=["release", "native", "default"],
default="release",
help="""
What CPU features should we target? Release enables CPU features that
should be enabled for a release build, which on x86-64 architectures
enables AVX. Native enables -march=native, which generates code targeted
to use all features of the current machine. Default means don't opt-in
to any architectural features and use whatever the C compiler generates
by default.
""",
)
compile_group.add_argument(
"--target_cpu",
default=None,
help="CPU platform to target. Default is the same as the host machine.",
)
compile_group.add_argument(
"--local_xla_path",
type=str,
default=os.environ.get("JAXCI_XLA_GIT_DIR", ""),
help="""
Path to local XLA repository to use. If not set, Bazel uses the XLA at
the pinned version in workspace.bzl.
""",
)
async def main():
parser = argparse.ArgumentParser(
description=r"""
CLI for building JAX wheel packages from source and for updating the
requirements_lock.txt files
""",
epilog=EPILOG,
formatter_class=argparse.RawDescriptionHelpFormatter
)
# Create subparsers for build and requirements_update
subparsers = parser.add_subparsers(dest="command", required=True)
# requirements_update subcommand
requirements_update_parser = subparsers.add_parser(
"requirements_update", help="Updates the requirements_lock.txt files"
)
requirements_update_parser.add_argument(
"--nightly_update",
action="store_true",
help="""
If true, updates requirements_lock.txt for a corresponding version of
Python and will consider dev, nightly and pre-release versions of
packages.
""",
)
add_global_arguments(requirements_update_parser)
# Artifact build subcommand
build_artifact_parser = subparsers.add_parser(
"build", help="Builds the jaxlib, plugin, and pjrt artifact"
)
add_artifact_subcommand_arguments(build_artifact_parser)
add_global_arguments(build_artifact_parser)
arch = platform.machine()
os_name = platform.system().lower()
args = parser.parse_args()
logger.info("%s", BANNER)
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
logger.info("Verbose logging enabled")
bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path)
logging.debug("Bazel path: %s", bazel_path)
logging.debug("Bazel version: %s", bazel_version)
executor = command.SubprocessExecutor()
# Start constructing the Bazel command
bazel_command_base = command.CommandBuilder(bazel_path)
if args.bazel_startup_options:
logging.debug(
"Additional Bazel startup options: %s", args.bazel_startup_options
)
for option in args.bazel_startup_options:
bazel_command_base.append(option)
bazel_command_base.append("run")
if args.python_version:
logging.debug("Hermetic Python version: %s", args.python_version)
bazel_command_base.append(
f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}"
)
# Enable verbose failures.
bazel_command_base.append("--verbose_failures=true")
# Requirements update subcommand execution
if args.command == "requirements_update":
requirements_command = copy.deepcopy(bazel_command_base)
if args.bazel_options:
logging.debug(
"Using additional build options: %s", args.bazel_options
)
for option in args.bazel_options:
requirements_command.append(option)
if args.nightly_update:
logging.info(
"--nightly_update is set. Bazel will run"
" //build:requirements_nightly.update"
)
requirements_command.append("//build:requirements_nightly.update")
else:
requirements_command.append("//build:requirements.update")
await executor.run(requirements_command.get_command_as_string(), args.dry_run)
sys.exit(0)
wheel_cpus = {
"darwin_arm64": "arm64",
"darwin_x86_64": "x86_64",
"ppc": "ppc64le",
"aarch64": "aarch64",
}
target_cpu = (
wheel_cpus[args.target_cpu] if args.target_cpu is not None else arch
)
if args.local_xla_path:
logging.debug("Local XLA path: %s", args.local_xla_path)
bazel_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"")
if args.target_cpu:
logging.debug("Target CPU: %s", args.target_cpu)
bazel_command_base.append(f"--cpu={args.target_cpu}")
if args.disable_nccl:
logging.debug("Disabling NCCL")
bazel_command_base.append("--config=nonccl")
git_hash = utils.get_githash()
# Wheel build command execution
for wheel in args.wheels.split(","):
# Allow CUDA/ROCm wheels without the "jax-" prefix.
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
wheel = "jax-" + wheel
if wheel not in WHEEL_BUILD_TARGET_DICT.keys():
logging.error(
"Incorrect wheel name provided, valid choices are jaxlib,"
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
" jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt"
)
sys.exit(1)
wheel_build_command = copy.deepcopy(bazel_command_base)
print("\n")
logger.info(
"Building %s for %s %s...",
wheel,
os_name,
arch,
)
clang_path = ""
if args.use_clang:
clang_path = args.clang_path or utils.get_clang_path_or_exit()
clang_major_version = utils.get_clang_major_version(clang_path)
logging.debug(
"Using Clang as the compiler, clang path: %s, clang version: %s",
clang_path,
clang_major_version,
)
# Use double quotes around clang path to avoid path issues on Windows.
wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"")
wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"")
else:
logging.debug("Use Clang: False")
# Do not apply --config=clang on Mac as these settings do not apply to
# Apple Clang.
if os_name != "darwin":
wheel_build_command.append("--config=clang")
if not args.disable_mkl_dnn:
logging.debug("Enabling MKL DNN")
wheel_build_command.append("--config=mkl_open_source_only")
if args.target_cpu_features == "release":
if arch in ["x86_64", "AMD64"]:
logging.debug(
"Using release cpu features: --config=avx_%s",
"windows" if os_name == "windows" else "posix",
)
wheel_build_command.append(
"--config=avx_windows"
if os_name == "windows"
else "--config=avx_posix"
)
elif wheel_build_command == "native":
if os_name == "windows":
logger.warning(
"--target_cpu_features=native is not supported on Windows;"
" ignoring."
)
else:
logging.debug("Using native cpu features: --config=native_arch_posix")
wheel_build_command.append("--config=native_arch_posix")
else:
logging.debug("Using default cpu features")
if "cuda" in wheel:
wheel_build_command.append("--config=cuda")
wheel_build_command.append(
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
)
if args.build_cuda_with_clang:
logging.debug("Building CUDA with Clang")
wheel_build_command.append("--config=build_cuda_with_clang")
else:
logging.debug("Building CUDA with NVCC")
wheel_build_command.append("--config=build_cuda_with_nvcc")
if args.cuda_version:
logging.debug("Hermetic CUDA version: %s", args.cuda_version)
wheel_build_command.append(
f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}"
)
if args.cudnn_version:
logging.debug("Hermetic cuDNN version: %s", args.cudnn_version)
wheel_build_command.append(
f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}"
)
if args.cuda_compute_capabilities:
logging.debug(
"Hermetic CUDA compute capabilities: %s",
args.cuda_compute_capabilities,
)
wheel_build_command.append(
f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}"
)
if "rocm" in wheel:
wheel_build_command.append("--config=rocm_base")
if args.use_clang:
wheel_build_command.append("--config=rocm")
wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
if args.rocm_path:
logging.debug("ROCm tookit path: %s", args.rocm_path)
wheel_build_command.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"")
if args.rocm_amdgpu_targets:
logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets)
wheel_build_command.append(
f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}"
)
# Append additional build options at the end to override any options set in
# .bazelrc or above.
if args.bazel_options:
logging.debug(
"Additional Bazel build options: %s", args.bazel_options
)
for option in args.bazel_options:
wheel_build_command.append(option)
with open(".jax_configure.bazelrc", "w") as f:
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list())
if not jax_configure_options:
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
sys.exit(1)
f.write(jax_configure_options)
logging.info("Bazel options written to .jax_configure.bazelrc")
if args.configure_only:
logging.info("--configure_only is set so not running any Bazel commands.")
else:
# Append the build target to the Bazel command.
build_target = WHEEL_BUILD_TARGET_DICT[wheel]
wheel_build_command.append(build_target)
wheel_build_command.append("--")
output_path = args.output_path
logger.debug("Artifacts output directory: %s", output_path)
if args.editable:
logger.info("Building an editable build")
output_path = os.path.join(output_path, wheel)
wheel_build_command.append("--editable")
wheel_build_command.append(f'--output_path="{output_path}"')
wheel_build_command.append(f"--cpu={target_cpu}")
if "cuda" in wheel:
wheel_build_command.append("--enable-cuda=True")
if args.cuda_version:
cuda_major_version = args.cuda_version.split(".")[0]
else:
cuda_major_version = args.cuda_major_version
wheel_build_command.append(f"--platform_version={cuda_major_version}")
if "rocm" in wheel:
wheel_build_command.append("--enable-rocm=True")
wheel_build_command.append(f"--platform_version={args.rocm_version}")
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
await executor.run(wheel_build_command.get_command_as_string(), args.dry_run)
if __name__ == "__main__":
asyncio.run(main())