This is especially convenient when using JAX as an HLO generator, because the
HLO AllGather defaults to the tiling behavior.
PiperOrigin-RevId: 384897270
The previous conversion for argmin/argmax simply used tf.argmin and tf.argmax.
Those ops behave differently than JAX when the inputs contain NaN and Inf. Added
a few test cases in primitive_harness to expose the failures.
In order to implement an accurate conversion of argmin/argmax, we need to use the
XLA Reduce op.
Also tightened the shape checks for lax.argmin and lax.argmax, to ensure they are
not used with an empty reduced dimension. E.g., if the axis=-1, previously we got
an internal error:
```
RuntimeError: Invalid argument: Reducing out-of-bounds dimension -1 in shape f32[2,0,3].:
This is a bug in JAX's shape-checking rules; please report it!
```
PiperOrigin-RevId: 384182794
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.
In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.
If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.
Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.
Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
- reductions aren't fused into any first-order primitives (e.g. a `pdot`
should have a named contracting axis added rather than being followed by a
`psum`; this can be implemented by putting these primitives into
`reducing_transposes`)
- reductions are performed eagerly, even over axes that are mapped to
hardware resources (the optimal thing to do would be to reduce eagerly
over any vectorized axis component while delaying the reduction over any
hardware-mapped component until the end of the overall backward pass; this
would require a way to represent these partially-reduced values)
PiperOrigin-RevId: 383685336
XLA itself does not consume these, but they can be propagated onto scatter() when computing gradients.
Compute unique/sorted information on indexed accesses and indexed updates. Non-advanced indexes are always sorted and unique.
This is especially useful because it makes it easy to implement
"memory-limited vmaps". It might also come in handy for pipelining,
as that could represent the microbatching loop.
Note that at the moment the xmap has to be a pure map along all axes
that are assigned to loop resources. No collectives are supported.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
Part of the discrepancies were due to JAX using a workaround for
missing complex convolutions on CPU/GPU, while jax2tf was not using
it. We apply the same lowering as JAX, on all platforms.
This allows us to remove custom numeric tolerances and enables complex
convolutions on GPU.
PiperOrigin-RevId: 374199441
If it doesn't, trying to run `lu` with a custom CPU backend when a GPU is
present results in a `Unable to resolve runtime symbol:
`cuda_lu_pivots_to_permutation'` fatal error.
Previously we simply converted integer_pow to tf.math.pow. JAX instead uses
a series of multiplications. We now use the same lowering strategy as JAX, so
that we have the same numerical result.
Also improved the error messages for assertion failures.
PiperOrigin-RevId: 373351147
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:
[jax2tf] Change the conversion of dot_general to use XLA op.
Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.
Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655