1281 Commits

Author SHA1 Message Date
Jevin Jiang
2faf540203 [Mosaic TPU] Add relayout-insertion pass and support bitwidth change for i1 vector relayout
We can use relayout-insertion pass to insert necessary ops and their layouts for relayout before unrolling in apply-vector-layout pass.

PiperOrigin-RevId: 708143852
2024-12-19 19:56:40 -08:00
Tomás Longeri
8b02884c3c [Mosaic:TPU] Fix trunc infer rule after cl/708011538
For (1, 128) tiling 32-bit input, it assigns (1, 128) tiling at output, which can be invalid (e.g. it should be (1, 256) for bf16)

PiperOrigin-RevId: 708112341
2024-12-19 18:14:12 -08:00
Tzu-Wei Sung
60ebde89e6 [Mosaic] Extend macros to handle parentheses.
PiperOrigin-RevId: 708045694
2024-12-19 15:00:12 -08:00
Tzu-Wei Sung
77f3c114d0 [Mosaic] Remove TODOs that are already addressed or obsolete.
PiperOrigin-RevId: 708045439
2024-12-19 14:58:15 -08:00
Nitin Srinivasan
6b096b0cb0 Use common set of build options when building jaxlib+plugin artifacts together
This commit modifies the behavior of the build CLI when building jaxlib and GPU plugin artifacts together (for instance `python build --wheels=jaxlib,jax-cuda-plugin`.

Before, CUDA/ROCm build options were only passed when building the CUDA/ROCm artifacts. However, this leads to inefficient use of the build cache as it looks like Bazel tries to rebuild some targets that has already been built in the previous run. This seems to be because the GPU plugin artifacts have a different set of build options compared to `jaxlib` which for some reason causes Bazel to invalidate/ignore certain cache hits. Therefore, this commit makes it so that the build options remain the same when the `jaxlib` and GPU artifacts are being built together so that we can better utilize the build cache.

As an example, this means that if `python build --wheels=jaxlib,jax-cuda-plugin` is run, the following build options will apply to both `jaxlib` and `jax-cuda-plugin` builds:
```
 /usr/local/bin/bazel run --repo_env=HERMETIC_PYTHON_VERSION=3.10 \
--verbose_failures=true --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--repo_env=CC="/usr/lib/llvm-16/bin/clang" \
--repo_env=BAZEL_COMPILER="/usr/lib/llvm-16/bin/clang" \
--config=clang --config=mkl_open_source_only --config=avx_posix \
--config=cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-16/bin/clang" \
--config=build_cuda_with_nvcc
```

Note, this commit shouldn't affect the content of the wheel it self. It is only meant to give a performance boost when building `jalxib`+plugin aritfacts together.

Also, this removes code that was used to build (now deprecated) monolithic `jaxlib` build from `build_wheel.py`

PiperOrigin-RevId: 708035062
2024-12-19 14:29:24 -08:00
Tomás Longeri
307c8d3af8 [Mosaic:TPU] For trunc, expand supported tilings, offsets and bitwidths
infer-vector-layout won't use the full generality anytime soon, but we could reuse this logic for relayouts

PiperOrigin-RevId: 708011538
2024-12-19 13:31:59 -08:00
Peter Hawkins
9c3365fb95 Migrate shardy dialect extension to nanobind.
PiperOrigin-RevId: 707991933
2024-12-19 12:25:14 -08:00
Benjamin Chetioui
3915f4a147 [Mosaic GPU] Commit to using Vectors everywhere (and no Tensors).
PiperOrigin-RevId: 707912637
2024-12-19 07:51:58 -08:00
Sergei Lebedev
c4fae4a7c2 [jaxlib] Added a missing pytype_dep to :_triton_ext
PiperOrigin-RevId: 707907056
2024-12-19 07:28:28 -08:00
Benjamin Chetioui
66ad2082ba [Mosaic GPU] Replace the dialect's layout enum with layouts holding the proper
sub-attributes.

PiperOrigin-RevId: 707846907
2024-12-19 02:59:26 -08:00
Tomás Longeri
8188c57475 [Mosaic:TPU][NFC] Small cleanup of extui rule in apply-vector-layout
Removed some duplicate variables, changed dyn_cast to cast, and used in/out consistently instead of source/dst

PiperOrigin-RevId: 707836363
2024-12-19 02:13:18 -08:00
Jevin Jiang
3a5c4da4ef [Mosaic TPU] Support i32 vector multi reduction except cross lane.
PiperOrigin-RevId: 707708236
2024-12-18 16:49:07 -08:00
Naums Mogers
6bcec910f2 [Mosaic] Improve error verbosity of tpu.memref_slice verification
Breaks down the compound verification conditional into smaller checks with verbose error messages.

PiperOrigin-RevId: 707699990
2024-12-18 16:18:45 -08:00
Naums Mogers
de359f5ce0 [Mosaic] Verify that the target IDs are provided in remote DMAs
Adds an extra verification check. Since the source semaphore is used only for remote DMAs, we should check that device or core IDs are also provided when source semaphore is provided.

PiperOrigin-RevId: 707675228
2024-12-18 14:49:59 -08:00
Tomás Longeri
13e721a25e [Mosaic:TPU][NFC] Delete unused functions
PiperOrigin-RevId: 707660214
2024-12-18 14:00:22 -08:00
Jevin Jiang
bf692efbfb [Mosaic TPU] Support direct cast i8 vector to mask
PiperOrigin-RevId: 707617318
2024-12-18 11:35:14 -08:00
jax authors
464e5a270f Merge pull request #25569 from hawkinsp:numpyver
PiperOrigin-RevId: 707570246
2024-12-18 09:05:19 -08:00
Peter Hawkins
3f24dfd234 Migrate mhlo dialect extension to nanobind.
PiperOrigin-RevId: 707562235
2024-12-18 08:36:12 -08:00
Adam Paszke
6edfe9eae5 [Mosaic TPU] Add support for bf16 second minor reductions in TPUv6
PiperOrigin-RevId: 707557416
2024-12-18 08:17:43 -08:00
Peter Hawkins
3d54d03529 Migrate StableHLO Python extension to nanobind.
PiperOrigin-RevId: 707543869
2024-12-18 07:28:52 -08:00
Peter Hawkins
ee45718457 Increase the minimum NumPy version to v1.25.
Per SPEC 0, we drop NumPy v1.24 support on Dec 18, 2024.
2024-12-18 08:18:57 -05:00
Tomás Longeri
dc0b77470e [Mosaic:TPU] Allow null parts for tpu.pack_subelements, meaning "don't care"
PiperOrigin-RevId: 707439259
2024-12-18 00:56:41 -08:00
Tomás Longeri
f9737b957e [Mosaic:TPU] Fix bug after cl/707025084
`tile_masks` was updated to use implicit, but we skipped the reshape for `tiles`

Seems like there was even a bug before cl/707025084: `tile_masks` was never reshaped, so if the shape was 1D and a store mask was specified, there would be a mismatch in dimensions.

PiperOrigin-RevId: 707368670
2024-12-17 20:31:29 -08:00
Adam Paszke
4911a396b2 [Mosaic TPU] Add support for the interleaved pack format to tpu.unpack_subelements
PiperOrigin-RevId: 707142562
2024-12-17 09:58:07 -08:00
Benjamin Chetioui
36b12d58f4 [Mosaic GPU] Add end-to-end lowering example for a pointwise kernel using the dialect and layout inference.
Also implement a lowering rule for `arith.AddFOp`.

PiperOrigin-RevId: 707131747
2024-12-17 09:28:05 -08:00
Tomás Longeri
c6e6b11e1e [Mosaic:TPU][NFC] use_implicit_shape instead of Reshape + don't create unused constant
PiperOrigin-RevId: 707025084
2024-12-17 03:18:04 -08:00
Tomás Longeri
7e6c52dc21 [Mosaic:TPU][NFC] Clean up local variable
PiperOrigin-RevId: 707013166
2024-12-17 02:26:07 -08:00
Peter Hawkins
11e0fdf3e7 Add Python.h include to fix Windows build.
PiperOrigin-RevId: 706700133
2024-12-16 07:21:01 -08:00
Tomás Longeri
5493fabb53 [Mosaic:TPU] Replicating retilings with increasing tile size for (a) replicated 2nd minor or (b) 32-bit single-row
This is a generalization of the (1, 128) -> (8, 128) 32-bit replicated retiling

PiperOrigin-RevId: 706266247
2024-12-14 12:23:11 -08:00
Peter Hawkins
64eae324ee Migrate JAX MLIR Python dialect extensions to nanobind.
Now that https://github.com/llvm/llvm-project/pull/117922 has landed upstream, we can work towards removing our last uses of pybind11.

PiperOrigin-RevId: 705872751
2024-12-13 07:08:28 -08:00
Sergei Lebedev
a14e6968bf [mosaic] Migrated the serialization pass from codegen to pass_boilerplate.h
This prepares teh generalization of the serialization pass to handle both
Mosaic TPU and GPU.

PiperOrigin-RevId: 705628923
2024-12-12 14:19:36 -08:00
Jevin Jiang
3ff5706051 [Mosaic TPU][NFC] Create local namespace to prevent function name duplication error under global namespace mlir::tpu
PiperOrigin-RevId: 705538965
2024-12-12 09:53:39 -08:00
Tzu-Wei Sung
21f6b401dd [Mosaic] Pad trailing transposes chunks with zeros.
PiperOrigin-RevId: 705310340
2024-12-11 18:20:05 -08:00
jax authors
5fe8bcc734 Merge pull request #25407 from ROCm:remove-cuda-import-in-plugin-upstream
PiperOrigin-RevId: 705168796
2024-12-11 11:07:19 -08:00
Charles Hofer
8d42fa0b0b Remove cuda include from gpu plugin extension and BUILD 2024-12-11 11:55:51 -06:00
jax authors
0d7eaeb5d8 Merge pull request #24805 from andportnoy:aportnoy/mosaic-gpu-cupti-profiler
PiperOrigin-RevId: 705071782
2024-12-11 05:29:10 -08:00
Paweł Paruzel
1256153200 Activate Triangular Solve to XLA's FFI
PiperOrigin-RevId: 705029286
2024-12-11 02:22:37 -08:00
Dimitar (Mitko) Asenov
66f45d039f [Mosaic GPU] Add WGMMA to the Mosaic GPU MLIR Dialect.
The op API is still in flux so I'm leaving some of the verification code untested.

PiperOrigin-RevId: 705020066
2024-12-11 01:47:29 -08:00
Ayaka
e88b578356 [Pallas TPU] Add WeirdOp to TPU dialect and add lowering for lax.is_finite
PiperOrigin-RevId: 704888940
2024-12-10 16:38:04 -08:00
Dan Foreman-Mackey
593143e17e Deduplicate some GPU plugin definition code.
The `jaxlib/cuda_plugin_extension.cc` and `jaxlib/rocm_plugin_extension.cc` files were nearly identical so this change consolidates the shared implementation into a single target.

PiperOrigin-RevId: 704785926
2024-12-10 11:32:06 -08:00
Dan Foreman-Mackey
32df37e6e4 Port symmetric tridiagonal reduction GPU kernel to FFI.
PiperOrigin-RevId: 704382200
2024-12-09 12:41:23 -08:00
Andrey Portnoy
cc22334c21 [Mosaic GPU] Add CUPTI profiler alongside events-based implementation 2024-12-09 14:31:20 -05:00
Tomás Longeri
b76d264fe7 [Mosaic:TPU][NFC] In ext and trunc rules, avoid vreg array reshape by always using implicit shapes
PiperOrigin-RevId: 704297805
2024-12-09 08:35:14 -08:00
jax authors
cc258f5f61 Merge pull request #25320 from ROCm:gh-9948-fix-kernel-build-upstream
PiperOrigin-RevId: 704279150
2024-12-09 07:29:31 -08:00
Paweł Paruzel
d474feda9e Activate Tridiagonal Reduction to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Tridiagonal Reduction.

PiperOrigin-RevId: 704234350
2024-12-09 04:36:59 -08:00
Adam Paszke
adb2bf629c [Mosaic TPU] Allow downgrading the IR during serialization for forward compat
This is to uphold the monthly stability promise made by jax.export.

PiperOrigin-RevId: 704233290
2024-12-09 04:32:41 -08:00
Charles Hofer
0c6b967e86 Don't look for CUDA files when building the ROCm wheel 2024-12-06 17:23:15 +00:00
Paweł Paruzel
9081e85d68 Activate Schur Decomposition to XLA's FFI
PiperOrigin-RevId: 703484916
2024-12-06 06:49:53 -08:00
Tomás Longeri
651ab18874 [Mosaic:TPU] Fix elementwise inference with i1s
PiperOrigin-RevId: 703263310
2024-12-05 15:00:57 -08:00
Tomás Longeri
23d5c10ff0 [Mosaic:TPU] Fix fully replicated relayout
It was incorrect since batch dims are not replicated

PiperOrigin-RevId: 703189919
2024-12-05 11:38:21 -08:00