This adds support for the basic (associative and commutative)
collectives to vmap. Supporting more complex collectives will
require some more complicated rules. Also, at the moment it is not
possible to use collectives inside `custom_vjp` rules which we might
want to fix in the future.
This feature is also omnistaging-only.
Co-authored-by: Matthew Johnson <mattjj@google.com>
This simplifies the implementation significantly, as we can piggyback
off of all the logic for custom derivatives. For example, the previous
implementation didn't support differentiating with respect to a subset
of function parameters, but the new one does.
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.
This reverts commit dbc3f83f6d14d491a06137f698aca92f7f3c572d.
This is breaking some google-internal users of xla_computation. Reverting while I investigate.
* Refactored host_callback to use the C++ runtime.
* The new runtime makes it unnecessary to start the outfeed_receiver
in the user's code
* We don't need msgpack anymore
* There is an interaction between host_callback and using lax.outfeed.
I am trying to solve this by (a) making host_callback_test stop the
outfeed receiver on finish and infeed_test on start, and (b)
telling pytest-xdist to run all the tests from one file into
a single worker.
The docstring for pmap does not currently mention that any "non-data"
arguments need to be indicated in `static_broadcasted_argnums`. This
commit updates the docs to parallel those for `jax.jit` which does
explain this. Additionally, a remark is added to the `static_*` argument
descriptions on both jit and pmap, so that this point can be understood
without reading the whole docstring.
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
Linearized functions are supposed to take tangent types to tangent
types, and so all primal arguments are unused and primal results get
replaced by units.
* Implement mask for slice, conv, pad, transpose, where
* Remove tentative mask(jit)
* Add explanatory comment to dot_general masking rule
* Rm reshape from select masking rule
* Rm unnecessary check from lax slice abstract_eval rule
* Revert to standard indentation in masking_test.py
* Begin simplifying masking tests
* Finish drafting masking check function
* More progress simplifying tests
* Add conv masking in batch dim
* Finish fixing up tests
* Revert to old API, making out_shape compulsory again
* More efficient conv masking rule
* Tidy up masking_test imports
* Check that out tree is preserved by masking
* fix flake errors
Co-authored-by: Jamie Townsend <jamestownsend@google.com>
Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com>
Co-authored-by: Matthew Johnson <mattjj@google.com>