There are currently two parameters that are used to configure
lowering: lowering_platform (for cross-platform lowering), and
override_lowering_rules. Each of them are passed as separate arguments
through several layers of lowering internal functions. This is tedious,
and error prone. In fact, override_lowering_rules was not plumbed
in all places, and due to using default arguments in all places,
this leads to silent errors.
We foresee introducing other parameters for lowering: for multi-platform
lowering, for controlling the lowering of effects.
Here is pack all such parameters into a `mlir.LoweringParameters`
dataclass and we plumb that through.
This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities.
Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules.
This function has two benefits over just building the stablehlo directly:
a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes
b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults).
Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper.
PiperOrigin-RevId: 561042402
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.
PiperOrigin-RevId: 523146076
Previously for native serialization we could only support polymorphic_shapes
where the specification was a simple dimension variable. E.g., we could not
handle a specification where `polymorphic_shapes="2*b"` because there was
no way to recover the value of `b` from the actual shape. (For non-native
serialization we were supporting some limited equation solving.)
The above is important, e.g., for the gradient of functions like
`jnp.concatenate([x, x])`, where the output shape if `2 *b`.
This is possible because in #15258 we have brought the computation
of the dimension variables into jax_export.
What we do here is to even out the support for native serialization to have
the same power as the non-native one. We do this by reusing the
same `shape_poly.prepare_dim_var_env` that we use for non-native
serialization.
After we land this, we will refactor the shape environment to be cleaner.
Generate DynamicPadOp instea of PadOp when the padding
sizes are not constant.
Fix the generation of RealDynamicSliceOp.
Exclude some tests that fail due to unimplemented support
for custom calls with polymorphic shapes.
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.
Unchanged occurrences:
1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
argument value in Lowering.as_text and Lowering.compiler_ir.
2) Documentation (changelog, JEPs, IR examples, etc).
3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
so both are necessary to disambiguate.
PiperOrigin-RevId: 495771153
The immediate motivation for this is to support the lowering
to StableHLO for programs with polymorphic shapes. This requires
mixing of dynamic shapes with opaque types.
The general strategy is to push the actual selection of the MHLO ops
down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim)
so that we have one place where we pick whether we use the Dynamic
or static ops. These routines can also handle the opaque type.
This will result in a recursive
call to, e.g., mlir.slice_op, but the inner call will be using
the physical avals, which should not be opaque anymore.
While making this change I was confused by the fact that the
custom KeyTyRules in prng.py have lowerings that return multiple
MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102
and I changed the rules to return a single op.
.
When working with George on https://github.com/google/jax/pull/13427, I discovered that modules with verifier errors can happily cross API boundaries and create confusion downstream.
As discussed, this is unintentional - the expectation was that `ctx.module.operation.verify()` will throw an exception when verification fails. This CL addresses that and throws an exception accordingly.
Not sure how to test this, given that passing a module with verifier errors to module_to_string indicates a logic error (i.e. such module shouldn't have been produced by JAX in the first place). As a result, I didn't write any tests, but I'm happy to write them if there's a good way to do that.
PiperOrigin-RevId: 493940591
jax2tf already supports many cases of shape polymorphism, e.g., those
where the shapes of all intermediates can be expressed as polynomials
in the dimension variables in the input. We want to achieve the same
same coverage, or more, while using StableHLO as the lowering format,
rather than tf.Graph.
For native serialization we will support two lowering implementations:
* one is using the growing support in JAX for dynamic shapes,
of which shape polymorphism is a special case.
This implementation is enabled with the --jax_dynamic_shapes flag.
At the moment, the JAX dynamic shapes support is still
incomplete and over 300 jax2tf shape polymorphism tests fail.
* a new one (added) here in which we form a Jaxpr using abstract
values that express dimension sizes as dimension polynomials
(as for the standard jax2tf). Then we lower the Jaxpr to StableHLO.
This implementation is enabled when --jax_dynamic_shapes is off.
With this implementation only 50 jax2tf tests fail (to be fixed
separately).
The key contribution here is to enable lowering a Jaxpr that contains
dimension polynomials in some of the intermediate shapes. Many lowering rules
already have some partial support for Jaxprs where the shapes contain
`Var`s. To the extent possible, we try to write lowering rules that should
cover both cases of dynamic shapes: Var or polynomials in shapes.
The lowering convention is that at top level we collect the sorted list
of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars.
All IR functions will take N additional prefix arguments of int32 type
containing the values of the dimension variables. This is stored as
a list of `ir.Value` in `LoweringContext.dim_var_values`.
Note that the Jaxprs are not changed to have extra Vars for the dimension
variable values. An alternative implementation could work by transforming
the Jaxpr to replace dimension polynomials into Vars.
The key code pattern used in the lowering rule is::
if not core.is_constant_shape(shape): # Handles both Var, and polynomials
shape = mlir.eval_dynamic_shape(ctx, shape)
return mhlo.DynamicXXX(..., shape)
else:
return mhlo.XXX(..., shape)
with `mlir.eval_dynamic_shape` handling both cases::
def eval_dynamic_shape(ctx, shape):
if config.jax_dynamic_shapes:
# Using Var
return ... subst using ctx.axis_size_env ...
else:
# Using polynomials
return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values
In order to support the above some lowering functions need to take a
LoweringContext parameter, e.g., mlir.broadcast_mhlo.
I expect that the changes here will improve the --jax_dynamic_shapes coverage
as well.
See tests/api_test.py for usage examples.
At the moment, stablehlo() works by using the hlo-legalize-to-stablehlo pass, which takes MHLO natively produced by JAX and converts it into StableHLO. This is an intermediate step towards switching JAX to natively produce StableHLO.
This CL adds both mhlo_to_stablehlo and stablehlo_to_mhlo to jaxlib, even though only the former is used at the moment. This is done in anticipation of switching JAX to natively produce StableHLO, where stablehlo_to_mhlo will be needed to provide backward compatibility for XlaLowering::mhlo(). We're adding stablehlo_to_mhlo now, so that in the future we don't have to update jaxlib again which will make deployment easier.
PiperOrigin-RevId: 487144342
Fix a bug in PJRT where if a buffer was not owned (e.g., it aliased a NumPy buffer) it could still be donated and that would lead to a use after free.
PiperOrigin-RevId: 484001545
https://peps.python.org/pep-0657/ means that we now have richer context information, which we can propagate where we use it, for example to the MHLO location in this example:
```
In [2]: jax.jit(lambda x: x + 2).lower(7).compiler_ir().operation.print(enable_debug_info=True)
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
module @jit__lambda_ {
func.func public @main(%arg0: tensor<i32> loc(unknown)) -> tensor<i32> {
%0 = mhlo.constant dense<2> : tensor<i32> loc(#loc0)
%1 = mhlo.add %arg0, %0 : tensor<i32> loc(#loc1)
return %1 : tensor<i32> loc(#loc0)
} loc(#loc0)
} loc(#loc0)
#loc1 = loc("jit(<lambda>)/jit(main)/add"("<ipython-input-2-525e569b8960>":1:18))
```
It does not make sense to pass how an input is partitioned to ShardingContext because you can have `n` inputs all partitioned in a different way but all of them should have the same device_assignment. This follows SPMDAxisContext too.
PiperOrigin-RevId: 474808207