Adam Paszke 9c5e3f7ecc Verify that slices are trivial before discarding them in state primitives
At the moment, if `r` is a JAX ref then `r[0:1] = a` works, but it silently ignores the slices
and performs `r[:] = a` instead...

PiperOrigin-RevId: 529385973
2023-05-04 05:59:47 -07:00
..