diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h index 4d8867ffd4d2..bb368bd2ef82 100644 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ b/mlir/include/mlir/StandardOps/StandardOps.h @@ -199,6 +199,8 @@ public: static bool parse(OpAsmParser *parser, OperationState *result); void print(OpAsmPrinter *p) const; bool verify() const; + Attribute constantFold(ArrayRef operands, + MLIRContext *context) const; private: friend class OperationInst; diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index f0ce38aaddee..3a12a9c4b8db 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -571,6 +571,50 @@ bool CmpIOp::verify() const { return false; } +// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer +// comparison predicates. +static bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, + const APInt &rhs) { + switch (predicate) { + case CmpIPredicate::EQ: + return lhs.eq(rhs); + case CmpIPredicate::NE: + return lhs.ne(rhs); + case CmpIPredicate::SLT: + return lhs.slt(rhs); + case CmpIPredicate::SLE: + return lhs.sle(rhs); + case CmpIPredicate::SGT: + return lhs.sgt(rhs); + case CmpIPredicate::SGE: + return lhs.sge(rhs); + case CmpIPredicate::ULT: + return lhs.ult(rhs); + case CmpIPredicate::ULE: + return lhs.ule(rhs); + case CmpIPredicate::UGT: + return lhs.ugt(rhs); + case CmpIPredicate::UGE: + return lhs.uge(rhs); + default: + llvm_unreachable("unknown comparison predicate"); + } +} + +// Constant folding hook for comparisons. +Attribute CmpIOp::constantFold(ArrayRef operands, + MLIRContext *context) const { + assert(operands.size() == 2 && "cmpi takes two arguments"); + + auto lhs = operands.front().dyn_cast_or_null(); + auto rhs = operands.back().dyn_cast_or_null(); + if (!lhs || !rhs) + return {}; + + auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); + return IntegerAttr::get(IntegerType::get(1, context), APInt(1, val)); +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir index be65cc44d970..78f65c17b52c 100644 --- a/mlir/test/Transforms/constant-fold.mlir +++ b/mlir/test/Transforms/constant-fold.mlir @@ -227,3 +227,29 @@ func @dim(%x : tensor<8x4xf32>) -> index { return %0 : index } +// CHECK-LABEL: func @cmpi +func @cmpi() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) { + %c42 = constant 42 : i32 + %cm1 = constant -1 : i32 +// CHECK-NEXT: %false = constant 0 : i1 + %0 = cmpi "eq", %c42, %cm1 : i32 +// CHECK-NEXT: %true = constant 1 : i1 + %1 = cmpi "ne", %c42, %cm1 : i32 +// CHECK-NEXT: %false_0 = constant 0 : i1 + %2 = cmpi "slt", %c42, %cm1 : i32 +// CHECK-NEXT: %false_1 = constant 0 : i1 + %3 = cmpi "sle", %c42, %cm1 : i32 +// CHECK-NEXT: %true_2 = constant 1 : i1 + %4 = cmpi "sgt", %c42, %cm1 : i32 +// CHECK-NEXT: %true_3 = constant 1 : i1 + %5 = cmpi "sge", %c42, %cm1 : i32 +// CHECK-NEXT: %true_4 = constant 1 : i1 + %6 = cmpi "ult", %c42, %cm1 : i32 +// CHECK-NEXT: %true_5 = constant 1 : i1 + %7 = cmpi "ule", %c42, %cm1 : i32 +// CHECK-NEXT: %false_6 = constant 0 : i1 + %8 = cmpi "ugt", %c42, %cm1 : i32 +// CHECK-NEXT: %false_7 = constant 0 : i1 + %9 = cmpi "uge", %c42, %cm1 : i32 + return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1 +}