mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 21:36:06 +00:00

jax2tf already supports many cases of shape polymorphism, e.g., those where the shapes of all intermediates can be expressed as polynomials in the dimension variables in the input. We want to achieve the same same coverage, or more, while using StableHLO as the lowering format, rather than tf.Graph. For native serialization we will support two lowering implementations: * one is using the growing support in JAX for dynamic shapes, of which shape polymorphism is a special case. This implementation is enabled with the --jax_dynamic_shapes flag. At the moment, the JAX dynamic shapes support is still incomplete and over 300 jax2tf shape polymorphism tests fail. * a new one (added) here in which we form a Jaxpr using abstract values that express dimension sizes as dimension polynomials (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO. This implementation is enabled when --jax_dynamic_shapes is off. With this implementation only 50 jax2tf tests fail (to be fixed separately). The key contribution here is to enable lowering a Jaxpr that contains dimension polynomials in some of the intermediate shapes. Many lowering rules already have some partial support for Jaxprs where the shapes contain `Var`s. To the extent possible, we try to write lowering rules that should cover both cases of dynamic shapes: Var or polynomials in shapes. The lowering convention is that at top level we collect the sorted list of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars. All IR functions will take N additional prefix arguments of int32 type containing the values of the dimension variables. This is stored as a list of `ir.Value` in `LoweringContext.dim_var_values`. Note that the Jaxprs are not changed to have extra Vars for the dimension variable values. An alternative implementation could work by transforming the Jaxpr to replace dimension polynomials into Vars. The key code pattern used in the lowering rule is:: if not core.is_constant_shape(shape): # Handles both Var, and polynomials shape = mlir.eval_dynamic_shape(ctx, shape) return mhlo.DynamicXXX(..., shape) else: return mhlo.XXX(..., shape) with `mlir.eval_dynamic_shape` handling both cases:: def eval_dynamic_shape(ctx, shape): if config.jax_dynamic_shapes: # Using Var return ... subst using ctx.axis_size_env ... else: # Using polynomials return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values In order to support the above some lowering functions need to take a LoweringContext parameter, e.g., mlir.broadcast_mhlo. I expect that the changes here will improve the --jax_dynamic_shapes coverage as well.