[ODS] Use Adaptor Traits for Type Inference

Author inferReturnTypes methods with the Op Adaptor by using the InferTypeOpAdaptor.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D155115
This commit is contained in:
Amanda Tang 2023-07-12 20:13:25 +00:00
parent 74c0bdff7d
commit 5267ed05bc
13 changed files with 124 additions and 183 deletions

View File

@ -903,7 +903,7 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
Pure,
SameVariadicResultSize,
ViewLikeOpInterface,
InferTypeOpInterfaceAdaptor]> {
InferTypeOpAdaptor]> {
let summary = "Extracts a buffer base with offset and strides";
let description = [{
Extracts a base buffer, offset and strides. This op allows additional layers

View File

@ -713,9 +713,8 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getNumRegionInvocations", "getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects,
NoRegionArguments]> {
InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
RecursiveMemoryEffects, NoRegionArguments]> {
let summary = "if-then-else operation";
let description = [{
The `scf.if` operation represents an if-then-else construct for

View File

@ -32,8 +32,7 @@ class Shape_Op<string mnemonic, list<Trait> traits = []> :
Op<ShapeDialect, mnemonic, traits>;
def Shape_AddOp : Shape_Op<"add",
[Commutative, Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
[Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Addition of sizes and indices";
let description = [{
Adds two sizes or indices. If either operand is an error it will be
@ -51,12 +50,6 @@ def Shape_AddOp : Shape_Op<"add",
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
}];
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
let hasFolder = 1;
let hasVerifier = 1;
}
@ -109,7 +102,7 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, Pure]> {
}
def Shape_ConstShapeOp : Shape_Op<"const_shape",
[ConstantLike, Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
[ConstantLike, Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Creates a constant shape or extent tensor";
let description = [{
Creates a constant shape or extent tensor. The individual extents are given
@ -128,11 +121,6 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape",
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
// InferTypeOpInterface:
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def Shape_ConstSizeOp : Shape_Op<"const_size", [
@ -158,8 +146,7 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
let hasFolder = 1;
}
def Shape_DivOp : Shape_Op<"div", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
def Shape_DivOp : Shape_Op<"div", [Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Division of sizes and indices";
let description = [{
Divides two sizes or indices. If either operand is an error it will be
@ -187,12 +174,6 @@ def Shape_DivOp : Shape_Op<"div", [Pure,
let hasFolder = 1;
let hasVerifier = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Pure, Commutative]> {
@ -287,7 +268,7 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
}
def Shape_RankOp : Shape_Op<"rank",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
[Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Gets the rank of a shape";
let description = [{
Returns the rank of the shape or extent tensor, i.e. the number of extents.
@ -301,12 +282,6 @@ def Shape_RankOp : Shape_Op<"rank",
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
@ -330,7 +305,7 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
}
def Shape_DimOp : Shape_Op<"dim",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
[Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Gets the specified extent from the shape of a shaped input";
let description = [{
Gets the extent indexed by `dim` from the shape of the `value` operand. If
@ -354,17 +329,13 @@ def Shape_DimOp : Shape_Op<"dim",
let extraClassDeclaration = [{
/// Get the `index` value as integer if it is constant.
std::optional<int64_t> getConstantIndex();
/// Returns when two result types are compatible for this op; method used
/// by InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
let hasFolder = 1;
}
def Shape_GetExtentOp : Shape_Op<"get_extent",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
[Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Gets the specified extent from a shape or extent tensor";
let description = [{
Gets the extent indexed by `dim` from the `shape` operand. If the shape is
@ -384,9 +355,6 @@ def Shape_GetExtentOp : Shape_Op<"get_extent",
let extraClassDeclaration = [{
/// Get the `dim` value as integer if it is constant.
std::optional<int64_t> getConstantDim();
/// Returns when two result types are compatible for this op; method used
/// by InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
let hasFolder = 1;
@ -413,8 +381,7 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [Pure]> {
}
def Shape_MaxOp : Shape_Op<"max",
[Commutative, Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
[Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Elementwise maximum";
let description = [{
Computes the elementwise maximum of two sizes or shapes with equal ranks.
@ -431,16 +398,10 @@ def Shape_MaxOp : Shape_Op<"max",
}];
let hasFolder = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def Shape_MeetOp : Shape_Op<"meet",
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
[Commutative, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Returns the least general shape or size of its operands";
let description = [{
An operation that computes the least general shape or dim of input operands.
@ -478,17 +439,10 @@ def Shape_MeetOp : Shape_Op<"meet",
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
type($arg0) `,` type($arg1) `->` type($result)
}];
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def Shape_MinOp : Shape_Op<"min",
[Commutative, Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
[Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Elementwise minimum";
let description = [{
Computes the elementwise minimum of two sizes or shapes with equal ranks.
@ -505,17 +459,10 @@ def Shape_MinOp : Shape_Op<"min",
}];
let hasFolder = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def Shape_MulOp : Shape_Op<"mul",
[Commutative, Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
[Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Multiplication of sizes and indices";
let description = [{
Multiplies two sizes or indices. If either operand is an error it will be
@ -535,16 +482,10 @@ def Shape_MulOp : Shape_Op<"mul",
let hasFolder = 1;
let hasVerifier = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def Shape_NumElementsOp : Shape_Op<"num_elements",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
[Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Returns the number of elements for a given shape";
let description = [{
Returns the number of elements for a given shape which is the product of
@ -561,11 +502,6 @@ def Shape_NumElementsOp : Shape_Op<"num_elements",
let hasFolder = 1;
let hasVerifier = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def Shape_ReduceOp : Shape_Op<"reduce",
@ -616,7 +552,7 @@ def Shape_ReduceOp : Shape_Op<"reduce",
}
def Shape_ShapeOfOp : Shape_Op<"shape_of",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
[Pure, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "Returns shape of a value or shaped type operand";
let description = [{
@ -632,12 +568,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def Shape_ValueOfOp : Shape_Op<"value_of", [Pure]> {

View File

@ -463,7 +463,7 @@ def Vector_ShuffleOp :
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"second operand v2 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>,
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
InferTypeOpAdaptor]>,
Arguments<(ins AnyVectorOfAnyRank:$v1, AnyVectorOfAnyRank:$v2,
I64ArrayAttr:$mask)>,
Results<(outs AnyVector:$vector)> {
@ -572,7 +572,7 @@ def Vector_ExtractOp :
Vector_Op<"extract", [Pure,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
InferTypeOpAdaptorWithIsCompatible]>,
Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>,
Results<(outs AnyType)> {
let summary = "extract operation";
@ -598,7 +598,6 @@ def Vector_ExtractOp :
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getVector().getType());
}
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";
let hasCanonicalizer = 1;

View File

@ -259,8 +259,8 @@ namespace mlir {
namespace OpTrait {
template <typename ConcreteType>
class InferTypeOpInterfaceAdaptor
: public TraitBase<ConcreteType, InferTypeOpInterfaceAdaptor> {};
class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
};
/// Tensor type inference trait that constructs a tensor from the inferred
/// shape and elemental types.

View File

@ -186,35 +186,42 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
// Convenient trait to define a wrapper to inferReturnTypes that passes in the
// Op Adaptor directly
def InferTypeOpInterfaceAdaptor : TraitList<
class InferTypeOpAdaptorBase<code additionalDecls = [{}]> : TraitList<
[
// Op implements infer type op interface.
DeclareOpInterfaceMethods<InferTypeOpInterface>,
NativeOpTrait<
/*name=*/"InferTypeOpInterfaceAdaptor",
/*name=*/"InferTypeOpAdaptor",
/*traits=*/[],
/*extraOpDeclaration=*/[{
static LogicalResult
inferReturnTypesAdaptor(MLIRContext *context,
std::optional<Location> location,
static ::mlir::LogicalResult
inferReturnTypes(::mlir::MLIRContext *context,
std::optional<::mlir::Location> location,
Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes);
}],
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes);
}] # additionalDecls,
/*extraOpDefinition=*/[{
LogicalResult
$cppClass::inferReturnTypes(MLIRContext *context,
std::optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
::mlir::LogicalResult
$cppClass::inferReturnTypes(::mlir::MLIRContext *context,
std::optional<::mlir::Location> location,
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
$cppClass::Adaptor adaptor(operands, attributes, properties, regions);
return $cppClass::inferReturnTypesAdaptor(context,
return $cppClass::inferReturnTypes(context,
location, adaptor, inferredReturnTypes);
}
}]
>
]>;
def InferTypeOpAdaptor : InferTypeOpAdaptorBase;
def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
[{
static bool isCompatibleReturnTypes(::mlir::TypeRange l, ::mlir::TypeRange r);
}]
>;
// Convenience class grouping together type and shaped type op interfaces for
// ops that have tensor return types.
class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
@ -231,13 +238,13 @@ class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
/*traits=*/[],
/*extraOpDeclaration=*/[{}],
/*extraOpDefinition=*/[{
LogicalResult
$cppClass::inferReturnTypes(MLIRContext *context,
std::optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
SmallVector<ShapedTypeComponents, 2> retComponents;
::mlir::LogicalResult
$cppClass::inferReturnTypes(::mlir::MLIRContext *context,
std::optional<::mlir::Location> location,
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
::llvm::SmallVector<::mlir::ShapedTypeComponents, 2> retComponents;
if (failed($cppClass::inferReturnTypeComponents(context, location,
operands, attributes, properties, regions,
retComponents)))

View File

@ -1354,7 +1354,7 @@ void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
/// The number and type of the results are inferred from the
/// shape of the source.
LogicalResult ExtractStridedMetadataOp::inferReturnTypesAdaptor(
LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
ExtractStridedMetadataOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {

View File

@ -1841,12 +1841,11 @@ bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
LogicalResult
IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
ValueRange operands, DictionaryAttr attrs,
OpaqueProperties properties, RegionRange regions,
IfOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (regions.empty())
if (adaptor.getRegions().empty())
return failure();
Region *r = regions.front();
Region *r = &adaptor.getThenRegion();
if (r->empty())
return failure();
Block &b = r->front();

View File

@ -394,11 +394,10 @@ void AssumingOp::build(
//===----------------------------------------------------------------------===//
LogicalResult mlir::shape::AddOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<SizeType>(operands[0].getType()) ||
llvm::isa<SizeType>(operands[1].getType()))
MLIRContext *context, std::optional<Location> location,
AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
llvm::isa<SizeType>(adaptor.getRhs().getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@ -916,18 +915,17 @@ void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}
LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
MLIRContext *context, std::optional<Location> location,
ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
Builder b(context);
Properties *prop = properties.as<Properties *>();
const Properties *prop = &adaptor.getProperties();
DenseIntElementsAttr shape;
// TODO: this is only exercised by the Python bindings codepath which does not
// support properties
if (prop)
shape = prop->shape;
else
shape = attributes.getAs<DenseIntElementsAttr>("shape");
shape = adaptor.getAttributes().getAs<DenseIntElementsAttr>("shape");
if (!shape)
return emitOptionalError(location, "missing shape attribute");
inferredReturnTypes.assign({RankedTensorType::get(
@ -1104,11 +1102,9 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::DimOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
DimOpAdaptor dimOp(operands);
inferredReturnTypes.assign({dimOp.getIndex().getType()});
MLIRContext *context, std::optional<Location> location,
DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({adaptor.getIndex().getType()});
return success();
}
@ -1141,11 +1137,10 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::DivOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<SizeType>(operands[0].getType()) ||
llvm::isa<SizeType>(operands[1].getType()))
MLIRContext *context, std::optional<Location> location,
DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
llvm::isa<SizeType>(adaptor.getRhs().getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@ -1361,9 +1356,8 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
}
LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
MLIRContext *context, std::optional<Location> location,
GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
@ -1399,10 +1393,9 @@ OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
LogicalResult mlir::shape::MeetOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands.empty())
MLIRContext *context, std::optional<Location> location,
MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
if (adaptor.getOperands().empty())
return failure();
auto isShapeType = [](Type arg) {
@ -1411,7 +1404,7 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes(
return isExtentTensorType(arg);
};
ValueRange::type_range types = operands.getTypes();
ValueRange::type_range types = adaptor.getOperands().getTypes();
Type acc = types.front();
for (auto t : drop_begin(types)) {
Type l = acc, r = t;
@ -1535,10 +1528,9 @@ void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}
LogicalResult mlir::shape::RankOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<ShapeType>(operands[0].getType()))
MLIRContext *context, std::optional<Location> location,
RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@ -1571,10 +1563,10 @@ OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
MLIRContext *context, std::optional<Location> location,
NumElementsOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<ShapeType>(operands[0].getType()))
if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@ -1603,11 +1595,10 @@ OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::MaxOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() == operands[1].getType())
inferredReturnTypes.assign({operands[0].getType()});
MLIRContext *context, std::optional<Location> location,
MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
inferredReturnTypes.assign({adaptor.getLhs().getType()});
else
inferredReturnTypes.assign({SizeType::get(context)});
return success();
@ -1635,11 +1626,10 @@ OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::MinOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() == operands[1].getType())
inferredReturnTypes.assign({operands[0].getType()});
MLIRContext *context, std::optional<Location> location,
MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
inferredReturnTypes.assign({adaptor.getLhs().getType()});
else
inferredReturnTypes.assign({SizeType::get(context)});
return success();
@ -1672,11 +1662,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
}
LogicalResult mlir::shape::MulOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<SizeType>(operands[0].getType()) ||
llvm::isa<SizeType>(operands[1].getType()))
MLIRContext *context, std::optional<Location> location,
MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
llvm::isa<SizeType>(adaptor.getRhs().getType()))
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
@ -1759,13 +1748,12 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<ValueShapeType>(operands[0].getType()))
MLIRContext *context, std::optional<Location> location,
ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
inferredReturnTypes.assign({ShapeType::get(context)});
else {
auto shapedTy = llvm::cast<ShapedType>(operands[0].getType());
auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
int64_t rank =
shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
Type indexTy = IndexType::get(context);

View File

@ -1146,15 +1146,15 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
LogicalResult
ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange,
ExtractOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
ExtractOp::Adaptor op(operands, attributes, properties);
auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
if (static_cast<int64_t>(adaptor.getPosition().size()) ==
vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
} else {
auto n = std::min<size_t>(op.getPosition().size(), vectorType.getRank());
auto n =
std::min<size_t>(adaptor.getPosition().size(), vectorType.getRank());
inferredReturnTypes.push_back(VectorType::get(
vectorType.getShape().drop_front(n), vectorType.getElementType()));
}
@ -2114,17 +2114,15 @@ LogicalResult ShuffleOp::verify() {
LogicalResult
ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange,
ShuffleOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
ShuffleOp::Adaptor op(operands, attributes, properties);
auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
auto v1Rank = v1Type.getRank();
// Construct resulting type: leading dimension matches mask
// length, all trailing dimensions match the operands.
SmallVector<int64_t, 4> shape;
shape.reserve(v1Rank);
shape.push_back(std::max<size_t>(1, op.getMask().size()));
shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
// In the 0-D case there is no trailing shape to append.
if (v1Rank > 0)
llvm::append_range(shape, v1Type.getShape().drop_front());

View File

@ -1385,6 +1385,19 @@ LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
return success();
}
LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
MLIRContext *, std::optional<Location> location,
OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (adaptor.getX().getType() != adaptor.getY().getType()) {
return emitOptionalError(location, "operand type mismatch ",
adaptor.getX().getType(), " vs ",
adaptor.getY().getType());
}
inferredReturnTypes.assign({adaptor.getX().getType()});
return success();
}
// TODO: We should be able to only define either inferReturnType or
// refineReturnType, currently only refineReturnType can be omitted.
LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(

View File

@ -761,6 +761,12 @@ def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
let results = (outs AnyTensor);
}
def OpWithInferTypeAdaptorInterfaceOp : TEST_Op<"op_with_infer_type_adaptor_if", [
InferTypeOpAdaptor]> {
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
let results = (outs AnyTensor);
}
def OpWithRefineTypeInterfaceOp : TEST_Op<"op_with_refine_type_if", [
DeclareOpInterfaceMethods<InferTypeOpInterface,
["refineReturnTypes"]>]> {

View File

@ -485,6 +485,8 @@ struct TestReturnTypeDriver
// output would be in reverse order underneath `op` from which
// the attributes and regions are used.
invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
op);
invokeCreateWithInferredReturnType<
OpWithShapedTypeInferTypeInterfaceOp>(op);
};