This is especially convenient when using JAX as an HLO generator, because the
HLO AllGather defaults to the tiling behavior.
PiperOrigin-RevId: 384897270
This has the benefit of limiting the insane axis arithmetic (with some
axes getting removed, and others introduced with their positions offset
by the removals) to the all_to_all user-facing function, but all the
collective rules should now be simpler to write. This should be a no-op
from the point of view of the users, but should make enabling all_to_all
splitting easier.
The previous patch has been causing some failures in the
`is_undefined_primal` assertion in `broadcast_position`, but it looks
like in all of those cases there are no positional axes, so this should
fix them. More debugging underway, but I wanted to make sure they're
unblocked.
Splitting a single-dimensional ppermute into multiple permutations is a
hard problem in general, but not when we're splitting a size-1
dimension. More importantly, this is the case that's triggered by any
`xmap` of a `ppermute`, so we better have an implementation ready!
Similarly to `jnp.einsum`, whenever we encounter an extension to the
positional NumPy API (in the case of reductions, the extension is
whenever a non-integer axis is specified), we reroute the call to a
parallel primitive instead of the standard lax reductions.
Note that this makes the parallel primitives implement a strict subset
of functionality of the lax reductions so in the future (when we decide
that we want axes to be truly first class) we can always swap out the
implementation for the parallel version. But, it makes sense to keep
them separate for the ease of prototyping in the near future.
Previously, a few places in our code assumed that all collectives (i.e.
primitives that operate over named axes) keep all of their axes in the
`axis_name` attribute. This was fine for a few simple use cases, but we
are now considering allowing named axes in many more primitives which
can have semantically different attributes where axis names can appear.
When an all_gather references a vmapped axis, there is a particularly
simple way of implementing it: simply "forget" that the axis was mapped,
and return the full array. Conveniently, this doesn't require any
explicit broadcasting, and makes it possible to use out_axes=None with
the results.
XLA recently added support for this parameter to xops.DotGeneral. It's an optional parameter that controls the accumulation type used by the dot operation.
This is useful for eg quantized ANNs, where you might want to do matrix multiples with int8 tensors and get back an int32 tensor instead of an int8 tensor that suffers from severe overflow. Note it's not sufficient in this case to cast the inputs to 'dot' to int32 beforehand and rely on the default output dtype inference, since backend devices might have an accelerated path for int8*int8->int32 matmuls and we want that explicitly represented in the XLA.
Note because XLA still doesn't support integer dots on the CPU backend, that use case can't tested with a CPU-only test at the moment.
This is an expansion of the first, rolled-back attempt (https://github.com/google/jax/pull/5260), this time with auto-diff and batching rules that some users are relying on.
My benchmarks suggest a speed-up of ~2-2.5x for larger inputs.
--
474fdfcde05b4e5f17cfcb087a832d37c41ddffe by Chris Jones <cjfj@google.com>:
[JAX] Use XLA AllGather op for GPU (when supported).
PiperOrigin-RevId: 351440599