This PR uses the same method to add cross references as the previous PR https://github.com/jax-ml/jax/pull/23889.
---
The content below is for future references.
#### Useful commands
Build documentation:
```sh
sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto
```
Create a label in *.md:
```md
(pallas_block_specs_by_example)=
```
Create a label in *.rst:
```rst
.. _pallas_tpu_noteworthy_properties:
```
Reference a label in *.md:
```md
{ref}`pallas_block_specs_by_example`
```
Sync changes from *.md to *.ipynb:
```sh
jupytext --sync docs/pallas/tpu/distributed.md
```
PiperOrigin-RevId: 682034607
Triton has a restriction that all operations have arguments and results
that are tensor whose size is a power of 2. Added a lowering check
for this. Without this, when we violate the condition we get an
unfriendly crash.
PiperOrigin-RevId: 659483450
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.
`grid` now defaults to `()` instead of `None`.
Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.
Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.
Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.
Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.
Removed some dead code for implementing the interpret mode.
Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.
Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.
Added test for the calling convention, including dynamic grid dimensions.
There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
I have only added tests and documentation, will improve error
reporting separately.
For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.
For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.
In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.
PiperOrigin-RevId: 653195289
This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.
Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.
This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
So, instead of
pl.BlockSpec(lambda i, j: ..., (42, 24))
``pl.BlockSpec`` now expects
pl.BlockSpec((42, 24), lambda i, j: ...)
I will update Pallas tests in a follow up.
PiperOrigin-RevId: 648486321
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.
The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.
I have also attempted to add a few docstrings.