Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

433 lines
15 KiB
C++
Raw Normal View History

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::tensor;
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
/// Determines whether tensor::CastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
/// implement canonicalization patterns for ops in different dialects that may
/// consume the results of tensor.cast operations. Such foldable tensor.cast
/// operations are typically inserted as `subtensor` ops and are canonicalized,
/// to preserve the type compatibility of their uses.
///
/// Returns true when all conditions are met:
/// 1. source and result are ranked tensors with same element type and rank.
/// 2. the tensor type has more static information than the result
///
/// Example:
/// ```mlir
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
/// ```
bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
if (!castOp)
return false;
RankedTensorType sourceType =
castOp.source().getType().dyn_cast<RankedTensorType>();
RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
// Requires RankedTensorType.
if (!sourceType || !resultType)
return false;
// Requires same elemental type.
if (sourceType.getElementType() != resultType.getElementType())
return false;
// Requires same rank.
if (sourceType.getRank() != resultType.getRank())
return false;
// If cast is towards more static sizes along any dimension, don't fold.
for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
if (ShapedType::isDynamic(std::get<0>(t)) &&
!ShapedType::isDynamic(std::get<1>(t)))
return false;
}
return true;
}
bool CastOp::areCastCompatible(Type a, Type b) {
auto aT = a.dyn_cast<TensorType>();
auto bT = b.dyn_cast<TensorType>();
if (!aT || !bT)
return false;
if (aT.getElementType() != bT.getElementType())
return false;
return succeeded(verifyCompatibleShape(aT, bT));
}
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return impl::foldCastOp(*this);
}
/// Compute a TensorType that has the joined shape knowledge of the two
/// given TensorTypes. The element types need to match.
static TensorType joinShapes(TensorType one, TensorType two) {
assert(one.getElementType() == two.getElementType());
if (!one.hasRank())
return two;
if (!two.hasRank())
return one;
int64_t rank = one.getRank();
if (rank != two.getRank())
return {};
SmallVector<int64_t, 4> join;
join.reserve(rank);
for (int64_t i = 0; i < rank; ++i) {
if (one.isDynamicDim(i)) {
join.push_back(two.getDimSize(i));
continue;
}
if (two.isDynamicDim(i)) {
join.push_back(one.getDimSize(i));
continue;
}
if (one.getDimSize(i) != two.getDimSize(i))
return {};
join.push_back(one.getDimSize(i));
}
return RankedTensorType::get(join, one.getElementType());
}
namespace {
/// Replaces chains of two tensor.cast operations by a single tensor.cast
/// operation if doing so does not remove runtime constraints.
struct ChainedTensorCast : public OpRewritePattern<CastOp> {
using OpRewritePattern<CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CastOp tensorCast,
PatternRewriter &rewriter) const final {
auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
if (!tensorCastOperand)
return failure();
auto sourceType =
tensorCastOperand.getOperand().getType().cast<TensorType>();
auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
auto resultType = tensorCast.getType().cast<TensorType>();
// We can remove the intermediate cast if joining all three produces the
// same result as just joining the source and result shapes.
auto firstJoin =
joinShapes(joinShapes(sourceType, intermediateType), resultType);
// The join might not exist if the cast sequence would fail at runtime.
if (!firstJoin)
return failure();
// The newJoin always exists if the above join exists, it might just contain
// less information. If so, we cannot drop the intermediate cast, as doing
// so would remove runtime checks.
auto newJoin = joinShapes(sourceType, resultType);
if (firstJoin != newJoin)
return failure();
rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
tensorCastOperand.getOperand());
return success();
}
};
} // namespace
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ChainedTensorCast>(context);
}
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(ExtractOp op) {
// Verify the # indices match if we have a ranked type.
if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
return op.emitOpError("incorrect number of indices for extract_element");
return success();
}
OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
// The tensor operand must be a known constant.
Attribute tensor = operands.front();
if (!tensor)
return {};
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
return splatTensor.getSplatValue();
// Otherwise, collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : llvm::drop_begin(operands, 1)) {
if (!indice || !indice.isa<IntegerAttr>())
return {};
indices.push_back(indice.cast<IntegerAttr>().getInt());
}
// If this is an elements attribute, query the value at the given indices.
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
if (elementsAttr && elementsAttr.isValidIndex(indices))
return elementsAttr.getValue(indices);
return {};
}
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
void FromElementsOp::build(OpBuilder &builder, OperationState &result,
Type elementType, ValueRange elements) {
Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
elementType);
result.addOperands(elements);
result.addTypes(resultTy);
}
void FromElementsOp::build(OpBuilder &builder, OperationState &result,
ValueRange elements) {
assert(!elements.empty() && "expected at least one element");
build(builder, result, elements.front().getType(), elements);
}
namespace {
// Canonicalizes the pattern of the form
//
// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
//
// to just %element.
struct ExtractElementFromTensorFromElements
: public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
if (extract.indices().size() != 1)
return failure();
auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
if (tensorFromElements == nullptr)
return failure();
APInt index;
if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
return failure();
rewriter.replaceOp(extract,
tensorFromElements.getOperand(index.getZExtValue()));
return success();
}
};
} // namespace
void FromElementsOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ExtractElementFromTensorFromElements>(context);
}
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(GenerateOp op) {
// Ensure that the tensor type has as many dynamic dimensions as are specified
// by the operands.
RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
if (op.getNumOperands() != resultTy.getNumDynamicDims())
return op.emitError("must have as many index operands as dynamic extents "
"in the result type");
// Ensure that region arguments span the index space.
if (!llvm::all_of(op.body().getArgumentTypes(),
[](Type ty) { return ty.isIndex(); }))
return op.emitError("all body arguments must be index");
if (op.body().getNumArguments() != resultTy.getRank())
return op.emitError("must have one body argument per input dimension");
// Ensure that the region yields an element of the right type.
auto yieldOp =
llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
if (yieldOp.value().getType() != resultTy.getElementType())
return op.emitOpError(
"body must be terminated with a `yield` operation of the tensor "
"element type");
return success();
}
void GenerateOp::build(
OpBuilder &b, OperationState &result, Type resultTy,
ValueRange dynamicExtents,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
build(b, result, resultTy, dynamicExtents);
// Build and populate body.
OpBuilder::InsertionGuard guard(b);
Region *bodyRegion = result.regions.front().get();
auto rank = resultTy.cast<RankedTensorType>().getRank();
SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
Block *bodyBlock =
b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
bodyBuilder(b, result.location, bodyBlock->getArguments());
}
namespace {
/// Canonicalizes tensor.generate operations with a constant
/// operand into the equivalent operation with the operand expressed in the
/// result type, instead. We also insert a type cast to make sure that the
/// resulting IR is still well-typed.
struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
using OpRewritePattern<GenerateOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
PatternRewriter &rewriter) const final {
auto resultType =
tensorFromElements.getResult().getType().cast<RankedTensorType>();
if (resultType.hasStaticShape())
return failure();
SmallVector<Value, 4> newOperands;
SmallVector<int64_t, 4> newShape;
auto operandsIt = tensorFromElements.dynamicExtents().begin();
for (int64_t dim : resultType.getShape()) {
if (dim != RankedTensorType::kDynamicSize) {
newShape.push_back(dim);
continue;
}
APInt index;
if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
newShape.push_back(RankedTensorType::kDynamicSize);
newOperands.push_back(*operandsIt++);
continue;
}
newShape.push_back(index.getSExtValue());
operandsIt++;
}
if (newOperands.size() == tensorFromElements.dynamicExtents().size())
return failure();
auto loc = tensorFromElements.getLoc();
auto newOp = rewriter.create<GenerateOp>(
loc, RankedTensorType::get(newShape, resultType.getElementType()),
newOperands);
rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
newOp.body().begin());
rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
newOp);
return success();
}
};
/// Canonicalizes the pattern of the form
///
/// %tensor = tensor.generate %x {
/// ^bb0(%arg0: index): // no predecessors
/// <computation>
/// yield %1 : index
/// } : tensor<?xindex>
/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
///
/// to just <computation> with %arg0 replaced by %c0. We only do this if the
/// tensor.generate operation has no side-effects.
struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
return failure();
BlockAndValueMapping mapping;
Block *body = tensorFromElements.getBody();
mapping.map(body->getArguments(), extract.indices());
for (auto &op : body->without_terminator())
rewriter.clone(op, mapping);
auto yield = cast<YieldOp>(body->getTerminator());
rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
return success();
}
};
/// Canonicalizes the pattern of the form
///
/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
///
/// to
///
/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
if (!tensorCast)
return failure();
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
extract.indices());
return success();
}
};
} // namespace
void GenerateOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
// TODO: Move extract patterns to tensor::ExtractOp.
results.insert<ExtractFromTensorGenerate, ExtractFromTensorCast,
StaticTensorGenerate>(context);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"