... and in map primitives in general (which is why the patch touches
most traces).
This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
... and in map primitives in general (which is why the patch touches
most traces).
This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
In preparation of adding support for `in_axes` and `out_axes` to `pmap`.
The only difference in expressivity of the new approach is that the
sharded dimensions can be permuted before ordering/replicating the
indices to match the device assignment. This is necessary if we want to
support `in_axes`, because it may cause some sharded dimensions that are
supposed to get mapped to the "replication" XLA mesh axis to follow the
dimensions mapped to the "partitioning" XLA mesh axis. XLA fixes the
mesh order such that the replicated dimension is always the leading one,
which forces us to decouple the order of data dimensions from the mesh
dimensions.
This patch additionally folds the `is_axis_materialized` into the
sharding specification, by wrapping the integers in small ADT-like
wrappers that distinguish the different ways of partitioning dimensions.
The order of replication is also more explicit in the `mesh_mapping`,
as opposed to being represented as a list of replication factors to be
inserted into the sharding details to obtain a mesh mapping.
Note that this doesn't change any existing functionality. It is purely
an internal rewrite that is supposed to lay the groundwork for the next
patches.
This allows pmapping vmapped computations that use `all_to_all` or
`pswapaxes` inside. It also includes a very slow CPU and GPU translation
rule that might be useful for debugging programs locally, since XLA only
implements the `AllToAll` collective on TPUs.
Fixes#4141.
The previous rules assumed that `split_axis == concat_axis` (i.e. that
the used collective is equivalent to `pswapaxes`). Since we expose this
as part of our API, we should probably make sure that we handle other
cases too.
Fixes#1332.
The previous translation rule has assumed that `axis_index` is always
taken over the outermost axis in the `axis_env`, and was always producing
the same output, no matter which axis has been specified. This fixes the
translation rule to start taking the `axis_name` into account.
Additionally, this adds support for querying the index along multiple
axes, which will be useful for `gmap`.
This is normally unnecessary, because the XLA translation usually
doesn't bind any of the primitives in the jaxpr, but this is not true in
case of scan! Its translation rule reevaluates the jaxpr as a function,
and if it contains collectives such as `axis_index` it can fail due to
axis being missing.
This adds support for the basic (associative and commutative)
collectives to vmap. Supporting more complex collectives will
require some more complicated rules. Also, at the moment it is not
possible to use collectives inside `custom_vjp` rules which we might
want to fix in the future.
This feature is also omnistaging-only.
Co-authored-by: Matthew Johnson <mattjj@google.com>
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.
See https://github.com/google/jax/pull/3370 fo more information.
* Support differentiation through jax.lax.all_to_all
Credit to @levskaya for the solution.
* Test gradient of all_to_all
We are testing all_to_all through pswapaxes, since general all_to_all is problematic according to https://github.com/google/jax/issues/1332.
* Removed trailing spaces
fixes#3440
Also re-applies the fix in #3439 (i.e it rolls-back the rollback PR #3448) because we're now confident it's correct (and some internal tests are buggy).
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.
Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.
No functional changes intended.
* Fix a number of lax reference implementations to preserve types.
* support replica groups in allreduce collectives
* add test and fix jaxpr in docs
* switch from XLA replica IDs to JAX axis indices
* fix psum transpose rule
* test other nesting order + imperfect nesting
* update jaxpr.rst
* handle None case
* add note+check that groups cover the index space
* switch split_axis assert to NotImplementedError
* update CHANGELOG