``custom_implicit_solve`` is a helper function designed help library authors
(most notably for ``jax.scipy.optimize``) define derivatives for functions that
perform an implicit solve.
There are two real changes here:
1. In api.py, improve the handling of the axis environment in
`xla_computation` so that `xla_computation(pmap(lambda x: x))(x)` works,
by checking for pmap's included in the jaxpr to be staged out (analogous
to how jit-of-pmap works).
2. In pxla.py, handle as a special case the pmapping of computations for
which the output does not depend on the input. The purpose here is to
enable `xla_computation(pmap(lambda x: x))(x)` when `x = np.arange(8)`
yet only one XLA device is available. Evaluating that expression leads
to the (partial) evaluation of a trivial pmap (unit / empty-tuple inputs and
outputs), which would cause an error when we attempt to compile an XLA
computation for more replicas than available hardware devices. We don't
know the computation is trivial until after we've run the function, i.e.
until we're in the xla_pmap impl, so this is the right place to do it.
The other changes are unrelated miscellania.
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes#804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
This implementation copies tensors to device without building an XLA computation. XLA compilation may take time superlinear in the number of arguments, but there's no good reason for us to build a computation at all, it was merely a convenient way to implement `device_put` for free. Instead, when the argument to `device_put` isn't a tracer, call `xla.device_put` directly. If it is a tracer, fall back to the old implementation.
Timing for benchmark in #947:
In [4]: %time x = jax.device_put([onp.random.randn(10,5) for _ in range(700)])
CPU times: user 45.3 ms, sys: 3.79 ms, total: 49.1 ms
Wall time: 33.8 ms
where the timing was previously 43.4s.
Fixes#947.
The serial_pmap transformation was a placeholder and is now replaced by
soft_pmap. The papply tests that used serial_pmap now use soft_pmap,
which means they can run on parallel hardware when available.
The papply transform had some unused features (e.g. in_axes, out_axes)
that won't be needed by parallelize, so those are removed. It is also
now only needed for testing now, since parallelize (which essentially
composes a soft_pmap with a papply) is likely to be the primary
user-facing API.
This commit adds the parallelize transformation and some tests for it,
including exhaustive transpose tests.
Misc changes:
* simplified the transpose papply rule and made it lazy (so that it
doesn't need to perform communication)
* misc bugs encountered
* a few lines cherry-picked from frostig@ branch, namely the fixed
broadcasting_papply rule and plumbing the `size` argument to papply
rules
* remove psplit primitive and psplit_like primitives and replace it with
calls to all_to_all where needed