7 Commits

Author SHA1 Message Date
George Necula
4063373b22 Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
George Necula
0d058ce86f Reverts 0e17d26b6d81a6b281f55bdd81a6f0ab45efeafe
PiperOrigin-RevId: 655552768
2024-07-24 07:09:16 -07:00
George Necula
c5871331ba [pallas] Simplify handling of BlockMapping and GridMapping
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.

`grid` now defaults to `()` instead of `None`.

Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.

Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.

Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.

Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.

Removed some dead code for implementing the interpret mode.

Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.

Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.

Added test for the calling convention, including dynamic grid dimensions.

There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
2024-07-24 14:48:08 +03:00
George Necula
a18872aa13 Reverts d7b821b04d8fec543f570faaece7572a50a75eb6
PiperOrigin-RevId: 655019101
2024-07-22 22:05:30 -07:00
Vladimir Belitskiy
d7b821b04d The newly added test class is failing, and blocking presubmits
Reverts 09523adf7dd5b5b1099780785a73a12bf6664c53

PiperOrigin-RevId: 654842341
2024-07-22 11:52:24 -07:00
George Necula
28d4caefb0 [pallas] Thread the interpret= parmeter to pallas_call.
This aligns more tests to use the same testing structure: tests
can run on CPU (in interpreter mode) or TPU/GPU, and for each
test class MyTest we have a sibling test class MyInterpreterTest.

This is useful when developing on a machine without accelerators.
2024-07-21 20:17:12 +03:00
George Necula
6f79925d61 [pallas] Renamed platform-specific tests.
Previously I have moved the platform-specific tests in their own `tpu` and `gpu` subirectories, with the multi-platform tests at the top level in
the `tests/pallas` directory.

It turns out that `pytest` wants every test base file name to be unique when it is loading tests, and in order to be able to run `pytest tests/pallas` I sometimes had to add platform names to the test file name, even though it was already in a platform-specific directory, e.g., `gpu/gpu_ops_test.py`.

Here we delete the `tpu` and `gpu` test subdirectories and we prepend the platform name to the test file name.

Additionally, the old `tpu/pallas_call_test.py` is now renamed `tpu_pallas_test.py` (similar to the multi-platform test `pallas_test.py`).

PiperOrigin-RevId: 651029357
2024-07-10 08:23:06 -07:00