[flang] add hlfir.all intrinsic

Adds a new HLFIR operation for the ALL intrinsic according to the
design set out in flang/docs/HighLevel.md

Differential Revision: https://reviews.llvm.org/D151090
This commit is contained in:
Jacob Crawley 2023-05-22 13:07:28 +00:00
parent 544a240ff7
commit 206b8538a6
4 changed files with 207 additions and 16 deletions

View File

@ -317,6 +317,27 @@ def hlfir_ConcatOp : hlfir_Op<"concat", []> {
let hasVerifier = 1;
}
def hlfir_AllOp : hlfir_Op<"all", []> {
let summary = "ALL transformational intrinsic";
let description = [{
Takes a logical array MASK as argument, optionally along a particular dimension,
and returns true if all elements of MASK are true.
}];
let arguments = (ins
AnyFortranLogicalArrayObject:$mask,
Optional<AnyIntegerType>:$dim
);
let results = (outs AnyFortranValue);
let assemblyFormat = [{
$mask (`dim` $dim^)? attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
}
def hlfir_AnyOp : hlfir_Op<"any", []> {
let summary = "ANY transformational intrinsic";
let description = [{

View File

@ -442,16 +442,19 @@ mlir::LogicalResult hlfir::ParentComponentOp::verify() {
}
//===----------------------------------------------------------------------===//
// AnyOp
// LogicalReductionOp
//===----------------------------------------------------------------------===//
mlir::LogicalResult hlfir::AnyOp::verify() {
mlir::Operation *op = getOperation();
template <typename LogicalReductionOp>
static mlir::LogicalResult
verifyLogicalReductionOp(LogicalReductionOp reductionOp) {
mlir::Operation *op = reductionOp->getOperation();
auto results = op->getResultTypes();
assert(results.size() == 1);
mlir::Value mask = getMask();
mlir::Value dim = getDim();
mlir::Value mask = reductionOp->getMask();
mlir::Value dim = reductionOp->getDim();
fir::SequenceType maskTy =
hlfir::getFortranElementOrSequenceType(mask.getType())
.cast<fir::SequenceType>();
@ -462,7 +465,7 @@ mlir::LogicalResult hlfir::AnyOp::verify() {
if (mlir::isa<fir::LogicalType>(resultType)) {
// Result is of the same type as MASK
if (resultType != logicalTy)
return emitOpError(
return reductionOp->emitOpError(
"result must have the same element type as MASK argument");
} else if (auto resultExpr =
@ -470,25 +473,42 @@ mlir::LogicalResult hlfir::AnyOp::verify() {
// Result should only be in hlfir.expr form if it is an array
if (maskShape.size() > 1 && dim != nullptr) {
if (!resultExpr.isArray())
return emitOpError("result must be an array");
return reductionOp->emitOpError("result must be an array");
if (resultExpr.getEleTy() != logicalTy)
return emitOpError(
return reductionOp->emitOpError(
"result must have the same element type as MASK argument");
llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
// Result has rank n-1
if (resultShape.size() != (maskShape.size() - 1))
return emitOpError("result rank must be one less than MASK");
return reductionOp->emitOpError(
"result rank must be one less than MASK");
} else {
return emitOpError("result must be of logical type");
return reductionOp->emitOpError("result must be of logical type");
}
} else {
return emitOpError("result must be of logical type");
return reductionOp->emitOpError("result must be of logical type");
}
return mlir::success();
}
//===----------------------------------------------------------------------===//
// AllOp
//===----------------------------------------------------------------------===//
mlir::LogicalResult hlfir::AllOp::verify() {
return verifyLogicalReductionOp<hlfir::AllOp *>(this);
}
//===----------------------------------------------------------------------===//
// AnyOp
//===----------------------------------------------------------------------===//
mlir::LogicalResult hlfir::AnyOp::verify() {
return verifyLogicalReductionOp<hlfir::AnyOp *>(this);
}
//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//
@ -537,11 +557,12 @@ void hlfir::ConcatOp::build(mlir::OpBuilder &builder,
}
//===----------------------------------------------------------------------===//
// ReductionOp
// NumericalReductionOp
//===----------------------------------------------------------------------===//
template <typename ReductionOp>
static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
template <typename NumericalReductionOp>
static mlir::LogicalResult
verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
mlir::Operation *op = reductionOp->getOperation();
auto results = op->getResultTypes();
@ -619,7 +640,7 @@ static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
//===----------------------------------------------------------------------===//
mlir::LogicalResult hlfir::ProductOp::verify() {
return verifyReductionOp<hlfir::ProductOp *>(this);
return verifyNumericalReductionOp<hlfir::ProductOp *>(this);
}
//===----------------------------------------------------------------------===//
@ -645,7 +666,7 @@ void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
//===----------------------------------------------------------------------===//
mlir::LogicalResult hlfir::SumOp::verify() {
return verifyReductionOp<hlfir::SumOp *>(this);
return verifyNumericalReductionOp<hlfir::SumOp *>(this);
}
//===----------------------------------------------------------------------===//

113
flang/test/HLFIR/all.fir Normal file
View File

@ -0,0 +1,113 @@
// Test hlfir.all operation parse, verify (no errors), and unparse
// RUN: fir-opt %s | fir-opt | FileCheck %s
// mask is an expression of known shape
func.func @all0(%arg0: !hlfir.expr<2x!fir.logical<4>>) {
%all = hlfir.all %arg0 : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4>
return
}
// CHECK: func.func @all0(%[[ARRAY:.*]]: !hlfir.expr<2x!fir.logical<4>>) {
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4>
// CHECK-NEXT: return
// CHECK-NEXT: }
// mask is an expression of assumed shape
func.func @all1(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
%all = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
return
}
// CHECK: func.func @all1(%[[ARRAY:.*]]: !hlfir.expr<?x!fir.logical<4>>) {
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
// CHECK-NEXT: return
// CHECK-NEXT: }
// mask is a boxed array
func.func @all2(%arg0: !fir.box<!fir.array<2x!fir.logical<4>>>) {
%all = hlfir.all %arg0 : (!fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
return
}
// CHECK: func.func @all2(%[[ARRAY:.*]]: !fir.box<!fir.array<2x!fir.logical<4>>>) {
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
// CHECK-NEXT: return
// CHECK-NEXT: }
// mask is an assumed shape boxed array
func.func @all3(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>){
%all = hlfir.all %arg0 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
return
}
// CHECK: func.func @all3(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>) {
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
// CHECK-NEXT: return
// CHECK-NEXT: }
// mask is a 2-dimensional array
func.func @all4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>>){
%all = hlfir.all %arg0 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.logical<4>
return
}
// CHECK: func.func @all4(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>) {
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.logical<4>
// CHECK-NEXT: return
// CHECK-NEXT: }
// mask and dim argument
func.func @all5(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>, %arg1: i32) {
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x!fir.logical<4>>>, i32) -> !fir.logical<4>
return
}
// CHECK: func.func @all5(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, i32) -> !fir.logical<4>
// CHECK-NEXT: return
// CHECK-NEXT: }
// hlfir.all with dim argument with an unusual type
func.func @all6(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>, %arg1: index) {
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x!fir.logical<4>>>, index) ->!fir.logical<4>
return
}
// CHECK: func.func @all6(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[DIM:.*]]: index) {
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, index) -> !fir.logical<4>
// CHECK-NEXT: return
// CHECK-NEXT: }
// mask is a 2 dimensional array with dim
func.func @all7(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %arg1: i32) {
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i32) -> !hlfir.expr<?x!fir.logical<4>>
return
}
// CHECK: func.func @all7(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i32) -> !hlfir.expr<?x!fir.logical<4>>
// CHECK-NEXT: return
// CHECK-NEXT: }
// known shape expr return
func.func @all8(%arg0: !fir.box<!fir.array<2x2x!fir.logical<4>>>, %arg1: i32) {
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<2x2x!fir.logical<4>>>, i32) -> !hlfir.expr<2x!fir.logical<4>>
return
}
// CHECK: func.func @all8(%[[ARRAY:.*]]: !fir.box<!fir.array<2x2x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<2x2x!fir.logical<4>>>, i32) -> !hlfir.expr<2x!fir.logical<4>>
// CHECK-NEXT: return
// CHECK-NEXT: }
// hlfir.all with mask argument of ref<array<>> type
func.func @all9(%arg0: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
%all = hlfir.all %arg0 : (!fir.ref<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
return
}
// CHECK: func.func @all9(%[[ARRAY:.*]]: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.ref<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
// CHECK-NEXT: return
// CHECK-NEXT: }
// hlfir.all with fir.logical<8> type
func.func @all10(%arg0: !fir.box<!fir.array<?x!fir.logical<8>>>) {
%all = hlfir.all %arg0 : (!fir.box<!fir.array<?x!fir.logical<8>>>) -> !fir.logical<8>
return
}
// CHECK: func.func @all10(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<8>>>) {
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x!fir.logical<8>>>) -> !fir.logical<8>
// CHECK-NEXT: return
// CHECK-NEXT: }

View File

@ -332,6 +332,42 @@ func.func @bad_any6(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
%0 = hlfir.any %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<!fir.logical<4>>
}
// -----
func.func @bad_all1(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.all' op result must have the same element type as MASK argument}}
%0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<8>
}
// -----
func.func @bad_all2(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32) {
// expected-error@+1 {{'hlfir.all' op result must have the same element type as MASK argument}}
%0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<?x!fir.logical<8>>
}
// -----
func.func @bad_all3(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32){
// expected-error@+1 {{'hlfir.all' op result rank must be one less than MASK}}
%0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<?x?x!fir.logical<4>>
}
// -----
func.func @bad_all4(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32) {
// expected-error@+1 {{'hlfir.all' op result must be an array}}
%0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<!fir.logical<4>>
}
// -----
func.func @bad_all5(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.all' op result must be of logical type}}
%0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> i32
}
// -----
func.func @bad_all6(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.all' op result must be of logical type}}
%0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<!fir.logical<4>>
}
// -----
func.func @bad_product1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
// expected-error@+1 {{'hlfir.product' op result must have the same element type as ARRAY argument}}