mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-19 01:46:49 +00:00
[mlir]Add a check to ensure bailing out when reducing to a scalar (#129694)
Fixes issue #64075 Referencing this comment for more detailed view -> https://github.com/llvm/llvm-project/issues/64075#issuecomment-2694112594 **Minimal example crashing :** ``` func.func @multi_reduction(%0: vector<4x2xf32>, %acc1: f32) -> f32 { %2 = vector.multi_reduction <add>, %0, %acc1 [0, 1] : vector<4x2xf32> to f32 return %2 : f32 } ```
This commit is contained in:
parent
46d218d1af
commit
037756242f
@ -355,6 +355,11 @@ struct UnrollMultiReductionPattern
|
||||
|
||||
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto resultType = reductionOp->getResult(0).getType();
|
||||
if (resultType.isIntOrFloat()) {
|
||||
return rewriter.notifyMatchFailure(reductionOp,
|
||||
"Unrolling scalars is not supported");
|
||||
}
|
||||
std::optional<SmallVector<int64_t>> targetShape =
|
||||
getTargetShape(options, reductionOp);
|
||||
if (!targetShape)
|
||||
|
@ -222,6 +222,15 @@ func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) ->
|
||||
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
|
||||
// CHECK: return %[[V2]] : vector<4xf32>
|
||||
|
||||
// This is a negative test case to ensure that further unrolling is not performed. Since the vector.multi_reduction
|
||||
// operation has already been unrolled, attempting additional unrolling should not be allowed.
|
||||
func.func @negative_vector_multi_reduction(%v: vector<4x2xf32>, %acc: f32) -> f32 {
|
||||
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [0, 1] : vector<4x2xf32> to f32
|
||||
return %0 : f32
|
||||
}
|
||||
// CHECK-LABEL: func @negative_vector_multi_reduction
|
||||
// CHECK-NEXT: %[[R0:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32
|
||||
// CHECK-NEXT: return %[[R0]] : f32
|
||||
|
||||
func.func @vector_reduction(%v : vector<8xf32>) -> f32 {
|
||||
%0 = vector.reduction <add>, %v : vector<8xf32> into f32
|
||||
|
Loading…
x
Reference in New Issue
Block a user