mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

This change detects a situation where a gmem_memref is read via `async_load` and directly used in a wgmma. In such cases, we insert a cast before the load to add tile, transpose, and swizzle transformations. PiperOrigin-RevId: 732618760