430 Commits

Author SHA1 Message Date
Ayaka
c60bafcc33 [Pallas TPU] Fix lowering for jnp.remainder
Fixes https://github.com/jax-ml/jax/issues/24027

PiperOrigin-RevId: 688614799
2024-10-22 11:01:58 -07:00
Ayaka
2b7b0742a4 [Pallas TPU] Add lowerings for bf16 jnp.ceil and jnp.floor in TPU v6+
This PR is similar to https://github.com/jax-ml/jax/pull/24284

Note that `np.testing.assert_allclose()` is changed to `self.assertAllClose()` because the latter is a wrapper with bfloat16 support.

PiperOrigin-RevId: 688581914
2024-10-22 09:33:53 -07:00
Adam Paszke
2db03ba54b [Pallas:MGPU] Add support for grid dims in GPUMesh
Of course no communication can happen across grid dimensions (unlike over the WG dim),
but we need to be able to launch multiple blocks somehow.

PiperOrigin-RevId: 688488660
2024-10-22 04:10:46 -07:00
Adam Paszke
84a303f32f [Pallas:MGPU] Allow allocating transformed refs in run_scoped
PiperOrigin-RevId: 688448592
2024-10-22 01:38:46 -07:00
Adam Paszke
f833891c87 [Pallas:MGPU] Add support for passing in WGMMA lhs from registers
PiperOrigin-RevId: 688117316
2024-10-21 06:42:18 -07:00
Adam Paszke
f08801b8d6 [Pallas:MGPU] Allow indexing to appear anywhere in the list of transforms
We only need to exchange the transforms preceding the indexer, while
the rest can remain unmodified.

PiperOrigin-RevId: 688112088
2024-10-21 06:22:16 -07:00
Adam Paszke
bbcc3eef3c [Pallas:MGPU] Fix the implementation of WGMMA with transposed RHS
It's not enough that we have the physical transpose between the order
of tiled dimensions, we also need the user to explicitly transpose the
logical dimensions. This fixes a shape error that was previously hidden
because the RHS was square.

PiperOrigin-RevId: 687350270
2024-10-18 10:31:42 -07:00
Christos Perivolaropoulos
f8a3c0366b [pallas] run_scoped now supports partial discharge.
PiperOrigin-RevId: 687347284
2024-10-18 10:22:31 -07:00
Adam Paszke
0ee9531ef2 [Pallas:MGPU] Add support for indexed refs to WGMMA
PiperOrigin-RevId: 687258992
2024-10-18 04:55:34 -07:00
Adam Paszke
f2edc83af3 [Pallas:MGPU] Properly commute indexing with other transforms
Doing so requires us to modify the other transforms when we attempt to
move indexing before them.

PiperOrigin-RevId: 687240515
2024-10-18 03:39:51 -07:00
Adam Paszke
2d78b17226 [Pallas:MGPU] Add support for transforms in user-specified async copies
PiperOrigin-RevId: 687019020
2024-10-17 13:10:45 -07:00
jax authors
6c2649fdf2 Rewrite mosaic concat to support operand shapes that do not align with native shapes, Expand tests to cover multi operand, batch dim concat, etc.
PiperOrigin-RevId: 687003778
2024-10-17 12:24:51 -07:00
Sergei Lebedev
de7beb91a7 [pallas:mosaic_gpu] Added layout_cast
PiperOrigin-RevId: 686917796
2024-10-17 08:08:05 -07:00
Adam Paszke
0519db15ab [Pallas:MGPU] Add lowerings for more ops
PiperOrigin-RevId: 686910947
2024-10-17 07:42:56 -07:00
Adam Paszke
f72376ae0a [Pallas:MGPU] Add support for debug_print of arrays that use the WGMMA layout
PiperOrigin-RevId: 686885229
2024-10-17 06:06:16 -07:00
Adam Paszke
ef361f05a4 [Mosaic GPU] Add support for launching multiple warpgroups using core_map
PiperOrigin-RevId: 686876014
2024-10-17 05:30:48 -07:00
Bart Chrzaszcz
801fe87da6 Do not allow None axis names in meshes.
PiperOrigin-RevId: 686557025
2024-10-16 10:32:25 -07:00
Mantas Pajarskas
1222b4a571 [Pallas TPU] Add a better error message for rank 1 block mappings check.
Currently, the error message refers to "last two dimensions" which is confusing for a rank-1 case; furthermore, the error does not match the check in the code.

PiperOrigin-RevId: 686520781
2024-10-16 08:41:22 -07:00
Sergei Lebedev
4c0d82824f [pallas:mosaic_gpu] Added a few more operations necessary to port Flash Attention
PiperOrigin-RevId: 686451398
2024-10-16 04:05:36 -07:00
Ayaka
5ac2076fb7 [Pallas TPU] Fix boolean comparison
Fixes https://github.com/jax-ml/jax/issues/24030

Also added tests to cover all scalar comparison cases.

PiperOrigin-RevId: 686197357
2024-10-15 12:24:58 -07:00
Jevin Jiang
3a7d9137a4 [Pallas TPU] Support ref reshape.
Jaxpr example:
```
{ lambda ; a:MemRef<None>{int32[32,256]} b:MemRef<None>{int32[8,128]}. let
    c:i32[8,128] <- a[:16,:][bitcast(int16[32,256])][reshape(int16[2,16,256])][bitcast(float16[2,16,256])][1:,:,:][reshape(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
    b[:,:] <- c
  in () }
```

Tested:

- DMA with reshaped ref
- Load from reshaped ref
- Store to reshaped ref
- Multiple transforms
- Interpret Mode for ref transforms (updated discharge rules)

PiperOrigin-RevId: 686186426
2024-10-15 11:52:15 -07:00
Sharad Vikram
cd78c653e7 [Pallas] Use core_map instead of shard_map for Shmallas
- core_map is like a shard_map but it takes in no inputs and outputs
- we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU)
- we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target

PiperOrigin-RevId: 686036101
2024-10-15 03:26:58 -07:00
Ayaka
bfc3d3cd18 [Pallas TPU] Add lowerings for scalar sin, cos, tan and tanh
This PR is similar to https://github.com/jax-ml/jax/pull/24238

PiperOrigin-RevId: 685842905
2024-10-14 14:49:11 -07:00
jax authors
1f0b5728a4 Add a memory saving index rewrite step to vmap with ragged inputs over pallas_call.
The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op.

PiperOrigin-RevId: 685827096
2024-10-14 14:01:42 -07:00
Justin Fu
cff9e93824 [Pallas] Add runtime assert via checkify.check. This check will halt the TPU if triggered, meaning that we would need to restart the program to recover.
PiperOrigin-RevId: 684940271
2024-10-11 13:34:04 -07:00
Sergei Lebedev
59ae2af699 [pallas:mosaic_gpu] Added a test doing manual in kernel pipelining
I think we have most of the primitives necessary, so the next step is to sketch
`emit_pipeline`.

PiperOrigin-RevId: 684840800
2024-10-11 08:08:04 -07:00
Sergei Lebedev
acd0e497af [pallas:mosaic_gpu] GPUBlockSpec no longer accepts swizzle
It was previously possible to pass `swizzle` both directly and via `transforms`.
This change eliminates the ambiguity at a slight downgrade to ergonomics.

PiperOrigin-RevId: 684797980
2024-10-11 05:11:26 -07:00
Ayaka
633cb31577 [Pallas TPU] Add lowering for scalar jnp.log1p
Fixes https://github.com/jax-ml/jax/issues/24239

PiperOrigin-RevId: 684650608
2024-10-10 18:44:41 -07:00
Ayaka
3bd8ca480a [Pallas] Add tests for scalar elementwise operations
On TPU, instructions differentiate between vectors and scalars, and the corresponding lowering paths are different. Existing Pallas tests only test vector version of operations, but not the scalar version of them. This PR adds tests for scalar elementwise operations.

The structure of the test is similar to the vector version of the tests above.

PiperOrigin-RevId: 684569107
2024-10-10 14:00:21 -07:00
Peter Hawkins
19dbff5326 Move additional CI enabled/disabled configurations into jax BUILD files.
PiperOrigin-RevId: 684457403
2024-10-10 08:41:45 -07:00
Sergei Lebedev
70ee8e1161 [pallas:mosaic_gpu] pl.run_scoped now supports scoped barriers
PiperOrigin-RevId: 684449776
2024-10-10 08:16:13 -07:00
Justin Fu
73418427a8 [Pallas] Add lowering for threefry PRNG.
PiperOrigin-RevId: 684179182
2024-10-09 14:48:26 -07:00
Ayaka
3fc4ba29ea [Pallas TPU] Add lowerings for lax.population_count_p and lax.clz_p
PiperOrigin-RevId: 684158096
2024-10-09 13:46:29 -07:00
Andrey Portnoy
2c731320af Use py_deps("absl/testing") instead of //third_party/py/absl/testing 2024-10-09 15:24:40 -04:00
Ayaka
77613f21aa [Pallas TPU] Fix comparison lowering for unsigned integers
Fixes https://github.com/jax-ml/jax/issues/23972.

In Pallas, we use `i32` for both `jnp.int32` and `jnp.uint32`, but we need to choose the correct operation (e.g. `arith.extUI` vs `arith.extSI`) or the correct attribute (e.g. `sle` vs `ule` for `arith::CmpIOp`).

In this particular issue, we need to use attributes like `ule` for `jnp.uint32`, but it's currently lowered to attributes for `jnp.int32` such as `sle`.

This PR fixes this issue by distinguishing the attributes to use for signed and unsigned types.

PiperOrigin-RevId: 684065893
2024-10-09 09:28:53 -07:00
Justin Fu
9cf952a535 [Pallas] Add support for runtime checking of grid bounds using checkify.
PiperOrigin-RevId: 683791662
2024-10-08 15:48:16 -07:00
Adam Paszke
25c1519a84 [Pallas/MGPU] Allow delaying the release of pipelined buffers
This is useful so that we don't have to block on the WGMMA immediately after it runs.
`delay_release=n` means that the input/output buffers will not be mutated by the system
for at least `n` sequential steps following the one when they were kernel arguments.

PiperOrigin-RevId: 683629935
2024-10-08 08:17:58 -07:00
Ayaka
6a958b90b3 [Pallas] Simplify OpsTest by skipping 64-bit tests on 32-bit environments
This PR is similar to https://github.com/jax-ml/jax/pull/23814.

Background: We run tests on both 32-bit and 64-bit environments. Currently, when the tests encounters 64-bit dtypes on 32-bit environments, it enters into a local 64-bit environment using `stack.enter_context(config.enable_x64(True))`. This is not necessary since we also run the same tests on 64-bit environments. This PR makes those test skipped on 32-bit environments.
PiperOrigin-RevId: 683405197
2024-10-07 18:41:14 -07:00
Christos Perivolaropoulos
9ac6723561 [pallas:mosaic_gpu] Dereferencing the accumulator now supports slicing
PiperOrigin-RevId: 683235013
2024-10-07 10:33:08 -07:00
Sergei Lebedev
95631a7d92 Added jax.experimental.pallas.mosaic_gpu
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.

PiperOrigin-RevId: 683119193
2024-10-07 04:05:08 -07:00
Sergei Lebedev
aadb50905c [pallas:mosaic_gpu] Allowed indexing refs with scalars
The transforms do not yet handle this case, so only the basic indexing works.

PiperOrigin-RevId: 682273046
2024-10-04 04:54:37 -07:00
Ayaka
cb2e0e2ced [Pallas TPU] Add lowering for lax.ceil_p
This PR uses exactly the same approach as https://github.com/jax-ml/jax/pull/24083, which adds lowering for `lax.floor_p`.

PiperOrigin-RevId: 682073765
2024-10-03 16:23:18 -07:00
Christos Perivolaropoulos
5800070c36 [pallas:mosaic_gpu] add logistic op and some tests for unary operations
PiperOrigin-RevId: 681889064
2024-10-03 08:25:44 -07:00
Sergei Lebedev
905c83c781 [pallas:mosaic_gpu] Support indexing barriers
A barrier must be indexed via `.at` and not directly. I wish we could emit
an instructive error for the latter case, but I couldn't find a good place
to put it.

PiperOrigin-RevId: 681857034
2024-10-03 06:48:03 -07:00
Sergei Lebedev
5a2e5a5a94 [pallas:mosaic_gpu] Copy primitives now support slices
I decided to

* split `async_copy_p` into multiple primitives to avoid having extra
  control flow in the lowering rule;
* drop the `async_*` prefix from `async_copy_p` and the corresponding
  APIs, because the names felt a bit too long otherwise.

Note that barriers cannot be sliced at the moment. I will address that in
a follow up CL.

PiperOrigin-RevId: 681793650
2024-10-03 02:54:20 -07:00
Ayaka
b5ce44536b [Pallas TPU] Add lowering for lax.floor_p
This is a follow-up of https://github.com/jax-ml/jax/pull/24056, which adds lowering for `lax.tan_p`.

PiperOrigin-RevId: 681793238
2024-10-03 02:52:26 -07:00
Adam Paszke
c9f946ef57 Only thread a discharged ref value through a cond when it changes in some branch
Otherwise, we can simply pass it in as an argument, but we can avoid updating it
since it will always remain constant. Both programs have equivalent semantics,
but this one can be optimized better since it makes it more apparent that the
cond does not actually modify a ref.

PiperOrigin-RevId: 681482148
2024-10-02 09:29:07 -07:00
Adam Paszke
e2d3bd866a [Pallas/MGPU] Add support for tiled and swizzled loads/stores + support slices
PiperOrigin-RevId: 681370464
2024-10-02 02:44:10 -07:00
Sharad Vikram
c34e25d6f4 [Pallas] Add state discharge rule for pallas_call
This enables us to avoid spurious copies in the cases outlined in [the async operations design note](https://jax.readthedocs.io/en/latest/pallas/async_note.html) but not in general, since JAX and/or XLA could introduce copies because we have value semantics. For a proper solution, we need to introduce some notion of buffer semantics to XLA/HLO and preserve it through the lowering of stateful JAX (maybe by avoiding state discharge altogether).

PiperOrigin-RevId: 681206784
2024-10-01 16:30:56 -07:00
Ayaka
e361868132 [Pallas TPU] Add lowering for lax.tan_p
This is a follow-up of https://github.com/jax-ml/jax/pull/24028, which adds lowering for `lax.cos_p`

PiperOrigin-RevId: 681180835
2024-10-01 15:09:52 -07:00