diff --git a/.bazelrc b/.bazelrc index 5e386c8b1..503738084 100644 --- a/.bazelrc +++ b/.bazelrc @@ -124,9 +124,10 @@ build:cuda --@local_config_cuda//:enable_cuda # Default hermetic CUDA and CUDNN versions. build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +build:cuda --@local_config_cuda//cuda:include_cuda_libs=true -# This flag is needed to include CUDA libraries for bazel tests. -test:cuda --@local_config_cuda//cuda:include_cuda_libs=true +# This config is used for building targets with CUDA libraries from stubs. +build:cuda_libraries_from_stubs --@local_config_cuda//cuda:include_cuda_libs=false # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to diff --git a/build/build.py b/build/build.py index d2f68f80e..c2933416d 100755 --- a/build/build.py +++ b/build/build.py @@ -532,6 +532,7 @@ async def main(): if "cuda" in args.wheels: wheel_build_command_base.append("--config=cuda") + wheel_build_command_base.append("--config=cuda_libraries_from_stubs") if args.use_clang: wheel_build_command_base.append( f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""