365 Commits

Author SHA1 Message Date
Peter Hawkins
26632fd344 Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
2024-09-27 06:15:31 -07:00
Sergei Lebedev
afaf8b823d Run Pallas Mosaic GPU tests on internal CI
PiperOrigin-RevId: 679508320
2024-09-27 02:43:35 -07:00
jax authors
ea86251a60 [Pallas:TPU] Fix lowering of convert_element_type(int32) -> bool.
We need to add a condition on vector type since both operands of arith::CmpIOp must have same type.

PiperOrigin-RevId: 679500783
2024-09-27 02:15:35 -07:00
Justin Fu
9f4e8d0039 [XLA:Mosaic][Pallas] Enable vector.ExtractOp for non-zero indices.
PiperOrigin-RevId: 679283281
2024-09-26 13:57:45 -07:00
Adam Paszke
076287fb5c [Pallas/MGPU] Implement block spec evaluation correctly
The preivous implementation made some surprising assumptions about the contents
of the block specs and wasn't correct in general. The new implementation handles
all the cases and seems to be sufficient to finally run the matmul example with
multiple k steps while producing correct results (it's also shorter!).

PiperOrigin-RevId: 679175212
2024-09-26 09:15:12 -07:00
Sergei Lebedev
5cef547eab Added support for lax.cond_p to Pallas Mosaic GPU lowering
PiperOrigin-RevId: 679156819
2024-09-26 08:20:53 -07:00
Adam Paszke
0a66e2d0a4 [Pallas/MGPU] Fix a race in the pipelining code
We never checked if the output windows are done writing before we reused them.
Also, rename num_stages to max_concurrent_steps since we always only have 2 stages,
but might be running multiple iterations at a time.

Also fix the test for this that has been passing for reasons that I don't understand
(it didn't even write to all entries in the output??).

PiperOrigin-RevId: 679148961
2024-09-26 07:57:54 -07:00
Adam Paszke
8599dbc9b2 [Pallas/Mosaic GPU] Implement a more comprehensive matmul kernel to see what we're still missing
I annotated a number of issues in the test. To make the test run I also needed to add support
for the accumulator reference allocation and discharge in the main lowering part. Ideally,
we'd defer it all to run_scoped, but run_scoped can't allocate barriers...

PiperOrigin-RevId: 679143948
2024-09-26 07:40:15 -07:00
Adam Paszke
3c25da2c59 [Pallas/Mosaic GPU] Replace tiling/transpose fields of GPUBlockSpec with a transform list
PiperOrigin-RevId: 679079269
2024-09-26 03:41:22 -07:00
Christos Perivolaropoulos
b6d668e0d7 [pallas::mosaic_gpu] Turn the accumulator into a reference
* Changes the accumulator into a reference
* Creates a discharged flavor of the wgmma op
* run_scoped lowering discharges the input jaxpr
* dereferencing the accumulator ref is done by a new primitive that behaves as expected when discharged
* the deref primitive implies flushing the wgmma pipeline.
* run_scoped does not allow references to be leaked.

PiperOrigin-RevId: 679056765
2024-09-26 02:18:27 -07:00
jax authors
70346bda74 [Pallas] Add scalar f32 downcast test cases.
PiperOrigin-RevId: 678779025
2024-09-25 11:25:59 -07:00
Sergei Lebedev
b49d8b2615 Fixed pl.debug_printing of scalar fragmented arrays under Mosaic GPU
PiperOrigin-RevId: 678726245
2024-09-25 09:10:48 -07:00
Peter Hawkins
a43c7f2ace Enable more H100 tests in CI.
Rename "gpu" config CI tag to "gpu_v100".

PiperOrigin-RevId: 678695003
2024-09-25 07:37:48 -07:00
Christos Perivolaropoulos
390b0ba4a6 [pallas::mosaic_gpu] Support for tiled transpose transforms.
For the time being this feature only supports 2D on the GMEM side and 4D after
tiling on the SMEM side.

PiperOrigin-RevId: 678683983
2024-09-25 07:00:09 -07:00
Sergei Lebedev
cdea3d4050 lax.fori_loop now allows scalars in its cary when lowering to Mosaic GPU
PiperOrigin-RevId: 678677508
2024-09-25 06:35:23 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
Adam Paszke
9114b084fc [Pallas] Update export compatibility tests
The old test was generated before our IR was really stable, which has started
to cause problems when trying to test with Trillium.

PiperOrigin-RevId: 678277755
2024-09-24 09:17:56 -07:00
Adam Paszke
ae86ef16c7 [Mosaic GPU] Add support for input_output_aliases
PiperOrigin-RevId: 678217775
2024-09-24 06:13:28 -07:00
Chris Jones
712e638ca4 [pallas] Add support for unblocked mode (without padding) in Triton lowering.
PiperOrigin-RevId: 677870258
2024-09-23 11:21:54 -07:00
Ayaka
93203c7574 [Pallas] Simplify sign and erf_inv tests
Removed the method to locally enabling x64 using:

```python
with contextlib.ExitStack() as stack:
  if jnp.dtype(dtype).itemsize == 8:
    stack.enter_context(config.enable_x64(True))
```

This is because we can determine whether a test is running in x64 environment by checking the value of `jax.config.x64_enabled`. There is no need to locally enabling x64.

PiperOrigin-RevId: 677865574
2024-09-23 11:11:09 -07:00
Christos Perivolaropoulos
3e19a28b09 [pallas:mosaic_gpu] Basic implementation of wgmma.
PiperOrigin-RevId: 677864187
2024-09-23 11:06:17 -07:00
Ayaka
b6fe793909 [Pallas] Skip atomic_cas and atomic_counter tests on GPU in 64-bit mode
These tests are failing on GPU in 64-bit mode.

This fixes test failures introduced by https://github.com/jax-ml/jax/pull/23798

PiperOrigin-RevId: 677583606
2024-09-22 18:55:39 -07:00
Christos Perivolaropoulos
48c29f62e1 [pallas:mosaic_gpu] Fragmented array debug printing.
PiperOrigin-RevId: 677537364
2024-09-22 14:30:53 -07:00
Ayaka
d63afd8438 [Pallas GPU] Enable Pallas OpsExtraTest in 64-bit mode
This is a follow-up of https://github.com/jax-ml/jax/pull/23747, which enables Pallas `OpsTest` in 64-bit mode.

In order to enable Pallas `OpsExtraTest` in 64-bit mode, some of the code in the tests need to be modified. There are three kinds of modifications:

1. Most of the modifications are just changing `jnp.int32` to `intx` and `jnp.float32` to `floatx`, which uses the same approach as the previous PR https://github.com/jax-ml/jax/pull/23747. `intx` and `floatx` are conventions used in Pallas tests to refer to 64-bit types in 64-bit mode and their 32-bit counterparts in 32-bit mode.
2. For the test `test_array_reduce`, the original code uses a simple approach to determine `out_dtype` from `dtype`, which no longer works in 64-bit mode. Therefore, I modified the code to deduct `out_dtype` by executing the operation on a single element first.
3. For the test `test_masked_load_store`, the `idx` variable is expected to be an `int32` array, which is calculated based on `pl.program_id()` and `block_size`. In 64-bit mode, the computation will give out an `int64` array instead. Since `pl.program_id()` always returns an `int32` result, I modified the computation to produce `int32` result. I also modified the `pl.program_id()` docstring to document the behaviour that `pl.program_id()` always returns an `int32` result.

PiperOrigin-RevId: 677007613
2024-09-20 16:18:31 -07:00
Jevin Jiang
6b93b35842 [Mosaic:TPU] Efficient relayout with internal scratch
We should support all different retilings (x*packing1, 128) <-> (y*packing2, 128) with any dtype in this cl at this moment. The efficient relayout with scratch brings significant improvements on current retiling in <= TPUv4 and retiling with (packing, 128) in TPUv5. All missing retiling supports are added in this cl, including increase sublane retiling and packed type retiling.

PiperOrigin-RevId: 676982957
2024-09-20 15:00:58 -07:00
Adam Paszke
99195ead83 [Mosaic TPU] Try reducing sublane tiling to support more vector.shape_casts
In particular, 32-bit values should now support all reshapes that do not modify the
last dimension.

PiperOrigin-RevId: 676855401
2024-09-20 08:36:22 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Sharad Vikram
1db47fd85d [Pallas] Minor cleanup of memory spaces. Also add ANY as a general memory space
PiperOrigin-RevId: 676650904
2024-09-19 19:08:18 -07:00
Ayaka
de23fdb5ad [Pallas TPU] Add lowering for 64 bit 2024-09-19 16:42:45 +01:00
Sergei Lebedev
22a7c73d27 Added support for lax.fori_loop in the Pallas Mosaic GPU lowering
This, coupled with `plgpu.async_copy` and barriers, should be enough to sketch
a simple pipelined loop in the kernel.

PiperOrigin-RevId: 676374408
2024-09-19 05:30:45 -07:00
Ayaka
3f23866f75 Enable Pallas ops_test on GPU in 64-bit mode.
Previously, the 64-bit tests are skipped in `PallasBaseTest`, which disables both `OpsTest` and `OpsExtraTest`. This PR enables the 64-bit tests for `OpsTest`, and only disables it for `OpsExtraTest`.

PiperOrigin-RevId: 676373904
2024-09-19 05:29:38 -07:00
Sharad Vikram
9d2e9c688c [Pallas TPU] Add support for passing in and returning semaphores
This change enables writing async ops using Pallas. However, there are *extremely sharp edges* using this API. Please read the design note here: https://jax.readthedocs.io/en/latest/pallas/async_note.html.

Followup CLs will investigate safer APIs for writing async ops.

PiperOrigin-RevId: 676243335
2024-09-18 20:39:43 -07:00
Sergei Lebedev
016c49951f Removed leftover usages of GPUGridSpec from Pallas Mosaic GPU tests
PiperOrigin-RevId: 676029854
2024-09-18 09:57:19 -07:00
Sergei Lebedev
e90336947a Pulled scratch_shapes into GridSpec
It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton.

PiperOrigin-RevId: 675950199
2024-09-18 05:26:21 -07:00
Sergei Lebedev
3555b2b2c1 Renamed plgpu.wait to plgpu.wait_barrier
This avoid a potential ambiguity with waiting for a WGMMA to complete.

PiperOrigin-RevId: 675528768
2024-09-17 05:35:02 -07:00
Sergei Lebedev
8c39d0373a Added a new primitive for copying GMEM<->SMEM in Pallas Mosaic GPU kernels
The copy is async and needs to be awaited via `plgpu.wait_inflight(...)` for
SMEM->GMEM copies and via `plgpu.wait(barrier)` for GMEM->SMEM copies.

I decided to have distinct functions for SMEM->GMEM and GMEM->SMEM copies
and for the ways to await the result, because the underlying Mosaic GPU
APIs (and PTX ISA) *are* in fact very different.

PiperOrigin-RevId: 675155317
2024-09-16 08:18:46 -07:00
Jevin Jiang
839ce9a11d [Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```

Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
    c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
    b[:,:] <- c
  in () }
```

Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:53:29 -07:00
George Necula
ee6f098fa9 [pallas] Clean up forward-compatibility conditionals in Pallas lowering
In cl/657184114 (July 29th) I have made some changes in error reporting for invalid block shapes, but have left behind some conditionals to ensure forward compatibility. We are now out of the forward compatibility windows, and we clean up those conditionals.

PiperOrigin-RevId: 674603915
2024-09-14 02:32:16 -07:00
Sergei Lebedev
40040e3f69 Added a new approx_math flag to Mosaic GPU params in Pallas
The flag allows to control the precision of some operations, e.g. `exp`.

PiperOrigin-RevId: 674305430
2024-09-13 08:21:07 -07:00
Sergei Lebedev
e2d7ef2a49 Pallas Mosaic GPU now supports scratch buffers in SMEM
PiperOrigin-RevId: 674173250
2024-09-13 00:09:57 -07:00
jax authors
baf9cc70bd Merge pull request #23534 from ayaka14732:pallas-indexing-3
PiperOrigin-RevId: 673749898
2024-09-12 02:22:42 -07:00
Justin Fu
10057eb739 [Pallas] Fix TPU large array indexing tests.
- On TPU, this test OOMs on some chips. We fix this by forcing a garbage collect before the test.
- On interpret mode, semaphores were overflowing with a large copy size. We cap the inc/dec value at maxint to prevent overflow.

PiperOrigin-RevId: 673451668
2024-09-11 10:44:48 -07:00
Sergei Lebedev
ea68f4569c Internal change
PiperOrigin-RevId: 673409076
2024-09-11 08:47:58 -07:00
Peter Hawkins
49dd6ed8d8 Disable a pallas export compatibility test that fails on TPU v6e.
PiperOrigin-RevId: 673295487
2024-09-11 02:00:42 -07:00
Justin Fu
e3c4b20fa0 [Pallas] Implement tiled and swizzled Memref loads for Mosaic GPU via "GPUBlockSpec"
PiperOrigin-RevId: 673165201
2024-09-10 17:21:20 -07:00
Justin Fu
c659dc9a01 [Pallas] Disable win32 gpu_ops_test.
PiperOrigin-RevId: 673149107
2024-09-10 16:23:17 -07:00
Ayaka
46bcb1e057 [Pallas] Simplify lowering and fix the test for lax.erf_inv_p
This PR is a follow-up of https://github.com/google/jax/pull/23192, which implements the lowering rule for `lax.erf_inv_p`. However, I've realised that the lowering rule can be simplified, and the test for it was moved to the wrong place. This PR resolves the above 2 issues.

After merging this PR, I will continue with https://github.com/google/jax/pull/22310, which adds 64-bit lowering support for `lax.erf_inv_p`.

PiperOrigin-RevId: 673095319
2024-09-10 14:00:28 -07:00
Sergei Lebedev
9fa0164ad2 Estimate the amount of required scratch SMEM automatically in Pallas Mosaic GPU lowering
No estimation is done if `smem_scratch_bytes` was explicitly specified via
`compiler_params=`.

PiperOrigin-RevId: 672998660
2024-09-10 09:43:04 -07:00
Ayaka
e0faa596b3 [Pallas] Fix array indexing error when dimension size is not a multiple of stride 2024-09-10 02:52:55 +01:00
Ayaka
7d2f0a75c1 [Pallas GPU] Fix the behavior of jnp.sign(jnp.nan) and move the TPU test case for jnp.sign into the general test
This PR is similar to https://github.com/google/jax/pull/23192, which moves TPU test case for `lax.erf_inv` into the general test

Fixes https://github.com/google/jax/issues/23504

PiperOrigin-RevId: 672682048
2024-09-09 14:49:40 -07:00