From f1f3612417a89f797ec2d6d405dd30d6890bef9e Mon Sep 17 00:00:00 2001 From: Nick Kreeger Date: Fri, 14 Oct 2022 11:56:35 -0500 Subject: [PATCH] [mlir] Update Values to use new casting infra This allows for using the llvm namespace cast methods instead of the ones on the Value class. The Value class method are kept for now, but we'll want to remove these eventually (with a really long lead time). Related change: https://reviews.llvm.org/D134327 Differential Revision: https://reviews.llvm.org/D135870 --- mlir/include/mlir/IR/Value.h | 41 ++++++++++++++++++++------- mlir/lib/AsmParser/AsmParserState.cpp | 2 +- mlir/lib/AsmParser/Parser.cpp | 2 +- mlir/lib/IR/AsmPrinter.cpp | 2 +- mlir/lib/IR/Dominance.cpp | 2 +- 5 files changed, 35 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index c9ff4f014593..3557d2d0e0d5 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -88,26 +88,22 @@ public: template bool isa() const { - assert(*this && "isa<> used on a null type."); - return U::classof(*this); + return llvm::isa(*this); } - template - bool isa() const { - return isa() || isa(); - } template U dyn_cast() const { - return isa() ? U(impl) : U(nullptr); + return llvm::dyn_cast(*this); } + template U dyn_cast_or_null() const { - return (*this && isa()) ? U(impl) : U(nullptr); + return llvm::dyn_cast_if_present(*this); } + template U cast() const { - assert(isa()); - return U(impl); + return llvm::cast(*this); } explicit operator bool() const { return impl; } @@ -560,6 +556,31 @@ public: } }; +/// Add support for llvm style casts. We provide a cast between To and From if +/// From is mlir::Value or derives from it. +template +struct CastInfo< + To, From, + std::enable_if_t> || + std::is_base_of_v>> + : NullableValueCastFailed, + DefaultDoCastIfPossible> { + /// Arguments are taken as mlir::Value here and not as `From`, because + /// when casting from an intermediate type of the hierarchy to one of its + /// children, the val.getKind() inside T::classof will use the static + /// getKind() of the parent instead of the non-static ValueImpl::getKind() + /// that returns the dynamic type. This means that T::classof would end up + /// comparing the static Kind of the children to the static Kind of its + /// parent, making it impossible to downcast from the parent to the child. + static inline bool isPossible(mlir::Value ty) { + /// Return a constant true instead of a dynamic true when casting to self or + /// up the hierarchy. + return std::is_same_v> || + std::is_base_of_v || To::classof(ty); + } + static inline To doCast(mlir::Value value) { return To(value.getImpl()); } +}; + } // namespace llvm #endif diff --git a/mlir/lib/AsmParser/AsmParserState.cpp b/mlir/lib/AsmParser/AsmParserState.cpp index d94a25017615..4860b101d598 100644 --- a/mlir/lib/AsmParser/AsmParserState.cpp +++ b/mlir/lib/AsmParser/AsmParserState.cpp @@ -273,7 +273,7 @@ void AsmParserState::addDefinition(BlockArgument blockArg, SMLoc location) { void AsmParserState::addUses(Value value, ArrayRef locations) { // Handle the case where the value is an operation result. - if (OpResult result = value.dyn_cast()) { + if (OpResult result = dyn_cast(value)) { // Check to see if a definition for the parent operation has been recorded. // If one hasn't, we treat the provided value as a placeholder value that // will be refined further later. diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index b92d27b37dc2..ac1941254311 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -2255,7 +2255,7 @@ ParseResult OperationParser::codeCompleteSSAUse() { // If the value isn't a forward reference, we also add the name of the op // to the detail. - if (auto result = frontValue.dyn_cast()) { + if (auto result = dyn_cast(frontValue)) { if (!forwardRefPlaceholders.count(result)) detailOS << result.getOwner()->getName() << ": "; } else { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 53da51cf1086..b2de0b3e4d7f 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1009,7 +1009,7 @@ void SSANameState::printValueID(Value value, bool printResultNo, // If this is an operation result, collect the head lookup value of the result // group and the result number of 'result' within that group. - if (OpResult result = value.dyn_cast()) + if (OpResult result = dyn_cast(value)) getResultIDAndNumber(result, lookupValue, resultNo); auto it = valueIDs.find(lookupValue); diff --git a/mlir/lib/IR/Dominance.cpp b/mlir/lib/IR/Dominance.cpp index 68cc2039ab64..beb7f7bfeedf 100644 --- a/mlir/lib/IR/Dominance.cpp +++ b/mlir/lib/IR/Dominance.cpp @@ -297,7 +297,7 @@ bool DominanceInfo::properlyDominatesImpl(Operation *a, Operation *b, bool DominanceInfo::properlyDominates(Value a, Operation *b) const { // block arguments properly dominate all operations in their own block, so // we use a dominates check here, not a properlyDominates check. - if (auto blockArg = a.dyn_cast()) + if (auto blockArg = dyn_cast(a)) return dominates(blockArg.getOwner(), b->getBlock()); // `a` properlyDominates `b` if the operation defining `a` properlyDominates