mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00

Fix select's batching rule where `explicit_mesh_axis` that we capture in `axis_data` was not propagated properly to the `broadcast` happening in `bdim_at_front`. PiperOrigin-RevId: 748867490