llvm-project/mlir/lib/IR/Builders.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

600 lines
20 KiB
C++
Raw Normal View History

//===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Locations.
//===----------------------------------------------------------------------===//
Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
return FusedLoc::get(locs, metadata, context);
}
//===----------------------------------------------------------------------===//
// Types.
//===----------------------------------------------------------------------===//
FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
FloatType Builder::getF16Type() { return Float16Type::get(context); }
FloatType Builder::getTF32Type() { return FloatTF32Type::get(context); }
FloatType Builder::getF32Type() { return Float32Type::get(context); }
FloatType Builder::getF64Type() { return Float64Type::get(context); }
FloatType Builder::getF80Type() { return Float80Type::get(context); }
FloatType Builder::getF128Type() { return Float128Type::get(context); }
IndexType Builder::getIndexType() { return IndexType::get(context); }
IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); }
IntegerType Builder::getI2Type() { return IntegerType::get(context, 2); }
IntegerType Builder::getI4Type() { return IntegerType::get(context, 4); }
IntegerType Builder::getI8Type() { return IntegerType::get(context, 8); }
IntegerType Builder::getI16Type() { return IntegerType::get(context, 16); }
IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); }
IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); }
IntegerType Builder::getIntegerType(unsigned width) {
return IntegerType::get(context, width);
}
[mlir] Add a signedness semantics bit to IntegerType Thus far IntegerType has been signless: a value of IntegerType does not have a sign intrinsically and it's up to the specific operation to decide how to interpret those bits. For example, std.addi does two's complement arithmetic, and std.divis/std.diviu treats the first bit as a sign. This design choice was made some time ago when we did't have lots of dialects and dialects were more rigid. Today we have much more extensible infrastructure and different dialect may want different modelling over integer signedness. So while we can say we want signless integers in the standard dialect, we cannot dictate for others. Requiring each dialect to model the signedness semantics with another set of custom types is duplicating the functionality everywhere, considering the fundamental role integer types play. This CL extends the IntegerType with a signedness semantics bit. This gives each dialect an option to opt in signedness semantics if that's what they want and helps code sharing. The parser is modified to recognize `si[1-9][0-9]*` and `ui[1-9][0-9]*` as signed and unsigned integer types, respectively, leaving the original `i[1-9][0-9]*` to continue to mean no indication over signedness semantics. All existing dialects are not affected (yet) as this is a feature to opt in. More discussions can be found at: https://groups.google.com/a/tensorflow.org/d/msg/mlir/XmkV8HOPWpo/7O4X0Nb_AQAJ Differential Revision: https://reviews.llvm.org/D72533
2020-01-10 14:48:24 -05:00
IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
return IntegerType::get(
context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned);
[mlir] Add a signedness semantics bit to IntegerType Thus far IntegerType has been signless: a value of IntegerType does not have a sign intrinsically and it's up to the specific operation to decide how to interpret those bits. For example, std.addi does two's complement arithmetic, and std.divis/std.diviu treats the first bit as a sign. This design choice was made some time ago when we did't have lots of dialects and dialects were more rigid. Today we have much more extensible infrastructure and different dialect may want different modelling over integer signedness. So while we can say we want signless integers in the standard dialect, we cannot dictate for others. Requiring each dialect to model the signedness semantics with another set of custom types is duplicating the functionality everywhere, considering the fundamental role integer types play. This CL extends the IntegerType with a signedness semantics bit. This gives each dialect an option to opt in signedness semantics if that's what they want and helps code sharing. The parser is modified to recognize `si[1-9][0-9]*` and `ui[1-9][0-9]*` as signed and unsigned integer types, respectively, leaving the original `i[1-9][0-9]*` to continue to mean no indication over signedness semantics. All existing dialects are not affected (yet) as this is a feature to opt in. More discussions can be found at: https://groups.google.com/a/tensorflow.org/d/msg/mlir/XmkV8HOPWpo/7O4X0Nb_AQAJ Differential Revision: https://reviews.llvm.org/D72533
2020-01-10 14:48:24 -05:00
}
FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
return FunctionType::get(context, inputs, results);
}
TupleType Builder::getTupleType(TypeRange elementTypes) {
return TupleType::get(context, elementTypes);
}
NoneType Builder::getNoneType() { return NoneType::get(context); }
//===----------------------------------------------------------------------===//
// Attributes.
//===----------------------------------------------------------------------===//
NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
return NamedAttribute(name, val);
}
UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
BoolAttr Builder::getBoolAttr(bool value) {
return BoolAttr::get(context, value);
}
DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
return DictionaryAttr::get(context, value);
}
IntegerAttr Builder::getIndexAttr(int64_t value) {
return IntegerAttr::get(getIndexType(), APInt(64, value));
}
IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
}
DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) {
return DenseIntElementsAttr::get(
VectorType::get(static_cast<int64_t>(values.size()), getI1Type()),
values);
}
DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
return DenseIntElementsAttr::get(
VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
values);
}
DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
return DenseIntElementsAttr::get(
VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
values);
}
DenseIntElementsAttr Builder::getIndexVectorAttr(ArrayRef<int64_t> values) {
return DenseIntElementsAttr::get(
VectorType::get(static_cast<int64_t>(values.size()), getIndexType()),
values);
}
DenseFPElementsAttr Builder::getF32VectorAttr(ArrayRef<float> values) {
return DenseFPElementsAttr::get(
VectorType::get(static_cast<float>(values.size()), getF32Type()), values);
}
DenseFPElementsAttr Builder::getF64VectorAttr(ArrayRef<double> values) {
return DenseFPElementsAttr::get(
VectorType::get(static_cast<double>(values.size()), getF64Type()),
values);
}
DenseBoolArrayAttr Builder::getDenseBoolArrayAttr(ArrayRef<bool> values) {
return DenseBoolArrayAttr::get(context, values);
}
DenseI8ArrayAttr Builder::getDenseI8ArrayAttr(ArrayRef<int8_t> values) {
return DenseI8ArrayAttr::get(context, values);
}
DenseI16ArrayAttr Builder::getDenseI16ArrayAttr(ArrayRef<int16_t> values) {
return DenseI16ArrayAttr::get(context, values);
}
DenseI32ArrayAttr Builder::getDenseI32ArrayAttr(ArrayRef<int32_t> values) {
return DenseI32ArrayAttr::get(context, values);
}
DenseI64ArrayAttr Builder::getDenseI64ArrayAttr(ArrayRef<int64_t> values) {
return DenseI64ArrayAttr::get(context, values);
}
DenseF32ArrayAttr Builder::getDenseF32ArrayAttr(ArrayRef<float> values) {
return DenseF32ArrayAttr::get(context, values);
}
DenseF64ArrayAttr Builder::getDenseF64ArrayAttr(ArrayRef<double> values) {
return DenseF64ArrayAttr::get(context, values);
}
DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
return DenseIntElementsAttr::get(
RankedTensorType::get(static_cast<int64_t>(values.size()),
getIntegerType(32)),
values);
}
DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
return DenseIntElementsAttr::get(
RankedTensorType::get(static_cast<int64_t>(values.size()),
getIntegerType(64)),
values);
}
DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
return DenseIntElementsAttr::get(
RankedTensorType::get(static_cast<int64_t>(values.size()),
getIndexType()),
values);
}
IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
// The APInt always uses isSigned=true here because we accept the value
// as int32_t.
return IntegerAttr::get(getIntegerType(32),
APInt(32, value, /*isSigned=*/true));
}
[mlir] Add a signedness semantics bit to IntegerType Thus far IntegerType has been signless: a value of IntegerType does not have a sign intrinsically and it's up to the specific operation to decide how to interpret those bits. For example, std.addi does two's complement arithmetic, and std.divis/std.diviu treats the first bit as a sign. This design choice was made some time ago when we did't have lots of dialects and dialects were more rigid. Today we have much more extensible infrastructure and different dialect may want different modelling over integer signedness. So while we can say we want signless integers in the standard dialect, we cannot dictate for others. Requiring each dialect to model the signedness semantics with another set of custom types is duplicating the functionality everywhere, considering the fundamental role integer types play. This CL extends the IntegerType with a signedness semantics bit. This gives each dialect an option to opt in signedness semantics if that's what they want and helps code sharing. The parser is modified to recognize `si[1-9][0-9]*` and `ui[1-9][0-9]*` as signed and unsigned integer types, respectively, leaving the original `i[1-9][0-9]*` to continue to mean no indication over signedness semantics. All existing dialects are not affected (yet) as this is a feature to opt in. More discussions can be found at: https://groups.google.com/a/tensorflow.org/d/msg/mlir/XmkV8HOPWpo/7O4X0Nb_AQAJ Differential Revision: https://reviews.llvm.org/D72533
2020-01-10 14:48:24 -05:00
IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
APInt(32, value, /*isSigned=*/true));
}
IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
APInt(32, (uint64_t)value, /*isSigned=*/false));
[mlir] Add a signedness semantics bit to IntegerType Thus far IntegerType has been signless: a value of IntegerType does not have a sign intrinsically and it's up to the specific operation to decide how to interpret those bits. For example, std.addi does two's complement arithmetic, and std.divis/std.diviu treats the first bit as a sign. This design choice was made some time ago when we did't have lots of dialects and dialects were more rigid. Today we have much more extensible infrastructure and different dialect may want different modelling over integer signedness. So while we can say we want signless integers in the standard dialect, we cannot dictate for others. Requiring each dialect to model the signedness semantics with another set of custom types is duplicating the functionality everywhere, considering the fundamental role integer types play. This CL extends the IntegerType with a signedness semantics bit. This gives each dialect an option to opt in signedness semantics if that's what they want and helps code sharing. The parser is modified to recognize `si[1-9][0-9]*` and `ui[1-9][0-9]*` as signed and unsigned integer types, respectively, leaving the original `i[1-9][0-9]*` to continue to mean no indication over signedness semantics. All existing dialects are not affected (yet) as this is a feature to opt in. More discussions can be found at: https://groups.google.com/a/tensorflow.org/d/msg/mlir/XmkV8HOPWpo/7O4X0Nb_AQAJ Differential Revision: https://reviews.llvm.org/D72533
2020-01-10 14:48:24 -05:00
}
IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
return IntegerAttr::get(getIntegerType(16), APInt(16, value));
}
IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
// The APInt always uses isSigned=true here because we accept the value
// as int8_t.
return IntegerAttr::get(getIntegerType(8),
APInt(8, value, /*isSigned=*/true));
}
IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
if (type.isIndex())
return IntegerAttr::get(type, APInt(64, value));
// TODO: Avoid implicit trunc?
// See https://github.com/llvm/llvm-project/issues/112510.
return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value,
type.isSignedInteger(),
/*implicitTrunc=*/true));
}
IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
return IntegerAttr::get(type, value);
}
FloatAttr Builder::getF64FloatAttr(double value) {
return FloatAttr::get(getF64Type(), APFloat(value));
}
FloatAttr Builder::getF32FloatAttr(float value) {
return FloatAttr::get(getF32Type(), APFloat(value));
}
FloatAttr Builder::getF16FloatAttr(float value) {
return FloatAttr::get(getF16Type(), value);
}
FloatAttr Builder::getFloatAttr(Type type, double value) {
return FloatAttr::get(type, value);
}
FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
return FloatAttr::get(type, value);
}
StringAttr Builder::getStringAttr(const Twine &bytes) {
return StringAttr::get(context, bytes);
}
ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
return ArrayAttr::get(context, value);
}
ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
auto attrs = llvm::map_to_vector<8>(
values, [this](bool v) -> Attribute { return getBoolAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
auto attrs = llvm::map_to_vector<8>(
values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
auto attrs = llvm::map_to_vector<8>(
values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
auto attrs = llvm::map_to_vector<8>(values, [this](int64_t v) -> Attribute {
return getIntegerAttr(IndexType::get(getContext()), v);
});
return getArrayAttr(attrs);
}
ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
auto attrs = llvm::map_to_vector<8>(
values, [this](float v) -> Attribute { return getF32FloatAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
auto attrs = llvm::map_to_vector<8>(
values, [this](double v) -> Attribute { return getF64FloatAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
auto attrs = llvm::map_to_vector<8>(
values, [this](StringRef v) -> Attribute { return getStringAttr(v); });
return getArrayAttr(attrs);
}
[mlir][PDL] Add a PDL Interpreter Dialect The PDL Interpreter dialect provides a lower level abstraction compared to the PDL dialect, and is targeted towards low level optimization and interpreter code generation. The dialect operations encapsulates low-level pattern match and rewrite "primitives", such as navigating the IR (Operation::getOperand), creating new operations (OpBuilder::create), etc. Many of the operations within this dialect also fuse branching control flow with some form of a predicate comparison operation. This type of fusion reduces the amount of work that an interpreter must do when executing. An example of this representation is shown below: ```mlir // The following high level PDL pattern: pdl.pattern : benefit(1) { %resultType = pdl.type %inputOperand = pdl.input %root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType pdl.rewrite %root { pdl.replace %root with (%inputOperand) } } // May be represented in the interpreter dialect as follows: module { func @matcher(%arg0: !pdl.operation) { pdl_interp.check_operation_name of %arg0 is "foo.op" -> ^bb2, ^bb1 ^bb1: pdl_interp.return ^bb2: pdl_interp.check_operand_count of %arg0 is 1 -> ^bb3, ^bb1 ^bb3: pdl_interp.check_result_count of %arg0 is 1 -> ^bb4, ^bb1 ^bb4: %0 = pdl_interp.get_operand 0 of %arg0 pdl_interp.is_not_null %0 : !pdl.value -> ^bb5, ^bb1 ^bb5: %1 = pdl_interp.get_result 0 of %arg0 pdl_interp.is_not_null %1 : !pdl.value -> ^bb6, ^bb1 ^bb6: pdl_interp.record_match @rewriters::@rewriter(%0, %arg0 : !pdl.value, !pdl.operation) : benefit(1), loc([%arg0]), root("foo.op") -> ^bb1 } module @rewriters { func @rewriter(%arg0: !pdl.value, %arg1: !pdl.operation) { pdl_interp.replace %arg1 with(%arg0) pdl_interp.return } } } ``` Differential Revision: https://reviews.llvm.org/D84579
2020-08-26 05:12:07 -07:00
ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
auto attrs = llvm::map_to_vector<8>(
values, [](Type v) -> Attribute { return TypeAttr::get(v); });
[mlir][PDL] Add a PDL Interpreter Dialect The PDL Interpreter dialect provides a lower level abstraction compared to the PDL dialect, and is targeted towards low level optimization and interpreter code generation. The dialect operations encapsulates low-level pattern match and rewrite "primitives", such as navigating the IR (Operation::getOperand), creating new operations (OpBuilder::create), etc. Many of the operations within this dialect also fuse branching control flow with some form of a predicate comparison operation. This type of fusion reduces the amount of work that an interpreter must do when executing. An example of this representation is shown below: ```mlir // The following high level PDL pattern: pdl.pattern : benefit(1) { %resultType = pdl.type %inputOperand = pdl.input %root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType pdl.rewrite %root { pdl.replace %root with (%inputOperand) } } // May be represented in the interpreter dialect as follows: module { func @matcher(%arg0: !pdl.operation) { pdl_interp.check_operation_name of %arg0 is "foo.op" -> ^bb2, ^bb1 ^bb1: pdl_interp.return ^bb2: pdl_interp.check_operand_count of %arg0 is 1 -> ^bb3, ^bb1 ^bb3: pdl_interp.check_result_count of %arg0 is 1 -> ^bb4, ^bb1 ^bb4: %0 = pdl_interp.get_operand 0 of %arg0 pdl_interp.is_not_null %0 : !pdl.value -> ^bb5, ^bb1 ^bb5: %1 = pdl_interp.get_result 0 of %arg0 pdl_interp.is_not_null %1 : !pdl.value -> ^bb6, ^bb1 ^bb6: pdl_interp.record_match @rewriters::@rewriter(%0, %arg0 : !pdl.value, !pdl.operation) : benefit(1), loc([%arg0]), root("foo.op") -> ^bb1 } module @rewriters { func @rewriter(%arg0: !pdl.value, %arg1: !pdl.operation) { pdl_interp.replace %arg1 with(%arg0) pdl_interp.return } } } ``` Differential Revision: https://reviews.llvm.org/D84579
2020-08-26 05:12:07 -07:00
return getArrayAttr(attrs);
}
Add a generic Linalg op This CL introduces a linalg.generic op to represent generic tensor contraction operations on views. A linalg.generic operation requires a numbers of attributes that are sufficient to emit the computation in scalar form as well as compute the appropriate subviews to enable tiling and fusion. These attributes are very similar to the attributes for existing operations such as linalg.matmul etc and existing operations can be implemented with the generic form. In the future, most existing operations can be implemented using the generic form. This CL starts by splitting out most of the functionality of the linalg::NInputsAndOutputs trait into a ViewTrait that queries the per-instance properties of the op. This allows using the attribute informations. This exposes an ordering of verifiers issue where ViewTrait::verify uses attributes but the verifiers for those attributes have not been run. The desired behavior would be for the verifiers of the attributes specified in the builder to execute first but it is not the case atm. As a consequence, to emit proper error messages and avoid crashing, some of the linalg.generic methods are defensive as such: ``` unsigned getNumInputs() { // This is redundant with the `n_views` attribute verifier but ordering of verifiers // may exhibit cases where we crash instead of emitting an error message. if (!getAttr("n_views") || n_views().getValue().size() != 2) return 0; ``` In pretty-printed form, the specific attributes required for linalg.generic are factored out in an independent dictionary named "_". When parsing its content is flattened and the "_name" is dropped. This allows using aliasing for reducing boilerplate at each linalg.generic invocation while benefiting from the Tablegen'd verifier form for each named attribute in the dictionary. For instance, implementing linalg.matmul in terms of linalg.generic resembles: ``` func @mac(%a: f32, %b: f32, %c: f32) -> f32 { %d = mulf %a, %b: f32 %e = addf %c, %d: f32 return %e: f32 } #matmul_accesses = [ (m, n, k) -> (m, k), (m, n, k) -> (k, n), (m, n, k) -> (m, n) ] #matmul_trait = { doc = "C(m, n) += A(m, k) * B(k, n)", fun = @mac, indexing_maps = #matmul_accesses, library_call = "linalg_matmul", n_views = [2, 1], n_loop_types = [2, 1, 0] } ``` And can be used in multiple places as: ``` linalg.generic #matmul_trait %A, %B, %C [other-attributes] : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32> ``` In the future it would be great to have a mechanism to alias / register a new linalg.op as a pair of linalg.generic, #trait. Also, note that with one could theoretically only specify the `doc` string and parse all the attributes from it. PiperOrigin-RevId: 261338740
2019-08-02 09:53:08 -07:00
ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
auto attrs = llvm::map_to_vector<8>(
values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); });
Add a generic Linalg op This CL introduces a linalg.generic op to represent generic tensor contraction operations on views. A linalg.generic operation requires a numbers of attributes that are sufficient to emit the computation in scalar form as well as compute the appropriate subviews to enable tiling and fusion. These attributes are very similar to the attributes for existing operations such as linalg.matmul etc and existing operations can be implemented with the generic form. In the future, most existing operations can be implemented using the generic form. This CL starts by splitting out most of the functionality of the linalg::NInputsAndOutputs trait into a ViewTrait that queries the per-instance properties of the op. This allows using the attribute informations. This exposes an ordering of verifiers issue where ViewTrait::verify uses attributes but the verifiers for those attributes have not been run. The desired behavior would be for the verifiers of the attributes specified in the builder to execute first but it is not the case atm. As a consequence, to emit proper error messages and avoid crashing, some of the linalg.generic methods are defensive as such: ``` unsigned getNumInputs() { // This is redundant with the `n_views` attribute verifier but ordering of verifiers // may exhibit cases where we crash instead of emitting an error message. if (!getAttr("n_views") || n_views().getValue().size() != 2) return 0; ``` In pretty-printed form, the specific attributes required for linalg.generic are factored out in an independent dictionary named "_". When parsing its content is flattened and the "_name" is dropped. This allows using aliasing for reducing boilerplate at each linalg.generic invocation while benefiting from the Tablegen'd verifier form for each named attribute in the dictionary. For instance, implementing linalg.matmul in terms of linalg.generic resembles: ``` func @mac(%a: f32, %b: f32, %c: f32) -> f32 { %d = mulf %a, %b: f32 %e = addf %c, %d: f32 return %e: f32 } #matmul_accesses = [ (m, n, k) -> (m, k), (m, n, k) -> (k, n), (m, n, k) -> (m, n) ] #matmul_trait = { doc = "C(m, n) += A(m, k) * B(k, n)", fun = @mac, indexing_maps = #matmul_accesses, library_call = "linalg_matmul", n_views = [2, 1], n_loop_types = [2, 1, 0] } ``` And can be used in multiple places as: ``` linalg.generic #matmul_trait %A, %B, %C [other-attributes] : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32> ``` In the future it would be great to have a mechanism to alias / register a new linalg.op as a pair of linalg.generic, #trait. Also, note that with one could theoretically only specify the `doc` string and parse all the attributes from it. PiperOrigin-RevId: 261338740
2019-08-02 09:53:08 -07:00
return getArrayAttr(attrs);
}
TypedAttr Builder::getZeroAttr(Type type) {
[mlir] Update method cast calls to function calls The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: * https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" * Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This follows a previous patch that updated calls `op.cast<T>()-> cast<T>(op)`. However some cases could not handle an unprefixed `cast` call due to occurrences of variables named cast, or occurring inside of class definitions which would resolve to the method. All C++ files that did not work automatically with `cast<T>()` are updated here to `llvm::cast` and similar with the intention that they can be easily updated after the methods are removed through a find-replace. See https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check for the clang-tidy check that is used and then update printed occurrences of the function to include `llvm::` before. One can then run the following: ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -export-fixes /tmp/cast/casts.yaml mlir/*\ -header-filter=mlir/ -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D150348
2023-05-11 11:10:46 +02:00
if (llvm::isa<FloatType>(type))
return getFloatAttr(type, 0.0);
[mlir] Update method cast calls to function calls The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: * https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" * Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This follows a previous patch that updated calls `op.cast<T>()-> cast<T>(op)`. However some cases could not handle an unprefixed `cast` call due to occurrences of variables named cast, or occurring inside of class definitions which would resolve to the method. All C++ files that did not work automatically with `cast<T>()` are updated here to `llvm::cast` and similar with the intention that they can be easily updated after the methods are removed through a find-replace. See https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check for the clang-tidy check that is used and then update printed occurrences of the function to include `llvm::` before. One can then run the following: ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -export-fixes /tmp/cast/casts.yaml mlir/*\ -header-filter=mlir/ -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D150348
2023-05-11 11:10:46 +02:00
if (llvm::isa<IndexType>(type))
return getIndexAttr(0);
if (llvm::dyn_cast<IntegerType>(type))
[mlir] Update method cast calls to function calls The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: * https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" * Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This follows a previous patch that updated calls `op.cast<T>()-> cast<T>(op)`. However some cases could not handle an unprefixed `cast` call due to occurrences of variables named cast, or occurring inside of class definitions which would resolve to the method. All C++ files that did not work automatically with `cast<T>()` are updated here to `llvm::cast` and similar with the intention that they can be easily updated after the methods are removed through a find-replace. See https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check for the clang-tidy check that is used and then update printed occurrences of the function to include `llvm::` before. One can then run the following: ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -export-fixes /tmp/cast/casts.yaml mlir/*\ -header-filter=mlir/ -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D150348
2023-05-11 11:10:46 +02:00
return getIntegerAttr(type,
APInt(llvm::cast<IntegerType>(type).getWidth(), 0));
if (llvm::isa<RankedTensorType, VectorType>(type)) {
auto vtType = llvm::cast<ShapedType>(type);
auto element = getZeroAttr(vtType.getElementType());
if (!element)
return {};
return DenseElementsAttr::get(vtType, element);
}
return {};
}
TypedAttr Builder::getOneAttr(Type type) {
if (llvm::isa<FloatType>(type))
return getFloatAttr(type, 1.0);
if (llvm::isa<IndexType>(type))
return getIndexAttr(1);
if (llvm::dyn_cast<IntegerType>(type))
return getIntegerAttr(type,
APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
if (llvm::isa<RankedTensorType, VectorType>(type)) {
auto vtType = llvm::cast<ShapedType>(type);
auto element = getOneAttr(vtType.getElementType());
if (!element)
return {};
return DenseElementsAttr::get(vtType, element);
}
return {};
}
//===----------------------------------------------------------------------===//
// Affine Expressions, Affine Maps, and Integer Sets.
//===----------------------------------------------------------------------===//
AffineExpr Builder::getAffineDimExpr(unsigned position) {
return mlir::getAffineDimExpr(position, context);
}
AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
return mlir::getAffineSymbolExpr(position, context);
}
AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
return mlir::getAffineConstantExpr(constant, context);
}
AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
AffineMap Builder::getConstantAffineMap(int64_t val) {
return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
getAffineConstantExpr(val));
}
AffineMap Builder::getDimIdentityMap() {
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
}
AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(rank);
for (unsigned i = 0; i < rank; ++i)
dimExprs.push_back(getAffineDimExpr(i));
return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
context);
}
AffineMap Builder::getSymbolIdentityMap() {
return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
getAffineSymbolExpr(0));
}
AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
// expr = d0 + shift.
auto expr = getAffineDimExpr(0) + shift;
return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
Introduce loop body skewing / loop pipelining / loop shifting utility. - loopBodySkew shifts statements of a loop body by stmt-wise delays, and is typically meant to be used to: - allow overlap of non-blocking start/wait until completion operations with other computation - allow shifting of statements (for better register reuse/locality/parallelism) - software pipelining (when applied to the innermost loop) - an additional argument specifies whether to unroll the prologue and epilogue. - add method to check SSA dominance preservation. - add a fake loop pipeline pass to test this utility. Sample input/output are below. While on this, fix/add following: - fix minor bug in getAddMulPureAffineExpr - add additional builder methods for common affine map cases - fix const_operand_iterator's for ForStmt, etc. When there is no such thing as 'const MLValue', the iterator shouldn't be returning const MLValue's. Returning MLValue is const correct. Sample input/output examples: 1) Simplest case: shift second statement by one. Input: for %i = 0 to 7 { %y = "foo"(%i) : (affineint) -> affineint %x = "bar"(%i) : (affineint) -> affineint } Output: #map0 = (d0) -> (d0 - 1) mlfunc @loop_nest_simple1() { %c8 = constant 8 : affineint %c0 = constant 0 : affineint %0 = "foo"(%c0) : (affineint) -> affineint for %i0 = 1 to 7 { %1 = "foo"(%i0) : (affineint) -> affineint %2 = affine_apply #map0(%i0) %3 = "bar"(%2) : (affineint) -> affineint } %4 = affine_apply #map0(%c8) %5 = "bar"(%4) : (affineint) -> affineint return } 2) DMA overlap: shift dma.wait and compute by one. Input for %i = 0 to 7 { %pingpong = affine_apply (d0) -> (d0 mod 2) (%i) "dma.enqueue"(%pingpong) : (affineint) -> affineint %pongping = affine_apply (d0) -> (d0 mod 2) (%i) "dma.wait"(%pongping) : (affineint) -> affineint "compute1"(%pongping) : (affineint) -> affineint } Output #map0 = (d0) -> (d0 mod 2) #map1 = (d0) -> (d0 - 1) #map2 = ()[s0] -> (s0 + 7) mlfunc @loop_nest_dma() { %c8 = constant 8 : affineint %c0 = constant 0 : affineint %0 = affine_apply #map0(%c0) %1 = "dma.enqueue"(%0) : (affineint) -> affineint for %i0 = 1 to 7 { %2 = affine_apply #map0(%i0) %3 = "dma.enqueue"(%2) : (affineint) -> affineint %4 = affine_apply #map1(%i0) %5 = affine_apply #map0(%4) %6 = "dma.wait"(%5) : (affineint) -> affineint %7 = "compute1"(%5) : (affineint) -> affineint } %8 = affine_apply #map1(%c8) %9 = affine_apply #map0(%8) %10 = "dma.wait"(%9) : (affineint) -> affineint %11 = "compute1"(%9) : (affineint) -> affineint return } 3) With arbitrary affine bound maps: Shift last two statements by two. Input: for %i = %N to ()[s0] -> (s0 + 7)()[%N] { %y = "foo"(%i) : (affineint) -> affineint %x = "bar"(%i) : (affineint) -> affineint %z = "foo_bar"(%i) : (affineint) -> (affineint) "bar_foo"(%i) : (affineint) -> (affineint) } Output #map0 = ()[s0] -> (s0 + 1) #map1 = ()[s0] -> (s0 + 2) #map2 = ()[s0] -> (s0 + 7) #map3 = (d0) -> (d0 - 2) #map4 = ()[s0] -> (s0 + 8) #map5 = ()[s0] -> (s0 + 9) for %i0 = %arg0 to #map0()[%arg0] { %0 = "foo"(%i0) : (affineint) -> affineint %1 = "bar"(%i0) : (affineint) -> affineint } for %i1 = #map1()[%arg0] to #map2()[%arg0] { %2 = "foo"(%i1) : (affineint) -> affineint %3 = "bar"(%i1) : (affineint) -> affineint %4 = affine_apply #map3(%i1) %5 = "foo_bar"(%4) : (affineint) -> affineint %6 = "bar_foo"(%4) : (affineint) -> affineint } for %i2 = #map4()[%arg0] to #map5()[%arg0] { %7 = affine_apply #map3(%i2) %8 = "foo_bar"(%7) : (affineint) -> affineint %9 = "bar_foo"(%7) : (affineint) -> affineint } 4) Shift one by zero, second by one, third by two for %i = 0 to 7 { %y = "foo"(%i) : (affineint) -> affineint %x = "bar"(%i) : (affineint) -> affineint %z = "foobar"(%i) : (affineint) -> affineint } #map0 = (d0) -> (d0 - 1) #map1 = (d0) -> (d0 - 2) #map2 = ()[s0] -> (s0 + 7) %c9 = constant 9 : affineint %c8 = constant 8 : affineint %c1 = constant 1 : affineint %c0 = constant 0 : affineint %0 = "foo"(%c0) : (affineint) -> affineint %1 = "foo"(%c1) : (affineint) -> affineint %2 = affine_apply #map0(%c1) %3 = "bar"(%2) : (affineint) -> affineint for %i0 = 2 to 7 { %4 = "foo"(%i0) : (affineint) -> affineint %5 = affine_apply #map0(%i0) %6 = "bar"(%5) : (affineint) -> affineint %7 = affine_apply #map1(%i0) %8 = "foobar"(%7) : (affineint) -> affineint } %9 = affine_apply #map0(%c8) %10 = "bar"(%9) : (affineint) -> affineint %11 = affine_apply #map1(%c8) %12 = "foobar"(%11) : (affineint) -> affineint %13 = affine_apply #map1(%c9) %14 = "foobar"(%13) : (affineint) -> affineint 5) SSA dominance violated; no shifting if a shift is specified for the second statement. for %i = 0 to 7 { %x = "foo"(%i) : (affineint) -> affineint "bar"(%x) : (affineint) -> affineint } PiperOrigin-RevId: 214975731
2018-09-28 12:17:26 -07:00
}
AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
SmallVector<AffineExpr, 4> shiftedResults;
shiftedResults.reserve(map.getNumResults());
for (auto resultExpr : map.getResults())
shiftedResults.push_back(resultExpr + shift);
return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
context);
Introduce loop body skewing / loop pipelining / loop shifting utility. - loopBodySkew shifts statements of a loop body by stmt-wise delays, and is typically meant to be used to: - allow overlap of non-blocking start/wait until completion operations with other computation - allow shifting of statements (for better register reuse/locality/parallelism) - software pipelining (when applied to the innermost loop) - an additional argument specifies whether to unroll the prologue and epilogue. - add method to check SSA dominance preservation. - add a fake loop pipeline pass to test this utility. Sample input/output are below. While on this, fix/add following: - fix minor bug in getAddMulPureAffineExpr - add additional builder methods for common affine map cases - fix const_operand_iterator's for ForStmt, etc. When there is no such thing as 'const MLValue', the iterator shouldn't be returning const MLValue's. Returning MLValue is const correct. Sample input/output examples: 1) Simplest case: shift second statement by one. Input: for %i = 0 to 7 { %y = "foo"(%i) : (affineint) -> affineint %x = "bar"(%i) : (affineint) -> affineint } Output: #map0 = (d0) -> (d0 - 1) mlfunc @loop_nest_simple1() { %c8 = constant 8 : affineint %c0 = constant 0 : affineint %0 = "foo"(%c0) : (affineint) -> affineint for %i0 = 1 to 7 { %1 = "foo"(%i0) : (affineint) -> affineint %2 = affine_apply #map0(%i0) %3 = "bar"(%2) : (affineint) -> affineint } %4 = affine_apply #map0(%c8) %5 = "bar"(%4) : (affineint) -> affineint return } 2) DMA overlap: shift dma.wait and compute by one. Input for %i = 0 to 7 { %pingpong = affine_apply (d0) -> (d0 mod 2) (%i) "dma.enqueue"(%pingpong) : (affineint) -> affineint %pongping = affine_apply (d0) -> (d0 mod 2) (%i) "dma.wait"(%pongping) : (affineint) -> affineint "compute1"(%pongping) : (affineint) -> affineint } Output #map0 = (d0) -> (d0 mod 2) #map1 = (d0) -> (d0 - 1) #map2 = ()[s0] -> (s0 + 7) mlfunc @loop_nest_dma() { %c8 = constant 8 : affineint %c0 = constant 0 : affineint %0 = affine_apply #map0(%c0) %1 = "dma.enqueue"(%0) : (affineint) -> affineint for %i0 = 1 to 7 { %2 = affine_apply #map0(%i0) %3 = "dma.enqueue"(%2) : (affineint) -> affineint %4 = affine_apply #map1(%i0) %5 = affine_apply #map0(%4) %6 = "dma.wait"(%5) : (affineint) -> affineint %7 = "compute1"(%5) : (affineint) -> affineint } %8 = affine_apply #map1(%c8) %9 = affine_apply #map0(%8) %10 = "dma.wait"(%9) : (affineint) -> affineint %11 = "compute1"(%9) : (affineint) -> affineint return } 3) With arbitrary affine bound maps: Shift last two statements by two. Input: for %i = %N to ()[s0] -> (s0 + 7)()[%N] { %y = "foo"(%i) : (affineint) -> affineint %x = "bar"(%i) : (affineint) -> affineint %z = "foo_bar"(%i) : (affineint) -> (affineint) "bar_foo"(%i) : (affineint) -> (affineint) } Output #map0 = ()[s0] -> (s0 + 1) #map1 = ()[s0] -> (s0 + 2) #map2 = ()[s0] -> (s0 + 7) #map3 = (d0) -> (d0 - 2) #map4 = ()[s0] -> (s0 + 8) #map5 = ()[s0] -> (s0 + 9) for %i0 = %arg0 to #map0()[%arg0] { %0 = "foo"(%i0) : (affineint) -> affineint %1 = "bar"(%i0) : (affineint) -> affineint } for %i1 = #map1()[%arg0] to #map2()[%arg0] { %2 = "foo"(%i1) : (affineint) -> affineint %3 = "bar"(%i1) : (affineint) -> affineint %4 = affine_apply #map3(%i1) %5 = "foo_bar"(%4) : (affineint) -> affineint %6 = "bar_foo"(%4) : (affineint) -> affineint } for %i2 = #map4()[%arg0] to #map5()[%arg0] { %7 = affine_apply #map3(%i2) %8 = "foo_bar"(%7) : (affineint) -> affineint %9 = "bar_foo"(%7) : (affineint) -> affineint } 4) Shift one by zero, second by one, third by two for %i = 0 to 7 { %y = "foo"(%i) : (affineint) -> affineint %x = "bar"(%i) : (affineint) -> affineint %z = "foobar"(%i) : (affineint) -> affineint } #map0 = (d0) -> (d0 - 1) #map1 = (d0) -> (d0 - 2) #map2 = ()[s0] -> (s0 + 7) %c9 = constant 9 : affineint %c8 = constant 8 : affineint %c1 = constant 1 : affineint %c0 = constant 0 : affineint %0 = "foo"(%c0) : (affineint) -> affineint %1 = "foo"(%c1) : (affineint) -> affineint %2 = affine_apply #map0(%c1) %3 = "bar"(%2) : (affineint) -> affineint for %i0 = 2 to 7 { %4 = "foo"(%i0) : (affineint) -> affineint %5 = affine_apply #map0(%i0) %6 = "bar"(%5) : (affineint) -> affineint %7 = affine_apply #map1(%i0) %8 = "foobar"(%7) : (affineint) -> affineint } %9 = affine_apply #map0(%c8) %10 = "bar"(%9) : (affineint) -> affineint %11 = affine_apply #map1(%c8) %12 = "foobar"(%11) : (affineint) -> affineint %13 = affine_apply #map1(%c9) %14 = "foobar"(%13) : (affineint) -> affineint 5) SSA dominance violated; no shifting if a shift is specified for the second statement. for %i = 0 to 7 { %x = "foo"(%i) : (affineint) -> affineint "bar"(%x) : (affineint) -> affineint } PiperOrigin-RevId: 214975731
2018-09-28 12:17:26 -07:00
}
//===----------------------------------------------------------------------===//
// OpBuilder
//===----------------------------------------------------------------------===//
/// Insert the given operation at the current insertion point and return it.
Operation *OpBuilder::insert(Operation *op) {
if (block) {
block->getOperations().insert(insertPoint, op);
if (listener)
listener->notifyOperationInserted(op, /*previous=*/{});
}
return op;
}
[mlir] DialectConversion: support block creation in ConversionPatternRewriter PatternRewriter and derived classes provide a set of virtual methods to manipulate blocks, which ConversionPatternRewriter overrides to keep track of the manipulations and undo them in case the conversion fails. However, one can currently create a block only by splitting another block into two. This not only makes the API inconsistent (`splitBlock` is allowed in conversion patterns, but `createBlock` is not), but it also make it impossible for one to create blocks with argument lists different from those of already existing blocks since in-place block updates are not supported either. Such functionality precludes dialect conversion infrastructure from being used more extensively on region-containing ops, for example, for value-returning "if" operations. At the same time, ConversionPatternRewriter already allows one to undo block creation as block creation is one of the primitive operations in already supported region inlining. Support block creation in conversion patterns by hooking `createBlock` on the block action undo mechanism. This requires to make `Builder::createBlock` virtual, similarly to Op insertion. This is a minimal change to the Builder infrastructure that will later help support additional use cases such as block signature changes. `createBlock` now additionally takes the types of the block arguments that are added immediately so as to avoid in-place argument list manipulation that would be illegal in conversion patterns.
2020-04-03 19:53:13 +02:00
Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
TypeRange argTypes, ArrayRef<Location> locs) {
assert(parent && "expected valid parent region");
assert(argTypes.size() == locs.size() && "argument location mismatch");
if (insertPt == Region::iterator())
insertPt = parent->end();
Block *b = new Block();
b->addArguments(argTypes, locs);
parent->getBlocks().insert(insertPt, b);
setInsertionPointToEnd(b);
if (listener)
listener->notifyBlockInserted(b, /*previous=*/nullptr, /*previousIt=*/{});
return b;
}
[mlir] DialectConversion: support block creation in ConversionPatternRewriter PatternRewriter and derived classes provide a set of virtual methods to manipulate blocks, which ConversionPatternRewriter overrides to keep track of the manipulations and undo them in case the conversion fails. However, one can currently create a block only by splitting another block into two. This not only makes the API inconsistent (`splitBlock` is allowed in conversion patterns, but `createBlock` is not), but it also make it impossible for one to create blocks with argument lists different from those of already existing blocks since in-place block updates are not supported either. Such functionality precludes dialect conversion infrastructure from being used more extensively on region-containing ops, for example, for value-returning "if" operations. At the same time, ConversionPatternRewriter already allows one to undo block creation as block creation is one of the primitive operations in already supported region inlining. Support block creation in conversion patterns by hooking `createBlock` on the block action undo mechanism. This requires to make `Builder::createBlock` virtual, similarly to Op insertion. This is a minimal change to the Builder infrastructure that will later help support additional use cases such as block signature changes. `createBlock` now additionally takes the types of the block arguments that are added immediately so as to avoid in-place argument list manipulation that would be illegal in conversion patterns.
2020-04-03 19:53:13 +02:00
/// Add new block with 'argTypes' arguments and set the insertion point to the
/// end of it. The block is placed before 'insertBefore'.
Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes,
ArrayRef<Location> locs) {
assert(insertBefore && "expected valid insertion block");
[mlir] DialectConversion: support block creation in ConversionPatternRewriter PatternRewriter and derived classes provide a set of virtual methods to manipulate blocks, which ConversionPatternRewriter overrides to keep track of the manipulations and undo them in case the conversion fails. However, one can currently create a block only by splitting another block into two. This not only makes the API inconsistent (`splitBlock` is allowed in conversion patterns, but `createBlock` is not), but it also make it impossible for one to create blocks with argument lists different from those of already existing blocks since in-place block updates are not supported either. Such functionality precludes dialect conversion infrastructure from being used more extensively on region-containing ops, for example, for value-returning "if" operations. At the same time, ConversionPatternRewriter already allows one to undo block creation as block creation is one of the primitive operations in already supported region inlining. Support block creation in conversion patterns by hooking `createBlock` on the block action undo mechanism. This requires to make `Builder::createBlock` virtual, similarly to Op insertion. This is a minimal change to the Builder infrastructure that will later help support additional use cases such as block signature changes. `createBlock` now additionally takes the types of the block arguments that are added immediately so as to avoid in-place argument list manipulation that would be illegal in conversion patterns.
2020-04-03 19:53:13 +02:00
return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
argTypes, locs);
}
/// Create an operation given the fields represented as an OperationState.
Operation *OpBuilder::create(const OperationState &state) {
return insert(Operation::create(state));
}
/// Creates an operation with the given fields.
Operation *OpBuilder::create(Location loc, StringAttr opName,
ValueRange operands, TypeRange types,
ArrayRef<NamedAttribute> attributes,
BlockRange successors,
MutableArrayRef<std::unique_ptr<Region>> regions) {
OperationState state(loc, opName, operands, types, attributes, successors,
regions);
return create(state);
}
LogicalResult OpBuilder::tryFold(Operation *op,
SmallVectorImpl<Value> &results) {
assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();
results.reserve(opResults.size());
auto cleanupFailure = [&] {
results.clear();
return failure();
};
// If this operation is already a constant, there is nothing to do.
if (matchPattern(op, m_Constant()))
return cleanupFailure();
// Try to fold the operation.
SmallVector<OpFoldResult, 4> foldResults;
if (failed(op->fold(foldResults)))
return cleanupFailure();
// An in-place fold does not require generation of any constants.
if (foldResults.empty())
return success();
// A temporary builder used for creating constants during folding.
OpBuilder cstBuilder(context);
SmallVector<Operation *, 1> generatedConstants;
// Populate the results with the folded results.
Dialect *dialect = op->getDialect();
for (auto [foldResult, expectedType] :
llvm::zip_equal(foldResults, opResults.getTypes())) {
// Normal values get pushed back directly.
if (auto value = llvm::dyn_cast_if_present<Value>(foldResult)) {
results.push_back(value);
continue;
}
// Otherwise, try to materialize a constant operation.
if (!dialect)
return cleanupFailure();
// Ask the dialect to materialize a constant operation for this value.
Attribute attr = cast<Attribute>(foldResult);
auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
op->getLoc());
if (!constOp) {
// Erase any generated constants.
for (Operation *cst : generatedConstants)
cst->erase();
return cleanupFailure();
}
assert(matchPattern(constOp, m_Constant()));
generatedConstants.push_back(constOp);
results.push_back(constOp->getResult(0));
}
// If we were successful, insert any generated constants.
for (Operation *cst : generatedConstants)
insert(cst);
return success();
}
/// Helper function that sends block insertion notifications for every block
/// that is directly nested in the given op.
static void notifyBlockInsertions(Operation *op,
OpBuilder::Listener *listener) {
for (Region &r : op->getRegions())
for (Block &b : r.getBlocks())
listener->notifyBlockInserted(&b, /*previous=*/nullptr,
/*previousIt=*/{});
}
Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
Operation *newOp = op.clone(mapper);
newOp = insert(newOp);
// The `insert` call above handles the notification for inserting `newOp`
// itself. But if `newOp` has any regions, we need to notify the listener
// about any ops that got inserted inside those regions as part of cloning.
if (listener) {
// The `insert` call above notifies about op insertion, but not about block
// insertion.
notifyBlockInsertions(newOp, listener);
auto walkFn = [&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
notifyBlockInsertions(walkedOp, listener);
};
for (Region &region : newOp->getRegions())
region.walk<WalkOrder::PreOrder>(walkFn);
}
return newOp;
}
Operation *OpBuilder::clone(Operation &op) {
IRMapping mapper;
return clone(op, mapper);
}
void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
Region::iterator before, IRMapping &mapping) {
region.cloneInto(&parent, before, mapping);
// Fast path: If no listener is attached, there is no more work to do.
if (!listener)
return;
// Notify about op/block insertion.
for (auto it = mapping.lookup(&region.front())->getIterator(); it != before;
++it) {
listener->notifyBlockInserted(&*it, /*previous=*/nullptr,
/*previousIt=*/{});
it->walk<WalkOrder::PreOrder>([&](Operation *walkedOp) {
listener->notifyOperationInserted(walkedOp, /*previous=*/{});
notifyBlockInsertions(walkedOp, listener);
});
}
}
void OpBuilder::cloneRegionBefore(Region &region, Region &parent,
Region::iterator before) {
IRMapping mapping;
cloneRegionBefore(region, parent, before, mapping);
}
void OpBuilder::cloneRegionBefore(Region &region, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}