mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-19 00:46:45 +00:00

This PR will fix a bug in a canonicalization pattern (operation shape.shape_of: shape of reshape) ``` // Before func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> { %reshape = tensor.reshape %arg0(%arg1) : (tensor<?x1xf32>, tensor<3xi32>) -> tensor<?x1x1xf32> %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex> return %0 : tensor<3xindex> } //This is will error out as follows: error: 'tensor.cast' op operand type 'tensor<3xi32>' and result type 'tensor<3xindex>' are cast incompatible %0 = shape.shape_of %reshape : tensor<?x1x1xf32> -> tensor<3xindex> ^ note: see current operation: %0 = "tensor.cast"(%arg1) : (tensor<3xi32>) -> tensor<3xindex> ``` ``` // After func.func @f(%arg0: tensor<?x1xf32>, %arg1: tensor<3xi32>) -> tensor<3xindex> { %0 = arith.index_cast %arg1 : tensor<3xi32> to tensor<3xindex> return %0 : tensor<3xindex> } ``` See file canonicalize.mlir in the change list for an example. For the context, this bug was found while running a test on Keras 3, the canonicalizer errors out due to an invalid tensor.cast operation when the batch size is dynamic. The operands of the op are tensor<3xi32> cast to tensor<3xindex>. This change is related to a previous PR: https://github.com/llvm/llvm-project/pull/98531 --------- Co-authored-by: Alaa Ali <alaaali@ah-alaaali-l.dhcp.mathworks.com> Co-authored-by: Mehdi Amini <joker.eph@gmail.com>