2020-12-11 14:20:03 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2021-01-14 18:28:48 -08:00
|
|
|
#include "mlir/IR/BlockAndValueMapping.h"
|
2020-12-11 14:20:03 -08:00
|
|
|
#include "mlir/IR/Builders.h"
|
2021-01-14 18:28:48 -08:00
|
|
|
#include "mlir/IR/Matchers.h"
|
2020-12-15 16:47:19 -08:00
|
|
|
#include "mlir/IR/PatternMatch.h"
|
2020-12-11 14:20:03 -08:00
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::tensor;
|
|
|
|
|
2020-12-15 16:47:19 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// CastOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Determines whether tensor::CastOp casts to a more dynamic version of the
|
|
|
|
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
|
|
|
|
/// implement canonicalization patterns for ops in different dialects that may
|
|
|
|
/// consume the results of tensor.cast operations. Such foldable tensor.cast
|
|
|
|
/// operations are typically inserted as `subtensor` ops and are canonicalized,
|
|
|
|
/// to preserve the type compatibility of their uses.
|
|
|
|
///
|
|
|
|
/// Returns true when all conditions are met:
|
|
|
|
/// 1. source and result are ranked tensors with same element type and rank.
|
|
|
|
/// 2. the tensor type has more static information than the result
|
|
|
|
///
|
|
|
|
/// Example:
|
|
|
|
/// ```mlir
|
|
|
|
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
|
|
|
|
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
|
|
|
|
/// ```
|
|
|
|
///
|
|
|
|
/// folds into:
|
|
|
|
///
|
|
|
|
/// ```mlir
|
|
|
|
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
|
|
|
|
/// ```
|
|
|
|
bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
|
|
|
if (!castOp)
|
|
|
|
return false;
|
|
|
|
|
|
|
|
RankedTensorType sourceType =
|
|
|
|
castOp.source().getType().dyn_cast<RankedTensorType>();
|
|
|
|
RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
|
|
|
|
|
|
|
|
// Requires RankedTensorType.
|
|
|
|
if (!sourceType || !resultType)
|
|
|
|
return false;
|
|
|
|
|
|
|
|
// Requires same elemental type.
|
|
|
|
if (sourceType.getElementType() != resultType.getElementType())
|
|
|
|
return false;
|
|
|
|
|
|
|
|
// Requires same rank.
|
|
|
|
if (sourceType.getRank() != resultType.getRank())
|
|
|
|
return false;
|
|
|
|
|
|
|
|
// If cast is towards more static sizes along any dimension, don't fold.
|
|
|
|
for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
|
|
|
|
if (ShapedType::isDynamic(std::get<0>(t)) &&
|
|
|
|
!ShapedType::isDynamic(std::get<1>(t)))
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool CastOp::areCastCompatible(Type a, Type b) {
|
|
|
|
auto aT = a.dyn_cast<TensorType>();
|
|
|
|
auto bT = b.dyn_cast<TensorType>();
|
|
|
|
if (!aT || !bT)
|
|
|
|
return false;
|
|
|
|
|
|
|
|
if (aT.getElementType() != bT.getElementType())
|
|
|
|
return false;
|
|
|
|
|
|
|
|
return succeeded(verifyCompatibleShape(aT, bT));
|
|
|
|
}
|
|
|
|
|
|
|
|
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
return impl::foldCastOp(*this);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Compute a TensorType that has the joined shape knowledge of the two
|
|
|
|
/// given TensorTypes. The element types need to match.
|
|
|
|
static TensorType joinShapes(TensorType one, TensorType two) {
|
|
|
|
assert(one.getElementType() == two.getElementType());
|
|
|
|
|
|
|
|
if (!one.hasRank())
|
|
|
|
return two;
|
|
|
|
if (!two.hasRank())
|
|
|
|
return one;
|
|
|
|
|
|
|
|
int64_t rank = one.getRank();
|
|
|
|
if (rank != two.getRank())
|
|
|
|
return {};
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> join;
|
|
|
|
join.reserve(rank);
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
|
|
if (one.isDynamicDim(i)) {
|
|
|
|
join.push_back(two.getDimSize(i));
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (two.isDynamicDim(i)) {
|
|
|
|
join.push_back(one.getDimSize(i));
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (one.getDimSize(i) != two.getDimSize(i))
|
|
|
|
return {};
|
|
|
|
join.push_back(one.getDimSize(i));
|
|
|
|
}
|
|
|
|
return RankedTensorType::get(join, one.getElementType());
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
/// Replaces chains of two tensor.cast operations by a single tensor.cast
|
|
|
|
/// operation if doing so does not remove runtime constraints.
|
|
|
|
struct ChainedTensorCast : public OpRewritePattern<CastOp> {
|
|
|
|
using OpRewritePattern<CastOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(CastOp tensorCast,
|
|
|
|
PatternRewriter &rewriter) const final {
|
|
|
|
auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
|
|
|
|
|
|
|
|
if (!tensorCastOperand)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto sourceType =
|
|
|
|
tensorCastOperand.getOperand().getType().cast<TensorType>();
|
|
|
|
auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
|
|
|
|
auto resultType = tensorCast.getType().cast<TensorType>();
|
|
|
|
|
|
|
|
// We can remove the intermediate cast if joining all three produces the
|
|
|
|
// same result as just joining the source and result shapes.
|
|
|
|
auto firstJoin =
|
|
|
|
joinShapes(joinShapes(sourceType, intermediateType), resultType);
|
|
|
|
|
|
|
|
// The join might not exist if the cast sequence would fail at runtime.
|
|
|
|
if (!firstJoin)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
// The newJoin always exists if the above join exists, it might just contain
|
|
|
|
// less information. If so, we cannot drop the intermediate cast, as doing
|
|
|
|
// so would remove runtime checks.
|
|
|
|
auto newJoin = joinShapes(sourceType, resultType);
|
|
|
|
if (firstJoin != newJoin)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
|
|
|
|
tensorCastOperand.getOperand());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|
|
|
MLIRContext *context) {
|
|
|
|
results.insert<ChainedTensorCast>(context);
|
|
|
|
}
|
|
|
|
|
2020-12-11 14:20:03 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ExtractOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static LogicalResult verify(ExtractOp op) {
|
|
|
|
// Verify the # indices match if we have a ranked type.
|
|
|
|
if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
|
|
|
|
if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
|
|
|
|
return op.emitOpError("incorrect number of indices for extract_element");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
|
|
|
// The tensor operand must be a known constant.
|
|
|
|
Attribute tensor = operands.front();
|
|
|
|
if (!tensor)
|
|
|
|
return {};
|
|
|
|
// If this is a splat elements attribute, simply return the value. All of the
|
|
|
|
// elements of a splat attribute are the same.
|
|
|
|
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
|
|
|
|
return splatTensor.getSplatValue();
|
|
|
|
|
|
|
|
// Otherwise, collect the constant indices into the tensor.
|
|
|
|
SmallVector<uint64_t, 8> indices;
|
|
|
|
for (Attribute indice : llvm::drop_begin(operands, 1)) {
|
|
|
|
if (!indice || !indice.isa<IntegerAttr>())
|
|
|
|
return {};
|
|
|
|
indices.push_back(indice.cast<IntegerAttr>().getInt());
|
|
|
|
}
|
|
|
|
|
|
|
|
// If this is an elements attribute, query the value at the given indices.
|
|
|
|
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
|
|
|
|
if (elementsAttr && elementsAttr.isValidIndex(indices))
|
|
|
|
return elementsAttr.getValue(indices);
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
2021-01-14 18:28:48 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// FromElementsOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
void FromElementsOp::build(OpBuilder &builder, OperationState &result,
|
|
|
|
Type elementType, ValueRange elements) {
|
|
|
|
Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
|
|
|
|
elementType);
|
|
|
|
result.addOperands(elements);
|
|
|
|
result.addTypes(resultTy);
|
|
|
|
}
|
|
|
|
|
|
|
|
void FromElementsOp::build(OpBuilder &builder, OperationState &result,
|
|
|
|
ValueRange elements) {
|
|
|
|
assert(!elements.empty() && "expected at least one element");
|
|
|
|
build(builder, result, elements.front().getType(), elements);
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
// Canonicalizes the pattern of the form
|
|
|
|
//
|
|
|
|
// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
|
|
|
|
// %extracted_element = tensor.extract %tensor[%c0] : tensor<1xi32>
|
|
|
|
//
|
|
|
|
// to just %element.
|
|
|
|
struct ExtractElementFromTensorFromElements
|
|
|
|
: public OpRewritePattern<tensor::ExtractOp> {
|
|
|
|
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
|
|
|
PatternRewriter &rewriter) const final {
|
|
|
|
if (extract.indices().size() != 1)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto tensorFromElements = extract.tensor().getDefiningOp<FromElementsOp>();
|
|
|
|
if (tensorFromElements == nullptr)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
APInt index;
|
|
|
|
if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
|
|
|
|
return failure();
|
|
|
|
rewriter.replaceOp(extract,
|
|
|
|
tensorFromElements.getOperand(index.getZExtValue()));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void FromElementsOp::getCanonicalizationPatterns(
|
|
|
|
OwningRewritePatternList &results, MLIRContext *context) {
|
|
|
|
results.insert<ExtractElementFromTensorFromElements>(context);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// GenerateOp
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static LogicalResult verify(GenerateOp op) {
|
|
|
|
// Ensure that the tensor type has as many dynamic dimensions as are specified
|
|
|
|
// by the operands.
|
|
|
|
RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
|
|
|
|
if (op.getNumOperands() != resultTy.getNumDynamicDims())
|
|
|
|
return op.emitError("must have as many index operands as dynamic extents "
|
|
|
|
"in the result type");
|
|
|
|
|
|
|
|
// Ensure that region arguments span the index space.
|
|
|
|
if (!llvm::all_of(op.body().getArgumentTypes(),
|
|
|
|
[](Type ty) { return ty.isIndex(); }))
|
|
|
|
return op.emitError("all body arguments must be index");
|
|
|
|
if (op.body().getNumArguments() != resultTy.getRank())
|
|
|
|
return op.emitError("must have one body argument per input dimension");
|
|
|
|
|
|
|
|
// Ensure that the region yields an element of the right type.
|
|
|
|
auto yieldOp =
|
|
|
|
llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
|
|
|
|
if (yieldOp.value().getType() != resultTy.getElementType())
|
|
|
|
return op.emitOpError(
|
|
|
|
"body must be terminated with a `yield` operation of the tensor "
|
|
|
|
"element type");
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
void GenerateOp::build(
|
|
|
|
OpBuilder &b, OperationState &result, Type resultTy,
|
|
|
|
ValueRange dynamicExtents,
|
|
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
|
|
|
|
build(b, result, resultTy, dynamicExtents);
|
|
|
|
|
|
|
|
// Build and populate body.
|
|
|
|
OpBuilder::InsertionGuard guard(b);
|
|
|
|
Region *bodyRegion = result.regions.front().get();
|
|
|
|
auto rank = resultTy.cast<RankedTensorType>().getRank();
|
|
|
|
SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
|
|
|
|
Block *bodyBlock =
|
|
|
|
b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes);
|
|
|
|
bodyBuilder(b, result.location, bodyBlock->getArguments());
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
/// Canonicalizes tensor.generate operations with a constant
|
|
|
|
/// operand into the equivalent operation with the operand expressed in the
|
|
|
|
/// result type, instead. We also insert a type cast to make sure that the
|
|
|
|
/// resulting IR is still well-typed.
|
|
|
|
struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
|
|
|
|
using OpRewritePattern<GenerateOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
|
|
|
|
PatternRewriter &rewriter) const final {
|
|
|
|
auto resultType =
|
|
|
|
tensorFromElements.getResult().getType().cast<RankedTensorType>();
|
|
|
|
|
|
|
|
if (resultType.hasStaticShape())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
SmallVector<Value, 4> newOperands;
|
|
|
|
SmallVector<int64_t, 4> newShape;
|
|
|
|
auto operandsIt = tensorFromElements.dynamicExtents().begin();
|
|
|
|
|
|
|
|
for (int64_t dim : resultType.getShape()) {
|
|
|
|
if (dim != RankedTensorType::kDynamicSize) {
|
|
|
|
newShape.push_back(dim);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
APInt index;
|
|
|
|
if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
|
|
|
|
newShape.push_back(RankedTensorType::kDynamicSize);
|
|
|
|
newOperands.push_back(*operandsIt++);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
newShape.push_back(index.getSExtValue());
|
|
|
|
operandsIt++;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (newOperands.size() == tensorFromElements.dynamicExtents().size())
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
auto loc = tensorFromElements.getLoc();
|
|
|
|
auto newOp = rewriter.create<GenerateOp>(
|
|
|
|
loc, RankedTensorType::get(newShape, resultType.getElementType()),
|
|
|
|
newOperands);
|
|
|
|
rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
|
|
|
|
newOp.body().begin());
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(tensorFromElements, resultType,
|
|
|
|
newOp);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// Canonicalizes the pattern of the form
|
|
|
|
///
|
|
|
|
/// %tensor = tensor.generate %x {
|
|
|
|
/// ^bb0(%arg0: index): // no predecessors
|
|
|
|
/// <computation>
|
|
|
|
/// yield %1 : index
|
|
|
|
/// } : tensor<?xindex>
|
|
|
|
/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
|
|
|
|
///
|
|
|
|
/// to just <computation> with %arg0 replaced by %c0. We only do this if the
|
|
|
|
/// tensor.generate operation has no side-effects.
|
|
|
|
struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
|
|
|
|
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
|
|
|
PatternRewriter &rewriter) const final {
|
|
|
|
auto tensorFromElements = extract.tensor().getDefiningOp<GenerateOp>();
|
|
|
|
if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
BlockAndValueMapping mapping;
|
|
|
|
Block *body = tensorFromElements.getBody();
|
|
|
|
mapping.map(body->getArguments(), extract.indices());
|
|
|
|
for (auto &op : body->without_terminator())
|
|
|
|
rewriter.clone(op, mapping);
|
|
|
|
|
|
|
|
auto yield = cast<YieldOp>(body->getTerminator());
|
|
|
|
|
|
|
|
rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// Canonicalizes the pattern of the form
|
|
|
|
///
|
|
|
|
/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
|
|
|
|
/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
|
|
|
|
///
|
|
|
|
/// to
|
|
|
|
///
|
|
|
|
/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
|
|
|
|
struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
|
|
|
|
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
|
|
|
PatternRewriter &rewriter) const final {
|
|
|
|
auto tensorCast = extract.tensor().getDefiningOp<tensor::CastOp>();
|
|
|
|
if (!tensorCast)
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(extract, tensorCast.source(),
|
|
|
|
extract.indices());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
void GenerateOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|
|
|
MLIRContext *context) {
|
|
|
|
// TODO: Move extract patterns to tensor::ExtractOp.
|
|
|
|
results.insert<ExtractFromTensorGenerate, ExtractFromTensorCast,
|
|
|
|
StaticTensorGenerate>(context);
|
|
|
|
}
|
|
|
|
|
2020-12-11 14:20:03 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TableGen'd op method definitions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#define GET_OP_CLASSES
|
|
|
|
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
|