14 Commits

Author SHA1 Message Date
Jiyoun (Jen) Ha
e6f38b7fb3 [pallas] fix typo
PiperOrigin-RevId: 725244824
2025-02-10 09:32:07 -08:00
Justin Fu
e05afefc97 [Pallas] Pallas documentation cleanup 2024-12-04 15:19:32 -08:00
Ayaka
e79d77aa47 [Pallas] [Docs] Replace full urls with label-based cross references
This PR uses the same method to add cross references as the previous PR https://github.com/jax-ml/jax/pull/23889.

---

The content below is for future references.

#### Useful commands

Build documentation:

```sh
sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto
```

Create a label in *.md:

```md
(pallas_block_specs_by_example)=
```

Create a label in *.rst:

```rst
.. _pallas_tpu_noteworthy_properties:
```

Reference a label in *.md:

```md
{ref}`pallas_block_specs_by_example`
```

Sync changes from *.md to *.ipynb:

```sh
jupytext --sync docs/pallas/tpu/distributed.md
```

PiperOrigin-RevId: 682034607
2024-10-03 14:35:51 -07:00
George Necula
9b35b760ce [pallas] Enable check for GPU lowering that tensor sizes are power of 2
Triton has a restriction that all operations have arguments and results
that are tensor whose size is a power of 2. Added a lowering check
for this. Without this, when we violate the condition we get an
unfriendly crash.

PiperOrigin-RevId: 659483450
2024-08-05 02:34:21 -07:00
George Necula
43163ff2e3 [pallas] Add error message for block_shapes of rank less than 1.
PiperOrigin-RevId: 658424421
2024-08-01 09:15:07 -07:00
George Necula
4063373b22 Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
George Necula
0d058ce86f Reverts 0e17d26b6d81a6b281f55bdd81a6f0ab45efeafe
PiperOrigin-RevId: 655552768
2024-07-24 07:09:16 -07:00
George Necula
c5871331ba [pallas] Simplify handling of BlockMapping and GridMapping
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.

`grid` now defaults to `()` instead of `None`.

Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.

Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.

Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.

Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.

Removed some dead code for implementing the interpret mode.

Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.

Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.

Added test for the calling convention, including dynamic grid dimensions.

There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
2024-07-24 14:48:08 +03:00
George Necula
6c5583d6aa [pallas] Document and test valid block shapes.
I have only added tests and documentation, will improve error
reporting separately.

For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.

For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.

In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.

PiperOrigin-RevId: 653195289
2024-07-17 05:17:20 -07:00
George Necula
7c059d4630 [pallas] Document the indexing_mode=Unblocked()
In the process discovered that the padding in the interpreter
mode was with 0s. I changed it to NaN/minint to match the
padding for the blocked mode.
2024-07-12 12:39:10 +03:00
George Necula
ea548e7c86 [pallas] Add more documentation and tests for BlockSpec.
This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.

Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
2024-07-10 19:16:53 +03:00
George Necula
f02d32c680 [pallas] Fix the interpreter for block_shape not dividing the overall shape
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.

This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
2024-07-09 16:10:22 +03:00
Sergei Lebedev
a2a5068e5e Changed `pl.BlockSpec to accept block_shape before index_map`
So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
2024-07-01 14:26:08 -07:00
George Necula
bfdf8f4bd3 [pallas] Added more documentation for grid and BlockSpec.
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.

The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.

I have also attempted to add a few docstrings.
2024-06-29 14:43:48 +03:00