This work is an effort to reduce cyclic dependencies in JAX internals.
Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.
PiperOrigin-RevId: 515667671
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
Previously we checked for out axes being a superset of the defined axes,
but that's just not the right relation. In particular, out_axes of {'a'}
are not a superset of defined axes {'b'}, but axis 'a' is undefined. The
correct check is to verify emptiness of their difference.
1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA.
2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally.
3. Some tests changes to make them pass
PiperOrigin-RevId: 474642889
I have added comments to places to explain things.
Dependence on MeshPspecSharding in Partial eval has been removed. It now depends on OpShardingSharding.
TODO: Fix the round trip through MeshPspecSharding in vmap batching handlers.
PiperOrigin-RevId: 465621165
* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
* `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.
* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.
* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.
* Rest of the changes are to make all other functions play nicely with sharding instances.
PiperOrigin-RevId: 461260553
Otherwise it can lead to tracer leak errors. I'm not a 100% sure how
this works out, because the sublevel counting has changed since I read
it previously. This replicates the changes applied to
DynamicJaxprTrace.process_map since I last looked at it.
* add caching via weakref_lru_cache
* add inst_in argument (needed for fixedpoints for loop primitives, in
follow-up PR), update callers not to over-instantiate inputs (previously I
had used a convention where call primitives would just stage out eqns with
all inputs instantiated, for expediene)
* add ensure_out_unknowns and ensure_out_inst arguments, analogues of
`instantiate` on e.g. partial_eval_jaxpr, jvp_jaxpr, etc (also neede for
fixpoints of loop primitives)
* better dce in remat_partial_eval (e.g. prune unused residuals)