[mlir] Move casting calls from methods to function calls

The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.

Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.

Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants
  for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443

Implementation:
This patch updates all remaining uses of the deprecated functionality in
mlir/. This was done with clang-tidy as described below and further
modifications to GPUBase.td and OpenMPOpsInterfaces.td.

Steps are described per line, as comments are removed by git:
0. Retrieve the change from the following to build clang-tidy with an
   additional check:
   main...tpopp:llvm-project:tidy-cast-check
1. Build clang-tidy
2. Run clang-tidy over your entire codebase while disabling all checks
   and enabling the one relevant one. Run on all header files also.
3. Delete .inc files that were also modified, so the next build rebuilds
   them to a pure state.

```
ninja -C $BUILD_DIR clang-tidy

run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
               -header-filter=mlir/ mlir/* -fix

rm -rf $BUILD_DIR/tools/mlir/**/*.inc
```

Differential Revision: https://reviews.llvm.org/D151542
This commit is contained in:
Tres Popp 2023-05-26 10:17:47 +02:00
parent 7c52520c8d
commit 68f58812e3
117 changed files with 508 additions and 488 deletions

View File

@ -54,7 +54,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@ -133,13 +133,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@ -269,8 +269,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
resultType.isa<mlir::UnrankedTensorType>())
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@ -289,8 +289,8 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
mlir::LogicalResult TransposeOp::verify() {
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
auto resultType = getType().dyn_cast<RankedTensorType>();
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();

View File

@ -54,7 +54,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@ -133,13 +133,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@ -269,8 +269,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
resultType.isa<mlir::UnrankedTensorType>())
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@ -289,8 +289,8 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
mlir::LogicalResult TransposeOp::verify() {
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
auto resultType = getType().dyn_cast<RankedTensorType>();
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();

View File

@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
TensorType input = inputs.front().dyn_cast<TensorType>();
TensorType output = outputs.front().dyn_cast<TensorType>();
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
resultType.isa<mlir::UnrankedTensorType>())
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
void TransposeOp::inferShapes() {
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
auto resultType = getType().dyn_cast<RankedTensorType>();
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();

View File

@ -94,7 +94,7 @@ struct ShapeInferencePass
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
return operandType.isa<RankedTensorType>();
return llvm::isa<RankedTensorType>(operandType);
});
}
@ -102,7 +102,7 @@ struct ShapeInferencePass
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
return !resultType.isa<RankedTensorType>();
return !llvm::isa<RankedTensorType>(resultType);
});
}
};

View File

@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
TensorType input = inputs.front().dyn_cast<TensorType>();
TensorType output = outputs.front().dyn_cast<TensorType>();
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
resultType.isa<mlir::UnrankedTensorType>())
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
void TransposeOp::inferShapes() {
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
auto resultType = getType().dyn_cast<RankedTensorType>();
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();

View File

@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
auto tensorType = op.getType().cast<RankedTensorType>();
auto tensorType = llvm::cast<RankedTensorType>(op.getType());
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
target.addIllegalDialect<toy::ToyDialect>();
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
return llvm::none_of(op->getOperandTypes(),
[](Type type) { return type.isa<TensorType>(); });
[](Type type) { return llvm::isa<TensorType>(type); });
});
// Now that the conversion target has been defined, we just need to provide

View File

@ -94,7 +94,7 @@ struct ShapeInferencePass
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
return operandType.isa<RankedTensorType>();
return llvm::isa<RankedTensorType>(operandType);
});
}
@ -102,7 +102,7 @@ struct ShapeInferencePass
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
return !resultType.isa<RankedTensorType>();
return !llvm::isa<RankedTensorType>(resultType);
});
}
};

View File

@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
mlir::LogicalResult ConstantOp::verify() {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the constant
// result type.
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
TensorType input = inputs.front().dyn_cast<TensorType>();
TensorType output = outputs.front().dyn_cast<TensorType>();
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
resultType.isa<mlir::UnrankedTensorType>())
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
void TransposeOp::inferShapes() {
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
auto resultType = getType().dyn_cast<RankedTensorType>();
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();

View File

@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
auto tensorType = op.getType().cast<RankedTensorType>();
auto tensorType = llvm::cast<RankedTensorType>(op.getType());
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
target.addIllegalDialect<toy::ToyDialect>();
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
return llvm::none_of(op->getOperandTypes(),
[](Type type) { return type.isa<TensorType>(); });
[](Type type) { return llvm::isa<TensorType>(type); });
});
// Now that the conversion target has been defined, we just need to provide

View File

@ -61,7 +61,7 @@ public:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();

View File

@ -94,7 +94,7 @@ struct ShapeInferencePass
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
return operandType.isa<RankedTensorType>();
return llvm::isa<RankedTensorType>(operandType);
});
}
@ -102,7 +102,7 @@ struct ShapeInferencePass
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
return !resultType.isa<RankedTensorType>();
return !llvm::isa<RankedTensorType>(resultType);
});
}
};

View File

@ -101,7 +101,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
// If the type is a function type, it contains the input and result types of
// this operation.
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
@ -179,9 +179,9 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
static mlir::LogicalResult verifyConstantForType(mlir::Type type,
mlir::Attribute opaqueValue,
mlir::Operation *op) {
if (type.isa<mlir::TensorType>()) {
if (llvm::isa<mlir::TensorType>(type)) {
// Check that the value is an elements attribute.
auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
auto attrValue = llvm::dyn_cast<mlir::DenseFPElementsAttr>(opaqueValue);
if (!attrValue)
return op->emitError("constant of TensorType must be initialized by "
"a DenseFPElementsAttr, got ")
@ -189,13 +189,13 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = type.dyn_cast<mlir::RankedTensorType>();
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(type);
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the
// constant result type.
auto attrType = attrValue.getType().cast<mlir::RankedTensorType>();
auto attrType = llvm::cast<mlir::RankedTensorType>(attrValue.getType());
if (attrType.getRank() != resultType.getRank()) {
return op->emitOpError("return type must match the one of the attached "
"value attribute: ")
@ -213,11 +213,11 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
}
return mlir::success();
}
auto resultType = type.cast<StructType>();
auto resultType = llvm::cast<StructType>(type);
llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
// Verify that the initializer is an Array.
auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
auto attrValue = llvm::dyn_cast<ArrayAttr>(opaqueValue);
if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
return op->emitError("constant of StructType must be initialized by an "
"ArrayAttr with the same number of elements, got ")
@ -283,8 +283,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
TensorType input = inputs.front().dyn_cast<TensorType>();
TensorType output = outputs.front().dyn_cast<TensorType>();
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
@ -426,8 +426,8 @@ mlir::LogicalResult ReturnOp::verify() {
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
resultType.isa<mlir::UnrankedTensorType>())
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
llvm::isa<mlir::UnrankedTensorType>(resultType))
return mlir::success();
return emitError() << "type of return operand (" << inputType
@ -442,7 +442,7 @@ mlir::LogicalResult ReturnOp::verify() {
void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
mlir::Value input, size_t index) {
// Extract the result type from the input type.
StructType structTy = input.getType().cast<StructType>();
StructType structTy = llvm::cast<StructType>(input.getType());
assert(index < structTy.getNumElementTypes());
mlir::Type resultType = structTy.getElementTypes()[index];
@ -451,7 +451,7 @@ void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
}
mlir::LogicalResult StructAccessOp::verify() {
StructType structTy = getInput().getType().cast<StructType>();
StructType structTy = llvm::cast<StructType>(getInput().getType());
size_t indexValue = getIndex();
if (indexValue >= structTy.getNumElementTypes())
return emitOpError()
@ -474,14 +474,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
}
void TransposeOp::inferShapes() {
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
mlir::LogicalResult TransposeOp::verify() {
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
auto resultType = getType().dyn_cast<RankedTensorType>();
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
if (!inputType || !resultType)
return mlir::success();
@ -598,7 +598,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
return nullptr;
// Check that the type is either a TensorType or another StructType.
if (!elementType.isa<mlir::TensorType, StructType>()) {
if (!llvm::isa<mlir::TensorType, StructType>(elementType)) {
parser.emitError(typeLoc, "element type for a struct must either "
"be a TensorType or a StructType, got: ")
<< elementType;
@ -619,7 +619,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
void ToyDialect::printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const {
// Currently the only toy type is a struct type.
StructType structType = type.cast<StructType>();
StructType structType = llvm::cast<StructType>(type);
// Print the struct type according to the parser format.
printer << "struct<";
@ -653,9 +653,9 @@ mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
mlir::Attribute value,
mlir::Type type,
mlir::Location loc) {
if (type.isa<StructType>())
if (llvm::isa<StructType>(type))
return builder.create<StructConstantOp>(loc, type,
value.cast<mlir::ArrayAttr>());
llvm::cast<mlir::ArrayAttr>(value));
return builder.create<ConstantOp>(loc, type,
value.cast<mlir::DenseElementsAttr>());
llvm::cast<mlir::DenseElementsAttr>(value));
}

View File

@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
auto tensorType = op.getType().cast<RankedTensorType>();
auto tensorType = llvm::cast<RankedTensorType>(op.getType());
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
target.addIllegalDialect<toy::ToyDialect>();
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
return llvm::none_of(op->getOperandTypes(),
[](Type type) { return type.isa<TensorType>(); });
[](Type type) { return llvm::isa<TensorType>(type); });
});
// Now that the conversion target has been defined, we just need to provide

View File

@ -61,7 +61,7 @@ public:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();

View File

@ -94,7 +94,7 @@ struct ShapeInferencePass
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
return operandType.isa<RankedTensorType>();
return llvm::isa<RankedTensorType>(operandType);
});
}
@ -102,7 +102,7 @@ struct ShapeInferencePass
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
return !resultType.isa<RankedTensorType>();
return !llvm::isa<RankedTensorType>(resultType);
});
}
};

View File

@ -31,7 +31,8 @@ OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
/// Fold simple struct access operations that access into a constant.
OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
auto structAttr =
llvm::dyn_cast_if_present<mlir::ArrayAttr>(adaptor.getInput());
if (!structAttr)
return nullptr;

View File

@ -62,19 +62,19 @@ class FileLineColLocBreakpointManager
public:
Breakpoint *match(const Action &action) const override {
for (const IRUnit &unit : action.getContextIRUnits()) {
if (auto *op = unit.dyn_cast<Operation *>()) {
if (auto *op = llvm::dyn_cast_if_present<Operation *>(unit)) {
if (auto match = matchFromLocation(op->getLoc()))
return *match;
continue;
}
if (auto *block = unit.dyn_cast<Block *>()) {
if (auto *block = llvm::dyn_cast_if_present<Block *>(unit)) {
for (auto &op : block->getOperations()) {
if (auto match = matchFromLocation(op.getLoc()))
return *match;
}
continue;
}
if (Region *region = unit.dyn_cast<Region *>()) {
if (Region *region = llvm::dyn_cast_if_present<Region *>(unit)) {
if (auto match = matchFromLocation(region->getLoc()))
return *match;
continue;

View File

@ -110,27 +110,27 @@ class MMAMatrixOf<list<Type> allowedTypes> :
"gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
// Types for all sparse handles.
def GPU_SparseEnvHandle :
DialectType<GPU_Dialect,
CPred<"$_self.isa<::mlir::gpu::SparseEnvHandleType>()">,
"sparse environment handle type">,
def GPU_SparseEnvHandle :
DialectType<GPU_Dialect,
CPred<"llvm::isa<::mlir::gpu::SparseEnvHandleType>($_self)">,
"sparse environment handle type">,
BuildableType<"mlir::gpu::SparseEnvHandleType::get($_builder.getContext())">;
def GPU_SparseDnVecHandle :
DialectType<GPU_Dialect,
CPred<"$_self.isa<::mlir::gpu::SparseDnVecHandleType>()">,
def GPU_SparseDnVecHandle :
DialectType<GPU_Dialect,
CPred<"llvm::isa<::mlir::gpu::SparseDnVecHandleType>($_self)">,
"dense vector handle type">,
BuildableType<"mlir::gpu::SparseDnVecHandleType::get($_builder.getContext())">;
def GPU_SparseDnMatHandle :
DialectType<GPU_Dialect,
CPred<"$_self.isa<::mlir::gpu::SparseDnMatHandleType>()">,
def GPU_SparseDnMatHandle :
DialectType<GPU_Dialect,
CPred<"llvm::isa<::mlir::gpu::SparseDnMatHandleType>($_self)">,
"dense matrix handle type">,
BuildableType<"mlir::gpu::SparseDnMatHandleType::get($_builder.getContext())">;
def GPU_SparseSpMatHandle :
DialectType<GPU_Dialect,
CPred<"$_self.isa<::mlir::gpu::SparseSpMatHandleType>()">,
def GPU_SparseSpMatHandle :
DialectType<GPU_Dialect,
CPred<"llvm::isa<::mlir::gpu::SparseSpMatHandleType>($_self)">,
"sparse matrix handle type">,
BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">;

View File

@ -95,7 +95,7 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
/*methodName=*/"getDeclareTargetDeviceType",
(ins), [{}], [{
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
return dAttr.getDeviceType().getValue();
return {};
}]>,
@ -108,7 +108,7 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
/*methodName=*/"getDeclareTargetCaptureClause",
(ins), [{}], [{
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
return dAttr.getCaptureClause().getValue();
return {};
}]>

View File

@ -115,7 +115,7 @@ public:
static bool classof(Type type);
/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return cast<ShapedType>(); }
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
//===----------------------------------------------------------------------===//
@ -169,7 +169,7 @@ public:
unsigned getMemorySpaceAsInt() const;
/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return cast<ShapedType>(); }
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
} // namespace mlir

View File

@ -217,13 +217,15 @@ private:
}
static bool isEmptyKey(mlir::TypeRange range) {
if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
if (const auto *type =
llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
return type == getEmptyKeyPointer();
return false;
}
static bool isTombstoneKey(mlir::TypeRange range) {
if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
if (const auto *type =
llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
return type == getTombstoneKeyPointer();
return false;
}

View File

@ -163,12 +163,12 @@ public:
/// Return the value the effect is applied on, or nullptr if there isn't a
/// known value being affected.
Value getValue() const { return value ? value.dyn_cast<Value>() : Value(); }
Value getValue() const { return value ? llvm::dyn_cast_if_present<Value>(value) : Value(); }
/// Return the symbol reference the effect is applied on, or nullptr if there
/// isn't a known smbol being affected.
SymbolRefAttr getSymbolRef() const {
return value ? value.dyn_cast<SymbolRefAttr>() : SymbolRefAttr();
return value ? llvm::dyn_cast_if_present<SymbolRefAttr>(value) : SymbolRefAttr();
}
/// Return the resource that the effect applies to.

View File

@ -254,7 +254,7 @@ struct NestedAnalysisMap {
/// Returns the parent analysis map for this analysis map, or null if this is
/// the top-level map.
const NestedAnalysisMap *getParent() const {
return parentOrInstrumentor.dyn_cast<NestedAnalysisMap *>();
return llvm::dyn_cast_if_present<NestedAnalysisMap *>(parentOrInstrumentor);
}
/// Returns a pass instrumentation object for the current operation. This

View File

@ -89,7 +89,7 @@ void SparseConstantPropagation::visitOperation(
// Merge in the result of the fold, either a constant or a value.
OpFoldResult foldResult = std::get<1>(it);
if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
propagateIfChanged(lattice,
lattice->join(ConstantValue(attr, op->getDialect())));

View File

@ -31,7 +31,7 @@ void Executable::print(raw_ostream &os) const {
}
void Executable::onUpdate(DataFlowSolver *solver) const {
if (auto *block = point.dyn_cast<Block *>()) {
if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
// Re-invoke the analyses on the block itself.
for (DataFlowAnalysis *analysis : subscribers)
solver->enqueue({block, analysis});
@ -39,7 +39,7 @@ void Executable::onUpdate(DataFlowSolver *solver) const {
for (DataFlowAnalysis *analysis : subscribers)
for (Operation &op : *block)
solver->enqueue({&op, analysis});
} else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
} else if (auto *programPoint = llvm::dyn_cast_if_present<GenericProgramPoint *>(point)) {
// Re-invoke the analysis on the successor block.
if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
for (DataFlowAnalysis *analysis : subscribers)
@ -219,7 +219,7 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
if (point.is<Block *>())
return success();
auto *op = point.dyn_cast<Operation *>();
auto *op = llvm::dyn_cast_if_present<Operation *>(point);
if (!op)
return emitError(point.getLoc(), "unknown program point kind");

View File

@ -33,9 +33,9 @@ LogicalResult AbstractDenseDataFlowAnalysis::initialize(Operation *top) {
}
LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) {
if (auto *op = point.dyn_cast<Operation *>())
if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
processOperation(op);
else if (auto *block = point.dyn_cast<Block *>())
else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
visitBlock(block);
else
return failure();

View File

@ -181,7 +181,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
if (auto bound =
dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
return bound.getValue();
} else if (auto value = loopBound->dyn_cast<Value>()) {
} else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
const IntegerValueRangeLattice *lattice =
getLatticeElementFor(op, value);
if (lattice != nullptr)

View File

@ -66,9 +66,9 @@ AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
}
LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
if (Operation *op = point.dyn_cast<Operation *>())
if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
visitOperation(op);
else if (Block *block = point.dyn_cast<Block *>())
else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
visitBlock(block);
else
return failure();
@ -238,7 +238,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
unsigned firstIndex = 0;
if (inputs.size() != lattices.size()) {
if (point.dyn_cast<Operation *>()) {
if (llvm::dyn_cast_if_present<Operation *>(point)) {
if (!inputs.empty())
firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
visitNonControlFlowArgumentsImpl(
@ -316,9 +316,9 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
if (Operation *op = point.dyn_cast<Operation *>())
if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
visitOperation(op);
else if (point.dyn_cast<Block *>())
else if (llvm::dyn_cast_if_present<Block *>(point))
// For backward dataflow, we don't have to do any work for the blocks
// themselves. CFG edges between blocks are processed by the BranchOp
// logic in `visitOperation`, and entry blocks for functions are tied

View File

@ -39,21 +39,21 @@ void ProgramPoint::print(raw_ostream &os) const {
os << "<NULL POINT>";
return;
}
if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
return programPoint->print(os);
if (auto *op = dyn_cast<Operation *>())
if (auto *op = llvm::dyn_cast<Operation *>(*this))
return op->print(os);
if (auto value = dyn_cast<Value>())
if (auto value = llvm::dyn_cast<Value>(*this))
return value.print(os);
return get<Block *>()->print(os);
}
Location ProgramPoint::getLoc() const {
if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
return programPoint->getLoc();
if (auto *op = dyn_cast<Operation *>())
if (auto *op = llvm::dyn_cast<Operation *>(*this))
return op->getLoc();
if (auto value = dyn_cast<Value>())
if (auto value = llvm::dyn_cast<Value>(*this))
return value.getLoc();
return get<Block *>()->getParent()->getLoc();
}

View File

@ -2060,7 +2060,7 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
if (parseToken(Token::r_paren, "expected ')' in location"))
return failure();
if (auto *op = opOrArgument.dyn_cast<Operation *>())
if (auto *op = llvm::dyn_cast_if_present<Operation *>(opOrArgument))
op->setLoc(directLoc);
else
opOrArgument.get<BlockArgument>().setLoc(directLoc);

View File

@ -47,7 +47,7 @@ SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
DictionaryAttr attributeDict;
if (!mlirAttributeIsNull(attributes))
attributeDict = unwrap(attributes).cast<DictionaryAttr>();
attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
return attributeDict;
}

View File

@ -1190,9 +1190,9 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
// TODO: safer and more flexible to store data type in actual op instead?
static Type getSpMatElemType(Value spMat) {
if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
return op.getValues().getType().cast<MemRefType>().getElementType();
return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
return op.getValues().getType().cast<MemRefType>().getElementType();
return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
llvm_unreachable("cannot find spmat def");
}
@ -1235,7 +1235,7 @@ LogicalResult ConvertCreateDnVecOpToGpuRuntimeCallPattern::matchAndRewrite(
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
if (!getTypeConverter()->useOpaquePointers())
pVec = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVec);
Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
dType.getIntOrFloatBitWidth());
auto handle =
@ -1271,7 +1271,7 @@ LogicalResult ConvertCreateDnMatOpToGpuRuntimeCallPattern::matchAndRewrite(
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
if (!getTypeConverter()->useOpaquePointers())
pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
dType.getIntOrFloatBitWidth());
auto handle =
@ -1315,8 +1315,8 @@ LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
}
Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
auto iw = rewriter.create<LLVM::ConstantOp>(
loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
@ -1350,9 +1350,9 @@ LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
}
Type pType = op.getRowPos().getType().cast<MemRefType>().getElementType();
Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
Type pType = llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
auto pw = rewriter.create<LLVM::ConstantOp>(
loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth());
auto iw = rewriter.create<LLVM::ConstantOp>(

View File

@ -405,7 +405,7 @@ LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) {
return failure();
if (!(*converted)) // Conversion to default is 0.
return 0;
if (auto explicitSpace = converted->dyn_cast_or_null<IntegerAttr>())
if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
return explicitSpace.getInt();
return failure();
}

View File

@ -671,7 +671,7 @@ struct GlobalMemrefOpLowering
Attribute initialValue = nullptr;
if (!global.isExternal() && !global.isUninitialized()) {
auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
initialValue = elementsAttr;
// For scalar memrefs, the global variable created is of the element type,

View File

@ -412,10 +412,10 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
auto *ans = cast<TypeAnswer>(answer);
if (isa<pdl::RangeType>(val.getType()))
builder.create<pdl_interp::CheckTypesOp>(
loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
else
builder.create<pdl_interp::CheckTypeOp>(
loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
break;
}
case Predicates::AttributeQuestion: {

View File

@ -300,7 +300,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
// tosa::ErfOp
if (isa<tosa::ErfOp>(op) && elementTy.isa<FloatType>())
if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
// tosa::GreaterOp
@ -1885,7 +1885,7 @@ public:
auto addDynamicDimension = [&](Value source, int64_t dim) {
auto dynamicDim = tensor::createDimValue(builder, loc, source, dim);
if (auto dimValue = dynamicDim.value().dyn_cast<Value>())
if (auto dimValue = llvm::dyn_cast_if_present<Value>(dynamicDim.value()))
results.push_back(dimValue);
};

View File

@ -121,11 +121,11 @@ void mlirDebuggerCursorSelectParentIRUnit() {
return;
}
IRUnit *unit = &state.cursor;
if (auto *op = unit->dyn_cast<Operation *>()) {
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
state.cursor = op->getBlock();
} else if (auto *region = unit->dyn_cast<Region *>()) {
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
state.cursor = region->getParentOp();
} else if (auto *block = unit->dyn_cast<Block *>()) {
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
state.cursor = block->getParent();
} else {
llvm::outs() << "Current cursor is not a valid IRUnit";
@ -142,14 +142,14 @@ void mlirDebuggerCursorSelectChildIRUnit(int index) {
return;
}
IRUnit *unit = &state.cursor;
if (auto *op = unit->dyn_cast<Operation *>()) {
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
llvm::outs() << "Index invalid, op has " << op->getNumRegions()
<< " but got " << index << "\n";
return;
}
state.cursor = &op->getRegion(index);
} else if (auto *region = unit->dyn_cast<Region *>()) {
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
auto block = region->begin();
int count = 0;
while (block != region->end() && count != index) {
@ -163,7 +163,7 @@ void mlirDebuggerCursorSelectChildIRUnit(int index) {
return;
}
state.cursor = &*block;
} else if (auto *block = unit->dyn_cast<Block *>()) {
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
auto op = block->begin();
int count = 0;
while (op != block->end() && count != index) {
@ -192,14 +192,14 @@ void mlirDebuggerCursorSelectPreviousIRUnit() {
return;
}
IRUnit *unit = &state.cursor;
if (auto *op = unit->dyn_cast<Operation *>()) {
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
Operation *previous = op->getPrevNode();
if (!previous) {
llvm::outs() << "No previous operation in the current block\n";
return;
}
state.cursor = previous;
} else if (auto *region = unit->dyn_cast<Region *>()) {
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
llvm::outs() << "Has region\n";
Operation *parent = region->getParentOp();
if (!parent) {
@ -212,7 +212,7 @@ void mlirDebuggerCursorSelectPreviousIRUnit() {
}
state.cursor =
&region->getParentOp()->getRegion(region->getRegionNumber() - 1);
} else if (auto *block = unit->dyn_cast<Block *>()) {
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
Block *previous = block->getPrevNode();
if (!previous) {
llvm::outs() << "No previous block in the current region\n";
@ -234,14 +234,14 @@ void mlirDebuggerCursorSelectNextIRUnit() {
return;
}
IRUnit *unit = &state.cursor;
if (auto *op = unit->dyn_cast<Operation *>()) {
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
Operation *next = op->getNextNode();
if (!next) {
llvm::outs() << "No next operation in the current block\n";
return;
}
state.cursor = next;
} else if (auto *region = unit->dyn_cast<Region *>()) {
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
Operation *parent = region->getParentOp();
if (!parent) {
llvm::outs() << "No parent operation for the current region\n";
@ -253,7 +253,7 @@ void mlirDebuggerCursorSelectNextIRUnit() {
}
state.cursor =
&region->getParentOp()->getRegion(region->getRegionNumber() + 1);
} else if (auto *block = unit->dyn_cast<Block *>()) {
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
Block *next = block->getNextNode();
if (!next) {
llvm::outs() << "No next block in the current region\n";

View File

@ -1212,7 +1212,7 @@ static void materializeConstants(OpBuilder &b, Location loc,
actualValues.reserve(values.size());
auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
for (OpFoldResult ofr : values) {
if (auto value = ofr.dyn_cast<Value>()) {
if (auto value = llvm::dyn_cast_if_present<Value>(ofr)) {
actualValues.push_back(value);
continue;
}
@ -4599,7 +4599,7 @@ void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
if (staticDim.has_value())
return builder.create<arith::ConstantIndexOp>(result.location,
*staticDim);
return ofr.dyn_cast<Value>();
return llvm::dyn_cast_if_present<Value>(ofr);
});
result.addOperands(basisValues);
}

View File

@ -808,7 +808,7 @@ OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
if (matchPattern(getRhs(), m_Zero()))
return getLhs();
/// or(x, <all ones>) -> <all ones>
if (auto rhsAttr = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
if (auto rhsAttr = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()))
if (rhsAttr.getValue().isAllOnes())
return rhsAttr;
@ -1249,7 +1249,7 @@ LogicalResult arith::ExtSIOp::verify() {
/// Always fold extension of FP constants.
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
auto constOperand = adaptor.getIn().dyn_cast_or_null<FloatAttr>();
auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
if (!constOperand)
return {};
@ -1702,7 +1702,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
// We are moving constants to the right side; So if lhs is constant rhs is
// guaranteed to be a constant.
if (auto lhs = adaptor.getLhs().dyn_cast_or_null<TypedAttr>()) {
if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getI1SameShape(lhs.getType()),
[pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
@ -1772,8 +1772,8 @@ bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
}
OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
auto lhs = adaptor.getLhs().dyn_cast_or_null<FloatAttr>();
auto rhs = adaptor.getRhs().dyn_cast_or_null<FloatAttr>();
auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
// If one operand is NaN, making them both NaN does not change the result.
if (lhs && lhs.getValue().isNaN())
@ -2193,11 +2193,11 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
// Constant-fold constant operands over non-splat constant condition.
// select %cst_vec, %cst0, %cst1 => %cst2
if (auto cond =
adaptor.getCondition().dyn_cast_or_null<DenseElementsAttr>()) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
if (auto lhs =
adaptor.getTrueValue().dyn_cast_or_null<DenseElementsAttr>()) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
if (auto rhs =
adaptor.getFalseValue().dyn_cast_or_null<DenseElementsAttr>()) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
SmallVector<Attribute> results;
results.reserve(static_cast<size_t>(cond.getNumElements()));
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),

View File

@ -184,7 +184,7 @@ struct SelectOpInterface
// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = trueType->cast<MemRefType>();
auto memrefType = llvm::cast<MemRefType>(*trueType);
return getMemRefTypeWithFullyDynamicLayout(
RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType()),

View File

@ -33,8 +33,8 @@ LogicalResult mlir::foldDynamicIndexList(Builder &b,
if (ofr.is<Attribute>())
continue;
// Newly static, move from Value to constant.
if (auto cstOp =
ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>()) {
if (auto cstOp = llvm::dyn_cast_if_present<Value>(ofr)
.getDefiningOp<arith::ConstantIndexOp>()) {
ofr = b.getIndexAttr(cstOp.value());
valuesChanged = true;
}
@ -56,9 +56,9 @@ llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
OpFoldResult ofr) {
if (auto value = ofr.dyn_cast<Value>())
if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
return value;
auto attr = dyn_cast<IntegerAttr>(ofr.dyn_cast<Attribute>());
auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
assert(attr && "expect the op fold result casts to an integer attribute");
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
}

View File

@ -179,7 +179,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
}
FailureOr<Value> alloc = options.createAlloc(
rewriter, loc, allocType->cast<MemRefType>(), dynamicDims);
rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
if (failed(alloc))
return failure();

View File

@ -59,7 +59,8 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
/// Return the func::FuncOp called by `callOp`.
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(

View File

@ -80,7 +80,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<FuncOp>(

View File

@ -995,7 +995,7 @@ static void annotateOpsWithAliasSets(Operation *op,
op->walk([&](Operation *op) {
SmallVector<Attribute> aliasSets;
for (OpResult opResult : op->getOpResults()) {
if (opResult.getType().isa<TensorType>()) {
if (llvm::isa<TensorType>(opResult.getType())) {
SmallVector<Attribute> aliases;
state.applyOnAliases(opResult, [&](Value alias) {
std::string buffer;

View File

@ -238,7 +238,7 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
/// Return the func::FuncOp called by `callOp`.
static func::FuncOp getCalledFunction(func::CallOp callOp) {
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(

View File

@ -90,7 +90,8 @@ OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
ArrayAttr arrayAttr =
llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
if (arrayAttr && arrayAttr.size() == 2)
return arrayAttr[1];
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
@ -103,7 +104,8 @@ OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
ArrayAttr arrayAttr =
llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
if (arrayAttr && arrayAttr.size() == 2)
return arrayAttr[0];
if (auto createOp = getOperand().getDefiningOp<CreateOp>())

View File

@ -94,7 +94,7 @@ DataLayoutEntryAttr DataLayoutEntryAttr::parse(AsmParser &parser) {
void DataLayoutEntryAttr::print(AsmPrinter &os) const {
os << DataLayoutEntryAttr::kAttrKeyword << "<";
if (auto type = getKey().dyn_cast<Type>())
if (auto type = llvm::dyn_cast_if_present<Type>(getKey()))
os << type;
else
os << "\"" << getKey().get<StringAttr>().strref() << "\"";
@ -151,7 +151,7 @@ DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
DenseSet<Type> types;
DenseSet<StringAttr> ids;
for (DataLayoutEntryInterface entry : entries) {
if (auto type = entry.getKey().dyn_cast<Type>()) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
if (!types.insert(type).second)
return emitError() << "repeated layout entry key: " << type;
} else {

View File

@ -493,7 +493,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
// error. All other canonicalization is done in the fold method.
bool requiresConst = !rawConstantIndices.empty() &&
currType.isa_and_nonnull<LLVMStructType>();
if (Value val = iter.dyn_cast<Value>()) {
if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
APInt intC;
if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
intC.isSignedIntN(kGEPConstantBitWidth)) {
@ -598,7 +598,7 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
llvm::interleaveComma(
GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
[&](PointerUnion<IntegerAttr, Value> cst) {
if (Value val = cst.dyn_cast<Value>())
if (Value val = llvm::dyn_cast_if_present<Value>(cst))
printer.printOperand(val);
else
printer << cst.get<IntegerAttr>().getInt();
@ -2495,7 +2495,7 @@ OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
!integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
if (Value val = existing.dyn_cast<Value>())
if (Value val = llvm::dyn_cast_if_present<Value>(existing))
gepArgs.emplace_back(val);
else
gepArgs.emplace_back(existing.get<IntegerAttr>().getInt());

View File

@ -261,7 +261,7 @@ DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
return llvm::all_of(gepOp.getIndices(), [](auto index) {
auto indexAttr = index.template dyn_cast<IntegerAttr>();
auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
return indexAttr && indexAttr.getValue() == 0;
});
}
@ -289,7 +289,7 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
// Ensures all indices are static and fetches them.
SmallVector<IntegerAttr> indices;
for (auto index : gep.getIndices()) {
IntegerAttr indexInt = index.dyn_cast<IntegerAttr>();
IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
if (!indexInt)
return {};
indices.push_back(indexInt);
@ -310,7 +310,7 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
for (IntegerAttr index : llvm::drop_begin(indices)) {
// Ensure the structure of the type being indexed can be reasoned about.
// This includes rejecting any potential typed pointer.
auto destructurable = selectedType.dyn_cast<DestructurableTypeInterface>();
auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
if (!destructurable)
return {};
@ -343,7 +343,7 @@ LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
auto basePtrType = getBase().getType().dyn_cast<LLVM::LLVMPointerType>();
auto basePtrType = llvm::dyn_cast<LLVM::LLVMPointerType>(getBase().getType());
if (!basePtrType)
return false;
@ -359,7 +359,7 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
return false;
auto firstLevelIndex = cast<IntegerAttr>(getIndices()[1]);
assert(slot.elementPtrs.contains(firstLevelIndex));
if (!slot.elementPtrs.at(firstLevelIndex).isa<LLVM::LLVMPointerType>())
if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
return false;
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
usedIndices.insert(firstLevelIndex);
@ -369,7 +369,7 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter) {
IntegerAttr firstLevelIndex = getIndices()[1].dyn_cast<IntegerAttr>();
IntegerAttr firstLevelIndex = llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
const MemorySlot &newSlot = subslots.at(firstLevelIndex);
ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
@ -414,7 +414,7 @@ LLVM::LLVMStructType::getSubelementIndexMap() {
}
Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) {
auto indexAttr = index.dyn_cast<IntegerAttr>();
auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
if (!indexAttr || !indexAttr.getType().isInteger(32))
return {};
int32_t indexInt = indexAttr.getInt();
@ -439,7 +439,7 @@ LLVM::LLVMArrayType::getSubelementIndexMap() const {
}
Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
auto indexAttr = index.dyn_cast<IntegerAttr>();
auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
if (!indexAttr || !indexAttr.getType().isInteger(32))
return {};
int32_t indexInt = indexAttr.getInt();

View File

@ -354,7 +354,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
auto newType = llvm::cast<LLVMPointerType>(newEntry.getKey().get<Type>());
const auto *it =
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
if (auto type = entry.getKey().dyn_cast<Type>()) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
newType.getAddressSpace();
}
@ -362,7 +362,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
});
if (it == oldLayout.end()) {
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
if (auto type = entry.getKey().dyn_cast<Type>()) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
}
return false;

View File

@ -2368,7 +2368,7 @@ transform::TileOp::apply(TransformResults &transformResults,
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
for (OpFoldResult ofr : getMixedSizes()) {
if (auto attr = ofr.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
getLoc(), cast<IntegerAttr>(attr).getInt()));
continue;
@ -2794,7 +2794,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
for (OpFoldResult ofr : getMixedSizes()) {
if (auto attr = ofr.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
getLoc(), cast<IntegerAttr>(attr).getInt()));
} else {

View File

@ -1447,7 +1447,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
cast<LinalgOp>(genericOp.getOperation())
.createLoopRanges(rewriter, genericOp.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
if (auto attr = ofr.dyn_cast<Attribute>())
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
return cast<IntegerAttr>(attr).getInt() == value;
llvm::APInt actual;
return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&

View File

@ -229,7 +229,7 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
// to look for the bound.
LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
Value size;
if (auto attr = rangeValue.size.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
} else {
Value materializedSize =

View File

@ -92,7 +92,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
rewriter, op.getLoc(), d0 + d1 - d2,
{iterationSpace[dimension].offset, iterationSpace[dimension].size,
minSplitPoint});
if (auto attr = remainingSize.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) {
if (cast<IntegerAttr>(attr).getValue().isZero())
return {op, TilingInterface()};
}

View File

@ -48,7 +48,7 @@ using namespace mlir::scf;
static bool isZero(OpFoldResult v) {
if (!v)
return false;
if (auto attr = v.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getValue().isZero();
}
@ -104,7 +104,7 @@ void mlir::linalg::transformIndexOps(
/// checked at runtime.
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
OpFoldResult value) {
if (auto attr = value.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
"expected strictly positive tile size and divisor");
return;

View File

@ -1135,7 +1135,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const {
// Given an OpFoldResult, return an index-typed value.
auto getIdxValue = [&](OpFoldResult ofr) {
if (auto val = ofr.dyn_cast<Value>())
if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
return val;
return rewriter
.create<arith::ConstantIndexOp>(

View File

@ -1646,7 +1646,7 @@ static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> ofrs) {
SmallVector<Value> result;
for (auto o : ofrs) {
if (auto val = o.template dyn_cast<Value>()) {
if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
result.push_back(val);
} else {
result.push_back(rewriter.create<arith::ConstantIndexOp>(
@ -1954,8 +1954,8 @@ struct PadOpVectorizationWithTransferWritePattern
continue;
// Other cases: Take a deeper look at defining ops of values.
auto v1 = size1.dyn_cast<Value>();
auto v2 = size2.dyn_cast<Value>();
auto v1 = llvm::dyn_cast_if_present<Value>(size1);
auto v2 = llvm::dyn_cast_if_present<Value>(size2);
if (!v1 || !v2)
return false;

View File

@ -970,7 +970,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
auto dim = it.index();
auto size = it.value();
curr.push_back(dim);
auto attr = size.dyn_cast<Attribute>();
auto attr = llvm::dyn_cast_if_present<Attribute>(size);
if (attr && cast<IntegerAttr>(attr).getInt() == 1)
continue;
reassociation.emplace_back(ReassociationIndices{});

View File

@ -64,7 +64,7 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
//===----------------------------------------------------------------------===//
static bool isSupportedElementType(Type type) {
return type.isa<MemRefType>() ||
return llvm::isa<MemRefType>(type) ||
OpBuilder(type.getContext()).getZeroAttr(type);
}
@ -110,7 +110,7 @@ void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
SmallVector<DestructurableMemorySlot>
memref::AllocaOp::getDestructurableSlots() {
MemRefType memrefType = getType();
auto destructurable = memrefType.dyn_cast<DestructurableTypeInterface>();
auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
if (!destructurable)
return {};
@ -134,7 +134,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> slotMap;
auto memrefType = getType().cast<DestructurableTypeInterface>();
auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
for (Attribute usedIndex : usedIndices) {
Type elemType = memrefType.getTypeAtIndex(usedIndex);
MemRefType elemPtr = MemRefType::get({}, elemType);
@ -281,7 +281,7 @@ struct MemRefDestructurableTypeExternalModel
MemRefDestructurableTypeExternalModel, MemRefType> {
std::optional<DenseMap<Attribute, Type>>
getSubelementIndexMap(Type type) const {
auto memrefType = type.cast<MemRefType>();
auto memrefType = llvm::cast<MemRefType>(type);
constexpr int64_t maxMemrefSizeForDestructuring = 16;
if (!memrefType.hasStaticShape() ||
memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
@ -298,15 +298,15 @@ struct MemRefDestructurableTypeExternalModel
}
Type getTypeAtIndex(Type type, Attribute index) const {
auto memrefType = type.cast<MemRefType>();
auto coordArrAttr = index.dyn_cast<ArrayAttr>();
auto memrefType = llvm::cast<MemRefType>(type);
auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
return {};
Type indexType = IndexType::get(memrefType.getContext());
for (const auto &[coordAttr, dimSize] :
llvm::zip(coordArrAttr, memrefType.getShape())) {
auto coord = coordAttr.dyn_cast<IntegerAttr>();
auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
coord.getInt() >= dimSize)
return {};

View File

@ -970,7 +970,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
return unusedDims;
for (const auto &dim : llvm::enumerate(sizes))
if (auto attr = dim.value().dyn_cast<Attribute>())
if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
unusedDims.set(dim.index());
@ -1042,7 +1042,7 @@ llvm::SmallBitVector SubViewOp::getDroppedDims() {
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
if (!index)
return {};

View File

@ -56,7 +56,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
// Because we only support input strides of 1, the output stride is also
// always 1.
if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
Attribute attr = valueOrAttr.dyn_cast<Attribute>();
Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
return attr && cast<IntegerAttr>(attr).getInt() == 1;
})) {
strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
@ -86,8 +86,9 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
}
sizes.push_back(opSize);
Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(),
sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>();
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
sourceOffsetAttr =
llvm::dyn_cast_if_present<Attribute>(sourceOffset);
if (opOffsetAttr && sourceOffsetAttr) {
// If both offsets are static we can simply calculate the combined
@ -101,7 +102,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
AffineExpr expr = rewriter.getAffineConstantExpr(0);
SmallVector<Value> affineApplyOperands;
for (auto valueOrAttr : {opOffset, sourceOffset}) {
if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
expr = expr + cast<IntegerAttr>(attr).getInt();
} else {
expr =

View File

@ -520,7 +520,7 @@ checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
<< operandName << " operand appears more than once";
mlir::Type varType = operand.getType();
auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
if (!decl)
return op->emitOpError()

View File

@ -802,10 +802,10 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange map_operands,
for (const auto &mapTypeOp : *map_types) {
int64_t mapTypeBits = 0x00;
if (!mapTypeOp.isa<mlir::IntegerAttr>())
if (!llvm::isa<mlir::IntegerAttr>(mapTypeOp))
return failure();
mapTypeBits = mapTypeOp.cast<mlir::IntegerAttr>().getInt();
mapTypeBits = llvm::cast<mlir::IntegerAttr>(mapTypeOp).getInt();
bool to =
bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);

View File

@ -381,7 +381,7 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
// map.
auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
#ifndef NDEBUG
auto iterRanked = initArgBufferType->cast<MemRefType>();
auto iterRanked = llvm::cast<MemRefType>(*initArgBufferType);
assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
"expected same shape");
assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() &&
@ -802,7 +802,7 @@ struct WhileOpInterface
if (!isa<TensorType>(bbArg.getType()))
return bbArg.getType();
// TODO: error handling
return bufferization::getBufferType(bbArg, options)->cast<Type>();
return llvm::cast<Type>(*bufferization::getBufferType(bbArg, options));
}));
// Construct a new scf.while op with memref instead of tensor values.

View File

@ -88,10 +88,10 @@ LogicalResult scf::addLoopRangeConstraints(FlatAffineValueConstraints &cstr,
return failure();
unsigned dimIv = cstr.appendDimVar(iv);
auto lbv = lb.dyn_cast<Value>();
auto lbv = llvm::dyn_cast_if_present<Value>(lb);
unsigned symLb =
lbv ? cstr.appendSymbolVar(lbv) : cstr.appendSymbolVar(/*num=*/1);
auto ubv = ub.dyn_cast<Value>();
auto ubv = llvm::dyn_cast_if_present<Value>(ub);
unsigned symUb =
ubv ? cstr.appendSymbolVar(ubv) : cstr.appendSymbolVar(/*num=*/1);

View File

@ -152,7 +152,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
if (getIndices().size() == 1 &&
constructOp.getConstituents().size() == type.getNumElements()) {
auto i = getIndices().begin()->cast<IntegerAttr>();
auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
return constructOp.getConstituents()[i.getValue().getSExtValue()];
}
}

View File

@ -1562,8 +1562,8 @@ LogicalResult spirv::BitcastOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::ConvertPtrToUOp::verify() {
auto operandType = getPointer().getType().cast<spirv::PointerType>();
auto resultType = getResult().getType().cast<spirv::ScalarType>();
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
if (!resultType || !resultType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
@ -1583,8 +1583,8 @@ LogicalResult spirv::ConvertPtrToUOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::ConvertUToPtrOp::verify() {
auto operandType = getOperand().getType().cast<spirv::ScalarType>();
auto resultType = getResult().getType().cast<spirv::PointerType>();
auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
if (!operandType || !operandType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();

View File

@ -125,23 +125,23 @@ Type CompositeType::getElementType(unsigned index) const {
}
unsigned CompositeType::getNumElements() const {
if (auto arrayType = dyn_cast<ArrayType>())
if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
return arrayType.getNumElements();
if (auto matrixType = dyn_cast<MatrixType>())
if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
return matrixType.getNumColumns();
if (auto structType = dyn_cast<StructType>())
if (auto structType = llvm::dyn_cast<StructType>(*this))
return structType.getNumElements();
if (auto vectorType = dyn_cast<VectorType>())
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
if (isa<CooperativeMatrixNVType>()) {
if (llvm::isa<CooperativeMatrixNVType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv::CooperativeMatrix type");
}
if (isa<JointMatrixINTELType>()) {
if (llvm::isa<JointMatrixINTELType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv::JointMatrix type");
}
if (isa<RuntimeArrayType>()) {
if (llvm::isa<RuntimeArrayType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv::RuntimeArray type");
}
@ -149,8 +149,8 @@ unsigned CompositeType::getNumElements() const {
}
bool CompositeType::hasCompileTimeKnownNumElements() const {
return !isa<CooperativeMatrixNVType, JointMatrixINTELType,
RuntimeArrayType>();
return !llvm::isa<CooperativeMatrixNVType, JointMatrixINTELType,
RuntimeArrayType>(*this);
}
void CompositeType::getExtensions(
@ -188,11 +188,11 @@ void CompositeType::getCapabilities(
}
std::optional<int64_t> CompositeType::getSizeInBytes() {
if (auto arrayType = dyn_cast<ArrayType>())
if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
return arrayType.getSizeInBytes();
if (auto structType = dyn_cast<StructType>())
if (auto structType = llvm::dyn_cast<StructType>(*this))
return structType.getSizeInBytes();
if (auto vectorType = dyn_cast<VectorType>()) {
if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
std::optional<int64_t> elementSize =
llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
if (!elementSize)
@ -680,7 +680,7 @@ void ScalarType::getCapabilities(
capabilities.push_back(ref); \
} break
if (auto intType = dyn_cast<IntegerType>()) {
if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
switch (bitwidth) {
WIDTH_CASE(Int, 8);
WIDTH_CASE(Int, 16);
@ -692,7 +692,7 @@ void ScalarType::getCapabilities(
llvm_unreachable("invalid bitwidth to getCapabilities");
}
} else {
assert(isa<FloatType>());
assert(llvm::isa<FloatType>(*this));
switch (bitwidth) {
WIDTH_CASE(Float, 16);
WIDTH_CASE(Float, 64);
@ -735,22 +735,22 @@ bool SPIRVType::classof(Type type) {
}
bool SPIRVType::isScalarOrVector() {
return isIntOrFloat() || isa<VectorType>();
return isIntOrFloat() || llvm::isa<VectorType>(*this);
}
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
std::optional<StorageClass> storage) {
if (auto scalarType = dyn_cast<ScalarType>()) {
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
scalarType.getExtensions(extensions, storage);
} else if (auto compositeType = dyn_cast<CompositeType>()) {
} else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
compositeType.getExtensions(extensions, storage);
} else if (auto imageType = dyn_cast<ImageType>()) {
} else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
imageType.getExtensions(extensions, storage);
} else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
} else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
sampledImageType.getExtensions(extensions, storage);
} else if (auto matrixType = dyn_cast<MatrixType>()) {
} else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
matrixType.getExtensions(extensions, storage);
} else if (auto ptrType = dyn_cast<PointerType>()) {
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getExtensions(extensions, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getExtensions");
@ -760,17 +760,17 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
void SPIRVType::getCapabilities(
SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage) {
if (auto scalarType = dyn_cast<ScalarType>()) {
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
scalarType.getCapabilities(capabilities, storage);
} else if (auto compositeType = dyn_cast<CompositeType>()) {
} else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
compositeType.getCapabilities(capabilities, storage);
} else if (auto imageType = dyn_cast<ImageType>()) {
} else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
imageType.getCapabilities(capabilities, storage);
} else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
} else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
sampledImageType.getCapabilities(capabilities, storage);
} else if (auto matrixType = dyn_cast<MatrixType>()) {
} else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
matrixType.getCapabilities(capabilities, storage);
} else if (auto ptrType = dyn_cast<PointerType>()) {
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getCapabilities(capabilities, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
@ -778,9 +778,9 @@ void SPIRVType::getCapabilities(
}
std::optional<int64_t> SPIRVType::getSizeInBytes() {
if (auto scalarType = dyn_cast<ScalarType>())
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
return scalarType.getSizeInBytes();
if (auto compositeType = dyn_cast<CompositeType>())
if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
return compositeType.getSizeInBytes();
return std::nullopt;
}

View File

@ -856,9 +856,9 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getLhs() || !adaptor.getRhs())
return nullptr;
auto lhsShape = llvm::to_vector<6>(
adaptor.getLhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
auto rhsShape = llvm::to_vector<6>(
adaptor.getRhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
resultShape.append(lhsShape.begin(), lhsShape.end());
resultShape.append(rhsShape.begin(), rhsShape.end());
@ -989,7 +989,7 @@ OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
if (!operand)
return false;
extents.push_back(llvm::to_vector<6>(
operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
}
return OpTrait::util::staticallyKnownBroadcastable(extents);
}())
@ -1132,10 +1132,10 @@ LogicalResult mlir::shape::DimOp::verify() {
//===----------------------------------------------------------------------===//
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
if (!lhs)
return nullptr;
auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
if (!rhs)
return nullptr;
@ -1346,7 +1346,7 @@ std::optional<int64_t> GetExtentOp::getConstantDim() {
}
OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
auto elements = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
if (!elements)
return nullptr;
std::optional<int64_t> dim = getConstantDim();
@ -1490,7 +1490,7 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
//===----------------------------------------------------------------------===//
OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
auto shape = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
if (!shape)
return {};
int64_t rank = shape.getNumElements();
@ -1671,10 +1671,10 @@ bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
//===----------------------------------------------------------------------===//
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
if (!lhs)
return nullptr;
auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
if (!rhs)
return nullptr;
APInt folded = lhs.getValue() * rhs.getValue();
@ -1864,9 +1864,9 @@ LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
if (!adaptor.getOperand() || !adaptor.getIndex())
return failure();
auto shapeVec = llvm::to_vector<6>(
adaptor.getOperand().cast<DenseIntElementsAttr>().getValues<int64_t>());
llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
auto shape = llvm::ArrayRef(shapeVec);
auto splitPoint = adaptor.getIndex().cast<IntegerAttr>().getInt();
auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
// Verify that the split point is in the correct range.
// TODO: Constant fold to an "error".
int64_t rank = shape.size();
@ -1889,7 +1889,7 @@ OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
return OpFoldResult();
Builder builder(getContext());
auto shape = llvm::to_vector<6>(
adaptor.getInput().cast<DenseIntElementsAttr>().getValues<int64_t>());
llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
builder.getIndexType());
return DenseIntElementsAttr::get(type, shape);

View File

@ -815,7 +815,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
Level cooStartLvl = getCOOStart(stt.getEncoding());
if (cooStartLvl < stt.getLvlRank()) {
// We only supports trailing COO for now, must be the last input.
auto cooTp = lvlTps.back().cast<ShapedType>();
auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
// The coordinates should be in shape of <? x rank>
unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
@ -844,7 +844,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
inputTp = lvlTps[idx++];
}
// The input element type and expected element type should match.
Type inpElemTp = inputTp.cast<TensorType>().getElementType();
Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
Type expElemTp = getFieldElemType(stt, fKind);
if (inpElemTp != expElemTp) {
misMatch = true;

View File

@ -188,7 +188,7 @@ static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
/// Generates a memref from tensor operation.
static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
Value tensor) {
auto tensorType = tensor.getType().cast<ShapedType>();
auto tensorType = llvm::cast<ShapedType>(tensor.getType());
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);

View File

@ -414,7 +414,7 @@ public:
/// TODO: better unord/not-unique; also generalize, optimize, specialize!
SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
OpBuilder &builder, Location loc) {
const SparseTensorType stt(rtp.cast<RankedTensorType>());
const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
const Level lvlRank = stt.getLvlRank();
// Extract fields and coordinates from args.
SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
@ -466,7 +466,7 @@ public:
// The mangled name of the function has this format:
// <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
constexpr const char kInsertFuncNamePrefix[] = "_insert_";
const SparseTensorType stt(rtp.cast<RankedTensorType>());
const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
@ -541,14 +541,14 @@ static void genEndInsert(OpBuilder &builder, Location loc,
static TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor) {
auto tTp = tensor.getType().cast<TensorType>();
auto tTp = llvm::cast<TensorType>(tensor.getType());
auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
.getResult();
}
Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz) {
auto elemTp = mem.getType().cast<MemRefType>().getElementType();
auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
return builder
.create<memref::SubViewOp>(
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,

View File

@ -180,7 +180,7 @@ struct ReifyPadOp
AffineExpr expr = b.getAffineDimExpr(0);
unsigned numSymbols = 0;
auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
if (Value v = valueOrAttr.dyn_cast<Value>()) {
if (Value v = llvm::dyn_cast_if_present<Value>(valueOrAttr)) {
expr = expr + b.getAffineSymbolExpr(numSymbols++);
mapOperands.push_back(v);
return;

View File

@ -501,7 +501,7 @@ Speculation::Speculatability DimOp::getSpeculatability() {
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
if (!index)
return {};
@ -764,7 +764,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
OpFoldResult currDim = std::get<1>(it);
// Case 1: The empty tensor dim is static. Check that the tensor cast
// result dim matches.
if (auto attr = currDim.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
if (ShapedType::isDynamic(newDim) ||
newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
// Something is off, the cast result shape cannot be more dynamic
@ -2106,7 +2106,7 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
}
OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
if (auto splat = adaptor.getSource().dyn_cast_or_null<SplatElementsAttr>()) {
if (auto splat = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
auto resultType = llvm::cast<ShapedType>(getResult().getType());
if (resultType.hasStaticShape())
return splat.resizeSplat(resultType);
@ -3558,7 +3558,7 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
SmallVector<int64_t> result;
for (auto o : ofrs) {
// Have to do this first, as getConstantIntValue special-cases constants.
if (o.dyn_cast<Value>())
if (llvm::dyn_cast_if_present<Value>(o))
result.push_back(ShapedType::kDynamic);
else
result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));

View File

@ -76,7 +76,7 @@ struct CastOpInterface
auto rankedResultType = cast<RankedTensorType>(castOp.getType());
return MemRefType::get(
rankedResultType.getShape(), rankedResultType.getElementType(),
maybeSrcBufferType->cast<MemRefType>().getLayout(), memorySpace);
llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@ -139,7 +139,7 @@ struct CollapseShapeOpInterface
collapseShapeOp.getSrc(), options, fixedTypes);
if (failed(maybeSrcBufferType))
return failure();
auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
srcBufferType, collapseShapeOp.getReassociationIndices());
@ -303,7 +303,7 @@ struct ExpandShapeOpInterface
expandShapeOp.getSrc(), options, fixedTypes);
if (failed(maybeSrcBufferType))
return failure();
auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
srcBufferType, expandShapeOp.getResultType().getShape(),
expandShapeOp.getReassociationIndices());
@ -369,7 +369,7 @@ struct ExtractSliceOpInterface
if (failed(resultMemrefType))
return failure();
Value subView = rewriter.create<memref::SubViewOp>(
loc, resultMemrefType->cast<MemRefType>(), *srcMemref, mixedOffsets,
loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
mixedSizes, mixedStrides);
replaceOpWithBufferizedValues(rewriter, op, subView);
@ -389,7 +389,7 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
extractSliceOp.getType().getShape(), srcMemrefType->cast<MemRefType>(),
extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
mixedOffsets, mixedSizes, mixedStrides));
}
};

View File

@ -548,8 +548,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
@ -573,8 +573,8 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsAttr && lhsAttr.isSplat()) {
if (llvm::isa<IntegerType>(resultETy) &&
lhsAttr.getSplatValue<APInt>().isZero())
@ -642,8 +642,8 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
if (rhsTy == resultTy) {
@ -670,8 +670,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
return {};
auto resultETy = resultTy.getElementType();
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
return getInput1();
@ -713,8 +713,8 @@ struct APIntFoldGreaterEqual {
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (!lhsAttr || !rhsAttr)
return {};
@ -725,8 +725,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
if (!lhsAttr || !rhsAttr)
return {};
@ -738,8 +738,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
Value lhs = getInput1();
Value rhs = getInput2();
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
@ -763,7 +763,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
if (getInput().getType() == getType())
return getInput();
auto operand = adaptor.getInput().dyn_cast_or_null<ElementsAttr>();
auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
if (!operand)
return {};
@ -852,7 +852,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (inputTy == outputTy)
return getInput1();
auto operand = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
@ -863,7 +863,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
// If the pad is all zeros we can fold this operation away.
if (adaptor.getPadding()) {
auto densePad = adaptor.getPadding().cast<DenseElementsAttr>();
auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
return getInput1();
}
@ -907,7 +907,7 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
auto operand = getInput();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
auto axis = getAxis();
auto operandAttr = adaptor.getInput().dyn_cast_or_null<SplatElementsAttr>();
auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
if (operandAttr)
return operandAttr;
@ -936,7 +936,7 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
!outputTy.getElementType().isIntOrIndexOrFloat())
return {};
auto operand = adaptor.getInput().cast<ElementsAttr>();
auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
if (operand.isSplat() && outputTy.hasStaticShape()) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
@ -955,7 +955,7 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
if (getOnTrue() == getOnFalse())
return getOnTrue();
auto predicate = adaptor.getPred().dyn_cast_or_null<DenseIntElementsAttr>();
auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
if (!predicate)
return {};
@ -977,7 +977,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::cast<ShapedType>(getType());
// Transposing splat values just means reshaping.
if (auto input = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>()) {
if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
if (input.isSplat() && resultTy.hasStaticShape() &&
inputTy.getElementType() == resultTy.getElementType())
return input.reshape(resultTy);

View File

@ -63,9 +63,9 @@ LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
// Verify the rank agrees with the output type if the output type is ranked.
if (outputType) {
if (outputType.getRank() !=
input1_copy.getType().cast<RankedTensorType>().getRank() ||
llvm::cast<RankedTensorType>(input1_copy.getType()).getRank() ||
outputType.getRank() !=
input2_copy.getType().cast<RankedTensorType>().getRank())
llvm::cast<RankedTensorType>(input2_copy.getType()).getRank())
return rewriter.notifyMatchFailure(
loc, "the reshaped type doesn't agrees with the ranked output type");
}

View File

@ -103,8 +103,8 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
Value &input1, Value &input2) {
auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
if (!input1Ty || !input2Ty) {
return failure();
@ -126,9 +126,9 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
}
ArrayRef<int64_t> higherRankShape =
higherTensorValue.getType().cast<RankedTensorType>().getShape();
llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
ArrayRef<int64_t> lowerRankShape =
lowerTensorValue.getType().cast<RankedTensorType>().getShape();
llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
SmallVector<int64_t, 4> reshapeOutputShape;
@ -136,7 +136,8 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
.failed())
return failure();
auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
auto reshapeInputType =
llvm::cast<RankedTensorType>(lowerTensorValue.getType());
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());

View File

@ -118,7 +118,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
SmallVector<Operation *> operations;
operations.reserve(values.size());
for (transform::MappedValue value : values) {
if (auto *op = value.dyn_cast<Operation *>()) {
if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
operations.push_back(op);
continue;
}
@ -135,7 +135,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
SmallVector<Value> payloadValues;
payloadValues.reserve(values.size());
for (transform::MappedValue value : values) {
if (auto v = value.dyn_cast<Value>()) {
if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
payloadValues.push_back(v);
continue;
}
@ -152,7 +152,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
SmallVector<transform::Param> parameters;
parameters.reserve(values.size());
for (transform::MappedValue value : values) {
if (auto attr = value.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
parameters.push_back(attr);
continue;
}

View File

@ -18,7 +18,7 @@ namespace mlir {
bool isZeroIndex(OpFoldResult v) {
if (!v)
return false;
if (auto attr = v.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getValue().isZero();
}
@ -51,7 +51,7 @@ getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
void dispatchIndexOpFoldResult(OpFoldResult ofr,
SmallVectorImpl<Value> &dynamicVec,
SmallVectorImpl<int64_t> &staticVec) {
auto v = ofr.dyn_cast<Value>();
auto v = llvm::dyn_cast_if_present<Value>(ofr);
if (!v) {
APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
staticVec.push_back(apInt.getSExtValue());
@ -116,14 +116,14 @@ SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
/// If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
if (auto val = ofr.dyn_cast<Value>()) {
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
APSInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
return intVal.getSExtValue();
return std::nullopt;
}
// Case 2: Check for IntegerAttr.
Attribute attr = ofr.dyn_cast<Attribute>();
Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
@ -143,7 +143,8 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
if (cst1 && cst2 && *cst1 == *cst2)
return true;
auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
v2 = llvm::dyn_cast_if_present<Value>(ofr2);
return v1 && v1 == v2;
}

View File

@ -1154,7 +1154,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ExtractOp::Adaptor op(operands, attributes, properties);
auto vectorType = op.getVector().getType().cast<VectorType>();
auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
inferredReturnTypes.push_back(vectorType.getElementType());
} else {
@ -2003,9 +2003,9 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
if (adaptor.getSource().isa<IntegerAttr, FloatAttr>())
if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, adaptor.getSource());
if (auto attr = adaptor.getSource().dyn_cast<SplatElementsAttr>())
if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
return {};
}
@ -2090,7 +2090,7 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ShuffleOp::Adaptor op(operands, attributes, properties);
auto v1Type = op.getV1().getType().cast<VectorType>();
auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
auto v1Rank = v1Type.getRank();
// Construct resulting type: leading dimension matches mask
// length, all trailing dimensions match the operands.
@ -4951,7 +4951,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
if (auto attr = adaptor.getVector().dyn_cast_or_null<DenseElementsAttr>())
if (auto attr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
if (attr.isSplat())
return attr.reshape(getResultVectorType());

View File

@ -3642,7 +3642,7 @@ void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
if (auto *op = getDefiningOp())
return op->print(os, flags);
// TODO: Improve BlockArgument print'ing.
BlockArgument arg = this->cast<BlockArgument>();
BlockArgument arg = llvm::cast<BlockArgument>(*this);
os << "<block argument> of type '" << arg.getType()
<< "' at index: " << arg.getArgNumber();
}
@ -3656,7 +3656,7 @@ void Value::print(raw_ostream &os, AsmState &state) {
return op->print(os, state);
// TODO: Improve BlockArgument print'ing.
BlockArgument arg = this->cast<BlockArgument>();
BlockArgument arg = llvm::cast<BlockArgument>(*this);
os << "<block argument> of type '" << arg.getType()
<< "' at index: " << arg.getArgNumber();
}
@ -3693,10 +3693,10 @@ static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) {
Operation *op;
if (auto result = dyn_cast<OpResult>()) {
if (auto result = llvm::dyn_cast<OpResult>(*this)) {
op = result.getOwner();
} else {
op = cast<BlockArgument>().getOwner()->getParentOp();
op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp();
if (!op) {
os << "<<UNKNOWN SSA VALUE>>";
return;

View File

@ -347,14 +347,14 @@ BlockRange::BlockRange(SuccessorRange successors)
/// See `llvm::detail::indexed_accessor_range_base` for details.
BlockRange::OwnerT BlockRange::offset_base(OwnerT object, ptrdiff_t index) {
if (auto *operand = object.dyn_cast<BlockOperand *>())
if (auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
return {operand + index};
return {object.dyn_cast<Block *const *>() + index};
return {llvm::dyn_cast_if_present<Block *const *>(object) + index};
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Block *BlockRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
if (const auto *operand = object.dyn_cast<BlockOperand *>())
if (const auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
return operand[index].get();
return object.dyn_cast<Block *const *>()[index];
return llvm::dyn_cast_if_present<Block *const *>(object)[index];
}

View File

@ -483,7 +483,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
Type expectedType = std::get<1>(it);
// Normal values get pushed back directly.
if (auto value = std::get<0>(it).dyn_cast<Value>()) {
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
if (value.getType() != expectedType)
return cleanupFailure();

View File

@ -1247,12 +1247,12 @@ DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
DenseElementsAttr
DenseElementsAttr::mapValues(Type newElementType,
function_ref<APInt(const APInt &)> mapping) const {
return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
return llvm::cast<DenseIntElementsAttr>(*this).mapValues(newElementType, mapping);
}
DenseElementsAttr DenseElementsAttr::mapValues(
Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
return llvm::cast<DenseFPElementsAttr>(*this).mapValues(newElementType, mapping);
}
ShapedType DenseElementsAttr::getType() const {

View File

@ -88,45 +88,45 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
//===----------------------------------------------------------------------===//
unsigned FloatType::getWidth() {
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*this))
return 8;
if (isa<Float16Type, BFloat16Type>())
if (llvm::isa<Float16Type, BFloat16Type>(*this))
return 16;
if (isa<Float32Type>())
if (llvm::isa<Float32Type>(*this))
return 32;
if (isa<Float64Type>())
if (llvm::isa<Float64Type>(*this))
return 64;
if (isa<Float80Type>())
if (llvm::isa<Float80Type>(*this))
return 80;
if (isa<Float128Type>())
if (llvm::isa<Float128Type>(*this))
return 128;
llvm_unreachable("unexpected float type");
}
/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() {
if (isa<Float8E5M2Type>())
if (llvm::isa<Float8E5M2Type>(*this))
return APFloat::Float8E5M2();
if (isa<Float8E4M3FNType>())
if (llvm::isa<Float8E4M3FNType>(*this))
return APFloat::Float8E4M3FN();
if (isa<Float8E5M2FNUZType>())
if (llvm::isa<Float8E5M2FNUZType>(*this))
return APFloat::Float8E5M2FNUZ();
if (isa<Float8E4M3FNUZType>())
if (llvm::isa<Float8E4M3FNUZType>(*this))
return APFloat::Float8E4M3FNUZ();
if (isa<Float8E4M3B11FNUZType>())
if (llvm::isa<Float8E4M3B11FNUZType>(*this))
return APFloat::Float8E4M3B11FNUZ();
if (isa<BFloat16Type>())
if (llvm::isa<BFloat16Type>(*this))
return APFloat::BFloat();
if (isa<Float16Type>())
if (llvm::isa<Float16Type>(*this))
return APFloat::IEEEhalf();
if (isa<Float32Type>())
if (llvm::isa<Float32Type>(*this))
return APFloat::IEEEsingle();
if (isa<Float64Type>())
if (llvm::isa<Float64Type>(*this))
return APFloat::IEEEdouble();
if (isa<Float80Type>())
if (llvm::isa<Float80Type>(*this))
return APFloat::x87DoubleExtended();
if (isa<Float128Type>())
if (llvm::isa<Float128Type>(*this))
return APFloat::IEEEquad();
llvm_unreachable("non-floating point type used");
}
@ -269,21 +269,21 @@ Type TensorType::getElementType() const {
[](auto type) { return type.getElementType(); });
}
bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
ArrayRef<int64_t> TensorType::getShape() const {
return cast<RankedTensorType>().getShape();
return llvm::cast<RankedTensorType>(*this).getShape();
}
TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
if (auto unrankedTy = llvm::dyn_cast<UnrankedTensorType>(*this)) {
if (shape)
return RankedTensorType::get(*shape, elementType);
return UnrankedTensorType::get(elementType);
}
auto rankedTy = cast<RankedTensorType>();
auto rankedTy = llvm::cast<RankedTensorType>(*this);
if (!shape)
return RankedTensorType::get(rankedTy.getShape(), elementType,
rankedTy.getEncoding());
@ -356,15 +356,15 @@ Type BaseMemRefType::getElementType() const {
[](auto type) { return type.getElementType(); });
}
bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
ArrayRef<int64_t> BaseMemRefType::getShape() const {
return cast<MemRefType>().getShape();
return llvm::cast<MemRefType>(*this).getShape();
}
BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
if (auto unrankedTy = llvm::dyn_cast<UnrankedMemRefType>(*this)) {
if (!shape)
return UnrankedMemRefType::get(elementType, getMemorySpace());
MemRefType::Builder builder(*shape, elementType);
@ -372,7 +372,7 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
return builder;
}
MemRefType::Builder builder(cast<MemRefType>());
MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
if (shape)
builder.setShape(*shape);
builder.setElementType(elementType);
@ -389,15 +389,15 @@ MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
}
Attribute BaseMemRefType::getMemorySpace() const {
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
return rankedMemRefTy.getMemorySpace();
return cast<UnrankedMemRefType>().getMemorySpace();
return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
}
unsigned BaseMemRefType::getMemorySpaceAsInt() const {
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
return rankedMemRefTy.getMemorySpaceAsInt();
return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
}
//===----------------------------------------------------------------------===//

View File

@ -626,17 +626,17 @@ ValueRange::ValueRange(ResultRange values)
/// See `llvm::detail::indexed_accessor_range_base` for details.
ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
ptrdiff_t index) {
if (const auto *value = owner.dyn_cast<const Value *>())
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
return {value + index};
if (auto *operand = owner.dyn_cast<OpOperand *>())
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return {operand + index};
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
if (const auto *value = owner.dyn_cast<const Value *>())
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
return value[index];
if (auto *operand = owner.dyn_cast<OpOperand *>())
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return operand[index].get();
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
}

View File

@ -267,18 +267,18 @@ RegionRange::RegionRange(ArrayRef<Region *> regions)
/// See `llvm::detail::indexed_accessor_range_base` for details.
RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner,
ptrdiff_t index) {
if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
return region + index;
if (auto **region = owner.dyn_cast<Region **>())
if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
return region + index;
return &owner.get<Region *>()[index];
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Region *RegionRange::dereference_iterator(const OwnerT &owner,
ptrdiff_t index) {
if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
return region[index].get();
if (auto **region = owner.dyn_cast<Region **>())
if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
return region[index];
return &owner.get<Region *>()[index];
}

View File

@ -551,7 +551,7 @@ struct SymbolScope {
typename llvm::function_traits<CallbackT>::result_t,
void>::value> * = nullptr>
std::optional<WalkResult> walk(CallbackT cback) {
if (Region *region = limit.dyn_cast<Region *>())
if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
return walkSymbolUses(*region, cback);
return walkSymbolUses(limit.get<Operation *>(), cback);
}
@ -571,7 +571,7 @@ struct SymbolScope {
/// traversing into any nested symbol tables.
template <typename CallbackT>
std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
if (Region *region = limit.dyn_cast<Region *>())
if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
return ::walkSymbolTable(*region, cback);
return ::walkSymbolTable(limit.get<Operation *>(), cback);
}

View File

@ -27,9 +27,9 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
if (count == 0)
return;
ValueRange::OwnerT owner = values.begin().getBase();
if (auto *result = owner.dyn_cast<detail::OpResultImpl *>())
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(owner))
this->base = result;
else if (auto *operand = owner.dyn_cast<OpOperand *>())
else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
this->base = operand;
else
this->base = owner.get<const Value *>();
@ -37,22 +37,22 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
/// See `llvm::detail::indexed_accessor_range_base` for details.
TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
if (const auto *value = object.dyn_cast<const Value *>())
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
return {value + index};
if (auto *operand = object.dyn_cast<OpOperand *>())
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
return {operand + index};
if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return {result->getNextResultAtOffset(index)};
return {object.dyn_cast<const Type *>() + index};
return {llvm::dyn_cast_if_present<const Type *>(object) + index};
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
if (const auto *value = object.dyn_cast<const Value *>())
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
return (value + index)->getType();
if (auto *operand = object.dyn_cast<OpOperand *>())
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
return (operand + index)->get().getType();
if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return result->getNextResultAtOffset(index)->getType();
return object.dyn_cast<const Type *>()[index];
return llvm::dyn_cast_if_present<const Type *>(object)[index];
}

View File

@ -34,84 +34,94 @@ Type AbstractType::replaceImmediateSubElements(Type type,
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
bool Type::isFloat8E4M3B11FNUZ() const { return isa<Float8E4M3B11FNUZType>(); }
bool Type::isBF16() const { return isa<BFloat16Type>(); }
bool Type::isF16() const { return isa<Float16Type>(); }
bool Type::isF32() const { return isa<Float32Type>(); }
bool Type::isF64() const { return isa<Float64Type>(); }
bool Type::isF80() const { return isa<Float80Type>(); }
bool Type::isF128() const { return isa<Float128Type>(); }
bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
bool Type::isFloat8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); }
bool Type::isFloat8E5M2FNUZ() const {
return llvm::isa<Float8E5M2FNUZType>(*this);
}
bool Type::isFloat8E4M3FNUZ() const {
return llvm::isa<Float8E4M3FNUZType>(*this);
}
bool Type::isFloat8E4M3B11FNUZ() const {
return llvm::isa<Float8E4M3B11FNUZType>(*this);
}
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
bool Type::isF32() const { return llvm::isa<Float32Type>(*this); }
bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
bool Type::isIndex() const { return isa<IndexType>(); }
bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
/// Return true if this is an integer type with the specified width.
bool Type::isInteger(unsigned width) const {
if (auto intTy = dyn_cast<IntegerType>())
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.getWidth() == width;
return false;
}
bool Type::isSignlessInteger() const {
if (auto intTy = dyn_cast<IntegerType>())
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSignless();
return false;
}
bool Type::isSignlessInteger(unsigned width) const {
if (auto intTy = dyn_cast<IntegerType>())
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSignless() && intTy.getWidth() == width;
return false;
}
bool Type::isSignedInteger() const {
if (auto intTy = dyn_cast<IntegerType>())
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSigned();
return false;
}
bool Type::isSignedInteger(unsigned width) const {
if (auto intTy = dyn_cast<IntegerType>())
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isSigned() && intTy.getWidth() == width;
return false;
}
bool Type::isUnsignedInteger() const {
if (auto intTy = dyn_cast<IntegerType>())
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isUnsigned();
return false;
}
bool Type::isUnsignedInteger(unsigned width) const {
if (auto intTy = dyn_cast<IntegerType>())
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
return intTy.isUnsigned() && intTy.getWidth() == width;
return false;
}
bool Type::isSignlessIntOrIndex() const {
return isSignlessInteger() || isa<IndexType>();
return isSignlessInteger() || llvm::isa<IndexType>(*this);
}
bool Type::isSignlessIntOrIndexOrFloat() const {
return isSignlessInteger() || isa<IndexType, FloatType>();
return isSignlessInteger() || llvm::isa<IndexType, FloatType>(*this);
}
bool Type::isSignlessIntOrFloat() const {
return isSignlessInteger() || isa<FloatType>();
return isSignlessInteger() || llvm::isa<FloatType>(*this);
}
bool Type::isIntOrIndex() const { return isa<IntegerType>() || isIndex(); }
bool Type::isIntOrIndex() const {
return llvm::isa<IntegerType>(*this) || isIndex();
}
bool Type::isIntOrFloat() const { return isa<IntegerType, FloatType>(); }
bool Type::isIntOrFloat() const {
return llvm::isa<IntegerType, FloatType>(*this);
}
bool Type::isIntOrIndexOrFloat() const { return isIntOrFloat() || isIndex(); }
unsigned Type::getIntOrFloatBitWidth() const {
assert(isIntOrFloat() && "only integers and floats have a bitwidth");
if (auto intType = dyn_cast<IntegerType>())
if (auto intType = llvm::dyn_cast<IntegerType>(*this))
return intType.getWidth();
return cast<FloatType>().getWidth();
return llvm::cast<FloatType>(*this).getWidth();
}

View File

@ -48,11 +48,11 @@ static void printBlock(llvm::raw_ostream &os, Block *block,
}
void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const {
if (auto *op = this->dyn_cast<Operation *>())
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*this))
return printOp(os, op, flags);
if (auto *region = this->dyn_cast<Region *>())
if (auto *region = llvm::dyn_cast_if_present<Region *>(*this))
return printRegion(os, region, flags);
if (auto *block = this->dyn_cast<Block *>())
if (auto *block = llvm::dyn_cast_if_present<Block *>(*this))
return printBlock(os, block, flags);
llvm_unreachable("unknown IRUnit");
}

View File

@ -18,7 +18,7 @@ using namespace mlir::detail;
/// If this value is the result of an Operation, return the operation that
/// defines it.
Operation *Value::getDefiningOp() const {
if (auto result = dyn_cast<OpResult>())
if (auto result = llvm::dyn_cast<OpResult>(*this))
return result.getOwner();
return nullptr;
}
@ -27,28 +27,28 @@ Location Value::getLoc() const {
if (auto *op = getDefiningOp())
return op->getLoc();
return cast<BlockArgument>().getLoc();
return llvm::cast<BlockArgument>(*this).getLoc();
}
void Value::setLoc(Location loc) {
if (auto *op = getDefiningOp())
return op->setLoc(loc);
return cast<BlockArgument>().setLoc(loc);
return llvm::cast<BlockArgument>(*this).setLoc(loc);
}
/// Return the Region in which this Value is defined.
Region *Value::getParentRegion() {
if (auto *op = getDefiningOp())
return op->getParentRegion();
return cast<BlockArgument>().getOwner()->getParent();
return llvm::cast<BlockArgument>(*this).getOwner()->getParent();
}
/// Return the Block in which this Value is defined.
Block *Value::getParentBlock() {
if (Operation *op = getDefiningOp())
return op->getBlock();
return cast<BlockArgument>().getOwner();
return llvm::cast<BlockArgument>(*this).getOwner();
}
//===----------------------------------------------------------------------===//

View File

@ -241,7 +241,7 @@ mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
TypeID typeID) {
return llvm::to_vector<4>(llvm::make_filter_range(
entries, [typeID](DataLayoutEntryInterface entry) {
auto type = entry.getKey().dyn_cast<Type>();
auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
return type && type.getTypeID() == typeID;
}));
}
@ -521,7 +521,7 @@ void DataLayoutSpecInterface::bucketEntriesByType(
DenseMap<TypeID, DataLayoutEntryList> &types,
DenseMap<StringAttr, DataLayoutEntryInterface> &ids) {
for (DataLayoutEntryInterface entry : getEntries()) {
if (auto type = entry.getKey().dyn_cast<Type>())
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()))
types[type.getTypeID()].push_back(entry);
else
ids[entry.getKey().get<StringAttr>()] = entry;

View File

@ -68,7 +68,7 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
bool ShapeAdaptor::hasRank() const {
if (val.isNull())
return false;
if (auto t = val.dyn_cast<Type>())
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).hasRank();
if (val.is<Attribute>())
return true;
@ -78,7 +78,7 @@ bool ShapeAdaptor::hasRank() const {
Type ShapeAdaptor::getElementType() const {
if (val.isNull())
return nullptr;
if (auto t = val.dyn_cast<Type>())
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getElementType();
if (val.is<Attribute>())
return nullptr;
@ -87,10 +87,10 @@ Type ShapeAdaptor::getElementType() const {
void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
assert(hasRank());
if (auto t = val.dyn_cast<Type>()) {
if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
res.assign(vals.begin(), vals.end());
} else if (auto attr = val.dyn_cast<Attribute>()) {
} else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
res.clear();
res.reserve(dattr.size());
@ -110,9 +110,9 @@ void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
int64_t ShapeAdaptor::getDimSize(int index) const {
assert(hasRank());
if (auto t = val.dyn_cast<Type>())
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getDimSize(index);
if (auto attr = val.dyn_cast<Attribute>())
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
return cast<DenseIntElementsAttr>(attr)
.getValues<APInt>()[index]
.getSExtValue();
@ -122,9 +122,9 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
int64_t ShapeAdaptor::getRank() const {
assert(hasRank());
if (auto t = val.dyn_cast<Type>())
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getRank();
if (auto attr = val.dyn_cast<Attribute>())
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
return cast<DenseIntElementsAttr>(attr).size();
return val.get<ShapedTypeComponents *>()->getDims().size();
}
@ -133,9 +133,9 @@ bool ShapeAdaptor::hasStaticShape() const {
if (!hasRank())
return false;
if (auto t = val.dyn_cast<Type>())
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).hasStaticShape();
if (auto attr = val.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
for (auto index : dattr.getValues<APInt>())
if (ShapedType::isDynamic(index.getSExtValue()))
@ -149,10 +149,10 @@ bool ShapeAdaptor::hasStaticShape() const {
int64_t ShapeAdaptor::getNumElements() const {
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
if (auto t = val.dyn_cast<Type>())
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getNumElements();
if (auto attr = val.dyn_cast<Attribute>()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
int64_t num = 1;
for (auto index : dattr.getValues<APInt>()) {

View File

@ -26,14 +26,14 @@ namespace mlir {
/// If ofr is a constant integer or an IntegerAttr, return the integer.
static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
if (auto val = ofr.dyn_cast<Value>()) {
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
APSInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
return intVal.getSExtValue();
return std::nullopt;
}
// Case 2: Check for IntegerAttr.
Attribute attr = ofr.dyn_cast<Attribute>();
Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
@ -99,7 +99,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
}
AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
if (Value value = ofr.dyn_cast<Value>())
if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
return getExpr(value, /*dim=*/std::nullopt);
auto constInt = getConstantIntValue(ofr);
assert(constInt.has_value() && "expected Integer constant");

View File

@ -26,7 +26,8 @@ struct PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
const Pass &getPass() const { return pass; }
Operation *getOp() const {
ArrayRef<IRUnit> irUnits = getContextIRUnits();
return irUnits.empty() ? nullptr : irUnits[0].dyn_cast<Operation *>();
return irUnits.empty() ? nullptr
: llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
}
public:

View File

@ -384,7 +384,7 @@ void Operator::populateTypeInferenceInfo(
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
// Check for a non-variable length operand to use as the type anchor.
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
NamedTypeConstraint *operand = arg.dyn_cast<NamedTypeConstraint *>();
NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
return operand && !operand->isVariableLength();
});
if (operandI == arguments.end())
@ -824,7 +824,7 @@ StringRef Operator::getAssemblyFormat() const {
void Operator::print(llvm::raw_ostream &os) const {
os << "op '" << getOperationName() << "'\n";
for (Argument arg : arguments) {
if (auto *attr = arg.dyn_cast<NamedAttribute *>())
if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg))
os << "[attribute] " << attr->name << '\n';
else
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';

View File

@ -131,7 +131,7 @@ convertBranchWeights(std::optional<ElementsAttr> weights,
return nullptr;
SmallVector<uint32_t> weightValues;
weightValues.reserve(weights->size());
for (APInt weight : weights->cast<DenseIntElementsAttr>())
for (APInt weight : llvm::cast<DenseIntElementsAttr>(*weights))
weightValues.push_back(weight.getLimitedValue());
return llvm::MDBuilder(moduleTranslation.getLLVMContext())
.createBranchWeights(weightValues);
@ -330,7 +330,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
auto *ty = llvm::cast<llvm::IntegerType>(
moduleTranslation.convertType(switchOp.getValue().getType()));
for (auto i :
llvm::zip(switchOp.getCaseValues()->cast<DenseIntElementsAttr>(),
llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()),
switchOp.getCaseDestinations()))
switchInst->addCase(
llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),

Some files were not shown because too many files have changed in this diff Show More