22 Commits

Author SHA1 Message Date
James Bradbury
8e86952ee4 AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.

In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.

If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.

Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.

Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
  - reductions aren't fused into any first-order primitives (e.g. a `pdot`
    should have a named contracting axis added rather than being followed by a
    `psum`; this can be implemented by putting these primitives into
    `reducing_transposes`)
  - reductions are performed eagerly, even over axes that are mapped to
    hardware resources (the optimal thing to do would be to reduce eagerly
    over any vectorized axis component while delaying the reduction over any
    hardware-mapped component until the end of the overall backward pass; this
    would require a way to represent these partially-reduced values)

PiperOrigin-RevId: 383685336
2021-07-08 12:06:29 -07:00
James Martens
f925b62ea0 Clarifying docstring for devices argument of pmap.
PiperOrigin-RevId: 383486168
2021-07-07 13:51:11 -07:00
Matthew Johnson
a0eb1126e4 remat: don't apply cse-foiling widget to primal 2021-06-30 09:29:47 -07:00
Clemens Giuliani
3041c18250 turn lifted_jvp into a PyTree 2021-06-24 23:55:49 +02:00
Peter Hawkins
e9611eb090 Move jax.ad_util to jax._src.ad_util.
Expose ad_util.stop_gradient_p as jax.lax.stop_gradient_p. stop_gradient() is already under the external lax namespace.

PiperOrigin-RevId: 378011152
2021-06-07 14:51:34 -07:00
jax authors
3c6a41eb9c Merge pull request #6612 from google:tracer-errors
PiperOrigin-RevId: 372211269
2021-05-05 14:45:57 -07:00
Matthew Johnson
7ec0b40173 Roll-forward of #6584, which broke internal tests.
PiperOrigin-RevId: 371839298
2021-05-03 21:41:23 -07:00
Matthew Johnson
b9d72a480f improve concreteness error from arguments
also tweak some error message wording
2021-05-03 17:37:34 -07:00
Qiao Zhang
850bd66242 [JAX] Prune unused inputs in jit.
- Python part based on: https://github.com/google/jax/pull/6567
- Added cpp_jit path to handle pruned args

PiperOrigin-RevId: 371743277
2021-05-03 11:41:29 -07:00
jax authors
75b00a1235 Copybara import of the project:
--
3c400a3e588abf9e2259119c50343cba6f3477f1 by Matthew Johnson <mattjj@google.com>:

add 'inline' option to xla_call for jaxpr inlining

--
fe297e39ca37896b75d7943b9b77c0b53fad13ee by Matthew Johnson <mattjj@google.com>:

add 'inline' to jit docstring

--
ff6866c4b3757cde66fe659c2f27d8aeff024e8f by Matthew Johnson <mattjj@google.com>:

new_sublevel in jax2tf

PiperOrigin-RevId: 371542778
2021-05-01 22:18:39 -07:00
Matthew Johnson
fe297e39ca add 'inline' to jit docstring 2021-05-01 12:32:44 -07:00
Matthew Johnson
3c400a3e58 add 'inline' option to xla_call for jaxpr inlining 2021-04-28 19:38:15 -07:00
Lena Martens
b244e2b8c8 Add eval_shape to the UnexpectedTracerError too. 2021-04-23 14:46:34 +01:00
Lena Martens
deb2227f4a Make sure the out_axes in the HashableFunction closure are hashable.
By flattening them before putting them in the closure.
2021-04-21 12:32:19 +01:00
Skye Wanderman-Milne
9128ba0c74 Replace host_id with process_index terminology, take 2.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
2021-04-20 18:13:34 -07:00
jax authors
bbc7be064c Merge pull request #6239 from j-towns:lt-allow-integers
PiperOrigin-RevId: 369467931
2021-04-20 10:23:10 -07:00
jax authors
14acd070c2 Internal change
PiperOrigin-RevId: 369345279
2021-04-19 18:23:07 -07:00
Skye Wanderman-Milne
b77ef5138b Replace host_id with process_index terminology.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
2021-04-19 14:09:19 -07:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00
Jamie Townsend
90f7e06bcc Allow integer inputs/outputs of linear_transpose
Refine check to only allow float/complex -> float/complex or int -> int
functions (i.e. no mixing float/int inputs/outputs).
2021-04-15 21:17:50 +02:00
Peter Hawkins
8a221d3d4d [JAX] Add support for sharing an LRU cache between multiple C++ jit-ted functions.
Adds a new CompiledFunctionCache object that can be passed to the CompiledFunction constructor. Multiple CompiledFunctions can share the same cache capacity.

This change is in preparation for adding jit decorators to many standard library functions. We do not want to drastically increase the number of cached computations, and having a cache shared between functions allows us to avoid this.

Also allow cache entries to persist past the lifetime of the enclosing jax.jit(f) call so long as `f` remains alive. This mirrors the behavior of the existing linear_util cache that JAX uses in Python.

PiperOrigin-RevId: 368664536
2021-04-15 10:16:26 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00