rocm_jax/jax/interpreters
George Necula 27f5bd057c Improves handling of opaque types for dynamic shapes
The immediate motivation for this is to support the lowering
to StableHLO for programs with polymorphic shapes. This requires
mixing of dynamic shapes with opaque types.

The general strategy is to push the actual selection of the MHLO ops
down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim)
so that we have one place where we pick whether we use the Dynamic
or static ops. These routines can also handle the opaque type.
This will result in a recursive
call to, e.g., mlir.slice_op, but the inner call will be using
the physical avals, which should not be opaque anymore.

While making this change I was confused by the fact that the
custom KeyTyRules in prng.py have lowerings that return multiple
MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102
and I changed the rules to return a single op.

.
2022-12-12 05:19:04 +01:00
..
2022-11-10 11:59:50 -08:00
2022-12-01 09:14:16 -08:00
2022-12-05 15:42:26 -08:00