Justin Fu
1f6152d11e
[Pallas] Use Pallas cost estimator for flash attention.
...
PiperOrigin-RevId: 698573265
2024-11-20 17:12:37 -08:00
Jevin Jiang
9d2f62f811
[Pallas TPU] Support masked store
...
PiperOrigin-RevId: 698514079
2024-11-20 14:03:56 -08:00
Sergei Lebedev
9584ee3bb9
[pallas:mosaic_gpu] Avoid using multiple indexers in the parallel grid test
...
Turns out we can mix parallel grid with `plgpu.emit_pipeline` without doing
indexing at all!
PiperOrigin-RevId: 698442820
2024-11-20 10:42:02 -08:00
Christos Perivolaropoulos
8d84f28373
[pallas mgpu] Lowering for while loops as long as they are secretly for loops.
...
PiperOrigin-RevId: 698427307
2024-11-20 10:00:14 -08:00
Sergei Lebedev
1df4b5f798
[pallas] Do not skip vmap tests on GPU when x64 is enabled
...
PiperOrigin-RevId: 698351984
2024-11-20 05:08:23 -08:00
Peter Buchlovsky
14da7ebb76
[pallas:mosaic_gpu] Add Pallas Mosaic GPU lowering for jax.lax.bitcast_convert_type.
...
Only handles the case where operand type and target type have the same bitwidth.
PiperOrigin-RevId: 698332564
2024-11-20 03:41:19 -08:00
Sergei Lebedev
1bf70fbbc4
[pallas:mosaic_gpu] copy_gmem_to_smem
no longer requires barrier
to be a keyword argument
...
... because there really isn't any reason to require that.
PiperOrigin-RevId: 698116984
2024-11-19 13:02:35 -08:00
Justin Fu
c44f11d15e
Add alternate implementation of threefry as a pallas kernel.
...
Current restrictions:
1) Dynamic grid sizes are not supported yet. This could in theory allow us to not recompile the kernel for different shapes.
2) fold_in and split still use the original rules. But there isn't a huge benefit to using the kernel right now since the input is so small and we can't avoid re-compilation due to (1).
3) Currently doesn't support high bits on the counter, meaning we can generate at max 4B numbers in one call. This is a fringe use-case since we only support 32-bit, and generating 4B 32-bit numbers would consume 16GB of HBM (an entire TPU v5p worth of HBM).
PiperOrigin-RevId: 698086352
2024-11-19 11:26:30 -08:00
jax authors
d397dd9684
Implement lax.pad in Pallas.
...
PiperOrigin-RevId: 697897093
2024-11-18 23:59:20 -08:00
Jevin Jiang
0fe77bc9f0
[Mosaic TPU] Support relayout for mask vector
...
We cast i1 vector (mask) to i32 vector before relayout and then cast back to i1 vector (mask) after relayout is finished.
PiperOrigin-RevId: 697823543
2024-11-18 18:07:15 -08:00
Tzu-Wei Sung
a60ef6e9bb
[Pallas] Increase test coverage of pl.dot.
...
PiperOrigin-RevId: 697752355
2024-11-18 14:09:25 -08:00
Sergei Lebedev
aefe6215ca
[pallas:mosaic_gpu] Ported two pipelining optimizations to emit_pipeline
...
* Skip SMEM->GMEM copy if the destination buffer is being revisited
* Skip SMEM->GMEM copy if the corresponding index map does not use grid indices
PiperOrigin-RevId: 696448043
2024-11-14 02:37:42 -08:00
Peter Buchlovsky
558ebb9fb1
Add Pallas Triton lowering for jax.lax.bitcast_convert_type.
...
Only handles the case where operand type and target type have the same bitwidth.
PiperOrigin-RevId: 696184251
2024-11-13 10:25:53 -08:00
Jevin Jiang
38d062dbee
[Mosaic TPU] Support dynamic DMA and ref slice on the 2nd minor when memref is untiled
...
* Generalize any untiled memref to have tiling (packing, 128)
* Support dynamic index on 2nd minor.
* Support dynamic shape on 2nd minor.
PiperOrigin-RevId: 695516124
2024-11-11 16:14:27 -08:00
Justin Fu
0e611e5cac
[Pallas] Add a cost estimator for Pallas/JAX functions.
...
Helps resolve the following issue, where invoking HLO's cost analysis triggers compilation which can OOM for larger inputs: https://github.com/jax-ml/jax/issues/24539 . This cost estimator uses only abstract evaluation which should work for all input sizes.
PiperOrigin-RevId: 695415760
2024-11-11 11:13:58 -08:00
jax authors
a889a95aa1
Merge pull request #24839 from andportnoy:aportnoy/mosaic-gpu-hopper-tests
...
PiperOrigin-RevId: 695388748
2024-11-11 10:12:29 -08:00
Andrey Portnoy
24af8a676b
[Mosaic GPU] Only run tests requiring sm90a on Hopper
2024-11-11 12:02:48 -05:00
jax authors
a041ea152e
Skip test_jnp_einsum_grad_y_pallas on gpu due to ooms
...
PiperOrigin-RevId: 695143627
2024-11-10 16:38:06 -08:00
jax authors
0cc1747873
Add tests for jnp.einsum in Pallas
...
PiperOrigin-RevId: 694622626
2024-11-08 13:35:38 -08:00
Peter Hawkins
8f169e7fb5
Disable the paged_attention test on TPU v5p.
...
This test is failing in CI.
PiperOrigin-RevId: 694574616
2024-11-08 11:20:31 -08:00
jax authors
927d7fc205
Skip flaky test on tpuv4
...
PiperOrigin-RevId: 694372268
2024-11-07 23:00:33 -08:00
Peter Hawkins
88a62a45d3
Reverts 1a544b6f363fbb03edc40e03d759cd42a6b64733
...
PiperOrigin-RevId: 694223298
2024-11-07 13:09:54 -08:00
Ayaka
1a544b6f36
[Pallas] Fix lowering tests for reduction ops
...
Remove unnecessary skip statements. Also added tests for bf16 types.
PiperOrigin-RevId: 694130207
2024-11-07 08:37:24 -08:00
Adam Paszke
f8dba3c8a4
[Pallas:MGPU] Add support for multiple heads in attention
...
PiperOrigin-RevId: 694104006
2024-11-07 07:03:35 -08:00
jax authors
37af1002c7
Merge pull request #24602 from rdyro:rdyro-decode-attention-mask
...
PiperOrigin-RevId: 693835080
2024-11-06 13:05:29 -08:00
Tzu-Wei Sung
b6f5c95a5a
[Pallas:TPU] Fix some stale/wrong skip conditions.
...
Surprised that we didn't test f32 dot_general on TPU (?) Even tpu_ops_test doesn't exercise it.
PiperOrigin-RevId: 693777426
2024-11-06 10:29:36 -08:00
Robert Dyro
d62510bfae
Adding start index and kv_seq_len to decode kernel
2024-11-05 15:52:21 -08:00
Dougal Maclaurin
478b750c29
Reverts f281c6f46475270a57a02416469226315377592c
...
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
Sergei Lebedev
d2bbd56405
[pallas:mosaic_gpu] lax.fori_loop
lowering now promotes the carry to mgpu.FragmentedArray
s
...
PiperOrigin-RevId: 692976037
2024-11-04 08:29:00 -08:00
Dougal Maclaurin
f281c6f464
Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
...
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Sergei Lebedev
c52b3227d1
[pallas:mosaic_gpu] Added a 2D test for emit_pipeline
...
PiperOrigin-RevId: 692945663
2024-11-04 06:38:12 -08:00
Dougal Maclaurin
ec39b592f7
Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
...
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
George Necula
292a00b35a
[export] Cleanup in the export module.
...
With jax.experimental.export gone we can now do some cleanup in the export module.
In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.
PiperOrigin-RevId: 692398132
2024-11-01 22:56:44 -07:00
Ayaka
f60b97cea1
[Pallas TPU] Add lowering for lax.nextafter
...
Also improved the corresponding test cases to ensure better coverage and accuracy.
This PR is similar to https://github.com/jax-ml/jax/pull/22283 , which adds lowering for `lax.sign`.
PiperOrigin-RevId: 691988164
2024-10-31 17:34:38 -07:00
Tzu-Wei Sung
7af7a60dcc
[Pallas:TPU] Use arith.divui for uint32 div.
...
PiperOrigin-RevId: 691939453
2024-10-31 14:37:47 -07:00
Sergei Lebedev
85662f6dd8
[pallas:mosaic_gpu] plgpu.copy_smem_to_gmem
no longer transparently commits SMEM
...
Users are expected to call `pltpu.commit_smem` manually instead.
PiperOrigin-RevId: 691724662
2024-10-31 02:21:10 -07:00
Bart Chrzaszcz
44158ab0e4
#sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
...
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike
Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.
Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.
PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
jax authors
3904ced255
[Mosaic] Test only cl - add triu test, skip bf16 due to select being native bitwidth only
...
PiperOrigin-RevId: 691477248
2024-10-30 10:48:44 -07:00
Sergei Lebedev
409517fcbc
[pallas:mosaic_gpu] Disabled verbose lowering errors in Mosaic GPU tests
...
PiperOrigin-RevId: 691472782
2024-10-30 10:37:32 -07:00
Tzu-Wei Sung
d2f5804449
[Pallas] Add test cases for var + constant.
...
PiperOrigin-RevId: 691450143
2024-10-30 09:37:50 -07:00
Ayaka
8f96e9082a
[Pallas TPU] Add lowerings for scalar absi
...
This PR is a follow-up of https://github.com/jax-ml/jax/pull/24504 , which adds lowerings for scalar `absf` and `rsqrt`.
PiperOrigin-RevId: 691402430
2024-10-30 06:55:34 -07:00
Jake VanderPlas
b65fdcc612
pallas: remove build dependency on jax.experimental.export
...
jax.experimental.export is deprecated, and it looks like the build rule is unused.
PiperOrigin-RevId: 691205626
2024-10-29 16:41:50 -07:00
jax authors
5ad066eeaa
[TPU][Mosaic] Replace tpu lowering (at canonicalization) for repeat with concat (which handles far more cases)
...
PiperOrigin-RevId: 691192121
2024-10-29 15:57:44 -07:00
Adam Paszke
8b21614973
[Pallas:MGPU] Add FlashAttention3 as an example
...
PiperOrigin-RevId: 690977852
2024-10-29 05:21:43 -07:00
Ayaka
a8d1048cb6
[Pallas] Add tests for jnp.logical_not
...
PiperOrigin-RevId: 690825419
2024-10-28 18:53:24 -07:00
Adam Paszke
36c56fa19b
[Pallas:MGPU] Fix flaky debug_print tests
...
Turns out that waiting for the kernel to finish it not enough, since the
prints also need to be processed by the CUDA runtime. Using a test-only
function that synchronizes all the devices seems to suffice.
PiperOrigin-RevId: 690624999
2024-10-28 08:42:02 -07:00
Sergei Lebedev
dfa6fcd56b
[pallas:mosaic_gpu] Extracted a basic emit_pipeline
API from the in kernel pipelining test
...
PiperOrigin-RevId: 690619853
2024-10-28 08:25:47 -07:00
Adam Paszke
343cf18e09
[Pallas:MGPU] Wire up the Mosaic GPU profiler into Pallas
...
PiperOrigin-RevId: 690574747
2024-10-28 05:40:08 -07:00
Ayaka
5c614470ad
[Pallas TPU] Add lowerings for scalar absf
and rsqrt
...
This PR is similar to https://github.com/jax-ml/jax/pull/24284
PiperOrigin-RevId: 689546724
2024-10-24 15:59:34 -07:00
Adam Paszke
bb2e2303d7
[Pallas:MGPU] Treat each warpgroup as a single logical thread.
...
As an extra minor change, we now disallow specifying the predicate when uniform is
unset, as that implies that we're going to use two different mechanisms to select
a single thread.
PiperOrigin-RevId: 689289365
2024-10-24 01:54:10 -07:00