61 Commits

Author SHA1 Message Date
Adam Paszke
64510bd5b6 Add axis and tiled options to lax.all_gather.
This is especially convenient when using JAX as an HLO generator, because the
HLO AllGather defaults to the tiling behavior.

PiperOrigin-RevId: 384897270
2021-07-15 04:22:36 -07:00
Adam Paszke
490f9778c8 Raise a friendlier error message when using loop axes in collectives 2021-06-08 11:55:03 +00:00
Adam Paszke
ed96e5305f Fix incorrect handling of axis_index_groups in parallel primitive fallbacks
PiperOrigin-RevId: 377139424
2021-06-02 14:03:47 -07:00
Adam Paszke
8df502aeb2 Use the axis names attached to a primitive when selecting the top trace
This is useful e.g. for handling psums of values that are not sharded,
but are also not statically known constants that we can fold.
2021-04-28 09:46:24 +00:00
Adam Paszke
d0606463e4 Fix the batching rule for named reductions
PiperOrigin-RevId: 370505998
2021-04-26 11:41:58 -07:00
Adam Paszke
2d95d5ad2b Small updates to abstract eval rules (AWN related)
I've been reading the AWN-related PRs and have found a few places that
could be improved a little.
2021-04-15 09:50:11 +00:00
Chris Jones
f87d05b48b Simplify all-reduce translation rule.
XLA should now be deterministic on GPU, so we don't need to special-case the GPU backend.
2021-04-13 10:28:32 +01:00
jax authors
db78037f8b Merge pull request #6219 from apaszke:fix-xmap-allgather
PiperOrigin-RevId: 367449098
2021-04-08 09:46:47 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Adam Paszke
ba8430605d Fix lax.all_gather inside xmap
The batching rule didn't properly handle tupled axis names.
2021-04-07 17:02:16 +00:00
Jake VanderPlas
8e789c7380 Run doctest on all source files except jax2tf 2021-04-05 10:39:59 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Jake VanderPlas
4139faf490 Fix dtype for pmap of scalars 2021-03-25 12:43:31 -07:00
Matthew Johnson
2b9ffb1fb3 make axis_index bind respect dynamic traces 2021-03-09 13:51:12 -08:00
James Bradbury
a8b8246554 add some todos 2021-03-09 13:51:09 -08:00
James Bradbury
c622422dad [avals with names] Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules 2021-03-09 13:48:15 -08:00
Adam Paszke
2c7c86a4ba Reenable multi-axis all_to_all 2021-03-08 12:45:03 +00:00
Adam Paszke
8a4f0a8931 Make all_to_all primitive match XLA semantics
This has the benefit of limiting the insane axis arithmetic (with some
axes getting removed, and others introduced with their positions offset
by the removals) to the all_to_all user-facing function, but all the
collective rules should now be simpler to write. This should be a no-op
from the point of view of the users, but should make enabling all_to_all
splitting easier.
2021-03-05 18:18:49 +00:00
Matthew Johnson
9b18135b6e Rollback of #5702 due to internal breakage.
PiperOrigin-RevId: 357943850
2021-02-17 07:32:09 -08:00
Matthew Johnson
7fa4dbb5b5 add some todos 2021-02-16 15:46:14 -08:00
James Bradbury
fb160b8afd [avals with names] Propagate presence of name (mapped) vs absence (replicated) in abstract eval based on existing batching rules 2021-02-16 15:46:14 -08:00
jax authors
34217d0aee Merge pull request #5682 from google:pargmax
PiperOrigin-RevId: 357019508
2021-02-11 11:26:02 -08:00
Adam Paszke
926b2ad03f Minor fixes for xmap docstring, xeinsum parser
The regression loss example from the xmap docstring was broken and
the xeinsum parser didn't accept empty parens while it should.
2021-02-10 10:30:49 +00:00
Matthew Johnson
ffb3873e5a add pargmax, pargmin wrappers 2021-02-09 19:04:46 -08:00
Adam Paszke
b19dd87581 Add a pgather primitive, making it possible to index into mapped axes 2021-02-09 10:44:31 +00:00
Adam Paszke
6ada0b02a8 Hotfix for psum transpose
The previous patch has been causing some failures in the
`is_undefined_primal` assertion in `broadcast_position`, but it looks
like in all of those cases there are no positional axes, so this should
fix them. More debugging underway, but I wanted to make sure they're
unblocked.
2021-02-05 17:12:23 +00:00
Adam Paszke
1361ae1247 Add positional axis handling to the psum transpose rule
I must have forgotten to do that in one of the previous patches and
apparently we didn't have any tests for it (at least in the `vmap`
case)!
2021-02-05 10:59:41 +00:00
jax authors
2a7697858a Merge pull request #5557 from apaszke:trivial-ppermute-batcher
PiperOrigin-RevId: 354945256
2021-02-01 08:30:29 -08:00
Adam Paszke
a7f9b84bf1 Implement a trivial ppermute collective batcher
Splitting a single-dimensional ppermute into multiple permutations is a
hard problem in general, but not when we're splitting a size-1
dimension. More importantly, this is the case that's triggered by any
`xmap` of a `ppermute`, so we better have an implementation ready!
2021-02-01 11:49:41 +00:00
Adam Paszke
f86bf12b5a Add support for axis names in jnp.{sum,min,max}
Similarly to `jnp.einsum`, whenever we encounter an extension to the
positional NumPy API (in the case of reductions, the extension is
whenever a non-integer axis is specified), we reroute the call to a
parallel primitive instead of the standard lax reductions.

Note that this makes the parallel primitives implement a strict subset
of functionality of the lax reductions so in the future (when we decide
that we want axes to be truly first class) we can always swap out the
implementation for the parallel version. But, it makes sense to keep
them separate for the ease of prototyping in the near future.
2021-02-01 11:41:05 +00:00
Adam Paszke
baf6ed11cf Generalize the access to axis names embedded in primitives
Previously, a few places in our code assumed that all collectives (i.e.
primitives that operate over named axes) keep all of their axes in the
`axis_name` attribute. This was fine for a few simple use cases, but we
are now considering allowing named axes in many more primitives which
can have semantically different attributes where axis names can appear.
2021-01-29 17:31:40 +00:00
Matthew Johnson
014f9a86b4 implement soft_pmap in terms of xmap 2021-01-28 07:59:57 -08:00
James Bradbury
f1918f0b19 [avals with names] Revise aval constructor call sites to use a new aval.update method
PiperOrigin-RevId: 354182876
2021-01-27 15:14:02 -08:00
jax authors
deb2afe3cb Merge pull request #5521 from chr1sj0nes:changelist/345649004
PiperOrigin-RevId: 354049631
2021-01-27 02:06:53 -08:00
jax authors
814c4ad78e Merge pull request #5490 from google:xeinsum
PiperOrigin-RevId: 353929996
2021-01-26 12:56:05 -08:00
Chris Jones
b633898ef9 [JAX-GPU] Use XLA AllToAll op on GPU (for supported configurations). 2021-01-26 16:19:39 +00:00
Daniel Johnson
15b95e3ff5 Use np.shape instead of assuming argument has a shape attr 2021-01-25 18:11:38 -05:00
Daniel Johnson
c6a1bba308 Add evaluation rule for all_gather.
This should only be called when an all_gather runs on arguments that
are not batch tracers, for instance when all_gather-ing a constant.
2021-01-25 17:27:39 -05:00
Daniel Johnson
7865043341 Improve batched collective rule for all_gather_p
When an all_gather references a vmapped axis, there is a particularly
simple way of implementing it: simply "forget" that the axis was mapped,
and return the full array. Conveniently, this doesn't require any
explicit broadcasting, and makes it possible to use out_axes=None with
the results.
2021-01-25 16:52:38 -05:00
Matthew Johnson
6d2f8320c3 add xeinsum, an einsum for xmap (& einsum easter egg)
Co-authored-by: Adam Paszke <apaszke@google.com>
2021-01-21 14:47:35 -08:00
Matthew Johnson
c02d8041f4 add systematic pdot tests, utility functions
Run lots of tests with e.g.

```
env JAX_NUM_GENERATED_CASES=1000 python tests/xmap_test.py PDotTests
```
2021-01-21 14:06:30 -08:00
jax authors
62e89cbdf8 Merge pull request #5213 from malmaud:changelist/346683050
PiperOrigin-RevId: 352800932
2021-01-20 08:45:06 -08:00
Jonathan Malmaud
c0c4843b93 Add support for 'preferred_element_type' keyword arg in dot and dot_general.
XLA recently added support for this parameter to xops.DotGeneral. It's an optional parameter that controls the accumulation type used by the dot operation.

This is useful for eg quantized ANNs, where you might want to do matrix multiples with int8 tensors and get back an int32 tensor instead of an int8 tensor that suffers from severe overflow. Note it's not sufficient in this case to cast the inputs to 'dot' to int32 beforehand and rely on the default output dtype inference, since backend devices might have an accelerated path for int8*int8->int32 matmuls and we want that explicitly represented in the XLA.

Note because XLA still doesn't support integer dots on the CPU backend, that use case can't tested with a CPU-only test at the moment.
2021-01-19 18:56:46 +00:00
Chris Jones
4b48c7f42b Use XLA AllGather op for GPU (attempt 2).
This is an expansion of the first, rolled-back attempt (https://github.com/google/jax/pull/5260), this time with auto-diff and batching rules that some users are relying on.

My benchmarks suggest a speed-up of ~2-2.5x for larger inputs.
2021-01-19 11:16:25 +00:00
Anselm Levskaya
2ca247f43e Fix pdot translation rule.
This concerns the direct pdot translation rule, which is not used
during spmd lowering.
2021-01-13 12:52:28 -08:00
jax authors
47f254d252 Copybara import of the project:
--
474fdfcde05b4e5f17cfcb087a832d37c41ddffe by Chris Jones <cjfj@google.com>:

[JAX] Use XLA AllGather op for GPU (when supported).

PiperOrigin-RevId: 351440599
2021-01-12 13:47:20 -08:00
Chris Jones
474fdfcde0 [JAX] Use XLA AllGather op for GPU (when supported). 2021-01-12 14:35:01 +00:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
98aac23d92 Change from deflinear to deflinear2 2021-01-05 09:03:33 -08:00
Joan Puigcerver
85fbc6d790 Add axis_index_groups argument to all_to_all. 2020-12-07 11:52:42 +00:00