rocm_jax/benchmarks
Tom Hennigan 9d56552517 Remove special casing on npy_value when indexing sharded arrays.
Before:

```
GPU
sda_index_2                           2912972 ns      2778716 ns          256

TPU
sda_index_1                            769968 ns       751700 ns          921
sda_index_2                           1510841 ns      1489716 ns          465
sda_index_8                           6102259 ns      6027655 ns          117
```

After:

```
GPU
sda_index_2                             28095 ns        27983 ns        25463

TPU
sda_index_1                             10302 ns        10279 ns        67884
sda_index_2                             20010 ns        19947 ns        34628
sda_index_8                             78492 ns        78306 ns         8934
```

PiperOrigin-RevId: 368380864
2021-04-14 01:17:10 -07:00
..