1314 Commits

Author SHA1 Message Date
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Jevin Jiang
876668faa1 [Mosaic TPU] Support bf16 div if HW does not directly support.
PiperOrigin-RevId: 726212286
2025-02-12 15:04:09 -08:00
tttc3
b1b56ea0b0 Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
2025-02-12 16:12:42 +00:00
Dimitar (Mitko) Asenov
6fc1c61520 [Mosaic GPU] Use the memref layout to encode transforms (only swizzle for now).
Tile and Transpose transforms to follow.

PiperOrigin-RevId: 725716812
2025-02-11 11:51:25 -08:00
jax authors
ffd3faad72 [TPU[Mosaic] Fix missing sfences in smem DMAs
PiperOrigin-RevId: 725376627
2025-02-10 15:51:35 -08:00
Dan Foreman-Mackey
154e4506c0 Some lax.linalg housekeeping.
The main aim here is to clean up lax.linalg to make it a bit easier to maintain and update with new features (e.g. batch partitioning - coming soon!). In this change, I removes some code duplication by consolidate most of the lowering logic into a helper function, and identifying some other common patterns. As part of this, I moved the remaining lowering rules from `jaxlib.lapack` into `lax.linalg`.

PiperOrigin-RevId: 725223882
2025-02-10 08:27:18 -08:00
Peter Hawkins
f6ca686641 Bump the minimum Mac OS X version for x86 builds to 11.0.
The x86 build stopped building completely due to a use of std::filesystem::path, which was added in 10.15.
We've dropped x86 support, but this is an easy enough fix to make and moves x86 to parity with ARM.
2025-02-10 08:51:32 -05:00
jax authors
6740165e4f [Pallas] Add pipeline mode to pltpu
PiperOrigin-RevId: 725133131
2025-02-10 02:36:44 -08:00
Dan Foreman-Mackey
5bc17f7ec3 Remove the unused cu_cholesky_update kernel in favor of the FFI version.
This kernel wasn't allowed in export, so no backwards compatibility period is required. Even so, the FFI kernels were added 6 months ago.

PiperOrigin-RevId: 724359996
2025-02-07 08:48:15 -08:00
Dan Foreman-Mackey
c6e83903de Update RNN kernels to use FFI.
PiperOrigin-RevId: 724151647
2025-02-06 18:27:58 -08:00
Dan Foreman-Mackey
5e915d3307 Update the sparse GPU kernels in jaxlib to use the FFI.
Unlike the other more detailed ports, this version doesn't take full advantage of the features provided by the FFI. For example, it would be possible to update the kernels to use the ScratchAllocator instead of querying the workspace size during lowering. However, since these kernels are really only meant to be experimental, it's not obvious to me that it's worth the extra work to do anything more sophisticated.

PiperOrigin-RevId: 724016331
2025-02-06 11:45:57 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
jax authors
d424f5b5b3 Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results.
This change is a part of the initiative to test the JAX wheels in the presubmit properly.

The list of the changes:
1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`.

2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase.

3. The version suffix of the wheel in the build rule output depends on the environment variables.

   The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables.

4. Environment variables combinations for creating wheels with different versions:
  * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release`
  * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1`
  * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 723552265
2025-02-05 10:01:23 -08:00
Adam Paszke
e7a4f89343 [Mosaic TPU] Add optimized casts for bf16->s4 in TPUv6
PiperOrigin-RevId: 723455843
2025-02-05 04:21:55 -08:00
George Necula
9f797990b5 Remove old backward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th, 2024 (#20997).

PiperOrigin-RevId: 723077990
2025-02-04 07:34:52 -08:00
Sergei Lebedev
7929cd8410 [pallas:triton] The lowering now uses PTX instead of Triton IR
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.

See #25196.

A few notes

* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
  compilation is done via a new PjRt extension, which uses its own compilation
  pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
  deprecated and will be removed after 6 months as per
  [compatibility guarantees] [*]

[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees

PiperOrigin-RevId: 722773884
2025-02-03 13:21:40 -08:00
Jevin Jiang
d8b9211359 [Mosaic TPU] Support dynamic gather along axis 0 or 1 for 32-bit vreg-sized vector.
PiperOrigin-RevId: 721980453
2025-01-31 18:47:25 -08:00
Jevin Jiang
785a63ad0f [Mosaic TPU] Support non-32 bit mask relayout
PiperOrigin-RevId: 721552594
2025-01-30 16:13:23 -08:00
Tzu-Wei Sung
d4758b6d5e [Mosaic][NFC] Factor out xla-array related utils in a separate file.
Also added tests.

PiperOrigin-RevId: 721424194
2025-01-30 09:49:41 -08:00
Benjamin Chetioui
d8f3b33ae4 [Mosaic GPU] Eliminate the arrive attribute from mosaic_gpu.async_load.
We plan to explicitly issue an `expect_tx` operation all the time when using
the dialect.

PiperOrigin-RevId: 721411949
2025-01-30 09:08:45 -08:00
Dimitar (Mitko) Asenov
6214c25a6d [Mosaic GPU] Add ArriveExpect and Wait ops on dialect barriers with explicit handling of parities
This makes dialect tests in mgpu_test.py truly express the entire computation at the warpgroup level.

PiperOrigin-RevId: 721371327
2025-01-30 06:44:32 -08:00
Adam Paszke
29b658b358 [Mosaic TPU] Optimize clipping impelmentation in arith.fptosi
We can use maxf/minf to avoid extra comparisons

PiperOrigin-RevId: 720601304
2025-01-28 09:20:16 -08:00
Dimitar (Mitko) Asenov
a3a285dddc [Mosaic GPU] Handle the swizzle attribute in the lowering of async_store and async_load
PiperOrigin-RevId: 720129408
2025-01-27 05:18:16 -08:00
Sergei Lebedev
9ee7123c39 [mosaic_gpu] Fixed mosaic_gpu-serde pass registration
We previously registered the pass in the :_mosaic_gpu_ext which didn't work
because the extension has its own pass registry. The fix instead is to move
the registration to :register_jax_dialects in jaxlib.

PiperOrigin-RevId: 719280601
2025-01-24 06:35:54 -08:00
Adam Paszke
7043b852ec [Mosaic GPU] Add basic support for TMA with sub-byte types
PiperOrigin-RevId: 719240287
2025-01-24 03:54:12 -08:00
Jevin Jiang
8e1f956804 [Mosaic TPU] Use vmask pack if possible for mask's bitwidth change and introduce relayout op.
PiperOrigin-RevId: 719089676
2025-01-23 18:15:08 -08:00
Dimitar (Mitko) Asenov
f57d603c45 [Mosaic GPU] Simplify enums in the MLIR Mosaic GPU dialect.
This enables us to use them more simply in the current and upcoming Python code. The Python bindings for enum and enum attributes leave much to be desired.

PiperOrigin-RevId: 718795667
2025-01-23 03:38:26 -08:00
Dimitar (Mitko) Asenov
6b747b4109 [Mosaic GPU] Add a result to the WGMMA op definition in the MLIR dialect
PiperOrigin-RevId: 718788390
2025-01-23 03:10:07 -08:00
jax authors
6c76cc4e36 Integrate LLVM at llvm/llvm-project@d33e33fde7
Updates LLVM usage to match
[d33e33fde770](https://github.com/llvm/llvm-project/commit/d33e33fde770)

PiperOrigin-RevId: 718414171
2025-01-22 09:22:07 -08:00
jax authors
54bb7f5ddb Remove meaningless template keywords.
This will fix -Wmissing-template-arg-list-after-template-kw warnings.
This warning is error-by-default in Clang.

PiperOrigin-RevId: 718133601
2025-01-21 17:22:04 -08:00
Tzu-Wei Sung
79bd72e2e8 [Mosaic] Remove hardcoded TARGET_SHAPE and align Python/C++ APIs.
PiperOrigin-RevId: 717973752
2025-01-21 10:24:10 -08:00
Dimitar (Mitko) Asenov
f89accc56a [Mosaic GPU] Add support for converting all fragmented layouts to ir and back.
This will be used in the layout inference and lowering of the dialect WGMMA op

PiperOrigin-RevId: 717836648
2025-01-21 03:27:03 -08:00
Adam Paszke
543dd94762 [Mosaic TPU] Add a faster implementation for packing b16 to s8 in TPUv6
PiperOrigin-RevId: 717583425
2025-01-20 11:18:22 -08:00
Peter Hawkins
034e967e11 Remove CUDA rpaths from jaxlib build.
These are also set in the TSL build rules as part of the CUDA stub libraries, which these libraries depend on, so these copies of the rpath settings are redundant.

PiperOrigin-RevId: 716844265
2025-01-17 17:09:30 -08:00
jax authors
a527aba646 Reverts f1b894d14a28ac22a037fb79177b991275c75a18
PiperOrigin-RevId: 716653711
2025-01-17 07:00:31 -08:00
Benjamin Chetioui
d3be190efb [Mosaic GPU] Delete unused declarations of mosaic_gpu_memcpy_async_h2d.
PiperOrigin-RevId: 716616807
2025-01-17 04:34:48 -08:00
Sergei Lebedev
d34c40f6b6 [mosaic_gpu] Added a serialization pass
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.

PiperOrigin-RevId: 716596848
2025-01-17 03:12:51 -08:00
Adam Paszke
bd22bfef71 [Mosaic TPU] Use large to compact 2nd minor retiling for conversions going both ways
This specific retiling is its own inverse and it faster than alternatives.

PiperOrigin-RevId: 716360070
2025-01-16 13:35:26 -08:00
Tzu-Wei Sung
5c020ee317 [Mosaic] Fix infer/apply extensions.
1. For apply, llvm::StringMap()::insert(MapEntryTy*) will cause dangling reference if not constructing mlir::tpu::extensions::rules() with const-reference. However, if we do construct it with const-reference, the signature is not const-qualified and fails to compile. Hence, change it to llvm::StringMap()::insert(std::pair<...>) and get extension rules by const-reference.
2. Pass default tiling to infer rule, we need it to infer single op. See infer of tpu::MatmulOp.

PiperOrigin-RevId: 716274818
2025-01-16 09:57:14 -08:00
Sergei Lebedev
4221f109d1 [mosaic] Extracted serialization pass traversal logic into a reusable function
I will use it to implement Mosaic GPU serialization pass in a follow up.

PiperOrigin-RevId: 716156650
2025-01-16 02:58:06 -08:00
Tzu-Wei Sung
4a9cc9ffc1 [Mosaic] Allow passing ApplyVectorLayoutCtx to tpu.apply_layout_op.
To make it the same with C++ API. While I'm here, fix a bug in test_concatenate.

PiperOrigin-RevId: 716016244
2025-01-15 17:47:36 -08:00
Naums Mogers
d3ba1eb339 [Mosaic] Add a macro to convert abseil StatusOr to LLVM FailureOr
PiperOrigin-RevId: 715943314
2025-01-15 14:19:29 -08:00
jax authors
41993fdb24 Merge pull request #25755 from ROCm:ci_rnn_final-upstream
PiperOrigin-RevId: 715856939
2025-01-15 10:40:54 -08:00
Ruturaj4
fe68eb8b25 [ROCm] Implement RNN support 2025-01-14 19:04:49 -06:00
George Necula
f1b894d14a Reverts 391bad8ff59c07c8fad7b8ce05cd0e29dee4cf1a
PiperOrigin-RevId: 715435319
2025-01-14 10:31:59 -08:00
Ayaka
9ba1fd2801 [Pallas TPU] Add vector support to pl.debug_print
PiperOrigin-RevId: 715085454
2025-01-13 13:22:21 -08:00
Adam Paszke
391bad8ff5 [Mosaic TPU] Add support for arith.fptosi with non-32bit source and target types
This effectively moves some of the Pallas logic to the layer below.

PiperOrigin-RevId: 714965374
2025-01-13 07:49:13 -08:00
Peter Hawkins
91ffb640a8 Use thread-safe initialization of LAPACK kernels.
Use absl::call_once instead of a GIL-protected global initialization.

In passing, also remove an unused function.

PiperOrigin-RevId: 714892175
2025-01-13 02:51:38 -08:00
Tomás Longeri
7852045582 [Mosaic TPU] Enable non-sublane-aligned bf16 2D load/stores for earlier TPU gens
It is still not efficiently implemented, this is mostly to clean up some logic. We may be able to fuse the creation of masks for different tiles into the creation of a single one. But this is also a problem for the later gens.

This also cleans up an unreachable return statement.

PiperOrigin-RevId: 714847066
2025-01-12 23:58:40 -08:00
Tomás Longeri
0930289997 [Mosaic TPU][NFC] Remove redundant num_subelems attribute from CreateSubelementMaskOp
PiperOrigin-RevId: 714795856
2025-01-12 19:34:25 -08:00