mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 07:36:06 +00:00
[flang] add hlfir.all intrinsic
Adds a new HLFIR operation for the ALL intrinsic according to the design set out in flang/docs/HighLevel.md Differential Revision: https://reviews.llvm.org/D151090
This commit is contained in:
parent
544a240ff7
commit
206b8538a6
@ -317,6 +317,27 @@ def hlfir_ConcatOp : hlfir_Op<"concat", []> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def hlfir_AllOp : hlfir_Op<"all", []> {
|
||||
let summary = "ALL transformational intrinsic";
|
||||
let description = [{
|
||||
Takes a logical array MASK as argument, optionally along a particular dimension,
|
||||
and returns true if all elements of MASK are true.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyFortranLogicalArrayObject:$mask,
|
||||
Optional<AnyIntegerType>:$dim
|
||||
);
|
||||
|
||||
let results = (outs AnyFortranValue);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$mask (`dim` $dim^)? attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def hlfir_AnyOp : hlfir_Op<"any", []> {
|
||||
let summary = "ANY transformational intrinsic";
|
||||
let description = [{
|
||||
|
@ -442,16 +442,19 @@ mlir::LogicalResult hlfir::ParentComponentOp::verify() {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AnyOp
|
||||
// LogicalReductionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
mlir::LogicalResult hlfir::AnyOp::verify() {
|
||||
mlir::Operation *op = getOperation();
|
||||
template <typename LogicalReductionOp>
|
||||
static mlir::LogicalResult
|
||||
verifyLogicalReductionOp(LogicalReductionOp reductionOp) {
|
||||
mlir::Operation *op = reductionOp->getOperation();
|
||||
|
||||
auto results = op->getResultTypes();
|
||||
assert(results.size() == 1);
|
||||
|
||||
mlir::Value mask = getMask();
|
||||
mlir::Value dim = getDim();
|
||||
mlir::Value mask = reductionOp->getMask();
|
||||
mlir::Value dim = reductionOp->getDim();
|
||||
|
||||
fir::SequenceType maskTy =
|
||||
hlfir::getFortranElementOrSequenceType(mask.getType())
|
||||
.cast<fir::SequenceType>();
|
||||
@ -462,7 +465,7 @@ mlir::LogicalResult hlfir::AnyOp::verify() {
|
||||
if (mlir::isa<fir::LogicalType>(resultType)) {
|
||||
// Result is of the same type as MASK
|
||||
if (resultType != logicalTy)
|
||||
return emitOpError(
|
||||
return reductionOp->emitOpError(
|
||||
"result must have the same element type as MASK argument");
|
||||
|
||||
} else if (auto resultExpr =
|
||||
@ -470,25 +473,42 @@ mlir::LogicalResult hlfir::AnyOp::verify() {
|
||||
// Result should only be in hlfir.expr form if it is an array
|
||||
if (maskShape.size() > 1 && dim != nullptr) {
|
||||
if (!resultExpr.isArray())
|
||||
return emitOpError("result must be an array");
|
||||
return reductionOp->emitOpError("result must be an array");
|
||||
|
||||
if (resultExpr.getEleTy() != logicalTy)
|
||||
return emitOpError(
|
||||
return reductionOp->emitOpError(
|
||||
"result must have the same element type as MASK argument");
|
||||
|
||||
llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
|
||||
// Result has rank n-1
|
||||
if (resultShape.size() != (maskShape.size() - 1))
|
||||
return emitOpError("result rank must be one less than MASK");
|
||||
return reductionOp->emitOpError(
|
||||
"result rank must be one less than MASK");
|
||||
} else {
|
||||
return emitOpError("result must be of logical type");
|
||||
return reductionOp->emitOpError("result must be of logical type");
|
||||
}
|
||||
} else {
|
||||
return emitOpError("result must be of logical type");
|
||||
return reductionOp->emitOpError("result must be of logical type");
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
mlir::LogicalResult hlfir::AllOp::verify() {
|
||||
return verifyLogicalReductionOp<hlfir::AllOp *>(this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AnyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
mlir::LogicalResult hlfir::AnyOp::verify() {
|
||||
return verifyLogicalReductionOp<hlfir::AnyOp *>(this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConcatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -537,11 +557,12 @@ void hlfir::ConcatOp::build(mlir::OpBuilder &builder,
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReductionOp
|
||||
// NumericalReductionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename ReductionOp>
|
||||
static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
|
||||
template <typename NumericalReductionOp>
|
||||
static mlir::LogicalResult
|
||||
verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
|
||||
mlir::Operation *op = reductionOp->getOperation();
|
||||
|
||||
auto results = op->getResultTypes();
|
||||
@ -619,7 +640,7 @@ static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
mlir::LogicalResult hlfir::ProductOp::verify() {
|
||||
return verifyReductionOp<hlfir::ProductOp *>(this);
|
||||
return verifyNumericalReductionOp<hlfir::ProductOp *>(this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -645,7 +666,7 @@ void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
mlir::LogicalResult hlfir::SumOp::verify() {
|
||||
return verifyReductionOp<hlfir::SumOp *>(this);
|
||||
return verifyNumericalReductionOp<hlfir::SumOp *>(this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
113
flang/test/HLFIR/all.fir
Normal file
113
flang/test/HLFIR/all.fir
Normal file
@ -0,0 +1,113 @@
|
||||
// Test hlfir.all operation parse, verify (no errors), and unparse
|
||||
|
||||
// RUN: fir-opt %s | fir-opt | FileCheck %s
|
||||
|
||||
// mask is an expression of known shape
|
||||
func.func @all0(%arg0: !hlfir.expr<2x!fir.logical<4>>) {
|
||||
%all = hlfir.all %arg0 : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @all0(%[[ARRAY:.*]]: !hlfir.expr<2x!fir.logical<4>>) {
|
||||
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// mask is an expression of assumed shape
|
||||
func.func @all1(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
|
||||
%all = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @all1(%[[ARRAY:.*]]: !hlfir.expr<?x!fir.logical<4>>) {
|
||||
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// mask is a boxed array
|
||||
func.func @all2(%arg0: !fir.box<!fir.array<2x!fir.logical<4>>>) {
|
||||
%all = hlfir.all %arg0 : (!fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @all2(%[[ARRAY:.*]]: !fir.box<!fir.array<2x!fir.logical<4>>>) {
|
||||
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// mask is an assumed shape boxed array
|
||||
func.func @all3(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>){
|
||||
%all = hlfir.all %arg0 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @all3(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>) {
|
||||
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// mask is a 2-dimensional array
|
||||
func.func @all4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>>){
|
||||
%all = hlfir.all %arg0 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @all4(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>) {
|
||||
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// mask and dim argument
|
||||
func.func @all5(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>, %arg1: i32) {
|
||||
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x!fir.logical<4>>>, i32) -> !fir.logical<4>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @all5(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
|
||||
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, i32) -> !fir.logical<4>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// hlfir.all with dim argument with an unusual type
|
||||
func.func @all6(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>, %arg1: index) {
|
||||
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x!fir.logical<4>>>, index) ->!fir.logical<4>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @all6(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[DIM:.*]]: index) {
|
||||
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, index) -> !fir.logical<4>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// mask is a 2 dimensional array with dim
|
||||
func.func @all7(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %arg1: i32) {
|
||||
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i32) -> !hlfir.expr<?x!fir.logical<4>>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @all7(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
|
||||
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i32) -> !hlfir.expr<?x!fir.logical<4>>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// known shape expr return
|
||||
func.func @all8(%arg0: !fir.box<!fir.array<2x2x!fir.logical<4>>>, %arg1: i32) {
|
||||
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<2x2x!fir.logical<4>>>, i32) -> !hlfir.expr<2x!fir.logical<4>>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @all8(%[[ARRAY:.*]]: !fir.box<!fir.array<2x2x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
|
||||
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<2x2x!fir.logical<4>>>, i32) -> !hlfir.expr<2x!fir.logical<4>>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// hlfir.all with mask argument of ref<array<>> type
|
||||
func.func @all9(%arg0: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
|
||||
%all = hlfir.all %arg0 : (!fir.ref<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @all9(%[[ARRAY:.*]]: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
|
||||
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.ref<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// hlfir.all with fir.logical<8> type
|
||||
func.func @all10(%arg0: !fir.box<!fir.array<?x!fir.logical<8>>>) {
|
||||
%all = hlfir.all %arg0 : (!fir.box<!fir.array<?x!fir.logical<8>>>) -> !fir.logical<8>
|
||||
return
|
||||
}
|
||||
// CHECK: func.func @all10(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<8>>>) {
|
||||
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x!fir.logical<8>>>) -> !fir.logical<8>
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-NEXT: }
|
@ -332,6 +332,42 @@ func.func @bad_any6(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
|
||||
%0 = hlfir.any %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<!fir.logical<4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @bad_all1(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
|
||||
// expected-error@+1 {{'hlfir.all' op result must have the same element type as MASK argument}}
|
||||
%0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<8>
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @bad_all2(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32) {
|
||||
// expected-error@+1 {{'hlfir.all' op result must have the same element type as MASK argument}}
|
||||
%0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<?x!fir.logical<8>>
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @bad_all3(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32){
|
||||
// expected-error@+1 {{'hlfir.all' op result rank must be one less than MASK}}
|
||||
%0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<?x?x!fir.logical<4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @bad_all4(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32) {
|
||||
// expected-error@+1 {{'hlfir.all' op result must be an array}}
|
||||
%0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<!fir.logical<4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @bad_all5(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
|
||||
// expected-error@+1 {{'hlfir.all' op result must be of logical type}}
|
||||
%0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> i32
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @bad_all6(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
|
||||
// expected-error@+1 {{'hlfir.all' op result must be of logical type}}
|
||||
%0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<!fir.logical<4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @bad_product1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
|
||||
// expected-error@+1 {{'hlfir.product' op result must have the same element type as ARRAY argument}}
|
||||
|
Loading…
x
Reference in New Issue
Block a user