1158 Commits

Author SHA1 Message Date
jax authors
1471702adc [Mosaic TPU] Support 1D concat: set implicit_dim to kSecondMinor to treat 1D (N,) as (1, N) and then tile it as (1, 128)
PiperOrigin-RevId: 696870258
2024-11-15 06:41:57 -08:00
jax authors
a8464ce761 [Mosaic][TPU] Omit short circuiting of relayout (we should always relayout!) and implement product mismatch case for where we relayout from replicated to offset, and the number of vregs changes.
PiperOrigin-RevId: 696557463
2024-11-14 09:53:25 -08:00
Naums Mogers
c32db46e6c [Mosaic] Add parameter names to tpu.sem_signal and add tests
This CLs adds parameter names to the optional parameters of `tpu.sem_signal` -- `device_id`, `core_id` -- to remove the ambiguity upon deserialization.
Adds LIT tests of signalling on TC with parameter names.

PiperOrigin-RevId: 695875037
2024-11-12 14:37:47 -08:00
Sergei Lebedev
d304025a41 [mosaic_gpu] The profiler now uses FFI calls for creating events and computing elapsed time
PiperOrigin-RevId: 695798787
2024-11-12 11:01:59 -08:00
jax authors
1221da8467 [Mosaic] Fix mask creation for packed sublanes
Unaligned concat used to be f32 only, but implicitly protected via unimplemented support for multi-row-shift in sub32 types. When this was added, we started invoking unaligned concat flow w/ sub32 types, but the masking code that assumed full rows (unpacked types) was no longer sufficient - we need better granularity for these cases. This only affects sublanes, as that is where we pack, we don't have partial lanes.

This CL, as a small benefit, also adds better error messages to the ops involved in lower_to_llo.cc.

PiperOrigin-RevId: 695796095
2024-11-12 10:55:19 -08:00
Dan Foreman-Mackey
21e98b5ce4 Fix overflow error in GPU batched linear algebra kernels.
As reported in https://github.com/jax-ml/jax/issues/24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695694648
2024-11-12 05:33:49 -08:00
Jevin Jiang
38d062dbee [Mosaic TPU] Support dynamic DMA and ref slice on the 2nd minor when memref is untiled
* Generalize any untiled memref to have tiling (packing, 128)
* Support dynamic index on 2nd minor.
* Support dynamic shape on 2nd minor.

PiperOrigin-RevId: 695516124
2024-11-11 16:14:27 -08:00
Benjamin Chetioui
da89c9e38c [Mosaic GPU] Add base_pointer argument to InitializeBarrierOp.
This corresponds to what's implemented in `BarrierRef`, and ultimately makes it
easier to allocate barriers at a specific address in dynamic shared memory.

PiperOrigin-RevId: 695308297
2024-11-11 06:18:26 -08:00
Dimitar (Mitko) Asenov
d833066a1f [MOSAIC:GPU] Add async_load, async_store, and supporting attributes to the MLIR Mosaic GPU Dialect.
PiperOrigin-RevId: 694643777
2024-11-08 14:34:23 -08:00
jax authors
4d1a1264f0 Merge pull request #24778 from cainmagi:fix-pr-23852
PiperOrigin-RevId: 694565904
2024-11-08 11:00:19 -08:00
Adam Paszke
ce3826d098 [Mosaic GPU] Make sure to free the cloned MLIR module when debugging
We only recently started using this in tests and it has caused ASAN
to report a bunch of leaks.

PiperOrigin-RevId: 694510867
2024-11-08 08:35:10 -08:00
Yuchen Jin
218f763255
(follow-up of PR #23852) add missing typename keyword to work with gcc
This update is a follow-up of PR #23852. In the previous PR, there was one missing place where the `typename` was not added.
2024-11-07 23:55:38 -06:00
Tomás Longeri
04a6652243 [Mosaic] Fix handling of i1 splat constants
PiperOrigin-RevId: 694248723
2024-11-07 14:28:59 -08:00
Tzu-Wei Sung
8b7bcadebe [Mosaic] Fix canonicalize_extract op name.
PiperOrigin-RevId: 694236671
2024-11-07 13:51:52 -08:00
Naums Mogers
3df204a457 [Mosaic] Verify that tpu.sem_wait semaphore rank is zero
Since we only wait on one semaphore, we should enforce this in the verifier.

PiperOrigin-RevId: 693770055
2024-11-06 10:10:15 -08:00
Peter Hawkins
ea1e879577 Include mpmath as a bazel dependency of lax_test.
This test has additional test cases that require mpmath.

PiperOrigin-RevId: 693464078
2024-11-05 13:43:06 -08:00
Sergei Lebedev
34b4787e2e [mosaic_gpu] Check the return code of gpuEventCreate and gpuEventDestroy
PiperOrigin-RevId: 693260326
2024-11-05 01:59:58 -08:00
Benjamin Chetioui
63e59c5fd7 [Mosaic GPU] Ensure that the dialect module can be loaded successfully.
This requires that the file providing the bindings has the same name as the
dialect it defines, since dialect search looks for a module path of the form
`<prefix>.<dialect namespace>`.

PiperOrigin-RevId: 693241875
2024-11-05 00:47:21 -08:00
Praveen Batra
8296f6e0ba [Mosaic] Add extension files for infer/apply vector layout.
PiperOrigin-RevId: 691868278
2024-10-31 11:08:37 -07:00
Praveen Batra
7d9f565647 [Mosaic] Fix some imports.
PiperOrigin-RevId: 691830491
2024-10-31 09:25:34 -07:00
Benjamin Chetioui
c708a04c6e [Mosaic GPU] Add Python bindings for the Mosaic GPU MLIR dialect.
Also start moving the existing C++ tests to Python.

PiperOrigin-RevId: 691729887
2024-10-31 02:47:30 -07:00
Dimitar (Mitko) Asenov
7d504cd95a [MOSAIC:GPU] Extend the mosaic mlir dialect with fragmented layouts.
PiperOrigin-RevId: 691712579
2024-10-31 01:29:22 -07:00
jax authors
5aeffde707 [Mosaic] Extend tpu matmulop to have dimension dims. Add support for batching and simple transposition.
PiperOrigin-RevId: 691706218
2024-10-31 00:59:13 -07:00
Naums Mogers
242e6634ff [Mosaic] Add the core type enum
The new attribute allows differentiating compilation by target core.

PiperOrigin-RevId: 691531726
2024-10-30 13:23:34 -07:00
jax authors
99ea4c1a4a [Fix] Put * packing into reshape no-op condition (Bug in my original CL)
PiperOrigin-RevId: 691476663
2024-10-30 10:47:23 -07:00
jax authors
5ad066eeaa [TPU][Mosaic] Replace tpu lowering (at canonicalization) for repeat with concat (which handles far more cases)
PiperOrigin-RevId: 691192121
2024-10-29 15:57:44 -07:00
Peter Hawkins
bee2bc443a Remove some dead code from gpu_prng.py 2024-10-29 09:29:56 -04:00
jax authors
de68018473 [NFC][Mosaic TPU] Clarify layout comment block
PiperOrigin-RevId: 690977672
2024-10-29 05:20:08 -07:00
jax authors
12d26053e3 [TPU][Mosaic] Add support for a no-op reshape where sublane_tiling = 1 and the res_tiled and src_tiled shapes both fill a full vreg (1024)
PiperOrigin-RevId: 690796348
2024-10-28 16:57:51 -07:00
Adam Paszke
36c56fa19b [Pallas:MGPU] Fix flaky debug_print tests
Turns out that waiting for the kernel to finish it not enough, since the
prints also need to be processed by the CUDA runtime. Using a test-only
function that synchronizes all the devices seems to suffice.

PiperOrigin-RevId: 690624999
2024-10-28 08:42:02 -07:00
Sergei Lebedev
04bdd07f66 [mosaic_gpu] mgpu.FragmentedArray now supports //
This is needed to compute grid index from the iteration step counter in `emit_pipeline`.

PiperOrigin-RevId: 690608581
2024-10-28 07:52:22 -07:00
Jevin Jiang
2a671e25a7 [Mosaic TPU] Remove extra check
PiperOrigin-RevId: 689852989
2024-10-25 11:22:17 -07:00
Tzu-Wei Sung
4972f84c94 [Mosaic] Use max sublane offset per shuffled load to decide whether to avoid bank conflict.
PiperOrigin-RevId: 689809024
2024-10-25 09:09:14 -07:00
jax authors
63c1699ed0 Fix a use-after-free bug in third_party/py/jax/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc
The backing array of the initializer_list is destroyed at the end of the full expression.

PiperOrigin-RevId: 689783482
2024-10-25 07:40:12 -07:00
Kanglan Tang
af28595909 Add a jax_wheel Bazel rule to build jax pip packages
PiperOrigin-RevId: 689514531
2024-10-24 14:20:46 -07:00
Adam Paszke
6634f5a348 [Mosaic GPU] Use absl::StrCat instead std::string::operator+
Repeated string addition is apparently a bit of an anti-pattern. Not that it matters
much in this place, but why not do it properly.

PiperOrigin-RevId: 689416587
2024-10-24 09:49:51 -07:00
Andrey Portnoy
14e0f0e7fa [Mosaic GPU] Query SM and PTX ISA dynamically using driver and LLVM
Originally proposed in #24021. Slightly rewritter to make testing with internal LLVM toolchains better.

Use CUDA driver API to query major and minor compute capabilities, thus arriving at a "base" SM string (e.g. `sm_90`).
Then use LLVM to see if we can "upgrade" the base SM string to one that enables architecture-specific capabilities (e.g. `sm_90a`).
Then use LLVM to map the SM string to a PTX ISA version that supports the SM.

Co-authored-by: Andrey Portnoy <aportnoy@nvidia.com>
PiperOrigin-RevId: 689286774
2024-10-24 01:46:29 -07:00
Jevin Jiang
b8bacda2d9 [Mosaic TPU] Use native vector tiling to load and store with untiled memref.
PiperOrigin-RevId: 689142734
2024-10-23 16:22:16 -07:00
jax authors
48bddc6f6c Adds arith.select to the op patters in order to canonicalize non 32 bit selects.
PiperOrigin-RevId: 687635492
2024-10-19 09:09:06 -07:00
Benjamin Chetioui
ade480ff05 Add a dialect for Mosaic GPU.
PiperOrigin-RevId: 687325692
2024-10-18 09:11:31 -07:00
Dan Foreman-Mackey
8361eb58e1 Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.

This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:

1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.

2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.

Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.

PiperOrigin-RevId: 687106965
2024-10-17 17:57:06 -07:00
jax authors
6c2649fdf2 Rewrite mosaic concat to support operand shapes that do not align with native shapes, Expand tests to cover multi operand, batch dim concat, etc.
PiperOrigin-RevId: 687003778
2024-10-17 12:24:51 -07:00
Ionel Gog
ec279f9c54 Add config option to log or fatal when jax.Arrays are GCed.
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.

PiperOrigin-RevId: 687003464
2024-10-17 12:23:16 -07:00
jax authors
9027fb38fe Fix segfault
PiperOrigin-RevId: 686821923
2024-10-17 01:52:44 -07:00
Jevin Jiang
a47b755619 [Mosaic TPU] Support native int4 @ int4
PiperOrigin-RevId: 686179715
2024-10-15 11:35:23 -07:00
Yash Katariya
824ccd7183 [Shardy] Inline meshes when using shardy and get rid of global meshes from the MLIR body.
Also do a couple of cleanups.

PiperOrigin-RevId: 685746298
2024-10-14 10:08:04 -07:00
Bart Chrzaszcz
75e22f2ccd #sdy Run inlined mesh lifter pass at the end of JAX lowering.
PiperOrigin-RevId: 685728692
2024-10-14 09:13:12 -07:00
jax authors
57ef7a4a59 Merge pull request #24274 from ROCm:ci_linalg_fix
PiperOrigin-RevId: 685717437
2024-10-14 08:33:33 -07:00
Paweł Paruzel
23fdb91252 Port Schur Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685689593
2024-10-14 06:46:42 -07:00
Paweł Paruzel
ec68d420fe Port Tridiagonal Reduction to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685679646
2024-10-14 06:02:59 -07:00