rocm_jax/benchmarks
Peter Hawkins d6c67c97db Remove redundant dtype canonicalization from jax.device_put().
Gives a small improvement to the included jax.device_put() benchmark on my VM:

```
name        old cpu/op  new cpu/op  delta
device_put  91.3µs ± 5%  80.1µs ± 3%  -12.29%  (p=0.008 n=5+5)

name        old time/op             new time/op             delta
device_put  91.4µs ± 5%             80.1µs ± 3%  -12.29%          (p=0.008 n=5+5)
```

jax.device_put() has not been optimized that much yet and there is plenty of room for further improvement.

PiperOrigin-RevId: 491727173
2022-11-29 13:47:36 -08:00
..