List of changes:
1. Allow us to build a RC wheel when building release artifacts. This is done by modifying the build CLI to use the new JAX build rule and passing in the build options that control the wheel tag. A new build argument `use_new_wheel_build_rule` is introduced to the build CLI to avoid breaking anyone that uses the CLI and the old build rule. Note that this option will go way in the future when the build CLI migrates fully to the new build rule.
2. Change the upload script to upload both rc and release tagged wheels (changes internal)
PiperOrigin-RevId: 733464219
Build.py shouldn't be used for building the wheels with real CUDA libraries in the dependencies. This change prevents overriding the default configuration.
PiperOrigin-RevId: 725326252
With this configuration the same cache is used both for `bazel build` and `bazel test` commands (provided the same target is specified).
Add `--config=no_cuda_libs` for building targets with CUDA libraries from stubs.
PiperOrigin-RevId: 720334587
This was causing an issue when building multiple wheels in editable mode.
i.e instead of wheels being stored as:
```
# jax-cuda12-pjrt 0.4.36.dev20241125 ./dist/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125 ./dist/jax-cuda-plugin
# jaxlib 0.4.36.dev20241125 ./dist/jaxlib
```
they were being stored as:
```
# jaxlib 0.4.36.dev20241125 ./dist/jaxlib
# jax-cuda12-pjrt 0.4.36.dev20241125 ./dist/jaxlib/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125 ./dist/jaxlib/jax-cuda-plugin
```
PiperOrigin-RevId: 708468522
This commit modifies the behavior of the build CLI when building jaxlib and GPU plugin artifacts together (for instance `python build --wheels=jaxlib,jax-cuda-plugin`.
Before, CUDA/ROCm build options were only passed when building the CUDA/ROCm artifacts. However, this leads to inefficient use of the build cache as it looks like Bazel tries to rebuild some targets that has already been built in the previous run. This seems to be because the GPU plugin artifacts have a different set of build options compared to `jaxlib` which for some reason causes Bazel to invalidate/ignore certain cache hits. Therefore, this commit makes it so that the build options remain the same when the `jaxlib` and GPU artifacts are being built together so that we can better utilize the build cache.
As an example, this means that if `python build --wheels=jaxlib,jax-cuda-plugin` is run, the following build options will apply to both `jaxlib` and `jax-cuda-plugin` builds:
```
/usr/local/bin/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
--verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--repo_env=CC="/usr/lib/llvm-16/bin/clang" \
--repo_env=BAZEL_COMPILER="/usr/lib/llvm-16/bin/clang" \
--config=clang --config=mkl_open_source_only --config=avx_posix \
--config=cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--config=build_cuda_with_nvcc
```
Note, this commit shouldn't affect the content of the wheel it self. It is only meant to give a performance boost when building `jalxib`+plugin aritfacts together.
Also, this removes code that was used to build (now deprecated) monolithic `jaxlib` build from `build_wheel.py`
PiperOrigin-RevId: 708035062
The `build.py` script uses Clang compiler by default, and JAX doesn't support building with GCC officially. However, experimental GCC support is still present.
Command examples:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false
python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false --gcc_path=/use/bin/gcc
```
This change addresses the request in https://github.com/jax-ml/jax/issues/25488.
PiperOrigin-RevId: 707930913
These flags disable Clang extensions that do things such as reject type definitions within offsetof or reject unknown arguments which does not seem to be needed on versions older than Clang 16
Also, fix a syntax error
Fixes https://github.com/jax-ml/jax/issues/25530
PiperOrigin-RevId: 707555651
This adds a new command-line flag, `--detailed_timestamped_log`, that enables detailed logging of Bazel build commands. When disabled (the default), logging mirrors the output you'd see when running the command directly in your terminal.
When this flag is enabled:
- Bazel's output is captured line by line.
- Each line is timestamped for improved traceability.
- The complete log is stored for potential use as an artifact.
The flag is disabled by default and only enabled in the CI builds. If you're running locally and enable `detailed_timestamped_log`, you might notice that Bazel's output is not colored. To force a color output, include `--bazel_options=--color=yes` in your command.
PiperOrigin-RevId: 703581368
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.
Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.
There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time.
Usage:
* Building `jaxlib`:
```
python build/build.py build --wheels=jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building multiple packages:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```
For more details on each argument and to see available options, run:
```
python build/build.py build --help
```
or
```
python build/build.py requirements_update --help
```
PiperOrigin-RevId: 700075411
This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability.
PiperOrigin-RevId: 691458051
Set `--action_env=CLANG_CUDA_COMPILER_PATH` after cuda_nvcc configuration
Add `--config=cuda_clang` when `--nouse_cuda_nvcc` flag set
PiperOrigin-RevId: 678873849
If `--use_cuda_nvcc` flag is set the NVCC compiler driver will be used to build the CUDA code (default behavior). Otherwise, if the flag `--nouse_cuda_nvcc` is set, only the clang compiler will be used to build the CUDA code (effectively disabling NVCC).
Mark `--use_clang` flag as deprecated.
Refactor `.bazelrc` configs to match the new flag and to cleanup all previous confusing names.
PiperOrigin-RevId: 678332548
1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases.
[Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/)
[Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/)
[Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history)
2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA.
Note: use `@local_tsl` instead of `@tsl` in Tensorflow project.
```
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
)
cuda_json_init_repository()
load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
"CUDNN_REDISTRIBUTIONS",
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"cuda_redist_init_repositories",
"cudnn_redist_init_repository",
)
cuda_redist_init_repositories(
cuda_redistributions = CUDA_REDISTRIBUTIONS,
)
cudnn_redist_init_repository(
cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"cuda_configure",
)
cuda_configure(name = "local_config_cuda")
load(
"@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"nccl_redist_init_repository",
)
nccl_redist_init_repository()
load(
"@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
"nccl_configure",
)
nccl_configure(name = "local_config_nccl")
```
PiperOrigin-RevId: 662981325
When `build_gpu_plugin` is true, three wheels will be produced (jaxlib, jax-cuda-pjrt and jax-cuda-plugin). If they are editable, they need to be placed in subdirectories to avoid overwrite.
Tested on GPU. After the editable wheels are built, they can be installed with `pip install -e /jax/dist/jax_gpu_pjrt /jax/dist/jaxlib /jax/dist/jax_gpu_plugin`.
PiperOrigin-RevId: 660984311
The plugin is released and the flag is no longer needed.
Also set default value of enable_gpu to False. enable_gpu will be removed in the next change.
PiperOrigin-RevId: 660059432
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.
PiperOrigin-RevId: 638569750