* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).
The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)
PiperOrigin-RevId: 411565432
Bug: 8367
Small refactoring to jax.image.resize to make it compatible with
shape polymorphismin jax2tf. In the process added also support for
jnp.arange([dim_poly]). Note that the underlying lax.iota already
supported shape polymorphism.
Use xla.lower_fun() to implement gather/scatter modes so we can share the implementation between the XLA translation and jax2tf.
Add an undocumented "fill" mode to jnp.take() that corresponds to the "fill" mode of `lax.gather`.
PiperOrigin-RevId: 407169324
The shape rule for gather should not allow collapsing size-0 dimensions because it is nonsensical: "collapsing" a size 0 dimension might turn an empty array into a non-empty array. And it's quite unclear what that non-empty array should contain. Forbid such collapsing in the JAX shape rule.
This appears to have arisen in practice when the size of the array is known to be 0 in another dimension, e.g., batching with a size 0 batch dimension. Instead, avoid using a gather to create these arrays. This isn't an ideal solution because it isn't polymorphic in the shape, but I think to do better we would need to change the definition of `gather` more extensively.
PiperOrigin-RevId: 406346374