mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

For the common special case of int or slice-of-int/None indexing, we can generate a lax.slice rather than a lax.gather. That makes compilation a little faster, and makes the generated jaxpr a bit more wieldy too, to process in transformations and to read when pretty-printed.