jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
The non-canonical state meant that we were falling back to a more expensive comparison for the first jit-compiled function in the program. I doubt there will be any impact on real benchmarks, but this perturbs the results of running a single microbenchmark in isolation.
PiperOrigin-RevId: 493489154
This function is quite important, since it runs at every JAX primitive bind,
but it included a few redundant conditionals.
PiperOrigin-RevId: 492481837
This function is quite important, since it runs at every JAX primitive bind,
but it included a few redundant conditionals.
PiperOrigin-RevId: 492460102
If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`.
To do that, we can use a generalized version of `dynamic_slice` which allows
dynamic slice sizes (where the result shape depends on those slice sizes).
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
The main changes here are only indirectly related to gather: we just had to
update some other rules (e.g. for comparison, and squeeze) for a simple
dynamic-batch-shape gather to work.
I also skipped two tests and deleted some old dynamic shape slicing logic
because we want to handle that differently. We didn't have to do that removal
in this PR, but it's just convenient given I'm looking at indexing again.
Now with:
* resetting the `random.PRNGKeyArray` type during Python typechecks
* zeroing JVP rules for random primitives
* temporarily skipping vmap-of-pmap test with keys under `config.jax_array`
PiperOrigin-RevId: 469276609
Specifically:
* Introduce a `physical_avals` view as a custom eltype method. This is
analogous to the existing `aval_to_ir_types`, but where the output
is an aval with a non-custom eltype (and hence a direct
correspondence to TF and to lowerings).
* Change jax2tf to continue tracing with logical avals, but to
maintain TF tensors of corresponding physical shape/dtype, and to
translate to TF operations based on physical avals where relevant.
* Fix up various TF impl rules to follow physical avals. To this end,
add a "physical" mode to jax2tf's `_convert_jax_impl` helper, which
carries out the conversion using physical rather than logical avals.
* Write TF impl rules for `random_{seed,split,fold_in,bits}`
primitives. To this end, factor out the part of these primitives'
impl rules that operates on the base array and convert that, pass it
through `_convert_jax_impl` in physical mode.
* Teach the jax2tf test harness how to unwrap key-array-typed outputs
into physical `uint32` arrays that it can use in comparison tests.