JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions.
Since batch dimensions are now allowed in arbitrary positions, it's not hard to
generalize the dot_general batching rule to avoid performing any transposes
(#2972).
In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`.
* Generalize reduce-window padding to support (lo, hi) pairs, as XLA does..
This turns out to simplify the code slightly, too.
* Fix select_and_gather_add batching rule and test.
* Fix documentation text to refer to ReduceWindowWithGeneralPadding.
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).