78 Commits

Author SHA1 Message Date
jax authors
1aca76fc13 Update :build_jaxlib flag to control whether we should add py_import dependencies to the test targets.
This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.

There are three options for running the tests:

1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.

PiperOrigin-RevId: 735765819
2025-03-11 08:31:43 -07:00
jax authors
c16f37d89d Set USERPROFILE for Windows builds to fix CI issue.
This change fixes https://github.com/jax-ml/jax/actions/runs/13686468791/job/38270929632.

From the [documentation](https://docs.python.org/3/library/os.path.html#os.path.expanduser):
`On Windows, USERPROFILE will be used if set, otherwise a combination of HOMEPATH and HOMEDRIVE will be used.`

PiperOrigin-RevId: 733935305
2025-03-05 18:09:14 -08:00
jax authors
3edc068f8c Fix ambiguous cpu definition for JAX wheels.
Should fix the error in https://github.com/jax-ml/jax/actions/runs/13682579939/job/38258344926.

PiperOrigin-RevId: 733838895
2025-03-05 12:59:21 -08:00
David Dunleavy
1a19d5594a Update all uses of @tsl//third_party to @xla//third_party
PiperOrigin-RevId: 733495240
2025-03-04 15:55:23 -08:00
jax authors
8f57b8167b Add build targets for jax-rocm-plugin and jax-rocm-pjrt wheels.
PiperOrigin-RevId: 732149495
2025-02-28 08:36:46 -08:00
Dan Foreman-Mackey
c7ed1bd3a8 Add version check to jaxlib plugin imports.
For the CUDA and ROCM plugins, we only support exact matches between the plugin and jaxlib version, and bad things can happen if we try and load mismatched versions. This change issues a warning and skips importing a plugin when there is a version mismatch.

There are a handful of other places where plugins are imported throughout the JAX codebase (e.g. in lax_numpy, mosaic_gpu, and in the plugins themselves). In a follow up it would be good to add version checking there too, but let's start with just these ones.

PiperOrigin-RevId: 731808733
2025-02-27 11:52:17 -08:00
jax authors
401d315091 Add targets for jaxlib, jax-cuda-plugin and jax-cuda-pjrt editable wheels.
PiperOrigin-RevId: 731737119
2025-02-27 08:33:40 -08:00
jax authors
4eb782e402 Update jax_wheel target to produce both wheel and source distribution files.
This change replicates the old method of building `jax` wheel via `python -m build`, which produced `.tar.gz` and `.whl` files.

PiperOrigin-RevId: 731721522
2025-02-27 07:41:13 -08:00
jax authors
d424f5b5b3 Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results.
This change is a part of the initiative to test the JAX wheels in the presubmit properly.

The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.

2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase.

3. The version suffix of the wheel in the build rule output depends on the environment variables.

   The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables.

4. Environment variables combinations for creating wheels with different versions:
  * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release`
  * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1`
  * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 723552265
2025-02-05 10:01:23 -08:00
Ruturaj4
fe68eb8b25 [ROCm] Implement RNN support 2025-01-14 19:04:49 -06:00
Sergei Lebedev
8987867faa [mosaic_gpu] Include Mosaic GPU dialect fiels into jaxlib 2024-12-23 13:46:25 +00:00
Nitin Srinivasan
6b096b0cb0 Use common set of build options when building jaxlib+plugin artifacts together
This commit modifies the behavior of the build CLI when building jaxlib and GPU plugin artifacts together (for instance `python build --wheels=jaxlib,jax-cuda-plugin`.

Before, CUDA/ROCm build options were only passed when building the CUDA/ROCm artifacts. However, this leads to inefficient use of the build cache as it looks like Bazel tries to rebuild some targets that has already been built in the previous run. This seems to be because the GPU plugin artifacts have a different set of build options compared to `jaxlib` which for some reason causes Bazel to invalidate/ignore certain cache hits. Therefore, this commit makes it so that the build options remain the same when the `jaxlib` and GPU artifacts are being built together so that we can better utilize the build cache.

As an example, this means that if `python build --wheels=jaxlib,jax-cuda-plugin` is run, the following build options will apply to both `jaxlib` and `jax-cuda-plugin` builds:
```
 /usr/local/bin/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
--verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--repo_env=CC="/usr/lib/llvm-16/bin/clang" \
--repo_env=BAZEL_COMPILER="/usr/lib/llvm-16/bin/clang" \
--config=clang --config=mkl_open_source_only --config=avx_posix \
--config=cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--config=build_cuda_with_nvcc
```

Note, this commit shouldn't affect the content of the wheel it self. It is only meant to give a performance boost when building `jalxib`+plugin aritfacts together.

Also, this removes code that was used to build (now deprecated) monolithic `jaxlib` build from `build_wheel.py`

PiperOrigin-RevId: 708035062
2024-12-19 14:29:24 -08:00
Charles Hofer
0c6b967e86 Don't look for CUDA files when building the ROCm wheel 2024-12-06 17:23:15 +00:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
Kanglan Tang
af28595909 Add a jax_wheel Bazel rule to build jax pip packages
PiperOrigin-RevId: 689514531
2024-10-24 14:20:46 -07:00
Ionel Gog
ec279f9c54 Add config option to log or fatal when jax.Arrays are GCed.
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.

PiperOrigin-RevId: 687003464
2024-10-17 12:23:16 -07:00
Peter Hawkins
145304a0e0 Remove reference to outfeed_receiver.pyi, which was deleted.
PiperOrigin-RevId: 683195999
2024-10-07 08:37:14 -07:00
Peter Hawkins
5a1d0a6c26 Include the sdy MLIR dialect in jaxlib.
We're seeing test failures from tests assuming that this dialect exists. But given we plan to enable it at some point, we may as well just include it in the build.

The size impact is small (around 400K uncompressed).

PiperOrigin-RevId: 679608092
2024-09-27 08:53:31 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Peter Hawkins
922e652c05 Replace plat-name with plat_name.
The former seems to elicit a deprecation warning from setuptools
recently.
2024-09-18 15:17:49 +00:00
Adam Paszke
611ad63060 Add basic PyTorch integration for Mosaic GPU
We have already had most of the relevant pieces and we only needed
to connect them together. The most sensitive change is perhaps that
I needed to expose one more symbol from the XLA GPU plugin, but I don't
think it should be a problem.
2024-09-18 12:55:23 +00:00
Bart Chrzaszcz
864178d3a3 #sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.

Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations

The following test:

```py
def test_sdy_lowering(self):
  mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  np_inp = np.arange(16).reshape(8, 2)
  s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
  arr = jax.device_put(np_inp, s)

  @partial(jax.jit, out_shardings=s)
  def f(x):
    return x * 2

  print(f.lower(arr).as_text())
```

outputs:

```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <"x"=4, "y"=2>
  func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %c = stablehlo.constant dense<2> : tensor<i64>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
    %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
    return %1 : tensor<8x2xi64>
  }
}
```

Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.

PiperOrigin-RevId: 655127611
2024-07-23 05:32:06 -07:00
Ruturaj4
332435e028 [ROCM] make mosaic dependency cuda specific 2024-07-02 11:05:42 -05:00
jax authors
00528b9858 libdevice.10.bc is removed from JAX wheels bundle.
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.

PiperOrigin-RevId: 647328834
2024-06-27 08:35:59 -07:00
Ruturaj4
99c2b7b4e9 [ROCm] Bring-up pjrt support 2024-06-17 16:49:22 +00:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
George Necula
3bcb8d6831 Remove DUCC FFT from jaxlib
JAX has stopped generating code that uses directly
the DUCC FFT custom calls.
The 6 months backwards compatibility window has also expired.

PiperOrigin-RevId: 638132572
2024-05-28 21:12:23 -07:00
Dan Foreman-Mackey
88790711e8 Package XLA FFI headers with jaxlib wheel
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.

It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.

All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
2024-05-22 12:28:38 -04:00
Vadym Matsishevskyi
517e299a9d Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details
PiperOrigin-RevId: 634146391
2024-05-15 18:20:56 -07:00
jax authors
c3cab2e3d3 Reverts 6c425338d20c0c9be3fc69d2f07ababf79c881d3
PiperOrigin-RevId: 632579101
2024-05-10 12:56:10 -07:00
Sergei Lebedev
8ccbebae4b Fixed Mosaic GPU build following #21029 2024-05-07 17:08:00 +01:00
jax authors
e691c19bb2 Merge pull request #21029 from superbobry:jaxlib-mlir-pyi
PiperOrigin-RevId: 629836927
2024-05-01 14:22:21 -07:00
Sergei Lebedev
442526869f Bundle MLIR .pyi files with jaxlib
This allows mypy and pyright to type check the code using MLIR Python APIs.
2024-05-01 19:37:26 +01:00
Adam Paszke
4051ac2a2f [Mosaic GPU] Only call kernel initializer from inside a custom call
XLA:GPU custom call design is far from ideal, as there's apparently no way to figure
out the CUDA context that will be used to run an HLO module before the custom call is
first called. So, we can't preload the kernel onto the GPU, or else we'll get invalid
handle errors due to the load and launch happening in different CUDA contexts...

Also fix up build_wheel.py to match the rename of the runtime lib.

PiperOrigin-RevId: 629401858
2024-04-30 07:10:05 -07:00
Adam Paszke
340b9e3739 Update GPU and NVGPU MLIR bindings to match upstream MLIR changes
Upstream MLIR Python bindings now require two more extension libraries
to work properly. The dialects fail to import without this change.
2024-04-25 11:41:19 +00:00
Adam Paszke
5a2d7a2df4 Switch Mosaic GPU to a custom pass pipeline and improve the lowering of GPU launch
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.

To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.

There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.

PiperOrigin-RevId: 627670773
2024-04-24 03:27:45 -07:00
Jieying Luo
16b4f69769 Rename arg in build script to be more clear.
The flag means skips GPU plugin extension in jaxlib.

PiperOrigin-RevId: 627203738
2024-04-22 17:22:24 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
jax authors
fb55d59143 This CL introduces 'PluginProgram' in IFRT and exposes this in python via xla_client.compile_ifrt_program().
The IFRT `PluginProgram` is simply a wrapper for arbitrary byte-strings: an IFRT backend that recognizes `PluginProgram` can interpret the byte-string in any way it sees fit.

PiperOrigin-RevId: 621258245
2024-04-02 12:20:35 -07:00
jax authors
df9cefabc1 jaxlib: Add ifrt_proxy.pyi to build_wheel.py.
PiperOrigin-RevId: 617275734
2024-03-19 13:27:39 -07:00
Peter Hawkins
ef40b85c8b Don't build the Triton MLIR dialect on Windows
This dialect doesn't build on Windows, but we don't support GPUs on Windows anyway, so we can simply exclude it from the build.

CI failures look like this:
```
C:\npm\prefix\bazel.CMD run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=C:\a\jax\jax\jax\dist --jaxlib_git_hash=5f19f7712b485493ac141c44eea3b3eb1ffdfb59 --cpu=AMD64
b"external/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2672: 'mlir::OpState::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2783: 'enable_if<llvm::function_traits<decay<FnT>::type,std::is_class<T>::value>::num_args==1,RetT>::type mlir::OpState::walk(FnT &&)': could not deduce template argument for 'RetT'\r\n        with\r\n        [\r\n            T=decay<FnT>::type\r\n        ]\r\nexternal/llvm-project/mlir/include\\mlir/IR/OpDefinition.h(165): note: see declaration of 'mlir::OpState::walk'\r\nexternal/llvm-project/mlir/include\\mlir/IR/PatternMatch.h(357): error C2872: 'detail': ambiguous symbol\r\nexternal/llvm-project/mlir/include\\mlir/Rewrite/FrozenRewritePatternSet.h(15): note: could be 'mlir::detail'\r\nbazel-out/x64_windows-opt/bin/external/triton/include\\triton/Dialect/Triton/IR/Ops.h.inc(5826): note: or       'mlir::triton::detail'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(712): note: see reference to class template instantiation 'mlir::OpRewritePattern<mlir::scf::ForOp>' being compiled\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\n"
    output = subprocess.check_output(cmd)
```
PiperOrigin-RevId: 609153322
2024-02-21 16:02:54 -08:00
Sergei Lebedev
881436240e Inlined triton.compat
We no longer need a compatibility layer, since Pallas does not use any Triton
IR building APIs.

PiperOrigin-RevId: 606948415
2024-02-14 05:23:15 -08:00
Sergei Lebedev
5e2e609a9b _triton_ext no longer links in MLIR C APIs
I re-used the same trick we do for the TPU dialect. Specifically, _triton_ext no longer depends on :triton_dialect_capi. Instead

* we include Triton dialect C bindings into :jaxlib_mlir_capi_objects
* and _triton_ext depends on :jaxlib_mlir_capi_objects and a header-only cc_library providing Triton dialect C bindings

This is a fork of #19680 with a few internal-only fixes.

PiperOrigin-RevId: 604929377
2024-02-07 03:39:29 -08:00
Jieying Luo
b0b7c1c186 Fix missing flag definition in plugin wheels built script.
jaxlib_git_hash was recently added to the build command build/build.py.

PiperOrigin-RevId: 599931552
2024-01-19 14:06:19 -08:00
Sergei Lebedev
1e9f96a574 Include Triton files into the jaxlib wheel
This PR is based on #19368.
2024-01-16 15:28:12 +00:00
Peter Hawkins
dedd69f323 Add a bazel test that verifies that the jaxlib wheel builds. 2024-01-11 23:22:17 +00:00
Peter Hawkins
858fd52ac0 Fix jaxlib wheel build after removal of mosaic python files.
PiperOrigin-RevId: 597536465
2024-01-11 06:21:07 -08:00
Tomás Longeri
027c24e602 [Mosaic] Remove Python implementation of apply_vector_layout and infer_memref_layout.
PiperOrigin-RevId: 597332393
2024-01-10 13:00:21 -08:00
Jake VanderPlas
326d1d27ef jaxlib: avoid external build-time dependency on ml_dtypes
Currently, the ml_dtypes C++ sources are included in the set of sources at jaxlib build time. This is unnecessary, and can lead to problematic version skew in some cases (e.g. nightly builds).

PiperOrigin-RevId: 595725529
2024-01-04 09:26:05 -08:00