mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 18:36:07 +00:00

This cl removes all the shape constrains in matmul for all types. We only need to mask out subelement on contracting dim. Instead of unpacking data and applying masks, we create a VREG-sized i32 "mask" which contains subelement mask info to logical and with target vreg. Through this way, in order to mask sub-elements, each target vreg only needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select + packing). PiperOrigin-RevId: 702480077