rocm_jax/jax/experimental
Yash Katariya 0532a63261 Optimizations for GDA to make creating GDA faster.
* Use math to figure out the replica id. Using `_hashed_index` (note that this is a function and not `_HashableIndex` which is a class which does not exist anymore) is 1.5 - 2 times slower than using math. markdaoust@ helped with the math here (going to office has its own perks :) )

* Get rid of `_HashableIndex` class and replace it with a function `_hashed_index`. Dataclass is extremely slow.

* Only calculate global_mesh.local_devices once. Even though its a cached property (but its after python 3.8)

```
name                                           old time/op             new time/op             delta
gda_construction_callback_(4, 2)_['x', 'y']    4.77ms ± 5%             4.74ms ± 5%     ~           (p=0.316 n=14+17)
gda_construction_raw_(256, 8)_['x', 'y']       17.9ms ± 5%              9.0ms ± 2%  -49.92%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x', 'y']    11.4ms ± 2%              2.9ms ± 2%  -74.52%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[None]        34.0ms ±20%             30.5ms ± 2%     ~             (p=0.413 n=5+4)
gda_construction_raw_(256, 8)_[None]           15.9ms ± 2%              7.7ms ± 3%  -51.56%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[None]        9.39ms ± 3%             1.74ms ± 2%  -81.44%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['x']         8.87ms ± 2%             8.92ms ± 3%     ~             (p=0.841 n=5+5)
gda_construction_raw_(256, 8)_['x']            16.4ms ± 2%              7.7ms ± 1%  -52.66%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x']         9.85ms ± 1%             1.90ms ± 2%  -80.68%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['y']         15.9ms ± 3%             16.0ms ± 5%     ~             (p=0.690 n=5+5)
gda_construction_raw_(256, 8)_['y']            15.8ms ± 3%              7.6ms ± 1%  -52.04%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['y']         9.29ms ± 1%             1.78ms ± 1%  -80.79%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[('x', 'y')]  4.65ms ± 2%             4.62ms ± 3%     ~            (p=0.440 n=5+10)
gda_construction_raw_(256, 8)_[('x', 'y')]     18.6ms ± 3%              9.7ms ± 5%  -47.76%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[('x', 'y')]  11.8ms ± 4%              3.5ms ± 2%  -70.28%          (p=0.008 n=5+5)
gda_construction_raw_(128, 8)_['x', 'y']       8.54ms ± 1%             4.03ms ± 2%  -52.84%          (p=0.008 n=5+5)
indices_replica_id_calc_(128, 8)_['x', 'y']    5.40ms ± 4%             1.10ms ± 1%  -79.69%          (p=0.008 n=5+5)
gda_construction_raw_(4, 2)_['x', 'y']          173µs ± 1%              193µs ± 3%  +11.63%          (p=0.008 n=5+5)
indices_replica_id_calc_(4, 2)_['x', 'y']       127µs ± 1%              147µs ± 1%  +15.57%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 421623147
2022-01-13 11:53:13 -08:00
..
2022-01-07 12:06:23 +01:00