25 Commits

Author SHA1 Message Date
Peter Hawkins
3f91b4b43a Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/

Cleanup only, no functional changes intended.

PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00
jax authors
1aca76fc13 Update :build_jaxlib flag to control whether we should add py_import dependencies to the test targets.
This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.

There are three options for running the tests:

1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.

PiperOrigin-RevId: 735765819
2025-03-11 08:31:43 -07:00
jax authors
3edc068f8c Fix ambiguous cpu definition for JAX wheels.
Should fix the error in https://github.com/jax-ml/jax/actions/runs/13682579939/job/38258344926.

PiperOrigin-RevId: 733838895
2025-03-05 12:59:21 -08:00
David Dunleavy
1a19d5594a Update all uses of @tsl//third_party to @xla//third_party
PiperOrigin-RevId: 733495240
2025-03-04 15:55:23 -08:00
jax authors
8f57b8167b Add build targets for jax-rocm-plugin and jax-rocm-pjrt wheels.
PiperOrigin-RevId: 732149495
2025-02-28 08:36:46 -08:00
jax authors
401d315091 Add targets for jaxlib, jax-cuda-plugin and jax-cuda-pjrt editable wheels.
PiperOrigin-RevId: 731737119
2025-02-27 08:33:40 -08:00
jax authors
d424f5b5b3 Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results.
This change is a part of the initiative to test the JAX wheels in the presubmit properly.

The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.

2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase.

3. The version suffix of the wheel in the build rule output depends on the environment variables.

   The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables.

4. Environment variables combinations for creating wheels with different versions:
  * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release`
  * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1`
  * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 723552265
2025-02-05 10:01:23 -08:00
Nitin Srinivasan
6b096b0cb0 Use common set of build options when building jaxlib+plugin artifacts together
This commit modifies the behavior of the build CLI when building jaxlib and GPU plugin artifacts together (for instance `python build --wheels=jaxlib,jax-cuda-plugin`.

Before, CUDA/ROCm build options were only passed when building the CUDA/ROCm artifacts. However, this leads to inefficient use of the build cache as it looks like Bazel tries to rebuild some targets that has already been built in the previous run. This seems to be because the GPU plugin artifacts have a different set of build options compared to `jaxlib` which for some reason causes Bazel to invalidate/ignore certain cache hits. Therefore, this commit makes it so that the build options remain the same when the `jaxlib` and GPU artifacts are being built together so that we can better utilize the build cache.

As an example, this means that if `python build --wheels=jaxlib,jax-cuda-plugin` is run, the following build options will apply to both `jaxlib` and `jax-cuda-plugin` builds:
```
 /usr/local/bin/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
--verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--repo_env=CC="/usr/lib/llvm-16/bin/clang" \
--repo_env=BAZEL_COMPILER="/usr/lib/llvm-16/bin/clang" \
--config=clang --config=mkl_open_source_only --config=avx_posix \
--config=cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--config=build_cuda_with_nvcc
```

Note, this commit shouldn't affect the content of the wheel it self. It is only meant to give a performance boost when building `jalxib`+plugin aritfacts together.

Also, this removes code that was used to build (now deprecated) monolithic `jaxlib` build from `build_wheel.py`

PiperOrigin-RevId: 708035062
2024-12-19 14:29:24 -08:00
Kanglan Tang
af28595909 Add a jax_wheel Bazel rule to build jax pip packages
PiperOrigin-RevId: 689514531
2024-10-24 14:20:46 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
Adam Paszke
611ad63060 Add basic PyTorch integration for Mosaic GPU
We have already had most of the relevant pieces and we only needed
to connect them together. The most sensitive change is perhaps that
I needed to expose one more symbol from the XLA GPU plugin, but I don't
think it should be a problem.
2024-09-18 12:55:23 +00:00
Ruturaj4
332435e028 [ROCM] make mosaic dependency cuda specific 2024-07-02 11:05:42 -05:00
Ruturaj4
99c2b7b4e9 [ROCm] Bring-up pjrt support 2024-06-17 16:49:22 +00:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
Dan Foreman-Mackey
88790711e8 Package XLA FFI headers with jaxlib wheel
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.

It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.

All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
2024-05-22 12:28:38 -04:00
Vadym Matsishevskyi
517e299a9d Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details
PiperOrigin-RevId: 634146391
2024-05-15 18:20:56 -07:00
Peter Hawkins
dedd69f323 Add a bazel test that verifies that the jaxlib wheel builds. 2024-01-11 23:22:17 +00:00
Jake VanderPlas
326d1d27ef jaxlib: avoid external build-time dependency on ml_dtypes
Currently, the ml_dtypes C++ sources are included in the set of sources at jaxlib build time. This is unnecessary, and can lead to problematic version skew in some cases (e.g. nightly builds).

PiperOrigin-RevId: 595725529
2024-01-04 09:26:05 -08:00
Jieying Luo
d6c5910105 [PJRT C API] Move cuda_plugin_extension from jaxlib to jax-cuda-plugin (the package for cuda kernels).
PiperOrigin-RevId: 583406466
2023-11-17 09:11:46 -08:00
Jieying Luo
ec21e04201 [PJRT C API] Rename the folder "plugins" to "jax_plugins".
With this change, existing plugin discovery mechanism can discover local plugins without pip install.

Update jax_plugins/cuda/__init__.py to return without registering the plugin if the .so file does not exist.

PiperOrigin-RevId: 582431300
2023-11-14 13:56:13 -08:00
Jieying Luo
462ef165c4 [PJRT C API] Change build wheel script to build a separate package for cuda kernels.
With this change, `python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` will generate three wheels:

|                      |size|wheel name                                                               |
|----------------------|----|-------------------------------------------------------------------------|
|jaxlib w/o cuda kernels|76M |jaxlib-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl           |
|cuda pjrt              |73M|jax_cuda12_pjrt-0.4.20.dev20231101-py3-none-manylinux2014_x86_64.whl                    |
|cuda kernels           |6.6M|jax_cuda12_plugin-0.4.20.dev20231101-cp310-cp310-manylinux2014_x86_64.whl|

The size of jaxlib with cuda kernels and pjrt is 119M.

The cuda kernel wheel contains all the cuda kernels. A plugin_setup.py and plugin_pyproject.toml are added for this new pacakge.

PiperOrigin-RevId: 579861480
2023-11-06 09:13:44 -08:00
Jieying Luo
0e24b90043 [PJRT C API] Register custom callback for xla_python_gpu_callback in plugin module.
PiperOrigin-RevId: 568671822
2023-09-26 15:54:10 -07:00
Jieying Luo
91fbf9da26 [PJRT C API] Set up jax xla cuda package.
Add a build wheel, pyproject.toml and setup.py.

The directory structure in jax repo is:
jax/
└── plugins/
     └── cuda/
          ├── __init__.py
          ├── pyproject.toml
          └── setup.py

Installed package structure is:
jax_plugins/
     └── xla_cuda_cu12/
           ├── __init__.py
           └── xla_cuda_plugin.so

The major cuda version will be part of the package name.

The plugin wheel can be built with command:
python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla"

PiperOrigin-RevId: 565187954
2023-09-13 16:03:53 -07:00
Peter Hawkins
88408e13ee Remove stale references to //jaxlib:setup.cfg in Bazel build.
Fixes broken jaxlib wheel build.
2023-09-03 19:18:25 +00:00
Peter Hawkins
f540ae4338 Fix warning about direct invocation of setup.py during jaxlib build.
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.

To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.

PiperOrigin-RevId: 548133811
2023-07-14 08:31:16 -07:00