1264 Commits

Author SHA1 Message Date
jax authors
a16fbffc13 [Mosaic][TPU] Add a compatibility mode to Mosaic's canonicalization pass, skipping over elementwise and matmul op insertions and/or type compat casts.
PiperOrigin-RevId: 714132282
2025-01-10 12:12:54 -08:00
Dan Foreman-Mackey
39ce7916f1 Activate FFI implementation of tridiagonal reduction on GPU.
PiperOrigin-RevId: 714078036
2025-01-10 09:28:15 -08:00
Dan Foreman-Mackey
c1de7c733d Add LAPACK lowering for lax.linalg.tridiagonal_solve on CPU.
In implementing https://github.com/jax-ml/jax/pull/25787, I realized that while we lower `tridiagonal_solve` to cuSPARSE on GPU, we were using an explicit implementation of the Thomas algorithm on CPU. We should instead lower to LAPACK's `gtsv` on CPU because it should be more numerically stable and faster.

PiperOrigin-RevId: 714069225
2025-01-10 08:56:46 -08:00
jax authors
564b6b0d72 Merge pull request #20282 from tttc3:pivoted-qr
PiperOrigin-RevId: 714053620
2025-01-10 08:02:02 -08:00
Adam Paszke
d2a5e8d072 [Mosaic TPU] Add support for integer truncation from packed types
PiperOrigin-RevId: 714048232
2025-01-10 07:40:55 -08:00
jax authors
061408aca3 Merge pull request #25803 from sergachev:fix_rnn_desc
PiperOrigin-RevId: 713789106
2025-01-09 14:05:30 -08:00
tttc3
c89be05b5b Enable pivoted QR on CPU devices.
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb331707e80b16d89de6e5c9f2f89b87c1682ed (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
2025-01-09 20:44:45 +00:00
Adam Paszke
07f4fd3e51 [Mosaic TPU] Fix a bug in the impl of sublane broadcasts for int8 and int4
PiperOrigin-RevId: 713675029
2025-01-09 08:05:25 -08:00
Ilia Sergachev
f0e1c3cf36 Fix struct string encoding non-determinism in the RNN descriptor.
Boolean fields in the descriptor struct led to padding, which let random
bytes in the string representation of the struct and variance in HLO
from run to run.
2025-01-09 12:57:09 +00:00
Peter Hawkins
0389d617c8 Add a unittest test extension that runs test cases in parallel using threads.
This change does not yet do the work necessary to make any tests pass with threading enabled, which will come in future changes.

This approach is broadly inspired by a6d205dd4c/testtools/testsuite.py (L113) and by unittest-ft.

We add a custom TestResult class that batches up any test result actions and applies them under a lock. We also add a custom TestSuite class that runs individual test cases in parallel using a thread-pool.

We need a reader-writer lock to implement a `@jtu.thread_hostile_test` decorator, which we do by adding bindings around absl::Mutex to jaxlib.

PiperOrigin-RevId: 713312937
2025-01-08 09:11:47 -08:00
Adam Paszke
f96339be1e [Mosaic TPU] Be much more aggressive in inferring large 2nd minor layouts for 16-bit types on v6
This often lets us avoid ambiguities between selecting the (8, 128) and (16, 128) tiling,
by biasing the layout inference to prefer the latter.

PiperOrigin-RevId: 713270421
2025-01-08 06:30:36 -08:00
Adam Paszke
5fd1b2f825 [Mosaic TPU] Add support for second minor broadcasts with packed types
PiperOrigin-RevId: 713259707
2025-01-08 05:45:02 -08:00
Adam Paszke
e954930eaf [Mosaic TPU] Add support for true divide in bf16 on TPUv6
PiperOrigin-RevId: 713247480
2025-01-08 04:49:22 -08:00
Tzu-Wei Sung
bf94389b08 [Mosaic] Use tpu::CreateMask for getX32VmaskByPaddingEnd.
It was cmp + iota before.

PiperOrigin-RevId: 713240888
2025-01-08 04:18:53 -08:00
Peter Hawkins
392a851769 Increase the minimum SciPy version to 1.11.1.
(1.11.0 was yanked from PyPi because of licensing problems, so 1.11.1 is the oldest 1.11 release.)

PiperOrigin-RevId: 713073731
2025-01-07 16:10:45 -08:00
Dan Foreman-Mackey
a7f384cc6e Add a register_custom_type_id function to the GPU plugins.
This enables dynamic registration of custom FFI types on the appropriate platform via PJRT.

PiperOrigin-RevId: 712904085
2025-01-07 07:29:38 -08:00
Sharad Vikram
4caa263a94 [Mosaic TPU] Add some elementwise canonicalizations
PiperOrigin-RevId: 712671502
2025-01-06 15:10:02 -08:00
Peter Hawkins
90d8f37863 Rename pybind_extension to nanobind_extension.
We have no remaining uses of pybind11 outside a GPU custom call example.

PiperOrigin-RevId: 712608834
2025-01-06 11:53:44 -08:00
Peter Hawkins
61dd041225 Suppress MSAN warnings from SVD that are showing up in CI.
In our MSAN CI, the copy of LAPACK we use is not MSAN-instrumented, leading to false positives. Suppress those false-positives via annotations.

PiperOrigin-RevId: 712607044
2025-01-06 11:49:05 -08:00
Jevin Jiang
9f842909ce [Mosaic TPU] Validate inserted layout in relayout-insertion pass.
PiperOrigin-RevId: 712595778
2025-01-06 11:15:47 -08:00
John QiangZhang
c39e38fe5a bazel: export serialization.fbs for downstream usage
PiperOrigin-RevId: 712587802
2025-01-06 10:57:35 -08:00
Tzu-Wei Sung
57b21541a2 [Mosaic] NFC: Pull out vreg related functions to util.
These functions are related to vreg manipulation and are used in different rules.

PiperOrigin-RevId: 711484002
2025-01-02 11:50:19 -08:00
jax authors
68483b8ed6 Merge pull request #25710 from apaszke:mgpu_dialect_fix
PiperOrigin-RevId: 711430610
2025-01-02 08:23:28 -08:00
Adam Paszke
64433435ff Fix OSS build for the Mosaic GPU dialect 2025-01-02 15:55:03 +00:00
Tomás Longeri
ac817b48ca [Mosaic:TPU][NFC] Clean up unused variable
PiperOrigin-RevId: 711412888
2025-01-02 06:57:38 -08:00
Tomás Longeri
4452960947 [Mosaic:TPU] In infer ext rule, avoid assigning offsets outside of dst first tile
Note that offsets outside of first tile are still disabled (for both infer and apply), and once we support it we will want to assign offsets differently, this is mostly to avoid assigning invalid layouts (that may not just be outside the first tile, but outside the vreg slice)

PiperOrigin-RevId: 709168368
2024-12-23 15:49:39 -08:00
jax authors
b8091a437a Switch mlir bindings from pybind11 to nanobind
PiperOrigin-RevId: 709161113
2024-12-23 15:10:11 -08:00
Tomás Longeri
3c79b98cd9 [Mosaic:TPU] Vreg-slice-aligned offset changes with scratch retiling
PiperOrigin-RevId: 709133729
2024-12-23 13:05:14 -08:00
Sergei Lebedev
68ec202d45 Use the right include for gmock and gtest
PiperOrigin-RevId: 709058082
2024-12-23 07:34:36 -08:00
Sergei Lebedev
8987867faa [mosaic_gpu] Include Mosaic GPU dialect fiels into jaxlib 2024-12-23 13:46:25 +00:00
Tomás Longeri
7ecc947184 [Mosaic:TPU] Roll forward of cl/708011538 (expanded trunc support), minus changes in infer-vector-layout
We can enable them later but at least this way the support is available to build on
(e.g. in the new insert relayouts pass)

Reverts 05f3a701e769748ff1ec51d50324a3595c4aff0d

PiperOrigin-RevId: 708397219
2024-12-20 12:33:30 -08:00
Peter Hawkins
0ff3f144e5 Migrate _mlir Python binding target to nanobind.
PiperOrigin-RevId: 708390390
2024-12-20 12:07:29 -08:00
Tomás Longeri
05f3a701e7 [Mosaic:TPU] Roll back cl/708011538 and cl/708112341
Reverts 307c8d3af81f16142fd4c64f501b05a5b69f815e

PiperOrigin-RevId: 708173083
2024-12-19 21:51:44 -08:00
Jevin Jiang
2faf540203 [Mosaic TPU] Add relayout-insertion pass and support bitwidth change for i1 vector relayout
We can use relayout-insertion pass to insert necessary ops and their layouts for relayout before unrolling in apply-vector-layout pass.

PiperOrigin-RevId: 708143852
2024-12-19 19:56:40 -08:00
Tomás Longeri
8b02884c3c [Mosaic:TPU] Fix trunc infer rule after cl/708011538
For (1, 128) tiling 32-bit input, it assigns (1, 128) tiling at output, which can be invalid (e.g. it should be (1, 256) for bf16)

PiperOrigin-RevId: 708112341
2024-12-19 18:14:12 -08:00
Tzu-Wei Sung
60ebde89e6 [Mosaic] Extend macros to handle parentheses.
PiperOrigin-RevId: 708045694
2024-12-19 15:00:12 -08:00
Tzu-Wei Sung
77f3c114d0 [Mosaic] Remove TODOs that are already addressed or obsolete.
PiperOrigin-RevId: 708045439
2024-12-19 14:58:15 -08:00
Nitin Srinivasan
6b096b0cb0 Use common set of build options when building jaxlib+plugin artifacts together
This commit modifies the behavior of the build CLI when building jaxlib and GPU plugin artifacts together (for instance `python build --wheels=jaxlib,jax-cuda-plugin`.

Before, CUDA/ROCm build options were only passed when building the CUDA/ROCm artifacts. However, this leads to inefficient use of the build cache as it looks like Bazel tries to rebuild some targets that has already been built in the previous run. This seems to be because the GPU plugin artifacts have a different set of build options compared to `jaxlib` which for some reason causes Bazel to invalidate/ignore certain cache hits. Therefore, this commit makes it so that the build options remain the same when the `jaxlib` and GPU artifacts are being built together so that we can better utilize the build cache.

As an example, this means that if `python build --wheels=jaxlib,jax-cuda-plugin` is run, the following build options will apply to both `jaxlib` and `jax-cuda-plugin` builds:
```
 /usr/local/bin/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
--verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--repo_env=CC="/usr/lib/llvm-16/bin/clang" \
--repo_env=BAZEL_COMPILER="/usr/lib/llvm-16/bin/clang" \
--config=clang --config=mkl_open_source_only --config=avx_posix \
--config=cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--config=build_cuda_with_nvcc
```

Note, this commit shouldn't affect the content of the wheel it self. It is only meant to give a performance boost when building `jalxib`+plugin aritfacts together.

Also, this removes code that was used to build (now deprecated) monolithic `jaxlib` build from `build_wheel.py`

PiperOrigin-RevId: 708035062
2024-12-19 14:29:24 -08:00
Tomás Longeri
307c8d3af8 [Mosaic:TPU] For trunc, expand supported tilings, offsets and bitwidths
infer-vector-layout won't use the full generality anytime soon, but we could reuse this logic for relayouts

PiperOrigin-RevId: 708011538
2024-12-19 13:31:59 -08:00
Peter Hawkins
9c3365fb95 Migrate shardy dialect extension to nanobind.
PiperOrigin-RevId: 707991933
2024-12-19 12:25:14 -08:00
Benjamin Chetioui
3915f4a147 [Mosaic GPU] Commit to using Vectors everywhere (and no Tensors).
PiperOrigin-RevId: 707912637
2024-12-19 07:51:58 -08:00
Sergei Lebedev
c4fae4a7c2 [jaxlib] Added a missing pytype_dep to :_triton_ext
PiperOrigin-RevId: 707907056
2024-12-19 07:28:28 -08:00
Benjamin Chetioui
66ad2082ba [Mosaic GPU] Replace the dialect's layout enum with layouts holding the proper
sub-attributes.

PiperOrigin-RevId: 707846907
2024-12-19 02:59:26 -08:00
Tomás Longeri
8188c57475 [Mosaic:TPU][NFC] Small cleanup of extui rule in apply-vector-layout
Removed some duplicate variables, changed dyn_cast to cast, and used in/out consistently instead of source/dst

PiperOrigin-RevId: 707836363
2024-12-19 02:13:18 -08:00
Jevin Jiang
3a5c4da4ef [Mosaic TPU] Support i32 vector multi reduction except cross lane.
PiperOrigin-RevId: 707708236
2024-12-18 16:49:07 -08:00
Naums Mogers
6bcec910f2 [Mosaic] Improve error verbosity of tpu.memref_slice verification
Breaks down the compound verification conditional into smaller checks with verbose error messages.

PiperOrigin-RevId: 707699990
2024-12-18 16:18:45 -08:00
Naums Mogers
de359f5ce0 [Mosaic] Verify that the target IDs are provided in remote DMAs
Adds an extra verification check. Since the source semaphore is used only for remote DMAs, we should check that device or core IDs are also provided when source semaphore is provided.

PiperOrigin-RevId: 707675228
2024-12-18 14:49:59 -08:00
Tomás Longeri
13e721a25e [Mosaic:TPU][NFC] Delete unused functions
PiperOrigin-RevId: 707660214
2024-12-18 14:00:22 -08:00
Jevin Jiang
bf692efbfb [Mosaic TPU] Support direct cast i8 vector to mask
PiperOrigin-RevId: 707617318
2024-12-18 11:35:14 -08:00
jax authors
464e5a270f Merge pull request #25569 from hawkinsp:numpyver
PiperOrigin-RevId: 707570246
2024-12-18 09:05:19 -08:00