36 Commits

Author SHA1 Message Date
Adam Paszke
6ada0b02a8 Hotfix for psum transpose
The previous patch has been causing some failures in the
`is_undefined_primal` assertion in `broadcast_position`, but it looks
like in all of those cases there are no positional axes, so this should
fix them. More debugging underway, but I wanted to make sure they're
unblocked.
2021-02-05 17:12:23 +00:00
Adam Paszke
1361ae1247 Add positional axis handling to the psum transpose rule
I must have forgotten to do that in one of the previous patches and
apparently we didn't have any tests for it (at least in the `vmap`
case)!
2021-02-05 10:59:41 +00:00
jax authors
2a7697858a Merge pull request #5557 from apaszke:trivial-ppermute-batcher
PiperOrigin-RevId: 354945256
2021-02-01 08:30:29 -08:00
Adam Paszke
a7f9b84bf1 Implement a trivial ppermute collective batcher
Splitting a single-dimensional ppermute into multiple permutations is a
hard problem in general, but not when we're splitting a size-1
dimension. More importantly, this is the case that's triggered by any
`xmap` of a `ppermute`, so we better have an implementation ready!
2021-02-01 11:49:41 +00:00
Adam Paszke
f86bf12b5a Add support for axis names in jnp.{sum,min,max}
Similarly to `jnp.einsum`, whenever we encounter an extension to the
positional NumPy API (in the case of reductions, the extension is
whenever a non-integer axis is specified), we reroute the call to a
parallel primitive instead of the standard lax reductions.

Note that this makes the parallel primitives implement a strict subset
of functionality of the lax reductions so in the future (when we decide
that we want axes to be truly first class) we can always swap out the
implementation for the parallel version. But, it makes sense to keep
them separate for the ease of prototyping in the near future.
2021-02-01 11:41:05 +00:00
Adam Paszke
baf6ed11cf Generalize the access to axis names embedded in primitives
Previously, a few places in our code assumed that all collectives (i.e.
primitives that operate over named axes) keep all of their axes in the
`axis_name` attribute. This was fine for a few simple use cases, but we
are now considering allowing named axes in many more primitives which
can have semantically different attributes where axis names can appear.
2021-01-29 17:31:40 +00:00
Matthew Johnson
014f9a86b4 implement soft_pmap in terms of xmap 2021-01-28 07:59:57 -08:00
James Bradbury
f1918f0b19 [avals with names] Revise aval constructor call sites to use a new aval.update method
PiperOrigin-RevId: 354182876
2021-01-27 15:14:02 -08:00
jax authors
deb2afe3cb Merge pull request #5521 from chr1sj0nes:changelist/345649004
PiperOrigin-RevId: 354049631
2021-01-27 02:06:53 -08:00
jax authors
814c4ad78e Merge pull request #5490 from google:xeinsum
PiperOrigin-RevId: 353929996
2021-01-26 12:56:05 -08:00
Chris Jones
b633898ef9 [JAX-GPU] Use XLA AllToAll op on GPU (for supported configurations). 2021-01-26 16:19:39 +00:00
Daniel Johnson
15b95e3ff5 Use np.shape instead of assuming argument has a shape attr 2021-01-25 18:11:38 -05:00
Daniel Johnson
c6a1bba308 Add evaluation rule for all_gather.
This should only be called when an all_gather runs on arguments that
are not batch tracers, for instance when all_gather-ing a constant.
2021-01-25 17:27:39 -05:00
Daniel Johnson
7865043341 Improve batched collective rule for all_gather_p
When an all_gather references a vmapped axis, there is a particularly
simple way of implementing it: simply "forget" that the axis was mapped,
and return the full array. Conveniently, this doesn't require any
explicit broadcasting, and makes it possible to use out_axes=None with
the results.
2021-01-25 16:52:38 -05:00
Matthew Johnson
6d2f8320c3 add xeinsum, an einsum for xmap (& einsum easter egg)
Co-authored-by: Adam Paszke <apaszke@google.com>
2021-01-21 14:47:35 -08:00
Matthew Johnson
c02d8041f4 add systematic pdot tests, utility functions
Run lots of tests with e.g.

```
env JAX_NUM_GENERATED_CASES=1000 python tests/xmap_test.py PDotTests
```
2021-01-21 14:06:30 -08:00
jax authors
62e89cbdf8 Merge pull request #5213 from malmaud:changelist/346683050
PiperOrigin-RevId: 352800932
2021-01-20 08:45:06 -08:00
Jonathan Malmaud
c0c4843b93 Add support for 'preferred_element_type' keyword arg in dot and dot_general.
XLA recently added support for this parameter to xops.DotGeneral. It's an optional parameter that controls the accumulation type used by the dot operation.

This is useful for eg quantized ANNs, where you might want to do matrix multiples with int8 tensors and get back an int32 tensor instead of an int8 tensor that suffers from severe overflow. Note it's not sufficient in this case to cast the inputs to 'dot' to int32 beforehand and rely on the default output dtype inference, since backend devices might have an accelerated path for int8*int8->int32 matmuls and we want that explicitly represented in the XLA.

Note because XLA still doesn't support integer dots on the CPU backend, that use case can't tested with a CPU-only test at the moment.
2021-01-19 18:56:46 +00:00
Chris Jones
4b48c7f42b Use XLA AllGather op for GPU (attempt 2).
This is an expansion of the first, rolled-back attempt (https://github.com/google/jax/pull/5260), this time with auto-diff and batching rules that some users are relying on.

My benchmarks suggest a speed-up of ~2-2.5x for larger inputs.
2021-01-19 11:16:25 +00:00
Anselm Levskaya
2ca247f43e Fix pdot translation rule.
This concerns the direct pdot translation rule, which is not used
during spmd lowering.
2021-01-13 12:52:28 -08:00
jax authors
47f254d252 Copybara import of the project:
--
474fdfcde05b4e5f17cfcb087a832d37c41ddffe by Chris Jones <cjfj@google.com>:

[JAX] Use XLA AllGather op for GPU (when supported).

PiperOrigin-RevId: 351440599
2021-01-12 13:47:20 -08:00
Chris Jones
474fdfcde0 [JAX] Use XLA AllGather op for GPU (when supported). 2021-01-12 14:35:01 +00:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
98aac23d92 Change from deflinear to deflinear2 2021-01-05 09:03:33 -08:00
Joan Puigcerver
85fbc6d790 Add axis_index_groups argument to all_to_all. 2020-12-07 11:52:42 +00:00
Matthew Johnson
58e441bed7 add experimental pdot primitive, basic tests 2020-11-27 11:18:01 -08:00
Matthew Johnson
8d884e2480 silence weird type error 2020-11-25 14:17:27 -08:00
Matthew Johnson
0965bc40c4 fix bug 2020-11-25 14:15:06 -08:00
Matthew Johnson
3053f4bc1e add link to XLA bug about complex dtype allreduce 2020-11-25 10:18:02 -08:00
Matthew Johnson
809731859b cleanup: unify pmin/pmax implementations with psum 2020-11-25 10:18:00 -08:00
Matthew Johnson
ebd51e12fb address reviewer comments 2020-11-25 10:15:21 -08:00
Matthew Johnson
8057cf919e simplify vmap collectives from two sets of rules to one
Specifically we:
1. remove the need for split_axis rules in batching.py, and instead just
rely on collective rules (namely to handle vectorizing over a single
named axis even if the collective is applied over multiple named axes)
2. simplify BatchTrace.process_primitive so that we don't pass tracers
into rules and rely on a subtle recursion

This change breaks all_to_all when used with multiple axis names, and in
particular it breaks all_to_all given the current gmap/xmap lowering
strategy of substituting multiple axis names in place of single axis
names. We believe we can replicate the previous logic with the new rule
organization, but we're leaving that for follow-up work because it's
tricky, and because we might end up changing lowering strategies not to
require axis substitution in the same way.
2020-11-25 10:15:21 -08:00
Peter Hawkins
424594feb2 Short-circuit references to jax.core via jax.abstract_arrays. 2020-11-19 14:15:28 -05:00
Peter Hawkins
7efc1dbc94 [JAX] Move source_info_util into jax._src.
TFP uses source_info_util, so we leave a forwarding stub until we can update TFP.

PiperOrigin-RevId: 340698612
2020-11-04 11:54:24 -08:00
Adam Paszke
6348a99fb4 Add support for vmap collectives in control flow primitives
All initial style primitives currently use `batch_jaxpr` in their
batching rules, but that function hasn't been updated to support
axis_name when I added support for vmap collectives.
2020-10-26 12:09:18 +00:00
Peter Hawkins
10b7d7d7c2 Move implementation of jax.lax into jax._src.lax.
Remove lax_ prefixes from jax/_src/lax filenames, since they aren't needed any longer to avoid name conflicts.
2020-10-17 16:09:21 -04:00