527 Commits

Author SHA1 Message Date
vfdev-5
5a340a9781 Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true
Description:
- Disable second order vjp tests in RunStateHypothesisTest.test_vjp if JAX_SKIP_SLOW_TESTS=true to reduce the test execution time
  - especially for TSAN CI job where this test takes ~700 seconds to pass with the recent 3.13 cpython
- Removed optional deps for 3.14
2025-04-08 21:02:55 +00:00
jax authors
0039d13295 Merge pull request #27698 from hawkinsp:win
PiperOrigin-RevId: 743611388
2025-04-03 10:21:18 -07:00
jax authors
91b0884ad1 Restrict the regex for copying the wheels.
The change is made to address the case when bazel dir has multiple wheels with different version suffixes. We need to copy only those wheels that were created by the current execution of build.py script.

PiperOrigin-RevId: 743566122
2025-04-03 08:06:37 -07:00
Peter Hawkins
8d59902e73 Fix problem finding clang++ when building JAX via build.py on windows.
It's important we use the un-stemmed name because on Windows there is an .exe suffix.
2025-04-03 15:00:29 +01:00
vfdev-5
8e2c1a18c7 Updates for 3.14
Added tsan ci cpython 3.14 job
2025-04-02 08:32:25 +00:00
jax authors
d974b09056 Fix error in build.py when trying to build aarch64 jaxlib wheel.
PiperOrigin-RevId: 741534342
2025-03-28 08:30:08 -07:00
jax authors
358c55d066 Update instructions for usage of :build_jaxlib=false flag.
By adding [jax wheel testing](https://github.com/jax-ml/jax/pull/27113) functionality, we need to have pre-built jax and jaxlib wheels.

PiperOrigin-RevId: 741249718
2025-03-27 12:54:47 -07:00
jax authors
e342f2dd60 Update the minimum supported CuDNN version to 9.8 (previously 9.1).
Announce maximum supported CUDA version 12.8 (previously 12.3).

PiperOrigin-RevId: 741188737
2025-03-27 09:54:00 -07:00
jax authors
1b7c8e8d08 Add editable jax wheel target.
The set of editable wheels (`jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt`) was used as dependencies in `requirements.in` file together with `:build_jaxlib=false` flag.

After [adding `jax` wheel dependencies](f5a4d1a85c) to the tests when `:build_jaxlib=false` is used, we need an editable `jax` wheel target as well to get the tests passing.

PiperOrigin-RevId: 740840736
2025-03-26 11:25:52 -07:00
jax authors
55318d5824 build/build.py changes: copy the wheels created by the new build wheel targets into the path specified by --output_path.
PiperOrigin-RevId: 740829299
2025-03-26 10:56:19 -07:00
Charles Hofer
4f9571eb2b Fix auditwheel 2025-03-25 14:49:21 +00:00
jax authors
014cf3084a Merge pull request #26775 from ROCm:rocm-fix-numalib
PiperOrigin-RevId: 739968243
2025-03-24 09:32:39 -07:00
jax authors
16dc0ad1dd Add jax_source_package macros and target to generate a source package .tar.gz.
Refactor `jax_wheel` macros, so it outputs a `.whl` file only.

When the macros returns one output object only, it allows all downstream dependencies consume it easily without the need to filter the macros outputs.

The previous implementation design (when `jax_wheel` returned `.tar.gz` and `.whl` files) required one of two options: either create a new target that produces `.whl` only, or to implement filename filtering in the downstream rules. With the new implementation we can just depend on `//:jax_wheel` target that produces the `.whl`.

PiperOrigin-RevId: 738547491
2025-03-19 14:48:36 -07:00
Nitin Srinivasan
031614c22b Pin numpy~=2.1.0 in workflow file instead of test-requirements.txt
PiperOrigin-RevId: 737632771
2025-03-17 08:59:06 -07:00
Nitin Srinivasan
5944c9ed65 Install test dependencies from test-requirements.txt instead of requirements.in
PiperOrigin-RevId: 736878834
2025-03-14 08:57:20 -07:00
jax authors
13eb8d3ae7 Upgrade ml-dtypes version in py3.10-py3.13 hermetic python lock files.
This change is needed to add testing of int2/uint2 dtypes via bazel in presubmit (see https://github.com/jax-ml/jax/pull/21395).

PiperOrigin-RevId: 735895293
2025-03-11 14:41:34 -07:00
jax authors
0db14aa342 Add NVIDIA wheel requirements only for Linux builds.
PiperOrigin-RevId: 735850240
2025-03-11 12:33:54 -07:00
jax authors
1aca76fc13 Update :build_jaxlib flag to control whether we should add py_import dependencies to the test targets.
This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.

There are three options for running the tests:

1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.

PiperOrigin-RevId: 735765819
2025-03-11 08:31:43 -07:00
Charles Hofer
132f88e8d5 Fix ROCm builds not finding numa library 2025-03-11 14:57:48 +00:00
jax authors
007fc7a6f1 Remove version limit for setuptools dependency.
PiperOrigin-RevId: 735453796
2025-03-10 11:36:17 -07:00
Nitin Srinivasan
721d1a3211 Add functionality to allow promoting RC wheels during release
List of changes:
1. Allow us to build a RC wheel when building release artifacts. This is done by modifying the build CLI to use the new JAX build rule and passing in the build options that control the wheel tag. A new build argument `use_new_wheel_build_rule` is introduced to the build CLI to avoid breaking anyone that uses the CLI and the old build rule. Note that this option will go way in the future when the build CLI migrates fully to the new build rule.
2. Change the upload script to upload both rc and release tagged wheels (changes internal)

PiperOrigin-RevId: 733464219
2025-03-04 14:21:12 -08:00
jax authors
615219b1f6 Remove tensorstore dependency from //jax/experimental/array_serialization:serialization in OSS (see https://github.com/google/tensorstore/issues/218)
Disable serialization_test in OSS.

PiperOrigin-RevId: 731463136
2025-02-26 14:47:16 -08:00
jax authors
c9c7250dd4 Upgrade to Bazel 7.4.1
PiperOrigin-RevId: 731278247
2025-02-26 05:33:24 -08:00
jax authors
eb912ad0d9 Create jax wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
2025-02-25 09:30:08 -08:00
H. Vetinari
dd4aa79d6a fix getting gcc major version 2025-02-24 08:03:57 +11:00
jax authors
e64650e2ba Add --config=cuda_libraries_from_stubs in the end of all additional bazel options for CUDA wheels.
Build.py shouldn't be used for building the wheels with real CUDA libraries in the dependencies. This change prevents overriding the default configuration.

PiperOrigin-RevId: 725326252
2025-02-10 13:24:47 -08:00
jax authors
7ffb613b8f Merge pull request #26409 from hawkinsp:fstring
PiperOrigin-RevId: 724390055
2025-02-07 10:23:50 -08:00
Peter Hawkins
d01520c63f Fix a missing "f" on an f-string.
While I'm here, reword the text a bit.
2025-02-07 12:53:24 -05:00
jax authors
1ca8807dca Merge pull request #25810 from ROCm:gh-9948-add-gpu-ci-upstream
PiperOrigin-RevId: 724378710
2025-02-07 09:51:59 -08:00
charleshofer
ebf4a54f4f Add AMD ROCm GPU CI post-build check (#137) 2025-02-06 21:41:53 +00:00
Kanglan Tang
59a3552ae6 Remove portpicker for free threaded python 3.13t in test-requirements.txt
PiperOrigin-RevId: 722776783
2025-02-03 13:30:01 -08:00
jax authors
727d0367a4 Update --config=cuda to add direct dependencies on CUDA libraries both for bazel build and bazel test phases.
With this configuration the same cache is used both for `bazel build` and `bazel test` commands (provided the same target is specified).

Add `--config=no_cuda_libs` for building targets with CUDA libraries from stubs.

PiperOrigin-RevId: 720334587
2025-01-27 15:46:17 -08:00
jax authors
9a60e6fce4 Merge pull request #25917 from ROCm:ci_fix_multi_gpu_test_logic-upstream
PiperOrigin-RevId: 716153760
2025-01-16 02:45:54 -08:00
Ruturaj4
8e88adcd3f Fix run_multi_gpu script multi-gpu issue and refactor code 2025-01-15 22:33:03 +00:00
Ruturaj4
435edf1f8c Add gfx12xx archs 2025-01-15 16:14:40 +00:00
vfdev-5
00806ddaf5 Added 3.13 ft requirements lock file and updated WORKSPACE 2025-01-08 22:47:29 +01:00
jax authors
6e1f060ad3 Merge pull request #25527 from vfdev-5:single-python-version-build-py
PiperOrigin-RevId: 713365267
2025-01-08 11:49:59 -08:00
Vladimir Belitskiy
f2e210b315 Disable avxvnniint8 when building with Clang version < 19, or GCC < 13.
PiperOrigin-RevId: 712516025
2025-01-06 07:06:09 -08:00
Ruturaj4
20b75ab82f Update package indentation fix 2025-01-01 18:50:47 -06:00
vfdev-5
70e06c2dbe Avoid adding conflicting --repo_env=HERMETIC_PYTHON_VERSION= to bazel command 2024-12-21 03:33:33 +01:00
Nitin Srinivasan
0159bead97 Move output path to be inside the wheel build command execution loop
This was causing an issue when building multiple wheels in editable mode.

i.e instead of wheels being stored as:
```
# jax-cuda12-pjrt   0.4.36.dev20241125           ./dist/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125           ./dist/jax-cuda-plugin
# jaxlib            0.4.36.dev20241125           ./dist/jaxlib
```

they were being stored as:
```
# jaxlib            0.4.36.dev20241125           ./dist/jaxlib
# jax-cuda12-pjrt   0.4.36.dev20241125           ./dist/jaxlib/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125           ./dist/jaxlib/jax-cuda-plugin
```

PiperOrigin-RevId: 708468522
2024-12-20 17:27:51 -08:00
Nitin Srinivasan
6b096b0cb0 Use common set of build options when building jaxlib+plugin artifacts together
This commit modifies the behavior of the build CLI when building jaxlib and GPU plugin artifacts together (for instance `python build --wheels=jaxlib,jax-cuda-plugin`.

Before, CUDA/ROCm build options were only passed when building the CUDA/ROCm artifacts. However, this leads to inefficient use of the build cache as it looks like Bazel tries to rebuild some targets that has already been built in the previous run. This seems to be because the GPU plugin artifacts have a different set of build options compared to `jaxlib` which for some reason causes Bazel to invalidate/ignore certain cache hits. Therefore, this commit makes it so that the build options remain the same when the `jaxlib` and GPU artifacts are being built together so that we can better utilize the build cache.

As an example, this means that if `python build --wheels=jaxlib,jax-cuda-plugin` is run, the following build options will apply to both `jaxlib` and `jax-cuda-plugin` builds:
```
 /usr/local/bin/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
--verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--repo_env=CC="/usr/lib/llvm-16/bin/clang" \
--repo_env=BAZEL_COMPILER="/usr/lib/llvm-16/bin/clang" \
--config=clang --config=mkl_open_source_only --config=avx_posix \
--config=cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--config=build_cuda_with_nvcc
```

Note, this commit shouldn't affect the content of the wheel it self. It is only meant to give a performance boost when building `jalxib`+plugin aritfacts together.

Also, this removes code that was used to build (now deprecated) monolithic `jaxlib` build from `build_wheel.py`

PiperOrigin-RevId: 708035062
2024-12-19 14:29:24 -08:00
jax authors
bad6f0f0e9 Merge pull request #25440 from ROCm:jax_pypi-upstream
PiperOrigin-RevId: 707934347
2024-12-19 09:17:46 -08:00
jax authors
25713f95f2 Merge pull request #25148 from ROCm:ci_install_sys_libs-upstream
PiperOrigin-RevId: 707934273
2024-12-19 09:15:55 -08:00
jax authors
c78ca042e7 Add experimental support for building JAX CPU and GPU wheels with GCC.
The `build.py` script uses Clang compiler by default, and JAX doesn't support building with GCC officially. However, experimental GCC support is still present.

Command examples:

```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false
python build/build.py build --wheels=jaxlib,jax-cuda-plugin --use_clang=false --gcc_path=/use/bin/gcc
```

This change addresses the request in https://github.com/jax-ml/jax/issues/25488.

PiperOrigin-RevId: 707930913
2024-12-19 09:03:25 -08:00
Ruturaj4
fefebebea4 Update documentation and add setup.py pypi bindings 2024-12-19 10:21:28 -06:00
Ruturaj4
bfcace4933 [ROCm] ci build and dockerfile changes 2024-12-19 09:02:55 -06:00
Nitin Srinivasan
bcca77cd8e Enable --config=clang only on newer Clang versions
These flags disable Clang extensions that do things such as reject type definitions within offsetof or reject unknown arguments which does not seem to be needed on versions older than Clang 16

Also, fix a syntax error

Fixes https://github.com/jax-ml/jax/issues/25530

PiperOrigin-RevId: 707555651
2024-12-18 08:12:04 -08:00
Adam J. Stewart
1afed917fb
get_githash: fix support for missing git 2024-12-13 10:02:05 -05:00
Sunita Nadampalli
e370deee0f add mkldnn+acl build config for aarch64 platform 2024-12-09 16:03:14 +00:00