diff --git a/build/build.py b/build/build.py index 4ebd0678b..0b640dd12 100755 --- a/build/build.py +++ b/build/build.py @@ -423,6 +423,8 @@ async def main(): else: sys.exit(0) + wheel_build_command_base = copy.deepcopy(bazel_command_base) + wheel_cpus = { "darwin_arm64": "arm64", "darwin_x86_64": "x86_64", @@ -435,177 +437,181 @@ async def main(): if args.local_xla_path: logging.debug("Local XLA path: %s", args.local_xla_path) - bazel_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"") + wheel_build_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"") if args.target_cpu: logging.debug("Target CPU: %s", args.target_cpu) - bazel_command_base.append(f"--cpu={args.target_cpu}") + wheel_build_command_base.append(f"--cpu={args.target_cpu}") if args.disable_nccl: logging.debug("Disabling NCCL") - bazel_command_base.append("--config=nonccl") + wheel_build_command_base.append("--config=nonccl") git_hash = utils.get_githash() - # Wheel build command execution - for wheel in args.wheels.split(","): - # Allow CUDA/ROCm wheels without the "jax-" prefix. - if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: - wheel = "jax-" + wheel - - if wheel not in WHEEL_BUILD_TARGET_DICT.keys(): - logging.error( - "Incorrect wheel name provided, valid choices are jaxlib," - " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," - " jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt" - ) - sys.exit(1) - - wheel_build_command = copy.deepcopy(bazel_command_base) - print("\n") - logger.info( - "Building %s for %s %s...", - wheel, - os_name, - arch, + clang_path = "" + if args.use_clang: + clang_path = args.clang_path or utils.get_clang_path_or_exit() + clang_major_version = utils.get_clang_major_version(clang_path) + logging.debug( + "Using Clang as the compiler, clang path: %s, clang version: %s", + clang_path, + clang_major_version, ) - clang_path = "" + # Use double quotes around clang path to avoid path issues on Windows. + wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + wheel_build_command_base.append(f"--repo_env=CC=\"{clang_path}\"") + wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") + + if clang_major_version >= 16: + # Enable clang settings that are needed for the build to work with newer + # versions of Clang. + wheel_build_command_base.append("--config=clang") + else: + gcc_path = args.gcc_path or utils.get_gcc_path_or_exit() + logging.debug( + "Using GCC as the compiler, gcc path: %s", + gcc_path, + ) + wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"") + wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"") + + if not args.disable_mkl_dnn: + logging.debug("Enabling MKL DNN") + if target_cpu == "aarch64": + wheel_build_command_base.append("--config=mkl_aarch64_threadpool") + else: + wheel_build_command_base.append("--config=mkl_open_source_only") + + if args.target_cpu_features == "release": + if arch in ["x86_64", "AMD64"]: + logging.debug( + "Using release cpu features: --config=avx_%s", + "windows" if os_name == "windows" else "posix", + ) + wheel_build_command_base.append( + "--config=avx_windows" + if os_name == "windows" + else "--config=avx_posix" + ) + elif args.target_cpu_features == "native": + if os_name == "windows": + logger.warning( + "--target_cpu_features=native is not supported on Windows;" + " ignoring." + ) + else: + logging.debug("Using native cpu features: --config=native_arch_posix") + wheel_build_command_base.append("--config=native_arch_posix") + else: + logging.debug("Using default cpu features") + + if "cuda" in args.wheels and "rocm" in args.wheels: + logging.error("CUDA and ROCm cannot be enabled at the same time.") + sys.exit(1) + + if "cuda" in args.wheels: + wheel_build_command_base.append("--config=cuda") if args.use_clang: - clang_path = args.clang_path or utils.get_clang_path_or_exit() - clang_major_version = utils.get_clang_major_version(clang_path) - logging.debug( - "Using Clang as the compiler, clang path: %s, clang version: %s", - clang_path, - clang_major_version, + wheel_build_command_base.append( + f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" ) - - # Use double quotes around clang path to avoid path issues on Windows. - wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") - wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"") - wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"") - - if clang_major_version >= 16: - # Enable clang settings that are needed for the build to work with newer - # versions of Clang. - wheel_build_command.append("--config=clang") - else: - gcc_path = args.gcc_path or utils.get_gcc_path_or_exit() - logging.debug( - "Using GCC as the compiler, gcc path: %s", - gcc_path, - ) - wheel_build_command.append(f"--repo_env=CC=\"{gcc_path}\"") - wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"") - - if not args.disable_mkl_dnn: - logging.debug("Enabling MKL DNN") - if target_cpu == "aarch64": - wheel_build_command.append("--config=mkl_aarch64_threadpool") - else: - wheel_build_command.append("--config=mkl_open_source_only") - - if args.target_cpu_features == "release": - if arch in ["x86_64", "AMD64"]: - logging.debug( - "Using release cpu features: --config=avx_%s", - "windows" if os_name == "windows" else "posix", - ) - wheel_build_command.append( - "--config=avx_windows" - if os_name == "windows" - else "--config=avx_posix" - ) - elif args.target_cpu_features == "native": - if os_name == "windows": - logger.warning( - "--target_cpu_features=native is not supported on Windows;" - " ignoring." - ) - else: - logging.debug("Using native cpu features: --config=native_arch_posix") - wheel_build_command.append("--config=native_arch_posix") - else: - logging.debug("Using default cpu features") - - if "cuda" in wheel: - wheel_build_command.append("--config=cuda") - if args.use_clang: - wheel_build_command.append( - f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\"" - ) - if args.build_cuda_with_clang: - logging.debug("Building CUDA with Clang") - wheel_build_command.append("--config=build_cuda_with_clang") - else: - logging.debug("Building CUDA with NVCC") - wheel_build_command.append("--config=build_cuda_with_nvcc") + if args.build_cuda_with_clang: + logging.debug("Building CUDA with Clang") + wheel_build_command_base.append("--config=build_cuda_with_clang") else: logging.debug("Building CUDA with NVCC") - wheel_build_command.append("--config=build_cuda_with_nvcc") - - if args.cuda_version: - logging.debug("Hermetic CUDA version: %s", args.cuda_version) - wheel_build_command.append( - f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}" - ) - if args.cudnn_version: - logging.debug("Hermetic cuDNN version: %s", args.cudnn_version) - wheel_build_command.append( - f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}" - ) - if args.cuda_compute_capabilities: - logging.debug( - "Hermetic CUDA compute capabilities: %s", - args.cuda_compute_capabilities, - ) - wheel_build_command.append( - f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}" - ) - - if "rocm" in wheel: - wheel_build_command.append("--config=rocm_base") - if args.use_clang: - wheel_build_command.append("--config=rocm") - wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") - if args.rocm_path: - logging.debug("ROCm tookit path: %s", args.rocm_path) - wheel_build_command.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") - if args.rocm_amdgpu_targets: - logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) - wheel_build_command.append( - f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" - ) - - # Append additional build options at the end to override any options set in - # .bazelrc or above. - if args.bazel_options: - logging.debug( - "Additional Bazel build options: %s", args.bazel_options - ) - for option in args.bazel_options: - wheel_build_command.append(option) - - with open(".jax_configure.bazelrc", "w") as f: - jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list()) - if not jax_configure_options: - logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") - sys.exit(1) - f.write(jax_configure_options) - logging.info("Bazel options written to .jax_configure.bazelrc") - - if args.configure_only: - logging.info("--configure_only is set so not running any Bazel commands.") + wheel_build_command_base.append("--config=build_cuda_with_nvcc") else: + logging.debug("Building CUDA with NVCC") + wheel_build_command_base.append("--config=build_cuda_with_nvcc") + + if args.cuda_version: + logging.debug("Hermetic CUDA version: %s", args.cuda_version) + wheel_build_command_base.append( + f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}" + ) + if args.cudnn_version: + logging.debug("Hermetic cuDNN version: %s", args.cudnn_version) + wheel_build_command_base.append( + f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}" + ) + if args.cuda_compute_capabilities: + logging.debug( + "Hermetic CUDA compute capabilities: %s", + args.cuda_compute_capabilities, + ) + wheel_build_command_base.append( + f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}" + ) + + if "rocm" in args.wheels: + wheel_build_command_base.append("--config=rocm_base") + if args.use_clang: + wheel_build_command_base.append("--config=rocm") + wheel_build_command_base.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"") + if args.rocm_path: + logging.debug("ROCm tookit path: %s", args.rocm_path) + wheel_build_command_base.append(f"--action_env=ROCM_PATH=\"{args.rocm_path}\"") + if args.rocm_amdgpu_targets: + logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) + wheel_build_command_base.append( + f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" + ) + + # Append additional build options at the end to override any options set in + # .bazelrc or above. + if args.bazel_options: + logging.debug( + "Additional Bazel build options: %s", args.bazel_options + ) + for option in args.bazel_options: + wheel_build_command_base.append(option) + + with open(".jax_configure.bazelrc", "w") as f: + jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command_base.get_command_as_list()) + if not jax_configure_options: + logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") + sys.exit(1) + f.write(jax_configure_options) + logging.info("Bazel options written to .jax_configure.bazelrc") + + if args.configure_only: + logging.info("--configure_only is set so not running any Bazel commands.") + else: + output_path = args.output_path + logger.debug("Artifacts output directory: %s", output_path) + + # Wheel build command execution + for wheel in args.wheels.split(","): + # Allow CUDA/ROCm wheels without the "jax-" prefix. + if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel: + wheel = "jax-" + wheel + + if wheel not in WHEEL_BUILD_TARGET_DICT.keys(): + logging.error( + "Incorrect wheel name provided, valid choices are jaxlib," + " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt," + " jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt" + ) + sys.exit(1) + + wheel_build_command = copy.deepcopy(wheel_build_command_base) + print("\n") + logger.info( + "Building %s for %s %s...", + wheel, + os_name, + arch, + ) + # Append the build target to the Bazel command. build_target = WHEEL_BUILD_TARGET_DICT[wheel] wheel_build_command.append(build_target) wheel_build_command.append("--") - output_path = args.output_path - logger.debug("Artifacts output directory: %s", output_path) - if args.editable: logger.info("Building an editable build") output_path = os.path.join(output_path, wheel) diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 48dc03cfb..63f2643fe 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -39,11 +39,6 @@ py_binary( "@xla//xla/python:xla_extension", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", - ]) + if_cuda([ - "//jaxlib/cuda:cuda_gpu_support", - "@local_config_cuda//cuda:cuda-nvvm", - ]) + if_rocm([ - "//jaxlib/rocm:rocm_gpu_support", ]), deps = [ "//jax/tools:build_utils", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 4db36fa0e..b46a50961 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -56,13 +56,6 @@ parser.add_argument( action="store_true", help="Create an 'editable' jaxlib build instead of a wheel.", ) -parser.add_argument( - "--skip_gpu_kernels", - # args.skip_gpu_kernels is True when - # --skip_gpu_kernels is in the command - action="store_true", - help="Whether to skip gpu kernels in jaxlib.", -) args = parser.parse_args() r = runfiles.Create() @@ -169,7 +162,7 @@ plat_name={tag} ) -def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): +def prepare_wheel(sources_path: pathlib.Path, *, cpu): """Assembles a source tree for the wheel in `sources_path`.""" copy_runfiles = functools.partial(build_utils.copy_file, runfiles=r) @@ -220,35 +213,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): ], ) - if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not skip_gpu_kernels: - copy_runfiles( - dst_dir=jaxlib_dir / "cuda", - src_files=[ - f"__main__/jaxlib/cuda/_solver.{pyext}", - f"__main__/jaxlib/cuda/_blas.{pyext}", - f"__main__/jaxlib/cuda/_linalg.{pyext}", - f"__main__/jaxlib/cuda/_prng.{pyext}", - f"__main__/jaxlib/cuda/_rnn.{pyext}", - f"__main__/jaxlib/cuda/_sparse.{pyext}", - f"__main__/jaxlib/cuda/_triton.{pyext}", - f"__main__/jaxlib/cuda/_hybrid.{pyext}", - f"__main__/jaxlib/cuda/_versions.{pyext}", - ], - ) - if exists(f"__main__/jaxlib/rocm/_solver.{pyext}") and not skip_gpu_kernels: - copy_runfiles( - dst_dir=jaxlib_dir / "rocm", - src_files=[ - f"__main__/jaxlib/rocm/_solver.{pyext}", - f"__main__/jaxlib/rocm/_blas.{pyext}", - f"__main__/jaxlib/rocm/_linalg.{pyext}", - f"__main__/jaxlib/rocm/_prng.{pyext}", - f"__main__/jaxlib/rocm/_sparse.{pyext}", - f"__main__/jaxlib/rocm/_triton.{pyext}", - f"__main__/jaxlib/rocm/_hybrid.{pyext}", - ], - ) - mosaic_python_dir = jaxlib_dir / "mosaic" / "python" copy_runfiles( dst_dir=mosaic_python_dir, @@ -406,7 +370,6 @@ try: prepare_wheel( pathlib.Path(sources_path), cpu=args.cpu, - skip_gpu_kernels=args.skip_gpu_kernels, ) package_name = "jaxlib" if args.editable: