mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple dataclasses that just store the parameters passed from the API. They are then canonicalized and coverted to `BlockMapping` and `GridMapping`, which contains fewer optional metadata. In particular, `BlockMapping` is never `None`. This consolidates the code to preprocess the block and grid parameters, and simplifies the code downstream. `grid` now defaults to `()` instead of `None`. Added more fields to `BlockMapping` (`block_aval`, `array_shape_dtype`, and `source`). The `source` field is used in error messages. The `array_shape_dtype` makes it unnecessary to process BlockMappings zipped with `in_shapes`. With these fields, we can now add a `check_invariants` method that is called during testing or when `config.enable_checks` is true. Added more fields and a `check_invariants` to `GridMapping`, since it is such an important data structure. The new fields are: `index_map_avals`, `index_map_tree` (to encode the calling convention for the index map functions), `num_inputs`, `num_outputs`. The latter make it possible to recover the `in_shapes` and `out_shapes` from the GridMapping. Previously there was some redundancy of information between `in_shapes` and `out_shapes`. Now we do not need the `in_shapes` and `out_shapes` parameters to `pallas_call_p`, since it already has `grid_mapping`. Moved some of the logic for handling scalar prefetch and scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to `GridSpec.get_grid_mapping`, and thus removed code duplication. Removed some dead code for implementing the interpret mode. Previous handling of hoisted consts did not account for them in `in_shapes`. Now, this is fixed since we do not keep track of `in_shapes` separately. Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to avoid confusion with the use of mapped in block shapes. Added test for the calling convention, including dynamic grid dimensions. There is more work to be done: with the new information in `GridMapping` it should be possible to clean the code throughout that extract various parts of the inputs and outputs. This should be a bunch of local changes, which I will do separately once I merge this large global change.