mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00

We will implement a more efficient relayout according to the configs in rewrite ctx, such as `hardware_generation`, `max_sublanes_in_scratch` and so on. So it makes sense to change the relayout interface to take ctx (including python bindings). Now we can define rewrite ctx in `apply_vector_layout_test` as well. It makes it easier to test some advanced stuff (eg., mxu_shape change, max_sublanes_in_scratch change for rotate and relayout). PiperOrigin-RevId: 655350013