mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00

Currently, we get MLIR verification errors when the inputs and outputs declared to be aliased do not have matching shapes and dtypes. We add a nicer error message that localizes the inputs and outputs in the corresponding PyTrees. Interestingly, if the output index is out of bounds, there is no MLIR verification error. This seems to be a bug in the StableHLO verification code. Currently, in interpreter mode we get a mix of internal assertion errors when there are errors in input_output_aliasing.