Use xla.lower_fun() to implement gather/scatter modes so we can share the implementation between the XLA translation and jax2tf.
Add an undocumented "fill" mode to jnp.take() that corresponds to the "fill" mode of `lax.gather`.
PiperOrigin-RevId: 407169324
The shape rule for gather should not allow collapsing size-0 dimensions because it is nonsensical: "collapsing" a size 0 dimension might turn an empty array into a non-empty array. And it's quite unclear what that non-empty array should contain. Forbid such collapsing in the JAX shape rule.
This appears to have arisen in practice when the size of the array is known to be 0 in another dimension, e.g., batching with a size 0 batch dimension. Instead, avoid using a gather to create these arrays. This isn't an ideal solution because it isn't polymorphic in the shape, but I think to do better we would need to change the definition of `gather` more extensively.
PiperOrigin-RevId: 406346374
* give an error for NumPy indexing with slices when the elements
of the slices are not constant. This check existed, but was
throing an error when the elements are dimension polynomials.
* give an error for NumPy indexing with slices when the dimension
size is not constant.
* Improvements in the handling of enable_xla=False for shape
polymorphism.
* Added test cases for the above.
- Add docstring to abstract property
- Add explicit HTML documentation of this property
- Mark index update functions as deprecated, linking to this documentation
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.
This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.
PiperOrigin-RevId: 398008511