134 Commits

Author SHA1 Message Date
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Jake VanderPlas
e376df29be disable implicit rank promotion in a number of remaining tests 2022-01-28 08:16:30 -08:00
Matthew Johnson
c555f5f0e4 handle trivial case for ppermute batching rule
fixes #8688
2021-12-14 10:42:05 -08:00
Jake VanderPlas
df0969961b Testing: avoid hard-coding random seeds 2021-12-10 10:32:09 -08:00
Matthew Johnson
2cb235809a make vmap ppermute consistent with pmap/docstring
This was a bad bug! Unfortunately our tests didn't catch it, in part
because permutations on size-two axes are either trivial or not. The
simplest test might have a size-three axis.
2021-11-18 14:02:49 -08:00
Matthew Johnson
50e7e952bd add internal vmappable interface (part 1)
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2021-11-04 15:01:54 -07:00
Adam Paszke
49d9affce0 Enable batcher and batched collective rules for tiled all gathers
Fixes #8221.
2021-10-15 14:37:38 +00:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
jax authors
c365d7f91c Merge pull request #7908 from hawkinsp:api3
PiperOrigin-RevId: 396578038
2021-09-14 06:04:28 -07:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
jax authors
9045672aea Merge pull request #7906 from sharadmv:pdot-precision
PiperOrigin-RevId: 396481634
2021-09-13 17:40:11 -07:00
Peter Hawkins
a84426cb8f Switch internal users of jax.ops.index_... to use x.at[x].set() APIs. 2021-09-13 19:48:29 -04:00
Sharad Vikram
ebd8d95847 Add precision param for pdot 2021-09-13 16:28:31 -07:00
Adam Paszke
1c1ec79edd Clarify the error message for out-of-bounds in_axes in pmap and vmap
Fixes #5201.
2021-07-14 12:11:06 +00:00
Adam Paszke
8df502aeb2 Use the axis names attached to a primitive when selecting the top trace
This is useful e.g. for handling psums of values that are not sharded,
but are also not statically known constants that we can fold.
2021-04-28 09:46:24 +00:00
Adam Paszke
d0606463e4 Fix the batching rule for named reductions
PiperOrigin-RevId: 370505998
2021-04-26 11:41:58 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Roy Frostig
7427991819 skip scalars when broadcasting for batch dimension agreement 2021-03-19 21:47:16 -07:00
Adam Paszke
2c7c86a4ba Reenable multi-axis all_to_all 2021-03-08 12:45:03 +00:00
Adam Paszke
8a4f0a8931 Make all_to_all primitive match XLA semantics
This has the benefit of limiting the insane axis arithmetic (with some
axes getting removed, and others introduced with their positions offset
by the removals) to the all_to_all user-facing function, but all the
collective rules should now be simpler to write. This should be a no-op
from the point of view of the users, but should make enabling all_to_all
splitting easier.
2021-03-05 18:18:49 +00:00
Peter Hawkins
ff3b402ec0 Improve error messages for invalid JAX types returned by batched functions. 2021-02-16 20:02:11 -05:00
Matthew Johnson
ffb3873e5a add pargmax, pargmin wrappers 2021-02-09 19:04:46 -08:00
Adam Paszke
1361ae1247 Add positional axis handling to the psum transpose rule
I must have forgotten to do that in one of the previous patches and
apparently we didn't have any tests for it (at least in the `vmap`
case)!
2021-02-05 10:59:41 +00:00
Daniel Johnson
15b95e3ff5 Use np.shape instead of assuming argument has a shape attr 2021-01-25 18:11:38 -05:00
Daniel Johnson
c6a1bba308 Add evaluation rule for all_gather.
This should only be called when an all_gather runs on arguments that
are not batch tracers, for instance when all_gather-ing a constant.
2021-01-25 17:27:39 -05:00
Daniel Johnson
7865043341 Improve batched collective rule for all_gather_p
When an all_gather references a vmapped axis, there is a particularly
simple way of implementing it: simply "forget" that the axis was mapped,
and return the full array. Conveniently, this doesn't require any
explicit broadcasting, and makes it possible to use out_axes=None with
the results.
2021-01-25 16:52:38 -05:00
Matthew Johnson
304685a152 allow vmapped function to accept kwargs
Arguments passed as keywords are always batched along their leading
axis. The in_tree specification must correspond to arguments passed
positionally.

This brings vmap in line with pmap. That is, pmap already followed this
convention for arguments passed via keywords. Consistency is good!

I had to adapt some utility functions so as not to change the error
messages raised. In particular, we have tests for vmap error messages
which report the in_axes and argument tree structure; naively including
keyword arguments changed those error messages. The error messages are
worth preserving. This change also brought the pmap error messages in
line with the vmap ones.

I also did some 80char wrapping of lines and docstring updating.

Fixes #912. Another user had the same issue and reported the same
expected behavior.
2021-01-12 20:13:23 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
David Majnemer
a87978f094 Enable more TPU tests
PiperOrigin-RevId: 351210210
2021-01-11 12:23:36 -08:00
Jake VanderPlas
1a83bb6f90 Cleanup: remove remaining instances of rng_factory boilerplate 2020-12-11 13:47:46 -08:00
Jake VanderPlas
f74235cdae X32 tests: fail on dtype warnings 2020-12-08 13:03:30 -08:00
Joan Puigcerver
85fbc6d790 Add axis_index_groups argument to all_to_all. 2020-12-07 11:52:42 +00:00
Matthew Johnson
58e441bed7 add experimental pdot primitive, basic tests 2020-11-27 11:18:01 -08:00
Matthew Johnson
8057cf919e simplify vmap collectives from two sets of rules to one
Specifically we:
1. remove the need for split_axis rules in batching.py, and instead just
rely on collective rules (namely to handle vectorizing over a single
named axis even if the collective is applied over multiple named axes)
2. simplify BatchTrace.process_primitive so that we don't pass tracers
into rules and rely on a subtle recursion

This change breaks all_to_all when used with multiple axis names, and in
particular it breaks all_to_all given the current gmap/xmap lowering
strategy of substituting multiple axis names in place of single axis
names. We believe we can replicate the previous logic with the new rule
organization, but we're leaving that for follow-up work because it's
tricky, and because we might end up changing lowering strategies not to
require axis substitution in the same way.
2020-11-25 10:15:21 -08:00
Peter Hawkins
575a8e0668 Move lax linear algebra routines into a jax.lax.linalg module.
PiperOrigin-RevId: 340717634
2020-11-04 13:36:28 -08:00
Peter Hawkins
c57bbb3cea [JAX] Move jax/lax_linalg.py to jax/_src/lax/linalg.py.
Because we now have a facade around the lax library, we can expose the lax_linalg primitives directly in lax without creating circular dependency problems.

Leave a few forwarding stubs to be removed later.

PiperOrigin-RevId: 340658800
2020-11-04 08:59:36 -08:00
Peter Hawkins
e1adbcd4e5 Fix batching_test flakiness on GPU. 2020-10-12 09:35:39 -04:00
Adam Paszke
e61ca91360 Implement split_axis for all_to_all
This allows us to use `all_to_all` over a mix of vmapped and pmapped
dimensions, which will be useful for `gmap`.
2020-09-30 08:25:36 +00:00
Adam Paszke
fa38f250e3 Add support for all_to_all over vmapped axes 2020-09-30 08:25:36 +00:00
Srijan Saurav
40e20242db
Fix code quality issues (#4302)
Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
2020-09-17 09:21:18 -07:00
Adam Paszke
a33f4dd8c8
Add support for axis_index inside vmap (#4168)
Also, reorganize the code to put all `axis_index` related functions in
`lax_parallel.py`, next to all other parallel collectives.
2020-08-28 20:03:39 +02:00
Adam Paszke
7210d6f5d0 Add support for binding axis_name in gmap
This allows executing collectives over the gmapped axes. This requires
some extra manipulation of the gmapped jaxpr, since gmap exposes a
single logical axis name, but evaluates the program using multiple
"physical" axes.

This also fixes some bugs around handling `multiple_returns` in
vmap collective implementation.
2020-08-28 14:42:01 +02:00
Peter Hawkins
e06a6ab6bf
Add support for negative axes to vmap. (#4111)
* Add support for negative axes to vmap.

* Add workaround for out-of-range vmap axes.
2020-08-24 20:21:19 -04:00
Adam Paszke
36f3a369c7
Separate axis splitting from collective handling (#4082)
This makes the vmap collective handling a bit more flexible and allowed
me to add ppermute support.
2020-08-18 12:02:28 +02:00
Adam Paszke
b75bae6437
Initial version of vmap collectives (#4005)
This adds support for the basic (associative and commutative)
collectives to vmap. Supporting more complex collectives will
require some more complicated rules. Also, at the moment it is not
possible to use collectives inside `custom_vjp` rules which we might
want to fix in the future.

This feature is also omnistaging-only.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-08-14 18:22:04 +02:00
Jamie Townsend
e28db33b01
Fix dynamic_slice, dynamic_update_slice scalar batching, fixes #3883 (#3888)
* Add test for issue 3883

* Fix dynamic_slice, dynamic_update_slice scalar batching, fixes #3883
2020-07-28 18:39:32 -07:00
Peter Hawkins
e4d5eade54 Use iteration over equations to test for "transpose" and "broadcast". 2020-07-17 08:44:47 -04:00
Peter Hawkins
165f31ef28 Also test for transpose in dot vmap test. 2020-07-17 08:38:33 -04:00