684 Commits

Author SHA1 Message Date
Justin Fu
1f6152d11e [Pallas] Use Pallas cost estimator for flash attention.
PiperOrigin-RevId: 698573265
2024-11-20 17:12:37 -08:00
Jevin Jiang
9d2f62f811 [Pallas TPU] Support masked store
PiperOrigin-RevId: 698514079
2024-11-20 14:03:56 -08:00
Sergei Lebedev
9584ee3bb9 [pallas:mosaic_gpu] Avoid using multiple indexers in the parallel grid test
Turns out we can mix parallel grid with `plgpu.emit_pipeline` without doing
indexing at all!

PiperOrigin-RevId: 698442820
2024-11-20 10:42:02 -08:00
Christos Perivolaropoulos
8d84f28373 [pallas mgpu] Lowering for while loops as long as they are secretly for loops.
PiperOrigin-RevId: 698427307
2024-11-20 10:00:14 -08:00
Sergei Lebedev
1df4b5f798 [pallas] Do not skip vmap tests on GPU when x64 is enabled
PiperOrigin-RevId: 698351984
2024-11-20 05:08:23 -08:00
Peter Buchlovsky
14da7ebb76 [pallas:mosaic_gpu] Add Pallas Mosaic GPU lowering for jax.lax.bitcast_convert_type.
Only handles the case where operand type and target type have the same bitwidth.

PiperOrigin-RevId: 698332564
2024-11-20 03:41:19 -08:00
Sergei Lebedev
1bf70fbbc4 [pallas:mosaic_gpu] copy_gmem_to_smem no longer requires barrier to be a keyword argument
... because there really isn't any reason to require that.

PiperOrigin-RevId: 698116984
2024-11-19 13:02:35 -08:00
Justin Fu
c44f11d15e Add alternate implementation of threefry as a pallas kernel.
Current restrictions:
1) Dynamic grid sizes are not supported yet. This could in theory allow us to not recompile the kernel for different shapes.
2) fold_in and split still use the original rules. But there isn't a huge benefit to using the kernel right now since the input is so small and we can't avoid re-compilation due to (1).
3) Currently doesn't support high bits on the counter, meaning we can generate at max 4B numbers in one call. This is a fringe use-case since we only support 32-bit, and generating 4B 32-bit numbers would consume 16GB of HBM (an entire TPU v5p worth of HBM).

PiperOrigin-RevId: 698086352
2024-11-19 11:26:30 -08:00
jax authors
d397dd9684 Implement lax.pad in Pallas.
PiperOrigin-RevId: 697897093
2024-11-18 23:59:20 -08:00
Jevin Jiang
0fe77bc9f0 [Mosaic TPU] Support relayout for mask vector
We cast i1 vector (mask) to i32 vector before relayout and then cast back to i1 vector (mask) after relayout is finished.

PiperOrigin-RevId: 697823543
2024-11-18 18:07:15 -08:00
Tzu-Wei Sung
a60ef6e9bb [Pallas] Increase test coverage of pl.dot.
PiperOrigin-RevId: 697752355
2024-11-18 14:09:25 -08:00
Sergei Lebedev
aefe6215ca [pallas:mosaic_gpu] Ported two pipelining optimizations to emit_pipeline
* Skip SMEM->GMEM copy if the destination buffer is being revisited
* Skip SMEM->GMEM copy if the corresponding index map does not use grid indices

PiperOrigin-RevId: 696448043
2024-11-14 02:37:42 -08:00
Peter Buchlovsky
558ebb9fb1 Add Pallas Triton lowering for jax.lax.bitcast_convert_type.
Only handles the case where operand type and target type have the same bitwidth.

PiperOrigin-RevId: 696184251
2024-11-13 10:25:53 -08:00
Jevin Jiang
38d062dbee [Mosaic TPU] Support dynamic DMA and ref slice on the 2nd minor when memref is untiled
* Generalize any untiled memref to have tiling (packing, 128)
* Support dynamic index on 2nd minor.
* Support dynamic shape on 2nd minor.

PiperOrigin-RevId: 695516124
2024-11-11 16:14:27 -08:00
Justin Fu
0e611e5cac [Pallas] Add a cost estimator for Pallas/JAX functions.
Helps resolve the following issue, where invoking HLO's cost analysis triggers compilation which can OOM for larger inputs: https://github.com/jax-ml/jax/issues/24539. This cost estimator uses only abstract evaluation which should work for all input sizes.

PiperOrigin-RevId: 695415760
2024-11-11 11:13:58 -08:00
jax authors
a889a95aa1 Merge pull request #24839 from andportnoy:aportnoy/mosaic-gpu-hopper-tests
PiperOrigin-RevId: 695388748
2024-11-11 10:12:29 -08:00
Andrey Portnoy
24af8a676b [Mosaic GPU] Only run tests requiring sm90a on Hopper 2024-11-11 12:02:48 -05:00
jax authors
a041ea152e Skip test_jnp_einsum_grad_y_pallas on gpu due to ooms
PiperOrigin-RevId: 695143627
2024-11-10 16:38:06 -08:00
jax authors
0cc1747873 Add tests for jnp.einsum in Pallas
PiperOrigin-RevId: 694622626
2024-11-08 13:35:38 -08:00
Peter Hawkins
8f169e7fb5 Disable the paged_attention test on TPU v5p.
This test is failing in CI.

PiperOrigin-RevId: 694574616
2024-11-08 11:20:31 -08:00
jax authors
927d7fc205 Skip flaky test on tpuv4
PiperOrigin-RevId: 694372268
2024-11-07 23:00:33 -08:00
Peter Hawkins
88a62a45d3 Reverts 1a544b6f363fbb03edc40e03d759cd42a6b64733
PiperOrigin-RevId: 694223298
2024-11-07 13:09:54 -08:00
Ayaka
1a544b6f36 [Pallas] Fix lowering tests for reduction ops
Remove unnecessary skip statements. Also added tests for bf16 types.

PiperOrigin-RevId: 694130207
2024-11-07 08:37:24 -08:00
Adam Paszke
f8dba3c8a4 [Pallas:MGPU] Add support for multiple heads in attention
PiperOrigin-RevId: 694104006
2024-11-07 07:03:35 -08:00
jax authors
37af1002c7 Merge pull request #24602 from rdyro:rdyro-decode-attention-mask
PiperOrigin-RevId: 693835080
2024-11-06 13:05:29 -08:00
Tzu-Wei Sung
b6f5c95a5a [Pallas:TPU] Fix some stale/wrong skip conditions.
Surprised that we didn't test f32 dot_general on TPU (?) Even tpu_ops_test doesn't exercise it.

PiperOrigin-RevId: 693777426
2024-11-06 10:29:36 -08:00
Robert Dyro
d62510bfae Adding start index and kv_seq_len to decode kernel 2024-11-05 15:52:21 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
Sergei Lebedev
d2bbd56405 [pallas:mosaic_gpu] lax.fori_loop lowering now promotes the carry to mgpu.FragmentedArrays
PiperOrigin-RevId: 692976037
2024-11-04 08:29:00 -08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Sergei Lebedev
c52b3227d1 [pallas:mosaic_gpu] Added a 2D test for emit_pipeline
PiperOrigin-RevId: 692945663
2024-11-04 06:38:12 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
George Necula
292a00b35a [export] Cleanup in the export module.
With jax.experimental.export gone we can now do some cleanup in the export module.

In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.

PiperOrigin-RevId: 692398132
2024-11-01 22:56:44 -07:00
Ayaka
f60b97cea1 [Pallas TPU] Add lowering for lax.nextafter
Also improved the corresponding test cases to ensure better coverage and accuracy.

This PR is similar to https://github.com/jax-ml/jax/pull/22283, which adds lowering for `lax.sign`.

PiperOrigin-RevId: 691988164
2024-10-31 17:34:38 -07:00
Tzu-Wei Sung
7af7a60dcc [Pallas:TPU] Use arith.divui for uint32 div.
PiperOrigin-RevId: 691939453
2024-10-31 14:37:47 -07:00
Sergei Lebedev
85662f6dd8 [pallas:mosaic_gpu] plgpu.copy_smem_to_gmem no longer transparently commits SMEM
Users are expected to call `pltpu.commit_smem` manually instead.

PiperOrigin-RevId: 691724662
2024-10-31 02:21:10 -07:00
Bart Chrzaszcz
44158ab0e4 #sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike

Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.

Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.

PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
jax authors
3904ced255 [Mosaic] Test only cl - add triu test, skip bf16 due to select being native bitwidth only
PiperOrigin-RevId: 691477248
2024-10-30 10:48:44 -07:00
Sergei Lebedev
409517fcbc [pallas:mosaic_gpu] Disabled verbose lowering errors in Mosaic GPU tests
PiperOrigin-RevId: 691472782
2024-10-30 10:37:32 -07:00
Tzu-Wei Sung
d2f5804449 [Pallas] Add test cases for var + constant.
PiperOrigin-RevId: 691450143
2024-10-30 09:37:50 -07:00
Ayaka
8f96e9082a [Pallas TPU] Add lowerings for scalar absi
This PR is a follow-up of https://github.com/jax-ml/jax/pull/24504, which adds lowerings for scalar `absf` and `rsqrt`.

PiperOrigin-RevId: 691402430
2024-10-30 06:55:34 -07:00
Jake VanderPlas
b65fdcc612 pallas: remove build dependency on jax.experimental.export
jax.experimental.export is deprecated, and it looks like the build rule is unused.

PiperOrigin-RevId: 691205626
2024-10-29 16:41:50 -07:00
jax authors
5ad066eeaa [TPU][Mosaic] Replace tpu lowering (at canonicalization) for repeat with concat (which handles far more cases)
PiperOrigin-RevId: 691192121
2024-10-29 15:57:44 -07:00
Adam Paszke
8b21614973 [Pallas:MGPU] Add FlashAttention3 as an example
PiperOrigin-RevId: 690977852
2024-10-29 05:21:43 -07:00
Ayaka
a8d1048cb6 [Pallas] Add tests for jnp.logical_not
PiperOrigin-RevId: 690825419
2024-10-28 18:53:24 -07:00
Adam Paszke
36c56fa19b [Pallas:MGPU] Fix flaky debug_print tests
Turns out that waiting for the kernel to finish it not enough, since the
prints also need to be processed by the CUDA runtime. Using a test-only
function that synchronizes all the devices seems to suffice.

PiperOrigin-RevId: 690624999
2024-10-28 08:42:02 -07:00
Sergei Lebedev
dfa6fcd56b [pallas:mosaic_gpu] Extracted a basic emit_pipeline API from the in kernel pipelining test
PiperOrigin-RevId: 690619853
2024-10-28 08:25:47 -07:00
Adam Paszke
343cf18e09 [Pallas:MGPU] Wire up the Mosaic GPU profiler into Pallas
PiperOrigin-RevId: 690574747
2024-10-28 05:40:08 -07:00
Ayaka
5c614470ad [Pallas TPU] Add lowerings for scalar absf and rsqrt
This PR is similar to https://github.com/jax-ml/jax/pull/24284

PiperOrigin-RevId: 689546724
2024-10-24 15:59:34 -07:00
Adam Paszke
bb2e2303d7 [Pallas:MGPU] Treat each warpgroup as a single logical thread.
As an extra minor change, we now disallow specifying the predicate when uniform is
unset, as that implies that we're going to use two different mechanisms to select
a single thread.

PiperOrigin-RevId: 689289365
2024-10-24 01:54:10 -07:00