The advantage (already being realized) is that the batching rules
become much simpler: we just batch along the stacked axis as always,
and when a reduction is about to occur, also mask out the padding
elements, replacing them with the identity element of the reduction.
This commit
- Changes the intended representation of data for piles and the
corresponding BatchTracers.
- Re-defines ConcatAxis as RaggedAxis to represent the metadata.
- Updates `defreducer` to require the identity function (in case
masking is needed), and supplies it everywhere.
- Flushes batching.segment_sum, as it is dead code now.
- Deletes unpack_concat_axes and reassemble_concat_axes, because they
are irrelevant to the padded representation.
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
--
06deb73c9be01cedc000efe7b3eb72d68615471a by Matthew Johnson <mattjj@google.com>:
cache initial-style jaxpr transformations
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/9196 from mattjj:issue3847 06deb73c9be01cedc000efe7b3eb72d68615471a
PiperOrigin-RevId: 422604879
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).
The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)
PiperOrigin-RevId: 411565432
Even though `vmap` and `pmap` don't use avals with names, the batching infrastructure
is used to implement xmap and pjit. So while we keep the introduction of names carefully
scoped, forgetting to remove them at the right points leads to extremely confusing errors.
PiperOrigin-RevId: 395423006
This should let us emit good XLA annotations for `xmap(pjit)`. Previously
we might have been overestimating the set of replicated mesh dimensions.
PiperOrigin-RevId: 377259226