mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-19 07:46:49 +00:00
[mlir][ODS] Verify type constraints in Types and Attributes (#102326)
When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing. Example: ``` def TestTypeVerification : Test_Type<"TestTypeVerification"> { let parameters = (ins AnyTypeOf<[I16, I32]>:$param); // ... } ``` No verification code was generated to ensure that `$param` is I16 or I32. When type constraints a present, a new method will generated for types and attributes: `verifyInvariantsImpl`. (The naming is similar to op verifiers.) The user-provided verifier is called `verify` (no change). There is now a new entry point to type/attribute verification: `verifyInvariants`. This function calls both `verifyInvariantsImpl` and `verify`. If neither of those two verifications are present, the `verifyInvariants` function is not generated. When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the `verifyInvariants` function. (This function was previously called `verify`.) Note for LLVM integration: If you have an attribute/type that is not defined in TableGen (i.e., just C++), you have to rename the verification function from `verify` to `verifyInvariants`. (Most attributes/types have no verification, in which case there is nothing to do.) Depends on #102657.
This commit is contained in:
parent
74e4694b8c
commit
7359a6b799
@ -148,9 +148,10 @@ public:
|
||||
|
||||
/// Verify that shape and elementType are actually allowed for the
|
||||
/// MMAMatrixType.
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<int64_t> shape, Type elementType,
|
||||
StringRef operand);
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<int64_t> shape, Type elementType,
|
||||
StringRef operand);
|
||||
|
||||
/// Get number of dims.
|
||||
unsigned getNumDims() const;
|
||||
|
@ -180,11 +180,13 @@ public:
|
||||
ArrayRef<Type> getBody() const;
|
||||
|
||||
/// Verifies that the type about to be constructed is well-formed.
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
StringRef, bool);
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<Type> types, bool);
|
||||
using Base::verify;
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
|
||||
bool);
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<Type> types, bool);
|
||||
using Base::verifyInvariants;
|
||||
|
||||
/// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
|
||||
/// DataLayout instance and query it instead.
|
||||
|
@ -54,10 +54,10 @@ public:
|
||||
/// The maximum number of bits supported for storage types.
|
||||
static constexpr unsigned MaxStorageBits = 32;
|
||||
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
unsigned flags, Type storageType,
|
||||
Type expressedType, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax);
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
|
||||
Type storageType, Type expressedType, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax);
|
||||
|
||||
/// Support method to enable LLVM-style type casting.
|
||||
static bool classof(Type type);
|
||||
@ -214,10 +214,10 @@ public:
|
||||
int64_t storageTypeMax);
|
||||
|
||||
/// Verifies construction invariants and issues errors/warnings.
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
unsigned flags, Type storageType,
|
||||
Type expressedType, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax);
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
|
||||
Type storageType, Type expressedType, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax);
|
||||
};
|
||||
|
||||
/// Represents a family of uniform, quantized types.
|
||||
@ -276,11 +276,11 @@ public:
|
||||
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
|
||||
|
||||
/// Verifies construction invariants and issues errors/warnings.
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
unsigned flags, Type storageType,
|
||||
Type expressedType, double scale,
|
||||
int64_t zeroPoint, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax);
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
|
||||
Type storageType, Type expressedType, double scale,
|
||||
int64_t zeroPoint, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax);
|
||||
|
||||
/// Gets the scale term. The scale designates the difference between the real
|
||||
/// values corresponding to consecutive quantized values differing by 1.
|
||||
@ -338,12 +338,12 @@ public:
|
||||
int64_t storageTypeMin, int64_t storageTypeMax);
|
||||
|
||||
/// Verifies construction invariants and issues errors/warnings.
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
unsigned flags, Type storageType,
|
||||
Type expressedType, ArrayRef<double> scales,
|
||||
ArrayRef<int64_t> zeroPoints,
|
||||
int32_t quantizedDimension,
|
||||
int64_t storageTypeMin, int64_t storageTypeMax);
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
|
||||
Type storageType, Type expressedType,
|
||||
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
|
||||
int32_t quantizedDimension, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax);
|
||||
|
||||
/// Gets the quantization scales. The scales designate the difference between
|
||||
/// the real values corresponding to consecutive quantized values differing
|
||||
@ -403,8 +403,9 @@ public:
|
||||
double min, double max);
|
||||
|
||||
/// Verifies construction invariants and issues errors/warnings.
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type expressedType, double min, double max);
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type expressedType, double min, double max);
|
||||
double getMin() const;
|
||||
double getMax() const;
|
||||
};
|
||||
|
@ -76,9 +76,10 @@ public:
|
||||
/// Returns `spirv::StorageClass`.
|
||||
std::optional<StorageClass> getStorageClass();
|
||||
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
IntegerAttr descriptorSet, IntegerAttr binding,
|
||||
IntegerAttr storageClass);
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
IntegerAttr descriptorSet, IntegerAttr binding,
|
||||
IntegerAttr storageClass);
|
||||
|
||||
static constexpr StringLiteral name = "spirv.interface_var_abi";
|
||||
};
|
||||
@ -128,9 +129,10 @@ public:
|
||||
/// Returns the capabilities as an integer array attribute.
|
||||
ArrayAttr getCapabilitiesAttr();
|
||||
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
IntegerAttr version, ArrayAttr capabilities,
|
||||
ArrayAttr extensions);
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
IntegerAttr version, ArrayAttr capabilities,
|
||||
ArrayAttr extensions);
|
||||
|
||||
static constexpr StringLiteral name = "spirv.ver_cap_ext";
|
||||
};
|
||||
|
@ -258,8 +258,9 @@ public:
|
||||
static SampledImageType
|
||||
getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
|
||||
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type imageType);
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type imageType);
|
||||
|
||||
Type getImageType() const;
|
||||
|
||||
@ -462,8 +463,9 @@ public:
|
||||
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type columnType, uint32_t columnCount);
|
||||
|
||||
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type columnType, uint32_t columnCount);
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type columnType, uint32_t columnCount);
|
||||
|
||||
/// Returns true if the matrix elements are vectors of float elements.
|
||||
static bool isValidColumnType(Type columnType);
|
||||
|
@ -176,8 +176,8 @@ public:
|
||||
template <typename... Args>
|
||||
static ConcreteT get(MLIRContext *ctx, Args &&...args) {
|
||||
// Ensure that the invariants are correct for construction.
|
||||
assert(
|
||||
succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
|
||||
assert(succeeded(
|
||||
ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
|
||||
return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
@ -198,7 +198,7 @@ public:
|
||||
static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
|
||||
MLIRContext *ctx, Args... args) {
|
||||
// If the construction invariants fail then we return a null attribute.
|
||||
if (failed(ConcreteT::verify(emitErrorFn, args...)))
|
||||
if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
|
||||
return ConcreteT();
|
||||
return UniquerT::template get<ConcreteT>(ctx, args...);
|
||||
}
|
||||
@ -226,7 +226,9 @@ protected:
|
||||
|
||||
/// Default implementation that just returns success.
|
||||
template <typename... Args>
|
||||
static LogicalResult verify(Args... args) {
|
||||
static LogicalResult
|
||||
verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
|
||||
Args... args) {
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -34,7 +34,7 @@ class AsmState;
|
||||
/// Derived type classes are expected to implement several required
|
||||
/// implementation hooks:
|
||||
/// * Optional:
|
||||
/// - static LogicalResult verify(
|
||||
/// - static LogicalResult verifyInvariants(
|
||||
/// function_ref<InFlightDiagnostic()> emitError,
|
||||
/// Args... args)
|
||||
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/TableGen/Builder.h"
|
||||
#include "mlir/TableGen/Constraint.h"
|
||||
#include "mlir/TableGen/Trait.h"
|
||||
|
||||
namespace llvm {
|
||||
@ -85,6 +86,9 @@ public:
|
||||
/// Get an optional C++ parameter parser.
|
||||
std::optional<StringRef> getParser() const;
|
||||
|
||||
/// If this is a type constraint, return it.
|
||||
std::optional<Constraint> getConstraint() const;
|
||||
|
||||
/// Get an optional C++ parameter printer.
|
||||
std::optional<StringRef> getPrinter() const;
|
||||
|
||||
@ -198,6 +202,10 @@ public:
|
||||
/// method.
|
||||
bool genVerifyDecl() const;
|
||||
|
||||
/// Return true if we need to generate any type constraint verification and
|
||||
/// the getChecked method.
|
||||
bool genVerifyInvariantsImpl() const;
|
||||
|
||||
/// Returns the def's extra class declaration code.
|
||||
std::optional<StringRef> getExtraDecls() const;
|
||||
|
||||
|
@ -67,6 +67,8 @@ public:
|
||||
|
||||
/// Get the C++ type.
|
||||
StringRef getType() const { return type; }
|
||||
/// Get the C++ parameter name.
|
||||
StringRef getName() const { return name; }
|
||||
/// Returns true if the parameter has a default value.
|
||||
bool hasDefaultValue() const { return !defaultValue.empty(); }
|
||||
|
||||
|
@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<int64_t> shape, Type elementType,
|
||||
StringRef operand) {
|
||||
MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<int64_t> shape, Type elementType,
|
||||
StringRef operand) {
|
||||
if (operand != "AOp" && operand != "BOp" && operand != "COp")
|
||||
return emitError() << "operand expected to be one of AOp, BOp or COp";
|
||||
|
||||
|
@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
|
||||
|
||||
bool LLVMStructType::isValidElementType(Type type) {
|
||||
return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
|
||||
LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
|
||||
type);
|
||||
LLVMFunctionType, LLVMTokenType>(type);
|
||||
}
|
||||
|
||||
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
|
||||
@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
|
||||
: getImpl()->getTypeList();
|
||||
}
|
||||
|
||||
LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
|
||||
StringRef, bool) {
|
||||
LogicalResult
|
||||
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
|
||||
bool) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<Type> types, bool) {
|
||||
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<Type> types, bool) {
|
||||
for (Type t : types)
|
||||
if (!isValidElementType(t))
|
||||
return emitError() << "invalid LLVM structure element type: " << t;
|
||||
|
@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) {
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
unsigned flags, Type storageType, Type expressedType,
|
||||
int64_t storageTypeMin, int64_t storageTypeMax) {
|
||||
QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
unsigned flags, Type storageType,
|
||||
Type expressedType, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax) {
|
||||
// Verify that the storage type is integral.
|
||||
// This restriction may be lifted at some point in favor of using bf16
|
||||
// or f16 as exact representations on hardware where that is advantageous.
|
||||
@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
unsigned flags, Type storageType, Type expressedType,
|
||||
int64_t storageTypeMin, int64_t storageTypeMax) {
|
||||
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
|
||||
storageTypeMin, storageTypeMax))) {
|
||||
AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
unsigned flags, Type storageType,
|
||||
Type expressedType, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax) {
|
||||
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
|
||||
expressedType, storageTypeMin,
|
||||
storageTypeMax))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked(
|
||||
storageTypeMin, storageTypeMax);
|
||||
}
|
||||
|
||||
LogicalResult UniformQuantizedType::verify(
|
||||
LogicalResult UniformQuantizedType::verifyInvariants(
|
||||
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
|
||||
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
|
||||
int64_t storageTypeMin, int64_t storageTypeMax) {
|
||||
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
|
||||
storageTypeMin, storageTypeMax))) {
|
||||
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
|
||||
expressedType, storageTypeMin,
|
||||
storageTypeMax))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@ -321,13 +325,14 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
|
||||
quantizedDimension, storageTypeMin, storageTypeMax);
|
||||
}
|
||||
|
||||
LogicalResult UniformQuantizedPerAxisType::verify(
|
||||
LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
|
||||
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
|
||||
Type storageType, Type expressedType, ArrayRef<double> scales,
|
||||
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
|
||||
int64_t storageTypeMin, int64_t storageTypeMax) {
|
||||
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
|
||||
storageTypeMin, storageTypeMax))) {
|
||||
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
|
||||
expressedType, storageTypeMin,
|
||||
storageTypeMax))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@ -380,9 +385,9 @@ CalibratedQuantizedType CalibratedQuantizedType::getChecked(
|
||||
min, max);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type expressedType, double min, double max) {
|
||||
LogicalResult CalibratedQuantizedType::verifyInvariants(
|
||||
function_ref<InFlightDiagnostic()> emitError, Type expressedType,
|
||||
double min, double max) {
|
||||
// Verify that the expressed type is floating point.
|
||||
// If this restriction is ever eliminated, the parser/printer must be
|
||||
// extended.
|
||||
|
@ -162,7 +162,7 @@ spirv::InterfaceVarABIAttr::getStorageClass() {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
LogicalResult spirv::InterfaceVarABIAttr::verify(
|
||||
LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants(
|
||||
function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
|
||||
IntegerAttr binding, IntegerAttr storageClass) {
|
||||
if (!descriptorSet.getType().isSignlessInteger(32))
|
||||
@ -257,10 +257,9 @@ ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
|
||||
return llvm::cast<ArrayAttr>(getImpl()->capabilities);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
IntegerAttr version, ArrayAttr capabilities,
|
||||
ArrayAttr extensions) {
|
||||
LogicalResult spirv::VerCapExtAttr::verifyInvariants(
|
||||
function_ref<InFlightDiagnostic()> emitError, IntegerAttr version,
|
||||
ArrayAttr capabilities, ArrayAttr extensions) {
|
||||
if (!version.getType().isSignlessInteger(32))
|
||||
return emitError() << "expected 32-bit integer for version";
|
||||
|
||||
|
@ -817,8 +817,8 @@ SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type SampledImageType::getImageType() const { return getImpl()->imageType; }
|
||||
|
||||
LogicalResult
|
||||
SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type imageType) {
|
||||
SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type imageType) {
|
||||
if (!llvm::isa<ImageType>(imageType))
|
||||
return emitError() << "expected image type";
|
||||
|
||||
@ -1181,8 +1181,9 @@ MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
|
||||
columnCount);
|
||||
}
|
||||
|
||||
LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type columnType, uint32_t columnCount) {
|
||||
LogicalResult
|
||||
MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
|
||||
Type columnType, uint32_t columnCount) {
|
||||
if (columnCount < 2 || columnCount > 4)
|
||||
return emitError() << "matrix can have 2, 3, or 4 columns only";
|
||||
|
||||
|
@ -1227,6 +1227,14 @@ StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
|
||||
return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
|
||||
}
|
||||
|
||||
StorageSpecifierType
|
||||
StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
|
||||
MLIRContext *ctx,
|
||||
SparseTensorEncodingAttr encoding) {
|
||||
return Base::getChecked(emitError, ctx,
|
||||
getNormalizedEncodingForSpecifier(encoding));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SparseTensorDialect Operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -184,6 +184,12 @@ bool AttrOrTypeDef::genVerifyDecl() const {
|
||||
return def->getValueAsBit("genVerifyDecl");
|
||||
}
|
||||
|
||||
bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
|
||||
return any_of(parameters, [](const AttrOrTypeParameter &p) {
|
||||
return p.getConstraint() != std::nullopt;
|
||||
});
|
||||
}
|
||||
|
||||
std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
|
||||
auto value = def->getValueAsString("extraClassDeclaration");
|
||||
return value.empty() ? std::optional<StringRef>() : value;
|
||||
@ -331,6 +337,13 @@ std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
|
||||
|
||||
llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
|
||||
|
||||
std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
|
||||
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
|
||||
if (param->getDef()->isSubClassOf("Constraint"))
|
||||
return Constraint(param->getDef());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttributeSelfTypeParameter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
9
mlir/test/IR/test-verifiers-type.mlir
Normal file
9
mlir/test/IR/test-verifiers-type.mlir
Normal file
@ -0,0 +1,9 @@
|
||||
// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK: "test.type_producer"() : () -> !test.type_verification<i16>
|
||||
"test.type_producer"() : () -> !test.type_verification<i16>
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @below{{failed to verify 'param': 16-bit signless integer or 32-bit signless integer}}
|
||||
"test.type_producer"() : () -> !test.type_verification<f16>
|
@ -270,7 +270,6 @@ def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
|
||||
let assemblyFormat = "`<` $a `>`";
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let genVerifyDecl = 1;
|
||||
let builders = [AttrBuilder<(ins "int":$a), [{
|
||||
return ::mlir::IntegerAttr::get(::mlir::IndexType::get($_ctxt), a);
|
||||
}], "::mlir::Attribute">];
|
||||
|
@ -392,4 +392,10 @@ def TestRecursiveAlias
|
||||
}];
|
||||
}
|
||||
|
||||
def TestTypeVerification : Test_Type<"TestTypeVerification"> {
|
||||
let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
|
||||
let mnemonic = "type_verification";
|
||||
let assemblyFormat = "`<` $param `>`";
|
||||
}
|
||||
|
||||
#endif // TEST_TYPEDEFS
|
||||
|
@ -1,8 +1,9 @@
|
||||
// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-attrdef-defs -I %S/../../include | FileCheck %s --check-prefix=ATTR
|
||||
// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-typedef-defs -I %S/../../include | FileCheck %s --check-prefix=TYPE
|
||||
// RUN: sed 's/DEFAULT_TYPE_PARSER/1/' %s | mlir-tblgen -gen-typedef-defs -I %S/../../include | FileCheck %s --check-prefix=TYPE --check-prefix=DEFAULT_TYPE_PARSER
|
||||
// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-attrdef-defs -attrdefs-dialect=TestDialect -I %S/../../include | FileCheck %s --check-prefix=ATTR
|
||||
// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-typedef-defs -typedefs-dialect=TestDialect -I %S/../../include | FileCheck %s --check-prefix=TYPE
|
||||
// RUN: sed 's/DEFAULT_TYPE_PARSER/1/' %s | mlir-tblgen -gen-typedef-defs -typedefs-dialect=TestDialect -I %S/../../include | FileCheck %s --check-prefix=TYPE --check-prefix=DEFAULT_TYPE_PARSER
|
||||
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/IR/BuiltinAttributes.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
@ -663,6 +664,78 @@ def TypeO : TestType<"TestQ"> {
|
||||
let assemblyFormat = "(custom<AB>($a)^ `x`) : (`y`)?";
|
||||
}
|
||||
|
||||
// Test attr / type verification.
|
||||
|
||||
// TYPE: ::llvm::LogicalResult TestPType::verifyInvariantsImpl(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::Type a) {
|
||||
// TYPE: if (!(((a.isSignlessInteger(16))) || ((a.isSignlessInteger(32))))) {
|
||||
// TYPE: emitError() << "failed to verify 'a': 16-bit signless integer or 32-bit signless integer";
|
||||
// TYPE: return ::mlir::failure();
|
||||
// TYPE: }
|
||||
// TYPE: return ::mlir::success();
|
||||
// TYPE: }
|
||||
|
||||
// TYPE: ::llvm::LogicalResult TestPType::verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::Type a) {
|
||||
// TYPE: if (::mlir::failed(verifyInvariantsImpl(emitError, a)))
|
||||
// TYPE: return ::mlir::failure();
|
||||
// TYPE: if (::mlir::failed(verify(emitError, a)))
|
||||
// TYPE: return ::mlir::failure();
|
||||
// TYPE: return ::mlir::success();
|
||||
// TYPE: }
|
||||
|
||||
def TypeP : TestType<"TestP"> {
|
||||
let parameters = (ins AnyTypeOf<[I16, I32]>:$a);
|
||||
let mnemonic = "type_p";
|
||||
let genVerifyDecl = 1;
|
||||
let assemblyFormat = "$a";
|
||||
}
|
||||
|
||||
// ATTR: ::llvm::LogicalResult TestRAttr::verifyInvariantsImpl(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::IntegerType a) {
|
||||
// ATTR: if (!((a.isSignlessInteger(32)))) {
|
||||
// ATTR: emitError() << "failed to verify 'a': 32-bit signless integer";
|
||||
// ATTR: return ::mlir::failure();
|
||||
// ATTR: }
|
||||
// ATTR: return ::mlir::success();
|
||||
// ATTR: }
|
||||
|
||||
// ATTR: ::llvm::LogicalResult TestRAttr::verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::IntegerType a) {
|
||||
// ATTR: if (::mlir::failed(verifyInvariantsImpl(emitError, a)))
|
||||
// ATTR: return ::mlir::failure();
|
||||
// ATTR: if (::mlir::failed(verify(emitError, a)))
|
||||
// ATTR: return ::mlir::failure();
|
||||
// ATTR: return ::mlir::success();
|
||||
// ATTR: }
|
||||
|
||||
def AttrR : TestAttr<"TestR"> {
|
||||
let parameters = (ins I32:$a);
|
||||
let mnemonic = "attr_r";
|
||||
let genVerifyDecl = 1;
|
||||
let assemblyFormat = "$a";
|
||||
}
|
||||
|
||||
// TYPE: ::llvm::LogicalResult TestSType::verifyInvariantsImpl(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::ArrayAttr a) {
|
||||
// TYPE: if (!((::llvm::isa<::mlir::ArrayAttr>(a)))) {
|
||||
// TYPE: emitError() << "failed to verify 'a': A collection of other Attribute values";
|
||||
// TYPE: return ::mlir::failure();
|
||||
// TYPE: }
|
||||
// TYPE: return ::mlir::success();
|
||||
// TYPE: }
|
||||
|
||||
// TYPE: ::llvm::LogicalResult TestSType::verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::ArrayAttr a) {
|
||||
// TYPE: if (::mlir::failed(verifyInvariantsImpl(emitError, a)))
|
||||
// TYPE: return ::mlir::failure();
|
||||
// TYPE: if (::mlir::failed(verify(emitError, a)))
|
||||
// TYPE: return ::mlir::failure();
|
||||
// TYPE: return ::mlir::success();
|
||||
// TYPE: }
|
||||
|
||||
def TypeS : TestType<"TestS"> {
|
||||
// TODO: Support attribute constraints as parameters.
|
||||
let parameters = (ins Builtin_ArrayAttr:$a);
|
||||
let mnemonic = "type_s";
|
||||
let genVerifyDecl = 1;
|
||||
let assemblyFormat = "$a";
|
||||
}
|
||||
|
||||
// DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
|
||||
// DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
|
||||
// DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {
|
||||
|
@ -93,8 +93,14 @@ private:
|
||||
void emitDialectName();
|
||||
/// Emit attribute or type builders.
|
||||
void emitBuilders();
|
||||
/// Emit a verifier for the def.
|
||||
void emitVerifier();
|
||||
/// Emit a verifier declaration for custom verification (impl. provided by
|
||||
/// the users).
|
||||
void emitVerifierDecl();
|
||||
/// Emit a verifier that checks type constraints.
|
||||
void emitInvariantsVerifierImpl();
|
||||
/// Emit an entry poiunt for verification that calls the invariants and
|
||||
/// custom verifier.
|
||||
void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier);
|
||||
/// Emit parsers and printers.
|
||||
void emitParserPrinter();
|
||||
/// Emit parameter accessors, if required.
|
||||
@ -188,9 +194,17 @@ DefGen::DefGen(const AttrOrTypeDef &def)
|
||||
emitName();
|
||||
// Emit the dialect name.
|
||||
emitDialectName();
|
||||
// Emit the verifier.
|
||||
if (storageCls && def.genVerifyDecl())
|
||||
emitVerifier();
|
||||
// Emit verification of type constraints.
|
||||
bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl();
|
||||
if (storageCls && genVerifyInvariantsImpl)
|
||||
emitInvariantsVerifierImpl();
|
||||
// Emit the custom verifier (written by the user).
|
||||
bool genVerifyDecl = def.genVerifyDecl();
|
||||
if (storageCls && genVerifyDecl)
|
||||
emitVerifierDecl();
|
||||
// Emit the "verifyInvariants" function if there is any verification at all.
|
||||
if (storageCls)
|
||||
emitInvariantsVerifier(genVerifyInvariantsImpl, genVerifyDecl);
|
||||
// Emit the mnemonic, if there is one, and any associated parser and printer.
|
||||
if (def.getMnemonic())
|
||||
emitParserPrinter();
|
||||
@ -295,24 +309,88 @@ void DefGen::emitDialectName() {
|
||||
void DefGen::emitBuilders() {
|
||||
if (!def.skipDefaultBuilders()) {
|
||||
emitDefaultBuilder();
|
||||
if (def.genVerifyDecl())
|
||||
if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
|
||||
emitCheckedBuilder();
|
||||
}
|
||||
for (auto &builder : def.getBuilders()) {
|
||||
emitCustomBuilder(builder);
|
||||
if (def.genVerifyDecl())
|
||||
if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
|
||||
emitCheckedCustomBuilder(builder);
|
||||
}
|
||||
}
|
||||
|
||||
void DefGen::emitVerifier() {
|
||||
defCls.declare<UsingDeclaration>("Base::getChecked");
|
||||
void DefGen::emitVerifierDecl() {
|
||||
defCls.declareStaticMethod(
|
||||
"::llvm::LogicalResult", "verify",
|
||||
getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
|
||||
"emitError"}}));
|
||||
}
|
||||
|
||||
static const char *const patternParameterVerificationCode = R"(
|
||||
if (!({0})) {
|
||||
emitError() << "failed to verify '{1}': {2}";
|
||||
return ::mlir::failure();
|
||||
}
|
||||
)";
|
||||
|
||||
void DefGen::emitInvariantsVerifierImpl() {
|
||||
SmallVector<MethodParameter> builderParams = getBuilderParams(
|
||||
{{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
|
||||
Method *verifier =
|
||||
defCls.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl",
|
||||
Method::Static, builderParams);
|
||||
verifier->body().indent();
|
||||
|
||||
// Generate verification for each parameter that is a type constraint.
|
||||
for (auto it : llvm::enumerate(def.getParameters())) {
|
||||
const AttrOrTypeParameter ¶m = it.value();
|
||||
std::optional<Constraint> constraint = param.getConstraint();
|
||||
// No verification needed for parameters that are not type constraints.
|
||||
if (!constraint.has_value())
|
||||
continue;
|
||||
FmtContext ctx;
|
||||
// Note: Skip over the first method parameter (`emitError`).
|
||||
ctx.withSelf(builderParams[it.index() + 1].getName());
|
||||
std::string condition = tgfmt(constraint->getConditionTemplate(), &ctx);
|
||||
verifier->body() << formatv(patternParameterVerificationCode, condition,
|
||||
param.getName(), constraint->getSummary())
|
||||
<< "\n";
|
||||
}
|
||||
verifier->body() << "return ::mlir::success();";
|
||||
}
|
||||
|
||||
void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
|
||||
if (!hasImpl && !hasCustomVerifier)
|
||||
return;
|
||||
defCls.declare<UsingDeclaration>("Base::getChecked");
|
||||
SmallVector<MethodParameter> builderParams = getBuilderParams(
|
||||
{{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
|
||||
Method *verifier =
|
||||
defCls.addMethod("::llvm::LogicalResult", "verifyInvariants",
|
||||
Method::Static, builderParams);
|
||||
verifier->body().indent();
|
||||
|
||||
auto emitVerifierCall = [&](StringRef name) {
|
||||
verifier->body() << strfmt("if (::mlir::failed({0}(", name);
|
||||
llvm::interleaveComma(
|
||||
llvm::map_range(builderParams,
|
||||
[](auto ¶m) { return param.getName(); }),
|
||||
verifier->body());
|
||||
verifier->body() << ")))\n";
|
||||
verifier->body() << " return ::mlir::failure();\n";
|
||||
};
|
||||
|
||||
if (hasImpl) {
|
||||
// Call the verifier that checks the type constraints.
|
||||
emitVerifierCall("verifyInvariantsImpl");
|
||||
}
|
||||
if (hasCustomVerifier) {
|
||||
// Call the custom verifier that is provided by the user.
|
||||
emitVerifierCall("verify");
|
||||
}
|
||||
verifier->body() << "return ::mlir::success();";
|
||||
}
|
||||
|
||||
void DefGen::emitParserPrinter() {
|
||||
auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
|
||||
"::llvm::StringLiteral", "getMnemonic");
|
||||
|
@ -323,7 +323,7 @@ void DefFormat::genParser(MethodBody &os) {
|
||||
|
||||
// Generate call to the attribute or type builder. Use the checked getter
|
||||
// if one was generated.
|
||||
if (def.genVerifyDecl()) {
|
||||
if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) {
|
||||
os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
|
||||
&ctx, def.getCppClassName());
|
||||
} else {
|
||||
|
Loading…
x
Reference in New Issue
Block a user