Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
- especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
The change is made to address the case when bazel dir has multiple wheels with different version suffixes. We need to copy only those wheels that were created by the current execution of build.py script.
PiperOrigin-RevId: 743566122
By adding [jax wheel testing](https://github.com/jax-ml/jax/pull/27113) functionality, we need to have pre-built jax and jaxlib wheels.
PiperOrigin-RevId: 741249718
The set of editable wheels (`jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt`) was used as dependencies in `requirements.in` file together with `:build_jaxlib=false` flag.
After [adding `jax` wheel dependencies](f5a4d1a85c) to the tests when `:build_jaxlib=false` is used, we need an editable `jax` wheel target as well to get the tests passing.
PiperOrigin-RevId: 740840736
Refactor `jax_wheel` macros, so it outputs a `.whl` file only.
When the macros returns one output object only, it allows all downstream dependencies consume it easily without the need to filter the macros outputs.
The previous implementation design (when `jax_wheel` returned `.tar.gz` and `.whl` files) required one of two options: either create a new target that produces `.whl` only, or to implement filename filtering in the downstream rules. With the new implementation we can just depend on `//:jax_wheel` target that produces the `.whl`.
PiperOrigin-RevId: 738547491
This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.
There are three options for running the tests:
1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.
PiperOrigin-RevId: 735765819
List of changes:
1. Allow us to build a RC wheel when building release artifacts. This is done by modifying the build CLI to use the new JAX build rule and passing in the build options that control the wheel tag. A new build argument `use_new_wheel_build_rule` is introduced to the build CLI to avoid breaking anyone that uses the CLI and the old build rule. Note that this option will go way in the future when the build CLI migrates fully to the new build rule.
2. Change the upload script to upload both rc and release tagged wheels (changes internal)
PiperOrigin-RevId: 733464219
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)
Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).
You can still build the `jax` wheel with `python3 -m build` command.
Bazel `jax` wheel target: `//:jax_wheel`
Environment variables combinations for creating wheels with different versions:
* self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
* release: `--repo_env=ML_WHEEL_TYPE=release`
* release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
* nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`
PiperOrigin-RevId: 730916743
Build.py shouldn't be used for building the wheels with real CUDA libraries in the dependencies. This change prevents overriding the default configuration.
PiperOrigin-RevId: 725326252
With this configuration the same cache is used both for `bazel build` and `bazel test` commands (provided the same target is specified).
Add `--config=no_cuda_libs` for building targets with CUDA libraries from stubs.
PiperOrigin-RevId: 720334587
This was causing an issue when building multiple wheels in editable mode.
i.e instead of wheels being stored as:
```
# jax-cuda12-pjrt 0.4.36.dev20241125 ./dist/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125 ./dist/jax-cuda-plugin
# jaxlib 0.4.36.dev20241125 ./dist/jaxlib
```
they were being stored as:
```
# jaxlib 0.4.36.dev20241125 ./dist/jaxlib
# jax-cuda12-pjrt 0.4.36.dev20241125 ./dist/jaxlib/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125 ./dist/jaxlib/jax-cuda-plugin
```
PiperOrigin-RevId: 708468522
This commit modifies the behavior of the build CLI when building jaxlib and GPU plugin artifacts together (for instance `python build --wheels=jaxlib,jax-cuda-plugin`.
Before, CUDA/ROCm build options were only passed when building the CUDA/ROCm artifacts. However, this leads to inefficient use of the build cache as it looks like Bazel tries to rebuild some targets that has already been built in the previous run. This seems to be because the GPU plugin artifacts have a different set of build options compared to `jaxlib` which for some reason causes Bazel to invalidate/ignore certain cache hits. Therefore, this commit makes it so that the build options remain the same when the `jaxlib` and GPU artifacts are being built together so that we can better utilize the build cache.
As an example, this means that if `python build --wheels=jaxlib,jax-cuda-plugin` is run, the following build options will apply to both `jaxlib` and `jax-cuda-plugin` builds:
```
/usr/local/bin/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
--verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--repo_env=CC="/usr/lib/llvm-16/bin/clang" \
--repo_env=BAZEL_COMPILER="/usr/lib/llvm-16/bin/clang" \
--config=clang --config=mkl_open_source_only --config=avx_posix \
--config=cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--config=build_cuda_with_nvcc
```
Note, this commit shouldn't affect the content of the wheel it self. It is only meant to give a performance boost when building `jalxib`+plugin aritfacts together.
Also, this removes code that was used to build (now deprecated) monolithic `jaxlib` build from `build_wheel.py`
PiperOrigin-RevId: 708035062
The `build.py` script uses Clang compiler by default, and JAX doesn't support building with GCC officially. However, experimental GCC support is still present.
Command examples:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false
python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false --gcc_path=/use/bin/gcc
```
This change addresses the request in https://github.com/jax-ml/jax/issues/25488.
PiperOrigin-RevId: 707930913
These flags disable Clang extensions that do things such as reject type definitions within offsetof or reject unknown arguments which does not seem to be needed on versions older than Clang 16
Also, fix a syntax error
Fixes https://github.com/jax-ml/jax/issues/25530
PiperOrigin-RevId: 707555651