mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00

This is often useful when a kernel uses statistics tensors that are constant across the minormost dimensions. Right now the only way to use them is to force XLA to insert the extra dimension before the kernel, but that turns out to be very inefficient. PiperOrigin-RevId: 561903222