`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.
`grid` now defaults to `()` instead of `None`.
Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.
Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.
Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.
Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.
Removed some dead code for implementing the interpret mode.
Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.
Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.
Added test for the calling convention, including dynamic grid dimensions.
There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
This aligns more tests to use the same testing structure: tests
can run on CPU (in interpreter mode) or TPU/GPU, and for each
test class MyTest we have a sibling test class MyInterpreterTest.
This is useful when developing on a machine without accelerators.
Previously I have moved the platform-specific tests in their own `tpu` and `gpu` subirectories, with the multi-platform tests at the top level in
the `tests/pallas` directory.
It turns out that `pytest` wants every test base file name to be unique when it is loading tests, and in order to be able to run `pytest tests/pallas` I sometimes had to add platform names to the test file name, even though it was already in a platform-specific directory, e.g., `gpu/gpu_ops_test.py`.
Here we delete the `tpu` and `gpu` test subdirectories and we prepend the platform name to the test file name.
Additionally, the old `tpu/pallas_call_test.py` is now renamed `tpu_pallas_test.py` (similar to the multi-platform test `pallas_test.py`).
PiperOrigin-RevId: 651029357