mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00

* Add the source location information for the index map function to `BlockMapping`. * Removed the `compute_index` wrapper around the index_map, so that we can get the location information for the index_map, not the wrapper. * Added source location to the errors related to index map functions. * Added an error if the index map returns something other than integer scalars. * Construct BlockSpec origins for arguments using JAX helper functions to get argument names * Removed redundant API error tests from tpu_pallas_test.py