mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00

This lets us only perform an all-reduce once at the end of a reduction, instead of at every step. This also bundles two small improvements, making layout inference less strict for `vector.broadcast` and relaxing an assert in elementwise rule. PiperOrigin-RevId: 552413179