mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-24 04:26:07 +00:00
[mlir] Move casting calls from methods to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This patch updates all remaining uses of the deprecated functionality in mlir/. This was done with clang-tidy as described below and further modifications to GPUBase.td and OpenMPOpsInterfaces.td. Steps are described per line, as comments are removed by git: 0. Retrieve the change from the following to build clang-tidy with an additional check: main...tpopp:llvm-project:tidy-cast-check 1. Build clang-tidy 2. Run clang-tidy over your entire codebase while disabling all checks and enabling the one relevant one. Run on all header files also. 3. Delete .inc files that were also modified, so the next build rebuilds them to a pure state. ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -header-filter=mlir/ mlir/* -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D151542
This commit is contained in:
parent
7c52520c8d
commit
68f58812e3
@ -54,7 +54,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
||||
|
||||
// If the type is a function type, it contains the input and result types of
|
||||
// this operation.
|
||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
||||
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||
result.operands))
|
||||
return mlir::failure();
|
||||
@ -133,13 +133,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
||||
mlir::LogicalResult ConstantOp::verify() {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
// Check that the rank of the attribute type matches the rank of the constant
|
||||
// result type.
|
||||
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
|
||||
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return emitOpError("return type must match the one of the attached value "
|
||||
"attribute: ")
|
||||
@ -269,8 +269,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
||||
auto resultType = results.front();
|
||||
|
||||
// Check that the result type of the function matches the operand type.
|
||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
||||
resultType.isa<mlir::UnrankedTensorType>())
|
||||
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||
return mlir::success();
|
||||
|
||||
return emitError() << "type of return operand (" << inputType
|
||||
@ -289,8 +289,8 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
}
|
||||
|
||||
mlir::LogicalResult TransposeOp::verify() {
|
||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
||||
|
@ -54,7 +54,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
||||
|
||||
// If the type is a function type, it contains the input and result types of
|
||||
// this operation.
|
||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
||||
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||
result.operands))
|
||||
return mlir::failure();
|
||||
@ -133,13 +133,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
||||
mlir::LogicalResult ConstantOp::verify() {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
// Check that the rank of the attribute type matches the rank of the constant
|
||||
// result type.
|
||||
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
|
||||
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return emitOpError("return type must match the one of the attached value "
|
||||
"attribute: ")
|
||||
@ -269,8 +269,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
||||
auto resultType = results.front();
|
||||
|
||||
// Check that the result type of the function matches the operand type.
|
||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
||||
resultType.isa<mlir::UnrankedTensorType>())
|
||||
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||
return mlir::success();
|
||||
|
||||
return emitError() << "type of return operand (" << inputType
|
||||
@ -289,8 +289,8 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
}
|
||||
|
||||
mlir::LogicalResult TransposeOp::verify() {
|
||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
||||
|
@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
||||
|
||||
// If the type is a function type, it contains the input and result types of
|
||||
// this operation.
|
||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
||||
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||
result.operands))
|
||||
return mlir::failure();
|
||||
@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
||||
mlir::LogicalResult ConstantOp::verify() {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
// Check that the rank of the attribute type matches the rank of the constant
|
||||
// result type.
|
||||
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
|
||||
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return emitOpError("return type must match the one of the attached value "
|
||||
"attribute: ")
|
||||
@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
// The inputs must be Tensors with the same element type.
|
||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
||||
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
|
||||
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
|
||||
if (!input || !output || input.getElementType() != output.getElementType())
|
||||
return false;
|
||||
// The shape is required to match if both types are ranked.
|
||||
@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
||||
auto resultType = results.front();
|
||||
|
||||
// Check that the result type of the function matches the operand type.
|
||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
||||
resultType.isa<mlir::UnrankedTensorType>())
|
||||
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||
return mlir::success();
|
||||
|
||||
return emitError() << "type of return operand (" << inputType
|
||||
@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
}
|
||||
|
||||
void TransposeOp::inferShapes() {
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
|
||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
mlir::LogicalResult TransposeOp::verify() {
|
||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
||||
|
@ -94,7 +94,7 @@ struct ShapeInferencePass
|
||||
/// operands inferred.
|
||||
static bool allOperandsInferred(Operation *op) {
|
||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||
return operandType.isa<RankedTensorType>();
|
||||
return llvm::isa<RankedTensorType>(operandType);
|
||||
});
|
||||
}
|
||||
|
||||
@ -102,7 +102,7 @@ struct ShapeInferencePass
|
||||
/// shaped result.
|
||||
static bool returnsDynamicShape(Operation *op) {
|
||||
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
||||
return !resultType.isa<RankedTensorType>();
|
||||
return !llvm::isa<RankedTensorType>(resultType);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
||||
|
||||
// If the type is a function type, it contains the input and result types of
|
||||
// this operation.
|
||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
||||
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||
result.operands))
|
||||
return mlir::failure();
|
||||
@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
||||
mlir::LogicalResult ConstantOp::verify() {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
// Check that the rank of the attribute type matches the rank of the constant
|
||||
// result type.
|
||||
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
|
||||
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return emitOpError("return type must match the one of the attached value "
|
||||
"attribute: ")
|
||||
@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
// The inputs must be Tensors with the same element type.
|
||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
||||
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
|
||||
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
|
||||
if (!input || !output || input.getElementType() != output.getElementType())
|
||||
return false;
|
||||
// The shape is required to match if both types are ranked.
|
||||
@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
||||
auto resultType = results.front();
|
||||
|
||||
// Check that the result type of the function matches the operand type.
|
||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
||||
resultType.isa<mlir::UnrankedTensorType>())
|
||||
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||
return mlir::success();
|
||||
|
||||
return emitError() << "type of return operand (" << inputType
|
||||
@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
}
|
||||
|
||||
void TransposeOp::inferShapes() {
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
|
||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
mlir::LogicalResult TransposeOp::verify() {
|
||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
||||
|
@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
|
||||
static void lowerOpToLoops(Operation *op, ValueRange operands,
|
||||
PatternRewriter &rewriter,
|
||||
LoopIterationFn processIteration) {
|
||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
||||
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
||||
|
||||
// When lowering the constant operation, we allocate and assign the constant
|
||||
// values to a corresponding memref allocation.
|
||||
auto tensorType = op.getType().cast<RankedTensorType>();
|
||||
auto tensorType = llvm::cast<RankedTensorType>(op.getType());
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||
|
||||
@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
||||
target.addIllegalDialect<toy::ToyDialect>();
|
||||
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
|
||||
return llvm::none_of(op->getOperandTypes(),
|
||||
[](Type type) { return type.isa<TensorType>(); });
|
||||
[](Type type) { return llvm::isa<TensorType>(type); });
|
||||
});
|
||||
|
||||
// Now that the conversion target has been defined, we just need to provide
|
||||
|
@ -94,7 +94,7 @@ struct ShapeInferencePass
|
||||
/// operands inferred.
|
||||
static bool allOperandsInferred(Operation *op) {
|
||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||
return operandType.isa<RankedTensorType>();
|
||||
return llvm::isa<RankedTensorType>(operandType);
|
||||
});
|
||||
}
|
||||
|
||||
@ -102,7 +102,7 @@ struct ShapeInferencePass
|
||||
/// shaped result.
|
||||
static bool returnsDynamicShape(Operation *op) {
|
||||
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
||||
return !resultType.isa<RankedTensorType>();
|
||||
return !llvm::isa<RankedTensorType>(resultType);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
@ -114,7 +114,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
||||
|
||||
// If the type is a function type, it contains the input and result types of
|
||||
// this operation.
|
||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
||||
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||
result.operands))
|
||||
return mlir::failure();
|
||||
@ -193,13 +193,13 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
||||
mlir::LogicalResult ConstantOp::verify() {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType = getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
// Check that the rank of the attribute type matches the rank of the constant
|
||||
// result type.
|
||||
auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
|
||||
auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return emitOpError("return type must match the one of the attached value "
|
||||
"attribute: ")
|
||||
@ -254,8 +254,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
// The inputs must be Tensors with the same element type.
|
||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
||||
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
|
||||
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
|
||||
if (!input || !output || input.getElementType() != output.getElementType())
|
||||
return false;
|
||||
// The shape is required to match if both types are ranked.
|
||||
@ -397,8 +397,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
||||
auto resultType = results.front();
|
||||
|
||||
// Check that the result type of the function matches the operand type.
|
||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
||||
resultType.isa<mlir::UnrankedTensorType>())
|
||||
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||
return mlir::success();
|
||||
|
||||
return emitError() << "type of return operand (" << inputType
|
||||
@ -417,14 +417,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
}
|
||||
|
||||
void TransposeOp::inferShapes() {
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
|
||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
mlir::LogicalResult TransposeOp::verify() {
|
||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
||||
|
@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
|
||||
static void lowerOpToLoops(Operation *op, ValueRange operands,
|
||||
PatternRewriter &rewriter,
|
||||
LoopIterationFn processIteration) {
|
||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
||||
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
||||
|
||||
// When lowering the constant operation, we allocate and assign the constant
|
||||
// values to a corresponding memref allocation.
|
||||
auto tensorType = op.getType().cast<RankedTensorType>();
|
||||
auto tensorType = llvm::cast<RankedTensorType>(op.getType());
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||
|
||||
@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
||||
target.addIllegalDialect<toy::ToyDialect>();
|
||||
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
|
||||
return llvm::none_of(op->getOperandTypes(),
|
||||
[](Type type) { return type.isa<TensorType>(); });
|
||||
[](Type type) { return llvm::isa<TensorType>(type); });
|
||||
});
|
||||
|
||||
// Now that the conversion target has been defined, we just need to provide
|
||||
|
@ -61,7 +61,7 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
|
||||
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
|
||||
auto memRefShape = memRefType.getShape();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
|
@ -94,7 +94,7 @@ struct ShapeInferencePass
|
||||
/// operands inferred.
|
||||
static bool allOperandsInferred(Operation *op) {
|
||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||
return operandType.isa<RankedTensorType>();
|
||||
return llvm::isa<RankedTensorType>(operandType);
|
||||
});
|
||||
}
|
||||
|
||||
@ -102,7 +102,7 @@ struct ShapeInferencePass
|
||||
/// shaped result.
|
||||
static bool returnsDynamicShape(Operation *op) {
|
||||
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
||||
return !resultType.isa<RankedTensorType>();
|
||||
return !llvm::isa<RankedTensorType>(resultType);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
@ -101,7 +101,7 @@ static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
|
||||
|
||||
// If the type is a function type, it contains the input and result types of
|
||||
// this operation.
|
||||
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
|
||||
if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
|
||||
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
|
||||
result.operands))
|
||||
return mlir::failure();
|
||||
@ -179,9 +179,9 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) {
|
||||
static mlir::LogicalResult verifyConstantForType(mlir::Type type,
|
||||
mlir::Attribute opaqueValue,
|
||||
mlir::Operation *op) {
|
||||
if (type.isa<mlir::TensorType>()) {
|
||||
if (llvm::isa<mlir::TensorType>(type)) {
|
||||
// Check that the value is an elements attribute.
|
||||
auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
|
||||
auto attrValue = llvm::dyn_cast<mlir::DenseFPElementsAttr>(opaqueValue);
|
||||
if (!attrValue)
|
||||
return op->emitError("constant of TensorType must be initialized by "
|
||||
"a DenseFPElementsAttr, got ")
|
||||
@ -189,13 +189,13 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
|
||||
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType = type.dyn_cast<mlir::RankedTensorType>();
|
||||
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(type);
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
// Check that the rank of the attribute type matches the rank of the
|
||||
// constant result type.
|
||||
auto attrType = attrValue.getType().cast<mlir::RankedTensorType>();
|
||||
auto attrType = llvm::cast<mlir::RankedTensorType>(attrValue.getType());
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return op->emitOpError("return type must match the one of the attached "
|
||||
"value attribute: ")
|
||||
@ -213,11 +213,11 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
auto resultType = type.cast<StructType>();
|
||||
auto resultType = llvm::cast<StructType>(type);
|
||||
llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
|
||||
|
||||
// Verify that the initializer is an Array.
|
||||
auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
|
||||
auto attrValue = llvm::dyn_cast<ArrayAttr>(opaqueValue);
|
||||
if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
|
||||
return op->emitError("constant of StructType must be initialized by an "
|
||||
"ArrayAttr with the same number of elements, got ")
|
||||
@ -283,8 +283,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
// The inputs must be Tensors with the same element type.
|
||||
TensorType input = inputs.front().dyn_cast<TensorType>();
|
||||
TensorType output = outputs.front().dyn_cast<TensorType>();
|
||||
TensorType input = llvm::dyn_cast<TensorType>(inputs.front());
|
||||
TensorType output = llvm::dyn_cast<TensorType>(outputs.front());
|
||||
if (!input || !output || input.getElementType() != output.getElementType())
|
||||
return false;
|
||||
// The shape is required to match if both types are ranked.
|
||||
@ -426,8 +426,8 @@ mlir::LogicalResult ReturnOp::verify() {
|
||||
auto resultType = results.front();
|
||||
|
||||
// Check that the result type of the function matches the operand type.
|
||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
||||
resultType.isa<mlir::UnrankedTensorType>())
|
||||
if (inputType == resultType || llvm::isa<mlir::UnrankedTensorType>(inputType) ||
|
||||
llvm::isa<mlir::UnrankedTensorType>(resultType))
|
||||
return mlir::success();
|
||||
|
||||
return emitError() << "type of return operand (" << inputType
|
||||
@ -442,7 +442,7 @@ mlir::LogicalResult ReturnOp::verify() {
|
||||
void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
|
||||
mlir::Value input, size_t index) {
|
||||
// Extract the result type from the input type.
|
||||
StructType structTy = input.getType().cast<StructType>();
|
||||
StructType structTy = llvm::cast<StructType>(input.getType());
|
||||
assert(index < structTy.getNumElementTypes());
|
||||
mlir::Type resultType = structTy.getElementTypes()[index];
|
||||
|
||||
@ -451,7 +451,7 @@ void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
|
||||
}
|
||||
|
||||
mlir::LogicalResult StructAccessOp::verify() {
|
||||
StructType structTy = getInput().getType().cast<StructType>();
|
||||
StructType structTy = llvm::cast<StructType>(getInput().getType());
|
||||
size_t indexValue = getIndex();
|
||||
if (indexValue >= structTy.getNumElementTypes())
|
||||
return emitOpError()
|
||||
@ -474,14 +474,14 @@ void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
}
|
||||
|
||||
void TransposeOp::inferShapes() {
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
auto arrayTy = llvm::cast<RankedTensorType>(getOperand().getType());
|
||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
mlir::LogicalResult TransposeOp::verify() {
|
||||
auto inputType = getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = llvm::dyn_cast<RankedTensorType>(getOperand().getType());
|
||||
auto resultType = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
||||
@ -598,7 +598,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
|
||||
return nullptr;
|
||||
|
||||
// Check that the type is either a TensorType or another StructType.
|
||||
if (!elementType.isa<mlir::TensorType, StructType>()) {
|
||||
if (!llvm::isa<mlir::TensorType, StructType>(elementType)) {
|
||||
parser.emitError(typeLoc, "element type for a struct must either "
|
||||
"be a TensorType or a StructType, got: ")
|
||||
<< elementType;
|
||||
@ -619,7 +619,7 @@ mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
|
||||
void ToyDialect::printType(mlir::Type type,
|
||||
mlir::DialectAsmPrinter &printer) const {
|
||||
// Currently the only toy type is a struct type.
|
||||
StructType structType = type.cast<StructType>();
|
||||
StructType structType = llvm::cast<StructType>(type);
|
||||
|
||||
// Print the struct type according to the parser format.
|
||||
printer << "struct<";
|
||||
@ -653,9 +653,9 @@ mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type,
|
||||
mlir::Location loc) {
|
||||
if (type.isa<StructType>())
|
||||
if (llvm::isa<StructType>(type))
|
||||
return builder.create<StructConstantOp>(loc, type,
|
||||
value.cast<mlir::ArrayAttr>());
|
||||
llvm::cast<mlir::ArrayAttr>(value));
|
||||
return builder.create<ConstantOp>(loc, type,
|
||||
value.cast<mlir::DenseElementsAttr>());
|
||||
llvm::cast<mlir::DenseElementsAttr>(value));
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
|
||||
static void lowerOpToLoops(Operation *op, ValueRange operands,
|
||||
PatternRewriter &rewriter,
|
||||
LoopIterationFn processIteration) {
|
||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
||||
auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
@ -144,7 +144,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
|
||||
|
||||
// When lowering the constant operation, we allocate and assign the constant
|
||||
// values to a corresponding memref allocation.
|
||||
auto tensorType = op.getType().cast<RankedTensorType>();
|
||||
auto tensorType = llvm::cast<RankedTensorType>(op.getType());
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||
|
||||
@ -342,7 +342,7 @@ void ToyToAffineLoweringPass::runOnOperation() {
|
||||
target.addIllegalDialect<toy::ToyDialect>();
|
||||
target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
|
||||
return llvm::none_of(op->getOperandTypes(),
|
||||
[](Type type) { return type.isa<TensorType>(); });
|
||||
[](Type type) { return llvm::isa<TensorType>(type); });
|
||||
});
|
||||
|
||||
// Now that the conversion target has been defined, we just need to provide
|
||||
|
@ -61,7 +61,7 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
|
||||
auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin()));
|
||||
auto memRefShape = memRefType.getShape();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
|
@ -94,7 +94,7 @@ struct ShapeInferencePass
|
||||
/// operands inferred.
|
||||
static bool allOperandsInferred(Operation *op) {
|
||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||
return operandType.isa<RankedTensorType>();
|
||||
return llvm::isa<RankedTensorType>(operandType);
|
||||
});
|
||||
}
|
||||
|
||||
@ -102,7 +102,7 @@ struct ShapeInferencePass
|
||||
/// shaped result.
|
||||
static bool returnsDynamicShape(Operation *op) {
|
||||
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
|
||||
return !resultType.isa<RankedTensorType>();
|
||||
return !llvm::isa<RankedTensorType>(resultType);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
@ -31,7 +31,8 @@ OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
|
||||
|
||||
/// Fold simple struct access operations that access into a constant.
|
||||
OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
|
||||
auto structAttr = adaptor.getInput().dyn_cast_or_null<mlir::ArrayAttr>();
|
||||
auto structAttr =
|
||||
llvm::dyn_cast_if_present<mlir::ArrayAttr>(adaptor.getInput());
|
||||
if (!structAttr)
|
||||
return nullptr;
|
||||
|
||||
|
@ -62,19 +62,19 @@ class FileLineColLocBreakpointManager
|
||||
public:
|
||||
Breakpoint *match(const Action &action) const override {
|
||||
for (const IRUnit &unit : action.getContextIRUnits()) {
|
||||
if (auto *op = unit.dyn_cast<Operation *>()) {
|
||||
if (auto *op = llvm::dyn_cast_if_present<Operation *>(unit)) {
|
||||
if (auto match = matchFromLocation(op->getLoc()))
|
||||
return *match;
|
||||
continue;
|
||||
}
|
||||
if (auto *block = unit.dyn_cast<Block *>()) {
|
||||
if (auto *block = llvm::dyn_cast_if_present<Block *>(unit)) {
|
||||
for (auto &op : block->getOperations()) {
|
||||
if (auto match = matchFromLocation(op.getLoc()))
|
||||
return *match;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (Region *region = unit.dyn_cast<Region *>()) {
|
||||
if (Region *region = llvm::dyn_cast_if_present<Region *>(unit)) {
|
||||
if (auto match = matchFromLocation(region->getLoc()))
|
||||
return *match;
|
||||
continue;
|
||||
|
@ -110,27 +110,27 @@ class MMAMatrixOf<list<Type> allowedTypes> :
|
||||
"gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
|
||||
|
||||
// Types for all sparse handles.
|
||||
def GPU_SparseEnvHandle :
|
||||
DialectType<GPU_Dialect,
|
||||
CPred<"$_self.isa<::mlir::gpu::SparseEnvHandleType>()">,
|
||||
"sparse environment handle type">,
|
||||
def GPU_SparseEnvHandle :
|
||||
DialectType<GPU_Dialect,
|
||||
CPred<"llvm::isa<::mlir::gpu::SparseEnvHandleType>($_self)">,
|
||||
"sparse environment handle type">,
|
||||
BuildableType<"mlir::gpu::SparseEnvHandleType::get($_builder.getContext())">;
|
||||
|
||||
def GPU_SparseDnVecHandle :
|
||||
DialectType<GPU_Dialect,
|
||||
CPred<"$_self.isa<::mlir::gpu::SparseDnVecHandleType>()">,
|
||||
def GPU_SparseDnVecHandle :
|
||||
DialectType<GPU_Dialect,
|
||||
CPred<"llvm::isa<::mlir::gpu::SparseDnVecHandleType>($_self)">,
|
||||
"dense vector handle type">,
|
||||
BuildableType<"mlir::gpu::SparseDnVecHandleType::get($_builder.getContext())">;
|
||||
|
||||
def GPU_SparseDnMatHandle :
|
||||
DialectType<GPU_Dialect,
|
||||
CPred<"$_self.isa<::mlir::gpu::SparseDnMatHandleType>()">,
|
||||
def GPU_SparseDnMatHandle :
|
||||
DialectType<GPU_Dialect,
|
||||
CPred<"llvm::isa<::mlir::gpu::SparseDnMatHandleType>($_self)">,
|
||||
"dense matrix handle type">,
|
||||
BuildableType<"mlir::gpu::SparseDnMatHandleType::get($_builder.getContext())">;
|
||||
|
||||
def GPU_SparseSpMatHandle :
|
||||
DialectType<GPU_Dialect,
|
||||
CPred<"$_self.isa<::mlir::gpu::SparseSpMatHandleType>()">,
|
||||
def GPU_SparseSpMatHandle :
|
||||
DialectType<GPU_Dialect,
|
||||
CPred<"llvm::isa<::mlir::gpu::SparseSpMatHandleType>($_self)">,
|
||||
"sparse matrix handle type">,
|
||||
BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">;
|
||||
|
||||
|
@ -95,7 +95,7 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
|
||||
/*methodName=*/"getDeclareTargetDeviceType",
|
||||
(ins), [{}], [{
|
||||
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
|
||||
if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
|
||||
if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
|
||||
return dAttr.getDeviceType().getValue();
|
||||
return {};
|
||||
}]>,
|
||||
@ -108,7 +108,7 @@ def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
|
||||
/*methodName=*/"getDeclareTargetCaptureClause",
|
||||
(ins), [{}], [{
|
||||
if (mlir::Attribute dTar = $_op->getAttr("omp.declare_target"))
|
||||
if (auto dAttr = dTar.dyn_cast_or_null<mlir::omp::DeclareTargetAttr>())
|
||||
if (auto dAttr = llvm::dyn_cast_or_null<mlir::omp::DeclareTargetAttr>(dTar))
|
||||
return dAttr.getCaptureClause().getValue();
|
||||
return {};
|
||||
}]>
|
||||
|
@ -115,7 +115,7 @@ public:
|
||||
static bool classof(Type type);
|
||||
|
||||
/// Allow implicit conversion to ShapedType.
|
||||
operator ShapedType() const { return cast<ShapedType>(); }
|
||||
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -169,7 +169,7 @@ public:
|
||||
unsigned getMemorySpaceAsInt() const;
|
||||
|
||||
/// Allow implicit conversion to ShapedType.
|
||||
operator ShapedType() const { return cast<ShapedType>(); }
|
||||
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -217,13 +217,15 @@ private:
|
||||
}
|
||||
|
||||
static bool isEmptyKey(mlir::TypeRange range) {
|
||||
if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
|
||||
if (const auto *type =
|
||||
llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
|
||||
return type == getEmptyKeyPointer();
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isTombstoneKey(mlir::TypeRange range) {
|
||||
if (const auto *type = range.getBase().dyn_cast<const mlir::Type *>())
|
||||
if (const auto *type =
|
||||
llvm::dyn_cast_if_present<const mlir::Type *>(range.getBase()))
|
||||
return type == getTombstoneKeyPointer();
|
||||
return false;
|
||||
}
|
||||
|
@ -163,12 +163,12 @@ public:
|
||||
|
||||
/// Return the value the effect is applied on, or nullptr if there isn't a
|
||||
/// known value being affected.
|
||||
Value getValue() const { return value ? value.dyn_cast<Value>() : Value(); }
|
||||
Value getValue() const { return value ? llvm::dyn_cast_if_present<Value>(value) : Value(); }
|
||||
|
||||
/// Return the symbol reference the effect is applied on, or nullptr if there
|
||||
/// isn't a known smbol being affected.
|
||||
SymbolRefAttr getSymbolRef() const {
|
||||
return value ? value.dyn_cast<SymbolRefAttr>() : SymbolRefAttr();
|
||||
return value ? llvm::dyn_cast_if_present<SymbolRefAttr>(value) : SymbolRefAttr();
|
||||
}
|
||||
|
||||
/// Return the resource that the effect applies to.
|
||||
|
@ -254,7 +254,7 @@ struct NestedAnalysisMap {
|
||||
/// Returns the parent analysis map for this analysis map, or null if this is
|
||||
/// the top-level map.
|
||||
const NestedAnalysisMap *getParent() const {
|
||||
return parentOrInstrumentor.dyn_cast<NestedAnalysisMap *>();
|
||||
return llvm::dyn_cast_if_present<NestedAnalysisMap *>(parentOrInstrumentor);
|
||||
}
|
||||
|
||||
/// Returns a pass instrumentation object for the current operation. This
|
||||
|
@ -89,7 +89,7 @@ void SparseConstantPropagation::visitOperation(
|
||||
|
||||
// Merge in the result of the fold, either a constant or a value.
|
||||
OpFoldResult foldResult = std::get<1>(it);
|
||||
if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
|
||||
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
|
||||
propagateIfChanged(lattice,
|
||||
lattice->join(ConstantValue(attr, op->getDialect())));
|
||||
|
@ -31,7 +31,7 @@ void Executable::print(raw_ostream &os) const {
|
||||
}
|
||||
|
||||
void Executable::onUpdate(DataFlowSolver *solver) const {
|
||||
if (auto *block = point.dyn_cast<Block *>()) {
|
||||
if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
|
||||
// Re-invoke the analyses on the block itself.
|
||||
for (DataFlowAnalysis *analysis : subscribers)
|
||||
solver->enqueue({block, analysis});
|
||||
@ -39,7 +39,7 @@ void Executable::onUpdate(DataFlowSolver *solver) const {
|
||||
for (DataFlowAnalysis *analysis : subscribers)
|
||||
for (Operation &op : *block)
|
||||
solver->enqueue({&op, analysis});
|
||||
} else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
|
||||
} else if (auto *programPoint = llvm::dyn_cast_if_present<GenericProgramPoint *>(point)) {
|
||||
// Re-invoke the analysis on the successor block.
|
||||
if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
|
||||
for (DataFlowAnalysis *analysis : subscribers)
|
||||
@ -219,7 +219,7 @@ void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
|
||||
LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
|
||||
if (point.is<Block *>())
|
||||
return success();
|
||||
auto *op = point.dyn_cast<Operation *>();
|
||||
auto *op = llvm::dyn_cast_if_present<Operation *>(point);
|
||||
if (!op)
|
||||
return emitError(point.getLoc(), "unknown program point kind");
|
||||
|
||||
|
@ -33,9 +33,9 @@ LogicalResult AbstractDenseDataFlowAnalysis::initialize(Operation *top) {
|
||||
}
|
||||
|
||||
LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) {
|
||||
if (auto *op = point.dyn_cast<Operation *>())
|
||||
if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
|
||||
processOperation(op);
|
||||
else if (auto *block = point.dyn_cast<Block *>())
|
||||
else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
|
||||
visitBlock(block);
|
||||
else
|
||||
return failure();
|
||||
|
@ -181,7 +181,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
|
||||
if (auto bound =
|
||||
dyn_cast_or_null<IntegerAttr>(loopBound->get<Attribute>()))
|
||||
return bound.getValue();
|
||||
} else if (auto value = loopBound->dyn_cast<Value>()) {
|
||||
} else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
|
||||
const IntegerValueRangeLattice *lattice =
|
||||
getLatticeElementFor(op, value);
|
||||
if (lattice != nullptr)
|
||||
|
@ -66,9 +66,9 @@ AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
|
||||
}
|
||||
|
||||
LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
|
||||
if (Operation *op = point.dyn_cast<Operation *>())
|
||||
if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
|
||||
visitOperation(op);
|
||||
else if (Block *block = point.dyn_cast<Block *>())
|
||||
else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
|
||||
visitBlock(block);
|
||||
else
|
||||
return failure();
|
||||
@ -238,7 +238,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
|
||||
|
||||
unsigned firstIndex = 0;
|
||||
if (inputs.size() != lattices.size()) {
|
||||
if (point.dyn_cast<Operation *>()) {
|
||||
if (llvm::dyn_cast_if_present<Operation *>(point)) {
|
||||
if (!inputs.empty())
|
||||
firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
|
||||
visitNonControlFlowArgumentsImpl(
|
||||
@ -316,9 +316,9 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
|
||||
|
||||
LogicalResult
|
||||
AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
|
||||
if (Operation *op = point.dyn_cast<Operation *>())
|
||||
if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
|
||||
visitOperation(op);
|
||||
else if (point.dyn_cast<Block *>())
|
||||
else if (llvm::dyn_cast_if_present<Block *>(point))
|
||||
// For backward dataflow, we don't have to do any work for the blocks
|
||||
// themselves. CFG edges between blocks are processed by the BranchOp
|
||||
// logic in `visitOperation`, and entry blocks for functions are tied
|
||||
|
@ -39,21 +39,21 @@ void ProgramPoint::print(raw_ostream &os) const {
|
||||
os << "<NULL POINT>";
|
||||
return;
|
||||
}
|
||||
if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
|
||||
if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
|
||||
return programPoint->print(os);
|
||||
if (auto *op = dyn_cast<Operation *>())
|
||||
if (auto *op = llvm::dyn_cast<Operation *>(*this))
|
||||
return op->print(os);
|
||||
if (auto value = dyn_cast<Value>())
|
||||
if (auto value = llvm::dyn_cast<Value>(*this))
|
||||
return value.print(os);
|
||||
return get<Block *>()->print(os);
|
||||
}
|
||||
|
||||
Location ProgramPoint::getLoc() const {
|
||||
if (auto *programPoint = dyn_cast<GenericProgramPoint *>())
|
||||
if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
|
||||
return programPoint->getLoc();
|
||||
if (auto *op = dyn_cast<Operation *>())
|
||||
if (auto *op = llvm::dyn_cast<Operation *>(*this))
|
||||
return op->getLoc();
|
||||
if (auto value = dyn_cast<Value>())
|
||||
if (auto value = llvm::dyn_cast<Value>(*this))
|
||||
return value.getLoc();
|
||||
return get<Block *>()->getParent()->getLoc();
|
||||
}
|
||||
|
@ -2060,7 +2060,7 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
|
||||
if (parseToken(Token::r_paren, "expected ')' in location"))
|
||||
return failure();
|
||||
|
||||
if (auto *op = opOrArgument.dyn_cast<Operation *>())
|
||||
if (auto *op = llvm::dyn_cast_if_present<Operation *>(opOrArgument))
|
||||
op->setLoc(directLoc);
|
||||
else
|
||||
opOrArgument.get<BlockArgument>().setLoc(directLoc);
|
||||
|
@ -47,7 +47,7 @@ SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
|
||||
DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
|
||||
DictionaryAttr attributeDict;
|
||||
if (!mlirAttributeIsNull(attributes))
|
||||
attributeDict = unwrap(attributes).cast<DictionaryAttr>();
|
||||
attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
|
||||
return attributeDict;
|
||||
}
|
||||
|
||||
|
@ -1190,9 +1190,9 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
// TODO: safer and more flexible to store data type in actual op instead?
|
||||
static Type getSpMatElemType(Value spMat) {
|
||||
if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
|
||||
return op.getValues().getType().cast<MemRefType>().getElementType();
|
||||
return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||
if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
|
||||
return op.getValues().getType().cast<MemRefType>().getElementType();
|
||||
return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||
llvm_unreachable("cannot find spmat def");
|
||||
}
|
||||
|
||||
@ -1235,7 +1235,7 @@ LogicalResult ConvertCreateDnVecOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
|
||||
if (!getTypeConverter()->useOpaquePointers())
|
||||
pVec = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVec);
|
||||
Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
|
||||
Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
|
||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
dType.getIntOrFloatBitWidth());
|
||||
auto handle =
|
||||
@ -1271,7 +1271,7 @@ LogicalResult ConvertCreateDnMatOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
|
||||
if (!getTypeConverter()->useOpaquePointers())
|
||||
pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
|
||||
Type dType = op.getMemref().getType().cast<MemRefType>().getElementType();
|
||||
Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
|
||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
dType.getIntOrFloatBitWidth());
|
||||
auto handle =
|
||||
@ -1315,8 +1315,8 @@ LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
|
||||
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
|
||||
}
|
||||
Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
|
||||
Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
|
||||
Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
|
||||
Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||
auto iw = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
|
||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
@ -1350,9 +1350,9 @@ LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
|
||||
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
|
||||
}
|
||||
Type pType = op.getRowPos().getType().cast<MemRefType>().getElementType();
|
||||
Type iType = op.getColIdxs().getType().cast<MemRefType>().getElementType();
|
||||
Type dType = op.getValues().getType().cast<MemRefType>().getElementType();
|
||||
Type pType = llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
|
||||
Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
|
||||
Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||
auto pw = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth());
|
||||
auto iw = rewriter.create<LLVM::ConstantOp>(
|
||||
|
@ -405,7 +405,7 @@ LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) {
|
||||
return failure();
|
||||
if (!(*converted)) // Conversion to default is 0.
|
||||
return 0;
|
||||
if (auto explicitSpace = converted->dyn_cast_or_null<IntegerAttr>())
|
||||
if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
|
||||
return explicitSpace.getInt();
|
||||
return failure();
|
||||
}
|
||||
|
@ -671,7 +671,7 @@ struct GlobalMemrefOpLowering
|
||||
|
||||
Attribute initialValue = nullptr;
|
||||
if (!global.isExternal() && !global.isUninitialized()) {
|
||||
auto elementsAttr = global.getInitialValue()->cast<ElementsAttr>();
|
||||
auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
|
||||
initialValue = elementsAttr;
|
||||
|
||||
// For scalar memrefs, the global variable created is of the element type,
|
||||
|
@ -412,10 +412,10 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
|
||||
auto *ans = cast<TypeAnswer>(answer);
|
||||
if (isa<pdl::RangeType>(val.getType()))
|
||||
builder.create<pdl_interp::CheckTypesOp>(
|
||||
loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
|
||||
loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
|
||||
else
|
||||
builder.create<pdl_interp::CheckTypeOp>(
|
||||
loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
|
||||
loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
|
||||
break;
|
||||
}
|
||||
case Predicates::AttributeQuestion: {
|
||||
|
@ -300,7 +300,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
||||
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
|
||||
|
||||
// tosa::ErfOp
|
||||
if (isa<tosa::ErfOp>(op) && elementTy.isa<FloatType>())
|
||||
if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
|
||||
return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
|
||||
|
||||
// tosa::GreaterOp
|
||||
@ -1885,7 +1885,7 @@ public:
|
||||
|
||||
auto addDynamicDimension = [&](Value source, int64_t dim) {
|
||||
auto dynamicDim = tensor::createDimValue(builder, loc, source, dim);
|
||||
if (auto dimValue = dynamicDim.value().dyn_cast<Value>())
|
||||
if (auto dimValue = llvm::dyn_cast_if_present<Value>(dynamicDim.value()))
|
||||
results.push_back(dimValue);
|
||||
};
|
||||
|
||||
|
@ -121,11 +121,11 @@ void mlirDebuggerCursorSelectParentIRUnit() {
|
||||
return;
|
||||
}
|
||||
IRUnit *unit = &state.cursor;
|
||||
if (auto *op = unit->dyn_cast<Operation *>()) {
|
||||
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
|
||||
state.cursor = op->getBlock();
|
||||
} else if (auto *region = unit->dyn_cast<Region *>()) {
|
||||
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
|
||||
state.cursor = region->getParentOp();
|
||||
} else if (auto *block = unit->dyn_cast<Block *>()) {
|
||||
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
|
||||
state.cursor = block->getParent();
|
||||
} else {
|
||||
llvm::outs() << "Current cursor is not a valid IRUnit";
|
||||
@ -142,14 +142,14 @@ void mlirDebuggerCursorSelectChildIRUnit(int index) {
|
||||
return;
|
||||
}
|
||||
IRUnit *unit = &state.cursor;
|
||||
if (auto *op = unit->dyn_cast<Operation *>()) {
|
||||
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
|
||||
if (index < 0 || index >= static_cast<int>(op->getNumRegions())) {
|
||||
llvm::outs() << "Index invalid, op has " << op->getNumRegions()
|
||||
<< " but got " << index << "\n";
|
||||
return;
|
||||
}
|
||||
state.cursor = &op->getRegion(index);
|
||||
} else if (auto *region = unit->dyn_cast<Region *>()) {
|
||||
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
|
||||
auto block = region->begin();
|
||||
int count = 0;
|
||||
while (block != region->end() && count != index) {
|
||||
@ -163,7 +163,7 @@ void mlirDebuggerCursorSelectChildIRUnit(int index) {
|
||||
return;
|
||||
}
|
||||
state.cursor = &*block;
|
||||
} else if (auto *block = unit->dyn_cast<Block *>()) {
|
||||
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
|
||||
auto op = block->begin();
|
||||
int count = 0;
|
||||
while (op != block->end() && count != index) {
|
||||
@ -192,14 +192,14 @@ void mlirDebuggerCursorSelectPreviousIRUnit() {
|
||||
return;
|
||||
}
|
||||
IRUnit *unit = &state.cursor;
|
||||
if (auto *op = unit->dyn_cast<Operation *>()) {
|
||||
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
|
||||
Operation *previous = op->getPrevNode();
|
||||
if (!previous) {
|
||||
llvm::outs() << "No previous operation in the current block\n";
|
||||
return;
|
||||
}
|
||||
state.cursor = previous;
|
||||
} else if (auto *region = unit->dyn_cast<Region *>()) {
|
||||
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
|
||||
llvm::outs() << "Has region\n";
|
||||
Operation *parent = region->getParentOp();
|
||||
if (!parent) {
|
||||
@ -212,7 +212,7 @@ void mlirDebuggerCursorSelectPreviousIRUnit() {
|
||||
}
|
||||
state.cursor =
|
||||
®ion->getParentOp()->getRegion(region->getRegionNumber() - 1);
|
||||
} else if (auto *block = unit->dyn_cast<Block *>()) {
|
||||
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
|
||||
Block *previous = block->getPrevNode();
|
||||
if (!previous) {
|
||||
llvm::outs() << "No previous block in the current region\n";
|
||||
@ -234,14 +234,14 @@ void mlirDebuggerCursorSelectNextIRUnit() {
|
||||
return;
|
||||
}
|
||||
IRUnit *unit = &state.cursor;
|
||||
if (auto *op = unit->dyn_cast<Operation *>()) {
|
||||
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*unit)) {
|
||||
Operation *next = op->getNextNode();
|
||||
if (!next) {
|
||||
llvm::outs() << "No next operation in the current block\n";
|
||||
return;
|
||||
}
|
||||
state.cursor = next;
|
||||
} else if (auto *region = unit->dyn_cast<Region *>()) {
|
||||
} else if (auto *region = llvm::dyn_cast_if_present<Region *>(*unit)) {
|
||||
Operation *parent = region->getParentOp();
|
||||
if (!parent) {
|
||||
llvm::outs() << "No parent operation for the current region\n";
|
||||
@ -253,7 +253,7 @@ void mlirDebuggerCursorSelectNextIRUnit() {
|
||||
}
|
||||
state.cursor =
|
||||
®ion->getParentOp()->getRegion(region->getRegionNumber() + 1);
|
||||
} else if (auto *block = unit->dyn_cast<Block *>()) {
|
||||
} else if (auto *block = llvm::dyn_cast_if_present<Block *>(*unit)) {
|
||||
Block *next = block->getNextNode();
|
||||
if (!next) {
|
||||
llvm::outs() << "No next block in the current region\n";
|
||||
|
@ -1212,7 +1212,7 @@ static void materializeConstants(OpBuilder &b, Location loc,
|
||||
actualValues.reserve(values.size());
|
||||
auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
|
||||
for (OpFoldResult ofr : values) {
|
||||
if (auto value = ofr.dyn_cast<Value>()) {
|
||||
if (auto value = llvm::dyn_cast_if_present<Value>(ofr)) {
|
||||
actualValues.push_back(value);
|
||||
continue;
|
||||
}
|
||||
@ -4599,7 +4599,7 @@ void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
|
||||
if (staticDim.has_value())
|
||||
return builder.create<arith::ConstantIndexOp>(result.location,
|
||||
*staticDim);
|
||||
return ofr.dyn_cast<Value>();
|
||||
return llvm::dyn_cast_if_present<Value>(ofr);
|
||||
});
|
||||
result.addOperands(basisValues);
|
||||
}
|
||||
|
@ -808,7 +808,7 @@ OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
|
||||
if (matchPattern(getRhs(), m_Zero()))
|
||||
return getLhs();
|
||||
/// or(x, <all ones>) -> <all ones>
|
||||
if (auto rhsAttr = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
|
||||
if (auto rhsAttr = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()))
|
||||
if (rhsAttr.getValue().isAllOnes())
|
||||
return rhsAttr;
|
||||
|
||||
@ -1249,7 +1249,7 @@ LogicalResult arith::ExtSIOp::verify() {
|
||||
|
||||
/// Always fold extension of FP constants.
|
||||
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
|
||||
auto constOperand = adaptor.getIn().dyn_cast_or_null<FloatAttr>();
|
||||
auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
@ -1702,7 +1702,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
|
||||
|
||||
// We are moving constants to the right side; So if lhs is constant rhs is
|
||||
// guaranteed to be a constant.
|
||||
if (auto lhs = adaptor.getLhs().dyn_cast_or_null<TypedAttr>()) {
|
||||
if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
|
||||
return constFoldBinaryOp<IntegerAttr>(
|
||||
adaptor.getOperands(), getI1SameShape(lhs.getType()),
|
||||
[pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
|
||||
@ -1772,8 +1772,8 @@ bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
|
||||
}
|
||||
|
||||
OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
|
||||
auto lhs = adaptor.getLhs().dyn_cast_or_null<FloatAttr>();
|
||||
auto rhs = adaptor.getRhs().dyn_cast_or_null<FloatAttr>();
|
||||
auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
|
||||
auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
|
||||
|
||||
// If one operand is NaN, making them both NaN does not change the result.
|
||||
if (lhs && lhs.getValue().isNaN())
|
||||
@ -2193,11 +2193,11 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
|
||||
// Constant-fold constant operands over non-splat constant condition.
|
||||
// select %cst_vec, %cst0, %cst1 => %cst2
|
||||
if (auto cond =
|
||||
adaptor.getCondition().dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
|
||||
if (auto lhs =
|
||||
adaptor.getTrueValue().dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
|
||||
if (auto rhs =
|
||||
adaptor.getFalseValue().dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
|
||||
SmallVector<Attribute> results;
|
||||
results.reserve(static_cast<size_t>(cond.getNumElements()));
|
||||
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
|
||||
|
@ -184,7 +184,7 @@ struct SelectOpInterface
|
||||
|
||||
// If the buffers have different types, they differ only in their layout
|
||||
// map.
|
||||
auto memrefType = trueType->cast<MemRefType>();
|
||||
auto memrefType = llvm::cast<MemRefType>(*trueType);
|
||||
return getMemRefTypeWithFullyDynamicLayout(
|
||||
RankedTensorType::get(memrefType.getShape(),
|
||||
memrefType.getElementType()),
|
||||
|
@ -33,8 +33,8 @@ LogicalResult mlir::foldDynamicIndexList(Builder &b,
|
||||
if (ofr.is<Attribute>())
|
||||
continue;
|
||||
// Newly static, move from Value to constant.
|
||||
if (auto cstOp =
|
||||
ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>()) {
|
||||
if (auto cstOp = llvm::dyn_cast_if_present<Value>(ofr)
|
||||
.getDefiningOp<arith::ConstantIndexOp>()) {
|
||||
ofr = b.getIndexAttr(cstOp.value());
|
||||
valuesChanged = true;
|
||||
}
|
||||
@ -56,9 +56,9 @@ llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
|
||||
|
||||
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
|
||||
OpFoldResult ofr) {
|
||||
if (auto value = ofr.dyn_cast<Value>())
|
||||
if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
|
||||
return value;
|
||||
auto attr = dyn_cast<IntegerAttr>(ofr.dyn_cast<Attribute>());
|
||||
auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
|
||||
assert(attr && "expect the op fold result casts to an integer attribute");
|
||||
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
|
||||
populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
|
||||
}
|
||||
FailureOr<Value> alloc = options.createAlloc(
|
||||
rewriter, loc, allocType->cast<MemRefType>(), dynamicDims);
|
||||
rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
|
||||
|
@ -59,7 +59,8 @@ static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
|
||||
|
||||
/// Return the func::FuncOp called by `callOp`.
|
||||
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
|
||||
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
||||
SymbolRefAttr sym =
|
||||
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
|
||||
if (!sym)
|
||||
return nullptr;
|
||||
return dyn_cast_or_null<func::FuncOp>(
|
||||
|
@ -80,7 +80,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
|
||||
|
||||
/// Return the FuncOp called by `callOp`.
|
||||
static FuncOp getCalledFunction(CallOpInterface callOp) {
|
||||
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
||||
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
|
||||
if (!sym)
|
||||
return nullptr;
|
||||
return dyn_cast_or_null<FuncOp>(
|
||||
|
@ -995,7 +995,7 @@ static void annotateOpsWithAliasSets(Operation *op,
|
||||
op->walk([&](Operation *op) {
|
||||
SmallVector<Attribute> aliasSets;
|
||||
for (OpResult opResult : op->getOpResults()) {
|
||||
if (opResult.getType().isa<TensorType>()) {
|
||||
if (llvm::isa<TensorType>(opResult.getType())) {
|
||||
SmallVector<Attribute> aliases;
|
||||
state.applyOnAliases(opResult, [&](Value alias) {
|
||||
std::string buffer;
|
||||
|
@ -238,7 +238,7 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
|
||||
|
||||
/// Return the func::FuncOp called by `callOp`.
|
||||
static func::FuncOp getCalledFunction(func::CallOp callOp) {
|
||||
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
||||
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
|
||||
if (!sym)
|
||||
return nullptr;
|
||||
return dyn_cast_or_null<func::FuncOp>(
|
||||
|
@ -90,7 +90,8 @@ OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
|
||||
ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
|
||||
ArrayAttr arrayAttr =
|
||||
llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
|
||||
if (arrayAttr && arrayAttr.size() == 2)
|
||||
return arrayAttr[1];
|
||||
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
|
||||
@ -103,7 +104,8 @@ OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
|
||||
ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null<ArrayAttr>();
|
||||
ArrayAttr arrayAttr =
|
||||
llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
|
||||
if (arrayAttr && arrayAttr.size() == 2)
|
||||
return arrayAttr[0];
|
||||
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
|
||||
|
@ -94,7 +94,7 @@ DataLayoutEntryAttr DataLayoutEntryAttr::parse(AsmParser &parser) {
|
||||
|
||||
void DataLayoutEntryAttr::print(AsmPrinter &os) const {
|
||||
os << DataLayoutEntryAttr::kAttrKeyword << "<";
|
||||
if (auto type = getKey().dyn_cast<Type>())
|
||||
if (auto type = llvm::dyn_cast_if_present<Type>(getKey()))
|
||||
os << type;
|
||||
else
|
||||
os << "\"" << getKey().get<StringAttr>().strref() << "\"";
|
||||
@ -151,7 +151,7 @@ DataLayoutSpecAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
DenseSet<Type> types;
|
||||
DenseSet<StringAttr> ids;
|
||||
for (DataLayoutEntryInterface entry : entries) {
|
||||
if (auto type = entry.getKey().dyn_cast<Type>()) {
|
||||
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
|
||||
if (!types.insert(type).second)
|
||||
return emitError() << "repeated layout entry key: " << type;
|
||||
} else {
|
||||
|
@ -493,7 +493,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
|
||||
// error. All other canonicalization is done in the fold method.
|
||||
bool requiresConst = !rawConstantIndices.empty() &&
|
||||
currType.isa_and_nonnull<LLVMStructType>();
|
||||
if (Value val = iter.dyn_cast<Value>()) {
|
||||
if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
|
||||
APInt intC;
|
||||
if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
|
||||
intC.isSignedIntN(kGEPConstantBitWidth)) {
|
||||
@ -598,7 +598,7 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
|
||||
llvm::interleaveComma(
|
||||
GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
|
||||
[&](PointerUnion<IntegerAttr, Value> cst) {
|
||||
if (Value val = cst.dyn_cast<Value>())
|
||||
if (Value val = llvm::dyn_cast_if_present<Value>(cst))
|
||||
printer.printOperand(val);
|
||||
else
|
||||
printer << cst.get<IntegerAttr>().getInt();
|
||||
@ -2495,7 +2495,7 @@ OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
|
||||
!integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
|
||||
|
||||
PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
|
||||
if (Value val = existing.dyn_cast<Value>())
|
||||
if (Value val = llvm::dyn_cast_if_present<Value>(existing))
|
||||
gepArgs.emplace_back(val);
|
||||
else
|
||||
gepArgs.emplace_back(existing.get<IntegerAttr>().getInt());
|
||||
|
@ -261,7 +261,7 @@ DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
|
||||
|
||||
static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
|
||||
return llvm::all_of(gepOp.getIndices(), [](auto index) {
|
||||
auto indexAttr = index.template dyn_cast<IntegerAttr>();
|
||||
auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
|
||||
return indexAttr && indexAttr.getValue() == 0;
|
||||
});
|
||||
}
|
||||
@ -289,7 +289,7 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
|
||||
// Ensures all indices are static and fetches them.
|
||||
SmallVector<IntegerAttr> indices;
|
||||
for (auto index : gep.getIndices()) {
|
||||
IntegerAttr indexInt = index.dyn_cast<IntegerAttr>();
|
||||
IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
|
||||
if (!indexInt)
|
||||
return {};
|
||||
indices.push_back(indexInt);
|
||||
@ -310,7 +310,7 @@ static Type computeReachedGEPType(LLVM::GEPOp gep) {
|
||||
for (IntegerAttr index : llvm::drop_begin(indices)) {
|
||||
// Ensure the structure of the type being indexed can be reasoned about.
|
||||
// This includes rejecting any potential typed pointer.
|
||||
auto destructurable = selectedType.dyn_cast<DestructurableTypeInterface>();
|
||||
auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(selectedType);
|
||||
if (!destructurable)
|
||||
return {};
|
||||
|
||||
@ -343,7 +343,7 @@ LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
|
||||
bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
|
||||
SmallPtrSetImpl<Attribute> &usedIndices,
|
||||
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
|
||||
auto basePtrType = getBase().getType().dyn_cast<LLVM::LLVMPointerType>();
|
||||
auto basePtrType = llvm::dyn_cast<LLVM::LLVMPointerType>(getBase().getType());
|
||||
if (!basePtrType)
|
||||
return false;
|
||||
|
||||
@ -359,7 +359,7 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
|
||||
return false;
|
||||
auto firstLevelIndex = cast<IntegerAttr>(getIndices()[1]);
|
||||
assert(slot.elementPtrs.contains(firstLevelIndex));
|
||||
if (!slot.elementPtrs.at(firstLevelIndex).isa<LLVM::LLVMPointerType>())
|
||||
if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
|
||||
return false;
|
||||
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
|
||||
usedIndices.insert(firstLevelIndex);
|
||||
@ -369,7 +369,7 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
|
||||
DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
|
||||
DenseMap<Attribute, MemorySlot> &subslots,
|
||||
RewriterBase &rewriter) {
|
||||
IntegerAttr firstLevelIndex = getIndices()[1].dyn_cast<IntegerAttr>();
|
||||
IntegerAttr firstLevelIndex = llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
|
||||
const MemorySlot &newSlot = subslots.at(firstLevelIndex);
|
||||
|
||||
ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
|
||||
@ -414,7 +414,7 @@ LLVM::LLVMStructType::getSubelementIndexMap() {
|
||||
}
|
||||
|
||||
Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) {
|
||||
auto indexAttr = index.dyn_cast<IntegerAttr>();
|
||||
auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
|
||||
if (!indexAttr || !indexAttr.getType().isInteger(32))
|
||||
return {};
|
||||
int32_t indexInt = indexAttr.getInt();
|
||||
@ -439,7 +439,7 @@ LLVM::LLVMArrayType::getSubelementIndexMap() const {
|
||||
}
|
||||
|
||||
Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
|
||||
auto indexAttr = index.dyn_cast<IntegerAttr>();
|
||||
auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
|
||||
if (!indexAttr || !indexAttr.getType().isInteger(32))
|
||||
return {};
|
||||
int32_t indexInt = indexAttr.getInt();
|
||||
|
@ -354,7 +354,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
|
||||
auto newType = llvm::cast<LLVMPointerType>(newEntry.getKey().get<Type>());
|
||||
const auto *it =
|
||||
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
|
||||
if (auto type = entry.getKey().dyn_cast<Type>()) {
|
||||
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
|
||||
return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
|
||||
newType.getAddressSpace();
|
||||
}
|
||||
@ -362,7 +362,7 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
|
||||
});
|
||||
if (it == oldLayout.end()) {
|
||||
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
|
||||
if (auto type = entry.getKey().dyn_cast<Type>()) {
|
||||
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
|
||||
return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
|
||||
}
|
||||
return false;
|
||||
|
@ -2368,7 +2368,7 @@ transform::TileOp::apply(TransformResults &transformResults,
|
||||
sizes.reserve(tileSizes.size());
|
||||
unsigned dynamicIdx = 0;
|
||||
for (OpFoldResult ofr : getMixedSizes()) {
|
||||
if (auto attr = ofr.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
|
||||
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
||||
getLoc(), cast<IntegerAttr>(attr).getInt()));
|
||||
continue;
|
||||
@ -2794,7 +2794,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
|
||||
sizes.reserve(tileSizes.size());
|
||||
unsigned dynamicIdx = 0;
|
||||
for (OpFoldResult ofr : getMixedSizes()) {
|
||||
if (auto attr = ofr.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
|
||||
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
||||
getLoc(), cast<IntegerAttr>(attr).getInt()));
|
||||
} else {
|
||||
|
@ -1447,7 +1447,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
|
||||
cast<LinalgOp>(genericOp.getOperation())
|
||||
.createLoopRanges(rewriter, genericOp.getLoc());
|
||||
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
|
||||
if (auto attr = ofr.dyn_cast<Attribute>())
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
|
||||
return cast<IntegerAttr>(attr).getInt() == value;
|
||||
llvm::APInt actual;
|
||||
return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
|
||||
|
@ -229,7 +229,7 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
|
||||
// to look for the bound.
|
||||
LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
|
||||
Value size;
|
||||
if (auto attr = rangeValue.size.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
|
||||
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
|
||||
} else {
|
||||
Value materializedSize =
|
||||
|
@ -92,7 +92,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
|
||||
rewriter, op.getLoc(), d0 + d1 - d2,
|
||||
{iterationSpace[dimension].offset, iterationSpace[dimension].size,
|
||||
minSplitPoint});
|
||||
if (auto attr = remainingSize.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) {
|
||||
if (cast<IntegerAttr>(attr).getValue().isZero())
|
||||
return {op, TilingInterface()};
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ using namespace mlir::scf;
|
||||
static bool isZero(OpFoldResult v) {
|
||||
if (!v)
|
||||
return false;
|
||||
if (auto attr = v.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
|
||||
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
return intAttr && intAttr.getValue().isZero();
|
||||
}
|
||||
@ -104,7 +104,7 @@ void mlir::linalg::transformIndexOps(
|
||||
/// checked at runtime.
|
||||
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
|
||||
OpFoldResult value) {
|
||||
if (auto attr = value.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
|
||||
assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
|
||||
"expected strictly positive tile size and divisor");
|
||||
return;
|
||||
|
@ -1135,7 +1135,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
// Given an OpFoldResult, return an index-typed value.
|
||||
auto getIdxValue = [&](OpFoldResult ofr) {
|
||||
if (auto val = ofr.dyn_cast<Value>())
|
||||
if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
|
||||
return val;
|
||||
return rewriter
|
||||
.create<arith::ConstantIndexOp>(
|
||||
|
@ -1646,7 +1646,7 @@ static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
|
||||
ArrayRef<OpFoldResult> ofrs) {
|
||||
SmallVector<Value> result;
|
||||
for (auto o : ofrs) {
|
||||
if (auto val = o.template dyn_cast<Value>()) {
|
||||
if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
|
||||
result.push_back(val);
|
||||
} else {
|
||||
result.push_back(rewriter.create<arith::ConstantIndexOp>(
|
||||
@ -1954,8 +1954,8 @@ struct PadOpVectorizationWithTransferWritePattern
|
||||
continue;
|
||||
|
||||
// Other cases: Take a deeper look at defining ops of values.
|
||||
auto v1 = size1.dyn_cast<Value>();
|
||||
auto v2 = size2.dyn_cast<Value>();
|
||||
auto v1 = llvm::dyn_cast_if_present<Value>(size1);
|
||||
auto v2 = llvm::dyn_cast_if_present<Value>(size2);
|
||||
if (!v1 || !v2)
|
||||
return false;
|
||||
|
||||
|
@ -970,7 +970,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
|
||||
auto dim = it.index();
|
||||
auto size = it.value();
|
||||
curr.push_back(dim);
|
||||
auto attr = size.dyn_cast<Attribute>();
|
||||
auto attr = llvm::dyn_cast_if_present<Attribute>(size);
|
||||
if (attr && cast<IntegerAttr>(attr).getInt() == 1)
|
||||
continue;
|
||||
reassociation.emplace_back(ReassociationIndices{});
|
||||
|
@ -64,7 +64,7 @@ static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static bool isSupportedElementType(Type type) {
|
||||
return type.isa<MemRefType>() ||
|
||||
return llvm::isa<MemRefType>(type) ||
|
||||
OpBuilder(type.getContext()).getZeroAttr(type);
|
||||
}
|
||||
|
||||
@ -110,7 +110,7 @@ void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
|
||||
SmallVector<DestructurableMemorySlot>
|
||||
memref::AllocaOp::getDestructurableSlots() {
|
||||
MemRefType memrefType = getType();
|
||||
auto destructurable = memrefType.dyn_cast<DestructurableTypeInterface>();
|
||||
auto destructurable = llvm::dyn_cast<DestructurableTypeInterface>(memrefType);
|
||||
if (!destructurable)
|
||||
return {};
|
||||
|
||||
@ -134,7 +134,7 @@ memref::AllocaOp::destructure(const DestructurableMemorySlot &slot,
|
||||
|
||||
DenseMap<Attribute, MemorySlot> slotMap;
|
||||
|
||||
auto memrefType = getType().cast<DestructurableTypeInterface>();
|
||||
auto memrefType = llvm::cast<DestructurableTypeInterface>(getType());
|
||||
for (Attribute usedIndex : usedIndices) {
|
||||
Type elemType = memrefType.getTypeAtIndex(usedIndex);
|
||||
MemRefType elemPtr = MemRefType::get({}, elemType);
|
||||
@ -281,7 +281,7 @@ struct MemRefDestructurableTypeExternalModel
|
||||
MemRefDestructurableTypeExternalModel, MemRefType> {
|
||||
std::optional<DenseMap<Attribute, Type>>
|
||||
getSubelementIndexMap(Type type) const {
|
||||
auto memrefType = type.cast<MemRefType>();
|
||||
auto memrefType = llvm::cast<MemRefType>(type);
|
||||
constexpr int64_t maxMemrefSizeForDestructuring = 16;
|
||||
if (!memrefType.hasStaticShape() ||
|
||||
memrefType.getNumElements() > maxMemrefSizeForDestructuring ||
|
||||
@ -298,15 +298,15 @@ struct MemRefDestructurableTypeExternalModel
|
||||
}
|
||||
|
||||
Type getTypeAtIndex(Type type, Attribute index) const {
|
||||
auto memrefType = type.cast<MemRefType>();
|
||||
auto coordArrAttr = index.dyn_cast<ArrayAttr>();
|
||||
auto memrefType = llvm::cast<MemRefType>(type);
|
||||
auto coordArrAttr = llvm::dyn_cast<ArrayAttr>(index);
|
||||
if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size())
|
||||
return {};
|
||||
|
||||
Type indexType = IndexType::get(memrefType.getContext());
|
||||
for (const auto &[coordAttr, dimSize] :
|
||||
llvm::zip(coordArrAttr, memrefType.getShape())) {
|
||||
auto coord = coordAttr.dyn_cast<IntegerAttr>();
|
||||
auto coord = llvm::dyn_cast<IntegerAttr>(coordAttr);
|
||||
if (!coord || coord.getType() != indexType || coord.getInt() < 0 ||
|
||||
coord.getInt() >= dimSize)
|
||||
return {};
|
||||
|
@ -970,7 +970,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
|
||||
return unusedDims;
|
||||
|
||||
for (const auto &dim : llvm::enumerate(sizes))
|
||||
if (auto attr = dim.value().dyn_cast<Attribute>())
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
|
||||
if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
|
||||
unusedDims.set(dim.index());
|
||||
|
||||
@ -1042,7 +1042,7 @@ llvm::SmallBitVector SubViewOp::getDroppedDims() {
|
||||
|
||||
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
|
||||
// All forms of folding require a known index.
|
||||
auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
|
||||
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
|
||||
if (!index)
|
||||
return {};
|
||||
|
||||
|
@ -56,7 +56,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
|
||||
// Because we only support input strides of 1, the output stride is also
|
||||
// always 1.
|
||||
if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
|
||||
Attribute attr = valueOrAttr.dyn_cast<Attribute>();
|
||||
Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
|
||||
return attr && cast<IntegerAttr>(attr).getInt() == 1;
|
||||
})) {
|
||||
strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
|
||||
@ -86,8 +86,9 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
|
||||
}
|
||||
|
||||
sizes.push_back(opSize);
|
||||
Attribute opOffsetAttr = opOffset.dyn_cast<Attribute>(),
|
||||
sourceOffsetAttr = sourceOffset.dyn_cast<Attribute>();
|
||||
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
|
||||
sourceOffsetAttr =
|
||||
llvm::dyn_cast_if_present<Attribute>(sourceOffset);
|
||||
|
||||
if (opOffsetAttr && sourceOffsetAttr) {
|
||||
// If both offsets are static we can simply calculate the combined
|
||||
@ -101,7 +102,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
|
||||
AffineExpr expr = rewriter.getAffineConstantExpr(0);
|
||||
SmallVector<Value> affineApplyOperands;
|
||||
for (auto valueOrAttr : {opOffset, sourceOffset}) {
|
||||
if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
|
||||
expr = expr + cast<IntegerAttr>(attr).getInt();
|
||||
} else {
|
||||
expr =
|
||||
|
@ -520,7 +520,7 @@ checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
|
||||
<< operandName << " operand appears more than once";
|
||||
|
||||
mlir::Type varType = operand.getType();
|
||||
auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
|
||||
auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
|
||||
auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
|
||||
if (!decl)
|
||||
return op->emitOpError()
|
||||
|
@ -802,10 +802,10 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange map_operands,
|
||||
for (const auto &mapTypeOp : *map_types) {
|
||||
int64_t mapTypeBits = 0x00;
|
||||
|
||||
if (!mapTypeOp.isa<mlir::IntegerAttr>())
|
||||
if (!llvm::isa<mlir::IntegerAttr>(mapTypeOp))
|
||||
return failure();
|
||||
|
||||
mapTypeBits = mapTypeOp.cast<mlir::IntegerAttr>().getInt();
|
||||
mapTypeBits = llvm::cast<mlir::IntegerAttr>(mapTypeOp).getInt();
|
||||
|
||||
bool to =
|
||||
bitAnd(mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
|
||||
|
@ -381,7 +381,7 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
|
||||
// map.
|
||||
auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
|
||||
#ifndef NDEBUG
|
||||
auto iterRanked = initArgBufferType->cast<MemRefType>();
|
||||
auto iterRanked = llvm::cast<MemRefType>(*initArgBufferType);
|
||||
assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
|
||||
"expected same shape");
|
||||
assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() &&
|
||||
@ -802,7 +802,7 @@ struct WhileOpInterface
|
||||
if (!isa<TensorType>(bbArg.getType()))
|
||||
return bbArg.getType();
|
||||
// TODO: error handling
|
||||
return bufferization::getBufferType(bbArg, options)->cast<Type>();
|
||||
return llvm::cast<Type>(*bufferization::getBufferType(bbArg, options));
|
||||
}));
|
||||
|
||||
// Construct a new scf.while op with memref instead of tensor values.
|
||||
|
@ -88,10 +88,10 @@ LogicalResult scf::addLoopRangeConstraints(FlatAffineValueConstraints &cstr,
|
||||
return failure();
|
||||
|
||||
unsigned dimIv = cstr.appendDimVar(iv);
|
||||
auto lbv = lb.dyn_cast<Value>();
|
||||
auto lbv = llvm::dyn_cast_if_present<Value>(lb);
|
||||
unsigned symLb =
|
||||
lbv ? cstr.appendSymbolVar(lbv) : cstr.appendSymbolVar(/*num=*/1);
|
||||
auto ubv = ub.dyn_cast<Value>();
|
||||
auto ubv = llvm::dyn_cast_if_present<Value>(ub);
|
||||
unsigned symUb =
|
||||
ubv ? cstr.appendSymbolVar(ubv) : cstr.appendSymbolVar(/*num=*/1);
|
||||
|
||||
|
@ -152,7 +152,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
|
||||
auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
|
||||
if (getIndices().size() == 1 &&
|
||||
constructOp.getConstituents().size() == type.getNumElements()) {
|
||||
auto i = getIndices().begin()->cast<IntegerAttr>();
|
||||
auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
|
||||
return constructOp.getConstituents()[i.getValue().getSExtValue()];
|
||||
}
|
||||
}
|
||||
|
@ -1562,8 +1562,8 @@ LogicalResult spirv::BitcastOp::verify() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult spirv::ConvertPtrToUOp::verify() {
|
||||
auto operandType = getPointer().getType().cast<spirv::PointerType>();
|
||||
auto resultType = getResult().getType().cast<spirv::ScalarType>();
|
||||
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
|
||||
auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
|
||||
if (!resultType || !resultType.isSignlessInteger())
|
||||
return emitError("result must be a scalar type of unsigned integer");
|
||||
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
|
||||
@ -1583,8 +1583,8 @@ LogicalResult spirv::ConvertPtrToUOp::verify() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult spirv::ConvertUToPtrOp::verify() {
|
||||
auto operandType = getOperand().getType().cast<spirv::ScalarType>();
|
||||
auto resultType = getResult().getType().cast<spirv::PointerType>();
|
||||
auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
|
||||
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
|
||||
if (!operandType || !operandType.isSignlessInteger())
|
||||
return emitError("result must be a scalar type of unsigned integer");
|
||||
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
|
||||
|
@ -125,23 +125,23 @@ Type CompositeType::getElementType(unsigned index) const {
|
||||
}
|
||||
|
||||
unsigned CompositeType::getNumElements() const {
|
||||
if (auto arrayType = dyn_cast<ArrayType>())
|
||||
if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
|
||||
return arrayType.getNumElements();
|
||||
if (auto matrixType = dyn_cast<MatrixType>())
|
||||
if (auto matrixType = llvm::dyn_cast<MatrixType>(*this))
|
||||
return matrixType.getNumColumns();
|
||||
if (auto structType = dyn_cast<StructType>())
|
||||
if (auto structType = llvm::dyn_cast<StructType>(*this))
|
||||
return structType.getNumElements();
|
||||
if (auto vectorType = dyn_cast<VectorType>())
|
||||
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
|
||||
return vectorType.getNumElements();
|
||||
if (isa<CooperativeMatrixNVType>()) {
|
||||
if (llvm::isa<CooperativeMatrixNVType>(*this)) {
|
||||
llvm_unreachable(
|
||||
"invalid to query number of elements of spirv::CooperativeMatrix type");
|
||||
}
|
||||
if (isa<JointMatrixINTELType>()) {
|
||||
if (llvm::isa<JointMatrixINTELType>(*this)) {
|
||||
llvm_unreachable(
|
||||
"invalid to query number of elements of spirv::JointMatrix type");
|
||||
}
|
||||
if (isa<RuntimeArrayType>()) {
|
||||
if (llvm::isa<RuntimeArrayType>(*this)) {
|
||||
llvm_unreachable(
|
||||
"invalid to query number of elements of spirv::RuntimeArray type");
|
||||
}
|
||||
@ -149,8 +149,8 @@ unsigned CompositeType::getNumElements() const {
|
||||
}
|
||||
|
||||
bool CompositeType::hasCompileTimeKnownNumElements() const {
|
||||
return !isa<CooperativeMatrixNVType, JointMatrixINTELType,
|
||||
RuntimeArrayType>();
|
||||
return !llvm::isa<CooperativeMatrixNVType, JointMatrixINTELType,
|
||||
RuntimeArrayType>(*this);
|
||||
}
|
||||
|
||||
void CompositeType::getExtensions(
|
||||
@ -188,11 +188,11 @@ void CompositeType::getCapabilities(
|
||||
}
|
||||
|
||||
std::optional<int64_t> CompositeType::getSizeInBytes() {
|
||||
if (auto arrayType = dyn_cast<ArrayType>())
|
||||
if (auto arrayType = llvm::dyn_cast<ArrayType>(*this))
|
||||
return arrayType.getSizeInBytes();
|
||||
if (auto structType = dyn_cast<StructType>())
|
||||
if (auto structType = llvm::dyn_cast<StructType>(*this))
|
||||
return structType.getSizeInBytes();
|
||||
if (auto vectorType = dyn_cast<VectorType>()) {
|
||||
if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) {
|
||||
std::optional<int64_t> elementSize =
|
||||
llvm::cast<ScalarType>(vectorType.getElementType()).getSizeInBytes();
|
||||
if (!elementSize)
|
||||
@ -680,7 +680,7 @@ void ScalarType::getCapabilities(
|
||||
capabilities.push_back(ref); \
|
||||
} break
|
||||
|
||||
if (auto intType = dyn_cast<IntegerType>()) {
|
||||
if (auto intType = llvm::dyn_cast<IntegerType>(*this)) {
|
||||
switch (bitwidth) {
|
||||
WIDTH_CASE(Int, 8);
|
||||
WIDTH_CASE(Int, 16);
|
||||
@ -692,7 +692,7 @@ void ScalarType::getCapabilities(
|
||||
llvm_unreachable("invalid bitwidth to getCapabilities");
|
||||
}
|
||||
} else {
|
||||
assert(isa<FloatType>());
|
||||
assert(llvm::isa<FloatType>(*this));
|
||||
switch (bitwidth) {
|
||||
WIDTH_CASE(Float, 16);
|
||||
WIDTH_CASE(Float, 64);
|
||||
@ -735,22 +735,22 @@ bool SPIRVType::classof(Type type) {
|
||||
}
|
||||
|
||||
bool SPIRVType::isScalarOrVector() {
|
||||
return isIntOrFloat() || isa<VectorType>();
|
||||
return isIntOrFloat() || llvm::isa<VectorType>(*this);
|
||||
}
|
||||
|
||||
void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
||||
std::optional<StorageClass> storage) {
|
||||
if (auto scalarType = dyn_cast<ScalarType>()) {
|
||||
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
|
||||
scalarType.getExtensions(extensions, storage);
|
||||
} else if (auto compositeType = dyn_cast<CompositeType>()) {
|
||||
} else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
|
||||
compositeType.getExtensions(extensions, storage);
|
||||
} else if (auto imageType = dyn_cast<ImageType>()) {
|
||||
} else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
|
||||
imageType.getExtensions(extensions, storage);
|
||||
} else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
|
||||
} else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
|
||||
sampledImageType.getExtensions(extensions, storage);
|
||||
} else if (auto matrixType = dyn_cast<MatrixType>()) {
|
||||
} else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
|
||||
matrixType.getExtensions(extensions, storage);
|
||||
} else if (auto ptrType = dyn_cast<PointerType>()) {
|
||||
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
|
||||
ptrType.getExtensions(extensions, storage);
|
||||
} else {
|
||||
llvm_unreachable("invalid SPIR-V Type to getExtensions");
|
||||
@ -760,17 +760,17 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
|
||||
void SPIRVType::getCapabilities(
|
||||
SPIRVType::CapabilityArrayRefVector &capabilities,
|
||||
std::optional<StorageClass> storage) {
|
||||
if (auto scalarType = dyn_cast<ScalarType>()) {
|
||||
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this)) {
|
||||
scalarType.getCapabilities(capabilities, storage);
|
||||
} else if (auto compositeType = dyn_cast<CompositeType>()) {
|
||||
} else if (auto compositeType = llvm::dyn_cast<CompositeType>(*this)) {
|
||||
compositeType.getCapabilities(capabilities, storage);
|
||||
} else if (auto imageType = dyn_cast<ImageType>()) {
|
||||
} else if (auto imageType = llvm::dyn_cast<ImageType>(*this)) {
|
||||
imageType.getCapabilities(capabilities, storage);
|
||||
} else if (auto sampledImageType = dyn_cast<SampledImageType>()) {
|
||||
} else if (auto sampledImageType = llvm::dyn_cast<SampledImageType>(*this)) {
|
||||
sampledImageType.getCapabilities(capabilities, storage);
|
||||
} else if (auto matrixType = dyn_cast<MatrixType>()) {
|
||||
} else if (auto matrixType = llvm::dyn_cast<MatrixType>(*this)) {
|
||||
matrixType.getCapabilities(capabilities, storage);
|
||||
} else if (auto ptrType = dyn_cast<PointerType>()) {
|
||||
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
|
||||
ptrType.getCapabilities(capabilities, storage);
|
||||
} else {
|
||||
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
|
||||
@ -778,9 +778,9 @@ void SPIRVType::getCapabilities(
|
||||
}
|
||||
|
||||
std::optional<int64_t> SPIRVType::getSizeInBytes() {
|
||||
if (auto scalarType = dyn_cast<ScalarType>())
|
||||
if (auto scalarType = llvm::dyn_cast<ScalarType>(*this))
|
||||
return scalarType.getSizeInBytes();
|
||||
if (auto compositeType = dyn_cast<CompositeType>())
|
||||
if (auto compositeType = llvm::dyn_cast<CompositeType>(*this))
|
||||
return compositeType.getSizeInBytes();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
@ -856,9 +856,9 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getLhs() || !adaptor.getRhs())
|
||||
return nullptr;
|
||||
auto lhsShape = llvm::to_vector<6>(
|
||||
adaptor.getLhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||
llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
|
||||
auto rhsShape = llvm::to_vector<6>(
|
||||
adaptor.getRhs().cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||
llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
|
||||
SmallVector<int64_t, 6> resultShape;
|
||||
resultShape.append(lhsShape.begin(), lhsShape.end());
|
||||
resultShape.append(rhsShape.begin(), rhsShape.end());
|
||||
@ -989,7 +989,7 @@ OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
|
||||
if (!operand)
|
||||
return false;
|
||||
extents.push_back(llvm::to_vector<6>(
|
||||
operand.cast<DenseIntElementsAttr>().getValues<int64_t>()));
|
||||
llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
|
||||
}
|
||||
return OpTrait::util::staticallyKnownBroadcastable(extents);
|
||||
}())
|
||||
@ -1132,10 +1132,10 @@ LogicalResult mlir::shape::DimOp::verify() {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
|
||||
auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
|
||||
auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
|
||||
auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
|
||||
@ -1346,7 +1346,7 @@ std::optional<int64_t> GetExtentOp::getConstantDim() {
|
||||
}
|
||||
|
||||
OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
|
||||
auto elements = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
|
||||
if (!elements)
|
||||
return nullptr;
|
||||
std::optional<int64_t> dim = getConstantDim();
|
||||
@ -1490,7 +1490,7 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
|
||||
auto shape = adaptor.getShape().dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
|
||||
if (!shape)
|
||||
return {};
|
||||
int64_t rank = shape.getNumElements();
|
||||
@ -1671,10 +1671,10 @@ bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
||||
auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>();
|
||||
auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>();
|
||||
auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
APInt folded = lhs.getValue() * rhs.getValue();
|
||||
@ -1864,9 +1864,9 @@ LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
|
||||
if (!adaptor.getOperand() || !adaptor.getIndex())
|
||||
return failure();
|
||||
auto shapeVec = llvm::to_vector<6>(
|
||||
adaptor.getOperand().cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||
llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
|
||||
auto shape = llvm::ArrayRef(shapeVec);
|
||||
auto splitPoint = adaptor.getIndex().cast<IntegerAttr>().getInt();
|
||||
auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
|
||||
// Verify that the split point is in the correct range.
|
||||
// TODO: Constant fold to an "error".
|
||||
int64_t rank = shape.size();
|
||||
@ -1889,7 +1889,7 @@ OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
|
||||
return OpFoldResult();
|
||||
Builder builder(getContext());
|
||||
auto shape = llvm::to_vector<6>(
|
||||
adaptor.getInput().cast<DenseIntElementsAttr>().getValues<int64_t>());
|
||||
llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
|
||||
auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
|
||||
builder.getIndexType());
|
||||
return DenseIntElementsAttr::get(type, shape);
|
||||
|
@ -815,7 +815,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
|
||||
Level cooStartLvl = getCOOStart(stt.getEncoding());
|
||||
if (cooStartLvl < stt.getLvlRank()) {
|
||||
// We only supports trailing COO for now, must be the last input.
|
||||
auto cooTp = lvlTps.back().cast<ShapedType>();
|
||||
auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
|
||||
// The coordinates should be in shape of <? x rank>
|
||||
unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
|
||||
if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
|
||||
@ -844,7 +844,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
|
||||
inputTp = lvlTps[idx++];
|
||||
}
|
||||
// The input element type and expected element type should match.
|
||||
Type inpElemTp = inputTp.cast<TensorType>().getElementType();
|
||||
Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
|
||||
Type expElemTp = getFieldElemType(stt, fKind);
|
||||
if (inpElemTp != expElemTp) {
|
||||
misMatch = true;
|
||||
|
@ -188,7 +188,7 @@ static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
|
||||
/// Generates a memref from tensor operation.
|
||||
static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
|
||||
Value tensor) {
|
||||
auto tensorType = tensor.getType().cast<ShapedType>();
|
||||
auto tensorType = llvm::cast<ShapedType>(tensor.getType());
|
||||
auto memrefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
|
||||
|
@ -414,7 +414,7 @@ public:
|
||||
/// TODO: better unord/not-unique; also generalize, optimize, specialize!
|
||||
SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
|
||||
OpBuilder &builder, Location loc) {
|
||||
const SparseTensorType stt(rtp.cast<RankedTensorType>());
|
||||
const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
|
||||
const Level lvlRank = stt.getLvlRank();
|
||||
// Extract fields and coordinates from args.
|
||||
SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
|
||||
@ -466,7 +466,7 @@ public:
|
||||
// The mangled name of the function has this format:
|
||||
// <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
|
||||
constexpr const char kInsertFuncNamePrefix[] = "_insert_";
|
||||
const SparseTensorType stt(rtp.cast<RankedTensorType>());
|
||||
const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
|
||||
|
||||
SmallString<32> nameBuffer;
|
||||
llvm::raw_svector_ostream nameOstream(nameBuffer);
|
||||
@ -541,14 +541,14 @@ static void genEndInsert(OpBuilder &builder, Location loc,
|
||||
|
||||
static TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
|
||||
Value tensor) {
|
||||
auto tTp = tensor.getType().cast<TensorType>();
|
||||
auto tTp = llvm::cast<TensorType>(tensor.getType());
|
||||
auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
|
||||
return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, Value sz) {
|
||||
auto elemTp = mem.getType().cast<MemRefType>().getElementType();
|
||||
auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
|
||||
return builder
|
||||
.create<memref::SubViewOp>(
|
||||
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
|
||||
|
@ -180,7 +180,7 @@ struct ReifyPadOp
|
||||
AffineExpr expr = b.getAffineDimExpr(0);
|
||||
unsigned numSymbols = 0;
|
||||
auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
|
||||
if (Value v = valueOrAttr.dyn_cast<Value>()) {
|
||||
if (Value v = llvm::dyn_cast_if_present<Value>(valueOrAttr)) {
|
||||
expr = expr + b.getAffineSymbolExpr(numSymbols++);
|
||||
mapOperands.push_back(v);
|
||||
return;
|
||||
|
@ -501,7 +501,7 @@ Speculation::Speculatability DimOp::getSpeculatability() {
|
||||
|
||||
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
|
||||
// All forms of folding require a known index.
|
||||
auto index = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
|
||||
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
|
||||
if (!index)
|
||||
return {};
|
||||
|
||||
@ -764,7 +764,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
|
||||
OpFoldResult currDim = std::get<1>(it);
|
||||
// Case 1: The empty tensor dim is static. Check that the tensor cast
|
||||
// result dim matches.
|
||||
if (auto attr = currDim.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
|
||||
if (ShapedType::isDynamic(newDim) ||
|
||||
newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
|
||||
// Something is off, the cast result shape cannot be more dynamic
|
||||
@ -2106,7 +2106,7 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
|
||||
}
|
||||
|
||||
OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto splat = adaptor.getSource().dyn_cast_or_null<SplatElementsAttr>()) {
|
||||
if (auto splat = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
|
||||
auto resultType = llvm::cast<ShapedType>(getResult().getType());
|
||||
if (resultType.hasStaticShape())
|
||||
return splat.resizeSplat(resultType);
|
||||
@ -3558,7 +3558,7 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
|
||||
SmallVector<int64_t> result;
|
||||
for (auto o : ofrs) {
|
||||
// Have to do this first, as getConstantIntValue special-cases constants.
|
||||
if (o.dyn_cast<Value>())
|
||||
if (llvm::dyn_cast_if_present<Value>(o))
|
||||
result.push_back(ShapedType::kDynamic);
|
||||
else
|
||||
result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
|
||||
|
@ -76,7 +76,7 @@ struct CastOpInterface
|
||||
auto rankedResultType = cast<RankedTensorType>(castOp.getType());
|
||||
return MemRefType::get(
|
||||
rankedResultType.getShape(), rankedResultType.getElementType(),
|
||||
maybeSrcBufferType->cast<MemRefType>().getLayout(), memorySpace);
|
||||
llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
@ -139,7 +139,7 @@ struct CollapseShapeOpInterface
|
||||
collapseShapeOp.getSrc(), options, fixedTypes);
|
||||
if (failed(maybeSrcBufferType))
|
||||
return failure();
|
||||
auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
|
||||
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
|
||||
bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
|
||||
srcBufferType, collapseShapeOp.getReassociationIndices());
|
||||
|
||||
@ -303,7 +303,7 @@ struct ExpandShapeOpInterface
|
||||
expandShapeOp.getSrc(), options, fixedTypes);
|
||||
if (failed(maybeSrcBufferType))
|
||||
return failure();
|
||||
auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
|
||||
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
|
||||
auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
|
||||
srcBufferType, expandShapeOp.getResultType().getShape(),
|
||||
expandShapeOp.getReassociationIndices());
|
||||
@ -369,7 +369,7 @@ struct ExtractSliceOpInterface
|
||||
if (failed(resultMemrefType))
|
||||
return failure();
|
||||
Value subView = rewriter.create<memref::SubViewOp>(
|
||||
loc, resultMemrefType->cast<MemRefType>(), *srcMemref, mixedOffsets,
|
||||
loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
|
||||
mixedSizes, mixedStrides);
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, subView);
|
||||
@ -389,7 +389,7 @@ struct ExtractSliceOpInterface
|
||||
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
|
||||
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
|
||||
return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
|
||||
extractSliceOp.getType().getShape(), srcMemrefType->cast<MemRefType>(),
|
||||
extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
|
||||
mixedOffsets, mixedSizes, mixedStrides));
|
||||
}
|
||||
};
|
||||
|
@ -548,8 +548,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
|
||||
auto resultETy = resultTy.getElementType();
|
||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
|
||||
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
||||
return getInput1();
|
||||
@ -573,8 +573,8 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
|
||||
auto resultETy = resultTy.getElementType();
|
||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
if (lhsAttr && lhsAttr.isSplat()) {
|
||||
if (llvm::isa<IntegerType>(resultETy) &&
|
||||
lhsAttr.getSplatValue<APInt>().isZero())
|
||||
@ -642,8 +642,8 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
|
||||
auto resultETy = resultTy.getElementType();
|
||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
|
||||
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
|
||||
if (rhsTy == resultTy) {
|
||||
@ -670,8 +670,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
|
||||
return {};
|
||||
|
||||
auto resultETy = resultTy.getElementType();
|
||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
|
||||
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
||||
return getInput1();
|
||||
@ -713,8 +713,8 @@ struct APIntFoldGreaterEqual {
|
||||
|
||||
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
|
||||
if (!lhsAttr || !rhsAttr)
|
||||
return {};
|
||||
@ -725,8 +725,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
|
||||
|
||||
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
|
||||
if (!lhsAttr || !rhsAttr)
|
||||
return {};
|
||||
@ -738,8 +738,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
|
||||
|
||||
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
||||
auto lhsAttr = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto rhsAttr = adaptor.getInput2().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
||||
Value lhs = getInput1();
|
||||
Value rhs = getInput2();
|
||||
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
|
||||
@ -763,7 +763,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
|
||||
if (getInput().getType() == getType())
|
||||
return getInput();
|
||||
|
||||
auto operand = adaptor.getInput().dyn_cast_or_null<ElementsAttr>();
|
||||
auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
|
||||
if (!operand)
|
||||
return {};
|
||||
|
||||
@ -852,7 +852,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
|
||||
if (inputTy == outputTy)
|
||||
return getInput1();
|
||||
|
||||
auto operand = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
||||
if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
|
||||
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
||||
}
|
||||
@ -863,7 +863,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
|
||||
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
|
||||
// If the pad is all zeros we can fold this operation away.
|
||||
if (adaptor.getPadding()) {
|
||||
auto densePad = adaptor.getPadding().cast<DenseElementsAttr>();
|
||||
auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
|
||||
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
|
||||
return getInput1();
|
||||
}
|
||||
@ -907,7 +907,7 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
|
||||
auto operand = getInput();
|
||||
auto operandTy = llvm::cast<ShapedType>(operand.getType());
|
||||
auto axis = getAxis();
|
||||
auto operandAttr = adaptor.getInput().dyn_cast_or_null<SplatElementsAttr>();
|
||||
auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
|
||||
if (operandAttr)
|
||||
return operandAttr;
|
||||
|
||||
@ -936,7 +936,7 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
|
||||
!outputTy.getElementType().isIntOrIndexOrFloat())
|
||||
return {};
|
||||
|
||||
auto operand = adaptor.getInput().cast<ElementsAttr>();
|
||||
auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
|
||||
if (operand.isSplat() && outputTy.hasStaticShape()) {
|
||||
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
||||
}
|
||||
@ -955,7 +955,7 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
|
||||
if (getOnTrue() == getOnFalse())
|
||||
return getOnTrue();
|
||||
|
||||
auto predicate = adaptor.getPred().dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
|
||||
if (!predicate)
|
||||
return {};
|
||||
|
||||
@ -977,7 +977,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
|
||||
auto resultTy = llvm::cast<ShapedType>(getType());
|
||||
|
||||
// Transposing splat values just means reshaping.
|
||||
if (auto input = adaptor.getInput1().dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
|
||||
if (input.isSplat() && resultTy.hasStaticShape() &&
|
||||
inputTy.getElementType() == resultTy.getElementType())
|
||||
return input.reshape(resultTy);
|
||||
|
@ -63,9 +63,9 @@ LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
|
||||
// Verify the rank agrees with the output type if the output type is ranked.
|
||||
if (outputType) {
|
||||
if (outputType.getRank() !=
|
||||
input1_copy.getType().cast<RankedTensorType>().getRank() ||
|
||||
llvm::cast<RankedTensorType>(input1_copy.getType()).getRank() ||
|
||||
outputType.getRank() !=
|
||||
input2_copy.getType().cast<RankedTensorType>().getRank())
|
||||
llvm::cast<RankedTensorType>(input2_copy.getType()).getRank())
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "the reshaped type doesn't agrees with the ranked output type");
|
||||
}
|
||||
|
@ -103,8 +103,8 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
|
||||
|
||||
LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
|
||||
Value &input1, Value &input2) {
|
||||
auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
|
||||
auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
|
||||
auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
|
||||
auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
|
||||
|
||||
if (!input1Ty || !input2Ty) {
|
||||
return failure();
|
||||
@ -126,9 +126,9 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> higherRankShape =
|
||||
higherTensorValue.getType().cast<RankedTensorType>().getShape();
|
||||
llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape();
|
||||
ArrayRef<int64_t> lowerRankShape =
|
||||
lowerTensorValue.getType().cast<RankedTensorType>().getShape();
|
||||
llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
|
||||
|
||||
SmallVector<int64_t, 4> reshapeOutputShape;
|
||||
|
||||
@ -136,7 +136,8 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
|
||||
.failed())
|
||||
return failure();
|
||||
|
||||
auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
|
||||
auto reshapeInputType =
|
||||
llvm::cast<RankedTensorType>(lowerTensorValue.getType());
|
||||
auto reshapeOutputType = RankedTensorType::get(
|
||||
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
|
||||
|
||||
|
@ -118,7 +118,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
|
||||
SmallVector<Operation *> operations;
|
||||
operations.reserve(values.size());
|
||||
for (transform::MappedValue value : values) {
|
||||
if (auto *op = value.dyn_cast<Operation *>()) {
|
||||
if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
|
||||
operations.push_back(op);
|
||||
continue;
|
||||
}
|
||||
@ -135,7 +135,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
|
||||
SmallVector<Value> payloadValues;
|
||||
payloadValues.reserve(values.size());
|
||||
for (transform::MappedValue value : values) {
|
||||
if (auto v = value.dyn_cast<Value>()) {
|
||||
if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
|
||||
payloadValues.push_back(v);
|
||||
continue;
|
||||
}
|
||||
@ -152,7 +152,7 @@ static DiagnosedSilenceableFailure dispatchMappedValues(
|
||||
SmallVector<transform::Param> parameters;
|
||||
parameters.reserve(values.size());
|
||||
for (transform::MappedValue value : values) {
|
||||
if (auto attr = value.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
|
||||
parameters.push_back(attr);
|
||||
continue;
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ namespace mlir {
|
||||
bool isZeroIndex(OpFoldResult v) {
|
||||
if (!v)
|
||||
return false;
|
||||
if (auto attr = v.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(v)) {
|
||||
IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
return intAttr && intAttr.getValue().isZero();
|
||||
}
|
||||
@ -51,7 +51,7 @@ getOffsetsSizesAndStrides(ArrayRef<Range> ranges) {
|
||||
void dispatchIndexOpFoldResult(OpFoldResult ofr,
|
||||
SmallVectorImpl<Value> &dynamicVec,
|
||||
SmallVectorImpl<int64_t> &staticVec) {
|
||||
auto v = ofr.dyn_cast<Value>();
|
||||
auto v = llvm::dyn_cast_if_present<Value>(ofr);
|
||||
if (!v) {
|
||||
APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
|
||||
staticVec.push_back(apInt.getSExtValue());
|
||||
@ -116,14 +116,14 @@ SmallVector<OpFoldResult> getAsIndexOpFoldResult(MLIRContext *ctx,
|
||||
/// If ofr is a constant integer or an IntegerAttr, return the integer.
|
||||
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
|
||||
// Case 1: Check for Constant integer.
|
||||
if (auto val = ofr.dyn_cast<Value>()) {
|
||||
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
|
||||
APSInt intVal;
|
||||
if (matchPattern(val, m_ConstantInt(&intVal)))
|
||||
return intVal.getSExtValue();
|
||||
return std::nullopt;
|
||||
}
|
||||
// Case 2: Check for IntegerAttr.
|
||||
Attribute attr = ofr.dyn_cast<Attribute>();
|
||||
Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
|
||||
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
|
||||
return intAttr.getValue().getSExtValue();
|
||||
return std::nullopt;
|
||||
@ -143,7 +143,8 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
|
||||
auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
|
||||
if (cst1 && cst2 && *cst1 == *cst2)
|
||||
return true;
|
||||
auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
|
||||
auto v1 = llvm::dyn_cast_if_present<Value>(ofr1),
|
||||
v2 = llvm::dyn_cast_if_present<Value>(ofr2);
|
||||
return v1 && v1 == v2;
|
||||
}
|
||||
|
||||
|
@ -1154,7 +1154,7 @@ ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
|
||||
OpaqueProperties properties, RegionRange,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
ExtractOp::Adaptor op(operands, attributes, properties);
|
||||
auto vectorType = op.getVector().getType().cast<VectorType>();
|
||||
auto vectorType = llvm::cast<VectorType>(op.getVector().getType());
|
||||
if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
|
||||
inferredReturnTypes.push_back(vectorType.getElementType());
|
||||
} else {
|
||||
@ -2003,9 +2003,9 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
|
||||
if (!adaptor.getSource())
|
||||
return {};
|
||||
auto vectorType = getResultVectorType();
|
||||
if (adaptor.getSource().isa<IntegerAttr, FloatAttr>())
|
||||
if (llvm::isa<IntegerAttr, FloatAttr>(adaptor.getSource()))
|
||||
return DenseElementsAttr::get(vectorType, adaptor.getSource());
|
||||
if (auto attr = adaptor.getSource().dyn_cast<SplatElementsAttr>())
|
||||
if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
|
||||
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
|
||||
return {};
|
||||
}
|
||||
@ -2090,7 +2090,7 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
|
||||
OpaqueProperties properties, RegionRange,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
ShuffleOp::Adaptor op(operands, attributes, properties);
|
||||
auto v1Type = op.getV1().getType().cast<VectorType>();
|
||||
auto v1Type = llvm::cast<VectorType>(op.getV1().getType());
|
||||
auto v1Rank = v1Type.getRank();
|
||||
// Construct resulting type: leading dimension matches mask
|
||||
// length, all trailing dimensions match the operands.
|
||||
@ -4951,7 +4951,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
|
||||
|
||||
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
|
||||
// Eliminate splat constant transpose ops.
|
||||
if (auto attr = adaptor.getVector().dyn_cast_or_null<DenseElementsAttr>())
|
||||
if (auto attr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
|
||||
if (attr.isSplat())
|
||||
return attr.reshape(getResultVectorType());
|
||||
|
||||
|
@ -3642,7 +3642,7 @@ void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
|
||||
if (auto *op = getDefiningOp())
|
||||
return op->print(os, flags);
|
||||
// TODO: Improve BlockArgument print'ing.
|
||||
BlockArgument arg = this->cast<BlockArgument>();
|
||||
BlockArgument arg = llvm::cast<BlockArgument>(*this);
|
||||
os << "<block argument> of type '" << arg.getType()
|
||||
<< "' at index: " << arg.getArgNumber();
|
||||
}
|
||||
@ -3656,7 +3656,7 @@ void Value::print(raw_ostream &os, AsmState &state) {
|
||||
return op->print(os, state);
|
||||
|
||||
// TODO: Improve BlockArgument print'ing.
|
||||
BlockArgument arg = this->cast<BlockArgument>();
|
||||
BlockArgument arg = llvm::cast<BlockArgument>(*this);
|
||||
os << "<block argument> of type '" << arg.getType()
|
||||
<< "' at index: " << arg.getArgNumber();
|
||||
}
|
||||
@ -3693,10 +3693,10 @@ static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
|
||||
|
||||
void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) {
|
||||
Operation *op;
|
||||
if (auto result = dyn_cast<OpResult>()) {
|
||||
if (auto result = llvm::dyn_cast<OpResult>(*this)) {
|
||||
op = result.getOwner();
|
||||
} else {
|
||||
op = cast<BlockArgument>().getOwner()->getParentOp();
|
||||
op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp();
|
||||
if (!op) {
|
||||
os << "<<UNKNOWN SSA VALUE>>";
|
||||
return;
|
||||
|
@ -347,14 +347,14 @@ BlockRange::BlockRange(SuccessorRange successors)
|
||||
|
||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||
BlockRange::OwnerT BlockRange::offset_base(OwnerT object, ptrdiff_t index) {
|
||||
if (auto *operand = object.dyn_cast<BlockOperand *>())
|
||||
if (auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
|
||||
return {operand + index};
|
||||
return {object.dyn_cast<Block *const *>() + index};
|
||||
return {llvm::dyn_cast_if_present<Block *const *>(object) + index};
|
||||
}
|
||||
|
||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||
Block *BlockRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
|
||||
if (const auto *operand = object.dyn_cast<BlockOperand *>())
|
||||
if (const auto *operand = llvm::dyn_cast_if_present<BlockOperand *>(object))
|
||||
return operand[index].get();
|
||||
return object.dyn_cast<Block *const *>()[index];
|
||||
return llvm::dyn_cast_if_present<Block *const *>(object)[index];
|
||||
}
|
||||
|
@ -483,7 +483,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
|
||||
Type expectedType = std::get<1>(it);
|
||||
|
||||
// Normal values get pushed back directly.
|
||||
if (auto value = std::get<0>(it).dyn_cast<Value>()) {
|
||||
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
|
||||
if (value.getType() != expectedType)
|
||||
return cleanupFailure();
|
||||
|
||||
|
@ -1247,12 +1247,12 @@ DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
|
||||
DenseElementsAttr
|
||||
DenseElementsAttr::mapValues(Type newElementType,
|
||||
function_ref<APInt(const APInt &)> mapping) const {
|
||||
return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
|
||||
return llvm::cast<DenseIntElementsAttr>(*this).mapValues(newElementType, mapping);
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::mapValues(
|
||||
Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
|
||||
return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
|
||||
return llvm::cast<DenseFPElementsAttr>(*this).mapValues(newElementType, mapping);
|
||||
}
|
||||
|
||||
ShapedType DenseElementsAttr::getType() const {
|
||||
|
@ -88,45 +88,45 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
unsigned FloatType::getWidth() {
|
||||
if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||
Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
|
||||
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||
Float8E4M3FNUZType, Float8E4M3B11FNUZType>(*this))
|
||||
return 8;
|
||||
if (isa<Float16Type, BFloat16Type>())
|
||||
if (llvm::isa<Float16Type, BFloat16Type>(*this))
|
||||
return 16;
|
||||
if (isa<Float32Type>())
|
||||
if (llvm::isa<Float32Type>(*this))
|
||||
return 32;
|
||||
if (isa<Float64Type>())
|
||||
if (llvm::isa<Float64Type>(*this))
|
||||
return 64;
|
||||
if (isa<Float80Type>())
|
||||
if (llvm::isa<Float80Type>(*this))
|
||||
return 80;
|
||||
if (isa<Float128Type>())
|
||||
if (llvm::isa<Float128Type>(*this))
|
||||
return 128;
|
||||
llvm_unreachable("unexpected float type");
|
||||
}
|
||||
|
||||
/// Returns the floating semantics for the given type.
|
||||
const llvm::fltSemantics &FloatType::getFloatSemantics() {
|
||||
if (isa<Float8E5M2Type>())
|
||||
if (llvm::isa<Float8E5M2Type>(*this))
|
||||
return APFloat::Float8E5M2();
|
||||
if (isa<Float8E4M3FNType>())
|
||||
if (llvm::isa<Float8E4M3FNType>(*this))
|
||||
return APFloat::Float8E4M3FN();
|
||||
if (isa<Float8E5M2FNUZType>())
|
||||
if (llvm::isa<Float8E5M2FNUZType>(*this))
|
||||
return APFloat::Float8E5M2FNUZ();
|
||||
if (isa<Float8E4M3FNUZType>())
|
||||
if (llvm::isa<Float8E4M3FNUZType>(*this))
|
||||
return APFloat::Float8E4M3FNUZ();
|
||||
if (isa<Float8E4M3B11FNUZType>())
|
||||
if (llvm::isa<Float8E4M3B11FNUZType>(*this))
|
||||
return APFloat::Float8E4M3B11FNUZ();
|
||||
if (isa<BFloat16Type>())
|
||||
if (llvm::isa<BFloat16Type>(*this))
|
||||
return APFloat::BFloat();
|
||||
if (isa<Float16Type>())
|
||||
if (llvm::isa<Float16Type>(*this))
|
||||
return APFloat::IEEEhalf();
|
||||
if (isa<Float32Type>())
|
||||
if (llvm::isa<Float32Type>(*this))
|
||||
return APFloat::IEEEsingle();
|
||||
if (isa<Float64Type>())
|
||||
if (llvm::isa<Float64Type>(*this))
|
||||
return APFloat::IEEEdouble();
|
||||
if (isa<Float80Type>())
|
||||
if (llvm::isa<Float80Type>(*this))
|
||||
return APFloat::x87DoubleExtended();
|
||||
if (isa<Float128Type>())
|
||||
if (llvm::isa<Float128Type>(*this))
|
||||
return APFloat::IEEEquad();
|
||||
llvm_unreachable("non-floating point type used");
|
||||
}
|
||||
@ -269,21 +269,21 @@ Type TensorType::getElementType() const {
|
||||
[](auto type) { return type.getElementType(); });
|
||||
}
|
||||
|
||||
bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
|
||||
bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
|
||||
|
||||
ArrayRef<int64_t> TensorType::getShape() const {
|
||||
return cast<RankedTensorType>().getShape();
|
||||
return llvm::cast<RankedTensorType>(*this).getShape();
|
||||
}
|
||||
|
||||
TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
|
||||
Type elementType) const {
|
||||
if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
|
||||
if (auto unrankedTy = llvm::dyn_cast<UnrankedTensorType>(*this)) {
|
||||
if (shape)
|
||||
return RankedTensorType::get(*shape, elementType);
|
||||
return UnrankedTensorType::get(elementType);
|
||||
}
|
||||
|
||||
auto rankedTy = cast<RankedTensorType>();
|
||||
auto rankedTy = llvm::cast<RankedTensorType>(*this);
|
||||
if (!shape)
|
||||
return RankedTensorType::get(rankedTy.getShape(), elementType,
|
||||
rankedTy.getEncoding());
|
||||
@ -356,15 +356,15 @@ Type BaseMemRefType::getElementType() const {
|
||||
[](auto type) { return type.getElementType(); });
|
||||
}
|
||||
|
||||
bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
|
||||
bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
|
||||
|
||||
ArrayRef<int64_t> BaseMemRefType::getShape() const {
|
||||
return cast<MemRefType>().getShape();
|
||||
return llvm::cast<MemRefType>(*this).getShape();
|
||||
}
|
||||
|
||||
BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
|
||||
Type elementType) const {
|
||||
if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
|
||||
if (auto unrankedTy = llvm::dyn_cast<UnrankedMemRefType>(*this)) {
|
||||
if (!shape)
|
||||
return UnrankedMemRefType::get(elementType, getMemorySpace());
|
||||
MemRefType::Builder builder(*shape, elementType);
|
||||
@ -372,7 +372,7 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
|
||||
return builder;
|
||||
}
|
||||
|
||||
MemRefType::Builder builder(cast<MemRefType>());
|
||||
MemRefType::Builder builder(llvm::cast<MemRefType>(*this));
|
||||
if (shape)
|
||||
builder.setShape(*shape);
|
||||
builder.setElementType(elementType);
|
||||
@ -389,15 +389,15 @@ MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
|
||||
}
|
||||
|
||||
Attribute BaseMemRefType::getMemorySpace() const {
|
||||
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
|
||||
if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
|
||||
return rankedMemRefTy.getMemorySpace();
|
||||
return cast<UnrankedMemRefType>().getMemorySpace();
|
||||
return llvm::cast<UnrankedMemRefType>(*this).getMemorySpace();
|
||||
}
|
||||
|
||||
unsigned BaseMemRefType::getMemorySpaceAsInt() const {
|
||||
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
|
||||
if (auto rankedMemRefTy = llvm::dyn_cast<MemRefType>(*this))
|
||||
return rankedMemRefTy.getMemorySpaceAsInt();
|
||||
return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
|
||||
return llvm::cast<UnrankedMemRefType>(*this).getMemorySpaceAsInt();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -626,17 +626,17 @@ ValueRange::ValueRange(ResultRange values)
|
||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||
ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
|
||||
ptrdiff_t index) {
|
||||
if (const auto *value = owner.dyn_cast<const Value *>())
|
||||
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
|
||||
return {value + index};
|
||||
if (auto *operand = owner.dyn_cast<OpOperand *>())
|
||||
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
|
||||
return {operand + index};
|
||||
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
|
||||
}
|
||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||
Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
|
||||
if (const auto *value = owner.dyn_cast<const Value *>())
|
||||
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
|
||||
return value[index];
|
||||
if (auto *operand = owner.dyn_cast<OpOperand *>())
|
||||
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
|
||||
return operand[index].get();
|
||||
return owner.get<detail::OpResultImpl *>()->getNextResultAtOffset(index);
|
||||
}
|
||||
|
@ -267,18 +267,18 @@ RegionRange::RegionRange(ArrayRef<Region *> regions)
|
||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||
RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner,
|
||||
ptrdiff_t index) {
|
||||
if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
|
||||
if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
|
||||
return region + index;
|
||||
if (auto **region = owner.dyn_cast<Region **>())
|
||||
if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
|
||||
return region + index;
|
||||
return &owner.get<Region *>()[index];
|
||||
}
|
||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||
Region *RegionRange::dereference_iterator(const OwnerT &owner,
|
||||
ptrdiff_t index) {
|
||||
if (auto *region = owner.dyn_cast<const std::unique_ptr<Region> *>())
|
||||
if (auto *region = llvm::dyn_cast_if_present<const std::unique_ptr<Region> *>(owner))
|
||||
return region[index].get();
|
||||
if (auto **region = owner.dyn_cast<Region **>())
|
||||
if (auto **region = llvm::dyn_cast_if_present<Region **>(owner))
|
||||
return region[index];
|
||||
return &owner.get<Region *>()[index];
|
||||
}
|
||||
|
@ -551,7 +551,7 @@ struct SymbolScope {
|
||||
typename llvm::function_traits<CallbackT>::result_t,
|
||||
void>::value> * = nullptr>
|
||||
std::optional<WalkResult> walk(CallbackT cback) {
|
||||
if (Region *region = limit.dyn_cast<Region *>())
|
||||
if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
|
||||
return walkSymbolUses(*region, cback);
|
||||
return walkSymbolUses(limit.get<Operation *>(), cback);
|
||||
}
|
||||
@ -571,7 +571,7 @@ struct SymbolScope {
|
||||
/// traversing into any nested symbol tables.
|
||||
template <typename CallbackT>
|
||||
std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
|
||||
if (Region *region = limit.dyn_cast<Region *>())
|
||||
if (Region *region = llvm::dyn_cast_if_present<Region *>(limit))
|
||||
return ::walkSymbolTable(*region, cback);
|
||||
return ::walkSymbolTable(limit.get<Operation *>(), cback);
|
||||
}
|
||||
|
@ -27,9 +27,9 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
|
||||
if (count == 0)
|
||||
return;
|
||||
ValueRange::OwnerT owner = values.begin().getBase();
|
||||
if (auto *result = owner.dyn_cast<detail::OpResultImpl *>())
|
||||
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(owner))
|
||||
this->base = result;
|
||||
else if (auto *operand = owner.dyn_cast<OpOperand *>())
|
||||
else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
|
||||
this->base = operand;
|
||||
else
|
||||
this->base = owner.get<const Value *>();
|
||||
@ -37,22 +37,22 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
|
||||
|
||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||
TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
|
||||
if (const auto *value = object.dyn_cast<const Value *>())
|
||||
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
|
||||
return {value + index};
|
||||
if (auto *operand = object.dyn_cast<OpOperand *>())
|
||||
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
|
||||
return {operand + index};
|
||||
if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
|
||||
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
|
||||
return {result->getNextResultAtOffset(index)};
|
||||
return {object.dyn_cast<const Type *>() + index};
|
||||
return {llvm::dyn_cast_if_present<const Type *>(object) + index};
|
||||
}
|
||||
|
||||
/// See `llvm::detail::indexed_accessor_range_base` for details.
|
||||
Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
|
||||
if (const auto *value = object.dyn_cast<const Value *>())
|
||||
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
|
||||
return (value + index)->getType();
|
||||
if (auto *operand = object.dyn_cast<OpOperand *>())
|
||||
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
|
||||
return (operand + index)->get().getType();
|
||||
if (auto *result = object.dyn_cast<detail::OpResultImpl *>())
|
||||
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
|
||||
return result->getNextResultAtOffset(index)->getType();
|
||||
return object.dyn_cast<const Type *>()[index];
|
||||
return llvm::dyn_cast_if_present<const Type *>(object)[index];
|
||||
}
|
||||
|
@ -34,84 +34,94 @@ Type AbstractType::replaceImmediateSubElements(Type type,
|
||||
|
||||
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
|
||||
|
||||
bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
|
||||
bool Type::isFloat8E4M3FN() const { return isa<Float8E4M3FNType>(); }
|
||||
bool Type::isFloat8E5M2FNUZ() const { return isa<Float8E5M2FNUZType>(); }
|
||||
bool Type::isFloat8E4M3FNUZ() const { return isa<Float8E4M3FNUZType>(); }
|
||||
bool Type::isFloat8E4M3B11FNUZ() const { return isa<Float8E4M3B11FNUZType>(); }
|
||||
bool Type::isBF16() const { return isa<BFloat16Type>(); }
|
||||
bool Type::isF16() const { return isa<Float16Type>(); }
|
||||
bool Type::isF32() const { return isa<Float32Type>(); }
|
||||
bool Type::isF64() const { return isa<Float64Type>(); }
|
||||
bool Type::isF80() const { return isa<Float80Type>(); }
|
||||
bool Type::isF128() const { return isa<Float128Type>(); }
|
||||
bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
|
||||
bool Type::isFloat8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); }
|
||||
bool Type::isFloat8E5M2FNUZ() const {
|
||||
return llvm::isa<Float8E5M2FNUZType>(*this);
|
||||
}
|
||||
bool Type::isFloat8E4M3FNUZ() const {
|
||||
return llvm::isa<Float8E4M3FNUZType>(*this);
|
||||
}
|
||||
bool Type::isFloat8E4M3B11FNUZ() const {
|
||||
return llvm::isa<Float8E4M3B11FNUZType>(*this);
|
||||
}
|
||||
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
|
||||
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
|
||||
bool Type::isF32() const { return llvm::isa<Float32Type>(*this); }
|
||||
bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
|
||||
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
|
||||
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
|
||||
|
||||
bool Type::isIndex() const { return isa<IndexType>(); }
|
||||
bool Type::isIndex() const { return llvm::isa<IndexType>(*this); }
|
||||
|
||||
/// Return true if this is an integer type with the specified width.
|
||||
bool Type::isInteger(unsigned width) const {
|
||||
if (auto intTy = dyn_cast<IntegerType>())
|
||||
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||
return intTy.getWidth() == width;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Type::isSignlessInteger() const {
|
||||
if (auto intTy = dyn_cast<IntegerType>())
|
||||
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||
return intTy.isSignless();
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Type::isSignlessInteger(unsigned width) const {
|
||||
if (auto intTy = dyn_cast<IntegerType>())
|
||||
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||
return intTy.isSignless() && intTy.getWidth() == width;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Type::isSignedInteger() const {
|
||||
if (auto intTy = dyn_cast<IntegerType>())
|
||||
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||
return intTy.isSigned();
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Type::isSignedInteger(unsigned width) const {
|
||||
if (auto intTy = dyn_cast<IntegerType>())
|
||||
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||
return intTy.isSigned() && intTy.getWidth() == width;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Type::isUnsignedInteger() const {
|
||||
if (auto intTy = dyn_cast<IntegerType>())
|
||||
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||
return intTy.isUnsigned();
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Type::isUnsignedInteger(unsigned width) const {
|
||||
if (auto intTy = dyn_cast<IntegerType>())
|
||||
if (auto intTy = llvm::dyn_cast<IntegerType>(*this))
|
||||
return intTy.isUnsigned() && intTy.getWidth() == width;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Type::isSignlessIntOrIndex() const {
|
||||
return isSignlessInteger() || isa<IndexType>();
|
||||
return isSignlessInteger() || llvm::isa<IndexType>(*this);
|
||||
}
|
||||
|
||||
bool Type::isSignlessIntOrIndexOrFloat() const {
|
||||
return isSignlessInteger() || isa<IndexType, FloatType>();
|
||||
return isSignlessInteger() || llvm::isa<IndexType, FloatType>(*this);
|
||||
}
|
||||
|
||||
bool Type::isSignlessIntOrFloat() const {
|
||||
return isSignlessInteger() || isa<FloatType>();
|
||||
return isSignlessInteger() || llvm::isa<FloatType>(*this);
|
||||
}
|
||||
|
||||
bool Type::isIntOrIndex() const { return isa<IntegerType>() || isIndex(); }
|
||||
bool Type::isIntOrIndex() const {
|
||||
return llvm::isa<IntegerType>(*this) || isIndex();
|
||||
}
|
||||
|
||||
bool Type::isIntOrFloat() const { return isa<IntegerType, FloatType>(); }
|
||||
bool Type::isIntOrFloat() const {
|
||||
return llvm::isa<IntegerType, FloatType>(*this);
|
||||
}
|
||||
|
||||
bool Type::isIntOrIndexOrFloat() const { return isIntOrFloat() || isIndex(); }
|
||||
|
||||
unsigned Type::getIntOrFloatBitWidth() const {
|
||||
assert(isIntOrFloat() && "only integers and floats have a bitwidth");
|
||||
if (auto intType = dyn_cast<IntegerType>())
|
||||
if (auto intType = llvm::dyn_cast<IntegerType>(*this))
|
||||
return intType.getWidth();
|
||||
return cast<FloatType>().getWidth();
|
||||
return llvm::cast<FloatType>(*this).getWidth();
|
||||
}
|
||||
|
@ -48,11 +48,11 @@ static void printBlock(llvm::raw_ostream &os, Block *block,
|
||||
}
|
||||
|
||||
void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const {
|
||||
if (auto *op = this->dyn_cast<Operation *>())
|
||||
if (auto *op = llvm::dyn_cast_if_present<Operation *>(*this))
|
||||
return printOp(os, op, flags);
|
||||
if (auto *region = this->dyn_cast<Region *>())
|
||||
if (auto *region = llvm::dyn_cast_if_present<Region *>(*this))
|
||||
return printRegion(os, region, flags);
|
||||
if (auto *block = this->dyn_cast<Block *>())
|
||||
if (auto *block = llvm::dyn_cast_if_present<Block *>(*this))
|
||||
return printBlock(os, block, flags);
|
||||
llvm_unreachable("unknown IRUnit");
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ using namespace mlir::detail;
|
||||
/// If this value is the result of an Operation, return the operation that
|
||||
/// defines it.
|
||||
Operation *Value::getDefiningOp() const {
|
||||
if (auto result = dyn_cast<OpResult>())
|
||||
if (auto result = llvm::dyn_cast<OpResult>(*this))
|
||||
return result.getOwner();
|
||||
return nullptr;
|
||||
}
|
||||
@ -27,28 +27,28 @@ Location Value::getLoc() const {
|
||||
if (auto *op = getDefiningOp())
|
||||
return op->getLoc();
|
||||
|
||||
return cast<BlockArgument>().getLoc();
|
||||
return llvm::cast<BlockArgument>(*this).getLoc();
|
||||
}
|
||||
|
||||
void Value::setLoc(Location loc) {
|
||||
if (auto *op = getDefiningOp())
|
||||
return op->setLoc(loc);
|
||||
|
||||
return cast<BlockArgument>().setLoc(loc);
|
||||
return llvm::cast<BlockArgument>(*this).setLoc(loc);
|
||||
}
|
||||
|
||||
/// Return the Region in which this Value is defined.
|
||||
Region *Value::getParentRegion() {
|
||||
if (auto *op = getDefiningOp())
|
||||
return op->getParentRegion();
|
||||
return cast<BlockArgument>().getOwner()->getParent();
|
||||
return llvm::cast<BlockArgument>(*this).getOwner()->getParent();
|
||||
}
|
||||
|
||||
/// Return the Block in which this Value is defined.
|
||||
Block *Value::getParentBlock() {
|
||||
if (Operation *op = getDefiningOp())
|
||||
return op->getBlock();
|
||||
return cast<BlockArgument>().getOwner();
|
||||
return llvm::cast<BlockArgument>(*this).getOwner();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -241,7 +241,7 @@ mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
|
||||
TypeID typeID) {
|
||||
return llvm::to_vector<4>(llvm::make_filter_range(
|
||||
entries, [typeID](DataLayoutEntryInterface entry) {
|
||||
auto type = entry.getKey().dyn_cast<Type>();
|
||||
auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
|
||||
return type && type.getTypeID() == typeID;
|
||||
}));
|
||||
}
|
||||
@ -521,7 +521,7 @@ void DataLayoutSpecInterface::bucketEntriesByType(
|
||||
DenseMap<TypeID, DataLayoutEntryList> &types,
|
||||
DenseMap<StringAttr, DataLayoutEntryInterface> &ids) {
|
||||
for (DataLayoutEntryInterface entry : getEntries()) {
|
||||
if (auto type = entry.getKey().dyn_cast<Type>())
|
||||
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey()))
|
||||
types[type.getTypeID()].push_back(entry);
|
||||
else
|
||||
ids[entry.getKey().get<StringAttr>()] = entry;
|
||||
|
@ -68,7 +68,7 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
|
||||
bool ShapeAdaptor::hasRank() const {
|
||||
if (val.isNull())
|
||||
return false;
|
||||
if (auto t = val.dyn_cast<Type>())
|
||||
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||
return cast<ShapedType>(t).hasRank();
|
||||
if (val.is<Attribute>())
|
||||
return true;
|
||||
@ -78,7 +78,7 @@ bool ShapeAdaptor::hasRank() const {
|
||||
Type ShapeAdaptor::getElementType() const {
|
||||
if (val.isNull())
|
||||
return nullptr;
|
||||
if (auto t = val.dyn_cast<Type>())
|
||||
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||
return cast<ShapedType>(t).getElementType();
|
||||
if (val.is<Attribute>())
|
||||
return nullptr;
|
||||
@ -87,10 +87,10 @@ Type ShapeAdaptor::getElementType() const {
|
||||
|
||||
void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
|
||||
assert(hasRank());
|
||||
if (auto t = val.dyn_cast<Type>()) {
|
||||
if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
|
||||
ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
|
||||
res.assign(vals.begin(), vals.end());
|
||||
} else if (auto attr = val.dyn_cast<Attribute>()) {
|
||||
} else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
|
||||
auto dattr = cast<DenseIntElementsAttr>(attr);
|
||||
res.clear();
|
||||
res.reserve(dattr.size());
|
||||
@ -110,9 +110,9 @@ void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
|
||||
|
||||
int64_t ShapeAdaptor::getDimSize(int index) const {
|
||||
assert(hasRank());
|
||||
if (auto t = val.dyn_cast<Type>())
|
||||
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||
return cast<ShapedType>(t).getDimSize(index);
|
||||
if (auto attr = val.dyn_cast<Attribute>())
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
|
||||
return cast<DenseIntElementsAttr>(attr)
|
||||
.getValues<APInt>()[index]
|
||||
.getSExtValue();
|
||||
@ -122,9 +122,9 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
|
||||
|
||||
int64_t ShapeAdaptor::getRank() const {
|
||||
assert(hasRank());
|
||||
if (auto t = val.dyn_cast<Type>())
|
||||
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||
return cast<ShapedType>(t).getRank();
|
||||
if (auto attr = val.dyn_cast<Attribute>())
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
|
||||
return cast<DenseIntElementsAttr>(attr).size();
|
||||
return val.get<ShapedTypeComponents *>()->getDims().size();
|
||||
}
|
||||
@ -133,9 +133,9 @@ bool ShapeAdaptor::hasStaticShape() const {
|
||||
if (!hasRank())
|
||||
return false;
|
||||
|
||||
if (auto t = val.dyn_cast<Type>())
|
||||
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||
return cast<ShapedType>(t).hasStaticShape();
|
||||
if (auto attr = val.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
|
||||
auto dattr = cast<DenseIntElementsAttr>(attr);
|
||||
for (auto index : dattr.getValues<APInt>())
|
||||
if (ShapedType::isDynamic(index.getSExtValue()))
|
||||
@ -149,10 +149,10 @@ bool ShapeAdaptor::hasStaticShape() const {
|
||||
int64_t ShapeAdaptor::getNumElements() const {
|
||||
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
|
||||
|
||||
if (auto t = val.dyn_cast<Type>())
|
||||
if (auto t = llvm::dyn_cast_if_present<Type>(val))
|
||||
return cast<ShapedType>(t).getNumElements();
|
||||
|
||||
if (auto attr = val.dyn_cast<Attribute>()) {
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
|
||||
auto dattr = cast<DenseIntElementsAttr>(attr);
|
||||
int64_t num = 1;
|
||||
for (auto index : dattr.getValues<APInt>()) {
|
||||
|
@ -26,14 +26,14 @@ namespace mlir {
|
||||
/// If ofr is a constant integer or an IntegerAttr, return the integer.
|
||||
static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
|
||||
// Case 1: Check for Constant integer.
|
||||
if (auto val = ofr.dyn_cast<Value>()) {
|
||||
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
|
||||
APSInt intVal;
|
||||
if (matchPattern(val, m_ConstantInt(&intVal)))
|
||||
return intVal.getSExtValue();
|
||||
return std::nullopt;
|
||||
}
|
||||
// Case 2: Check for IntegerAttr.
|
||||
Attribute attr = ofr.dyn_cast<Attribute>();
|
||||
Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
|
||||
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
|
||||
return intAttr.getValue().getSExtValue();
|
||||
return std::nullopt;
|
||||
@ -99,7 +99,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
|
||||
}
|
||||
|
||||
AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
|
||||
if (Value value = ofr.dyn_cast<Value>())
|
||||
if (Value value = llvm::dyn_cast_if_present<Value>(ofr))
|
||||
return getExpr(value, /*dim=*/std::nullopt);
|
||||
auto constInt = getConstantIntValue(ofr);
|
||||
assert(constInt.has_value() && "expected Integer constant");
|
||||
|
@ -26,7 +26,8 @@ struct PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
|
||||
const Pass &getPass() const { return pass; }
|
||||
Operation *getOp() const {
|
||||
ArrayRef<IRUnit> irUnits = getContextIRUnits();
|
||||
return irUnits.empty() ? nullptr : irUnits[0].dyn_cast<Operation *>();
|
||||
return irUnits.empty() ? nullptr
|
||||
: llvm::dyn_cast_if_present<Operation *>(irUnits[0]);
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -384,7 +384,7 @@ void Operator::populateTypeInferenceInfo(
|
||||
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
|
||||
// Check for a non-variable length operand to use as the type anchor.
|
||||
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
|
||||
NamedTypeConstraint *operand = arg.dyn_cast<NamedTypeConstraint *>();
|
||||
NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
|
||||
return operand && !operand->isVariableLength();
|
||||
});
|
||||
if (operandI == arguments.end())
|
||||
@ -824,7 +824,7 @@ StringRef Operator::getAssemblyFormat() const {
|
||||
void Operator::print(llvm::raw_ostream &os) const {
|
||||
os << "op '" << getOperationName() << "'\n";
|
||||
for (Argument arg : arguments) {
|
||||
if (auto *attr = arg.dyn_cast<NamedAttribute *>())
|
||||
if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg))
|
||||
os << "[attribute] " << attr->name << '\n';
|
||||
else
|
||||
os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
|
||||
|
@ -131,7 +131,7 @@ convertBranchWeights(std::optional<ElementsAttr> weights,
|
||||
return nullptr;
|
||||
SmallVector<uint32_t> weightValues;
|
||||
weightValues.reserve(weights->size());
|
||||
for (APInt weight : weights->cast<DenseIntElementsAttr>())
|
||||
for (APInt weight : llvm::cast<DenseIntElementsAttr>(*weights))
|
||||
weightValues.push_back(weight.getLimitedValue());
|
||||
return llvm::MDBuilder(moduleTranslation.getLLVMContext())
|
||||
.createBranchWeights(weightValues);
|
||||
@ -330,7 +330,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||
auto *ty = llvm::cast<llvm::IntegerType>(
|
||||
moduleTranslation.convertType(switchOp.getValue().getType()));
|
||||
for (auto i :
|
||||
llvm::zip(switchOp.getCaseValues()->cast<DenseIntElementsAttr>(),
|
||||
llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()),
|
||||
switchOp.getCaseDestinations()))
|
||||
switchInst->addCase(
|
||||
llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user