rocm_jax/benchmarks
Yash Katariya 7f192c1946 Cache the expensive computations in GDA. For example get_shard_indices_replica_ids can be the same for multiple variables in a neural network (global_shape, mesh_axes and global_mesh) can be the same
Note that the first time will be a little slow. The below timings you are seeing shows the caching working because the benchmark is running for multiple iterations and then the time is averaged over the number of iterations.

```
name                                                     time/op
gda_construction_callback_(4, 2)_['x', 'y']              4.50ms ±10%
gda_construction_raw_(256, 8)_['x', 'y']                 5.82ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x', 'y']    2.95ms ± 6%
indices_replica_id_calc_cached_(256, 8)_['x', 'y']       28.7µs ± 1%
gda_construction_callback_(4, 2)_[None]                  31.9ms ±20%
gda_construction_raw_(256, 8)_[None]                     5.85ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_[None]        1.75ms ± 1%
indices_replica_id_calc_cached_(256, 8)_[None]           29.0µs ± 4%
gda_construction_callback_(4, 2)_['x']                   8.40ms ± 4%
gda_construction_raw_(256, 8)_['x']                      5.48ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x']         1.89ms ± 1%
indices_replica_id_calc_cached_(256, 8)_['x']            29.0µs ± 4%
gda_construction_callback_(4, 2)_['y']                   15.3ms ± 6%
gda_construction_raw_(256, 8)_['y']                      5.66ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_['y']         1.82ms ± 2%
indices_replica_id_calc_cached_(256, 8)_['y']            29.4µs ± 3%
gda_construction_callback_(4, 2)_[('x', 'y')]            4.29ms ± 5%
gda_construction_raw_(256, 8)_[('x', 'y')]               5.61ms ± 7%
indices_replica_id_calc__uncached_(256, 8)_[('x', 'y')]  3.81ms ±10%
indices_replica_id_calc_cached_(256, 8)_[('x', 'y')]     29.0µs ± 5%
gda_construction_raw_(128, 8)_['x', 'y']                 2.42ms ± 1%
indices_replica_id_calc__uncached_(128, 8)_['x', 'y']    1.14ms ±11%
indices_replica_id_calc_cached_(128, 8)_['x', 'y']       19.9µs ± 1%
gda_construction_raw_(4, 2)_['x', 'y']                   46.7µs ± 0%
indices_replica_id_calc__uncached_(4, 2)_['x', 'y']       153µs ± 4%
indices_replica_id_calc_cached_(4, 2)_['x', 'y']         11.1µs ± 8%
gda_construction_raw_(16, 4)_['x', 'y']                   164µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_['x', 'y']      212µs ± 3%
indices_replica_id_calc_cached_(16, 4)_['x', 'y']        11.3µs ± 1%
gda_construction_raw_(16, 4)_[('x', 'y')]                 163µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_[('x', 'y')]    210µs ± 2%
indices_replica_id_calc_cached_(16, 4)_[('x', 'y')]      11.6µs ± 8%
```

PiperOrigin-RevId: 422639127
2022-01-18 13:53:52 -08:00
..
2021-10-04 17:54:46 -07:00