Jevin Jiang 9e5edb7015 [Mosaic TPU] Support packed type matmul with arbitrary shapes.
This cl removes all the shape constrains in matmul for all types.

We only need to mask out subelement on contracting dim. Instead of unpacking data and applying masks, we create a VREG-sized i32 "mask" which contains subelement mask info to logical and with target vreg. Through this way, in order to mask sub-elements, each target vreg only needs to apply 1 op (logical_and) instead of 3 ops (unpacking + select + packing).

PiperOrigin-RevId: 702480077
2024-12-03 14:58:42 -08:00
..
2024-03-25 11:46:39 -07:00
2024-06-17 16:49:22 +00:00
2024-06-26 16:10:18 -04:00

jaxlib: support library for JAX

jaxlib is the support library for JAX. While JAX itself is a pure Python package, jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. For more information, including installation and build instructions, refer to main JAX README: https://github.com/jax-ml/jax/.