We don't support Windows GPU builds right now and skip all the tests,
but at the moment they can't even skip because of the import failure.
PiperOrigin-RevId: 726917651
The corresponding Triton op is restricted to `jnp.stack([x, y], axis=-1)`,
so the lowering only supports that case for now.
See #25321.
PiperOrigin-RevId: 726881284
The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.
This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.
PiperOrigin-RevId: 726030853
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
This will allow us to lower Pallas kernels using the Mosaic GPU dialect, and
in turn to perform layout inference and optimization automatically.
The change contains lowering rules for `get` and `swap` (which are necessary
to get a basic example to run), as well as for `add`.
The new lowering path can be used by specifying the `Warpgroup` thread
semantics as part of `pallas_call`'s compiler params.
PiperOrigin-RevId: 725958027
Some pallas kernels shouldn't be CSEd even if they share the same inputs.
For example in async pallas scenarios like when you have a kernel starting some DMAs
that are waited in the user of the kernel (to perform async copies) we can't CSE or kernels
might wait multiple times on a DMA that happens only one.
PiperOrigin-RevId: 725752913
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
The goal of this interpret mode is to run a Pallas TPU kernel on CPU,
while simulating a TPU's shared memory, multiple devices/cores, remote
DMAs, and synchronization.
The basic approach is to execute the kernel's Jaxpr on CPU, but to
replace all load/store, DMA, and synchronization primitives with
io_callbacks to a Python functions that simulate these primitives.
When this interpret mode is run inside of shard_map and jit, the
shards will run in parallel, simulating the parallel execution of the
kernel on multiple TPU devices.
The initial version in this PR can successfully interpret the examples
in https://jax.readthedocs.io/en/latest/pallas/tpu/distributed.html ,
but is still missing a lot of functionality, including:
- Executing DMAs asynchronously.
- Padding in pallas_call.
- Propagating source info.
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.
See #25196.
A few notes
* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
compilation is done via a new PjRt extension, which uses its own compilation
pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
deprecated and will be removed after 6 months as per
[compatibility guarantees] [*]
[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees
PiperOrigin-RevId: 722773884
The user has access only to accumulator references and they can't pass them as caries to loops. However when they are discharged these accumulators become values and become part of the carry. Before this CL this would surprise the loop lowering code.
This was never a problem for pallas mgpu until we added pipelining loops instead of sequential bloc axes.
PiperOrigin-RevId: 722495749