`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.
`grid` now defaults to `()` instead of `None`.
Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.
Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.
Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.
Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.
Removed some dead code for implementing the interpret mode.
Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.
Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.
Added test for the calling convention, including dynamic grid dimensions.
There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
I have only added tests and documentation, will improve error
reporting separately.
For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.
For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.
In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.
PiperOrigin-RevId: 653195289
This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.
Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.
This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
So, instead of
pl.BlockSpec(lambda i, j: ..., (42, 24))
``pl.BlockSpec`` now expects
pl.BlockSpec((42, 24), lambda i, j: ...)
I will update Pallas tests in a follow up.
PiperOrigin-RevId: 648486321
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.
The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.
I have also attempted to add a few docstrings.
Fixed Typos in JEP doc files
Revert "Fixed Typos in JEP doc files"
This reverts commit c2a16950e0fc1b32971168501d183991e2394b5d.
revert two changes
reverted one change in advanced-autodiff
revert one change in parallelism
sync notebooks
This is a change that makes the API a bit more intuitive and avoids footguns like accidentally passing in `in_spec` instead of `in_specs` because previously kwargs that weren't used by any downstream lowering would be ignored and users would get weird errors as a result.
This change doesn't deprecate the old way of passing in compiler params but it will be deprecated soon after this.
PiperOrigin-RevId: 613239439
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.
PiperOrigin-RevId: 598550225
Running the example as-is gives
```
ValueError: Pytree specs for `out_shape` and `out_specs` must match: PyTreeDef(*) vs. PyTreeDef((*,))
```
Giving a list argument to `out_shape` seems to fix the issue.