This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
* Implement mask for slice, conv, pad, transpose, where
* Remove tentative mask(jit)
* Add explanatory comment to dot_general masking rule
* Rm reshape from select masking rule
* Rm unnecessary check from lax slice abstract_eval rule
* Revert to standard indentation in masking_test.py
* Begin simplifying masking tests
* Finish drafting masking check function
* More progress simplifying tests
* Add conv masking in batch dim
* Finish fixing up tests
* Revert to old API, making out_shape compulsory again
* More efficient conv masking rule
* Tidy up masking_test imports
* Check that out tree is preserved by masking
* fix flake errors
Co-authored-by: Jamie Townsend <jamestownsend@google.com>
Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>
* Improve tracing performance of _dynamic_slice_indices
* More precisely preserve semantics of dynamic_slice_indices
* Use safe_map in dynamic_slice_indices
revert find_top_trace change from #3197
The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).
Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
* Prefer using expand_dims/broadcast_in_dim to reshape in lax_numpy.py
`reshape()` is quite powerful, but does not necessarily preserve a notion of
axis identity (particularly for axes of length 1). This is problematic for
transformation rules that need to preserve a notion of axis identity, such as
for masking and a new transformation rule I'm exploring for unraveling pytrees.
This PR rewrites these rules in terms of expand_dims / lax.broadcast_in_dim,
when feasible, which has a well-defined mapping between input and output axes.
In particular: `matmul`, various `stack` functions, the `array` constructor,
broadcasting arithmetic, array indexing, `squeeze` and reductions with
`keepdims=True` no longer use `lax.reshape`.
I also implemented support for multiple axes in `expand_dims` (added in NumPy
1.18), since it was convenient for some of these other functions.
I considered trying to write a masking rule for broadcast_in_dim as well, but
it was trickier than I expected and @JuliusKunze has probably already thought
about it :)
* Remove unnecessary branch
* Add lax.squeeze primitive
* Changes per review
* Fix typing
* Move expand_dims into lax
* Update per review; add comments/documentation
* Type annotations for squeeze/expand_dims
* Added argument check to all primitives.
The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.
This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.
In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.
* Merged find_top_trace with check_args
This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
This is being done to allow the creation of a differentiable segment_max. Segment_max is an important operation for GraphNets and is an open feature request at https://github.com/google/jax/issues/2255
Co-authored-by: Alex Davies <adavies@google.com>
* Add a primitive integer_pow() for values raised to fixed integer scalar.
Use integer_pow() in the RHS JVP of div(). Also use it in square() and reciprocal().
Fixes#3136
```
In [1]: from jax import grad, make_jaxpr
In [2]: def inv(x): return 1/x
In [3]: print(grad(grad(grad(grad(grad(grad(inv))))))(4.))
0.043945312
In [4]: make_jaxpr(grad(grad(grad(grad(grad(grad(inv)))))))(4.)
Out[4]:
{ lambda ; a.
let b = integer_pow[ y=-7 ] a
c = mul -6.0 b
d = mul -120.0 c
in (d,) }
In [5]:
```
* Use x ** 3 in gelu definition.
* Add support for sorting complex values, defaulting to a NumPy-style lexicographic ordering.
Implemented using a custom comparator, since the XLA-level default comparator doesn't impose and ordering for complex values.
* Disable sort test on CPU and TPU.
* Make lax.sort support tuple arguments using a variadic sort.
Change sort_jvp to use a gather of ids to compute the JVP rather than sorting repeatedly.
Remove sort_key_val_p, since it is redundant with a variadic sort_p.
* Fix mypy errors.
* Change JVP rule to use NumPy indexing.
Remove redundant case in batching rule.
* Added argument check to all primitives.
The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.
This error would be caught previosuly if core.skip_checks == False
because then `bind` checks its arguments. I have essentially
added an unconditional argument check to `bind`.
In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and
`numpy` would report the error somehow, perhaps.
* Merged find_top_trace with check_args
* Add decorator for broadcasting at the translation rule layer.
* Fix broadcasting in igamma gradients.
Co-authored-by: Peter Hawkins <phawkins@google.com>
* allow in_axes=None for pmap in api.py
* wire in_axes=None through parallel_callable
* add test
* fix error string
* fixes
* fixes
* add test for nested pmap with in_axes
* test pmap still defaults to (implicit) out_axes=0
At head the following fails:
```python
>>> import jax
>>> import jax.numpy as jnp
>>> jax.config.update('jax_numpy_rank_promotion', 'raise')
>>> jax.nn.one_hot(jnp.ones([8]), 512)
...
ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.
```
* Fixed a few places where device sitckyness was lost. Added FAQ for device
placement.
I have also added a new test (multi_device_test.test_computation_follows_data),
written more as part of the documentation. It is shorted than the
old test_computation_follows_data (which is still there, renamed
as test_computation_follows_data_old). I believe there is no
extra coverage in test_computation_follows_data_old w.r.t. all the
other tests we have.
* Fix mypy annotations and updates based on comments
* Undid some changes, will make another PR