684 Commits

Author SHA1 Message Date
Jacob Burnim
ac74857d27 [Pallas] Support dynamic grids in the new TPU interpret mode
PiperOrigin-RevId: 728786896
2025-02-19 13:09:23 -08:00
Jevin Jiang
bb68124c33 [Mosaic TPU] Support mask concat
PiperOrigin-RevId: 728349788
2025-02-18 14:03:46 -08:00
Sergei Lebedev
d4559ba404 [pallas] Skip OpsTest.test_concat_constant if TPU is not available
PiperOrigin-RevId: 728269127
2025-02-18 10:38:35 -08:00
jax authors
eaceac3bf9 [Pallas] Reductions with replicated axes.
PiperOrigin-RevId: 727292293
2025-02-15 07:41:16 -08:00
Marcello Maggioni
9a8c9a56cf [JAX] Allow pallas to accept scalar shape semaphores.
PiperOrigin-RevId: 727198066
2025-02-14 23:20:47 -08:00
Adam Paszke
b287c3924a Ignore ImportError for Triton on Windows
We don't support Windows GPU builds right now and skip all the tests,
but at the moment they can't even skip because of the import failure.

PiperOrigin-RevId: 726917651
2025-02-14 07:17:49 -08:00
Christos Perivolaropoulos
49ad24152c [pallas:mgpu] Change FA3 kernel bc lax.div doesn't like mixed types anymore.
PiperOrigin-RevId: 726883573
2025-02-14 05:10:49 -08:00
Sergei Lebedev
3162cc4d0d [pallas:triton] Added basic support for lax.concatenate
The corresponding Triton op is restricted to `jnp.stack([x, y], axis=-1)`,
so the lowering only supports that case for now.

See #25321.

PiperOrigin-RevId: 726881284
2025-02-14 05:02:53 -08:00
jax authors
f0cd1686ec Merge pull request #26509 from andportnoy:aportnoy/pallas-mosaic-gpu-test-sm90a
PiperOrigin-RevId: 726624339
2025-02-13 13:52:31 -08:00
Adam Paszke
b0b1fa7dad Skip pipeline mode args in tests with older libTPU
PiperOrigin-RevId: 726480896
2025-02-13 07:39:16 -08:00
Andrey Portnoy
54fa1b9aa5 [Mosaic GPU] Factor out arch specific Pallas Mosaic GPU tests 2025-02-13 10:29:22 -05:00
Christos Perivolaropoulos
305e55f323 [pallas:mgpu] Fix and test multiple indexers where one is a dynamic selection index.
PiperOrigin-RevId: 726447413
2025-02-13 05:53:08 -08:00
Jevin Jiang
876668faa1 [Mosaic TPU] Support bf16 div if HW does not directly support.
PiperOrigin-RevId: 726212286
2025-02-12 15:04:09 -08:00
Benjamin Chetioui
c7199fe8a5 [Pallas/Mosaic GPU] Enable progressive lowering for integer addition.
The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.

This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.

PiperOrigin-RevId: 726030853
2025-02-12 06:29:25 -08:00
jax authors
1e2a5770c9 Merge pull request #26455 from gnecula:debug_info_jaxpr_8
PiperOrigin-RevId: 726023315
2025-02-12 06:03:32 -08:00
George Necula
faa0ad6f33 [better_errors] Continue adding debug info to Jaxprs (step 8)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
2025-02-12 14:23:52 +01:00
Benjamin Chetioui
5ad89006c3 [Pallas/Mosaic GPU] Add initial support for warpgroup semantics in lowering.
This will allow us to lower Pallas kernels using the Mosaic GPU dialect, and
in turn to perform layout inference and optimization automatically.

The change contains lowering rules for `get` and `swap` (which are necessary
to get a basic example to run), as well as for `add`.

The new lowering path can be used by specifying the `Warpgroup` thread
semantics as part of `pallas_call`'s compiler params.

PiperOrigin-RevId: 725958027
2025-02-12 01:47:49 -08:00
Marcello Maggioni
6c6b5ec582 [JAX/Pallas] Add has_side_effect parameter to CompilerParams to stop CSE of operations.
Some pallas kernels shouldn't be CSEd even if they share the same inputs.
For example in async pallas scenarios like when you have a kernel starting some DMAs
that are waited in the user of the kernel (to perform async copies) we can't CSE or kernels
might wait multiple times on a DMA that happens only one.

PiperOrigin-RevId: 725752913
2025-02-11 13:33:01 -08:00
Charles Hofer
3745591d68 Merge branch 'rocm-main' into ci-upstream-sync-112_1 2025-02-11 20:36:19 +00:00
jax authors
d3ed6ca0cc Re-enable oss paged attn kernel
PiperOrigin-RevId: 725411244
2025-02-10 17:47:22 -08:00
jax authors
b7d012281e Merge pull request #26423 from gnecula:debug_info_jaxpr_7
PiperOrigin-RevId: 725317552
2025-02-10 12:58:26 -08:00
Sergei Lebedev
700298fd82 [pallas:triton] Updated :export_pallas_test following the changes in Pallas Triton lowering
PiperOrigin-RevId: 725293815
2025-02-10 11:48:36 -08:00
Charles Hofer
31c1f25425 Merge branch 'rocm-main' into ci-upstream-sync-110_1 2025-02-10 18:02:25 +00:00
jax authors
6740165e4f [Pallas] Add pipeline mode to pltpu
PiperOrigin-RevId: 725133131
2025-02-10 02:36:44 -08:00
George Necula
817b3e5757 [better_errors] Continue adding debug info to Jaxprs (step 7)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
2025-02-09 18:14:33 +02:00
Gleb Pobudzey
cd0753751c Increase the absolute error tolerance to fix flaky tests.
PiperOrigin-RevId: 724424293
2025-02-07 11:59:13 -08:00
Sergei Lebedev
e5058079c9 [pallas:mosaic_gpu] Fixed a bug in how delay_release is handled in emit_pipeline
PiperOrigin-RevId: 724395676
2025-02-07 10:37:21 -08:00
Sergei Lebedev
35351f95e4 [pallas:triton] Really revert to the lowering using Triton IR
PiperOrigin-RevId: 724329911
2025-02-07 06:55:14 -08:00
jax authors
4b86ff22e9 Merge pull request #25097 from jburnim:jburnim_pallas_interpret_mode
PiperOrigin-RevId: 724073443
2025-02-06 14:22:49 -08:00
Jacob Burnim
1c82484c9b Start a new TPU interpret mode for Pallas.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.

The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.

The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:

 - Executing DMAs asynchronously.

 - Padding in pallas_call.

 - Propagating source info.
2025-02-06 13:04:14 -08:00
Ayaka
afad924de7 [Pallas TPU] Remove obsolete skip condition
PiperOrigin-RevId: 723963888
2025-02-06 09:23:21 -08:00
Peter Buchlovsky
9f53dfae0b [pallas_mgpu] Fix emit_pipeline_with_wgmma test.
PiperOrigin-RevId: 723547617
2025-02-05 09:47:50 -08:00
Charles Hofer
c3e27f86bc Merge branch 'rocm-main' into ci-upstream-sync-106_1 2025-02-04 17:28:34 +00:00
Peter Buchlovsky
c7d535d3c9 [pallas_mgpu] Add a test for emit_pipeline with wgmma.
PiperOrigin-RevId: 723012611
2025-02-04 03:25:02 -08:00
Jevin Jiang
124e123946 [Pallas] Support promise_in_bounds mode in jnp.take_along_axis.
Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds".

PiperOrigin-RevId: 722930215
2025-02-03 22:06:19 -08:00
Sergei Lebedev
7929cd8410 [pallas:triton] The lowering now uses PTX instead of Triton IR
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.

See #25196.

A few notes

* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
  compilation is done via a new PjRt extension, which uses its own compilation
  pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
  deprecated and will be removed after 6 months as per
  [compatibility guarantees] [*]

[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees

PiperOrigin-RevId: 722773884
2025-02-03 13:21:40 -08:00
Charles Hofer
ec2c3050d3 Merge branch 'rocm-main' into ci-upstream-sync-104_1 2025-02-03 16:20:36 +00:00
Christos Perivolaropoulos
b48d15d788 [pallas_mgpu] For loops can have **non-ref** accumulators for carries.
The user has access only to accumulator references and they can't pass them as caries to loops. However when they are discharged these accumulators become values and become part of the carry. Before this CL this would surprise the loop lowering code.

This was never a problem for pallas mgpu until we added pipelining loops instead of sequential bloc axes.

PiperOrigin-RevId: 722495749
2025-02-02 21:03:26 -08:00
Jevin Jiang
ed952c8e65 [Pallas TPU] Support jnp.take_along_axis for 32-bit vreg-sized vector.
PiperOrigin-RevId: 722015152
2025-01-31 21:27:08 -08:00
Adam Paszke
c1e136058c Re-enable Pallas bf16 exp2 tests on TPU
PiperOrigin-RevId: 721784841
2025-01-31 08:36:57 -08:00
Gleb Pobudzey
8c02731a06 Increasing shard count and removing asan builds to prevent timeouts.
PiperOrigin-RevId: 721038112
2025-01-29 10:58:51 -08:00
Anselm Levskaya
23c8607bab disable kernel test due to races.
PiperOrigin-RevId: 720715578
2025-01-28 14:47:36 -08:00
Charles Hofer
47580efda5 Merge branch 'rocm-main' into ci-upstream-sync-98_1 2025-01-28 21:18:47 +00:00
Gleb Pobudzey
7a4a53ad9e Add win32 guard to fix imports on Windows
PiperOrigin-RevId: 720625818
2025-01-28 10:32:19 -08:00
jax authors
24987a90dc Merge pull request #26134 from justinjfu:pallas_accum_bugfix
PiperOrigin-RevId: 720374819
2025-01-27 18:05:57 -08:00
Justin Fu
54ac172b4c [Pallas] Refactor Pallas HLO interpret mode to a standalone file.
Also replaces the interpreter context (used only for handling extended dtypes) with a physicalize Jaxpr pass.

PiperOrigin-RevId: 720371033
2025-01-27 17:52:27 -08:00
jax authors
bc130c7ba6 Merge pull request #25213 from Rifur13:dynamic_mask
PiperOrigin-RevId: 720361301
2025-01-27 17:16:12 -08:00
Gleb Pobudzey
4fe937683e Fix import for Windows platforms
PiperOrigin-RevId: 720348679
2025-01-27 16:33:37 -08:00
Gleb Pobudzey
c0d23af42c Support dynamic masks in splash attention 2025-01-28 00:14:53 +00:00
Justin Fu
7ace72fb3a [Pallas] Be explicit about accumulation dtype in reference implementations 2025-01-27 22:09:29 +00:00