mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-27 18:46:04 +00:00
[ODS] Use Adaptor Traits for Type Inference
Author inferReturnTypes methods with the Op Adaptor by using the InferTypeOpAdaptor. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D155115
This commit is contained in:
parent
74c0bdff7d
commit
5267ed05bc
@ -903,7 +903,7 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
|
||||
Pure,
|
||||
SameVariadicResultSize,
|
||||
ViewLikeOpInterface,
|
||||
InferTypeOpInterfaceAdaptor]> {
|
||||
InferTypeOpAdaptor]> {
|
||||
let summary = "Extracts a buffer base with offset and strides";
|
||||
let description = [{
|
||||
Extracts a base buffer, offset and strides. This op allows additional layers
|
||||
|
@ -713,9 +713,8 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
|
||||
|
||||
def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
|
||||
"getNumRegionInvocations", "getRegionInvocationBounds"]>,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveMemoryEffects,
|
||||
NoRegionArguments]> {
|
||||
InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
|
||||
RecursiveMemoryEffects, NoRegionArguments]> {
|
||||
let summary = "if-then-else operation";
|
||||
let description = [{
|
||||
The `scf.if` operation represents an if-then-else construct for
|
||||
|
@ -32,8 +32,7 @@ class Shape_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<ShapeDialect, mnemonic, traits>;
|
||||
|
||||
def Shape_AddOp : Shape_Op<"add",
|
||||
[Commutative, Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
[Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Addition of sizes and indices";
|
||||
let description = [{
|
||||
Adds two sizes or indices. If either operand is an error it will be
|
||||
@ -51,12 +50,6 @@ def Shape_AddOp : Shape_Op<"add",
|
||||
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
@ -109,7 +102,7 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, Pure]> {
|
||||
}
|
||||
|
||||
def Shape_ConstShapeOp : Shape_Op<"const_shape",
|
||||
[ConstantLike, Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
[ConstantLike, Pure, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Creates a constant shape or extent tensor";
|
||||
let description = [{
|
||||
Creates a constant shape or extent tensor. The individual extents are given
|
||||
@ -128,11 +121,6 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape",
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// InferTypeOpInterface:
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_ConstSizeOp : Shape_Op<"const_size", [
|
||||
@ -158,8 +146,7 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_DivOp : Shape_Op<"div", [Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
def Shape_DivOp : Shape_Op<"div", [Pure, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Division of sizes and indices";
|
||||
let description = [{
|
||||
Divides two sizes or indices. If either operand is an error it will be
|
||||
@ -187,12 +174,6 @@ def Shape_DivOp : Shape_Op<"div", [Pure,
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Pure, Commutative]> {
|
||||
@ -287,7 +268,7 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", [Commutative]> {
|
||||
}
|
||||
|
||||
def Shape_RankOp : Shape_Op<"rank",
|
||||
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
[Pure, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Gets the rank of a shape";
|
||||
let description = [{
|
||||
Returns the rank of the shape or extent tensor, i.e. the number of extents.
|
||||
@ -301,12 +282,6 @@ def Shape_RankOp : Shape_Op<"rank",
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
|
||||
@ -330,7 +305,7 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
|
||||
}
|
||||
|
||||
def Shape_DimOp : Shape_Op<"dim",
|
||||
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
[Pure, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Gets the specified extent from the shape of a shaped input";
|
||||
let description = [{
|
||||
Gets the extent indexed by `dim` from the shape of the `value` operand. If
|
||||
@ -354,17 +329,13 @@ def Shape_DimOp : Shape_Op<"dim",
|
||||
let extraClassDeclaration = [{
|
||||
/// Get the `index` value as integer if it is constant.
|
||||
std::optional<int64_t> getConstantIndex();
|
||||
|
||||
/// Returns when two result types are compatible for this op; method used
|
||||
/// by InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_GetExtentOp : Shape_Op<"get_extent",
|
||||
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
[Pure, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Gets the specified extent from a shape or extent tensor";
|
||||
let description = [{
|
||||
Gets the extent indexed by `dim` from the `shape` operand. If the shape is
|
||||
@ -384,9 +355,6 @@ def Shape_GetExtentOp : Shape_Op<"get_extent",
|
||||
let extraClassDeclaration = [{
|
||||
/// Get the `dim` value as integer if it is constant.
|
||||
std::optional<int64_t> getConstantDim();
|
||||
/// Returns when two result types are compatible for this op; method used
|
||||
/// by InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
@ -413,8 +381,7 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [Pure]> {
|
||||
}
|
||||
|
||||
def Shape_MaxOp : Shape_Op<"max",
|
||||
[Commutative, Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
[Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Elementwise maximum";
|
||||
let description = [{
|
||||
Computes the elementwise maximum of two sizes or shapes with equal ranks.
|
||||
@ -431,16 +398,10 @@ def Shape_MaxOp : Shape_Op<"max",
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_MeetOp : Shape_Op<"meet",
|
||||
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
[Commutative, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Returns the least general shape or size of its operands";
|
||||
let description = [{
|
||||
An operation that computes the least general shape or dim of input operands.
|
||||
@ -478,17 +439,10 @@ def Shape_MeetOp : Shape_Op<"meet",
|
||||
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
|
||||
type($arg0) `,` type($arg1) `->` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_MinOp : Shape_Op<"min",
|
||||
[Commutative, Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
[Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Elementwise minimum";
|
||||
let description = [{
|
||||
Computes the elementwise minimum of two sizes or shapes with equal ranks.
|
||||
@ -505,17 +459,10 @@ def Shape_MinOp : Shape_Op<"min",
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_MulOp : Shape_Op<"mul",
|
||||
[Commutative, Pure,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
[Commutative, Pure, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Multiplication of sizes and indices";
|
||||
let description = [{
|
||||
Multiplies two sizes or indices. If either operand is an error it will be
|
||||
@ -535,16 +482,10 @@ def Shape_MulOp : Shape_Op<"mul",
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_NumElementsOp : Shape_Op<"num_elements",
|
||||
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
[Pure, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Returns the number of elements for a given shape";
|
||||
let description = [{
|
||||
Returns the number of elements for a given shape which is the product of
|
||||
@ -561,11 +502,6 @@ def Shape_NumElementsOp : Shape_Op<"num_elements",
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_ReduceOp : Shape_Op<"reduce",
|
||||
@ -616,7 +552,7 @@ def Shape_ReduceOp : Shape_Op<"reduce",
|
||||
}
|
||||
|
||||
def Shape_ShapeOfOp : Shape_Op<"shape_of",
|
||||
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
[Pure, InferTypeOpAdaptorWithIsCompatible]> {
|
||||
let summary = "Returns shape of a value or shaped type operand";
|
||||
|
||||
let description = [{
|
||||
@ -632,12 +568,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_ValueOfOp : Shape_Op<"value_of", [Pure]> {
|
||||
|
@ -463,7 +463,7 @@ def Vector_ShuffleOp :
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
PredOpTrait<"second operand v2 and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 1>>,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
|
||||
InferTypeOpAdaptor]>,
|
||||
Arguments<(ins AnyVectorOfAnyRank:$v1, AnyVectorOfAnyRank:$v2,
|
||||
I64ArrayAttr:$mask)>,
|
||||
Results<(outs AnyVector:$vector)> {
|
||||
@ -572,7 +572,7 @@ def Vector_ExtractOp :
|
||||
Vector_Op<"extract", [Pure,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
|
||||
InferTypeOpAdaptorWithIsCompatible]>,
|
||||
Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>,
|
||||
Results<(outs AnyType)> {
|
||||
let summary = "extract operation";
|
||||
@ -598,7 +598,6 @@ def Vector_ExtractOp :
|
||||
VectorType getSourceVectorType() {
|
||||
return ::llvm::cast<VectorType>(getVector().getType());
|
||||
}
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
let assemblyFormat = "$vector `` $position attr-dict `:` type($vector)";
|
||||
let hasCanonicalizer = 1;
|
||||
|
@ -259,8 +259,8 @@ namespace mlir {
|
||||
namespace OpTrait {
|
||||
|
||||
template <typename ConcreteType>
|
||||
class InferTypeOpInterfaceAdaptor
|
||||
: public TraitBase<ConcreteType, InferTypeOpInterfaceAdaptor> {};
|
||||
class InferTypeOpAdaptor : public TraitBase<ConcreteType, InferTypeOpAdaptor> {
|
||||
};
|
||||
|
||||
/// Tensor type inference trait that constructs a tensor from the inferred
|
||||
/// shape and elemental types.
|
||||
|
@ -186,35 +186,42 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
|
||||
|
||||
// Convenient trait to define a wrapper to inferReturnTypes that passes in the
|
||||
// Op Adaptor directly
|
||||
def InferTypeOpInterfaceAdaptor : TraitList<
|
||||
class InferTypeOpAdaptorBase<code additionalDecls = [{}]> : TraitList<
|
||||
[
|
||||
// Op implements infer type op interface.
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
NativeOpTrait<
|
||||
/*name=*/"InferTypeOpInterfaceAdaptor",
|
||||
/*name=*/"InferTypeOpAdaptor",
|
||||
/*traits=*/[],
|
||||
/*extraOpDeclaration=*/[{
|
||||
static LogicalResult
|
||||
inferReturnTypesAdaptor(MLIRContext *context,
|
||||
std::optional<Location> location,
|
||||
static ::mlir::LogicalResult
|
||||
inferReturnTypes(::mlir::MLIRContext *context,
|
||||
std::optional<::mlir::Location> location,
|
||||
Adaptor adaptor,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes);
|
||||
}],
|
||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes);
|
||||
}] # additionalDecls,
|
||||
/*extraOpDefinition=*/[{
|
||||
LogicalResult
|
||||
$cppClass::inferReturnTypes(MLIRContext *context,
|
||||
std::optional<Location> location,
|
||||
ValueRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
::mlir::LogicalResult
|
||||
$cppClass::inferReturnTypes(::mlir::MLIRContext *context,
|
||||
std::optional<::mlir::Location> location,
|
||||
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
|
||||
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
$cppClass::Adaptor adaptor(operands, attributes, properties, regions);
|
||||
return $cppClass::inferReturnTypesAdaptor(context,
|
||||
return $cppClass::inferReturnTypes(context,
|
||||
location, adaptor, inferredReturnTypes);
|
||||
}
|
||||
}]
|
||||
>
|
||||
]>;
|
||||
|
||||
def InferTypeOpAdaptor : InferTypeOpAdaptorBase;
|
||||
def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
|
||||
[{
|
||||
static bool isCompatibleReturnTypes(::mlir::TypeRange l, ::mlir::TypeRange r);
|
||||
}]
|
||||
>;
|
||||
|
||||
// Convenience class grouping together type and shaped type op interfaces for
|
||||
// ops that have tensor return types.
|
||||
class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
|
||||
@ -231,13 +238,13 @@ class InferTensorTypeBase<list<string> overridenMethods = []> : TraitList<
|
||||
/*traits=*/[],
|
||||
/*extraOpDeclaration=*/[{}],
|
||||
/*extraOpDefinition=*/[{
|
||||
LogicalResult
|
||||
$cppClass::inferReturnTypes(MLIRContext *context,
|
||||
std::optional<Location> location,
|
||||
ValueRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
SmallVector<ShapedTypeComponents, 2> retComponents;
|
||||
::mlir::LogicalResult
|
||||
$cppClass::inferReturnTypes(::mlir::MLIRContext *context,
|
||||
std::optional<::mlir::Location> location,
|
||||
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
|
||||
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
::llvm::SmallVector<::mlir::ShapedTypeComponents, 2> retComponents;
|
||||
if (failed($cppClass::inferReturnTypeComponents(context, location,
|
||||
operands, attributes, properties, regions,
|
||||
retComponents)))
|
||||
|
@ -1354,7 +1354,7 @@ void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
|
||||
|
||||
/// The number and type of the results are inferred from the
|
||||
/// shape of the source.
|
||||
LogicalResult ExtractStridedMetadataOp::inferReturnTypesAdaptor(
|
||||
LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
ExtractStridedMetadataOp::Adaptor adaptor,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
|
@ -1841,12 +1841,11 @@ bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
|
||||
|
||||
LogicalResult
|
||||
IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
|
||||
ValueRange operands, DictionaryAttr attrs,
|
||||
OpaqueProperties properties, RegionRange regions,
|
||||
IfOp::Adaptor adaptor,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (regions.empty())
|
||||
if (adaptor.getRegions().empty())
|
||||
return failure();
|
||||
Region *r = regions.front();
|
||||
Region *r = &adaptor.getThenRegion();
|
||||
if (r->empty())
|
||||
return failure();
|
||||
Block &b = r->front();
|
||||
|
@ -394,11 +394,10 @@ void AssumingOp::build(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult mlir::shape::AddOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (llvm::isa<SizeType>(operands[0].getType()) ||
|
||||
llvm::isa<SizeType>(operands[1].getType()))
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
|
||||
llvm::isa<SizeType>(adaptor.getRhs().getType()))
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
else
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
@ -916,18 +915,17 @@ void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
Builder b(context);
|
||||
Properties *prop = properties.as<Properties *>();
|
||||
const Properties *prop = &adaptor.getProperties();
|
||||
DenseIntElementsAttr shape;
|
||||
// TODO: this is only exercised by the Python bindings codepath which does not
|
||||
// support properties
|
||||
if (prop)
|
||||
shape = prop->shape;
|
||||
else
|
||||
shape = attributes.getAs<DenseIntElementsAttr>("shape");
|
||||
shape = adaptor.getAttributes().getAs<DenseIntElementsAttr>("shape");
|
||||
if (!shape)
|
||||
return emitOptionalError(location, "missing shape attribute");
|
||||
inferredReturnTypes.assign({RankedTensorType::get(
|
||||
@ -1104,11 +1102,9 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::DimOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
DimOpAdaptor dimOp(operands);
|
||||
inferredReturnTypes.assign({dimOp.getIndex().getType()});
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.assign({adaptor.getIndex().getType()});
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -1141,11 +1137,10 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::DivOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (llvm::isa<SizeType>(operands[0].getType()) ||
|
||||
llvm::isa<SizeType>(operands[1].getType()))
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
|
||||
llvm::isa<SizeType>(adaptor.getRhs().getType()))
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
else
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
@ -1361,9 +1356,8 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
return success();
|
||||
}
|
||||
@ -1399,10 +1393,9 @@ OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult mlir::shape::MeetOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (operands.empty())
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (adaptor.getOperands().empty())
|
||||
return failure();
|
||||
|
||||
auto isShapeType = [](Type arg) {
|
||||
@ -1411,7 +1404,7 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes(
|
||||
return isExtentTensorType(arg);
|
||||
};
|
||||
|
||||
ValueRange::type_range types = operands.getTypes();
|
||||
ValueRange::type_range types = adaptor.getOperands().getTypes();
|
||||
Type acc = types.front();
|
||||
for (auto t : drop_begin(types)) {
|
||||
Type l = acc, r = t;
|
||||
@ -1535,10 +1528,9 @@ void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::RankOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (llvm::isa<ShapeType>(operands[0].getType()))
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
else
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
@ -1571,10 +1563,10 @@ OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
NumElementsOp::Adaptor adaptor,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (llvm::isa<ShapeType>(operands[0].getType()))
|
||||
if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
else
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
@ -1603,11 +1595,10 @@ OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::MaxOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (operands[0].getType() == operands[1].getType())
|
||||
inferredReturnTypes.assign({operands[0].getType()});
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
|
||||
inferredReturnTypes.assign({adaptor.getLhs().getType()});
|
||||
else
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
return success();
|
||||
@ -1635,11 +1626,10 @@ OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::MinOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (operands[0].getType() == operands[1].getType())
|
||||
inferredReturnTypes.assign({operands[0].getType()});
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
|
||||
inferredReturnTypes.assign({adaptor.getLhs().getType()});
|
||||
else
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
return success();
|
||||
@ -1672,11 +1662,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::MulOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (llvm::isa<SizeType>(operands[0].getType()) ||
|
||||
llvm::isa<SizeType>(operands[1].getType()))
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
|
||||
llvm::isa<SizeType>(adaptor.getRhs().getType()))
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
else
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
@ -1759,13 +1748,12 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (llvm::isa<ValueShapeType>(operands[0].getType()))
|
||||
MLIRContext *context, std::optional<Location> location,
|
||||
ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
|
||||
inferredReturnTypes.assign({ShapeType::get(context)});
|
||||
else {
|
||||
auto shapedTy = llvm::cast<ShapedType>(operands[0].getType());
|
||||
auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
|
||||
int64_t rank =
|
||||
shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
|
||||
Type indexTy = IndexType::get(context);
|
||||
|
@ -1146,15 +1146,15 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
|
||||
|
||||
LogicalResult
|
||||
ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
|
||||
ValueRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange,
|
||||
ExtractOp::Adaptor adaptor,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
ExtractOp::Adaptor op(operands, attributes, properties);
|
||||
auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
|
||||
if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
|
||||
auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
|
||||
if (static_cast<int64_t>(adaptor.getPosition().size()) ==
|
||||
vectorType.getRank()) {
|
||||
inferredReturnTypes.push_back(vectorType.getElementType());
|
||||
} else {
|
||||
auto n = std::min<size_t>(op.getPosition().size(), vectorType.getRank());
|
||||
auto n =
|
||||
std::min<size_t>(adaptor.getPosition().size(), vectorType.getRank());
|
||||
inferredReturnTypes.push_back(VectorType::get(
|
||||
vectorType.getShape().drop_front(n), vectorType.getElementType()));
|
||||
}
|
||||
@ -2114,17 +2114,15 @@ LogicalResult ShuffleOp::verify() {
|
||||
|
||||
LogicalResult
|
||||
ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
|
||||
ValueRange operands, DictionaryAttr attributes,
|
||||
OpaqueProperties properties, RegionRange,
|
||||
ShuffleOp::Adaptor adaptor,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
ShuffleOp::Adaptor op(operands, attributes, properties);
|
||||
auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
|
||||
auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
|
||||
auto v1Rank = v1Type.getRank();
|
||||
// Construct resulting type: leading dimension matches mask
|
||||
// length, all trailing dimensions match the operands.
|
||||
SmallVector<int64_t, 4> shape;
|
||||
shape.reserve(v1Rank);
|
||||
shape.push_back(std::max<size_t>(1, op.getMask().size()));
|
||||
shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
|
||||
// In the 0-D case there is no trailing shape to append.
|
||||
if (v1Rank > 0)
|
||||
llvm::append_range(shape, v1Type.getShape().drop_front());
|
||||
|
@ -1385,6 +1385,19 @@ LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
|
||||
MLIRContext *, std::optional<Location> location,
|
||||
OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (adaptor.getX().getType() != adaptor.getY().getType()) {
|
||||
return emitOptionalError(location, "operand type mismatch ",
|
||||
adaptor.getX().getType(), " vs ",
|
||||
adaptor.getY().getType());
|
||||
}
|
||||
inferredReturnTypes.assign({adaptor.getX().getType()});
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO: We should be able to only define either inferReturnType or
|
||||
// refineReturnType, currently only refineReturnType can be omitted.
|
||||
LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
|
||||
|
@ -761,6 +761,12 @@ def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def OpWithInferTypeAdaptorInterfaceOp : TEST_Op<"op_with_infer_type_adaptor_if", [
|
||||
InferTypeOpAdaptor]> {
|
||||
let arguments = (ins AnyTensor:$x, AnyTensor:$y);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def OpWithRefineTypeInterfaceOp : TEST_Op<"op_with_refine_type_if", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface,
|
||||
["refineReturnTypes"]>]> {
|
||||
|
@ -485,6 +485,8 @@ struct TestReturnTypeDriver
|
||||
// output would be in reverse order underneath `op` from which
|
||||
// the attributes and regions are used.
|
||||
invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op);
|
||||
invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>(
|
||||
op);
|
||||
invokeCreateWithInferredReturnType<
|
||||
OpWithShapedTypeInferTypeInterfaceOp>(op);
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user