mirror of
https://github.com/ROCm/jax.git
synced 2025-04-13 02:16:06 +00:00

To prepare for the upcoming `BatchedDevicePut` implementation changes, this change makes `make_array_from_callback_*` benchmark code to be more homogeneous. Also it adds a variant that uses a partially replicated sharding. PiperOrigin-RevId: 736665856