mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add --use_clang
and --clang_path
options to build.py
PiperOrigin-RevId: 603837975
This commit is contained in:
parent
a7a6b40b55
commit
6928465b87
7
.bazelrc
7
.bazelrc
@ -72,6 +72,13 @@ build:cuda --@xla//xla/python:enable_gpu=true
|
||||
build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
|
||||
build:cuda --define=xla_python_enable_gpu=true
|
||||
|
||||
# Build with nvcc for CUDA and clang for host
|
||||
build:nvcc_clang --config=cuda
|
||||
# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang
|
||||
build:nvcc_clang --action_env=TF_CUDA_CLANG="1"
|
||||
build:nvcc_clang --action_env=TF_NVCC_CLANG="1"
|
||||
build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc
|
||||
|
||||
# Later Bazel flag values override earlier values.
|
||||
# TODO(jieying): remove enable_gpu and xla_python_enable_gpu from build:cuda
|
||||
# after the pluin is released.
|
||||
|
@ -20,7 +20,9 @@
|
||||
import argparse
|
||||
import collections
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
@ -29,7 +31,6 @@ import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
import urllib.request
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -228,13 +229,41 @@ def get_bazel_version(bazel_path):
|
||||
return tuple(int(x) for x in match.group(1).split("."))
|
||||
|
||||
|
||||
def get_clang_path_or_exit():
|
||||
which_clang_output = shutil.which("clang")
|
||||
if which_clang_output:
|
||||
# If we've found a clang on the path, need to get the fully resolved path
|
||||
# to ensure that system headers are found.
|
||||
return str(pathlib.Path(which_clang_output).resolve())
|
||||
else:
|
||||
print(
|
||||
"--use_clang set, but --clang_path is unset and clang cannot be found"
|
||||
" on the PATH. Please pass --clang_path directly."
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
def get_clang_major_version(clang_path):
|
||||
clang_version_proc = subprocess.run(
|
||||
[clang_path, "-E", "-P", "-"],
|
||||
input="__clang_major__",
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
major_version = int(clang_version_proc.stdout)
|
||||
|
||||
return major_version
|
||||
|
||||
|
||||
|
||||
def write_bazelrc(*, python_bin_path, remote_build,
|
||||
cuda_toolkit_path, cudnn_install_path,
|
||||
cuda_version, cudnn_version, rocm_toolkit_path,
|
||||
cpu, cuda_compute_capabilities,
|
||||
rocm_amdgpu_targets, bazel_options, target_cpu_features,
|
||||
wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,
|
||||
enable_rocm, build_gpu_plugin):
|
||||
wheel_cpu, enable_mkl_dnn, use_clang, clang_path,
|
||||
clang_major_version, enable_cuda, enable_nccl, enable_rocm,
|
||||
build_gpu_plugin):
|
||||
tf_cuda_paths = []
|
||||
|
||||
with open("../.jax_configure.bazelrc", "w") as f:
|
||||
@ -246,6 +275,16 @@ def write_bazelrc(*, python_bin_path, remote_build,
|
||||
build --python_path="{python_bin_path}"
|
||||
""").format(python_bin_path=python_bin_path))
|
||||
|
||||
if use_clang:
|
||||
f.write(f'build --action_env CLANG_COMPILER_PATH="{clang_path}"\n')
|
||||
f.write(f'build --repo_env CC="{clang_path}"\n')
|
||||
f.write(f'build --repo_env BAZEL_COMPILER="{clang_path}"\n')
|
||||
bazel_options.append("--copt=-Wno-error=unused-command-line-argument\n")
|
||||
if clang_major_version in (16, 17):
|
||||
# Necessary due to XLA's old version of upb. See:
|
||||
# https://github.com/openxla/xla/blob/c4277a076e249f5b97c8e45c8cb9d1f554089d76/.bazelrc#L505
|
||||
bazel_options.append("--copt=-Wno-gnu-offsetof-extensions\n")
|
||||
|
||||
if cuda_toolkit_path:
|
||||
tf_cuda_paths.append(cuda_toolkit_path)
|
||||
f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n"
|
||||
@ -295,6 +334,9 @@ def write_bazelrc(*, python_bin_path, remote_build,
|
||||
f.write("build --config=cuda\n")
|
||||
if not enable_nccl:
|
||||
f.write("build --config=nonccl\n")
|
||||
if use_clang:
|
||||
f.write("build --config=nvcc_clang\n")
|
||||
f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n")
|
||||
if enable_rocm:
|
||||
f.write("build --config=rocm\n")
|
||||
if not enable_nccl:
|
||||
@ -302,6 +344,7 @@ def write_bazelrc(*, python_bin_path, remote_build,
|
||||
if build_gpu_plugin:
|
||||
f.write("build --config=cuda_plugin\n")
|
||||
|
||||
|
||||
BANNER = r"""
|
||||
_ _ __ __
|
||||
| | / \ \ \/ /
|
||||
@ -374,11 +417,28 @@ def main():
|
||||
"features of the current machine. 'default' means don't opt-in "
|
||||
"to any architectural features and use whatever the C compiler "
|
||||
"generates by default.")
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"use_clang",
|
||||
help_str=(
|
||||
"Should we build using clang as the host compiler? Requires "
|
||||
"clang to be findable via the PATH, or a path to be given via "
|
||||
"--clang_path."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clang_path",
|
||||
help=(
|
||||
"Path to clang binary to use if --use_clang is set. The default is "
|
||||
"to find clang via the PATH."
|
||||
),
|
||||
)
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"enable_mkl_dnn",
|
||||
default=True,
|
||||
help_str="Should we build with MKL-DNN enabled?")
|
||||
help_str="Should we build with MKL-DNN enabled?",
|
||||
)
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"enable_cuda",
|
||||
@ -534,6 +594,15 @@ def main():
|
||||
check_package_is_installed(python_bin_path, python_version, "build")
|
||||
check_package_is_installed(python_bin_path, python_version, "setuptools")
|
||||
|
||||
print("Use clang: {}".format("yes" if args.use_clang else "no"))
|
||||
clang_path = args.clang_path
|
||||
clang_major_version = None
|
||||
if args.use_clang:
|
||||
if not clang_path:
|
||||
clang_path = get_clang_path_or_exit()
|
||||
print(f"clang path: {clang_path}")
|
||||
clang_major_version = get_clang_major_version(clang_path)
|
||||
|
||||
print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no"))
|
||||
print(f"Target CPU: {wheel_cpu}")
|
||||
print(f"Target CPU features: {args.target_cpu_features}")
|
||||
@ -576,6 +645,9 @@ def main():
|
||||
target_cpu_features=args.target_cpu_features,
|
||||
wheel_cpu=wheel_cpu,
|
||||
enable_mkl_dnn=args.enable_mkl_dnn,
|
||||
use_clang=args.use_clang,
|
||||
clang_path=clang_path,
|
||||
clang_major_version=clang_major_version,
|
||||
enable_cuda=args.enable_cuda,
|
||||
enable_nccl=args.enable_nccl,
|
||||
enable_rocm=args.enable_rocm,
|
||||
|
Loading…
x
Reference in New Issue
Block a user