mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

Previously we always had two steps when extracting the batch size: (1) check the buffer has enough dimensions, (2) get the shape. And, in a few cases, this first check was missing. Now these steps are combined into one function that returns a StatusOr. As part of this, I needed to fix our implementation of the `ASSIGN_OR_RETURN` macro to properly handle parentheses. PiperOrigin-RevId: 664803225