Introduce a config flag for upgrading to a world of custom PRNGs. The
flag defaults off, so that we can introduce custom PRNGs into the
codebase and allow downstream libraries time to upgrade.
Backwards compatible behavior is meant in an external sense. This does
not mean that our code is internally the same any longer.
A PRNG implementation is determined by a key shape and a set of basic
functions on such a key: seed, split, random_bits, and fold_in.
A PRNG implementation can then by lifted to an array-of-keys-like
object. Namely, a new internal pytree class PRNGKeyArray wraps the
implementation and maintains an array of keys of the right shape. This
array-like object is the new "key" that gets passed around the various
functions in the public random API (e.g. `random.uniform`,
`random.normal`, ...). So the PRNGKeyArray class really serves two
purposes at once:
1. To adapt key implementations into "arrays" of such keys.
2. To carry a reference to the PRNG implementation around and delegate
back to it from the functions in random.
When checking the data type of the dynamic arguments in jax.value_and_grad the
PyTree is unflattened with `None` (the output of `_check_input_dtype_grad`) as
value for each leaf. This causes an issue if a custom PyTree does not accept
None as a value for the leaves (issue #7546) even though the tree that is
returned from the data type check is never used.
This commit solves this issue by iterating over tree_leaves when checking data
types rather than using tree_map.
By wrapping common operators in `jit`, we get a number of benefits:
* `jit` has a faster, more optimized dispatch path compared to the primitive dispatch path in JAX. It's faster to dispatch a `jit` computation than a single primitive.
* `jit` allows us to cache and reuse logic such as broadcasting and type promotion.
One downside is that we now report an error when large Python integer scalars (e.g. `2**32 - 1`) are passed as arguments to JAX array operators. The workaround to this is to use explicitly typed constants instead of Python scalars.
On my laptop, this benchmark improves from 95us to 4us:
```
In [1]: import jax.numpy as jnp, jax
In [2]: x = jax.device_put(7)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [3]: %timeit jnp.add(x, x).block_until_ready()
4.18 µs ± 159 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```
PiperOrigin-RevId: 389871450
* Test closure conversion with mixed values in the closure, one
participating in AD and the other not.
* Simplify the basic closure_convert test and give its intermediates
more descriptive names.
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:
[jax2tf] Change the conversion of dot_general to use XLA op.
Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.
Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
--
3c400a3e588abf9e2259119c50343cba6f3477f1 by Matthew Johnson <mattjj@google.com>:
add 'inline' option to xla_call for jaxpr inlining
--
fe297e39ca37896b75d7943b9b77c0b53fad13ee by Matthew Johnson <mattjj@google.com>:
add 'inline' to jit docstring
--
ff6866c4b3757cde66fe659c2f27d8aeff024e8f by Matthew Johnson <mattjj@google.com>:
new_sublevel in jax2tf
PiperOrigin-RevId: 371542778