This adds a new command-line flag, `--detailed_timestamped_log`, that enables detailed logging of Bazel build commands. When disabled (the default), logging mirrors the output you'd see when running the command directly in your terminal.
When this flag is enabled:
- Bazel's output is captured line by line.
- Each line is timestamped for improved traceability.
- The complete log is stored for potential use as an artifact.
The flag is disabled by default and only enabled in the CI builds. If you're running locally and enable `detailed_timestamped_log`, you might notice that Bazel's output is not colored. To force a color output, include `--bazel_options=--color=yes` in your command.
PiperOrigin-RevId: 703581368
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.
Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.
There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time.
Usage:
* Building `jaxlib`:
```
python build/build.py build --wheels=jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building multiple packages:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```
For more details on each argument and to see available options, run:
```
python build/build.py build --help
```
or
```
python build/build.py requirements_update --help
```
PiperOrigin-RevId: 700075411
This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability.
PiperOrigin-RevId: 691458051
Set `--action_env=CLANG_CUDA_COMPILER_PATH` after cuda_nvcc configuration
Add `--config=cuda_clang` when `--nouse_cuda_nvcc` flag set
PiperOrigin-RevId: 678873849
If `--use_cuda_nvcc` flag is set the NVCC compiler driver will be used to build the CUDA code (default behavior). Otherwise, if the flag `--nouse_cuda_nvcc` is set, only the clang compiler will be used to build the CUDA code (effectively disabling NVCC).
Mark `--use_clang` flag as deprecated.
Refactor `.bazelrc` configs to match the new flag and to cleanup all previous confusing names.
PiperOrigin-RevId: 678332548
1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases.
[Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/)
[Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/)
[Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history)
2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA.
Note: use `@local_tsl` instead of `@tsl` in Tensorflow project.
```
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
)
cuda_json_init_repository()
load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
"CUDNN_REDISTRIBUTIONS",
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"cuda_redist_init_repositories",
"cudnn_redist_init_repository",
)
cuda_redist_init_repositories(
cuda_redistributions = CUDA_REDISTRIBUTIONS,
)
cudnn_redist_init_repository(
cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"cuda_configure",
)
cuda_configure(name = "local_config_cuda")
load(
"@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"nccl_redist_init_repository",
)
nccl_redist_init_repository()
load(
"@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
"nccl_configure",
)
nccl_configure(name = "local_config_nccl")
```
PiperOrigin-RevId: 662981325
When `build_gpu_plugin` is true, three wheels will be produced (jaxlib, jax-cuda-pjrt and jax-cuda-plugin). If they are editable, they need to be placed in subdirectories to avoid overwrite.
Tested on GPU. After the editable wheels are built, they can be installed with `pip install -e /jax/dist/jax_gpu_pjrt /jax/dist/jaxlib /jax/dist/jax_gpu_plugin`.
PiperOrigin-RevId: 660984311
The plugin is released and the flag is no longer needed.
Also set default value of enable_gpu to False. enable_gpu will be removed in the next change.
PiperOrigin-RevId: 660059432
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.
PiperOrigin-RevId: 638569750
With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:
| |size|wheel name |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl |
|cuda pjrt |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl |
|cuda kernels |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|
The size of jaxlib with cuda kernels and pjrt is 119M.
The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.
PiperOrigin-RevId: 579861480
Add a --verbose option that logs all shell() commands run by the script.
Remove some Python 2 backward compatibility logic related to urllib and shutil.
Enable debug logging on Windows wheel builds.
Also include setuptools in the build requirements and test for its presence in build.py.
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.
PiperOrigin-RevId: 568265745
Add a build wheel, pyproject.toml and setup.py.
The directory structure in jax repo is:
jax/
└── plugins/
└── cuda/
├── __init__.py
├── pyproject.toml
└── setup.py
Installed package structure is:
jax_plugins/
└── xla_cuda_cu12/
├── __init__.py
└── xla_cuda_plugin.so
The major cuda version will be part of the package name.
The plugin wheel can be built with command:
python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla"
PiperOrigin-RevId: 565187954
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.
To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.
PiperOrigin-RevId: 548133811