mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 23:56:06 +00:00

The immediate motivation for this is to support the lowering to StableHLO for programs with polymorphic shapes. This requires mixing of dynamic shapes with opaque types. The general strategy is to push the actual selection of the MHLO ops down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim) so that we have one place where we pick whether we use the Dynamic or static ops. These routines can also handle the opaque type. This will result in a recursive call to, e.g., mlir.slice_op, but the inner call will be using the physical avals, which should not be opaque anymore. While making this change I was confused by the fact that the custom KeyTyRules in prng.py have lowerings that return multiple MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102 and I changed the rules to return a single op. .