mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

Gives a small improvement to the included jax.device_put() benchmark on my VM: ``` name old cpu/op new cpu/op delta device_put 91.3µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5) name old time/op new time/op delta device_put 91.4µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5) ``` jax.device_put() has not been optimized that much yet and there is plenty of room for further improvement. PiperOrigin-RevId: 491727173