XLA recently added support for this parameter to xops.DotGeneral. It's an optional parameter that controls the accumulation type used by the dot operation.
This is useful for eg quantized ANNs, where you might want to do matrix multiples with int8 tensors and get back an int32 tensor instead of an int8 tensor that suffers from severe overflow. Note it's not sufficient in this case to cast the inputs to 'dot' to int32 beforehand and rely on the default output dtype inference, since backend devices might have an accelerated path for int8*int8->int32 matmuls and we want that explicitly represented in the XLA.
Note because XLA still doesn't support integer dots on the CPU backend, that use case can't tested with a CPU-only test at the moment.
This is an expansion of the first, rolled-back attempt (https://github.com/google/jax/pull/5260), this time with auto-diff and batching rules that some users are relying on.
My benchmarks suggest a speed-up of ~2-2.5x for larger inputs.
--
474fdfcde05b4e5f17cfcb087a832d37c41ddffe by Chris Jones <cjfj@google.com>:
[JAX] Use XLA AllGather op for GPU (when supported).
PiperOrigin-RevId: 351440599
Specifically we:
1. remove the need for split_axis rules in batching.py, and instead just
rely on collective rules (namely to handle vectorizing over a single
named axis even if the collective is applied over multiple named axes)
2. simplify BatchTrace.process_primitive so that we don't pass tracers
into rules and rely on a subtle recursion
This change breaks all_to_all when used with multiple axis names, and in
particular it breaks all_to_all given the current gmap/xmap lowering
strategy of substituting multiple axis names in place of single axis
names. We believe we can replicate the previous logic with the new rule
organization, but we're leaving that for follow-up work because it's
tricky, and because we might end up changing lowering strategies not to
require axis substitution in the same way.
All initial style primitives currently use `batch_jaxpr` in their
batching rules, but that function hasn't been updated to support
axis_name when I added support for vmap collectives.