Use matcher sugars for cannonicalization pattern matching

- Added a mechanism for specifying pattern matching more concisely like LLVM.
- Added support for canonicalization of addi/muli over vector/tensor splat
- Added ValueType to Attribute class hierarchy
- Allowed creating constant splat

PiperOrigin-RevId: 219149621
This commit is contained in:
Lei Zhang 2018-10-29 10:22:49 -07:00 committed by jpienaar
parent 1ec77cecf2
commit 582b0761c6
7 changed files with 243 additions and 35 deletions

View File

@ -80,7 +80,8 @@ public:
LAST_ELEMENTS_ATTR = SparseElements,
};
typedef detail::AttributeStorage ImplType;
using ImplType = detail::AttributeStorage;
using ValueType = void;
Attribute() : attr(nullptr) {}
/* implicit */ Attribute(const ImplType *attr)
@ -126,7 +127,9 @@ inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
class BoolAttr : public Attribute {
public:
typedef detail::BoolAttributeStorage ImplType;
using ImplType = detail::BoolAttributeStorage;
using ValueType = bool;
BoolAttr() = default;
/* implicit */ BoolAttr(Attribute::ImplType *ptr);
@ -140,7 +143,9 @@ public:
class IntegerAttr : public Attribute {
public:
typedef detail::IntegerAttributeStorage ImplType;
using ImplType = detail::IntegerAttributeStorage;
using ValueType = int64_t;
IntegerAttr() = default;
/* implicit */ IntegerAttr(Attribute::ImplType *ptr);
@ -154,7 +159,9 @@ public:
class FloatAttr final : public Attribute {
public:
typedef detail::FloatAttributeStorage ImplType;
using ImplType = detail::FloatAttributeStorage;
using ValueType = APFloat;
FloatAttr() = default;
/* implicit */ FloatAttr(Attribute::ImplType *ptr);
@ -171,7 +178,9 @@ public:
class StringAttr : public Attribute {
public:
typedef detail::StringAttributeStorage ImplType;
using ImplType = detail::StringAttributeStorage;
using ValueType = StringRef;
StringAttr() = default;
/* implicit */ StringAttr(Attribute::ImplType *ptr);
@ -187,7 +196,9 @@ public:
/// type homogenous given that attributes don't, in general, carry types.
class ArrayAttr : public Attribute {
public:
typedef detail::ArrayAttributeStorage ImplType;
using ImplType = detail::ArrayAttributeStorage;
using ValueType = ArrayRef<Attribute>;
ArrayAttr() = default;
/* implicit */ ArrayAttr(Attribute::ImplType *ptr);
@ -201,7 +212,9 @@ public:
class AffineMapAttr : public Attribute {
public:
typedef detail::AffineMapAttributeStorage ImplType;
using ImplType = detail::AffineMapAttributeStorage;
using ValueType = AffineMap;
AffineMapAttr() = default;
/* implicit */ AffineMapAttr(Attribute::ImplType *ptr);
@ -215,7 +228,9 @@ public:
class IntegerSetAttr : public Attribute {
public:
typedef detail::IntegerSetAttributeStorage ImplType;
using ImplType = detail::IntegerSetAttributeStorage;
using ValueType = IntegerSet;
IntegerSetAttr() = default;
/* implicit */ IntegerSetAttr(Attribute::ImplType *ptr);
@ -229,7 +244,9 @@ public:
class TypeAttr : public Attribute {
public:
typedef detail::TypeAttributeStorage ImplType;
using ImplType = detail::TypeAttributeStorage;
using ValueType = Type *;
TypeAttr() = default;
/* implicit */ TypeAttr(Attribute::ImplType *ptr);
@ -250,7 +267,9 @@ public:
/// remain in MLIRContext.
class FunctionAttr : public Attribute {
public:
typedef detail::FunctionAttributeStorage ImplType;
using ImplType = detail::FunctionAttributeStorage;
using ValueType = Function *;
FunctionAttr() = default;
/* implicit */ FunctionAttr(Attribute::ImplType *ptr);
@ -288,7 +307,9 @@ public:
/// meaning all of the elements have the same value.
class SplatElementsAttr : public ElementsAttr {
public:
typedef detail::SplatElementsAttributeStorage ImplType;
using ImplType = detail::SplatElementsAttributeStorage;
using ValueType = Attribute;
SplatElementsAttr() = default;
/* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
@ -305,7 +326,8 @@ public:
/// than 64.
class DenseElementsAttr : public ElementsAttr {
public:
typedef detail::DenseElementsAttributeStorage ImplType;
using ImplType = detail::DenseElementsAttributeStorage;
DenseElementsAttr() = default;
/* implicit */ DenseElementsAttr(Attribute::ImplType *ptr);
@ -335,7 +357,8 @@ public:
/// object.
class DenseIntElementsAttr : public DenseElementsAttr {
public:
typedef detail::DenseIntElementsAttributeStorage ImplType;
using ImplType = detail::DenseIntElementsAttributeStorage;
DenseIntElementsAttr() = default;
/* implicit */ DenseIntElementsAttr(Attribute::ImplType *ptr);
@ -362,7 +385,8 @@ public:
/// object. Each element is stored as a double.
class DenseFPElementsAttr : public DenseElementsAttr {
public:
typedef detail::DenseFPElementsAttributeStorage ImplType;
using ImplType = detail::DenseFPElementsAttributeStorage;
DenseFPElementsAttr() = default;
/* implicit */ DenseFPElementsAttr(Attribute::ImplType *ptr);
@ -380,7 +404,9 @@ public:
/// doesn't need to interpret.
class OpaqueElementsAttr : public ElementsAttr {
public:
typedef detail::OpaqueElementsAttributeStorage ImplType;
using ImplType = detail::OpaqueElementsAttributeStorage;
using ValueType = StringRef;
OpaqueElementsAttr() = default;
/* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
@ -409,7 +435,8 @@ public:
/// [0, 0, 0, 0]].
class SparseElementsAttr : public ElementsAttr {
public:
typedef detail::SparseElementsAttributeStorage ImplType;
using ImplType = detail::SparseElementsAttributeStorage;
SparseElementsAttr() = default;
/* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);

View File

@ -0,0 +1,142 @@
//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file provides a simple and efficient mechanism for performing general
// tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
// include/llvm/IR/PatternMatch.h.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_MATCHERS_H
#define MLIR_MATCHERS_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include <type_traits>
namespace mlir {
namespace detail {
/// The matcher that matches a certain kind of Attribute and binds the value
/// inside the Attribute.
template <
typename AttrClass,
// Require AttrClass to be a derived class from Atribute and get its
// value type
typename ValueType =
typename std::enable_if<std::is_base_of<Attribute, AttrClass>::value,
AttrClass>::type::ValueType,
// Require the ValueType is not void
typename = typename std::enable_if<!std::is_void<ValueType>::value>::type>
struct attr_value_binder {
ValueType *bind_value;
/// Creates a matcher instance that binds the value to bv if match succeeds.
attr_value_binder(ValueType *bv) : bind_value(bv) {}
bool match(const Attribute &attr) {
if (auto intAttr = attr.dyn_cast<AttrClass>()) {
*bind_value = intAttr.getValue();
return true;
}
return false;
}
};
/// The matcher that matches a constant scalar / vector splat / tensor splat
/// integer operation and binds the constant integer value.
struct constant_int_op_binder {
IntegerAttr::ValueType *bind_value;
/// Creates a matcher instance that binds the value to bv if match succeeds.
constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
bool match(Operation *op) {
if (auto constOp = op->dyn_cast<ConstantOp>()) {
auto *type = constOp->getResult()->getType();
auto attr = constOp->getAttr("value");
if (isa<IntegerType>(type)) {
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
}
if (isa<VectorOrTensorType>(type)) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<IntegerAttr>(bind_value)
.match(splatAttr.getValue());
}
}
}
return false;
}
};
// The matcher that matches a given target constant scalar / vector splat /
// tensor splat integer value.
template <IntegerAttr::ValueType TargetValue>
struct constant_int_value_matcher {
bool match(Operation *op) {
IntegerAttr::ValueType value;
return constant_int_op_binder(&value).match(op) && TargetValue == value;
}
};
/// The matcher that matches a certain kind of op.
template <typename OpClass> struct op_matcher {
bool match(Operation *op) { return op->isa<OpClass>(); }
};
} // end namespace detail
/// Entry point for matching a pattern over an SSAValue.
template <typename Pattern>
inline bool match(SSAValue *value, const Pattern &pattern) {
// TODO: handle other cases
if (auto *op = value->getDefiningOperation())
return const_cast<Pattern &>(pattern).match(op);
return false;
}
/// Matches a ConstantIndexOp.
inline detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
return detail::op_matcher<ConstantIndexOp>();
}
/// Matches a ConstantOp holding a scalar/vector/tensor integer (splat) and
/// writes the integer value to bind_value.
inline detail::constant_int_op_binder
m_ConstantInt(IntegerAttr::ValueType *bind_value) {
return detail::constant_int_op_binder(bind_value);
}
/// Matches a constant scalar / vector splat / tensor splat integer one.
inline detail::constant_int_value_matcher<1> m_One() {
return detail::constant_int_value_matcher<1>();
}
/// Matches a constant scalar / vector splat / tensor splat integer zero.
inline detail::constant_int_value_matcher<0> m_Zero() {
return detail::constant_int_value_matcher<0>();
}
} // end namespace mlir
#endif // MLIR_MATCHERS_H

View File

@ -279,7 +279,7 @@ private:
/// The group of patterns that are matched for optimization through this
/// matcher.
std::vector<std::unique_ptr<Pattern>> patterns;
OwningPatternList patterns;
};
//===----------------------------------------------------------------------===//

View File

@ -222,6 +222,12 @@ bool ConstantOp::verify() const {
return false;
}
if (isa<VectorOrTensorType>(type)) {
if (!value.isa<ElementsAttr>())
return emitOpError("requires 'value' to be a vector/tensor constant");
return false;
}
if (type->isTFString()) {
if (!value.isa<StringAttr>())
return emitOpError("requires 'value' to be a string constant");

View File

@ -20,6 +20,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SSAValue.h"
@ -46,6 +47,11 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
//===----------------------------------------------------------------------===//
namespace {
/// Matches a MemRefCastOp.
inline detail::op_matcher<MemRefCastOp> m_MemRefCast() {
return detail::op_matcher<MemRefCastOp>();
}
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
/// into the root operation directly.
@ -57,9 +63,8 @@ struct MemRefCastFolder : public Pattern {
std::pair<PatternBenefit, std::unique_ptr<PatternState>>
match(Operation *op) const override {
for (auto *operand : op->getOperands())
if (auto *memref = operand->getDefiningOperation())
if (memref->isa<MemRefCastOp>())
return matchSuccess();
if (::match(operand, m_MemRefCast()))
return matchSuccess();
return matchFailure();
}
@ -116,12 +121,9 @@ struct SimplifyAddX0 : public Pattern {
std::pair<PatternBenefit, std::unique_ptr<PatternState>>
match(Operation *op) const override {
auto addi = op->cast<AddIOp>();
if (auto *operandOp = addi->getOperand(1)->getDefiningOperation())
// TODO: Support splatted zero as well. We need a general zero pattern.
if (auto cst = operandOp->dyn_cast<ConstantIntOp>()) {
if (cst->getValue() == 0)
return matchSuccess();
}
if (::match(addi->getOperand(1), m_Zero()))
return matchSuccess();
return matchFailure();
}
@ -230,9 +232,8 @@ struct SimplifyAllocConst : public Pattern {
// Check to see if any dimensions operands are constants. If so, we can
// substitute and drop them.
for (auto *operand : alloc->getOperands())
if (auto *opOperation = operand->getDefiningOperation())
if (opOperation->isa<ConstantIndexOp>())
return matchSuccess();
if (::match(operand, m_ConstantIndex()))
return matchSuccess();
return matchFailure();
}
@ -892,12 +893,9 @@ struct SimplifyMulX1 : public Pattern {
std::pair<PatternBenefit, std::unique_ptr<PatternState>>
match(Operation *op) const override {
auto muli = op->cast<MulIOp>();
if (auto *operandOp = muli->getOperand(1)->getDefiningOperation())
// TODO: Support splatted one as well. We need a general one pattern.
if (auto cst = operandOp->dyn_cast<ConstantIntOp>()) {
if (cst->getValue() == 1)
return matchSuccess();
}
if (::match(muli->getOperand(1), m_One()))
return matchSuccess();
return matchFailure();
}

View File

@ -85,6 +85,9 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32, %i: i32):
// CHECK: %f_2 = constant @affine_apply : () -> ()
%12 = constant @affine_apply : () -> ()
// CHECK: %cst_3 = constant splat<vector<4xi32>, 0> : vector<4xi32>
%13 = constant splat<vector<4 x i32>, 0> : vector<4 x i32>
return
}

View File

@ -58,6 +58,22 @@ mlfunc @addi_zero(%arg0: i32) -> i32 {
return %y: i32
}
// CHECK-LABEL: mlfunc @addi_zero_vector
mlfunc @addi_zero_vector(%arg0: vector<4 x i32>) -> vector<4 x i32> {
// CHECK-NEXT: return %arg0
%c0_v4i32 = constant splat<vector<4 x i32>, 0> : vector<4 x i32>
%y = addi %c0_v4i32, %arg0 : vector<4 x i32>
return %y: vector<4 x i32>
}
// CHECK-LABEL: mlfunc @addi_zero_tensor
mlfunc @addi_zero_tensor(%arg0: tensor<4 x 5 x i32>) -> tensor<4 x 5 x i32> {
// CHECK-NEXT: return %arg0
%c0_t45i32 = constant splat<tensor<4 x 5 x i32>, 0> : tensor<4 x 5 x i32>
%y = addi %arg0, %c0_t45i32 : tensor<4 x 5 x i32>
return %y: tensor<4 x 5 x i32>
}
// CHECK-LABEL: mlfunc @muli_one
mlfunc @muli_one(%arg0: i32) -> i32 {
// CHECK-NEXT: return %arg0
@ -66,6 +82,22 @@ mlfunc @muli_one(%arg0: i32) -> i32 {
return %y: i32
}
// CHECK-LABEL: mlfunc @muli_one_vector
mlfunc @muli_one_vector(%arg0: vector<4 x i32>) -> vector<4 x i32> {
// CHECK-NEXT: return %arg0
%c1_v4i32 = constant splat<vector<4 x i32>, 1> : vector<4 x i32>
%y = muli %c1_v4i32, %arg0 : vector<4 x i32>
return %y: vector<4 x i32>
}
// CHECK-LABEL: mlfunc @muli_one_tensor
mlfunc @muli_one_tensor(%arg0: tensor<4 x 5 x i32>) -> tensor<4 x 5 x i32> {
// CHECK-NEXT: return %arg0
%c1_t45i32 = constant splat<tensor<4 x 5 x i32>, 1> : tensor<4 x 5 x i32>
%y = muli %arg0, %c1_t45i32 : tensor<4 x 5 x i32>
return %y: tensor<4 x 5 x i32>
}
// CHECK-LABEL: mlfunc @memref_cast_folding
mlfunc @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 {
%1 = memref_cast %arg0: memref<4xf32> to memref<?xf32>