281 Commits

Author SHA1 Message Date
George Necula
3e5e947542 Move some backwards compatibility tests from jax_triton to jax/pallas.
While doing this I moved `matmul.py` to `jax/experimental/pallas/ops/tpu`

PiperOrigin-RevId: 660341331
2024-08-07 05:00:29 -07:00
jax authors
8b9ceb598b Handle bool comparisons.
PiperOrigin-RevId: 659919931
2024-08-06 05:37:35 -07:00
jax authors
9762ac53c8 Move CostEstimate from pltu to pl
* Move CostEstimate from TPU-specific `compiler_params` to a platform-independent argument of `pallas_call`.
Passing a CostEstimate in `compiler_params` is now deprecated and will be removed in 3 months time.
* Update the CostEstimate when batching a kernel by scaling it by the size of the batch axis.

PiperOrigin-RevId: 659560330
2024-08-05 08:18:01 -07:00
George Necula
252032a368 [pallas] Improve error and debugging messages with source locations
Document the `name` argument to `pallas_call` and supplement it with source location information for the kernel function.
Pass all this as the `name_and_src_info` parameter to the `pallas_call_p` primitive.

Added some more information to the `if debug` prints.

Set the MLIR module names so that the debug dumps are named properly.

I changed `import pallas.core as pl_core` to `... as pallas_core` for consistency, in a couple of modules.

PiperOrigin-RevId: 659506675
2024-08-05 04:23:55 -07:00
George Necula
9b35b760ce [pallas] Enable check for GPU lowering that tensor sizes are power of 2
Triton has a restriction that all operations have arguments and results
that are tensor whose size is a power of 2. Added a lowering check
for this. Without this, when we violate the condition we get an
unfriendly crash.

PiperOrigin-RevId: 659483450
2024-08-05 02:34:21 -07:00
Eugene Zhulenev
ac52890e3d [jax] Shard pallas_vmap_test
PiperOrigin-RevId: 658834942
2024-08-02 10:41:22 -07:00
Adam Paszke
86c9903067 [Pallas TPU] Make sure that the bug repros actually fail
One of them was fixed in the meantime but we didn't realize it.

PiperOrigin-RevId: 658799901
2024-08-02 08:37:22 -07:00
George Necula
43163ff2e3 [pallas] Add error message for block_shapes of rank less than 1.
PiperOrigin-RevId: 658424421
2024-08-01 09:15:07 -07:00
jax authors
6870d37822 Added test cases for more TPU Mosaic bugs.
PiperOrigin-RevId: 658067355
2024-07-31 10:58:27 -07:00
jax authors
d696813b1f Merge pull request #22746 from gnecula:pallas_consts
PiperOrigin-RevId: 658050734
2024-07-31 10:13:34 -07:00
Justin Fu
3c7c9ffbbc [Pallas] Correctly handle asymmetrical remote DMA dst_ref indexing in interpret mode.
PiperOrigin-RevId: 658029297
2024-07-31 09:10:53 -07:00
Justin Fu
acacbe297f [Pallas] Move distributed TPU tests to their own file.
PiperOrigin-RevId: 658014149
2024-07-31 08:19:19 -07:00
George Necula
987bf33e85 [pallas] Disallow capturing of consts by kernel functions.
Previously this was allowed, but until recently (#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
2024-07-31 09:06:29 +02:00
George Necula
6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
jax authors
63303973a2 Merge pull request #22716 from superbobry:pallas
PiperOrigin-RevId: 657333519
2024-07-29 14:50:01 -07:00
Sergei Lebedev
a44265aa73 Added a trivial discharge rule for debug_callback_p
This allows using jax.debug.print with Refs in interpreted Pallas kernels.
2024-07-29 22:26:01 +01:00
George Necula
68972de021 [pallas] Add lowering errors for block shapes that are not supported.
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.

PiperOrigin-RevId: 657184114
2024-07-29 06:49:27 -07:00
Sergei Lebedev
ccc4c42ec9 Reduced the input size in PallasCallInputOutputAliasingTest
This ensures the test doesn't OOM when running on A100 on the CI.

PiperOrigin-RevId: 657165032
2024-07-29 05:29:45 -07:00
Ayaka
bb160cf54e Move TPU ops test to ops_test.py
Move the TPU ops test from `tpu_ops_test.py` to `ops_test.py`. The functions tested in this file are not TPU-specific operations, so we don't need a separate test file.

PiperOrigin-RevId: 656347969
2024-07-26 04:24:13 -07:00
jax authors
2db99e03dd Merge pull request #22283 from ayaka14732:ayx/lowering/sign
PiperOrigin-RevId: 656317943
2024-07-26 02:28:33 -07:00
Ayaka
6cc09173d5 Add lowering for lax.sign 2024-07-26 10:33:42 +08:00
Sergei Lebedev
99a8b92a7f Fixed Pallas Mosaic GPU tests
* Migrated to the new barrier APIs
* Fixed scratch view casting logic, it previously didn't work for >1 view

PiperOrigin-RevId: 655937541
2024-07-25 06:51:13 -07:00
jax authors
f15f9717c3 [Pallas/TPU] Fix bug with LocalMask grid shrinking
LocalMasks can trigger shrinking of the MaskInfo arrays and of the iteration space.
As a consequence, it is important that in the kernel body we use the `global_kv_index`. This is the kv_index in the "global" space without any shrinking of the iteration space.

PiperOrigin-RevId: 655901432
2024-07-25 04:05:57 -07:00
George Necula
4063373b22 Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
Christos Perivolaropoulos
80a193d5db [pallas] Use the same primitive run_scoped_p for moth mosaic and mosaic_gpu
PiperOrigin-RevId: 655751205
2024-07-24 17:14:30 -07:00
Vladimir Belitskiy
d9a7cb4490 Skip pallas/gpu_attention_test.py on TPU.
PiperOrigin-RevId: 655575719
2024-07-24 08:24:57 -07:00
George Necula
0d058ce86f Reverts 0e17d26b6d81a6b281f55bdd81a6f0ab45efeafe
PiperOrigin-RevId: 655552768
2024-07-24 07:09:16 -07:00
George Necula
c5871331ba [pallas] Simplify handling of BlockMapping and GridMapping
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.

`grid` now defaults to `()` instead of `None`.

Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.

Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.

Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.

Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.

Removed some dead code for implementing the interpret mode.

Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.

Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.

Added test for the calling convention, including dynamic grid dimensions.

There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
2024-07-24 14:48:08 +03:00
Sharad Vikram
cfd9d8f548 [Pallas/TPU] Allow reading DMA semaphores in Pallas
PiperOrigin-RevId: 655384701
2024-07-23 19:08:45 -07:00
Christos Perivolaropoulos
4186824b34 [pallas:mosaic_gpu] Add support for run_scoped
PiperOrigin-RevId: 655338646
2024-07-23 16:13:00 -07:00
Sharad Vikram
ae8da83357 Shmallas, a.k.a. allow lowering shard_map + run_state to a pallas_call.
This allows code like this:
```python
def f(x):
  mesh = pltpu.create_tensorcore_mesh('core')
  y = jnp.zeros_like(x)
  @state_discharge.run_state
  def inner(refs):
    x_ref, y_ref = refs
    def kernel():
      def alloc(sem):
        pltpu.async_copy(x_ref, y_ref, sem).wait()
      pltpu.run_scoped(alloc, pltpu.SemaphoreType.DMA)
    shard_map.shard_map(kernel, mesh, in_specs=(), out_specs=None,
                        check_rep=False)()
  _, y = inner((x, y))
  return y
```

Why? pallas_call as an API has a lot of responsibilities:
1. Creating Refs out of Arrays
2. Parallelizing execution over cores (via dimension_semantics and grid)
3. Pipelining
4. Allocating scratch spaces
5. Scalar prefetch

This change allows you to express pallas_call *compositionally* using existing APIs.

1. Creating Refs out of arrays -> run_state
2. Parallelizing execution over cores -> shmap w/ a special mesh
3. Pipelining -> emit_pipeline
4. Allocating scratch spaces (run_scoped, which we could generalize to run_state)
5. Scalar prefetch -> run_scoped + a DMA

The hope is that this allows Pallas to generalize to more backends beyond TPU while becoming more intuitive to write and explain. For now, this lowering path is experimental and not officially exposed but we want to make sure it is possible to support.

PiperOrigin-RevId: 655320587
2024-07-23 15:16:50 -07:00
Peter Hawkins
fd85c78366 Skip some Pallas tests that fail on TPUv6.
PiperOrigin-RevId: 655153366
2024-07-23 07:16:24 -07:00
jax authors
0c09e7949a Merge pull request #22559 from superbobry:pallas-test
PiperOrigin-RevId: 655145718
2024-07-23 06:44:49 -07:00
Sergei Lebedev
b7715e279d Another take at enabling Pallas GPU tests on x64
Note that for_loop_p no longer assumes that the loop index is an int32.

Closes #18847
2024-07-23 09:19:01 +00:00
Sharad Vikram
499ceeeb2c Add support for named grids in pallas_call.
PiperOrigin-RevId: 655036727
2024-07-22 23:25:12 -07:00
George Necula
a18872aa13 Reverts d7b821b04d8fec543f570faaece7572a50a75eb6
PiperOrigin-RevId: 655019101
2024-07-22 22:05:30 -07:00
Vladimir Belitskiy
d7b821b04d The newly added test class is failing, and blocking presubmits
Reverts 09523adf7dd5b5b1099780785a73a12bf6664c53

PiperOrigin-RevId: 654842341
2024-07-22 11:52:24 -07:00
George Necula
b7105ccd19 [pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.

I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).

The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.

I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.
2024-07-22 13:34:32 +03:00
George Necula
28d4caefb0 [pallas] Thread the interpret= parmeter to pallas_call.
This aligns more tests to use the same testing structure: tests
can run on CPU (in interpreter mode) or TPU/GPU, and for each
test class MyTest we have a sibling test class MyInterpreterTest.

This is useful when developing on a machine without accelerators.
2024-07-21 20:17:12 +03:00
Enrique Piqueras
ca6d6341f5 look mum I can still can edit Pallas, AGI! Optimize pipeline emitter scheduler by omitting copies of accumulators during iteration in which they are going to be zeroed out.
Also, add some clarifying comments and set fixed RHS schedules of matmul reduce scatter implementations.

PiperOrigin-RevId: 654015498
2024-07-19 08:30:14 -07:00
Adam Paszke
5d9e715289 [Mosaic TPU] Adjust tolerance in one Pallas test
PiperOrigin-RevId: 653247157
2024-07-17 08:46:02 -07:00
George Necula
6c5583d6aa [pallas] Document and test valid block shapes.
I have only added tests and documentation, will improve error
reporting separately.

For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.

For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.

In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.

PiperOrigin-RevId: 653195289
2024-07-17 05:17:20 -07:00
Justin Fu
6ba889c01c [Pallas] Add support for checkify in TPU execution mode.
PiperOrigin-RevId: 653045818
2024-07-16 18:13:02 -07:00
jax authors
5ddec63a47 Merge pull request #22441 from gnecula:test_clean_hypothesis
PiperOrigin-RevId: 652919414
2024-07-16 11:32:46 -07:00
Sergei Lebedev
7a62b8dd18 Re-enabled PallasCallPrintTest on Cloud TPUs
PiperOrigin-RevId: 652823653
2024-07-16 06:55:20 -07:00
Justin Fu
0690988626 [Pallas] Add limited boolean memref support for scalars.
PiperOrigin-RevId: 652653003
2024-07-15 17:59:05 -07:00
Peter Hawkins
f488c4cc31 Disable some tests that fail on Cloud TPU. 2024-07-15 16:00:58 -04:00
jax authors
7255ab146b Merge pull request #22440 from gnecula:pallas_test_clean
PiperOrigin-RevId: 652513116
2024-07-15 09:55:38 -07:00
Peter Hawkins
a1f69713f5 Disable Pallas vmap test that is very slow under tsan.
PiperOrigin-RevId: 652505878
2024-07-15 09:28:35 -07:00
jax authors
86f4bb4346 Added more Mosaic bug reproducers.
PiperOrigin-RevId: 652498944
2024-07-15 09:02:48 -07:00