mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-29 03:06:07 +00:00
Use matcher sugars for cannonicalization pattern matching
- Added a mechanism for specifying pattern matching more concisely like LLVM. - Added support for canonicalization of addi/muli over vector/tensor splat - Added ValueType to Attribute class hierarchy - Allowed creating constant splat PiperOrigin-RevId: 219149621
This commit is contained in:
parent
1ec77cecf2
commit
582b0761c6
@ -80,7 +80,8 @@ public:
|
||||
LAST_ELEMENTS_ATTR = SparseElements,
|
||||
};
|
||||
|
||||
typedef detail::AttributeStorage ImplType;
|
||||
using ImplType = detail::AttributeStorage;
|
||||
using ValueType = void;
|
||||
|
||||
Attribute() : attr(nullptr) {}
|
||||
/* implicit */ Attribute(const ImplType *attr)
|
||||
@ -126,7 +127,9 @@ inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
|
||||
|
||||
class BoolAttr : public Attribute {
|
||||
public:
|
||||
typedef detail::BoolAttributeStorage ImplType;
|
||||
using ImplType = detail::BoolAttributeStorage;
|
||||
using ValueType = bool;
|
||||
|
||||
BoolAttr() = default;
|
||||
/* implicit */ BoolAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -140,7 +143,9 @@ public:
|
||||
|
||||
class IntegerAttr : public Attribute {
|
||||
public:
|
||||
typedef detail::IntegerAttributeStorage ImplType;
|
||||
using ImplType = detail::IntegerAttributeStorage;
|
||||
using ValueType = int64_t;
|
||||
|
||||
IntegerAttr() = default;
|
||||
/* implicit */ IntegerAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -154,7 +159,9 @@ public:
|
||||
|
||||
class FloatAttr final : public Attribute {
|
||||
public:
|
||||
typedef detail::FloatAttributeStorage ImplType;
|
||||
using ImplType = detail::FloatAttributeStorage;
|
||||
using ValueType = APFloat;
|
||||
|
||||
FloatAttr() = default;
|
||||
/* implicit */ FloatAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -171,7 +178,9 @@ public:
|
||||
|
||||
class StringAttr : public Attribute {
|
||||
public:
|
||||
typedef detail::StringAttributeStorage ImplType;
|
||||
using ImplType = detail::StringAttributeStorage;
|
||||
using ValueType = StringRef;
|
||||
|
||||
StringAttr() = default;
|
||||
/* implicit */ StringAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -187,7 +196,9 @@ public:
|
||||
/// type homogenous given that attributes don't, in general, carry types.
|
||||
class ArrayAttr : public Attribute {
|
||||
public:
|
||||
typedef detail::ArrayAttributeStorage ImplType;
|
||||
using ImplType = detail::ArrayAttributeStorage;
|
||||
using ValueType = ArrayRef<Attribute>;
|
||||
|
||||
ArrayAttr() = default;
|
||||
/* implicit */ ArrayAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -201,7 +212,9 @@ public:
|
||||
|
||||
class AffineMapAttr : public Attribute {
|
||||
public:
|
||||
typedef detail::AffineMapAttributeStorage ImplType;
|
||||
using ImplType = detail::AffineMapAttributeStorage;
|
||||
using ValueType = AffineMap;
|
||||
|
||||
AffineMapAttr() = default;
|
||||
/* implicit */ AffineMapAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -215,7 +228,9 @@ public:
|
||||
|
||||
class IntegerSetAttr : public Attribute {
|
||||
public:
|
||||
typedef detail::IntegerSetAttributeStorage ImplType;
|
||||
using ImplType = detail::IntegerSetAttributeStorage;
|
||||
using ValueType = IntegerSet;
|
||||
|
||||
IntegerSetAttr() = default;
|
||||
/* implicit */ IntegerSetAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -229,7 +244,9 @@ public:
|
||||
|
||||
class TypeAttr : public Attribute {
|
||||
public:
|
||||
typedef detail::TypeAttributeStorage ImplType;
|
||||
using ImplType = detail::TypeAttributeStorage;
|
||||
using ValueType = Type *;
|
||||
|
||||
TypeAttr() = default;
|
||||
/* implicit */ TypeAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -250,7 +267,9 @@ public:
|
||||
/// remain in MLIRContext.
|
||||
class FunctionAttr : public Attribute {
|
||||
public:
|
||||
typedef detail::FunctionAttributeStorage ImplType;
|
||||
using ImplType = detail::FunctionAttributeStorage;
|
||||
using ValueType = Function *;
|
||||
|
||||
FunctionAttr() = default;
|
||||
/* implicit */ FunctionAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -288,7 +307,9 @@ public:
|
||||
/// meaning all of the elements have the same value.
|
||||
class SplatElementsAttr : public ElementsAttr {
|
||||
public:
|
||||
typedef detail::SplatElementsAttributeStorage ImplType;
|
||||
using ImplType = detail::SplatElementsAttributeStorage;
|
||||
using ValueType = Attribute;
|
||||
|
||||
SplatElementsAttr() = default;
|
||||
/* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -305,7 +326,8 @@ public:
|
||||
/// than 64.
|
||||
class DenseElementsAttr : public ElementsAttr {
|
||||
public:
|
||||
typedef detail::DenseElementsAttributeStorage ImplType;
|
||||
using ImplType = detail::DenseElementsAttributeStorage;
|
||||
|
||||
DenseElementsAttr() = default;
|
||||
/* implicit */ DenseElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -335,7 +357,8 @@ public:
|
||||
/// object.
|
||||
class DenseIntElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
typedef detail::DenseIntElementsAttributeStorage ImplType;
|
||||
using ImplType = detail::DenseIntElementsAttributeStorage;
|
||||
|
||||
DenseIntElementsAttr() = default;
|
||||
/* implicit */ DenseIntElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -362,7 +385,8 @@ public:
|
||||
/// object. Each element is stored as a double.
|
||||
class DenseFPElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
typedef detail::DenseFPElementsAttributeStorage ImplType;
|
||||
using ImplType = detail::DenseFPElementsAttributeStorage;
|
||||
|
||||
DenseFPElementsAttr() = default;
|
||||
/* implicit */ DenseFPElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -380,7 +404,9 @@ public:
|
||||
/// doesn't need to interpret.
|
||||
class OpaqueElementsAttr : public ElementsAttr {
|
||||
public:
|
||||
typedef detail::OpaqueElementsAttributeStorage ImplType;
|
||||
using ImplType = detail::OpaqueElementsAttributeStorage;
|
||||
using ValueType = StringRef;
|
||||
|
||||
OpaqueElementsAttr() = default;
|
||||
/* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
@ -409,7 +435,8 @@ public:
|
||||
/// [0, 0, 0, 0]].
|
||||
class SparseElementsAttr : public ElementsAttr {
|
||||
public:
|
||||
typedef detail::SparseElementsAttributeStorage ImplType;
|
||||
using ImplType = detail::SparseElementsAttributeStorage;
|
||||
|
||||
SparseElementsAttr() = default;
|
||||
/* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
|
142
mlir/include/mlir/IR/Matchers.h
Normal file
142
mlir/include/mlir/IR/Matchers.h
Normal file
@ -0,0 +1,142 @@
|
||||
//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file provides a simple and efficient mechanism for performing general
|
||||
// tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
|
||||
// include/llvm/IR/PatternMatch.h.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_MATCHERS_H
|
||||
#define MLIR_MATCHERS_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include <type_traits>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// The matcher that matches a certain kind of Attribute and binds the value
|
||||
/// inside the Attribute.
|
||||
template <
|
||||
typename AttrClass,
|
||||
// Require AttrClass to be a derived class from Atribute and get its
|
||||
// value type
|
||||
typename ValueType =
|
||||
typename std::enable_if<std::is_base_of<Attribute, AttrClass>::value,
|
||||
AttrClass>::type::ValueType,
|
||||
// Require the ValueType is not void
|
||||
typename = typename std::enable_if<!std::is_void<ValueType>::value>::type>
|
||||
struct attr_value_binder {
|
||||
ValueType *bind_value;
|
||||
|
||||
/// Creates a matcher instance that binds the value to bv if match succeeds.
|
||||
attr_value_binder(ValueType *bv) : bind_value(bv) {}
|
||||
|
||||
bool match(const Attribute &attr) {
|
||||
if (auto intAttr = attr.dyn_cast<AttrClass>()) {
|
||||
*bind_value = intAttr.getValue();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/// The matcher that matches a constant scalar / vector splat / tensor splat
|
||||
/// integer operation and binds the constant integer value.
|
||||
struct constant_int_op_binder {
|
||||
IntegerAttr::ValueType *bind_value;
|
||||
|
||||
/// Creates a matcher instance that binds the value to bv if match succeeds.
|
||||
constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
|
||||
|
||||
bool match(Operation *op) {
|
||||
if (auto constOp = op->dyn_cast<ConstantOp>()) {
|
||||
auto *type = constOp->getResult()->getType();
|
||||
auto attr = constOp->getAttr("value");
|
||||
|
||||
if (isa<IntegerType>(type)) {
|
||||
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
|
||||
}
|
||||
if (isa<VectorOrTensorType>(type)) {
|
||||
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
||||
return attr_value_binder<IntegerAttr>(bind_value)
|
||||
.match(splatAttr.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// The matcher that matches a given target constant scalar / vector splat /
|
||||
// tensor splat integer value.
|
||||
template <IntegerAttr::ValueType TargetValue>
|
||||
struct constant_int_value_matcher {
|
||||
bool match(Operation *op) {
|
||||
IntegerAttr::ValueType value;
|
||||
|
||||
return constant_int_op_binder(&value).match(op) && TargetValue == value;
|
||||
}
|
||||
};
|
||||
|
||||
/// The matcher that matches a certain kind of op.
|
||||
template <typename OpClass> struct op_matcher {
|
||||
bool match(Operation *op) { return op->isa<OpClass>(); }
|
||||
};
|
||||
|
||||
} // end namespace detail
|
||||
|
||||
/// Entry point for matching a pattern over an SSAValue.
|
||||
template <typename Pattern>
|
||||
inline bool match(SSAValue *value, const Pattern &pattern) {
|
||||
// TODO: handle other cases
|
||||
if (auto *op = value->getDefiningOperation())
|
||||
return const_cast<Pattern &>(pattern).match(op);
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Matches a ConstantIndexOp.
|
||||
inline detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
|
||||
return detail::op_matcher<ConstantIndexOp>();
|
||||
}
|
||||
|
||||
/// Matches a ConstantOp holding a scalar/vector/tensor integer (splat) and
|
||||
/// writes the integer value to bind_value.
|
||||
inline detail::constant_int_op_binder
|
||||
m_ConstantInt(IntegerAttr::ValueType *bind_value) {
|
||||
return detail::constant_int_op_binder(bind_value);
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat integer one.
|
||||
inline detail::constant_int_value_matcher<1> m_One() {
|
||||
return detail::constant_int_value_matcher<1>();
|
||||
}
|
||||
|
||||
/// Matches a constant scalar / vector splat / tensor splat integer zero.
|
||||
inline detail::constant_int_value_matcher<0> m_Zero() {
|
||||
return detail::constant_int_value_matcher<0>();
|
||||
}
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_MATCHERS_H
|
@ -279,7 +279,7 @@ private:
|
||||
|
||||
/// The group of patterns that are matched for optimization through this
|
||||
/// matcher.
|
||||
std::vector<std::unique_ptr<Pattern>> patterns;
|
||||
OwningPatternList patterns;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -222,6 +222,12 @@ bool ConstantOp::verify() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isa<VectorOrTensorType>(type)) {
|
||||
if (!value.isa<ElementsAttr>())
|
||||
return emitOpError("requires 'value' to be a vector/tensor constant");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (type->isTFString()) {
|
||||
if (!value.isa<StringAttr>())
|
||||
return emitOpError("requires 'value' to be a string constant");
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SSAValue.h"
|
||||
@ -46,6 +47,11 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Matches a MemRefCastOp.
|
||||
inline detail::op_matcher<MemRefCastOp> m_MemRefCast() {
|
||||
return detail::op_matcher<MemRefCastOp>();
|
||||
}
|
||||
|
||||
/// This is a common class used for patterns of the form
|
||||
/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
|
||||
/// into the root operation directly.
|
||||
@ -57,9 +63,8 @@ struct MemRefCastFolder : public Pattern {
|
||||
std::pair<PatternBenefit, std::unique_ptr<PatternState>>
|
||||
match(Operation *op) const override {
|
||||
for (auto *operand : op->getOperands())
|
||||
if (auto *memref = operand->getDefiningOperation())
|
||||
if (memref->isa<MemRefCastOp>())
|
||||
return matchSuccess();
|
||||
if (::match(operand, m_MemRefCast()))
|
||||
return matchSuccess();
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
@ -116,12 +121,9 @@ struct SimplifyAddX0 : public Pattern {
|
||||
std::pair<PatternBenefit, std::unique_ptr<PatternState>>
|
||||
match(Operation *op) const override {
|
||||
auto addi = op->cast<AddIOp>();
|
||||
if (auto *operandOp = addi->getOperand(1)->getDefiningOperation())
|
||||
// TODO: Support splatted zero as well. We need a general zero pattern.
|
||||
if (auto cst = operandOp->dyn_cast<ConstantIntOp>()) {
|
||||
if (cst->getValue() == 0)
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
if (::match(addi->getOperand(1), m_Zero()))
|
||||
return matchSuccess();
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
@ -230,9 +232,8 @@ struct SimplifyAllocConst : public Pattern {
|
||||
// Check to see if any dimensions operands are constants. If so, we can
|
||||
// substitute and drop them.
|
||||
for (auto *operand : alloc->getOperands())
|
||||
if (auto *opOperation = operand->getDefiningOperation())
|
||||
if (opOperation->isa<ConstantIndexOp>())
|
||||
return matchSuccess();
|
||||
if (::match(operand, m_ConstantIndex()))
|
||||
return matchSuccess();
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
@ -892,12 +893,9 @@ struct SimplifyMulX1 : public Pattern {
|
||||
std::pair<PatternBenefit, std::unique_ptr<PatternState>>
|
||||
match(Operation *op) const override {
|
||||
auto muli = op->cast<MulIOp>();
|
||||
if (auto *operandOp = muli->getOperand(1)->getDefiningOperation())
|
||||
// TODO: Support splatted one as well. We need a general one pattern.
|
||||
if (auto cst = operandOp->dyn_cast<ConstantIntOp>()) {
|
||||
if (cst->getValue() == 1)
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
if (::match(muli->getOperand(1), m_One()))
|
||||
return matchSuccess();
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
|
@ -85,6 +85,9 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32, %i: i32):
|
||||
// CHECK: %f_2 = constant @affine_apply : () -> ()
|
||||
%12 = constant @affine_apply : () -> ()
|
||||
|
||||
// CHECK: %cst_3 = constant splat<vector<4xi32>, 0> : vector<4xi32>
|
||||
%13 = constant splat<vector<4 x i32>, 0> : vector<4 x i32>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -58,6 +58,22 @@ mlfunc @addi_zero(%arg0: i32) -> i32 {
|
||||
return %y: i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @addi_zero_vector
|
||||
mlfunc @addi_zero_vector(%arg0: vector<4 x i32>) -> vector<4 x i32> {
|
||||
// CHECK-NEXT: return %arg0
|
||||
%c0_v4i32 = constant splat<vector<4 x i32>, 0> : vector<4 x i32>
|
||||
%y = addi %c0_v4i32, %arg0 : vector<4 x i32>
|
||||
return %y: vector<4 x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @addi_zero_tensor
|
||||
mlfunc @addi_zero_tensor(%arg0: tensor<4 x 5 x i32>) -> tensor<4 x 5 x i32> {
|
||||
// CHECK-NEXT: return %arg0
|
||||
%c0_t45i32 = constant splat<tensor<4 x 5 x i32>, 0> : tensor<4 x 5 x i32>
|
||||
%y = addi %arg0, %c0_t45i32 : tensor<4 x 5 x i32>
|
||||
return %y: tensor<4 x 5 x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @muli_one
|
||||
mlfunc @muli_one(%arg0: i32) -> i32 {
|
||||
// CHECK-NEXT: return %arg0
|
||||
@ -66,6 +82,22 @@ mlfunc @muli_one(%arg0: i32) -> i32 {
|
||||
return %y: i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @muli_one_vector
|
||||
mlfunc @muli_one_vector(%arg0: vector<4 x i32>) -> vector<4 x i32> {
|
||||
// CHECK-NEXT: return %arg0
|
||||
%c1_v4i32 = constant splat<vector<4 x i32>, 1> : vector<4 x i32>
|
||||
%y = muli %c1_v4i32, %arg0 : vector<4 x i32>
|
||||
return %y: vector<4 x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @muli_one_tensor
|
||||
mlfunc @muli_one_tensor(%arg0: tensor<4 x 5 x i32>) -> tensor<4 x 5 x i32> {
|
||||
// CHECK-NEXT: return %arg0
|
||||
%c1_t45i32 = constant splat<tensor<4 x 5 x i32>, 1> : tensor<4 x 5 x i32>
|
||||
%y = muli %arg0, %c1_t45i32 : tensor<4 x 5 x i32>
|
||||
return %y: tensor<4 x 5 x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @memref_cast_folding
|
||||
mlfunc @memref_cast_folding(%arg0: memref<4 x f32>, %arg1: f32) -> f32 {
|
||||
%1 = memref_cast %arg0: memref<4xf32> to memref<?xf32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user