mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-25 23:56:05 +00:00
[mlir] (NFC) run clang-format on all files
This commit is contained in:
parent
edee61b55c
commit
b7f93c2809
@ -60,7 +60,8 @@ private:
|
||||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
template <typename T>
|
||||
static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
|
@ -60,7 +60,8 @@ private:
|
||||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
template <typename T>
|
||||
static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
|
@ -60,7 +60,8 @@ private:
|
||||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
template <typename T>
|
||||
static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
|
@ -60,7 +60,8 @@ private:
|
||||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
template <typename T>
|
||||
static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
|
@ -104,25 +104,26 @@ struct BinaryOpLowering : public ConversionPattern {
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the BinaryOp. This
|
||||
// allows for using the nice named accessors that are generated by the
|
||||
// ODS.
|
||||
typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
|
||||
lowerOpToLoops(op, operands, rewriter,
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the
|
||||
// BinaryOp. This allows for using the nice named accessors
|
||||
// that are generated by the ODS.
|
||||
typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
|
||||
|
||||
// Generate loads for the element of 'lhs' and 'rhs' at the inner
|
||||
// loop.
|
||||
auto loadedLhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getLhs(), loopIvs);
|
||||
auto loadedRhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getRhs(), loopIvs);
|
||||
// Generate loads for the element of 'lhs' and 'rhs' at the
|
||||
// inner loop.
|
||||
auto loadedLhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getLhs(), loopIvs);
|
||||
auto loadedRhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getRhs(), loopIvs);
|
||||
|
||||
// Create the binary operation performed on the loaded values.
|
||||
return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
});
|
||||
// Create the binary operation performed on the loaded
|
||||
// values.
|
||||
return builder.create<LoweredBinaryOp>(loc, loadedLhs,
|
||||
loadedRhs);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -60,7 +60,8 @@ private:
|
||||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
template <typename T>
|
||||
static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
|
@ -104,25 +104,26 @@ struct BinaryOpLowering : public ConversionPattern {
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the BinaryOp. This
|
||||
// allows for using the nice named accessors that are generated by the
|
||||
// ODS.
|
||||
typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
|
||||
lowerOpToLoops(op, operands, rewriter,
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the
|
||||
// BinaryOp. This allows for using the nice named accessors
|
||||
// that are generated by the ODS.
|
||||
typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
|
||||
|
||||
// Generate loads for the element of 'lhs' and 'rhs' at the inner
|
||||
// loop.
|
||||
auto loadedLhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getLhs(), loopIvs);
|
||||
auto loadedRhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getRhs(), loopIvs);
|
||||
// Generate loads for the element of 'lhs' and 'rhs' at the
|
||||
// inner loop.
|
||||
auto loadedLhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getLhs(), loopIvs);
|
||||
auto loadedRhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getRhs(), loopIvs);
|
||||
|
||||
// Create the binary operation performed on the loaded values.
|
||||
return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
});
|
||||
// Create the binary operation performed on the loaded
|
||||
// values.
|
||||
return builder.create<LoweredBinaryOp>(loc, loadedLhs,
|
||||
loadedRhs);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -60,7 +60,8 @@ private:
|
||||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
template <typename T>
|
||||
static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
|
@ -104,25 +104,26 @@ struct BinaryOpLowering : public ConversionPattern {
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
lowerOpToLoops(
|
||||
op, operands, rewriter,
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the BinaryOp. This
|
||||
// allows for using the nice named accessors that are generated by the
|
||||
// ODS.
|
||||
typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
|
||||
lowerOpToLoops(op, operands, rewriter,
|
||||
[loc](OpBuilder &builder, ValueRange memRefOperands,
|
||||
ValueRange loopIvs) {
|
||||
// Generate an adaptor for the remapped operands of the
|
||||
// BinaryOp. This allows for using the nice named accessors
|
||||
// that are generated by the ODS.
|
||||
typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
|
||||
|
||||
// Generate loads for the element of 'lhs' and 'rhs' at the inner
|
||||
// loop.
|
||||
auto loadedLhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getLhs(), loopIvs);
|
||||
auto loadedRhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getRhs(), loopIvs);
|
||||
// Generate loads for the element of 'lhs' and 'rhs' at the
|
||||
// inner loop.
|
||||
auto loadedLhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getLhs(), loopIvs);
|
||||
auto loadedRhs = builder.create<AffineLoadOp>(
|
||||
loc, binaryAdaptor.getRhs(), loopIvs);
|
||||
|
||||
// Create the binary operation performed on the loaded values.
|
||||
return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
|
||||
});
|
||||
// Create the binary operation performed on the loaded
|
||||
// values.
|
||||
return builder.create<LoweredBinaryOp>(loc, loadedLhs,
|
||||
loadedRhs);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -62,7 +62,8 @@ private:
|
||||
} // namespace
|
||||
|
||||
/// Return a formatted string for the location of any node
|
||||
template <typename T> static std::string loc(T *node) {
|
||||
template <typename T>
|
||||
static std::string loc(T *node) {
|
||||
const auto &loc = node->loc();
|
||||
return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" +
|
||||
llvm::Twine(loc.col))
|
||||
|
@ -193,7 +193,8 @@ struct AliasAnalysisTraits {
|
||||
/// `ImplT`. A model is instantiated for each alias analysis implementation
|
||||
/// to implement the `Concept` without the need for the derived
|
||||
/// implementation to inherit from the `Concept` class.
|
||||
template <typename ImplT> class Model final : public Concept {
|
||||
template <typename ImplT>
|
||||
class Model final : public Concept {
|
||||
public:
|
||||
explicit Model(ImplT &&impl) : impl(std::forward<ImplT>(impl)) {}
|
||||
~Model() override = default;
|
||||
|
@ -221,7 +221,8 @@ private:
|
||||
namespace llvm {
|
||||
// Provide graph traits for traversing call graphs using standard graph
|
||||
// traversals.
|
||||
template <> struct GraphTraits<const mlir::CallGraphNode *> {
|
||||
template <>
|
||||
struct GraphTraits<const mlir::CallGraphNode *> {
|
||||
using NodeRef = mlir::CallGraphNode *;
|
||||
static NodeRef getEntryNode(NodeRef node) { return node; }
|
||||
|
||||
|
@ -252,7 +252,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename Func, typename... Extra>
|
||||
pure_subclass &def(const char *name, Func &&f, const Extra &... extra) {
|
||||
pure_subclass &def(const char *name, Func &&f, const Extra &...extra) {
|
||||
py::cpp_function cf(
|
||||
std::forward<Func>(f), py::name(name), py::is_method(thisClass),
|
||||
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
|
||||
@ -262,7 +262,7 @@ public:
|
||||
|
||||
template <typename Func, typename... Extra>
|
||||
pure_subclass &def_property_readonly(const char *name, Func &&f,
|
||||
const Extra &... extra) {
|
||||
const Extra &...extra) {
|
||||
py::cpp_function cf(
|
||||
std::forward<Func>(f), py::name(name), py::is_method(thisClass),
|
||||
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
|
||||
@ -274,7 +274,7 @@ public:
|
||||
|
||||
template <typename Func, typename... Extra>
|
||||
pure_subclass &def_staticmethod(const char *name, Func &&f,
|
||||
const Extra &... extra) {
|
||||
const Extra &...extra) {
|
||||
static_assert(!std::is_member_function_pointer<Func>::value,
|
||||
"def_staticmethod(...) called with a non-static member "
|
||||
"function pointer");
|
||||
@ -287,7 +287,7 @@ public:
|
||||
|
||||
template <typename Func, typename... Extra>
|
||||
pure_subclass &def_classmethod(const char *name, Func &&f,
|
||||
const Extra &... extra) {
|
||||
const Extra &...extra) {
|
||||
static_assert(!std::is_member_function_pointer<Func>::value,
|
||||
"def_classmethod(...) called with a non-static member "
|
||||
"function pointer");
|
||||
|
@ -21,7 +21,8 @@
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
template <typename T>
|
||||
class OperationPass;
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertVulkanLaunchFuncToVulkanCallsPass();
|
||||
|
@ -49,7 +49,8 @@ public:
|
||||
/// Allow the given dialects.
|
||||
///
|
||||
/// This function adds one or multiple ALLOW entries.
|
||||
template <typename... DialectTs> void allowDialect() {
|
||||
template <typename... DialectTs>
|
||||
void allowDialect() {
|
||||
// The following expands a call to allowDialectImpl for each dialect
|
||||
// in 'DialectTs'. This magic is necessary due to a limitation in the places
|
||||
// that a parameter pack can be expanded in c++11.
|
||||
@ -60,7 +61,8 @@ public:
|
||||
/// Deny the given dialects.
|
||||
///
|
||||
/// This function adds one or multiple DENY entries.
|
||||
template <typename... DialectTs> void denyDialect() {
|
||||
template <typename... DialectTs>
|
||||
void denyDialect() {
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{0, (denyDialectImpl<DialectTs>(), 0)...};
|
||||
}
|
||||
@ -78,7 +80,8 @@ public:
|
||||
/// Allow the given ops.
|
||||
///
|
||||
/// This function adds one or multiple ALLOW entries.
|
||||
template <typename... OpTys> void allowOperation() {
|
||||
template <typename... OpTys>
|
||||
void allowOperation() {
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{0, (allowOperationImpl<OpTys>(), 0)...};
|
||||
}
|
||||
@ -86,7 +89,8 @@ public:
|
||||
/// Deny the given ops.
|
||||
///
|
||||
/// This function adds one or multiple DENY entries.
|
||||
template <typename... OpTys> void denyOperation() {
|
||||
template <typename... OpTys>
|
||||
void denyOperation() {
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{0, (denyOperationImpl<OpTys>(), 0)...};
|
||||
}
|
||||
@ -135,22 +139,26 @@ private:
|
||||
}
|
||||
|
||||
/// Allow a dialect.
|
||||
template <typename DialectT> void allowDialectImpl() {
|
||||
template <typename DialectT>
|
||||
void allowDialectImpl() {
|
||||
allowDialect(DialectT::getDialectNamespace());
|
||||
}
|
||||
|
||||
/// Deny a dialect.
|
||||
template <typename DialectT> void denyDialectImpl() {
|
||||
template <typename DialectT>
|
||||
void denyDialectImpl() {
|
||||
denyDialect(DialectT::getDialectNamespace());
|
||||
}
|
||||
|
||||
/// Allow an op.
|
||||
template <typename OpTy> void allowOperationImpl() {
|
||||
template <typename OpTy>
|
||||
void allowOperationImpl() {
|
||||
allowOperation(OpTy::getOperationName());
|
||||
}
|
||||
|
||||
/// Deny an op.
|
||||
template <typename OpTy> void denyOperationImpl() {
|
||||
template <typename OpTy>
|
||||
void denyOperationImpl() {
|
||||
denyOperation(OpTy::getOperationName());
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,8 @@
|
||||
#include <memory>
|
||||
|
||||
namespace llvm {
|
||||
template <typename T> class Expected;
|
||||
template <typename T>
|
||||
class Expected;
|
||||
class Module;
|
||||
class ExecutionEngine;
|
||||
class JITEventListener;
|
||||
|
@ -182,8 +182,8 @@ FOREVERY_O(DECL_SPARSEINDICES)
|
||||
/// Coordinate-scheme method for adding a new element.
|
||||
#define DECL_ADDELT(VNAME, V) \
|
||||
MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_addElt##VNAME( \
|
||||
void *coo, \
|
||||
StridedMemRefType<V, 0> *vref, StridedMemRefType<index_type, 1> *iref, \
|
||||
void *coo, StridedMemRefType<V, 0> *vref, \
|
||||
StridedMemRefType<index_type, 1> *iref, \
|
||||
StridedMemRefType<index_type, 1> *pref);
|
||||
FOREVERY_V(DECL_ADDELT)
|
||||
#undef DECL_ADDELT
|
||||
|
@ -65,7 +65,8 @@ namespace mlir {
|
||||
/// just as efficient as having your own switch instruction over the instruction
|
||||
/// opcode.
|
||||
|
||||
template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
|
||||
template <typename SubClass, typename RetTy = void>
|
||||
class AffineExprVisitor {
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Interface code - This is the public interface of the AffineExprVisitor
|
||||
// that you use to visit affine expressions...
|
||||
|
@ -46,14 +46,18 @@ public:
|
||||
|
||||
bool operator!() const { return impl == nullptr; }
|
||||
|
||||
template <typename U> bool isa() const;
|
||||
template <typename U>
|
||||
bool isa() const;
|
||||
template <typename First, typename Second, typename... Rest>
|
||||
bool isa() const;
|
||||
template <typename First, typename... Rest>
|
||||
bool isa_and_nonnull() const;
|
||||
template <typename U> U dyn_cast() const;
|
||||
template <typename U> U dyn_cast_or_null() const;
|
||||
template <typename U> U cast() const;
|
||||
template <typename U>
|
||||
U dyn_cast() const;
|
||||
template <typename U>
|
||||
U dyn_cast_or_null() const;
|
||||
template <typename U>
|
||||
U cast() const;
|
||||
|
||||
// Support dyn_cast'ing Attribute to itself.
|
||||
static bool classof(Attribute) { return true; }
|
||||
@ -106,7 +110,8 @@ inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
|
||||
return os;
|
||||
}
|
||||
|
||||
template <typename U> bool Attribute::isa() const {
|
||||
template <typename U>
|
||||
bool Attribute::isa() const {
|
||||
assert(impl && "isa<> used on a null attribute.");
|
||||
return U::classof(*this);
|
||||
}
|
||||
@ -121,13 +126,16 @@ bool Attribute::isa_and_nonnull() const {
|
||||
return impl && isa<First, Rest...>();
|
||||
}
|
||||
|
||||
template <typename U> U Attribute::dyn_cast() const {
|
||||
template <typename U>
|
||||
U Attribute::dyn_cast() const {
|
||||
return isa<U>() ? U(impl) : U(nullptr);
|
||||
}
|
||||
template <typename U> U Attribute::dyn_cast_or_null() const {
|
||||
template <typename U>
|
||||
U Attribute::dyn_cast_or_null() const {
|
||||
return (impl && isa<U>()) ? U(impl) : U(nullptr);
|
||||
}
|
||||
template <typename U> U Attribute::cast() const {
|
||||
template <typename U>
|
||||
U Attribute::cast() const {
|
||||
assert(isa<U>());
|
||||
return U(impl);
|
||||
}
|
||||
@ -248,7 +256,8 @@ using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
|
||||
namespace llvm {
|
||||
|
||||
// Attribute hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::Attribute> {
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::Attribute> {
|
||||
static mlir::Attribute getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
|
||||
@ -280,7 +289,8 @@ struct DenseMapInfo<
|
||||
};
|
||||
|
||||
/// Allow LLVM to steal the low bits of Attributes.
|
||||
template <> struct PointerLikeTypeTraits<mlir::Attribute> {
|
||||
template <>
|
||||
struct PointerLikeTypeTraits<mlir::Attribute> {
|
||||
static inline void *getAsVoidPointer(mlir::Attribute attr) {
|
||||
return const_cast<void *>(attr.getAsOpaquePointer());
|
||||
}
|
||||
@ -291,7 +301,8 @@ template <> struct PointerLikeTypeTraits<mlir::Attribute> {
|
||||
mlir::AttributeStorage *>::NumLowBitsAvailable;
|
||||
};
|
||||
|
||||
template <> struct DenseMapInfo<mlir::NamedAttribute> {
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::NamedAttribute> {
|
||||
static mlir::NamedAttribute getEmptyKey() {
|
||||
auto emptyAttr = llvm::DenseMapInfo<mlir::Attribute>::getEmptyKey();
|
||||
return mlir::NamedAttribute(emptyAttr, emptyAttr);
|
||||
|
@ -63,7 +63,8 @@ public:
|
||||
|
||||
/// Lookup a mapped value within the map. This asserts the provided value
|
||||
/// exists within the map.
|
||||
template <typename T> T lookup(T from) const {
|
||||
template <typename T>
|
||||
T lookup(T from) const {
|
||||
auto result = lookupOrNull(from);
|
||||
assert(result && "expected 'from' to be contained within the map");
|
||||
return result;
|
||||
|
@ -52,7 +52,6 @@ class PredecessorIterator final
|
||||
static Block *unwrap(BlockOperand &value);
|
||||
|
||||
public:
|
||||
|
||||
/// Initializes the operand type iterator to the specified operand iterator.
|
||||
PredecessorIterator(ValueUseIterator<BlockOperand> it)
|
||||
: llvm::mapped_iterator<ValueUseIterator<BlockOperand>,
|
||||
@ -162,7 +161,6 @@ class op_iterator
|
||||
static OpT unwrap(Operation &op) { return cast<OpT>(op); }
|
||||
|
||||
public:
|
||||
|
||||
/// Initializes the iterator to the specified filter iterator.
|
||||
op_iterator(op_filter_iterator<OpT, IteratorT> it)
|
||||
: llvm::mapped_iterator<op_filter_iterator<OpT, IteratorT>,
|
||||
@ -225,7 +223,8 @@ protected:
|
||||
};
|
||||
} // namespace ilist_detail
|
||||
|
||||
template <> struct ilist_traits<::mlir::Operation> {
|
||||
template <>
|
||||
struct ilist_traits<::mlir::Operation> {
|
||||
using Operation = ::mlir::Operation;
|
||||
using op_iterator = simple_ilist<Operation>::iterator;
|
||||
|
||||
|
@ -76,7 +76,8 @@ public:
|
||||
}
|
||||
|
||||
/// Access the element at the given index.
|
||||
template <typename T> T at(uint64_t index) const {
|
||||
template <typename T>
|
||||
T at(uint64_t index) const {
|
||||
if (isSplat)
|
||||
index = 0;
|
||||
return isContiguous ? conState.at<T>(index) : nonConState.at<T>(index);
|
||||
@ -93,7 +94,8 @@ private:
|
||||
ContiguousState(const void *firstEltPtr) : firstEltPtr(firstEltPtr) {}
|
||||
|
||||
/// Access the element at the given index.
|
||||
template <typename T> const T &at(uint64_t index) const {
|
||||
template <typename T>
|
||||
const T &at(uint64_t index) const {
|
||||
return *(reinterpret_cast<const T *>(firstEltPtr) + index);
|
||||
}
|
||||
|
||||
@ -171,7 +173,8 @@ private:
|
||||
NonContiguousState(NonContiguousState &&other) = default;
|
||||
|
||||
/// Access the element at the given index.
|
||||
template <typename T> T at(uint64_t index) const {
|
||||
template <typename T>
|
||||
T at(uint64_t index) const {
|
||||
auto *valueIt = static_cast<OpaqueIteratorValueBase<T> *>(iterator.get());
|
||||
return valueIt->at(index);
|
||||
}
|
||||
|
@ -61,7 +61,8 @@ protected:
|
||||
};
|
||||
|
||||
/// Type trait detector that checks if a given type T is a complex type.
|
||||
template <typename T> struct is_complex_t : public std::false_type {};
|
||||
template <typename T>
|
||||
struct is_complex_t : public std::false_type {};
|
||||
template <typename T>
|
||||
struct is_complex_t<std::complex<T>> : public std::true_type {};
|
||||
} // namespace detail
|
||||
@ -81,7 +82,8 @@ public:
|
||||
/// floating point type that can be used to access the underlying element
|
||||
/// types of a DenseElementsAttr.
|
||||
// TODO: Use std::disjunction when C++17 is supported.
|
||||
template <typename T> struct is_valid_cpp_fp_type {
|
||||
template <typename T>
|
||||
struct is_valid_cpp_fp_type {
|
||||
/// The type is a valid floating point type if it is a builtin floating
|
||||
/// point type, or is a potentially user defined floating point type. The
|
||||
/// latter allows for supporting users that have custom types defined for
|
||||
|
@ -201,9 +201,7 @@ public:
|
||||
|
||||
/// Stream in an Operation.
|
||||
Diagnostic &operator<<(Operation &val);
|
||||
Diagnostic &operator<<(Operation *val) {
|
||||
return *this << *val;
|
||||
}
|
||||
Diagnostic &operator<<(Operation *val) { return *this << *val; }
|
||||
/// Append an operation with the given printing flags.
|
||||
Diagnostic &appendOp(Operation &val, const OpPrintingFlags &flags);
|
||||
|
||||
@ -229,12 +227,13 @@ public:
|
||||
|
||||
/// Append arguments to the diagnostic.
|
||||
template <typename Arg1, typename Arg2, typename... Args>
|
||||
Diagnostic &append(Arg1 &&arg1, Arg2 &&arg2, Args &&... args) {
|
||||
Diagnostic &append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args) {
|
||||
append(std::forward<Arg1>(arg1));
|
||||
return append(std::forward<Arg2>(arg2), std::forward<Args>(args)...);
|
||||
}
|
||||
/// Append one argument to the diagnostic.
|
||||
template <typename Arg> Diagnostic &append(Arg &&arg) {
|
||||
template <typename Arg>
|
||||
Diagnostic &append(Arg &&arg) {
|
||||
*this << std::forward<Arg>(arg);
|
||||
return *this;
|
||||
}
|
||||
@ -323,21 +322,25 @@ public:
|
||||
}
|
||||
|
||||
/// Stream operator for new diagnostic arguments.
|
||||
template <typename Arg> InFlightDiagnostic &operator<<(Arg &&arg) & {
|
||||
template <typename Arg>
|
||||
InFlightDiagnostic &operator<<(Arg &&arg) & {
|
||||
return append(std::forward<Arg>(arg));
|
||||
}
|
||||
template <typename Arg> InFlightDiagnostic &&operator<<(Arg &&arg) && {
|
||||
template <typename Arg>
|
||||
InFlightDiagnostic &&operator<<(Arg &&arg) && {
|
||||
return std::move(append(std::forward<Arg>(arg)));
|
||||
}
|
||||
|
||||
/// Append arguments to the diagnostic.
|
||||
template <typename... Args> InFlightDiagnostic &append(Args &&... args) & {
|
||||
template <typename... Args>
|
||||
InFlightDiagnostic &append(Args &&...args) & {
|
||||
assert(isActive() && "diagnostic not active");
|
||||
if (isInFlight())
|
||||
impl->append(std::forward<Args>(args)...);
|
||||
return *this;
|
||||
}
|
||||
template <typename... Args> InFlightDiagnostic &&append(Args &&... args) && {
|
||||
template <typename... Args>
|
||||
InFlightDiagnostic &&append(Args &&...args) && {
|
||||
return std::move(append(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
@ -483,19 +486,19 @@ InFlightDiagnostic emitRemark(Location loc, const Twine &message);
|
||||
/// the diagnostic arguments directly instead of relying on the returned
|
||||
/// InFlightDiagnostic.
|
||||
template <typename... Args>
|
||||
LogicalResult emitOptionalError(Optional<Location> loc, Args &&... args) {
|
||||
LogicalResult emitOptionalError(Optional<Location> loc, Args &&...args) {
|
||||
if (loc)
|
||||
return emitError(*loc).append(std::forward<Args>(args)...);
|
||||
return failure();
|
||||
}
|
||||
template <typename... Args>
|
||||
LogicalResult emitOptionalWarning(Optional<Location> loc, Args &&... args) {
|
||||
LogicalResult emitOptionalWarning(Optional<Location> loc, Args &&...args) {
|
||||
if (loc)
|
||||
return emitWarning(*loc).append(std::forward<Args>(args)...);
|
||||
return failure();
|
||||
}
|
||||
template <typename... Args>
|
||||
LogicalResult emitOptionalRemark(Optional<Location> loc, Args &&... args) {
|
||||
LogicalResult emitOptionalRemark(Optional<Location> loc, Args &&...args) {
|
||||
if (loc)
|
||||
return emitRemark(*loc).append(std::forward<Args>(args)...);
|
||||
return failure();
|
||||
@ -520,7 +523,8 @@ public:
|
||||
|
||||
protected:
|
||||
/// Set the handler to manage via RAII.
|
||||
template <typename FuncTy> void setHandler(FuncTy &&handler) {
|
||||
template <typename FuncTy>
|
||||
void setHandler(FuncTy &&handler) {
|
||||
auto &diagEngine = ctx->getDiagEngine();
|
||||
if (handlerID)
|
||||
diagEngine.eraseHandler(handlerID);
|
||||
|
@ -161,7 +161,8 @@ public:
|
||||
auto it = registeredInterfaces.find(interfaceID);
|
||||
return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
|
||||
}
|
||||
template <typename InterfaceT> const InterfaceT *getRegisteredInterface() {
|
||||
template <typename InterfaceT>
|
||||
const InterfaceT *getRegisteredInterface() {
|
||||
return static_cast<const InterfaceT *>(
|
||||
getRegisteredInterface(InterfaceT::getInterfaceID()));
|
||||
}
|
||||
@ -201,13 +202,15 @@ protected:
|
||||
|
||||
/// This method is used by derived classes to add their operations to the set.
|
||||
///
|
||||
template <typename... Args> void addOperations() {
|
||||
template <typename... Args>
|
||||
void addOperations() {
|
||||
(void)std::initializer_list<int>{
|
||||
0, (RegisteredOperationName::insert<Args>(*this), 0)...};
|
||||
}
|
||||
|
||||
/// Register a set of type classes with this dialect.
|
||||
template <typename... Args> void addTypes() {
|
||||
template <typename... Args>
|
||||
void addTypes() {
|
||||
(void)std::initializer_list<int>{0, (addType<Args>(), 0)...};
|
||||
}
|
||||
|
||||
@ -217,7 +220,8 @@ protected:
|
||||
void addType(TypeID typeID, AbstractType &&typeInfo);
|
||||
|
||||
/// Register a set of attribute classes with this dialect.
|
||||
template <typename... Args> void addAttributes() {
|
||||
template <typename... Args>
|
||||
void addAttributes() {
|
||||
(void)std::initializer_list<int>{0, (addAttribute<Args>(), 0)...};
|
||||
}
|
||||
|
||||
@ -237,14 +241,16 @@ private:
|
||||
void operator=(Dialect &) = delete;
|
||||
|
||||
/// Register an attribute instance with this dialect.
|
||||
template <typename T> void addAttribute() {
|
||||
template <typename T>
|
||||
void addAttribute() {
|
||||
// Add this attribute to the dialect and register it with the uniquer.
|
||||
addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
|
||||
detail::AttributeUniquer::registerAttribute<T>(context);
|
||||
}
|
||||
|
||||
/// Register a type instance with this dialect.
|
||||
template <typename T> void addType() {
|
||||
template <typename T>
|
||||
void addType() {
|
||||
// Add this type to the dialect and register it with the uniquer.
|
||||
addType(T::getTypeID(), AbstractType::get<T>(*this));
|
||||
detail::TypeUniquer::registerType<T>(context);
|
||||
|
@ -126,10 +126,12 @@ protected:
|
||||
};
|
||||
|
||||
/// Iterator access to the held interfaces.
|
||||
template <typename InterfaceT> iterator<InterfaceT> interface_begin() const {
|
||||
template <typename InterfaceT>
|
||||
iterator<InterfaceT> interface_begin() const {
|
||||
return iterator<InterfaceT>(orderedInterfaces.begin());
|
||||
}
|
||||
template <typename InterfaceT> iterator<InterfaceT> interface_end() const {
|
||||
template <typename InterfaceT>
|
||||
iterator<InterfaceT> interface_end() const {
|
||||
return iterator<InterfaceT>(orderedInterfaces.end());
|
||||
}
|
||||
|
||||
|
@ -128,7 +128,8 @@ inline ::llvm::hash_code hash_value(IntegerSet arg) {
|
||||
namespace llvm {
|
||||
|
||||
// IntegerSet hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::IntegerSet> {
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::IntegerSet> {
|
||||
static mlir::IntegerSet getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::IntegerSet(static_cast<mlir::IntegerSet::ImplType *>(pointer));
|
||||
|
@ -62,9 +62,18 @@ public:
|
||||
LocationAttr *operator->() const { return const_cast<LocationAttr *>(&impl); }
|
||||
|
||||
/// Type casting utilities on the underlying location.
|
||||
template <typename U> bool isa() const { return impl.isa<U>(); }
|
||||
template <typename U> U dyn_cast() const { return impl.dyn_cast<U>(); }
|
||||
template <typename U> U cast() const { return impl.cast<U>(); }
|
||||
template <typename U>
|
||||
bool isa() const {
|
||||
return impl.isa<U>();
|
||||
}
|
||||
template <typename U>
|
||||
U dyn_cast() const {
|
||||
return impl.dyn_cast<U>();
|
||||
}
|
||||
template <typename U>
|
||||
U cast() const {
|
||||
return impl.cast<U>();
|
||||
}
|
||||
|
||||
/// Comparison operators.
|
||||
bool operator==(Location rhs) const { return impl == rhs.impl; }
|
||||
@ -128,7 +137,8 @@ inline OpaqueLoc OpaqueLoc::get(T underlyingLocation, MLIRContext *context) {
|
||||
namespace llvm {
|
||||
|
||||
// Type hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::Location> {
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::Location> {
|
||||
static mlir::Location getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::Location::getFromOpaquePointer(pointer);
|
||||
@ -146,7 +156,8 @@ template <> struct DenseMapInfo<mlir::Location> {
|
||||
};
|
||||
|
||||
/// We align LocationStorage by 8, so allow LLVM to steal the low bits.
|
||||
template <> struct PointerLikeTypeTraits<mlir::Location> {
|
||||
template <>
|
||||
struct PointerLikeTypeTraits<mlir::Location> {
|
||||
public:
|
||||
static inline void *getAsVoidPointer(mlir::Location I) {
|
||||
return const_cast<void *>(I.getAsOpaquePointer());
|
||||
|
@ -1732,14 +1732,16 @@ private:
|
||||
struct InterfaceTargetOrOpT {
|
||||
using type = typename T::ConcreteEntity;
|
||||
};
|
||||
template <typename T> struct InterfaceTargetOrOpT<T, false> {
|
||||
template <typename T>
|
||||
struct InterfaceTargetOrOpT<T, false> {
|
||||
using type = ConcreteType;
|
||||
};
|
||||
|
||||
/// A hook for static assertion that the external interface model T is
|
||||
/// targeting the concrete type of this op. The model can also be a fallback
|
||||
/// model that works for every op.
|
||||
template <typename T> static void checkInterfaceTarget() {
|
||||
template <typename T>
|
||||
static void checkInterfaceTarget() {
|
||||
static_assert(std::is_same<typename InterfaceTargetOrOpT<T>::type,
|
||||
ConcreteType>::value,
|
||||
"attaching an interface to the wrong op kind");
|
||||
|
@ -151,10 +151,9 @@ public:
|
||||
std::enable_if_t<detect_has_print_method<AttrOrType>::value>
|
||||
*sfinae = nullptr>
|
||||
void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
|
||||
llvm::interleaveComma(attrOrTypes, getStream(),
|
||||
[this](AttrOrType attrOrType) {
|
||||
printStrippedAttrOrType(attrOrType);
|
||||
});
|
||||
llvm::interleaveComma(
|
||||
attrOrTypes, getStream(),
|
||||
[this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); });
|
||||
}
|
||||
|
||||
/// SFINAE for printing the provided attribute in the context of an operation
|
||||
|
@ -141,7 +141,8 @@ public:
|
||||
/// Returns true if the operation was registered with a particular trait, e.g.
|
||||
/// hasTrait<OperandsAreSignlessIntegerLike>(). Returns false if the operation
|
||||
/// is unregistered.
|
||||
template <template <typename T> class Trait> bool hasTrait() const {
|
||||
template <template <typename T> class Trait>
|
||||
bool hasTrait() const {
|
||||
return hasTrait(TypeID::get<Trait>());
|
||||
}
|
||||
bool hasTrait(TypeID traitID) const {
|
||||
@ -151,7 +152,8 @@ public:
|
||||
/// Returns true if the operation *might* have the provided trait. This
|
||||
/// means that either the operation is unregistered, or it was registered with
|
||||
/// the provide trait.
|
||||
template <template <typename T> class Trait> bool mightHaveTrait() const {
|
||||
template <template <typename T> class Trait>
|
||||
bool mightHaveTrait() const {
|
||||
return mightHaveTrait(TypeID::get<Trait>());
|
||||
}
|
||||
bool mightHaveTrait(TypeID traitID) const {
|
||||
@ -161,12 +163,14 @@ public:
|
||||
/// Returns an instance of the concept object for the given interface if it
|
||||
/// was registered to this operation, null otherwise. This should not be used
|
||||
/// directly.
|
||||
template <typename T> typename T::Concept *getInterface() const {
|
||||
template <typename T>
|
||||
typename T::Concept *getInterface() const {
|
||||
return impl->interfaceMap.lookup<T>();
|
||||
}
|
||||
|
||||
/// Returns true if this operation has the given interface registered to it.
|
||||
template <typename T> bool hasInterface() const {
|
||||
template <typename T>
|
||||
bool hasInterface() const {
|
||||
return hasInterface(TypeID::get<T>());
|
||||
}
|
||||
bool hasInterface(TypeID interfaceID) const {
|
||||
@ -345,7 +349,8 @@ public:
|
||||
}
|
||||
|
||||
/// Returns true if the operation has a particular trait.
|
||||
template <template <typename T> class Trait> bool hasTrait() const {
|
||||
template <template <typename T> class Trait>
|
||||
bool hasTrait() const {
|
||||
return hasTrait(TypeID::get<Trait>());
|
||||
}
|
||||
|
||||
|
@ -271,7 +271,7 @@ public:
|
||||
/// This method provides a convenient interface for creating and initializing
|
||||
/// derived rewrite patterns of the given type `T`.
|
||||
template <typename T, typename... Args>
|
||||
static std::unique_ptr<T> create(Args &&... args) {
|
||||
static std::unique_ptr<T> create(Args &&...args) {
|
||||
std::unique_ptr<T> pattern =
|
||||
std::make_unique<T>(std::forward<Args>(args)...);
|
||||
initializePattern<T>(*pattern);
|
||||
@ -1410,7 +1410,7 @@ public:
|
||||
template <typename... Ts, typename ConstructorArg,
|
||||
typename... ConstructorArgs,
|
||||
typename = std::enable_if_t<sizeof...(Ts) != 0>>
|
||||
RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&... args) {
|
||||
RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
|
||||
// The following expands a call to emplace_back for each of the pattern
|
||||
// types 'Ts'. This magic is necessary due to a limitation in the places
|
||||
// that a parameter pack can be expanded in c++11.
|
||||
@ -1428,7 +1428,7 @@ public:
|
||||
typename = std::enable_if_t<sizeof...(Ts) != 0>>
|
||||
RewritePatternSet &addWithLabel(ArrayRef<StringRef> debugLabels,
|
||||
ConstructorArg &&arg,
|
||||
ConstructorArgs &&... args) {
|
||||
ConstructorArgs &&...args) {
|
||||
// The following expands a call to emplace_back for each of the pattern
|
||||
// types 'Ts'. This magic is necessary due to a limitation in the places
|
||||
// that a parameter pack can be expanded in c++11.
|
||||
@ -1493,7 +1493,7 @@ public:
|
||||
template <typename... Ts, typename ConstructorArg,
|
||||
typename... ConstructorArgs,
|
||||
typename = std::enable_if_t<sizeof...(Ts) != 0>>
|
||||
RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
|
||||
RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
|
||||
// The following expands a call to emplace_back for each of the pattern
|
||||
// types 'Ts'. This magic is necessary due to a limitation in the places
|
||||
// that a parameter pack can be expanded in c++11.
|
||||
@ -1553,7 +1553,7 @@ private:
|
||||
/// chaining insertions.
|
||||
template <typename T, typename... Args>
|
||||
std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
|
||||
addImpl(ArrayRef<StringRef> debugLabels, Args &&... args) {
|
||||
addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
|
||||
std::unique_ptr<T> pattern =
|
||||
RewritePattern::create<T>(std::forward<Args>(args)...);
|
||||
pattern->addDebugLabels(debugLabels);
|
||||
@ -1561,7 +1561,7 @@ private:
|
||||
}
|
||||
template <typename T, typename... Args>
|
||||
std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
|
||||
addImpl(ArrayRef<StringRef> debugLabels, Args &&... args) {
|
||||
addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
|
||||
// TODO: Add the provided labels to the PDL pattern when PDL supports
|
||||
// labels.
|
||||
pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
|
||||
|
@ -19,7 +19,8 @@
|
||||
#include "llvm/ADT/GraphTraits.h"
|
||||
|
||||
namespace llvm {
|
||||
template <> struct GraphTraits<mlir::Block *> {
|
||||
template <>
|
||||
struct GraphTraits<mlir::Block *> {
|
||||
using ChildIteratorType = mlir::Block::succ_iterator;
|
||||
using Node = mlir::Block;
|
||||
using NodeRef = Node *;
|
||||
@ -32,7 +33,8 @@ template <> struct GraphTraits<mlir::Block *> {
|
||||
static ChildIteratorType child_end(NodeRef node) { return node->succ_end(); }
|
||||
};
|
||||
|
||||
template <> struct GraphTraits<Inverse<mlir::Block *>> {
|
||||
template <>
|
||||
struct GraphTraits<Inverse<mlir::Block *>> {
|
||||
using ChildIteratorType = mlir::Block::pred_iterator;
|
||||
using Node = mlir::Block;
|
||||
using NodeRef = Node *;
|
||||
|
@ -105,7 +105,8 @@ public:
|
||||
|
||||
/// Provide an implementation of 'classof' that compares the type id of the
|
||||
/// provided value with that of the concrete type.
|
||||
template <typename T> static bool classof(T val) {
|
||||
template <typename T>
|
||||
static bool classof(T val) {
|
||||
static_assert(std::is_convertible<ConcreteT, T>::value,
|
||||
"casting from a non-convertible type");
|
||||
return val.getTypeID() == getTypeID();
|
||||
@ -182,7 +183,8 @@ public:
|
||||
protected:
|
||||
/// Mutate the current storage instance. This will not change the unique key.
|
||||
/// The arguments are forwarded to 'ConcreteT::mutate'.
|
||||
template <typename... Args> LogicalResult mutate(Args &&...args) {
|
||||
template <typename... Args>
|
||||
LogicalResult mutate(Args &&...args) {
|
||||
static_assert(std::is_base_of<StorageUserTrait::IsMutable<ConcreteT>,
|
||||
ConcreteT>::value,
|
||||
"The `mutate` function expects mutable trait "
|
||||
@ -192,7 +194,8 @@ protected:
|
||||
}
|
||||
|
||||
/// Default implementation that just returns success.
|
||||
template <typename... Args> static LogicalResult verify(Args... args) {
|
||||
template <typename... Args>
|
||||
static LogicalResult verify(Args... args) {
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -59,7 +59,8 @@ public:
|
||||
/// Returns an instance of the concept object for the given interface if it
|
||||
/// was registered to this type, null otherwise. This should not be used
|
||||
/// directly.
|
||||
template <typename T> typename T::Concept *getInterface() const {
|
||||
template <typename T>
|
||||
typename T::Concept *getInterface() const {
|
||||
return interfaceMap.lookup<T>();
|
||||
}
|
||||
|
||||
|
@ -94,11 +94,16 @@ public:
|
||||
|
||||
bool operator!() const { return impl == nullptr; }
|
||||
|
||||
template <typename U> bool isa() const;
|
||||
template <typename First, typename Second, typename... Rest> bool isa() const;
|
||||
template <typename U> U dyn_cast() const;
|
||||
template <typename U> U dyn_cast_or_null() const;
|
||||
template <typename U> U cast() const;
|
||||
template <typename U>
|
||||
bool isa() const;
|
||||
template <typename First, typename Second, typename... Rest>
|
||||
bool isa() const;
|
||||
template <typename U>
|
||||
U dyn_cast() const;
|
||||
template <typename U>
|
||||
U dyn_cast_or_null() const;
|
||||
template <typename U>
|
||||
U cast() const;
|
||||
|
||||
// Support type casting Type to itself.
|
||||
static bool classof(Type) { return true; }
|
||||
@ -243,7 +248,8 @@ inline ::llvm::hash_code hash_value(Type arg) {
|
||||
return DenseMapInfo<const Type::ImplType *>::getHashValue(arg.impl);
|
||||
}
|
||||
|
||||
template <typename U> bool Type::isa() const {
|
||||
template <typename U>
|
||||
bool Type::isa() const {
|
||||
assert(impl && "isa<> used on a null type.");
|
||||
return U::classof(*this);
|
||||
}
|
||||
@ -253,13 +259,16 @@ bool Type::isa() const {
|
||||
return isa<First>() || isa<Second, Rest...>();
|
||||
}
|
||||
|
||||
template <typename U> U Type::dyn_cast() const {
|
||||
template <typename U>
|
||||
U Type::dyn_cast() const {
|
||||
return isa<U>() ? U(impl) : U(nullptr);
|
||||
}
|
||||
template <typename U> U Type::dyn_cast_or_null() const {
|
||||
template <typename U>
|
||||
U Type::dyn_cast_or_null() const {
|
||||
return (impl && isa<U>()) ? U(impl) : U(nullptr);
|
||||
}
|
||||
template <typename U> U Type::cast() const {
|
||||
template <typename U>
|
||||
U Type::cast() const {
|
||||
assert(isa<U>());
|
||||
return U(impl);
|
||||
}
|
||||
@ -269,7 +278,8 @@ template <typename U> U Type::cast() const {
|
||||
namespace llvm {
|
||||
|
||||
// Type hash just like pointers.
|
||||
template <> struct DenseMapInfo<mlir::Type> {
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::Type> {
|
||||
static mlir::Type getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
|
||||
@ -296,7 +306,8 @@ struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value &&
|
||||
};
|
||||
|
||||
/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
|
||||
template <> struct PointerLikeTypeTraits<mlir::Type> {
|
||||
template <>
|
||||
struct PointerLikeTypeTraits<mlir::Type> {
|
||||
public:
|
||||
static inline void *getAsVoidPointer(mlir::Type I) {
|
||||
return const_cast<void *>(I.getAsOpaquePointer());
|
||||
|
@ -20,9 +20,12 @@
|
||||
namespace mlir {
|
||||
|
||||
class Operation;
|
||||
template <typename OperandType> class ValueUseIterator;
|
||||
template <typename OperandType> class FilteredValueUseIterator;
|
||||
template <typename UseIteratorT, typename OperandType> class ValueUserIterator;
|
||||
template <typename OperandType>
|
||||
class ValueUseIterator;
|
||||
template <typename OperandType>
|
||||
class FilteredValueUseIterator;
|
||||
template <typename UseIteratorT, typename OperandType>
|
||||
class ValueUserIterator;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IROperand
|
||||
@ -77,7 +80,8 @@ protected:
|
||||
}
|
||||
|
||||
/// Insert this operand into the given use list.
|
||||
template <typename UseListT> void insertInto(UseListT *useList) {
|
||||
template <typename UseListT>
|
||||
void insertInto(UseListT *useList) {
|
||||
back = &useList->firstUse;
|
||||
nextUse = useList->firstUse;
|
||||
if (nextUse)
|
||||
@ -164,7 +168,8 @@ private:
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents a single IR object that contains a use list.
|
||||
template <typename OperandType> class IRObjectWithUseList {
|
||||
template <typename OperandType>
|
||||
class IRObjectWithUseList {
|
||||
public:
|
||||
~IRObjectWithUseList() {
|
||||
assert(use_empty() && "Cannot destroy a value that still has uses!");
|
||||
|
@ -53,7 +53,8 @@ public:
|
||||
TypeID getEffectID() const { return id; }
|
||||
|
||||
/// Returns a unique instance for the given effect class.
|
||||
template <typename DerivedEffect> static DerivedEffect *get() {
|
||||
template <typename DerivedEffect>
|
||||
static DerivedEffect *get() {
|
||||
static_assert(std::is_base_of<Effect, DerivedEffect>::value,
|
||||
"expected DerivedEffect to inherit from Effect");
|
||||
|
||||
@ -134,7 +135,8 @@ struct AutomaticAllocationScopeResource
|
||||
/// applied, and an optional symbol reference or value(either operand, result,
|
||||
/// or region entry argument) that the effect is applied to, and an optional
|
||||
/// parameters attribute further specifying the details of the effect.
|
||||
template <typename EffectT> class EffectInstance {
|
||||
template <typename EffectT>
|
||||
class EffectInstance {
|
||||
public:
|
||||
EffectInstance(EffectT *effect, Resource *resource = DefaultResource::get())
|
||||
: effect(effect), resource(resource) {}
|
||||
|
@ -43,7 +43,8 @@ public:
|
||||
bool isNone() const { return preservedIDs.empty(); }
|
||||
|
||||
/// Preserve the given analyses.
|
||||
template <typename AnalysisT> void preserve() {
|
||||
template <typename AnalysisT>
|
||||
void preserve() {
|
||||
preserve(TypeID::get<AnalysisT>());
|
||||
}
|
||||
template <typename AnalysisT, typename AnalysisT2, typename... OtherAnalysesT>
|
||||
@ -56,7 +57,8 @@ public:
|
||||
/// Returns true if the given analysis has been marked as preserved. Note that
|
||||
/// this simply checks for the presence of a given analysis ID and should not
|
||||
/// be used as a general preservation checker.
|
||||
template <typename AnalysisT> bool isPreserved() const {
|
||||
template <typename AnalysisT>
|
||||
bool isPreserved() const {
|
||||
return isPreserved(TypeID::get<AnalysisT>());
|
||||
}
|
||||
bool isPreserved(TypeID id) const { return preservedIDs.count(id); }
|
||||
@ -110,7 +112,8 @@ struct AnalysisConcept {
|
||||
};
|
||||
|
||||
/// A derived analysis model used to hold a specific analysis object.
|
||||
template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
|
||||
template <typename AnalysisT>
|
||||
struct AnalysisModel : public AnalysisConcept {
|
||||
template <typename... Args>
|
||||
explicit AnalysisModel(Args &&...args)
|
||||
: analysis(std::forward<Args>(args)...) {}
|
||||
@ -135,7 +138,8 @@ class AnalysisMap {
|
||||
using ConceptMap = llvm::MapVector<TypeID, std::unique_ptr<AnalysisConcept>>;
|
||||
|
||||
/// Utility to return the name of the given analysis class.
|
||||
template <typename AnalysisT> static StringRef getAnalysisName() {
|
||||
template <typename AnalysisT>
|
||||
static StringRef getAnalysisName() {
|
||||
StringRef name = llvm::getTypeName<AnalysisT>();
|
||||
if (!name.consume_front("mlir::"))
|
||||
name.consume_front("(anonymous namespace)::");
|
||||
@ -309,7 +313,8 @@ public:
|
||||
}
|
||||
|
||||
/// Query for the given analysis for the current operation.
|
||||
template <typename AnalysisT> AnalysisT &getAnalysis() {
|
||||
template <typename AnalysisT>
|
||||
AnalysisT &getAnalysis() {
|
||||
return impl->analyses.getAnalysis<AnalysisT>(getPassInstrumentor(), *this);
|
||||
}
|
||||
|
||||
@ -328,7 +333,8 @@ public:
|
||||
}
|
||||
|
||||
/// Query for an analysis of a child operation, constructing it if necessary.
|
||||
template <typename AnalysisT> AnalysisT &getChildAnalysis(Operation *op) {
|
||||
template <typename AnalysisT>
|
||||
AnalysisT &getChildAnalysis(Operation *op) {
|
||||
return nest(op).template getAnalysis<AnalysisT>();
|
||||
}
|
||||
|
||||
|
@ -212,7 +212,8 @@ protected:
|
||||
void signalPassFailure() { getPassState().irAndPassFailed.setInt(true); }
|
||||
|
||||
/// Query an analysis for the current ir unit.
|
||||
template <typename AnalysisT> AnalysisT &getAnalysis() {
|
||||
template <typename AnalysisT>
|
||||
AnalysisT &getAnalysis() {
|
||||
return getAnalysisManager().getAnalysis<AnalysisT>();
|
||||
}
|
||||
|
||||
@ -236,7 +237,8 @@ protected:
|
||||
}
|
||||
|
||||
/// Mark the provided analyses as preserved.
|
||||
template <typename... AnalysesT> void markAnalysesPreserved() {
|
||||
template <typename... AnalysesT>
|
||||
void markAnalysesPreserved() {
|
||||
getPassState().preservedAnalyses.preserve<AnalysesT...>();
|
||||
}
|
||||
void markAnalysesPreserved(TypeID id) {
|
||||
@ -266,7 +268,8 @@ protected:
|
||||
|
||||
/// Returns the analysis for the given child operation, or creates it if it
|
||||
/// doesn't exist.
|
||||
template <typename AnalysisT> AnalysisT &getChildAnalysis(Operation *child) {
|
||||
template <typename AnalysisT>
|
||||
AnalysisT &getChildAnalysis(Operation *child) {
|
||||
return getAnalysisManager().getChildAnalysis<AnalysisT>(child);
|
||||
}
|
||||
|
||||
@ -343,7 +346,8 @@ private:
|
||||
/// - A 'void runOnOperation()' method.
|
||||
/// - A 'StringRef getName() const' method.
|
||||
/// - A 'std::unique_ptr<Pass> clonePass() const' method.
|
||||
template <typename OpT = void> class OperationPass : public Pass {
|
||||
template <typename OpT = void>
|
||||
class OperationPass : public Pass {
|
||||
protected:
|
||||
OperationPass(TypeID passID) : Pass(passID, OpT::getOperationName()) {}
|
||||
OperationPass(const OperationPass &) = default;
|
||||
@ -381,7 +385,8 @@ protected:
|
||||
/// - A 'void runOnOperation()' method.
|
||||
/// - A 'StringRef getName() const' method.
|
||||
/// - A 'std::unique_ptr<Pass> clonePass() const' method.
|
||||
template <> class OperationPass<void> : public Pass {
|
||||
template <>
|
||||
class OperationPass<void> : public Pass {
|
||||
protected:
|
||||
OperationPass(TypeID passID) : Pass(passID) {}
|
||||
OperationPass(const OperationPass &) = default;
|
||||
@ -431,7 +436,8 @@ protected:
|
||||
/// several necessary utility methods. This should only be used for passes that
|
||||
/// are not suitably represented using the declarative pass specification(i.e.
|
||||
/// tablegen backend).
|
||||
template <typename PassT, typename BaseT> class PassWrapper : public BaseT {
|
||||
template <typename PassT, typename BaseT>
|
||||
class PassWrapper : public BaseT {
|
||||
public:
|
||||
/// Support isa/dyn_cast functionality for the derived pass class.
|
||||
static bool classof(const Pass *pass) {
|
||||
|
@ -125,7 +125,8 @@ private:
|
||||
} // namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
template <> struct DenseMapInfo<mlir::PassInstrumentation::PipelineParentInfo> {
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::PassInstrumentation::PipelineParentInfo> {
|
||||
using T = mlir::PassInstrumentation::PipelineParentInfo;
|
||||
using PairInfo = DenseMapInfo<std::pair<uint64_t, void *>>;
|
||||
|
||||
|
@ -98,7 +98,8 @@ public:
|
||||
/// pass manager.
|
||||
OpPassManager &nest(OperationName nestedName);
|
||||
OpPassManager &nest(StringRef nestedName);
|
||||
template <typename OpT> OpPassManager &nest() {
|
||||
template <typename OpT>
|
||||
OpPassManager &nest() {
|
||||
return nest(OpT::getOperationName());
|
||||
}
|
||||
|
||||
@ -115,7 +116,8 @@ public:
|
||||
|
||||
/// Add the given pass to a nested pass manager for the given operation kind
|
||||
/// `OpT`.
|
||||
template <typename OpT> void addNestedPass(std::unique_ptr<Pass> pass) {
|
||||
template <typename OpT>
|
||||
void addNestedPass(std::unique_ptr<Pass> pass) {
|
||||
nest<OpT>().addPass(std::move(pass));
|
||||
}
|
||||
|
||||
|
@ -140,7 +140,8 @@ void registerPass(const PassAllocatorFunction &function);
|
||||
/// /// At namespace scope.
|
||||
/// static PassRegistration<MyPass> reg;
|
||||
///
|
||||
template <typename ConcretePass> struct PassRegistration {
|
||||
template <typename ConcretePass>
|
||||
struct PassRegistration {
|
||||
PassRegistration(const PassAllocatorFunction &constructor) {
|
||||
registerPass(constructor);
|
||||
}
|
||||
@ -184,7 +185,8 @@ struct PassPipelineRegistration {
|
||||
|
||||
/// Convenience specialization of PassPipelineRegistration for EmptyPassOptions
|
||||
/// that does not pass an empty options struct to the pass builder function.
|
||||
template <> struct PassPipelineRegistration<EmptyPipelineOptions> {
|
||||
template <>
|
||||
struct PassPipelineRegistration<EmptyPipelineOptions> {
|
||||
PassPipelineRegistration(
|
||||
StringRef arg, StringRef description,
|
||||
const std::function<void(OpPassManager &)> &builder) {
|
||||
|
@ -107,13 +107,13 @@ public:
|
||||
/// `Args` are a set of parameters used by handlers of `ActionType` to
|
||||
/// determine if the action should be executed.
|
||||
template <typename ActionType, typename... Args>
|
||||
bool shouldExecute(Args &&... args) {
|
||||
bool shouldExecute(Args &&...args) {
|
||||
// The manager is always disabled if built without debug.
|
||||
#if !LLVM_ENABLE_ABI_BREAKING_CHECKS
|
||||
return true;
|
||||
#else
|
||||
// Invoke the `shouldExecute` method on the provided handler.
|
||||
auto shouldExecuteFn = [&](auto *handler, auto &&... handlerParams) {
|
||||
auto shouldExecuteFn = [&](auto *handler, auto &&...handlerParams) {
|
||||
return handler->shouldExecute(
|
||||
std::forward<decltype(handlerParams)>(handlerParams)...);
|
||||
};
|
||||
@ -139,7 +139,7 @@ private:
|
||||
template <typename ActionType, typename ResultT, typename HandlerCallbackT,
|
||||
typename... Args>
|
||||
FailureOr<ResultT> dispatchToHandler(HandlerCallbackT &&handlerCallback,
|
||||
Args &&... args) {
|
||||
Args &&...args) {
|
||||
static_assert(ActionType::template canHandleWith<Args...>(),
|
||||
"cannot execute action with the given set of parameters");
|
||||
|
||||
@ -189,7 +189,8 @@ private:
|
||||
/// This class provides a handler class that can be derived from to handle
|
||||
/// instances of this action. The parameters to its query methods map 1-1 to the
|
||||
/// types on the action type.
|
||||
template <typename... ParameterTs> class DebugAction {
|
||||
template <typename... ParameterTs>
|
||||
class DebugAction {
|
||||
public:
|
||||
class Handler : public DebugActionManager::HandlerBase {
|
||||
public:
|
||||
|
@ -71,7 +71,8 @@ template <typename ConcreteType, typename ValueT, typename Traits,
|
||||
class Interface : public BaseType {
|
||||
public:
|
||||
using Concept = typename Traits::Concept;
|
||||
template <typename T> using Model = typename Traits::template Model<T>;
|
||||
template <typename T>
|
||||
using Model = typename Traits::template Model<T>;
|
||||
template <typename T>
|
||||
using FallbackModel = typename Traits::template FallbackModel<T>;
|
||||
using InterfaceBase =
|
||||
@ -205,7 +206,8 @@ public:
|
||||
|
||||
/// Returns an instance of the concept object for the given interface if it
|
||||
/// was registered to this map, null otherwise.
|
||||
template <typename T> typename T::Concept *lookup() const {
|
||||
template <typename T>
|
||||
typename T::Concept *lookup() const {
|
||||
return reinterpret_cast<typename T::Concept *>(lookup(T::getInterfaceID()));
|
||||
}
|
||||
|
||||
|
@ -35,41 +35,60 @@
|
||||
// Forward declarations.
|
||||
namespace llvm {
|
||||
// String types
|
||||
template <unsigned N> class SmallString;
|
||||
template <unsigned N>
|
||||
class SmallString;
|
||||
class StringRef;
|
||||
class StringLiteral;
|
||||
class Twine;
|
||||
|
||||
// Containers.
|
||||
template <typename T> class ArrayRef;
|
||||
template <typename T>
|
||||
class ArrayRef;
|
||||
class BitVector;
|
||||
namespace detail {
|
||||
template <typename KeyT, typename ValueT> struct DenseMapPair;
|
||||
template <typename KeyT, typename ValueT>
|
||||
struct DenseMapPair;
|
||||
} // namespace detail
|
||||
template <typename KeyT, typename ValueT, typename KeyInfoT, typename BucketT>
|
||||
class DenseMap;
|
||||
template <typename T, typename Enable> struct DenseMapInfo;
|
||||
template <typename ValueT, typename ValueInfoT> class DenseSet;
|
||||
template <typename T, typename Enable>
|
||||
struct DenseMapInfo;
|
||||
template <typename ValueT, typename ValueInfoT>
|
||||
class DenseSet;
|
||||
class MallocAllocator;
|
||||
template <typename T> class MutableArrayRef;
|
||||
template <typename T> class Optional;
|
||||
template <typename... PT> class PointerUnion;
|
||||
template <typename T, typename Vector, typename Set> class SetVector;
|
||||
template <typename T, unsigned N> class SmallPtrSet;
|
||||
template <typename T> class SmallPtrSetImpl;
|
||||
template <typename T, unsigned N> class SmallVector;
|
||||
template <typename T> class SmallVectorImpl;
|
||||
template <typename AllocatorTy> class StringSet;
|
||||
template <typename T, typename R> class StringSwitch;
|
||||
template <typename T> class TinyPtrVector;
|
||||
template <typename T, typename ResultT> class TypeSwitch;
|
||||
template <typename T>
|
||||
class MutableArrayRef;
|
||||
template <typename T>
|
||||
class Optional;
|
||||
template <typename... PT>
|
||||
class PointerUnion;
|
||||
template <typename T, typename Vector, typename Set>
|
||||
class SetVector;
|
||||
template <typename T, unsigned N>
|
||||
class SmallPtrSet;
|
||||
template <typename T>
|
||||
class SmallPtrSetImpl;
|
||||
template <typename T, unsigned N>
|
||||
class SmallVector;
|
||||
template <typename T>
|
||||
class SmallVectorImpl;
|
||||
template <typename AllocatorTy>
|
||||
class StringSet;
|
||||
template <typename T, typename R>
|
||||
class StringSwitch;
|
||||
template <typename T>
|
||||
class TinyPtrVector;
|
||||
template <typename T, typename ResultT>
|
||||
class TypeSwitch;
|
||||
|
||||
// Other common classes.
|
||||
class APInt;
|
||||
class APSInt;
|
||||
class APFloat;
|
||||
template <typename Fn> class function_ref;
|
||||
template <typename IteratorT> class iterator_range;
|
||||
template <typename Fn>
|
||||
class function_ref;
|
||||
template <typename IteratorT>
|
||||
class iterator_range;
|
||||
class raw_ostream;
|
||||
class SMLoc;
|
||||
class SMRange;
|
||||
@ -126,7 +145,8 @@ using TypeSwitch = llvm::TypeSwitch<T, ResultT>;
|
||||
using llvm::APFloat;
|
||||
using llvm::APInt;
|
||||
using llvm::APSInt;
|
||||
template <typename Fn> using function_ref = llvm::function_ref<Fn>;
|
||||
template <typename Fn>
|
||||
using function_ref = llvm::function_ref<Fn>;
|
||||
using llvm::iterator_range;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::SMLoc;
|
||||
|
@ -74,7 +74,8 @@ inline bool failed(LogicalResult result) { return result.failed(); }
|
||||
/// This class provides support for representing a failure result, or a valid
|
||||
/// value of type `T`. This allows for integrating with LogicalResult, while
|
||||
/// also providing a value on the success path.
|
||||
template <typename T> class LLVM_NODISCARD FailureOr : public Optional<T> {
|
||||
template <typename T>
|
||||
class LLVM_NODISCARD FailureOr : public Optional<T> {
|
||||
public:
|
||||
/// Allow constructing from a LogicalResult. The result *must* be a failure.
|
||||
/// Success results should use a proper instance of type `T`.
|
||||
|
@ -94,7 +94,8 @@ public:
|
||||
public:
|
||||
/// Copy the specified array of elements into memory managed by our bump
|
||||
/// pointer allocator. This assumes the elements are all PODs.
|
||||
template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
|
||||
template <typename T>
|
||||
ArrayRef<T> copyInto(ArrayRef<T> elements) {
|
||||
if (elements.empty())
|
||||
return llvm::None;
|
||||
auto result = allocator.Allocate<T>(elements.size());
|
||||
@ -115,7 +116,10 @@ public:
|
||||
}
|
||||
|
||||
/// Allocate an instance of the provided type.
|
||||
template <typename T> T *allocate() { return allocator.Allocate<T>(); }
|
||||
template <typename T>
|
||||
T *allocate() {
|
||||
return allocator.Allocate<T>();
|
||||
}
|
||||
|
||||
/// Allocate 'size' bytes of 'alignment' aligned memory.
|
||||
void *allocate(size_t size, size_t alignment) {
|
||||
@ -141,7 +145,8 @@ public:
|
||||
/// Register a new parametric storage class, this is necessary to create
|
||||
/// instances of this class type. `id` is the type identifier that will be
|
||||
/// used to identify this type when creating instances of it via 'get'.
|
||||
template <typename Storage> void registerParametricStorageType(TypeID id) {
|
||||
template <typename Storage>
|
||||
void registerParametricStorageType(TypeID id) {
|
||||
// If the storage is trivially destructible, we don't need a destructor
|
||||
// function.
|
||||
if (std::is_trivially_destructible<Storage>::value)
|
||||
@ -151,7 +156,8 @@ public:
|
||||
});
|
||||
}
|
||||
/// Utility override when the storage type represents the type id.
|
||||
template <typename Storage> void registerParametricStorageType() {
|
||||
template <typename Storage>
|
||||
void registerParametricStorageType() {
|
||||
registerParametricStorageType<Storage>(TypeID::get<Storage>());
|
||||
}
|
||||
/// Register a new singleton storage class, this is necessary to get the
|
||||
@ -170,7 +176,8 @@ public:
|
||||
};
|
||||
registerSingletonImpl(id, ctorFn);
|
||||
}
|
||||
template <typename Storage> void registerSingletonStorageType(TypeID id) {
|
||||
template <typename Storage>
|
||||
void registerSingletonStorageType(TypeID id) {
|
||||
registerSingletonStorageType<Storage>(id, llvm::None);
|
||||
}
|
||||
/// Utility override when the storage type represents the type id.
|
||||
@ -219,11 +226,13 @@ public:
|
||||
|
||||
/// Gets a uniqued instance of 'Storage' which is a singleton storage type.
|
||||
/// 'id' is the type id used when registering the storage instance.
|
||||
template <typename Storage> Storage *get(TypeID id) {
|
||||
template <typename Storage>
|
||||
Storage *get(TypeID id) {
|
||||
return static_cast<Storage *>(getSingletonImpl(id));
|
||||
}
|
||||
/// Utility override when the storage type represents the type id.
|
||||
template <typename Storage> Storage *get() {
|
||||
template <typename Storage>
|
||||
Storage *get() {
|
||||
return get<Storage>(TypeID::get<Storage>());
|
||||
}
|
||||
|
||||
|
@ -334,7 +334,8 @@ public:
|
||||
MLIR_DECLARE_EXPLICIT_TYPE_ID(void)
|
||||
|
||||
namespace llvm {
|
||||
template <> struct DenseMapInfo<mlir::TypeID> {
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::TypeID> {
|
||||
static inline mlir::TypeID getEmptyKey() {
|
||||
void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::TypeID::getFromOpaquePointer(pointer);
|
||||
@ -350,7 +351,8 @@ template <> struct DenseMapInfo<mlir::TypeID> {
|
||||
};
|
||||
|
||||
/// We align TypeID::Storage by 8, so allow LLVM to steal the low bits.
|
||||
template <> struct PointerLikeTypeTraits<mlir::TypeID> {
|
||||
template <>
|
||||
struct PointerLikeTypeTraits<mlir::TypeID> {
|
||||
static inline void *getAsVoidPointer(mlir::TypeID info) {
|
||||
return const_cast<void *>(info.getAsOpaquePointer());
|
||||
}
|
||||
|
@ -216,15 +216,16 @@ private:
|
||||
std::string escapeString(StringRef value);
|
||||
|
||||
namespace detail {
|
||||
template <typename> struct stringifier {
|
||||
template <typename T> static std::string apply(T &&t) {
|
||||
template <typename>
|
||||
struct stringifier {
|
||||
template <typename T>
|
||||
static std::string apply(T &&t) {
|
||||
return std::string(std::forward<T>(t));
|
||||
}
|
||||
};
|
||||
template <> struct stringifier<Twine> {
|
||||
static std::string apply(const Twine &twine) {
|
||||
return twine.str();
|
||||
}
|
||||
template <>
|
||||
struct stringifier<Twine> {
|
||||
static std::string apply(const Twine &twine) { return twine.str(); }
|
||||
};
|
||||
template <typename OptionalT>
|
||||
struct stringifier<Optional<OptionalT>> {
|
||||
@ -235,7 +236,8 @@ struct stringifier<Optional<OptionalT>> {
|
||||
} // namespace detail
|
||||
|
||||
/// Generically convert a value to a std::string.
|
||||
template <typename T> std::string stringify(T &&t) {
|
||||
template <typename T>
|
||||
std::string stringify(T &&t) {
|
||||
return detail::stringifier<std::remove_reference_t<std::remove_const_t<T>>>::
|
||||
apply(std::forward<T>(t));
|
||||
}
|
||||
|
@ -165,19 +165,24 @@ public:
|
||||
return s.str();
|
||||
}
|
||||
|
||||
template <unsigned N> SmallString<N> sstr() const {
|
||||
template <unsigned N>
|
||||
SmallString<N> sstr() const {
|
||||
SmallString<N> result;
|
||||
llvm::raw_svector_ostream s(result);
|
||||
format(s);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <unsigned N> operator SmallString<N>() const { return sstr<N>(); }
|
||||
template <unsigned N>
|
||||
operator SmallString<N>() const {
|
||||
return sstr<N>();
|
||||
}
|
||||
|
||||
operator std::string() const { return str(); }
|
||||
};
|
||||
|
||||
template <typename Tuple> class FmtObject : public FmtObjectBase {
|
||||
template <typename Tuple>
|
||||
class FmtObject : public FmtObjectBase {
|
||||
// Storage for the parameter adapters. Since the base class erases the type
|
||||
// of the parameters, we have to own the storage for the parameters here, and
|
||||
// have the base class store type-erased pointers into this tuple.
|
||||
|
@ -34,8 +34,7 @@ class VariableDecl;
|
||||
/// This class provides a convenient API for interacting with source names. It
|
||||
/// contains a string name as well as the source location for that name.
|
||||
struct Name {
|
||||
static const Name &create(Context &ctx, StringRef name,
|
||||
SMRange location);
|
||||
static const Name &create(Context &ctx, StringRef name, SMRange location);
|
||||
|
||||
/// Return the raw string name.
|
||||
StringRef getName() const { return name; }
|
||||
@ -47,8 +46,7 @@ private:
|
||||
Name() = delete;
|
||||
Name(const Name &) = delete;
|
||||
Name &operator=(const Name &) = delete;
|
||||
Name(StringRef name, SMRange location)
|
||||
: name(name), location(location) {}
|
||||
Name(StringRef name, SMRange location) : name(name), location(location) {}
|
||||
|
||||
/// The string name of the decl.
|
||||
StringRef name;
|
||||
@ -80,13 +78,15 @@ public:
|
||||
/// Lookup a decl with the given name starting from this scope. Returns
|
||||
/// nullptr if no decl could be found.
|
||||
Decl *lookup(StringRef name);
|
||||
template <typename T> T *lookup(StringRef name) {
|
||||
template <typename T>
|
||||
T *lookup(StringRef name) {
|
||||
return dyn_cast_or_null<T>(lookup(name));
|
||||
}
|
||||
const Decl *lookup(StringRef name) const {
|
||||
return const_cast<DeclScope *>(this)->lookup(name);
|
||||
}
|
||||
template <typename T> const T *lookup(StringRef name) const {
|
||||
template <typename T>
|
||||
const T *lookup(StringRef name) const {
|
||||
return dyn_cast_or_null<T>(lookup(name));
|
||||
}
|
||||
|
||||
@ -107,7 +107,8 @@ private:
|
||||
class Node {
|
||||
public:
|
||||
/// This CRTP class provides several utilies when defining new AST nodes.
|
||||
template <typename T, typename BaseT> class NodeBase : public BaseT {
|
||||
template <typename T, typename BaseT>
|
||||
class NodeBase : public BaseT {
|
||||
public:
|
||||
using Base = NodeBase<T, BaseT>;
|
||||
|
||||
@ -208,15 +209,13 @@ private:
|
||||
/// to define variables.
|
||||
class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {
|
||||
public:
|
||||
static LetStmt *create(Context &ctx, SMRange loc,
|
||||
VariableDecl *varDecl);
|
||||
static LetStmt *create(Context &ctx, SMRange loc, VariableDecl *varDecl);
|
||||
|
||||
/// Return the variable defined by this statement.
|
||||
VariableDecl *getVarDecl() const { return varDecl; }
|
||||
|
||||
private:
|
||||
LetStmt(SMRange loc, VariableDecl *varDecl)
|
||||
: Base(loc), varDecl(varDecl) {}
|
||||
LetStmt(SMRange loc, VariableDecl *varDecl) : Base(loc), varDecl(varDecl) {}
|
||||
|
||||
/// The variable defined by this statement.
|
||||
VariableDecl *varDecl;
|
||||
@ -351,8 +350,7 @@ public:
|
||||
static bool classof(const Node *node);
|
||||
|
||||
protected:
|
||||
Expr(TypeID typeID, SMRange loc, Type type)
|
||||
: Stmt(typeID, loc), type(type) {}
|
||||
Expr(TypeID typeID, SMRange loc, Type type) : Stmt(typeID, loc), type(type) {}
|
||||
|
||||
private:
|
||||
/// The type of this expression.
|
||||
@ -367,8 +365,7 @@ private:
|
||||
/// textual assembly format of that attribute.
|
||||
class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> {
|
||||
public:
|
||||
static AttributeExpr *create(Context &ctx, SMRange loc,
|
||||
StringRef value);
|
||||
static AttributeExpr *create(Context &ctx, SMRange loc, StringRef value);
|
||||
|
||||
/// Get the raw value of this expression. This is the textual assembly format
|
||||
/// of the MLIR Attribute.
|
||||
@ -426,8 +423,7 @@ private:
|
||||
/// This expression represents a reference to a Decl node.
|
||||
class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {
|
||||
public:
|
||||
static DeclRefExpr *create(Context &ctx, SMRange loc, Decl *decl,
|
||||
Type type);
|
||||
static DeclRefExpr *create(Context &ctx, SMRange loc, Decl *decl, Type type);
|
||||
|
||||
/// Get the decl referenced by this expression.
|
||||
Decl *getDecl() const { return decl; }
|
||||
@ -459,8 +455,8 @@ public:
|
||||
StringRef getMemberName() const { return memberName; }
|
||||
|
||||
private:
|
||||
MemberAccessExpr(SMRange loc, const Expr *parentExpr,
|
||||
StringRef memberName, Type type)
|
||||
MemberAccessExpr(SMRange loc, const Expr *parentExpr, StringRef memberName,
|
||||
Type type)
|
||||
: Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
|
||||
|
||||
/// The parent expression of this access.
|
||||
@ -578,8 +574,7 @@ private:
|
||||
class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
|
||||
private llvm::TrailingObjects<TupleExpr, Expr *> {
|
||||
public:
|
||||
static TupleExpr *create(Context &ctx, SMRange loc,
|
||||
ArrayRef<Expr *> elements,
|
||||
static TupleExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
|
||||
ArrayRef<StringRef> elementNames);
|
||||
|
||||
/// Return the element expressions of this tuple.
|
||||
@ -697,8 +692,7 @@ public:
|
||||
static bool classof(const Node *node);
|
||||
|
||||
protected:
|
||||
CoreConstraintDecl(TypeID typeID, SMRange loc,
|
||||
const Name *name = nullptr)
|
||||
CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
|
||||
: ConstraintDecl(typeID, loc, name) {}
|
||||
};
|
||||
|
||||
@ -786,8 +780,7 @@ protected:
|
||||
class ValueConstraintDecl
|
||||
: public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
|
||||
public:
|
||||
static ValueConstraintDecl *create(Context &ctx, SMRange loc,
|
||||
Expr *typeExpr);
|
||||
static ValueConstraintDecl *create(Context &ctx, SMRange loc, Expr *typeExpr);
|
||||
|
||||
/// Return the optional type the value is constrained to.
|
||||
Expr *getTypeExpr() { return typeExpr; }
|
||||
@ -996,8 +989,8 @@ private:
|
||||
/// This Decl represents a single Pattern.
|
||||
class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
|
||||
public:
|
||||
static PatternDecl *create(Context &ctx, SMRange location,
|
||||
const Name *name, Optional<uint16_t> benefit,
|
||||
static PatternDecl *create(Context &ctx, SMRange location, const Name *name,
|
||||
Optional<uint16_t> benefit,
|
||||
bool hasBoundedRecursion,
|
||||
const CompoundStmt *body);
|
||||
|
||||
@ -1249,8 +1242,7 @@ private:
|
||||
class Module final : public Node::NodeBase<Module, Node>,
|
||||
private llvm::TrailingObjects<Module, Decl *> {
|
||||
public:
|
||||
static Module *create(Context &ctx, SMLoc loc,
|
||||
ArrayRef<Decl *> children);
|
||||
static Module *create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children);
|
||||
|
||||
/// Return the children of this module.
|
||||
MutableArrayRef<Decl *> getChildren() {
|
||||
|
@ -62,20 +62,25 @@ public:
|
||||
explicit operator bool() const { return impl; }
|
||||
|
||||
/// Provide type casting support.
|
||||
template <typename U> bool isa() const {
|
||||
template <typename U>
|
||||
bool isa() const {
|
||||
assert(impl && "isa<> used on a null type.");
|
||||
return U::classof(*this);
|
||||
}
|
||||
template <typename U, typename V, typename... Others> bool isa() const {
|
||||
template <typename U, typename V, typename... Others>
|
||||
bool isa() const {
|
||||
return isa<U>() || isa<V, Others...>();
|
||||
}
|
||||
template <typename U> U dyn_cast() const {
|
||||
template <typename U>
|
||||
U dyn_cast() const {
|
||||
return isa<U>() ? U(impl) : U(nullptr);
|
||||
}
|
||||
template <typename U> U dyn_cast_or_null() const {
|
||||
template <typename U>
|
||||
U dyn_cast_or_null() const {
|
||||
return (impl && isa<U>()) ? U(impl) : U(nullptr);
|
||||
}
|
||||
template <typename U> U cast() const {
|
||||
template <typename U>
|
||||
U cast() const {
|
||||
assert(isa<U>());
|
||||
return U(impl);
|
||||
}
|
||||
@ -99,7 +104,8 @@ public:
|
||||
protected:
|
||||
/// Return the internal storage instance of this type reinterpreted as the
|
||||
/// given derived storage type.
|
||||
template <typename T> const T *getImplAs() const {
|
||||
template <typename T>
|
||||
const T *getImplAs() const {
|
||||
return static_cast<const T *>(impl);
|
||||
}
|
||||
|
||||
@ -296,7 +302,8 @@ MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::TypeTypeStorage)
|
||||
MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::ValueTypeStorage)
|
||||
|
||||
namespace llvm {
|
||||
template <> struct DenseMapInfo<mlir::pdll::ast::Type> {
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::pdll::ast::Type> {
|
||||
static mlir::pdll::ast::Type getEmptyKey() {
|
||||
void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::pdll::ast::Type(
|
||||
|
@ -67,7 +67,7 @@ public:
|
||||
/// the results after folding the operation.
|
||||
template <typename OpTy, typename... Args>
|
||||
void create(OpBuilder &builder, SmallVectorImpl<Value> &results,
|
||||
Location location, Args &&... args) {
|
||||
Location location, Args &&...args) {
|
||||
// The op needs to be inserted only if the fold (below) fails, or the number
|
||||
// of results produced by the successful folding is zero (which is treated
|
||||
// as an in-place fold). Using create methods of the builder will insert the
|
||||
@ -88,7 +88,7 @@ public:
|
||||
template <typename OpTy, typename... Args>
|
||||
typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
|
||||
Value>::type
|
||||
create(OpBuilder &builder, Location location, Args &&... args) {
|
||||
create(OpBuilder &builder, Location location, Args &&...args) {
|
||||
SmallVector<Value, 1> results;
|
||||
create<OpTy>(builder, results, location, std::forward<Args>(args)...);
|
||||
return results.front();
|
||||
@ -98,7 +98,7 @@ public:
|
||||
template <typename OpTy, typename... Args>
|
||||
typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResults>(),
|
||||
OpTy>::type
|
||||
create(OpBuilder &builder, Location location, Args &&... args) {
|
||||
create(OpBuilder &builder, Location location, Args &&...args) {
|
||||
auto op = builder.create<OpTy>(location, std::forward<Args>(args)...);
|
||||
SmallVector<Value, 0> unused;
|
||||
(void)tryToFold(op.getOperation(), unused);
|
||||
|
@ -20,8 +20,7 @@ namespace mlir {
|
||||
class Pass;
|
||||
|
||||
/// Creates a pass to print op graphs.
|
||||
std::unique_ptr<Pass>
|
||||
createPrintOpGraphPass(raw_ostream &os = llvm::errs());
|
||||
std::unique_ptr<Pass> createPrintOpGraphPass(raw_ostream &os = llvm::errs());
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -11,8 +11,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
@ -766,8 +766,8 @@ struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
|
||||
Value absArg = b.create<complex::AbsOp>(elementType, arg);
|
||||
Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
|
||||
|
||||
Value half = b.create<arith::ConstantOp>(
|
||||
elementType, b.getFloatAttr(elementType, 0.5));
|
||||
Value half = b.create<arith::ConstantOp>(elementType,
|
||||
b.getFloatAttr(elementType, 0.5));
|
||||
Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
|
||||
Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
|
||||
|
||||
|
@ -877,9 +877,9 @@ mlir::createGpuToLLVMConversionPass() {
|
||||
return std::make_unique<GpuToLLVMConversionPass>();
|
||||
}
|
||||
|
||||
void mlir::populateGpuToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
||||
StringRef gpuBinaryAnnotation) {
|
||||
void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns,
|
||||
StringRef gpuBinaryAnnotation) {
|
||||
converter.addConversion(
|
||||
[context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
|
||||
return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
|
||||
|
@ -146,7 +146,6 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
|
||||
|
||||
// TODO: Load to Workgroup storage class first.
|
||||
|
||||
|
||||
// Get the input element accessed by this invocation.
|
||||
Value inputElementPtr = spirv::getElementPtr(
|
||||
*typeConverter, originalInputType, convertedInput, {x}, loc, rewriter);
|
||||
|
@ -111,7 +111,6 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
/// Populate the given list with patterns that convert from Linalg to Standard.
|
||||
void mlir::linalg::populateLinalgToStandardConversionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
|
@ -2753,7 +2753,8 @@ LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
|
||||
// AffineMinMaxOpBase
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename T> static LogicalResult verifyAffineMinMaxOp(T op) {
|
||||
template <typename T>
|
||||
static LogicalResult verifyAffineMinMaxOp(T op) {
|
||||
// Verify that operand count matches affine map dimension and symbol count.
|
||||
if (op.getNumOperands() !=
|
||||
op.getMap().getNumDims() + op.getMap().getNumSymbols())
|
||||
@ -2762,7 +2763,8 @@ template <typename T> static LogicalResult verifyAffineMinMaxOp(T op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename T> static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
|
||||
template <typename T>
|
||||
static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
|
||||
p << ' ' << op->getAttr(T::getMapAttrStrName());
|
||||
auto operands = op.getOperands();
|
||||
unsigned numDims = op.getMap().getNumDims();
|
||||
@ -2870,7 +2872,8 @@ struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
|
||||
///
|
||||
/// %1 = affine.min affine_map<
|
||||
/// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
|
||||
template <typename T> struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
|
||||
template <typename T>
|
||||
struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
|
||||
using OpRewritePattern<T>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(T affineOp,
|
||||
|
@ -209,8 +209,8 @@ OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
|
||||
operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
|
||||
}
|
||||
|
||||
void arith::AddIOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
|
||||
context);
|
||||
}
|
||||
@ -231,8 +231,8 @@ OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
|
||||
operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
|
||||
}
|
||||
|
||||
void arith::SubIOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns
|
||||
.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
|
||||
SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
|
||||
@ -539,8 +539,8 @@ OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
|
||||
operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
|
||||
}
|
||||
|
||||
void arith::XOrIOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<XOrINotCmpI>(context);
|
||||
}
|
||||
|
||||
@ -921,8 +921,8 @@ bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
|
||||
}
|
||||
|
||||
void arith::ExtSIOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<ExtSIOfExtUI>(context);
|
||||
}
|
||||
|
||||
@ -1017,8 +1017,8 @@ LogicalResult arith::TruncFOp::verify() {
|
||||
// AndIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void arith::AndIOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<AndOfExtUI, AndOfExtSI>(context);
|
||||
}
|
||||
|
||||
@ -1026,8 +1026,8 @@ void arith::AndIOp::getCanonicalizationPatterns(
|
||||
// OrIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void arith::OrIOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<OrOfExtUI, OrOfExtSI>(context);
|
||||
}
|
||||
|
||||
@ -1226,8 +1226,8 @@ OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return IntegerAttr::get(resType, bits);
|
||||
}
|
||||
|
||||
void arith::BitcastOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<BitcastOfBitcast>(context);
|
||||
}
|
||||
|
||||
|
@ -159,7 +159,7 @@ public:
|
||||
Location loc = op.getLoc();
|
||||
// If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
|
||||
static_assert(pred == arith::CmpFPredicate::UGT ||
|
||||
pred == arith::CmpFPredicate::ULT,
|
||||
pred == arith::CmpFPredicate::ULT,
|
||||
"pred must be either UGT or ULT");
|
||||
Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
|
||||
Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
|
||||
|
@ -212,15 +212,14 @@ struct OneShotBufferizePass
|
||||
};
|
||||
|
||||
// Configure op filter.
|
||||
OpFilter::Entry::FilterFn filterFn =
|
||||
[&](Operation *op) {
|
||||
// Filter may be specified via options.
|
||||
if (this->dialectFilter.hasValue())
|
||||
return llvm::is_contained(this->dialectFilter,
|
||||
op->getDialect()->getNamespace());
|
||||
// No filter specified: All other ops are allowed.
|
||||
return true;
|
||||
};
|
||||
OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
|
||||
// Filter may be specified via options.
|
||||
if (this->dialectFilter.hasValue())
|
||||
return llvm::is_contained(this->dialectFilter,
|
||||
op->getDialect()->getNamespace());
|
||||
// No filter specified: All other ops are allowed.
|
||||
return true;
|
||||
};
|
||||
opt.opFilter.allowOperation(filterFn);
|
||||
} else {
|
||||
opt = *options;
|
||||
|
@ -133,16 +133,15 @@ SerializeToCubinPass::serializeISA(const std::string &isa) {
|
||||
|
||||
// Register pass to serialize GPU kernel functions to a CUBIN binary annotation.
|
||||
void mlir::registerGpuSerializeToCubinPass() {
|
||||
PassRegistration<SerializeToCubinPass> registerSerializeToCubin(
|
||||
[] {
|
||||
// Initialize LLVM NVPTX backend.
|
||||
LLVMInitializeNVPTXTarget();
|
||||
LLVMInitializeNVPTXTargetInfo();
|
||||
LLVMInitializeNVPTXTargetMC();
|
||||
LLVMInitializeNVPTXAsmPrinter();
|
||||
PassRegistration<SerializeToCubinPass> registerSerializeToCubin([] {
|
||||
// Initialize LLVM NVPTX backend.
|
||||
LLVMInitializeNVPTXTarget();
|
||||
LLVMInitializeNVPTXTargetInfo();
|
||||
LLVMInitializeNVPTXTargetMC();
|
||||
LLVMInitializeNVPTXAsmPrinter();
|
||||
|
||||
return std::make_unique<SerializeToCubinPass>();
|
||||
});
|
||||
return std::make_unique<SerializeToCubinPass>();
|
||||
});
|
||||
}
|
||||
#else // MLIR_GPU_TO_CUBIN_PASS_ENABLE
|
||||
void mlir::registerGpuSerializeToCubinPass() {}
|
||||
|
@ -360,8 +360,7 @@ SerializeToHsacoPass::assembleIsa(const std::string &isa) {
|
||||
}
|
||||
|
||||
llvm::SourceMgr srcMgr;
|
||||
srcMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(isa),
|
||||
SMLoc());
|
||||
srcMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(isa), SMLoc());
|
||||
|
||||
const llvm::MCTargetOptions mcOptions;
|
||||
std::unique_ptr<llvm::MCRegisterInfo> mri(
|
||||
@ -469,18 +468,17 @@ SerializeToHsacoPass::serializeISA(const std::string &isa) {
|
||||
|
||||
// Register pass to serialize GPU kernel functions to a HSACO binary annotation.
|
||||
void mlir::registerGpuSerializeToHsacoPass() {
|
||||
PassRegistration<SerializeToHsacoPass> registerSerializeToHSACO(
|
||||
[] {
|
||||
// Initialize LLVM AMDGPU backend.
|
||||
LLVMInitializeAMDGPUAsmParser();
|
||||
LLVMInitializeAMDGPUAsmPrinter();
|
||||
LLVMInitializeAMDGPUTarget();
|
||||
LLVMInitializeAMDGPUTargetInfo();
|
||||
LLVMInitializeAMDGPUTargetMC();
|
||||
PassRegistration<SerializeToHsacoPass> registerSerializeToHSACO([] {
|
||||
// Initialize LLVM AMDGPU backend.
|
||||
LLVMInitializeAMDGPUAsmParser();
|
||||
LLVMInitializeAMDGPUAsmPrinter();
|
||||
LLVMInitializeAMDGPUTarget();
|
||||
LLVMInitializeAMDGPUTargetInfo();
|
||||
LLVMInitializeAMDGPUTargetMC();
|
||||
|
||||
return std::make_unique<SerializeToHsacoPass>("amdgcn-amd-amdhsa", "",
|
||||
"", 2);
|
||||
});
|
||||
return std::make_unique<SerializeToHsacoPass>("amdgcn-amd-amdhsa", "", "",
|
||||
2);
|
||||
});
|
||||
}
|
||||
|
||||
/// Create an instance of the GPU kernel function to HSAco binary serialization
|
||||
|
@ -485,7 +485,8 @@ recordStructIndices(Type baseGEPType, unsigned indexPos,
|
||||
unsigned dynamicIndexPos = indexPos;
|
||||
if (!isStaticIndex)
|
||||
dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1),
|
||||
LLVM::GEPOp::kDynamicIndex) - 1;
|
||||
LLVM::GEPOp::kDynamicIndex) -
|
||||
1;
|
||||
|
||||
return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType)
|
||||
.Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error {
|
||||
|
@ -311,8 +311,7 @@ static LLVMArrayType parseArrayType(AsmParser &parser) {
|
||||
/// error at `subtypesLoc` in case of failure.
|
||||
static LLVMStructType trySetStructBody(LLVMStructType type,
|
||||
ArrayRef<Type> subtypes, bool isPacked,
|
||||
AsmParser &parser,
|
||||
SMLoc subtypesLoc) {
|
||||
AsmParser &parser, SMLoc subtypesLoc) {
|
||||
for (Type t : subtypes) {
|
||||
if (!LLVMStructType::isValidElementType(t)) {
|
||||
parser.emitError(subtypesLoc)
|
||||
|
@ -127,7 +127,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::memref::populateComposeSubViewPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
void mlir::memref::populateComposeSubViewPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<ComposeSubViewOpPattern>(context);
|
||||
}
|
||||
|
@ -483,7 +483,8 @@ namespace {
|
||||
// parseAndVerify does the actual parsing and verification of individual
|
||||
// elements. This is a functor since parsing the last element of the list
|
||||
// (termination condition) needs partial specialization.
|
||||
template <typename ParseType, typename... Args> struct ParseCommaSeparatedList {
|
||||
template <typename ParseType, typename... Args>
|
||||
struct ParseCommaSeparatedList {
|
||||
Optional<std::tuple<ParseType, Args...>>
|
||||
operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
|
||||
auto parseVal = parseAndVerify<ParseType>(dialect, parser);
|
||||
@ -503,7 +504,8 @@ template <typename ParseType, typename... Args> struct ParseCommaSeparatedList {
|
||||
|
||||
// Partial specialization of the function to parse a comma separated list of
|
||||
// specs to parse the last element of the list.
|
||||
template <typename ParseType> struct ParseCommaSeparatedList<ParseType> {
|
||||
template <typename ParseType>
|
||||
struct ParseCommaSeparatedList<ParseType> {
|
||||
Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
|
||||
DialectAsmParser &parser) const {
|
||||
if (auto value = parseAndVerify<ParseType>(dialect, parser))
|
||||
|
@ -259,33 +259,42 @@ void CooperativeMatrixNVType::getCapabilities(
|
||||
// ImageType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename T> static constexpr unsigned getNumBits() { return 0; }
|
||||
template <> constexpr unsigned getNumBits<Dim>() {
|
||||
template <typename T>
|
||||
static constexpr unsigned getNumBits() {
|
||||
return 0;
|
||||
}
|
||||
template <>
|
||||
constexpr unsigned getNumBits<Dim>() {
|
||||
static_assert((1 << 3) > getMaxEnumValForDim(),
|
||||
"Not enough bits to encode Dim value");
|
||||
return 3;
|
||||
}
|
||||
template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
|
||||
template <>
|
||||
constexpr unsigned getNumBits<ImageDepthInfo>() {
|
||||
static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
|
||||
"Not enough bits to encode ImageDepthInfo value");
|
||||
return 2;
|
||||
}
|
||||
template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
|
||||
template <>
|
||||
constexpr unsigned getNumBits<ImageArrayedInfo>() {
|
||||
static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
|
||||
"Not enough bits to encode ImageArrayedInfo value");
|
||||
return 1;
|
||||
}
|
||||
template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
|
||||
template <>
|
||||
constexpr unsigned getNumBits<ImageSamplingInfo>() {
|
||||
static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
|
||||
"Not enough bits to encode ImageSamplingInfo value");
|
||||
return 1;
|
||||
}
|
||||
template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
|
||||
template <>
|
||||
constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
|
||||
static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
|
||||
"Not enough bits to encode ImageSamplerUseInfo value");
|
||||
return 2;
|
||||
}
|
||||
template <> constexpr unsigned getNumBits<ImageFormat>() {
|
||||
template <>
|
||||
constexpr unsigned getNumBits<ImageFormat>() {
|
||||
static_assert((1 << 6) > getMaxEnumValForImageFormat(),
|
||||
"Not enough bits to encode ImageFormat value");
|
||||
return 6;
|
||||
|
@ -8,8 +8,8 @@
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "AffineExprDetail.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
|
@ -930,8 +930,7 @@ private:
|
||||
};
|
||||
} // namespace
|
||||
|
||||
SSANameState::SSANameState(
|
||||
Operation *op, const OpPrintingFlags &printerFlags)
|
||||
SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags)
|
||||
: printerFlags(printerFlags) {
|
||||
llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
|
||||
llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
|
||||
|
@ -319,11 +319,11 @@ unsigned PredecessorIterator::getSuccessorIndex() const {
|
||||
SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {}
|
||||
|
||||
SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() {
|
||||
if (block->empty() || llvm::hasSingleElement(*block->getParent()))
|
||||
return;
|
||||
Operation *term = &block->back();
|
||||
if ((count = term->getNumSuccessors()))
|
||||
base = term->getBlockOperands().data();
|
||||
if (block->empty() || llvm::hasSingleElement(*block->getParent()))
|
||||
return;
|
||||
Operation *term = &block->back();
|
||||
if ((count = term->getNumSuccessors()))
|
||||
base = term->getBlockOperands().data();
|
||||
}
|
||||
|
||||
SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() {
|
||||
|
@ -373,8 +373,7 @@ struct SourceMgrDiagnosticHandlerImpl {
|
||||
|
||||
// Otherwise, try to load the source file.
|
||||
std::string ignored;
|
||||
unsigned id =
|
||||
mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
|
||||
unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
|
||||
filenameToBufId[filename] = id;
|
||||
return id;
|
||||
}
|
||||
|
@ -261,10 +261,8 @@ TypeRange mlir::function_interface_impl::insertTypesInto(
|
||||
return storage;
|
||||
}
|
||||
|
||||
TypeRange
|
||||
mlir::function_interface_impl::filterTypesOut(TypeRange types,
|
||||
const BitVector &indices,
|
||||
SmallVectorImpl<Type> &storage) {
|
||||
TypeRange mlir::function_interface_impl::filterTypesOut(
|
||||
TypeRange types, const BitVector &indices, SmallVectorImpl<Type> &storage) {
|
||||
if (indices.none())
|
||||
return types;
|
||||
|
||||
|
@ -292,8 +292,7 @@ void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
|
||||
operands[numOperands + i].~OpOperand();
|
||||
}
|
||||
|
||||
void detail::OperandStorage::eraseOperands(
|
||||
const BitVector &eraseIndices) {
|
||||
void detail::OperandStorage::eraseOperands(const BitVector &eraseIndices) {
|
||||
MutableArrayRef<OpOperand> operands = getOperands();
|
||||
assert(eraseIndices.size() == operands.size());
|
||||
|
||||
|
@ -212,4 +212,3 @@ unsigned BlockOperand::getOperandNumber() {
|
||||
unsigned OpOperand::getOperandNumber() {
|
||||
return this - &getOwner()->getOpOperands()[0];
|
||||
}
|
||||
|
||||
|
@ -51,9 +51,7 @@ public:
|
||||
|
||||
/// Get the location of the next token and store it into the argument. This
|
||||
/// always succeeds.
|
||||
SMLoc getCurrentLocation() override {
|
||||
return parser.getToken().getLoc();
|
||||
}
|
||||
SMLoc getCurrentLocation() override { return parser.getToken().getLoc(); }
|
||||
|
||||
/// Re-encode the given source location as an MLIR location and return it.
|
||||
Location getEncodedSourceLoc(SMLoc loc) override {
|
||||
|
@ -259,8 +259,7 @@ void AsmParserState::addDefinition(Block *block, SMLoc location) {
|
||||
impl->blocks[it->second]->definition.loc = convertIdLocToRange(location);
|
||||
}
|
||||
|
||||
void AsmParserState::addDefinition(BlockArgument blockArg,
|
||||
SMLoc location) {
|
||||
void AsmParserState::addDefinition(BlockArgument blockArg, SMLoc location) {
|
||||
auto it = impl->blocksToIdx.find(blockArg.getOwner());
|
||||
assert(it != impl->blocksToIdx.end() &&
|
||||
"expected owner block to have an entry");
|
||||
|
@ -525,8 +525,7 @@ ParseResult TensorLiteralParser::parse(bool allowHex) {
|
||||
|
||||
/// Build a dense attribute instance with the parsed elements and the given
|
||||
/// shaped type.
|
||||
DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc,
|
||||
ShapedType type) {
|
||||
DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
|
||||
Type eltType = type.getElementType();
|
||||
|
||||
// Check to see if we parse the literal from a hex string.
|
||||
@ -676,8 +675,7 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
|
||||
}
|
||||
|
||||
/// Build a Dense String attribute for the given type.
|
||||
DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc,
|
||||
ShapedType type,
|
||||
DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
|
||||
Type eltTy) {
|
||||
if (hexStorage.has_value()) {
|
||||
auto stringValue = hexStorage.value().getStringValue();
|
||||
@ -698,8 +696,7 @@ DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc,
|
||||
}
|
||||
|
||||
/// Build a Dense attribute with hex data for the given type.
|
||||
DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc,
|
||||
ShapedType type) {
|
||||
DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
|
||||
Type elementType = type.getElementType();
|
||||
if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
|
||||
p.emitError(loc)
|
||||
|
@ -62,7 +62,8 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T> void addDataToHash(llvm::SHA1 &hasher, const T &data) {
|
||||
template <typename T>
|
||||
void addDataToHash(llvm::SHA1 &hasher, const T &data) {
|
||||
hasher.update(
|
||||
ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
|
||||
}
|
||||
|
@ -8,9 +8,9 @@
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
|
@ -33,9 +33,9 @@ PatternApplicator::~PatternApplicator() = default;
|
||||
#ifndef NDEBUG
|
||||
/// Log a message for a pattern that is impossible to match.
|
||||
static void logImpossibleToMatch(const Pattern &pattern) {
|
||||
llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
|
||||
<< "' because it is impossible to match or cannot lead "
|
||||
"to legal IR (by cost model)\n";
|
||||
llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
|
||||
<< "' because it is impossible to match or cannot lead "
|
||||
"to legal IR (by cost model)\n";
|
||||
}
|
||||
|
||||
/// Log IR after pattern application.
|
||||
|
@ -341,10 +341,9 @@ static LogicalResult convertDataOp(acc::DataOp &op,
|
||||
mapperAllocas)))
|
||||
return failure();
|
||||
|
||||
if (failed(processOperands(builder, moduleTranslation, op,
|
||||
op.createOperands(), totalNbOperand,
|
||||
kCreateFlag | kHoldFlag, flags, names, index,
|
||||
mapperAllocas)))
|
||||
if (failed(processOperands(
|
||||
builder, moduleTranslation, op, op.createOperands(), totalNbOperand,
|
||||
kCreateFlag | kHoldFlag, flags, names, index, mapperAllocas)))
|
||||
return failure();
|
||||
|
||||
// TODO create zero currenlty handled as create. Update when extension
|
||||
@ -355,10 +354,9 @@ static LogicalResult convertDataOp(acc::DataOp &op,
|
||||
mapperAllocas)))
|
||||
return failure();
|
||||
|
||||
if (failed(processOperands(builder, moduleTranslation, op,
|
||||
op.presentOperands(), totalNbOperand,
|
||||
kPresentFlag | kHoldFlag, flags, names, index,
|
||||
mapperAllocas)))
|
||||
if (failed(processOperands(
|
||||
builder, moduleTranslation, op, op.presentOperands(), totalNbOperand,
|
||||
kPresentFlag | kHoldFlag, flags, names, index, mapperAllocas)))
|
||||
return failure();
|
||||
|
||||
llvm::GlobalVariable *maptypes =
|
||||
|
@ -467,7 +467,8 @@ private:
|
||||
/// Method to deserialize an operation in the SPIR-V dialect that is a mirror
|
||||
/// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode
|
||||
/// == 1 and autogenSerialization == 1 in ODS.
|
||||
template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) {
|
||||
template <typename OpTy>
|
||||
LogicalResult processOp(ArrayRef<uint32_t> words) {
|
||||
return emitError(unknownLoc, "unsupported deserialization for ")
|
||||
<< OpTy::getOperationName() << " op";
|
||||
}
|
||||
|
@ -297,7 +297,8 @@ private:
|
||||
/// Serializes an operation in the SPIR-V dialect that is a mirror of an
|
||||
/// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1
|
||||
/// and autogenSerialization == 1 in ODS.
|
||||
template <typename OpTy> LogicalResult processOp(OpTy op) {
|
||||
template <typename OpTy>
|
||||
LogicalResult processOp(OpTy op) {
|
||||
return op.emitError("unsupported op serialization");
|
||||
}
|
||||
|
||||
|
@ -199,8 +199,7 @@ CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc,
|
||||
// LetStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LetStmt *LetStmt::create(Context &ctx, SMRange loc,
|
||||
VariableDecl *varDecl) {
|
||||
LetStmt *LetStmt::create(Context &ctx, SMRange loc, VariableDecl *varDecl) {
|
||||
return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
|
||||
}
|
||||
|
||||
@ -394,8 +393,7 @@ Optional<StringRef> OpConstraintDecl::getName() const {
|
||||
// TypeConstraintDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx,
|
||||
SMRange loc) {
|
||||
TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx, SMRange loc) {
|
||||
return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
|
||||
TypeConstraintDecl(loc);
|
||||
}
|
||||
@ -414,8 +412,8 @@ TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
|
||||
// ValueConstraintDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ValueConstraintDecl *
|
||||
ValueConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) {
|
||||
ValueConstraintDecl *ValueConstraintDecl::create(Context &ctx, SMRange loc,
|
||||
Expr *typeExpr) {
|
||||
return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
|
||||
ValueConstraintDecl(loc, typeExpr);
|
||||
}
|
||||
@ -424,9 +422,8 @@ ValueConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) {
|
||||
// ValueRangeConstraintDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
|
||||
SMRange loc,
|
||||
Expr *typeExpr) {
|
||||
ValueRangeConstraintDecl *
|
||||
ValueRangeConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) {
|
||||
return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
|
||||
ValueRangeConstraintDecl(loc, typeExpr);
|
||||
}
|
||||
@ -498,8 +495,8 @@ OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) {
|
||||
// PatternDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternDecl *PatternDecl::create(Context &ctx, SMRange loc,
|
||||
const Name *name, Optional<uint16_t> benefit,
|
||||
PatternDecl *PatternDecl::create(Context &ctx, SMRange loc, const Name *name,
|
||||
Optional<uint16_t> benefit,
|
||||
bool hasBoundedRecursion,
|
||||
const CompoundStmt *body) {
|
||||
return new (ctx.getAllocator().Allocate<PatternDecl>())
|
||||
@ -554,8 +551,7 @@ VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
|
||||
// Module
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Module *Module::create(Context &ctx, SMLoc loc,
|
||||
ArrayRef<Decl *> children) {
|
||||
Module *Module::create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children) {
|
||||
unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
|
||||
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
|
||||
|
||||
|
@ -44,18 +44,18 @@ std::string Token::getStringValue() const {
|
||||
assert(i + 1 <= e && "invalid string should be caught by lexer");
|
||||
auto c1 = bytes[i++];
|
||||
switch (c1) {
|
||||
case '"':
|
||||
case '\\':
|
||||
result.push_back(c1);
|
||||
continue;
|
||||
case 'n':
|
||||
result.push_back('\n');
|
||||
continue;
|
||||
case 't':
|
||||
result.push_back('\t');
|
||||
continue;
|
||||
default:
|
||||
break;
|
||||
case '"':
|
||||
case '\\':
|
||||
result.push_back(c1);
|
||||
continue;
|
||||
case 'n':
|
||||
result.push_back('\n');
|
||||
continue;
|
||||
case 't':
|
||||
result.push_back('\t');
|
||||
continue;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
assert(i + 1 <= e && "invalid string should be caught by lexer");
|
||||
@ -101,7 +101,8 @@ Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
|
||||
}
|
||||
|
||||
Lexer::~Lexer() {
|
||||
if (addedHandlerToDiagEngine) diagEngine.setHandlerFn(nullptr);
|
||||
if (addedHandlerToDiagEngine)
|
||||
diagEngine.setHandlerFn(nullptr);
|
||||
}
|
||||
|
||||
LogicalResult Lexer::pushInclude(StringRef filename, SMRange includeLoc) {
|
||||
@ -121,39 +122,39 @@ Token Lexer::emitError(SMRange loc, const Twine &msg) {
|
||||
diagEngine.emitError(loc, msg);
|
||||
return formToken(Token::error, loc.Start.getPointer());
|
||||
}
|
||||
Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg,
|
||||
SMRange noteLoc, const Twine ¬e) {
|
||||
Token Lexer::emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
|
||||
const Twine ¬e) {
|
||||
diagEngine.emitError(loc, msg)->attachNote(note, noteLoc);
|
||||
return formToken(Token::error, loc.Start.getPointer());
|
||||
}
|
||||
Token Lexer::emitError(const char *loc, const Twine &msg) {
|
||||
return emitError(SMRange(SMLoc::getFromPointer(loc),
|
||||
SMLoc::getFromPointer(loc + 1)),
|
||||
msg);
|
||||
return emitError(
|
||||
SMRange(SMLoc::getFromPointer(loc), SMLoc::getFromPointer(loc + 1)), msg);
|
||||
}
|
||||
|
||||
int Lexer::getNextChar() {
|
||||
char curChar = *curPtr++;
|
||||
switch (curChar) {
|
||||
default:
|
||||
return static_cast<unsigned char>(curChar);
|
||||
case 0: {
|
||||
// A nul character in the stream is either the end of the current buffer
|
||||
// or a random nul in the file. Disambiguate that here.
|
||||
if (curPtr - 1 != curBuffer.end()) return 0;
|
||||
default:
|
||||
return static_cast<unsigned char>(curChar);
|
||||
case 0: {
|
||||
// A nul character in the stream is either the end of the current buffer
|
||||
// or a random nul in the file. Disambiguate that here.
|
||||
if (curPtr - 1 != curBuffer.end())
|
||||
return 0;
|
||||
|
||||
// Otherwise, return end of file.
|
||||
--curPtr;
|
||||
return EOF;
|
||||
}
|
||||
case '\n':
|
||||
case '\r':
|
||||
// Handle the newline character by ignoring it and incrementing the line
|
||||
// count. However, be careful about 'dos style' files with \n\r in them.
|
||||
// Only treat a \n\r or \r\n as a single line.
|
||||
if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
|
||||
++curPtr;
|
||||
return '\n';
|
||||
// Otherwise, return end of file.
|
||||
--curPtr;
|
||||
return EOF;
|
||||
}
|
||||
case '\n':
|
||||
case '\r':
|
||||
// Handle the newline character by ignoring it and incrementing the line
|
||||
// count. However, be careful about 'dos style' files with \n\r in them.
|
||||
// Only treat a \n\r or \r\n as a single line.
|
||||
if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
|
||||
++curPtr;
|
||||
return '\n';
|
||||
}
|
||||
}
|
||||
|
||||
@ -168,99 +169,100 @@ Token Lexer::lexToken() {
|
||||
// This always consumes at least one character.
|
||||
int curChar = getNextChar();
|
||||
switch (curChar) {
|
||||
default:
|
||||
// Handle identifiers: [a-zA-Z_]
|
||||
if (isalpha(curChar) || curChar == '_') return lexIdentifier(tokStart);
|
||||
default:
|
||||
// Handle identifiers: [a-zA-Z_]
|
||||
if (isalpha(curChar) || curChar == '_')
|
||||
return lexIdentifier(tokStart);
|
||||
|
||||
// Unknown character, emit an error.
|
||||
return emitError(tokStart, "unexpected character");
|
||||
case EOF: {
|
||||
// Return EOF denoting the end of lexing.
|
||||
Token eof = formToken(Token::eof, tokStart);
|
||||
// Unknown character, emit an error.
|
||||
return emitError(tokStart, "unexpected character");
|
||||
case EOF: {
|
||||
// Return EOF denoting the end of lexing.
|
||||
Token eof = formToken(Token::eof, tokStart);
|
||||
|
||||
// Check to see if we are in an included file.
|
||||
SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
|
||||
if (parentIncludeLoc.isValid()) {
|
||||
curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
|
||||
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
|
||||
curPtr = parentIncludeLoc.getPointer();
|
||||
}
|
||||
|
||||
return eof;
|
||||
// Check to see if we are in an included file.
|
||||
SMLoc parentIncludeLoc = srcMgr.getParentIncludeLoc(curBufferID);
|
||||
if (parentIncludeLoc.isValid()) {
|
||||
curBufferID = srcMgr.FindBufferContainingLoc(parentIncludeLoc);
|
||||
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
|
||||
curPtr = parentIncludeLoc.getPointer();
|
||||
}
|
||||
|
||||
// Lex punctuation.
|
||||
case '-':
|
||||
if (*curPtr == '>') {
|
||||
++curPtr;
|
||||
return formToken(Token::arrow, tokStart);
|
||||
}
|
||||
return emitError(tokStart, "unexpected character");
|
||||
case ':':
|
||||
return formToken(Token::colon, tokStart);
|
||||
case ',':
|
||||
return formToken(Token::comma, tokStart);
|
||||
case '.':
|
||||
return formToken(Token::dot, tokStart);
|
||||
case '=':
|
||||
if (*curPtr == '>') {
|
||||
++curPtr;
|
||||
return formToken(Token::equal_arrow, tokStart);
|
||||
}
|
||||
return formToken(Token::equal, tokStart);
|
||||
case ';':
|
||||
return formToken(Token::semicolon, tokStart);
|
||||
case '[':
|
||||
if (*curPtr == '{') {
|
||||
++curPtr;
|
||||
return lexString(tokStart, /*isStringBlock=*/true);
|
||||
}
|
||||
return formToken(Token::l_square, tokStart);
|
||||
case ']':
|
||||
return formToken(Token::r_square, tokStart);
|
||||
return eof;
|
||||
}
|
||||
|
||||
case '<':
|
||||
return formToken(Token::less, tokStart);
|
||||
case '>':
|
||||
return formToken(Token::greater, tokStart);
|
||||
case '{':
|
||||
return formToken(Token::l_brace, tokStart);
|
||||
case '}':
|
||||
return formToken(Token::r_brace, tokStart);
|
||||
case '(':
|
||||
return formToken(Token::l_paren, tokStart);
|
||||
case ')':
|
||||
return formToken(Token::r_paren, tokStart);
|
||||
case '/':
|
||||
if (*curPtr == '/') {
|
||||
lexComment();
|
||||
continue;
|
||||
}
|
||||
return emitError(tokStart, "unexpected character");
|
||||
// Lex punctuation.
|
||||
case '-':
|
||||
if (*curPtr == '>') {
|
||||
++curPtr;
|
||||
return formToken(Token::arrow, tokStart);
|
||||
}
|
||||
return emitError(tokStart, "unexpected character");
|
||||
case ':':
|
||||
return formToken(Token::colon, tokStart);
|
||||
case ',':
|
||||
return formToken(Token::comma, tokStart);
|
||||
case '.':
|
||||
return formToken(Token::dot, tokStart);
|
||||
case '=':
|
||||
if (*curPtr == '>') {
|
||||
++curPtr;
|
||||
return formToken(Token::equal_arrow, tokStart);
|
||||
}
|
||||
return formToken(Token::equal, tokStart);
|
||||
case ';':
|
||||
return formToken(Token::semicolon, tokStart);
|
||||
case '[':
|
||||
if (*curPtr == '{') {
|
||||
++curPtr;
|
||||
return lexString(tokStart, /*isStringBlock=*/true);
|
||||
}
|
||||
return formToken(Token::l_square, tokStart);
|
||||
case ']':
|
||||
return formToken(Token::r_square, tokStart);
|
||||
|
||||
// Ignore whitespace characters.
|
||||
case 0:
|
||||
case ' ':
|
||||
case '\t':
|
||||
case '\n':
|
||||
return lexToken();
|
||||
case '<':
|
||||
return formToken(Token::less, tokStart);
|
||||
case '>':
|
||||
return formToken(Token::greater, tokStart);
|
||||
case '{':
|
||||
return formToken(Token::l_brace, tokStart);
|
||||
case '}':
|
||||
return formToken(Token::r_brace, tokStart);
|
||||
case '(':
|
||||
return formToken(Token::l_paren, tokStart);
|
||||
case ')':
|
||||
return formToken(Token::r_paren, tokStart);
|
||||
case '/':
|
||||
if (*curPtr == '/') {
|
||||
lexComment();
|
||||
continue;
|
||||
}
|
||||
return emitError(tokStart, "unexpected character");
|
||||
|
||||
case '#':
|
||||
return lexDirective(tokStart);
|
||||
case '"':
|
||||
return lexString(tokStart, /*isStringBlock=*/false);
|
||||
// Ignore whitespace characters.
|
||||
case 0:
|
||||
case ' ':
|
||||
case '\t':
|
||||
case '\n':
|
||||
return lexToken();
|
||||
|
||||
case '0':
|
||||
case '1':
|
||||
case '2':
|
||||
case '3':
|
||||
case '4':
|
||||
case '5':
|
||||
case '6':
|
||||
case '7':
|
||||
case '8':
|
||||
case '9':
|
||||
return lexNumber(tokStart);
|
||||
case '#':
|
||||
return lexDirective(tokStart);
|
||||
case '"':
|
||||
return lexString(tokStart, /*isStringBlock=*/false);
|
||||
|
||||
case '0':
|
||||
case '1':
|
||||
case '2':
|
||||
case '3':
|
||||
case '4':
|
||||
case '5':
|
||||
case '6':
|
||||
case '7':
|
||||
case '8':
|
||||
case '9':
|
||||
return lexNumber(tokStart);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -273,27 +275,28 @@ void Lexer::lexComment() {
|
||||
|
||||
while (true) {
|
||||
switch (*curPtr++) {
|
||||
case '\n':
|
||||
case '\r':
|
||||
// Newline is end of comment.
|
||||
case '\n':
|
||||
case '\r':
|
||||
// Newline is end of comment.
|
||||
return;
|
||||
case 0:
|
||||
// If this is the end of the buffer, end the comment.
|
||||
if (curPtr - 1 == curBuffer.end()) {
|
||||
--curPtr;
|
||||
return;
|
||||
case 0:
|
||||
// If this is the end of the buffer, end the comment.
|
||||
if (curPtr - 1 == curBuffer.end()) {
|
||||
--curPtr;
|
||||
return;
|
||||
}
|
||||
LLVM_FALLTHROUGH;
|
||||
default:
|
||||
// Skip over other characters.
|
||||
break;
|
||||
}
|
||||
LLVM_FALLTHROUGH;
|
||||
default:
|
||||
// Skip over other characters.
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Token Lexer::lexDirective(const char *tokStart) {
|
||||
// Match the rest with an identifier regex: [0-9a-zA-Z_]*
|
||||
while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
|
||||
while (isalnum(*curPtr) || *curPtr == '_')
|
||||
++curPtr;
|
||||
|
||||
StringRef str(tokStart, curPtr - tokStart);
|
||||
return Token(Token::directive, str);
|
||||
@ -301,7 +304,8 @@ Token Lexer::lexDirective(const char *tokStart) {
|
||||
|
||||
Token Lexer::lexIdentifier(const char *tokStart) {
|
||||
// Match the rest of the identifier regex: [0-9a-zA-Z_]*
|
||||
while (isalnum(*curPtr) || *curPtr == '_') ++curPtr;
|
||||
while (isalnum(*curPtr) || *curPtr == '_')
|
||||
++curPtr;
|
||||
|
||||
// Check to see if this identifier is a keyword.
|
||||
StringRef str(tokStart, curPtr - tokStart);
|
||||
@ -334,7 +338,8 @@ Token Lexer::lexNumber(const char *tokStart) {
|
||||
assert(isdigit(curPtr[-1]));
|
||||
|
||||
// Handle the normal decimal case.
|
||||
while (isdigit(*curPtr)) ++curPtr;
|
||||
while (isdigit(*curPtr))
|
||||
++curPtr;
|
||||
|
||||
return formToken(Token::integer, tokStart);
|
||||
}
|
||||
@ -352,54 +357,54 @@ Token Lexer::lexString(const char *tokStart, bool isStringBlock) {
|
||||
}
|
||||
|
||||
switch (*curPtr++) {
|
||||
case '"':
|
||||
// If this is a string block, we only end the string when we encounter a
|
||||
// `}]`.
|
||||
if (!isStringBlock)
|
||||
return formToken(Token::string, tokStart);
|
||||
case '"':
|
||||
// If this is a string block, we only end the string when we encounter a
|
||||
// `}]`.
|
||||
if (!isStringBlock)
|
||||
return formToken(Token::string, tokStart);
|
||||
continue;
|
||||
case '}':
|
||||
// If this is a string block, we only end the string when we encounter a
|
||||
// `}]`.
|
||||
if (!isStringBlock || *curPtr != ']')
|
||||
continue;
|
||||
case '}':
|
||||
// If this is a string block, we only end the string when we encounter a
|
||||
// `}]`.
|
||||
if (!isStringBlock || *curPtr != ']')
|
||||
continue;
|
||||
++curPtr;
|
||||
return formToken(Token::string_block, tokStart);
|
||||
case 0: {
|
||||
// If this is a random nul character in the middle of a string, just
|
||||
// include it. If it is the end of file, then it is an error.
|
||||
if (curPtr - 1 != curBuffer.end())
|
||||
continue;
|
||||
--curPtr;
|
||||
|
||||
StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
|
||||
return emitError(curPtr - 1,
|
||||
"expected '" + expectedEndStr + "' in string literal");
|
||||
}
|
||||
|
||||
case '\n':
|
||||
case '\v':
|
||||
case '\f':
|
||||
// String blocks allow multiple lines.
|
||||
if (!isStringBlock)
|
||||
return emitError(curPtr - 1, "expected '\"' in string literal");
|
||||
continue;
|
||||
|
||||
case '\\':
|
||||
// Handle explicitly a few escapes.
|
||||
if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
|
||||
*curPtr == 't') {
|
||||
++curPtr;
|
||||
return formToken(Token::string_block, tokStart);
|
||||
case 0: {
|
||||
// If this is a random nul character in the middle of a string, just
|
||||
// include it. If it is the end of file, then it is an error.
|
||||
if (curPtr - 1 != curBuffer.end())
|
||||
continue;
|
||||
--curPtr;
|
||||
|
||||
StringRef expectedEndStr = isStringBlock ? "}]" : "\"";
|
||||
return emitError(curPtr - 1,
|
||||
"expected '" + expectedEndStr + "' in string literal");
|
||||
} else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
|
||||
// Support \xx for two hex digits.
|
||||
curPtr += 2;
|
||||
} else {
|
||||
return emitError(curPtr - 1, "unknown escape in string literal");
|
||||
}
|
||||
continue;
|
||||
|
||||
case '\n':
|
||||
case '\v':
|
||||
case '\f':
|
||||
// String blocks allow multiple lines.
|
||||
if (!isStringBlock)
|
||||
return emitError(curPtr - 1, "expected '\"' in string literal");
|
||||
continue;
|
||||
|
||||
case '\\':
|
||||
// Handle explicitly a few escapes.
|
||||
if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' ||
|
||||
*curPtr == 't') {
|
||||
++curPtr;
|
||||
} else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1])) {
|
||||
// Support \xx for two hex digits.
|
||||
curPtr += 2;
|
||||
} else {
|
||||
return emitError(curPtr - 1, "unknown escape in string literal");
|
||||
}
|
||||
continue;
|
||||
|
||||
default:
|
||||
continue;
|
||||
default:
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -133,7 +133,8 @@ public:
|
||||
|
||||
/// Return if the token does not have the given kind.
|
||||
bool isNot(Kind k) const { return k != kind; }
|
||||
template <typename... T> bool isNot(Kind k1, Kind k2, T... others) const {
|
||||
template <typename... T>
|
||||
bool isNot(Kind k1, Kind k2, T... others) const {
|
||||
return !isAny(k1, k2, others...);
|
||||
}
|
||||
|
||||
@ -141,17 +142,13 @@ public:
|
||||
bool is(Kind k) const { return kind == k; }
|
||||
|
||||
/// Return a location for the start of this token.
|
||||
SMLoc getStartLoc() const {
|
||||
return SMLoc::getFromPointer(spelling.data());
|
||||
}
|
||||
SMLoc getStartLoc() const { return SMLoc::getFromPointer(spelling.data()); }
|
||||
/// Return a location at the end of this token.
|
||||
SMLoc getEndLoc() const {
|
||||
return SMLoc::getFromPointer(spelling.data() + spelling.size());
|
||||
}
|
||||
/// Return a location for the range of this token.
|
||||
SMRange getLoc() const {
|
||||
return SMRange(getStartLoc(), getEndLoc());
|
||||
}
|
||||
SMRange getLoc() const { return SMRange(getStartLoc(), getEndLoc()); }
|
||||
|
||||
private:
|
||||
/// Discriminator that indicates the kind of token this is.
|
||||
@ -193,8 +190,8 @@ public:
|
||||
/// Emit an error to the lexer with the given location and message.
|
||||
Token emitError(SMRange loc, const Twine &msg);
|
||||
Token emitError(const char *loc, const Twine &msg);
|
||||
Token emitErrorAndNote(SMRange loc, const Twine &msg,
|
||||
SMRange noteLoc, const Twine ¬e);
|
||||
Token emitErrorAndNote(SMRange loc, const Twine &msg, SMRange noteLoc,
|
||||
const Twine ¬e);
|
||||
|
||||
private:
|
||||
Token formToken(Token::Kind kind, const char *tokStart) {
|
||||
|
@ -314,8 +314,7 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass>
|
||||
mlir::createPrintOpGraphPass(raw_ostream &os) {
|
||||
std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
|
||||
return std::make_unique<PrintOpPass>(os);
|
||||
}
|
||||
|
||||
|
@ -2147,7 +2147,8 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
|
||||
|
||||
/// Function to find an element within the given range that has the same name as
|
||||
/// 'name'.
|
||||
template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
|
||||
template <typename RangeT>
|
||||
static auto findArg(RangeT &&range, StringRef name) {
|
||||
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
|
||||
return it != range.end() ? &*it : nullptr;
|
||||
}
|
||||
|
@ -70,9 +70,7 @@ TEST(DebugActionTest, DebugCounterHandler) {
|
||||
|
||||
// Handler that uses the number of action executions as the decider.
|
||||
struct DebugCounterHandler : public SimpleAction::Handler {
|
||||
FailureOr<bool> shouldExecute() final {
|
||||
return numExecutions++ < 3;
|
||||
}
|
||||
FailureOr<bool> shouldExecute() final { return numExecutions++ < 3; }
|
||||
unsigned numExecutions = 0;
|
||||
};
|
||||
manager.registerActionHandler<DebugCounterHandler>();
|
||||
|
@ -77,8 +77,8 @@ protected:
|
||||
verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
|
||||
|
||||
// Test collective params build method.
|
||||
op =
|
||||
builder.create<OpTy>(loc, TypeRange{i32Ty}, ValueRange{*cstI32, *cstI32});
|
||||
op = builder.create<OpTy>(loc, TypeRange{i32Ty},
|
||||
ValueRange{*cstI32, *cstI32});
|
||||
verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
|
||||
|
||||
// Test build method with no result types, default value of attributes.
|
||||
|
Loading…
x
Reference in New Issue
Block a user