684 Commits

Author SHA1 Message Date
Charles Hofer
41ab12bf8d Skip PallasCallRemoteDMAInterpretTest.test_interpret_remote_dma_ppermute for failing on ROCm 2025-01-27 20:46:14 +00:00
Jake VanderPlas
8e97272332 tpu_pallas_distributed_test: skip unless on TPU 2025-01-27 12:41:28 -08:00
Peter Hawkins
a7a885aa5a Disable a Pallas ops test that fails on TPU v6e.
PiperOrigin-RevId: 720194737
2025-01-27 09:15:14 -08:00
Gleb Pobudzey
e0b38f4e56 Adding GPU paged attention kernel 2025-01-24 17:13:02 +00:00
Adam Paszke
c10b9b88f2 [Pallas:MGPU] Add helpers to make writing core_map kernels less verbose
Also add small "getting started" examples that use the helpers in tests.

PiperOrigin-RevId: 719303512
2025-01-24 07:59:26 -08:00
Jevin Jiang
8e1f956804 [Mosaic TPU] Use vmask pack if possible for mask's bitwidth change and introduce relayout op.
PiperOrigin-RevId: 719089676
2025-01-23 18:15:08 -08:00
Bart Chrzaszcz
db8c8fc37c #sdy unskip JAX Shardy tests that are already passing
PiperOrigin-RevId: 718898708
2025-01-23 09:26:38 -08:00
Justin Fu
10bb38bb79 [Mosaic GPU] Add manual consumed barrier handling to WS pipeline.
PiperOrigin-RevId: 718451678
2025-01-22 10:59:58 -08:00
jax authors
e304e9ea16 Merge pull request #25992 from gnecula:debug_info_arg_names
PiperOrigin-RevId: 718216003
2025-01-21 22:17:08 -08:00
Jevin Jiang
908df65a26 [Mosaic TPU] Emulate converting x16 vector to mask if mask packing is supported.
PiperOrigin-RevId: 718133639
2025-01-21 17:23:33 -08:00
jax authors
96a3ed36c7 Add part (non-quantized K/V pages) of paged_attention_kernel tests back for TPU v6.
The paged_attention_kernel tests for TPU v6 was disabled in the past but I discovered that all the failing tests have `are_kv_quantized=True`. So we can still test the non-quantized part on TPU v6.

PiperOrigin-RevId: 717969073
2025-01-21 10:12:52 -08:00
jax authors
70a5175d0a Merge pull request #25647 from Rifur13:bwd_pass
PiperOrigin-RevId: 717954065
2025-01-21 09:37:41 -08:00
Adam Paszke
3c8cf3c92e [Pallas] Improve testing for casts from narrow types + test int4
For sub-32bit types there are so few distinct values that we can just
exhaustively test them all.

PiperOrigin-RevId: 717879040
2025-01-21 05:56:10 -08:00
George Necula
3f73f7b0eb [better_errors] Ensure debug_info.arg_names is never None.
Most places in the code assumed this already, but often
that usage is error reporting code, which is not yet well tested.

When we cannot get the `inspect.Signature` or when the
args and kwargs do not match the signature, we generate
the flattened argument names as: `args[0]`, `args[1]`,
`kwargs['foo']`, ... Previously, in these cases we
returned `arg_names` is None, and then the whole
debug_info ended up being `None`, throwing away even
available information.

We also add support for `api_util.fun_sourceinfo` even
for cases when the `fun.__code__` is not available. In
those cases we used to say that `fun_sourceinfo` is
`None`. Now, we use the string representation of `fun`
to get the name of built-in functions, or we use "<unknown>".
2025-01-21 13:38:10 +01:00
Gleb Pobudzey
b8b9f2bc33 Fix the backwards pass and support more block sizes. 2025-01-21 02:48:37 +00:00
George Necula
4fd0bb05b1 [better_errors] Finally remove api_util.debug_info.
Following https://github.com/jax-ml/jax/pull/25916 there were a few TODOs
left in the code to remove api_util.debug_info and replace the
one remaining use with api_util.tracing_debug_info.

PiperOrigin-RevId: 717583667
2025-01-20 11:19:53 -08:00
Adam Paszke
543dd94762 [Mosaic TPU] Add a faster implementation for packing b16 to s8 in TPUv6
PiperOrigin-RevId: 717583425
2025-01-20 11:18:22 -08:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
Gleb Pobudzey
2cdd9b7dd9 Fixing bwd attention test tolerance level 2025-01-17 01:41:51 +00:00
Adam Paszke
8954e71d73 [Mosaic TPU] Improve support for int16->int32 casts in TPUv4
PiperOrigin-RevId: 716250236
2025-01-16 08:44:10 -08:00
Adam Paszke
ef4dbd9cb9 [Mosaic TPU] Add support for packing to 16-bit integers on TPUv4
And refactor some test conditions to better match what we really support.
The tests were failing on older TPUs.

PiperOrigin-RevId: 716214098
2025-01-16 06:39:23 -08:00
Zac Mustin
2d72e8de84 Jax: Stop returning a list of cost-analyses.
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.

This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available)) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.

PiperOrigin-RevId: 715837855
2025-01-15 09:53:59 -08:00
Adam Paszke
aa19f9c4c4 [Pallas TPU] Temporarily strengthen restrictions on Pallas tests
Mosaic is not more aggressive in its inference of large 2nd minor layouts,
which causes slight problems for Pallas pipelines. This will be addressed
shortly.

PiperOrigin-RevId: 715714752
2025-01-15 02:32:14 -08:00
jax authors
c18492be65 [pallas][mosaic kernel export] Add initial support for exporting a dynamic shapes (placeholder bound) kernel out of mosaic, via pallas as both MLIR and jaxpr.
PiperOrigin-RevId: 715629439
2025-01-14 20:34:11 -08:00
Justin Fu
cc9f6e7528 [Pallas] Fix GQA triton kernel test.
PiperOrigin-RevId: 715576240
2025-01-14 16:40:55 -08:00
Peter Hawkins
d1810b42cb Temporarily disable GQA attention tests on GPU, which were broken by a Triton integrate.
PiperOrigin-RevId: 715516188
2025-01-14 13:48:37 -08:00
Justin Fu
ff5cb811e6 [Mosaic GPU] Enable x64 tests for mosaic gpu.
PiperOrigin-RevId: 715496496
2025-01-14 13:02:48 -08:00
Peter Hawkins
f122f17b27 Rename test configs to include GPU variants more consistently.
* Include "p100" or "v100" in the default "gpu" config names, matching their current CI configuration.
* Rename "_2gpu" test variants to "x2" variants, since this is more succinct.

This change is intended to be a pure renaming, and it is not intended to alter the set of tests that run.

PiperOrigin-RevId: 715468944
2025-01-14 11:55:45 -08:00
Ayaka
9ba1fd2801 [Pallas TPU] Add vector support to pl.debug_print
PiperOrigin-RevId: 715085454
2025-01-13 13:22:21 -08:00
Adam Paszke
aa51f2af47 [Pallas TPU] Skip cast test incompatible with older libtpu builds
PiperOrigin-RevId: 714975806
2025-01-13 08:24:01 -08:00
Justin Fu
8e86bede9f [Mosaic GPU] Allow multiple gmem indexers on copies.
This is implemented by merging multiple indexers into one.

PiperOrigin-RevId: 714150733
2025-01-10 13:12:50 -08:00
Justin Fu
73b64b8e56 [Mosaic GPU] Enable loop carries in the pipeline emitter.
PiperOrigin-RevId: 714141077
2025-01-10 12:40:42 -08:00
Adam Paszke
d2a5e8d072 [Mosaic TPU] Add support for integer truncation from packed types
PiperOrigin-RevId: 714048232
2025-01-10 07:40:55 -08:00
Sergei Lebedev
18018d9cc9 [pallas:mosaic_gpu] Tests now pass with x64 enabled
PiperOrigin-RevId: 714005603
2025-01-10 04:48:36 -08:00
Adam Paszke
74cf67df9d [Pallas] Improve testing for lowering of dtype conversions + fix uncovered bugs
We previously weren't testing unsigned integer types.

PiperOrigin-RevId: 714002869
2025-01-10 04:35:38 -08:00
Chris Jones
a27566cc7b Reverts dbe9ccd6dccd83c365021677c7e17e843d4559c4
PiperOrigin-RevId: 713989952
2025-01-10 03:40:18 -08:00
Adam Paszke
07f4fd3e51 [Mosaic TPU] Fix a bug in the impl of sublane broadcasts for int8 and int4
PiperOrigin-RevId: 713675029
2025-01-09 08:05:25 -08:00
Justin Fu
d99a637d8b [Mosaic GPU] Allow multiple indexing on refs
PiperOrigin-RevId: 713355813
2025-01-08 11:21:19 -08:00
Sergei Lebedev
f1f98afee8 [pallas:mosaic_gpu] Fix the tests following the changes to pl.core_map
PiperOrigin-RevId: 713283207
2025-01-08 07:24:08 -08:00
Adam Paszke
5fd1b2f825 [Mosaic TPU] Add support for second minor broadcasts with packed types
PiperOrigin-RevId: 713259707
2025-01-08 05:45:02 -08:00
Adam Paszke
e954930eaf [Mosaic TPU] Add support for true divide in bf16 on TPUv6
PiperOrigin-RevId: 713247480
2025-01-08 04:49:22 -08:00
Justin Fu
8c9a539405 [Pallas] Fix pallas_call lowering mutating compiler params during Triton lowering.
Addresses: https://github.com/jax-ml/jax/issues/25714
PiperOrigin-RevId: 712930760
2025-01-07 09:01:51 -08:00
Adam Paszke
7c984c600b Don't use x32 mode for pallas_test
There's no need to, and it caused our GPU tests for this target to only
run nightly.

PiperOrigin-RevId: 711406571
2025-01-02 06:23:32 -08:00
Adam Paszke
dbe9ccd6dc Reverts 83e60a9697ec20023f4e11169edf64e910b93031
PiperOrigin-RevId: 711403091
2025-01-02 06:04:14 -08:00
Tomás Longeri
3c79b98cd9 [Mosaic:TPU] Vreg-slice-aligned offset changes with scratch retiling
PiperOrigin-RevId: 709133729
2024-12-23 13:05:14 -08:00
Chris Jones
83e60a9697 [pallas:triton] Add support for lowering int4 load.
PiperOrigin-RevId: 709032308
2024-12-23 05:12:46 -08:00
jax authors
1719986aaa [Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering.
This CL builds out a simple sketch of constant prop by construction in mosaic - we walk the graph up from cond, collecting the values and either const propping or failing out of const prop. Failure out of const prop is not a bug, but hitting an unimplemented const prop func is for now, in order to drive better coverage.

This then allows us to pick a single branch, and ignore branches which do not have a viable mosaic implementation.

And, finally, for diag, this means we can replace the initial gather-dependent implementation in lax with a mosaic specific one that avoids gather.

PiperOrigin-RevId: 708752566
2024-12-22 00:50:51 -08:00
Christos Perivolaropoulos
20efbd965f [pallas:mosaic_gpu] Change the fori tests to also take the while_p path and fix the bug.
The bug was that bounds were dropped ctx.avals_in and then they were being
extracted. Extract them before dropping them.

PiperOrigin-RevId: 708266659
2024-12-20 03:50:34 -08:00
Jevin Jiang
2faf540203 [Mosaic TPU] Add relayout-insertion pass and support bitwidth change for i1 vector relayout
We can use relayout-insertion pass to insert necessary ops and their layouts for relayout before unrolling in apply-vector-layout pass.

PiperOrigin-RevId: 708143852
2024-12-19 19:56:40 -08:00
Adam Paszke
de8fa8fd19 [Mosaic TPU] Add support for sqrt and rsqrt in bf16 on TPUv6
PiperOrigin-RevId: 708016513
2024-12-19 13:42:38 -08:00