* support and test edge case where axis argument is empty tuple ()
* replace swapaxis + reshape methodology by one call to lax.reshape for computational efficiency's sake
* add check on repeated axis and throw ValueError
* introduced and changed corresponding numpy code to swap and reshape axis to be quantiled
* introduced code to accomodate the reintroduction of those axes if keepdims=True
* added testcases
by gating the offending code under a flag which no one has enabled.
#9316 is part of an ongoing experiment in adding dynamic shape support. The
experiment is meant not to perturb existing users. So any changes which may not
be innocuous should be behind the jax_dynamic_shapes flag.
But one of the changes in #9316 was not innocuous! (And I knew it might not be
at the time, but I'm an idiot and was optimistic that no one would notice.)
It has to do with the broadcasting logic in jax.numpy, specifically in
lax_numpy.py:_promote_shapes. Like NumPy, jax.numpy supports rank promotion,
e.g. `jnp.add(x:f32[4], y:f32[2,3,4])` is valid and results in the first
argument being logically promoted to shape `f32[2,3,4]` before the operation is
applied.
Our implementation of that rank promotion was to reduce it to an instance of
singleton-axis broadcasting: in the jax.numpy layer we would promote the shape
of the first argument to `f32[1,1,4]`, and then we could rely on lax.py's
singleton-axis broadcasting (copied from XLA HLO) to handle the rest. I
implemented it that way because, at least in eager mode (i.e. not staging out
with `jax.jit`), it could avoid broadcasting out a large temporary value. (I
thought reverse-mode AD would end up introducing this large intermediate
anyway, but maybe the `jit`s applied to `jax.numpy` functions avoid that...)
The way this relates to dynamic shapes is that we don't (and may not ever)
support singleton-axis broadcasting with dynamic shapes, like
`jnp.add(x:f32[n,4], y:f32[1,4])`. So when adding dynamic shape support, I
changed the rank promotion path not to rely on singleton-axis broadcasting. In
other words, instead of promoting the first argument in the example to
`f32[1,1,4]`, after #9316 we'd broadcast it to `f32[2,3,4]`. That could use
extra memory!
It turns out that some memory-sensitive users _do_ rely on this memory savings.
So we should hide this alternative implementation of rank promotion behind a
flag. (All these details around dynamic shapes are subject to change.)
PiperOrigin-RevId: 426201099
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).
The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)
PiperOrigin-RevId: 411565432
Bug: 8367
Small refactoring to jax.image.resize to make it compatible with
shape polymorphismin jax2tf. In the process added also support for
jnp.arange([dim_poly]). Note that the underlying lax.iota already
supported shape polymorphism.