This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.
There are three options for running the tests:
1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.
PiperOrigin-RevId: 735765819
List of changes:
1. Allow us to build a RC wheel when building release artifacts. This is done by modifying the build CLI to use the new JAX build rule and passing in the build options that control the wheel tag. A new build argument `use_new_wheel_build_rule` is introduced to the build CLI to avoid breaking anyone that uses the CLI and the old build rule. Note that this option will go way in the future when the build CLI migrates fully to the new build rule.
2. Change the upload script to upload both rc and release tagged wheels (changes internal)
PiperOrigin-RevId: 733464219
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)
Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).
You can still build the `jax` wheel with `python3 -m build` command.
Bazel `jax` wheel target: `//:jax_wheel`
Environment variables combinations for creating wheels with different versions:
* self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
* release: `--repo_env=ML_WHEEL_TYPE=release`
* release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
* nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`
PiperOrigin-RevId: 730916743
Build.py shouldn't be used for building the wheels with real CUDA libraries in the dependencies. This change prevents overriding the default configuration.
PiperOrigin-RevId: 725326252
With this configuration the same cache is used both for `bazel build` and `bazel test` commands (provided the same target is specified).
Add `--config=no_cuda_libs` for building targets with CUDA libraries from stubs.
PiperOrigin-RevId: 720334587
This was causing an issue when building multiple wheels in editable mode.
i.e instead of wheels being stored as:
```
# jax-cuda12-pjrt 0.4.36.dev20241125 ./dist/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125 ./dist/jax-cuda-plugin
# jaxlib 0.4.36.dev20241125 ./dist/jaxlib
```
they were being stored as:
```
# jaxlib 0.4.36.dev20241125 ./dist/jaxlib
# jax-cuda12-pjrt 0.4.36.dev20241125 ./dist/jaxlib/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125 ./dist/jaxlib/jax-cuda-plugin
```
PiperOrigin-RevId: 708468522
This commit modifies the behavior of the build CLI when building jaxlib and GPU plugin artifacts together (for instance `python build --wheels=jaxlib,jax-cuda-plugin`.
Before, CUDA/ROCm build options were only passed when building the CUDA/ROCm artifacts. However, this leads to inefficient use of the build cache as it looks like Bazel tries to rebuild some targets that has already been built in the previous run. This seems to be because the GPU plugin artifacts have a different set of build options compared to `jaxlib` which for some reason causes Bazel to invalidate/ignore certain cache hits. Therefore, this commit makes it so that the build options remain the same when the `jaxlib` and GPU artifacts are being built together so that we can better utilize the build cache.
As an example, this means that if `python build --wheels=jaxlib,jax-cuda-plugin` is run, the following build options will apply to both `jaxlib` and `jax-cuda-plugin` builds:
```
/usr/local/bin/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
--verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--repo_env=CC="/usr/lib/llvm-16/bin/clang" \
--repo_env=BAZEL_COMPILER="/usr/lib/llvm-16/bin/clang" \
--config=clang --config=mkl_open_source_only --config=avx_posix \
--config=cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--config=build_cuda_with_nvcc
```
Note, this commit shouldn't affect the content of the wheel it self. It is only meant to give a performance boost when building `jalxib`+plugin aritfacts together.
Also, this removes code that was used to build (now deprecated) monolithic `jaxlib` build from `build_wheel.py`
PiperOrigin-RevId: 708035062
The `build.py` script uses Clang compiler by default, and JAX doesn't support building with GCC officially. However, experimental GCC support is still present.
Command examples:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false
python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false --gcc_path=/use/bin/gcc
```
This change addresses the request in https://github.com/jax-ml/jax/issues/25488.
PiperOrigin-RevId: 707930913
These flags disable Clang extensions that do things such as reject type definitions within offsetof or reject unknown arguments which does not seem to be needed on versions older than Clang 16
Also, fix a syntax error
Fixes https://github.com/jax-ml/jax/issues/25530
PiperOrigin-RevId: 707555651
This adds a new command-line flag, `--detailed_timestamped_log`, that enables detailed logging of Bazel build commands. When disabled (the default), logging mirrors the output you'd see when running the command directly in your terminal.
When this flag is enabled:
- Bazel's output is captured line by line.
- Each line is timestamped for improved traceability.
- The complete log is stored for potential use as an artifact.
The flag is disabled by default and only enabled in the CI builds. If you're running locally and enable `detailed_timestamped_log`, you might notice that Bazel's output is not colored. To force a color output, include `--bazel_options=--color=yes` in your command.
PiperOrigin-RevId: 703581368
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.
Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.
There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time.
Usage:
* Building `jaxlib`:
```
python build/build.py build --wheels=jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building multiple packages:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```
For more details on each argument and to see available options, run:
```
python build/build.py build --help
```
or
```
python build/build.py requirements_update --help
```
PiperOrigin-RevId: 700075411
Temporary fixes an issue with `python -m build` that fails when python 3.8 is used because `matplotlib~=3.8.4` is unavailable for this python version.
We are working on creating Bazel build rule with the hermetic Python for JAX wheel ([we already have Jaxlib and plugins build rules ready](https://github.com/jax-ml/jax/pull/23276)). The required python modules are provided in requirements.in file, so when we implement Bazel build rule for JAX wheel, requirements.in will be the only source of dependencies, and test-requirements.txt won't be needed for building JAX wheel.
PiperOrigin-RevId: 692260046