Use common set of build options when building jaxlib+plugin artifacts together

This commit modifies the behavior of the build CLI when building jaxlib and GPU plugin artifacts together (for instance `python build --wheels=jaxlib,jax-cuda-plugin`.

Before, CUDA/ROCm build options were only passed when building the CUDA/ROCm artifacts. However, this leads to inefficient use of the build cache as it looks like Bazel tries to rebuild some targets that has already been built in the previous run. This seems to be because the GPU plugin artifacts have a different set of build options compared to `jaxlib` which for some reason causes Bazel to invalidate/ignore certain cache hits. Therefore, this commit makes it so that the build options remain the same when the `jaxlib` and GPU artifacts are being built together so that we can better utilize the build cache.

As an example, this means that if `python build --wheels=jaxlib,jax-cuda-plugin` is run, the following build options will apply to both `jaxlib` and `jax-cuda-plugin` builds:
```
 /usr/local/bin/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
--verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--repo_env=CC="/usr/lib/llvm-16/bin/clang" \
--repo_env=BAZEL_COMPILER="/usr/lib/llvm-16/bin/clang" \
--config=clang --config=mkl_open_source_only --config=avx_posix \
--config=cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--config=build_cuda_with_nvcc
```

Note, this commit shouldn't affect the content of the wheel it self. It is only meant to give a performance boost when building `jalxib`+plugin aritfacts together.

Also, this removes code that was used to build (now deprecated) monolithic `jaxlib` build from `build_wheel.py`

PiperOrigin-RevId: 708035062
This commit is contained in:
Nitin Srinivasan 2024-12-19 14:28:46 -08:00 committed by jax authors
parent 16712b5116
commit 6b096b0cb0
3 changed files with 156 additions and 192 deletions

View File

@ -423,6 +423,8 @@ async def main():
else:
sys.exit(0)
wheel_build_command_base = copy.deepcopy(bazel_command_base)
wheel_cpus = {
"darwin_arm64": "arm64",
"darwin_x86_64": "x86_64",
@ -435,18 +437,152 @@ async def main():
if args.local_xla_path:
logging.debug("Local XLA path: %s", args.local_xla_path)
bazel_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"")
wheel_build_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"")
if args.target_cpu:
logging.debug("Target CPU: %s", args.target_cpu)
bazel_command_base.append(f"--cpu={args.target_cpu}")
wheel_build_command_base.append(f"--cpu={args.target_cpu}")
if args.disable_nccl:
logging.debug("Disabling NCCL")
bazel_command_base.append("--config=nonccl")
wheel_build_command_base.append("--config=nonccl")
git_hash = utils.get_githash()
clang_path = ""
if args.use_clang:
clang_path = args.clang_path or utils.get_clang_path_or_exit()
clang_major_version = utils.get_clang_major_version(clang_path)
logging.debug(
"Using Clang as the compiler, clang path: %s, clang version: %s",
clang_path,
clang_major_version,
)
# Use double quotes around clang path to avoid path issues on Windows.
wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"")
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"")
if clang_major_version >= 16:
# Enable clang settings that are needed for the build to work with newer
# versions of Clang.
wheel_build_command_base.append("--config=clang")
else:
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
logging.debug(
"Using GCC as the compiler, gcc path: %s",
gcc_path,
)
wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"")
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")
if not args.disable_mkl_dnn:
logging.debug("Enabling MKL DNN")
if target_cpu == "aarch64":
wheel_build_command_base.append("--config=mkl_aarch64_threadpool")
else:
wheel_build_command_base.append("--config=mkl_open_source_only")
if args.target_cpu_features == "release":
if arch in ["x86_64", "AMD64"]:
logging.debug(
"Using release cpu features: --config=avx_%s",
"windows" if os_name == "windows" else "posix",
)
wheel_build_command_base.append(
"--config=avx_windows"
if os_name == "windows"
else "--config=avx_posix"
)
elif args.target_cpu_features == "native":
if os_name == "windows":
logger.warning(
"--target_cpu_features=native is not supported on Windows;"
" ignoring."
)
else:
logging.debug("Using native cpu features: --config=native_arch_posix")
wheel_build_command_base.append("--config=native_arch_posix")
else:
logging.debug("Using default cpu features")
if "cuda" in args.wheels and "rocm" in args.wheels:
logging.error("CUDA and ROCm cannot be enabled at the same time.")
sys.exit(1)
if "cuda" in args.wheels:
wheel_build_command_base.append("--config=cuda")
if args.use_clang:
wheel_build_command_base.append(
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
)
if args.build_cuda_with_clang:
logging.debug("Building CUDA with Clang")
wheel_build_command_base.append("--config=build_cuda_with_clang")
else:
logging.debug("Building CUDA with NVCC")
wheel_build_command_base.append("--config=build_cuda_with_nvcc")
else:
logging.debug("Building CUDA with NVCC")
wheel_build_command_base.append("--config=build_cuda_with_nvcc")
if args.cuda_version:
logging.debug("Hermetic CUDA version: %s", args.cuda_version)
wheel_build_command_base.append(
f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}"
)
if args.cudnn_version:
logging.debug("Hermetic cuDNN version: %s", args.cudnn_version)
wheel_build_command_base.append(
f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}"
)
if args.cuda_compute_capabilities:
logging.debug(
"Hermetic CUDA compute capabilities: %s",
args.cuda_compute_capabilities,
)
wheel_build_command_base.append(
f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}"
)
if "rocm" in args.wheels:
wheel_build_command_base.append("--config=rocm_base")
if args.use_clang:
wheel_build_command_base.append("--config=rocm")
wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
if args.rocm_path:
logging.debug("ROCm tookit path: %s", args.rocm_path)
wheel_build_command_base.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"")
if args.rocm_amdgpu_targets:
logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets)
wheel_build_command_base.append(
f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}"
)
# Append additional build options at the end to override any options set in
# .bazelrc or above.
if args.bazel_options:
logging.debug(
"Additional Bazel build options: %s", args.bazel_options
)
for option in args.bazel_options:
wheel_build_command_base.append(option)
with open(".jax_configure.bazelrc", "w") as f:
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list())
if not jax_configure_options:
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
sys.exit(1)
f.write(jax_configure_options)
logging.info("Bazel options written to .jax_configure.bazelrc")
if args.configure_only:
logging.info("--configure_only is set so not running any Bazel commands.")
else:
output_path = args.output_path
logger.debug("Artifacts output directory: %s", output_path)
# Wheel build command execution
for wheel in args.wheels.split(","):
# Allow CUDA/ROCm wheels without the "jax-" prefix.
@ -461,7 +597,7 @@ async def main():
)
sys.exit(1)
wheel_build_command = copy.deepcopy(bazel_command_base)
wheel_build_command = copy.deepcopy(wheel_build_command_base)
print("\n")
logger.info(
"Building %s for %s %s...",
@ -470,142 +606,12 @@ async def main():
arch,
)
clang_path = ""
if args.use_clang:
clang_path = args.clang_path or utils.get_clang_path_or_exit()
clang_major_version = utils.get_clang_major_version(clang_path)
logging.debug(
"Using Clang as the compiler, clang path: %s, clang version: %s",
clang_path,
clang_major_version,
)
# Use double quotes around clang path to avoid path issues on Windows.
wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"")
wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"")
if clang_major_version >= 16:
# Enable clang settings that are needed for the build to work with newer
# versions of Clang.
wheel_build_command.append("--config=clang")
else:
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
logging.debug(
"Using GCC as the compiler, gcc path: %s",
gcc_path,
)
wheel_build_command.append(f"--repo_env=CC=\"{gcc_path}\"")
wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")
if not args.disable_mkl_dnn:
logging.debug("Enabling MKL DNN")
if target_cpu == "aarch64":
wheel_build_command.append("--config=mkl_aarch64_threadpool")
else:
wheel_build_command.append("--config=mkl_open_source_only")
if args.target_cpu_features == "release":
if arch in ["x86_64", "AMD64"]:
logging.debug(
"Using release cpu features: --config=avx_%s",
"windows" if os_name == "windows" else "posix",
)
wheel_build_command.append(
"--config=avx_windows"
if os_name == "windows"
else "--config=avx_posix"
)
elif args.target_cpu_features == "native":
if os_name == "windows":
logger.warning(
"--target_cpu_features=native is not supported on Windows;"
" ignoring."
)
else:
logging.debug("Using native cpu features: --config=native_arch_posix")
wheel_build_command.append("--config=native_arch_posix")
else:
logging.debug("Using default cpu features")
if "cuda" in wheel:
wheel_build_command.append("--config=cuda")
if args.use_clang:
wheel_build_command.append(
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
)
if args.build_cuda_with_clang:
logging.debug("Building CUDA with Clang")
wheel_build_command.append("--config=build_cuda_with_clang")
else:
logging.debug("Building CUDA with NVCC")
wheel_build_command.append("--config=build_cuda_with_nvcc")
else:
logging.debug("Building CUDA with NVCC")
wheel_build_command.append("--config=build_cuda_with_nvcc")
if args.cuda_version:
logging.debug("Hermetic CUDA version: %s", args.cuda_version)
wheel_build_command.append(
f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}"
)
if args.cudnn_version:
logging.debug("Hermetic cuDNN version: %s", args.cudnn_version)
wheel_build_command.append(
f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}"
)
if args.cuda_compute_capabilities:
logging.debug(
"Hermetic CUDA compute capabilities: %s",
args.cuda_compute_capabilities,
)
wheel_build_command.append(
f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}"
)
if "rocm" in wheel:
wheel_build_command.append("--config=rocm_base")
if args.use_clang:
wheel_build_command.append("--config=rocm")
wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
if args.rocm_path:
logging.debug("ROCm tookit path: %s", args.rocm_path)
wheel_build_command.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"")
if args.rocm_amdgpu_targets:
logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets)
wheel_build_command.append(
f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}"
)
# Append additional build options at the end to override any options set in
# .bazelrc or above.
if args.bazel_options:
logging.debug(
"Additional Bazel build options: %s", args.bazel_options
)
for option in args.bazel_options:
wheel_build_command.append(option)
with open(".jax_configure.bazelrc", "w") as f:
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list())
if not jax_configure_options:
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
sys.exit(1)
f.write(jax_configure_options)
logging.info("Bazel options written to .jax_configure.bazelrc")
if args.configure_only:
logging.info("--configure_only is set so not running any Bazel commands.")
else:
# Append the build target to the Bazel command.
build_target = WHEEL_BUILD_TARGET_DICT[wheel]
wheel_build_command.append(build_target)
wheel_build_command.append("--")
output_path = args.output_path
logger.debug("Artifacts output directory: %s", output_path)
if args.editable:
logger.info("Building an editable build")
output_path = os.path.join(output_path, wheel)

View File

@ -39,11 +39,6 @@ py_binary(
"@xla//xla/python:xla_extension",
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]) + if_cuda([
"//jaxlib/cuda:cuda_gpu_support",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([
"//jaxlib/rocm:rocm_gpu_support",
]),
deps = [
"//jax/tools:build_utils",

View File

@ -56,13 +56,6 @@ parser.add_argument(
action="store_true",
help="Create an 'editable' jaxlib build instead of a wheel.",
)
parser.add_argument(
"--skip_gpu_kernels",
# args.skip_gpu_kernels is True when
# --skip_gpu_kernels is in the command
action="store_true",
help="Whether to skip gpu kernels in jaxlib.",
)
args = parser.parse_args()
r = runfiles.Create()
@ -169,7 +162,7 @@ plat_name={tag}
)
def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
def prepare_wheel(sources_path: pathlib.Path, *, cpu):
"""Assembles a source tree for the wheel in `sources_path`."""
copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r)
@ -220,35 +213,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
],
)
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not skip_gpu_kernels:
copy_runfiles(
dst_dir=jaxlib_dir / "cuda",
src_files=[
f"__main__/jaxlib/cuda/_solver.{pyext}",
f"__main__/jaxlib/cuda/_blas.{pyext}",
f"__main__/jaxlib/cuda/_linalg.{pyext}",
f"__main__/jaxlib/cuda/_prng.{pyext}",
f"__main__/jaxlib/cuda/_rnn.{pyext}",
f"__main__/jaxlib/cuda/_sparse.{pyext}",
f"__main__/jaxlib/cuda/_triton.{pyext}",
f"__main__/jaxlib/cuda/_hybrid.{pyext}",
f"__main__/jaxlib/cuda/_versions.{pyext}",
],
)
if exists(f"__main__/jaxlib/rocm/_solver.{pyext}") and not skip_gpu_kernels:
copy_runfiles(
dst_dir=jaxlib_dir / "rocm",
src_files=[
f"__main__/jaxlib/rocm/_solver.{pyext}",
f"__main__/jaxlib/rocm/_blas.{pyext}",
f"__main__/jaxlib/rocm/_linalg.{pyext}",
f"__main__/jaxlib/rocm/_prng.{pyext}",
f"__main__/jaxlib/rocm/_sparse.{pyext}",
f"__main__/jaxlib/rocm/_triton.{pyext}",
f"__main__/jaxlib/rocm/_hybrid.{pyext}",
],
)
mosaic_python_dir = jaxlib_dir / "mosaic" / "python"
copy_runfiles(
dst_dir=mosaic_python_dir,
@ -406,7 +370,6 @@ try:
prepare_wheel(
pathlib.Path(sources_path),
cpu=args.cpu,
skip_gpu_kernels=args.skip_gpu_kernels,
)
package_name = "jaxlib"
if args.editable: