This is the first step in relanding the larger refactoring in #9724, which had to be rolled back due to downstream breakages.
PiperOrigin-RevId: 431999528
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
This change makes ndarray a bit easier for tooling to handle, since de-facto
all these methods are supposed to return *something*, but the type inferrable
from their default implementations is None.
As a hand-wavy aside, in a type stub
def f(): ...
could be treated equivalently to
def f() -> Any: ...
because there is no body to infer return type from, and Any is a reasonable
fallback type. In a .py file, however, f is no longer just a function *type*
(as opposed to function *implementation*), and thus it has an inferrable
return type.
* support and test edge case where axis argument is empty tuple ()
* replace swapaxis + reshape methodology by one call to lax.reshape for computational efficiency's sake
* add check on repeated axis and throw ValueError
* introduced and changed corresponding numpy code to swap and reshape axis to be quantiled
* introduced code to accomodate the reintroduction of those axes if keepdims=True
* added testcases
by gating the offending code under a flag which no one has enabled.
#9316 is part of an ongoing experiment in adding dynamic shape support. The
experiment is meant not to perturb existing users. So any changes which may not
be innocuous should be behind the jax_dynamic_shapes flag.
But one of the changes in #9316 was not innocuous! (And I knew it might not be
at the time, but I'm an idiot and was optimistic that no one would notice.)
It has to do with the broadcasting logic in jax.numpy, specifically in
lax_numpy.py:_promote_shapes. Like NumPy, jax.numpy supports rank promotion,
e.g. `jnp.add(x:f32[4], y:f32[2,3,4])` is valid and results in the first
argument being logically promoted to shape `f32[2,3,4]` before the operation is
applied.
Our implementation of that rank promotion was to reduce it to an instance of
singleton-axis broadcasting: in the jax.numpy layer we would promote the shape
of the first argument to `f32[1,1,4]`, and then we could rely on lax.py's
singleton-axis broadcasting (copied from XLA HLO) to handle the rest. I
implemented it that way because, at least in eager mode (i.e. not staging out
with `jax.jit`), it could avoid broadcasting out a large temporary value. (I
thought reverse-mode AD would end up introducing this large intermediate
anyway, but maybe the `jit`s applied to `jax.numpy` functions avoid that...)
The way this relates to dynamic shapes is that we don't (and may not ever)
support singleton-axis broadcasting with dynamic shapes, like
`jnp.add(x:f32[n,4], y:f32[1,4])`. So when adding dynamic shape support, I
changed the rank promotion path not to rely on singleton-axis broadcasting. In
other words, instead of promoting the first argument in the example to
`f32[1,1,4]`, after #9316 we'd broadcast it to `f32[2,3,4]`. That could use
extra memory!
It turns out that some memory-sensitive users _do_ rely on this memory savings.
So we should hide this alternative implementation of rank promotion behind a
flag. (All these details around dynamic shapes are subject to change.)
PiperOrigin-RevId: 426201099