rocm_jax/benchmarks
Yash Katariya 928dee415f Optimize host_local_array_to_global_array by caching the local to global conversion and flattening of axis resources. Also take a fast path for device_put which does not do abstractify and only canonicalize_dtype on the entire array once (instead of doing it for every shard).
This results in a 5x speedup!

Before:

```
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array       3.03 ms         3.02 ms          220
```

After:

```
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array      0.673 ms        0.671 ms          985
```

PiperOrigin-RevId: 489880547
2022-11-20 20:53:02 -08:00
..