1332 Commits

Author SHA1 Message Date
Adam Paszke
99a12ef9ea [Mosaic GPU] Add support for warpgroup lowering of loops with vector carries
PiperOrigin-RevId: 731260912
2025-02-26 04:29:36 -08:00
Adam Paszke
cb7402f6de Remove MemoryEffects annotations from async_{load/store} ops
The annotation on async_load didn't indicate its write to SMEM, allowing it
to be DCEd by MLIR canonicalization. We don't get much mileage out of those
annotations, so let's just delete them for simplicity.

PiperOrigin-RevId: 731003033
2025-02-25 13:15:00 -08:00
Dan Foreman-Mackey
2ce88c950a Deprecate alpha argument to trsm LAPACK kernel.
(Part of general cleanups of the lax.linalg submodule.)

This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility

We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release).

Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released.

PiperOrigin-RevId: 730928808
2025-02-25 10:04:29 -08:00
jax authors
eb912ad0d9 Create jax wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
2025-02-25 09:30:08 -08:00
jax authors
083ffd3717 [Easy][Mosaic] Tiny refactor for clarity in getTypeBitwidth
PiperOrigin-RevId: 730906329
2025-02-25 08:58:19 -08:00
Jan Naumann
e03fe3a06d Implement SVD algorithm based on QR for CPU targets
In a recent jax release the SvdAlgorithm parameter has been added
to the jax.lax.linalg.svd function. Currently, for CPU targets
still only the divide and conquer algorithm from LAPACK is
supported (gesdd).

This commits adds the functionality to select the QR based
algorithm on CPU as well. Mainly it addes the wrapper code
to call the gesvd function of LAPACK using the FFI interface.

Signed-off-by: Jan Naumann <j.naumann@fu-berlin.de>
2025-02-22 15:24:57 +01:00
jax authors
b510127a13 Internal compatibility change
PiperOrigin-RevId: 729428478
2025-02-21 01:21:56 -08:00
jax authors
b7968474c2 [Pallas][Mosaic] Support float8_e4m3b11fnuz
PiperOrigin-RevId: 729169181
2025-02-20 10:44:33 -08:00
jax authors
37af0135b0 [Mosaic] Consider divisibility when doing large tiling
PiperOrigin-RevId: 728980108
2025-02-19 23:56:07 -08:00
Jevin Jiang
bb68124c33 [Mosaic TPU] Support mask concat
PiperOrigin-RevId: 728349788
2025-02-18 14:03:46 -08:00
jax authors
725087e13f Integrate LLVM at llvm/llvm-project@9d24f94379
Updates LLVM usage to match
[9d24f9437944](https://github.com/llvm/llvm-project/commit/9d24f9437944)

PiperOrigin-RevId: 728265165
2025-02-18 10:30:48 -08:00
jax authors
e78a469b42 Integrate LLVM at llvm/llvm-project@912b154f3a
Updates LLVM usage to match
[912b154f3a3f](https://github.com/llvm/llvm-project/commit/912b154f3a3f)

PiperOrigin-RevId: 727895384
2025-02-17 10:08:37 -08:00
Dimitar (Mitko) Asenov
52f8fbeee0 [Mosaic GPU] Implement lowerings for Tile and Transpose transforms from the MLIR dialect.
PiperOrigin-RevId: 727762334
2025-02-17 01:29:47 -08:00
jax authors
a6fcb7415f [TPU][Mosaic][Easy] Add verification for AssumeMultipleOp.
A user must use AssumeMultipleOp to annotate integer constants that are divisible by the given multiple.

PiperOrigin-RevId: 727699186
2025-02-16 21:16:05 -08:00
jax authors
eaceac3bf9 [Pallas] Reductions with replicated axes.
PiperOrigin-RevId: 727292293
2025-02-15 07:41:16 -08:00
Dan Foreman-Mackey
902ebe1bfe Fix segfault when old GPU plugins are installed.
PiperOrigin-RevId: 726919772
2025-02-14 07:26:45 -08:00
jax authors
5889fd0d22 Merge pull request #26486 from superbobry:maint-2
PiperOrigin-RevId: 726490849
2025-02-13 08:07:30 -08:00
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Adam Paszke
a493df4dd8 Fix Windows build for Mosaic GPU extension
We only export symbols that being with `mlir` and a few other prefixes, so this renames our C API functions for consistency with that.

PiperOrigin-RevId: 726468092
2025-02-13 06:58:17 -08:00
Jevin Jiang
876668faa1 [Mosaic TPU] Support bf16 div if HW does not directly support.
PiperOrigin-RevId: 726212286
2025-02-12 15:04:09 -08:00
tttc3
b1b56ea0b0 Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
2025-02-12 16:12:42 +00:00
Dimitar (Mitko) Asenov
6fc1c61520 [Mosaic GPU] Use the memref layout to encode transforms (only swizzle for now).
Tile and Transpose transforms to follow.

PiperOrigin-RevId: 725716812
2025-02-11 11:51:25 -08:00
jax authors
ffd3faad72 [TPU[Mosaic] Fix missing sfences in smem DMAs
PiperOrigin-RevId: 725376627
2025-02-10 15:51:35 -08:00
Dan Foreman-Mackey
154e4506c0 Some lax.linalg housekeeping.
The main aim here is to clean up lax.linalg to make it a bit easier to maintain and update with new features (e.g. batch partitioning - coming soon!). In this change, I removes some code duplication by consolidate most of the lowering logic into a helper function, and identifying some other common patterns. As part of this, I moved the remaining lowering rules from `jaxlib.lapack` into `lax.linalg`.

PiperOrigin-RevId: 725223882
2025-02-10 08:27:18 -08:00
Peter Hawkins
f6ca686641 Bump the minimum Mac OS X version for x86 builds to 11.0.
The x86 build stopped building completely due to a use of std::filesystem::path, which was added in 10.15.
We've dropped x86 support, but this is an easy enough fix to make and moves x86 to parity with ARM.
2025-02-10 08:51:32 -05:00
jax authors
6740165e4f [Pallas] Add pipeline mode to pltpu
PiperOrigin-RevId: 725133131
2025-02-10 02:36:44 -08:00
Dan Foreman-Mackey
5bc17f7ec3 Remove the unused cu_cholesky_update kernel in favor of the FFI version.
This kernel wasn't allowed in export, so no backwards compatibility period is required. Even so, the FFI kernels were added 6 months ago.

PiperOrigin-RevId: 724359996
2025-02-07 08:48:15 -08:00
Dan Foreman-Mackey
c6e83903de Update RNN kernels to use FFI.
PiperOrigin-RevId: 724151647
2025-02-06 18:27:58 -08:00
Dan Foreman-Mackey
5e915d3307 Update the sparse GPU kernels in jaxlib to use the FFI.
Unlike the other more detailed ports, this version doesn't take full advantage of the features provided by the FFI. For example, it would be possible to update the kernels to use the ScratchAllocator instead of querying the workspace size during lowering. However, since these kernels are really only meant to be experimental, it's not obvious to me that it's worth the extra work to do anything more sophisticated.

PiperOrigin-RevId: 724016331
2025-02-06 11:45:57 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
jax authors
d424f5b5b3 Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results.
This change is a part of the initiative to test the JAX wheels in the presubmit properly.

The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.

2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase.

3. The version suffix of the wheel in the build rule output depends on the environment variables.

   The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables.

4. Environment variables combinations for creating wheels with different versions:
  * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release`
  * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1`
  * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 723552265
2025-02-05 10:01:23 -08:00
Adam Paszke
e7a4f89343 [Mosaic TPU] Add optimized casts for bf16->s4 in TPUv6
PiperOrigin-RevId: 723455843
2025-02-05 04:21:55 -08:00
George Necula
9f797990b5 Remove old backward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th, 2024 (#20997).

PiperOrigin-RevId: 723077990
2025-02-04 07:34:52 -08:00
Sergei Lebedev
7929cd8410 [pallas:triton] The lowering now uses PTX instead of Triton IR
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.

See #25196.

A few notes

* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
  compilation is done via a new PjRt extension, which uses its own compilation
  pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
  deprecated and will be removed after 6 months as per
  [compatibility guarantees] [*]

[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees

PiperOrigin-RevId: 722773884
2025-02-03 13:21:40 -08:00
Jevin Jiang
d8b9211359 [Mosaic TPU] Support dynamic gather along axis 0 or 1 for 32-bit vreg-sized vector.
PiperOrigin-RevId: 721980453
2025-01-31 18:47:25 -08:00
Jevin Jiang
785a63ad0f [Mosaic TPU] Support non-32 bit mask relayout
PiperOrigin-RevId: 721552594
2025-01-30 16:13:23 -08:00
Tzu-Wei Sung
d4758b6d5e [Mosaic][NFC] Factor out xla-array related utils in a separate file.
Also added tests.

PiperOrigin-RevId: 721424194
2025-01-30 09:49:41 -08:00
Benjamin Chetioui
d8f3b33ae4 [Mosaic GPU] Eliminate the arrive attribute from mosaic_gpu.async_load.
We plan to explicitly issue an `expect_tx` operation all the time when using
the dialect.

PiperOrigin-RevId: 721411949
2025-01-30 09:08:45 -08:00
Dimitar (Mitko) Asenov
6214c25a6d [Mosaic GPU] Add ArriveExpect and Wait ops on dialect barriers with explicit handling of parities
This makes dialect tests in mgpu_test.py truly express the entire computation at the warpgroup level.

PiperOrigin-RevId: 721371327
2025-01-30 06:44:32 -08:00
Adam Paszke
29b658b358 [Mosaic TPU] Optimize clipping impelmentation in arith.fptosi
We can use maxf/minf to avoid extra comparisons

PiperOrigin-RevId: 720601304
2025-01-28 09:20:16 -08:00
Dimitar (Mitko) Asenov
a3a285dddc [Mosaic GPU] Handle the swizzle attribute in the lowering of async_store and async_load
PiperOrigin-RevId: 720129408
2025-01-27 05:18:16 -08:00
Sergei Lebedev
9ee7123c39 [mosaic_gpu] Fixed mosaic_gpu-serde pass registration
We previously registered the pass in the :_mosaic_gpu_ext which didn't work
because the extension has its own pass registry. The fix instead is to move
the registration to :register_jax_dialects in jaxlib.

PiperOrigin-RevId: 719280601
2025-01-24 06:35:54 -08:00
Adam Paszke
7043b852ec [Mosaic GPU] Add basic support for TMA with sub-byte types
PiperOrigin-RevId: 719240287
2025-01-24 03:54:12 -08:00
Jevin Jiang
8e1f956804 [Mosaic TPU] Use vmask pack if possible for mask's bitwidth change and introduce relayout op.
PiperOrigin-RevId: 719089676
2025-01-23 18:15:08 -08:00
Dimitar (Mitko) Asenov
f57d603c45 [Mosaic GPU] Simplify enums in the MLIR Mosaic GPU dialect.
This enables us to use them more simply in the current and upcoming Python code. The Python bindings for enum and enum attributes leave much to be desired.

PiperOrigin-RevId: 718795667
2025-01-23 03:38:26 -08:00
Dimitar (Mitko) Asenov
6b747b4109 [Mosaic GPU] Add a result to the WGMMA op definition in the MLIR dialect
PiperOrigin-RevId: 718788390
2025-01-23 03:10:07 -08:00
jax authors
6c76cc4e36 Integrate LLVM at llvm/llvm-project@d33e33fde7
Updates LLVM usage to match
[d33e33fde770](https://github.com/llvm/llvm-project/commit/d33e33fde770)

PiperOrigin-RevId: 718414171
2025-01-22 09:22:07 -08:00
jax authors
54bb7f5ddb Remove meaningless template keywords.
This will fix -Wmissing-template-arg-list-after-template-kw warnings.
This warning is error-by-default in Clang.

PiperOrigin-RevId: 718133601
2025-01-21 17:22:04 -08:00
Tzu-Wei Sung
79bd72e2e8 [Mosaic] Remove hardcoded TARGET_SHAPE and align Python/C++ APIs.
PiperOrigin-RevId: 717973752
2025-01-21 10:24:10 -08:00
Dimitar (Mitko) Asenov
f89accc56a [Mosaic GPU] Add support for converting all fragmented layouts to ir and back.
This will be used in the layout inference and lowering of the dialect WGMMA op

PiperOrigin-RevId: 717836648
2025-01-21 03:27:03 -08:00