rocm_jax/benchmarks
Yash Katariya f4637c364d Fix the gda_xla_sharding_match benchmark which was regressing. This was happening because that function was executed from top to bottom a couple of times and each time a new mesh object was created violating the already created cache which doesn't happen in real life.
```
gda_xla_sharding_match_(256, 8)_PartitionSpec('x', 'y')     21.8ms ± 2%              1.3ms ± 2%  -93.80%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(None,)        21.8ms ± 4%              1.3ms ± 1%  -93.92%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('x',)         21.8ms ± 3%              1.3ms ± 1%  -94.11%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('y',)         21.8ms ± 3%              1.3ms ± 0%  -94.12%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(('x', 'y'),)  21.8ms ± 3%              1.3ms ± 1%  -94.07%          (p=0.008 n=5+5)
gda_xla_sharding_match_(128, 8)_PartitionSpec('x', 'y')     13.9ms ± 6%              1.3ms ± 1%  -90.85%          (p=0.008 n=5+5)
gda_xla_sharding_match_(4, 2)_PartitionSpec('x', 'y')       5.72ms ±10%             1.25ms ± 1%  -78.15%          (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec('x', 'y')      6.17ms ±11%             1.25ms ± 1%  -79.71%          (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec(('x', 'y'),)   6.17ms ±10%             1.26ms ± 2%  -79.61%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 463760534
2022-07-27 23:08:55 -07:00
..