--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:
JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
This was a bad bug! Unfortunately our tests didn't catch it, in part
because permutations on size-two axes are either trivial or not. The
simplest test might have a size-three axis.
This has the benefit of limiting the insane axis arithmetic (with some
axes getting removed, and others introduced with their positions offset
by the removals) to the all_to_all user-facing function, but all the
collective rules should now be simpler to write. This should be a no-op
from the point of view of the users, but should make enabling all_to_all
splitting easier.
When an all_gather references a vmapped axis, there is a particularly
simple way of implementing it: simply "forget" that the axis was mapped,
and return the full array. Conveniently, this doesn't require any
explicit broadcasting, and makes it possible to use out_axes=None with
the results.
Arguments passed as keywords are always batched along their leading
axis. The in_tree specification must correspond to arguments passed
positionally.
This brings vmap in line with pmap. That is, pmap already followed this
convention for arguments passed via keywords. Consistency is good!
I had to adapt some utility functions so as not to change the error
messages raised. In particular, we have tests for vmap error messages
which report the in_axes and argument tree structure; naively including
keyword arguments changed those error messages. The error messages are
worth preserving. This change also brought the pmap error messages in
line with the vmap ones.
I also did some 80char wrapping of lines and docstring updating.
Fixes#912. Another user had the same issue and reported the same
expected behavior.
Specifically we:
1. remove the need for split_axis rules in batching.py, and instead just
rely on collective rules (namely to handle vectorizing over a single
named axis even if the collective is applied over multiple named axes)
2. simplify BatchTrace.process_primitive so that we don't pass tracers
into rules and rely on a subtle recursion
This change breaks all_to_all when used with multiple axis names, and in
particular it breaks all_to_all given the current gmap/xmap lowering
strategy of substituting multiple axis names in place of single axis
names. We believe we can replicate the previous logic with the new rule
organization, but we're leaving that for follow-up work because it's
tricky, and because we might end up changing lowering strategies not to
require axis substitution in the same way.
Because we now have a facade around the lax library, we can expose the lax_linalg primitives directly in lax without creating circular dependency problems.
Leave a few forwarding stubs to be removed later.
PiperOrigin-RevId: 340658800
Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
This allows executing collectives over the gmapped axes. This requires
some extra manipulation of the gmapped jaxpr, since gmap exposes a
single logical axis name, but evaluates the program using multiple
"physical" axes.
This also fixes some bugs around handling `multiple_returns` in
vmap collective implementation.
This adds support for the basic (associative and commutative)
collectives to vmap. Supporting more complex collectives will
require some more complicated rules. Also, at the moment it is not
possible to use collectives inside `custom_vjp` rules which we might
want to fix in the future.
This feature is also omnistaging-only.
Co-authored-by: Matthew Johnson <mattjj@google.com>