20 Commits

Author SHA1 Message Date
jax authors
62e89cbdf8 Merge pull request #5213 from malmaud:changelist/346683050
PiperOrigin-RevId: 352800932
2021-01-20 08:45:06 -08:00
Jonathan Malmaud
c0c4843b93 Add support for 'preferred_element_type' keyword arg in dot and dot_general.
XLA recently added support for this parameter to xops.DotGeneral. It's an optional parameter that controls the accumulation type used by the dot operation.

This is useful for eg quantized ANNs, where you might want to do matrix multiples with int8 tensors and get back an int32 tensor instead of an int8 tensor that suffers from severe overflow. Note it's not sufficient in this case to cast the inputs to 'dot' to int32 beforehand and rely on the default output dtype inference, since backend devices might have an accelerated path for int8*int8->int32 matmuls and we want that explicitly represented in the XLA.

Note because XLA still doesn't support integer dots on the CPU backend, that use case can't tested with a CPU-only test at the moment.
2021-01-19 18:56:46 +00:00
Chris Jones
4b48c7f42b Use XLA AllGather op for GPU (attempt 2).
This is an expansion of the first, rolled-back attempt (https://github.com/google/jax/pull/5260), this time with auto-diff and batching rules that some users are relying on.

My benchmarks suggest a speed-up of ~2-2.5x for larger inputs.
2021-01-19 11:16:25 +00:00
Anselm Levskaya
2ca247f43e Fix pdot translation rule.
This concerns the direct pdot translation rule, which is not used
during spmd lowering.
2021-01-13 12:52:28 -08:00
jax authors
47f254d252 Copybara import of the project:
--
474fdfcde05b4e5f17cfcb087a832d37c41ddffe by Chris Jones <cjfj@google.com>:

[JAX] Use XLA AllGather op for GPU (when supported).

PiperOrigin-RevId: 351440599
2021-01-12 13:47:20 -08:00
Chris Jones
474fdfcde0 [JAX] Use XLA AllGather op for GPU (when supported). 2021-01-12 14:35:01 +00:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
98aac23d92 Change from deflinear to deflinear2 2021-01-05 09:03:33 -08:00
Joan Puigcerver
85fbc6d790 Add axis_index_groups argument to all_to_all. 2020-12-07 11:52:42 +00:00
Matthew Johnson
58e441bed7 add experimental pdot primitive, basic tests 2020-11-27 11:18:01 -08:00
Matthew Johnson
8d884e2480 silence weird type error 2020-11-25 14:17:27 -08:00
Matthew Johnson
0965bc40c4 fix bug 2020-11-25 14:15:06 -08:00
Matthew Johnson
3053f4bc1e add link to XLA bug about complex dtype allreduce 2020-11-25 10:18:02 -08:00
Matthew Johnson
809731859b cleanup: unify pmin/pmax implementations with psum 2020-11-25 10:18:00 -08:00
Matthew Johnson
ebd51e12fb address reviewer comments 2020-11-25 10:15:21 -08:00
Matthew Johnson
8057cf919e simplify vmap collectives from two sets of rules to one
Specifically we:
1. remove the need for split_axis rules in batching.py, and instead just
rely on collective rules (namely to handle vectorizing over a single
named axis even if the collective is applied over multiple named axes)
2. simplify BatchTrace.process_primitive so that we don't pass tracers
into rules and rely on a subtle recursion

This change breaks all_to_all when used with multiple axis names, and in
particular it breaks all_to_all given the current gmap/xmap lowering
strategy of substituting multiple axis names in place of single axis
names. We believe we can replicate the previous logic with the new rule
organization, but we're leaving that for follow-up work because it's
tricky, and because we might end up changing lowering strategies not to
require axis substitution in the same way.
2020-11-25 10:15:21 -08:00
Peter Hawkins
424594feb2 Short-circuit references to jax.core via jax.abstract_arrays. 2020-11-19 14:15:28 -05:00
Peter Hawkins
7efc1dbc94 [JAX] Move source_info_util into jax._src.
TFP uses source_info_util, so we leave a forwarding stub until we can update TFP.

PiperOrigin-RevId: 340698612
2020-11-04 11:54:24 -08:00
Adam Paszke
6348a99fb4 Add support for vmap collectives in control flow primitives
All initial style primitives currently use `batch_jaxpr` in their
batching rules, but that function hasn't been updated to support
axis_name when I added support for vmap collectives.
2020-10-26 12:09:18 +00:00
Peter Hawkins
10b7d7d7c2 Move implementation of jax.lax into jax._src.lax.
Remove lax_ prefixes from jax/_src/lax filenames, since they aren't needed any longer to avoid name conflicts.
2020-10-17 16:09:21 -04:00