rocm_jax/build/build.py
Nitin Srinivasan da994d3552 Move utility functions in build.py to utils.py
This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability.

PiperOrigin-RevId: 691458051
2024-10-30 10:00:32 -07:00

533 lines
18 KiB
Python
Executable File

#!/usr/bin/python
#
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Helper script for building JAX's libjax easily.
import argparse
import logging
import os
import platform
import textwrap
from tools import utils
logger = logging.getLogger(__name__)
def write_bazelrc(*, remote_build,
cuda_version, cudnn_version, rocm_toolkit_path,
cpu, cuda_compute_capabilities,
rocm_amdgpu_targets, target_cpu_features,
wheel_cpu, enable_mkl_dnn, use_clang, clang_path,
clang_major_version, python_version,
enable_cuda, enable_nccl, enable_rocm,
use_cuda_nvcc):
with open("../.jax_configure.bazelrc", "w") as f:
if not remote_build:
f.write(textwrap.dedent("""\
build --strategy=Genrule=standalone
"""))
if use_clang:
f.write(f'build --action_env CLANG_COMPILER_PATH="{clang_path}"\n')
f.write(f'build --repo_env CC="{clang_path}"\n')
f.write(f'build --repo_env BAZEL_COMPILER="{clang_path}"\n')
f.write('build --copt=-Wno-error=unused-command-line-argument\n')
if clang_major_version in (16, 17, 18):
# Necessary due to XLA's old version of upb. See:
# https://github.com/openxla/xla/blob/c4277a076e249f5b97c8e45c8cb9d1f554089d76/.bazelrc#L505
f.write("build --copt=-Wno-gnu-offsetof-extensions\n")
if rocm_toolkit_path:
f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n"
.format(rocm_toolkit_path=rocm_toolkit_path))
if rocm_amdgpu_targets:
f.write(
f'build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"\n')
if cpu is not None:
f.write(f"build --cpu={cpu}\n")
if target_cpu_features == "release":
if wheel_cpu == "x86_64":
f.write("build --config=avx_windows\n" if utils.is_windows()
else "build --config=avx_posix\n")
elif target_cpu_features == "native":
if utils.is_windows():
print("--target_cpu_features=native is not supported on Windows; ignoring.")
else:
f.write("build --config=native_arch_posix\n")
if enable_mkl_dnn:
f.write("build --config=mkl_open_source_only\n")
if enable_cuda:
f.write("build --config=cuda\n")
if use_cuda_nvcc:
f.write("build --config=build_cuda_with_nvcc\n")
else:
f.write("build --config=build_cuda_with_clang\n")
f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
if cuda_version:
f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n"
.format(cuda_version=cuda_version))
if cudnn_version:
f.write("build --repo_env HERMETIC_CUDNN_VERSION=\"{cudnn_version}\"\n"
.format(cudnn_version=cudnn_version))
if cuda_compute_capabilities:
f.write(
f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n')
if enable_rocm:
f.write("build --config=rocm_base\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
if use_clang:
f.write("build --config=rocm\n")
f.write(f"build --action_env=CLANG_COMPILER_PATH={clang_path}\n")
if python_version:
f.write(
"build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format(
python_version=python_version))
BANNER = r"""
_ _ __ __
| | / \ \ \/ /
_ | |/ _ \ \ /
| |_| / ___ \/ \
\___/_/ \/_/\_\
"""
EPILOG = """
From the 'build' directory in the JAX repository, run
python build.py
or
python3 build.py
to download and build JAX's XLA (jaxlib) dependency.
"""
def _parse_string_as_bool(s):
"""Parses a string as a boolean argument."""
lower = s.lower()
if lower == "true":
return True
elif lower == "false":
return False
else:
raise ValueError(f"Expected either 'true' or 'false'; got {s}")
def add_boolean_argument(parser, name, default=False, help_str=None):
"""Creates a boolean flag."""
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--" + name,
nargs="?",
default=default,
const=True,
type=_parse_string_as_bool,
help=help_str)
group.add_argument("--no" + name, dest=name, action="store_false")
def _get_editable_output_paths(output_path):
"""Returns the paths to the editable wheels."""
return (
os.path.join(output_path, "jaxlib"),
os.path.join(output_path, "jax_gpu_pjrt"),
os.path.join(output_path, "jax_gpu_plugin"),
)
def main():
cwd = os.getcwd()
parser = argparse.ArgumentParser(
description="Builds jaxlib from source.", epilog=EPILOG)
add_boolean_argument(
parser,
"verbose",
default=False,
help_str="Should we produce verbose debugging output?")
parser.add_argument(
"--bazel_path",
help="Path to the Bazel binary to use. The default is to find bazel via "
"the PATH; if none is found, downloads a fresh copy of bazel from "
"GitHub.")
parser.add_argument(
"--python_bin_path",
help="Path to Python binary whose version to match while building with "
"hermetic python. The default is the Python interpreter used to run the "
"build script. DEPRECATED: use --python_version instead.")
parser.add_argument(
"--target_cpu_features",
choices=["release", "native", "default"],
default="release",
help="What CPU features should we target? 'release' enables CPU "
"features that should be enabled for a release build, which on "
"x86-64 architectures enables AVX. 'native' enables "
"-march=native, which generates code targeted to use all "
"features of the current machine. 'default' means don't opt-in "
"to any architectural features and use whatever the C compiler "
"generates by default.")
add_boolean_argument(
parser,
"use_clang",
default = "true",
help_str=(
"DEPRECATED: This flag is redundant because clang is "
"always used as default compiler."
),
)
parser.add_argument(
"--clang_path",
help=(
"Path to clang binary to use. The default is "
"to find clang via the PATH."
),
)
add_boolean_argument(
parser,
"enable_mkl_dnn",
default=True,
help_str="Should we build with MKL-DNN enabled?",
)
add_boolean_argument(
parser,
"enable_cuda",
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN."
)
add_boolean_argument(
parser,
"use_cuda_nvcc",
default=True,
help_str=(
"Should we build CUDA code using NVCC compiler driver? The default value "
"is true. If --nouse_cuda_nvcc flag is used then CUDA code is built "
"by clang compiler."
),
)
add_boolean_argument(
parser,
"build_gpu_plugin",
default=False,
help_str=(
"Are we building the gpu plugin in addition to jaxlib? The GPU "
"plugin is still experimental and is not ready for use yet."
),
)
parser.add_argument(
"--build_gpu_kernel_plugin",
choices=["cuda", "rocm"],
default="",
help=(
"Specify 'cuda' or 'rocm' to build the respective kernel plugin."
" When this flag is set, jaxlib will not be built."
),
)
add_boolean_argument(
parser,
"build_gpu_pjrt_plugin",
default=False,
help_str=(
"Are we building the cuda/rocm pjrt plugin? jaxlib will not be built "
"when this flag is True."
),
)
parser.add_argument(
"--gpu_plugin_cuda_version",
choices=["12"],
default="12",
help="Which CUDA major version the gpu plugin is for.")
parser.add_argument(
"--gpu_plugin_rocm_version",
choices=["60"],
default="60",
help="Which ROCM major version the gpu plugin is for.")
add_boolean_argument(
parser,
"enable_rocm",
help_str="Should we build with ROCm enabled?")
add_boolean_argument(
parser,
"enable_nccl",
default=True,
help_str="Should we build with NCCL enabled? Has no effect for non-CUDA "
"builds.")
add_boolean_argument(
parser,
"remote_build",
default=False,
help_str="Should we build with RBE (Remote Build Environment)?")
parser.add_argument(
"--cuda_version",
default=None,
help="CUDA toolkit version, e.g., 12.3.2")
parser.add_argument(
"--cudnn_version",
default=None,
help="CUDNN version, e.g., 8.9.7.29")
# Caution: if changing the default list of CUDA capabilities, you should also
# update the list in .bazelrc, which is used for wheel builds.
parser.add_argument(
"--cuda_compute_capabilities",
default=None,
help="A comma-separated list of CUDA compute capabilities to support.")
parser.add_argument(
"--rocm_amdgpu_targets",
default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100",
help="A comma-separated list of ROCm amdgpu targets to support.")
parser.add_argument(
"--rocm_path",
default=None,
help="Path to the ROCm toolkit.")
parser.add_argument(
"--bazel_startup_options",
action="append", default=[],
help="Additional startup options to pass to bazel.")
parser.add_argument(
"--bazel_options",
action="append", default=[],
help="Additional options to pass to the main Bazel command to be "
"executed, e.g. `run`.")
parser.add_argument(
"--output_path",
default=os.path.join(cwd, "dist"),
help="Directory to which the jaxlib wheel should be written")
parser.add_argument(
"--target_cpu",
default=None,
help="CPU platform to target. Default is the same as the host machine. "
"Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.")
parser.add_argument(
"--editable",
action="store_true",
help="Create an 'editable' jaxlib build instead of a wheel.")
parser.add_argument(
"--python_version",
default=None,
help="hermetic python version, e.g., 3.10")
add_boolean_argument(
parser,
"configure_only",
default=False,
help_str="If true, writes a .bazelrc file but does not build jaxlib.")
add_boolean_argument(
parser,
"requirements_update",
default=False,
help_str="If true, writes a .bazelrc and updates requirements_lock.txt "
"for a corresponding version of Python but does not build "
"jaxlib.")
add_boolean_argument(
parser,
"requirements_nightly_update",
default=False,
help_str="Same as update_requirements, but will consider dev, nightly "
"and pre-release versions of packages.")
args = parser.parse_args()
logging.basicConfig()
if args.verbose:
logger.setLevel(logging.DEBUG)
if args.enable_cuda and args.enable_rocm:
parser.error("--enable_cuda and --enable_rocm cannot be enabled at the same time.")
print(BANNER)
output_path = os.path.abspath(args.output_path)
os.chdir(os.path.dirname(__file__ or args.prog) or '.')
host_cpu = platform.machine()
wheel_cpus = {
"darwin_arm64": "arm64",
"darwin_x86_64": "x86_64",
"ppc": "ppc64le",
"aarch64": "aarch64",
}
# TODO(phawkins): support other bazel cpu overrides.
wheel_cpu = (wheel_cpus[args.target_cpu] if args.target_cpu is not None
else host_cpu)
# Find a working Bazel.
bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path)
print(f"Bazel binary path: {bazel_path}")
print(f"Bazel version: {bazel_version}")
if args.python_version:
python_version = args.python_version
else:
python_bin_path = utils.get_python_bin_path(args.python_bin_path)
print(f"Python binary path: {python_bin_path}")
python_version = utils.get_python_version(python_bin_path)
print("Python version: {}".format(".".join(map(str, python_version))))
utils.check_python_version(python_version)
python_version = ".".join(map(str, python_version))
print("Use clang: {}".format("yes" if args.use_clang else "no"))
clang_path = args.clang_path
clang_major_version = None
if args.use_clang:
if not clang_path:
clang_path = utils.get_clang_path_or_exit()
print(f"clang path: {clang_path}")
clang_major_version = utils.get_clang_major_version(clang_path)
print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no"))
print(f"Target CPU: {wheel_cpu}")
print(f"Target CPU features: {args.target_cpu_features}")
rocm_toolkit_path = args.rocm_path
print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no"))
if args.enable_cuda:
if args.cuda_compute_capabilities is not None:
print(f"CUDA compute capabilities: {args.cuda_compute_capabilities}")
if args.cuda_version:
print(f"CUDA version: {args.cuda_version}")
if args.cudnn_version:
print(f"CUDNN version: {args.cudnn_version}")
print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no"))
print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no"))
if args.enable_rocm:
if rocm_toolkit_path:
print(f"ROCm toolkit path: {rocm_toolkit_path}")
print(f"ROCm amdgpu targets: {args.rocm_amdgpu_targets}")
write_bazelrc(
remote_build=args.remote_build,
cuda_version=args.cuda_version,
cudnn_version=args.cudnn_version,
rocm_toolkit_path=rocm_toolkit_path,
cpu=args.target_cpu,
cuda_compute_capabilities=args.cuda_compute_capabilities,
rocm_amdgpu_targets=args.rocm_amdgpu_targets,
target_cpu_features=args.target_cpu_features,
wheel_cpu=wheel_cpu,
enable_mkl_dnn=args.enable_mkl_dnn,
use_clang=args.use_clang,
clang_path=clang_path,
clang_major_version=clang_major_version,
python_version=python_version,
enable_cuda=args.enable_cuda,
enable_nccl=args.enable_nccl,
enable_rocm=args.enable_rocm,
use_cuda_nvcc=args.use_cuda_nvcc,
)
if args.requirements_update or args.requirements_nightly_update:
if args.requirements_update:
task = "//build:requirements.update"
else: # args.requirements_nightly_update
task = "//build:requirements_nightly.update"
update_command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true", task, *args.bazel_options])
print(" ".join(update_command))
utils.shell(update_command)
return
if args.configure_only:
return
print("\nBuilding XLA and installing it in the jaxlib source tree...")
command_base = (
bazel_path,
*args.bazel_startup_options,
"run",
"--verbose_failures=true",
*args.bazel_options,
)
if args.build_gpu_plugin and args.editable:
output_path_jaxlib, output_path_jax_pjrt, output_path_jax_kernel = (
_get_editable_output_paths(output_path)
)
else:
output_path_jaxlib = output_path
output_path_jax_pjrt = output_path
output_path_jax_kernel = output_path
if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin:
build_cpu_wheel_command = [
*command_base,
"//jaxlib/tools:build_wheel",
"--",
f"--output_path={output_path_jaxlib}",
f"--jaxlib_git_hash={utils.get_githash()}",
f"--cpu={wheel_cpu}",
]
if args.build_gpu_plugin:
build_cpu_wheel_command.append("--skip_gpu_kernels")
if args.editable:
build_cpu_wheel_command.append("--editable")
print(" ".join(build_cpu_wheel_command))
utils.shell(build_cpu_wheel_command)
if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \
(args.build_gpu_kernel_plugin == "rocm"):
build_gpu_kernels_command = [
*command_base,
"//jaxlib/tools:build_gpu_kernels_wheel",
"--",
f"--output_path={output_path_jax_kernel}",
f"--jaxlib_git_hash={utils.get_githash()}",
f"--cpu={wheel_cpu}",
]
if args.enable_cuda:
build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}")
build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_cuda_version}")
elif args.enable_rocm:
build_gpu_kernels_command.append(f"--enable-rocm={args.enable_rocm}")
build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_rocm_version}")
else:
raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.")
if args.editable:
build_gpu_kernels_command.append("--editable")
print(" ".join(build_gpu_kernels_command))
utils.shell(build_gpu_kernels_command)
if args.build_gpu_plugin or args.build_gpu_pjrt_plugin:
build_pjrt_plugin_command = [
*command_base,
"//jaxlib/tools:build_gpu_plugin_wheel",
"--",
f"--output_path={output_path_jax_pjrt}",
f"--jaxlib_git_hash={utils.get_githash()}",
f"--cpu={wheel_cpu}",
]
if args.enable_cuda:
build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}")
build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_cuda_version}")
elif args.enable_rocm:
build_pjrt_plugin_command.append(f"--enable-rocm={args.enable_rocm}")
build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_rocm_version}")
else:
raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.")
if args.editable:
build_pjrt_plugin_command.append("--editable")
print(" ".join(build_pjrt_plugin_command))
utils.shell(build_pjrt_plugin_command)
utils.shell([bazel_path] + args.bazel_startup_options + ["shutdown"])
if __name__ == "__main__":
main()