688 Commits

Author SHA1 Message Date
rocm-repo-management-api-2[bot]
b505df9973
Merge pull request #299 from ROCm/ci-upstream-sync-152_1
CI: 03/19/25 upstream sync
2025-03-19 07:20:19 -05:00
jax authors
e9ce8fb92d Merge pull request #27227 from jburnim:jburnim_pallas_interpret_mode4
PiperOrigin-RevId: 738235363
2025-03-18 20:22:27 -07:00
jax authors
01a110c4c9 Better mosaic lowering for dynamic shapes, extend an interpreter into shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime.
PiperOrigin-RevId: 738171437
2025-03-18 15:51:15 -07:00
Jacob Burnim
47e8effdce Adds option to initialize buffers to NaNs or zeros in TPU interpret mode. 2025-03-18 12:24:45 -07:00
Charles Hofer
c7b407c9f0 Merge branch 'rocm-main' into ci-upstream-sync-151_1 2025-03-18 15:27:35 +00:00
Sergei Lebedev
0ff234049b Removed trivial docstrings from JAX tests
These docstrings do not make the tests any more clear and typically just duplicate the test module name.

PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00
Sergei Lebedev
a7e5eaee56 [pallas:mosaic_gpu] jnp.reduce_sum now works for >1D arrays
PiperOrigin-RevId: 737578598
2025-03-17 05:32:07 -07:00
GitHub Actions
e275d5cf6c Merge remote-tracking branch 'origin/rocm-main' into ci-upstream-sync-147_1 2025-03-14 22:42:07 +00:00
Sergei Lebedev
64230d1c93 [pallas:mosaic_gpu] WG lowering now supports while_p
PiperOrigin-RevId: 736996154
2025-03-14 14:59:29 -07:00
Tzu-Wei Sung
21f5f2d45e [Pallas] Increase #rows when casting to x2.
There is a bug in XLA on v5p.

PiperOrigin-RevId: 736987667
2025-03-14 14:32:33 -07:00
Justin Fu
dbd8d92075 [Pallas] Add legacy PRNG key support to Pallas PRNG
PiperOrigin-RevId: 736949584
2025-03-14 12:30:04 -07:00
jax authors
cbece0b00b Add explicit support for float8_e4m3b11fnuz in pl.dot
PiperOrigin-RevId: 736798315
2025-03-14 02:51:55 -07:00
Tzu-Wei Sung
e235fb9760 [Mosaic] Allow part of x2 int casts.
This should at least allow int2 -> int4 for native tiling vregs. Skip many tests due to XLA compatibility.

PiperOrigin-RevId: 736710186
2025-03-13 18:57:36 -07:00
jax authors
47bf22e37d [pallas][Mosaic][Easy] Add batch dot dim test, remove check
PiperOrigin-RevId: 736623531
2025-03-13 13:38:44 -07:00
GitHub Actions
a0edd3fbb2 Merge remote-tracking branch 'origin/rocm-main' into ci-upstream-sync-144_1 2025-03-12 16:57:18 +00:00
Sergei Lebedev
e33f3fc48b [pallas:mosaic_gpu] Added support for reductions to the WG lowering
Note that

* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
  the dialect lowering AFAICT.

PiperOrigin-RevId: 736138554
2025-03-12 08:18:31 -07:00
Jevin Jiang
29bfd00f9c [Pallas TPU] Fix preferred_element_type propagation in dot_general with const
PiperOrigin-RevId: 735903687
2025-03-11 15:06:07 -07:00
Jevin Jiang
eff612a3b6 Fix the assumption that pages_per_seq is already a multiple of num_kv_pages_per_blk.
PiperOrigin-RevId: 735851301
2025-03-11 12:36:33 -07:00
Charles Hofer
fb89a4b427 Merge branch 'rocm-main' into ci-upstream-sync-142_1 2025-03-11 16:33:59 +00:00
Benjamin Chetioui
7fd32ecc04 [Pallas/Mosaic GPU] Explicitly disable ops_test on Mosaic GPU pre-Hopper.
PiperOrigin-RevId: 735744473
2025-03-11 07:11:09 -07:00
jax authors
aceae84fab [Pallas] Enable skipping of floating-point operations when interpreting Pallas TPU kernels on CPU.
PiperOrigin-RevId: 735527650
2025-03-10 15:14:00 -07:00
Jacob Burnim
802cb33bf8 [Pallas] Increase tolerance in PallasOutOfBoundsInterpretTest.
PiperOrigin-RevId: 735519526
2025-03-10 14:49:34 -07:00
Jacob Burnim
73d20cd62a [Pallas] Small fix to TPU interpret mode (input_output_aliases + scalar args).
PiperOrigin-RevId: 735455671
2025-03-10 11:40:10 -07:00
Sergei Lebedev
91340ea0a7 [pallas:mosaic_gpu] Added support for math functions to the WG lowering
PiperOrigin-RevId: 735333893
2025-03-10 05:08:19 -07:00
Benjamin Chetioui
75d8702023 [Pallas/Mosaic GPU] Add lowerings/layout inference for all the necessary conversion ops when using Warpgroup semantics.
Enable some of the pre-existing Pallas `ops_test`s for testing.

PiperOrigin-RevId: 735293084
2025-03-10 02:14:39 -07:00
Jevin Jiang
0f0636afab [Mosaic TPU][Pallas] Add pl.reciprocal
PiperOrigin-RevId: 734749577
2025-03-07 18:29:30 -08:00
Jevin Jiang
041f575747 Support MHA in ragged paged attention for packed type
PiperOrigin-RevId: 734695213
2025-03-07 14:47:04 -08:00
Christos Perivolaropoulos
eeccc67c0b [mgpu] Debug print arrays.
PiperOrigin-RevId: 734576543
2025-03-07 08:58:25 -08:00
Adam Paszke
402389290c [Mosaic TPU] Enable all conversions involving fp8 types on TPUv5+
PiperOrigin-RevId: 734558364
2025-03-07 07:59:31 -08:00
Jevin Jiang
ff4310f640 [Mosaic TPU] Support fp8 upcast to f32
PiperOrigin-RevId: 734345644
2025-03-06 17:19:15 -08:00
Jevin Jiang
4b49c03523 Open source TPU-friendly ragged paged attention kernel.
Key features:
* ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
* ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
* ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***!
* ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode.
* ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine.

PiperOrigin-RevId: 734269519
2025-03-06 13:36:45 -08:00
Sergei Lebedev
2a34019388 [pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
PiperOrigin-RevId: 734081448
2025-03-06 04:09:55 -08:00
Chris Jones
d6b97c2026 [pallas] Add support for pl.dot with int8 inputs.
PiperOrigin-RevId: 734081057
2025-03-06 04:08:04 -08:00
Benjamin Chetioui
fe577b5dc4 [Pallas/Mosaic GPU] Enable ops_test for Mosaic GPU.
For now, most of the tests are skipped.

PiperOrigin-RevId: 734026728
2025-03-06 00:45:05 -08:00
Jacob Burnim
016b351f00 [Pallas] Adds a simple dynamic race detector for TPU interpret mode.
PiperOrigin-RevId: 733885890
2025-03-05 15:15:21 -08:00
Gleb Pobudzey
43b6be0e81 [Mosaic GPU] Add lowering for log, and a fast path using log2.
PiperOrigin-RevId: 733411276
2025-03-04 11:50:50 -08:00
Charles Hofer
07cd809ba8 Merge branch 'rocm-main' into ci-upstream-sync-135_1 2025-03-04 16:22:17 +00:00
Sergei Lebedev
155839bb4d [pallas:triton] Emit a better error message for matmul with non-2D operands
Triton seems to support both 2D and 3D operands now, the latter case being a
batched matmul. We need more changes in the lowering to support 3D, so I will
leave it out of scope here.

Fixes #26013.

PiperOrigin-RevId: 733293299
2025-03-04 05:46:29 -08:00
Sharad Vikram
00d9f4529d [Pallas/Fuser] Add support for custom_call_jvp/pjit to push_block_spec
PiperOrigin-RevId: 733122108
2025-03-03 17:43:13 -08:00
Sharad Vikram
0b6c355083 [Pallas] Add experimental (private for now) API for manual fusion into Pallas kernels
PiperOrigin-RevId: 733112191
2025-03-03 17:05:51 -08:00
jax authors
d8953e5311 Remove spurious zero_to_zero conversion used exclusively for backcompat types as a way of supporting (best effort) unsupported types on hardware to make it easier to debug.
In general, this is a good feature, but it assumed that the packing type utilized here was exclusively for backcompat, and so always applied the adjustment.

PiperOrigin-RevId: 731954456
2025-02-27 19:21:37 -08:00
jax authors
3450e2cee0 Disable certain tests on V4 and below.
PiperOrigin-RevId: 731812726
2025-02-27 12:02:52 -08:00
Chris Jones
d6752e9267 [pallas:triton] Generate more efficient code for loading contiguous slices of int4 values.
The existing `int4` loading code is very generic. When reading contiguous data, it will read with offsets like `0, 0, 1, 1, ...`. Triton doesn't consider these to be contiguous in memory and emits much less efficient code than when reading contiguous blocks.

PiperOrigin-RevId: 731635736
2025-02-27 01:57:47 -08:00
Adam Paszke
99a12ef9ea [Mosaic GPU] Add support for warpgroup lowering of loops with vector carries
PiperOrigin-RevId: 731260912
2025-02-26 04:29:36 -08:00
Jacob Burnim
4c7140fa03 [Pallas] Add option for async DMAs in the new TPU interpret mode
When dma_execution_mode='on_wait', we wait to execute DMAs until we are interpreting a `dma_wait` instruction.  In particular, while a device is waiting on a DMA semaphore, we will (partially) execute DMAs that signal that semaphore until the wait operation can succeed.

PiperOrigin-RevId: 731103569
2025-02-25 18:19:20 -08:00
Charles Hofer
45e2060b90 Merge branch 'rocm-main' into ci-upstream-sync-127_1 2025-02-25 19:30:08 +00:00
Gleb Pobudzey
a35494e020 Allow query and keys that aren’t multiples of 128 2025-02-25 19:13:24 +00:00
Adam Paszke
3d87a01bea [Pallas:MGPU] Adjust warpgroup lowering to the recent emit_pipeline changes
The Pallas-level pipelining generates a number of ops we haven't had to deal with before
like conditionals, scans, etc.

PiperOrigin-RevId: 730899808
2025-02-25 08:39:44 -08:00
George Necula
c4e0db6f8a [better_errors] Port the Pallas debug info mechanisms to the new JAX DebugInfo.
Now that we carry debug informatiion in Jaxpr we can remove the Pallas-specific
tracking of the `func_src_info`, e.g., `NameAndSrcInfo`.
2025-02-25 14:43:17 +01:00
Sergei Lebedev
c13a2f95d5 [pallas:mosaic_gpu] Use emit_pipeline for pipelining in the lowering
This shaves off a lot of complexity from our lowering code, while retaining
all of the functionality, except the arrive_tx optimization: `emit_pipeline`
arrives once per buffer, whereas the pipelining in the lowering used to
arrive once for all buffers.

PiperOrigin-RevId: 730824239
2025-02-25 04:14:10 -08:00