Build jaxlib without PJRT GPU deps when plugin will be built.

PiperOrigin-RevId: 573844805
This commit is contained in:
Jieying Luo 2023-10-16 09:58:30 -07:00 committed by jax authors
parent 110b8d7484
commit 0290150c4c
2 changed files with 10 additions and 1 deletions

View File

@ -72,6 +72,12 @@ build:cuda --@xla//xla/python:enable_gpu=true
build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
build:cuda --define=xla_python_enable_gpu=true
# Later Bazel flag values override earlier values.
# TODO(jieying): remove enable_gpu and xla_python_enable_gpu from build:cuda
# after the pluin is released.
build:cuda_plugin --@xla//xla/python:enable_gpu=false
build:cuda_plugin --define=xla_python_enable_gpu=false
# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA

View File

@ -225,7 +225,7 @@ def write_bazelrc(*, python_bin_path, remote_build,
cpu, cuda_compute_capabilities,
rocm_amdgpu_targets, bazel_options, target_cpu_features,
wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,
enable_tpu, enable_rocm):
enable_tpu, enable_rocm, build_gpu_plugin):
tf_cuda_paths = []
with open("../.jax_configure.bazelrc", "w") as f:
@ -292,6 +292,8 @@ def write_bazelrc(*, python_bin_path, remote_build,
f.write("build --config=rocm\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
if build_gpu_plugin:
f.write("build --config=cuda_plugin\n")
BANNER = r"""
_ _ __ __
@ -559,6 +561,7 @@ def main():
enable_nccl=args.enable_nccl,
enable_tpu=args.enable_tpu,
enable_rocm=args.enable_rocm,
build_gpu_plugin=args.build_gpu_plugin,
)
if args.configure_only: