vmap tries hard to give nice error messages when the mapped axes
for different arguments have different sizes, but the code to
compute the error message can run into InconsistentDimensionOperation
in presence of dimension polynomials. Ensure that the comparisons
are done symbolically.
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.
This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.
PiperOrigin-RevId: 501707496
Add an "explicit_global_axis_size" arg. `global_axis` used to be set to `None`
when the user did not provide an explicit axis size. After this change,
`global_axis` should never be set to `None` internally, and always contain the
size of the global axis. It's still useful to thread the information that the
user has provided an explicit axis size so we can throw explicit errors in
`pxla` when explicit axis sizes are not allowed.
Why do we need to do this? We only go down the lowering path when calling
`pmap`s impl rule (while executing or final-style transforming), but not when
initial-style transforming. The global_axis size should be computed earlier,
such that it is available for initial-style transformations/primitives, e.g. if
we round-trip a multi-host pmap computation through make_jaxpr and eval_jaxpr.
We have tests for "initial-style transform of a `pmap`", but no such test for
_multi-host_ `pmap`! Alors, this bug went unnoticed.
#13545 makes `checkify` initial-style, and because `checkify-of-pmap` is a
valid way to check a `pmap`, an internal multi-host test uncovered this bug.
PiperOrigin-RevId: 499877003
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
In presence of dimension variables we conservatively do not drop unused inputs
because we may drop the only inputs from whose shape we can infer the
values of the dimension variables.
See b/261971607.
Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same).
Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 480256796