rocm-repo-management-api-2[bot]
b505df9973
Merge pull request #299 from ROCm/ci-upstream-sync-152_1
...
CI: 03/19/25 upstream sync
2025-03-19 07:20:19 -05:00
jax authors
e9ce8fb92d
Merge pull request #27227 from jburnim:jburnim_pallas_interpret_mode4
...
PiperOrigin-RevId: 738235363
2025-03-18 20:22:27 -07:00
jax authors
01a110c4c9
Better mosaic lowering for dynamic shapes, extend an interpreter into shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime.
...
PiperOrigin-RevId: 738171437
2025-03-18 15:51:15 -07:00
Jacob Burnim
47e8effdce
Adds option to initialize buffers to NaNs or zeros in TPU interpret mode.
2025-03-18 12:24:45 -07:00
Charles Hofer
c7b407c9f0
Merge branch 'rocm-main' into ci-upstream-sync-151_1
2025-03-18 15:27:35 +00:00
Sergei Lebedev
0ff234049b
Removed trivial docstrings from JAX tests
...
These docstrings do not make the tests any more clear and typically just duplicate the test module name.
PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00
Sergei Lebedev
a7e5eaee56
[pallas:mosaic_gpu] jnp.reduce_sum
now works for >1D arrays
...
PiperOrigin-RevId: 737578598
2025-03-17 05:32:07 -07:00
GitHub Actions
e275d5cf6c
Merge remote-tracking branch 'origin/rocm-main' into ci-upstream-sync-147_1
2025-03-14 22:42:07 +00:00
Sergei Lebedev
64230d1c93
[pallas:mosaic_gpu] WG lowering now supports while_p
...
PiperOrigin-RevId: 736996154
2025-03-14 14:59:29 -07:00
Tzu-Wei Sung
21f5f2d45e
[Pallas] Increase #rows when casting to x2.
...
There is a bug in XLA on v5p.
PiperOrigin-RevId: 736987667
2025-03-14 14:32:33 -07:00
Justin Fu
dbd8d92075
[Pallas] Add legacy PRNG key support to Pallas PRNG
...
PiperOrigin-RevId: 736949584
2025-03-14 12:30:04 -07:00
jax authors
cbece0b00b
Add explicit support for float8_e4m3b11fnuz in pl.dot
...
PiperOrigin-RevId: 736798315
2025-03-14 02:51:55 -07:00
Tzu-Wei Sung
e235fb9760
[Mosaic] Allow part of x2 int casts.
...
This should at least allow int2 -> int4 for native tiling vregs. Skip many tests due to XLA compatibility.
PiperOrigin-RevId: 736710186
2025-03-13 18:57:36 -07:00
jax authors
47bf22e37d
[pallas][Mosaic][Easy] Add batch dot dim test, remove check
...
PiperOrigin-RevId: 736623531
2025-03-13 13:38:44 -07:00
GitHub Actions
a0edd3fbb2
Merge remote-tracking branch 'origin/rocm-main' into ci-upstream-sync-144_1
2025-03-12 16:57:18 +00:00
Sergei Lebedev
e33f3fc48b
[pallas:mosaic_gpu] Added support for reductions to the WG lowering
...
Note that
* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
the dialect lowering AFAICT.
PiperOrigin-RevId: 736138554
2025-03-12 08:18:31 -07:00
Jevin Jiang
29bfd00f9c
[Pallas TPU] Fix preferred_element_type propagation in dot_general with const
...
PiperOrigin-RevId: 735903687
2025-03-11 15:06:07 -07:00
Jevin Jiang
eff612a3b6
Fix the assumption that pages_per_seq is already a multiple of num_kv_pages_per_blk.
...
PiperOrigin-RevId: 735851301
2025-03-11 12:36:33 -07:00
Charles Hofer
fb89a4b427
Merge branch 'rocm-main' into ci-upstream-sync-142_1
2025-03-11 16:33:59 +00:00
Benjamin Chetioui
7fd32ecc04
[Pallas/Mosaic GPU] Explicitly disable ops_test
on Mosaic GPU pre-Hopper.
...
PiperOrigin-RevId: 735744473
2025-03-11 07:11:09 -07:00
jax authors
aceae84fab
[Pallas] Enable skipping of floating-point operations when interpreting Pallas TPU kernels on CPU.
...
PiperOrigin-RevId: 735527650
2025-03-10 15:14:00 -07:00
Jacob Burnim
802cb33bf8
[Pallas] Increase tolerance in PallasOutOfBoundsInterpretTest.
...
PiperOrigin-RevId: 735519526
2025-03-10 14:49:34 -07:00
Jacob Burnim
73d20cd62a
[Pallas] Small fix to TPU interpret mode (input_output_aliases + scalar args).
...
PiperOrigin-RevId: 735455671
2025-03-10 11:40:10 -07:00
Sergei Lebedev
91340ea0a7
[pallas:mosaic_gpu] Added support for math functions to the WG lowering
...
PiperOrigin-RevId: 735333893
2025-03-10 05:08:19 -07:00
Benjamin Chetioui
75d8702023
[Pallas/Mosaic GPU] Add lowerings/layout inference for all the necessary conversion ops when using Warpgroup semantics.
...
Enable some of the pre-existing Pallas `ops_test`s for testing.
PiperOrigin-RevId: 735293084
2025-03-10 02:14:39 -07:00
Jevin Jiang
0f0636afab
[Mosaic TPU][Pallas] Add pl.reciprocal
...
PiperOrigin-RevId: 734749577
2025-03-07 18:29:30 -08:00
Jevin Jiang
041f575747
Support MHA in ragged paged attention for packed type
...
PiperOrigin-RevId: 734695213
2025-03-07 14:47:04 -08:00
Christos Perivolaropoulos
eeccc67c0b
[mgpu] Debug print arrays.
...
PiperOrigin-RevId: 734576543
2025-03-07 08:58:25 -08:00
Adam Paszke
402389290c
[Mosaic TPU] Enable all conversions involving fp8 types on TPUv5+
...
PiperOrigin-RevId: 734558364
2025-03-07 07:59:31 -08:00
Jevin Jiang
ff4310f640
[Mosaic TPU] Support fp8 upcast to f32
...
PiperOrigin-RevId: 734345644
2025-03-06 17:19:15 -08:00
Jevin Jiang
4b49c03523
Open source TPU-friendly ragged paged attention kernel.
...
Key features:
* ***Support mixed prefill and decode*** to increase throughput for inference. (eg., ***5x*** speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
* ***No explicit `swapaxes`*** for `seq_len` and `num_head` in pre/post kernel. The kernel takes `num_head` in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
* ***No GMM (Grouped Matmul) Metadata required!*** We calculate the metadata on the fly in the kernel. This can speed up ***10%***!
* ***Increase MXU utilization 8x in GQA*** by grouping shared q heads for MXU in decode.
* ***Minimize recompilation:*** The only factors can cause recompilation are model specs, `max_num_batched_tokens` and `max_num_seqs` in the setting of mixed engine.
PiperOrigin-RevId: 734269519
2025-03-06 13:36:45 -08:00
Sergei Lebedev
2a34019388
[pallas:mosaic_gpu] Added WG lowering rule for lax.bitcast_convert_type_p
...
PiperOrigin-RevId: 734081448
2025-03-06 04:09:55 -08:00
Chris Jones
d6b97c2026
[pallas] Add support for pl.dot
with int8
inputs.
...
PiperOrigin-RevId: 734081057
2025-03-06 04:08:04 -08:00
Benjamin Chetioui
fe577b5dc4
[Pallas/Mosaic GPU] Enable ops_test
for Mosaic GPU.
...
For now, most of the tests are skipped.
PiperOrigin-RevId: 734026728
2025-03-06 00:45:05 -08:00
Jacob Burnim
016b351f00
[Pallas] Adds a simple dynamic race detector for TPU interpret mode.
...
PiperOrigin-RevId: 733885890
2025-03-05 15:15:21 -08:00
Gleb Pobudzey
43b6be0e81
[Mosaic GPU] Add lowering for log
, and a fast path using log2.
...
PiperOrigin-RevId: 733411276
2025-03-04 11:50:50 -08:00
Charles Hofer
07cd809ba8
Merge branch 'rocm-main' into ci-upstream-sync-135_1
2025-03-04 16:22:17 +00:00
Sergei Lebedev
155839bb4d
[pallas:triton] Emit a better error message for matmul with non-2D operands
...
Triton seems to support both 2D and 3D operands now, the latter case being a
batched matmul. We need more changes in the lowering to support 3D, so I will
leave it out of scope here.
Fixes #26013 .
PiperOrigin-RevId: 733293299
2025-03-04 05:46:29 -08:00
Sharad Vikram
00d9f4529d
[Pallas/Fuser] Add support for custom_call_jvp/pjit to push_block_spec
...
PiperOrigin-RevId: 733122108
2025-03-03 17:43:13 -08:00
Sharad Vikram
0b6c355083
[Pallas] Add experimental (private for now) API for manual fusion into Pallas kernels
...
PiperOrigin-RevId: 733112191
2025-03-03 17:05:51 -08:00
jax authors
d8953e5311
Remove spurious zero_to_zero conversion used exclusively for backcompat types as a way of supporting (best effort) unsupported types on hardware to make it easier to debug.
...
In general, this is a good feature, but it assumed that the packing type utilized here was exclusively for backcompat, and so always applied the adjustment.
PiperOrigin-RevId: 731954456
2025-02-27 19:21:37 -08:00
jax authors
3450e2cee0
Disable certain tests on V4 and below.
...
PiperOrigin-RevId: 731812726
2025-02-27 12:02:52 -08:00
Chris Jones
d6752e9267
[pallas:triton] Generate more efficient code for loading contiguous slices of int4
values.
...
The existing `int4` loading code is very generic. When reading contiguous data, it will read with offsets like `0, 0, 1, 1, ...`. Triton doesn't consider these to be contiguous in memory and emits much less efficient code than when reading contiguous blocks.
PiperOrigin-RevId: 731635736
2025-02-27 01:57:47 -08:00
Adam Paszke
99a12ef9ea
[Mosaic GPU] Add support for warpgroup lowering of loops with vector carries
...
PiperOrigin-RevId: 731260912
2025-02-26 04:29:36 -08:00
Jacob Burnim
4c7140fa03
[Pallas] Add option for async DMAs in the new TPU interpret mode
...
When dma_execution_mode='on_wait', we wait to execute DMAs until we are interpreting a `dma_wait` instruction. In particular, while a device is waiting on a DMA semaphore, we will (partially) execute DMAs that signal that semaphore until the wait operation can succeed.
PiperOrigin-RevId: 731103569
2025-02-25 18:19:20 -08:00
Charles Hofer
45e2060b90
Merge branch 'rocm-main' into ci-upstream-sync-127_1
2025-02-25 19:30:08 +00:00
Gleb Pobudzey
a35494e020
Allow query and keys that aren’t multiples of 128
2025-02-25 19:13:24 +00:00
Adam Paszke
3d87a01bea
[Pallas:MGPU] Adjust warpgroup lowering to the recent emit_pipeline changes
...
The Pallas-level pipelining generates a number of ops we haven't had to deal with before
like conditionals, scans, etc.
PiperOrigin-RevId: 730899808
2025-02-25 08:39:44 -08:00
George Necula
c4e0db6f8a
[better_errors] Port the Pallas debug info mechanisms to the new JAX DebugInfo.
...
Now that we carry debug informatiion in Jaxpr we can remove the Pallas-specific
tracking of the `func_src_info`, e.g., `NameAndSrcInfo`.
2025-02-25 14:43:17 +01:00
Sergei Lebedev
c13a2f95d5
[pallas:mosaic_gpu] Use emit_pipeline
for pipelining in the lowering
...
This shaves off a lot of complexity from our lowering code, while retaining
all of the functionality, except the arrive_tx optimization: `emit_pipeline`
arrives once per buffer, whereas the pipelining in the lowering used to
arrive once for all buffers.
PiperOrigin-RevId: 730824239
2025-02-25 04:14:10 -08:00