mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-26 15:06:09 +00:00
[mlir][Interfaces] LISH: Add helpers for hyperrectangular subsets (#70628)
The majority of subset ops operate on hyperrectangular subsets. This commit adds a new optional interface method (`getAccessedHyperrectangularSlice`) that can be implemented by such subset ops. If implemented, the other `operatesOn...` interface methods of the `SubsetOpInterface` do not have to be implemented anymore. The comparison logic for hyperrectangular subsets (is disjoint/equivalent) is implemented with `ValueBoundsOpInterface`. This makes the subset hoisting more powerful: simple cases where two different SSA values always have the same runtime value can now be supported.
This commit is contained in:
parent
5b6ceaf8c3
commit
ff614a5729
@ -220,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
|
||||
AllElementTypesMatch<["source", "dest"]>,
|
||||
BufferizableOpInterface, DestinationStyleOpInterface,
|
||||
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
|
||||
DeclareOpInterfaceMethods<SubsetOpInterface>,
|
||||
DeclareOpInterfaceMethods<SubsetOpInterface,
|
||||
["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
|
||||
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
|
||||
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
|
||||
"buildSubsetExtraction", "isEquivalentSubset"]>,
|
||||
|
@ -268,6 +268,11 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
|
||||
|
||||
public:
|
||||
void dump() const { llvm::errs() << *this << "\n"; }
|
||||
|
||||
MLIRContext *getContext() const {
|
||||
return is<Attribute>() ? get<Attribute>().getContext()
|
||||
: get<Value>().getContext();
|
||||
}
|
||||
};
|
||||
|
||||
// Temporarily exit the MLIR namespace to add casting support as later code in
|
||||
|
@ -10,6 +10,7 @@
|
||||
#define MLIR_INTERFACES_SUBSETOPINTERFACE_H_
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
class SubsetOpInterface;
|
||||
@ -27,10 +28,23 @@ OpOperand &defaultGetDestinationOperand(Operation *op);
|
||||
/// `DestinationStyleOpInterface`.
|
||||
OpResult defaultGetUpdatedDestination(Operation *op);
|
||||
|
||||
/// Default implementation of `isEquivalentSubset`.
|
||||
/// Default implementation of `SubsetInsertionOpInterface::isEquivalentSubset`.
|
||||
bool defaultIsEquivalentSubset(Operation *op, Value candidate,
|
||||
function_ref<bool(Value, Value)> equivalenceFn);
|
||||
|
||||
/// Default implementation of `SubsetOpInterface::operatesOnEquivalentSubset`.
|
||||
bool defaultOperatesOnEquivalentSubset(
|
||||
Operation *op, SubsetOpInterface candidate,
|
||||
function_ref<bool(Value, Value)> equivalenceFn);
|
||||
|
||||
/// Default implementation of `SubsetOpInterface::operatesOnDisjointSubset`.
|
||||
bool defaultOperatesOnDisjointSubset(
|
||||
Operation *op, SubsetOpInterface candidate,
|
||||
function_ref<bool(Value, Value)> equivalenceFn);
|
||||
|
||||
/// Return the container that the given subset op is operating on.
|
||||
Value getTensorContainer(Operation *op);
|
||||
|
||||
/// Verify `SubsetOpInterface`.
|
||||
LogicalResult verifySubsetOpInterface(SubsetOpInterface op);
|
||||
|
||||
|
@ -32,11 +32,6 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
|
||||
hyperrectangular slice.
|
||||
- `tensor.gather/scatter` describe the subset as list of indices. (Not
|
||||
implemented yet.)
|
||||
|
||||
Note: This interface does not expose any interface methods to get a
|
||||
description of the accessed subset. That is because there is currently no
|
||||
efficient way to describe arbitrary subsets. This interface merely provides
|
||||
interface methods to check if two subsets are equivalent or disjoint.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir";
|
||||
@ -46,24 +41,59 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
|
||||
Return "true" if this op and the given candidate subset op operate on
|
||||
equivalent subsets. Return "false" if the two subsets are disjoint
|
||||
or cannot be proven to be equivalent.
|
||||
|
||||
This interface method does not have to be implemented if
|
||||
`getAccessedHyperrectangularSlice` is implemented.
|
||||
}],
|
||||
/*retType=*/"bool",
|
||||
/*methodName=*/"operatesOnEquivalentSubset",
|
||||
/*args=*/(ins
|
||||
"::mlir::SubsetOpInterface":$candidate,
|
||||
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
|
||||
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return ::mlir::detail::defaultOperatesOnEquivalentSubset(
|
||||
$_op, candidate, equivalenceFn);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return "true" if this op and the given candidate subset op operate on
|
||||
disjoint subsets. Return "false" if the two subsets are equivalent,
|
||||
overlapping or cannot be proven to be disjoint.
|
||||
|
||||
This interface method does not have to be implemented if
|
||||
`getAccessedHyperrectangularSlice` is implemented.
|
||||
}],
|
||||
/*retType=*/"bool",
|
||||
/*methodName=*/"operatesOnDisjointSubset",
|
||||
/*args=*/(ins
|
||||
"::mlir::SubsetOpInterface":$candidate,
|
||||
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
|
||||
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return ::mlir::detail::defaultOperatesOnDisjointSubset(
|
||||
$_op, candidate, equivalenceFn);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
If this op operates on a hyperrectangular subset, return a
|
||||
description of the subset in terms of offsets, sizes and strides.
|
||||
Otherwise, return "failure".
|
||||
|
||||
This interface method is a convenience method for the most common case
|
||||
of hyperrectangular subset ops. It is optional. If it is implemented,
|
||||
`operatesOnEquivalentSubset` and `operatesOnDisjointSubset` do not
|
||||
have to be implemented.
|
||||
}],
|
||||
/*retType=*/"::mlir::FailureOr<::mlir::HyperrectangularSlice>",
|
||||
/*methodName=*/"getAccessedHyperrectangularSlice",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return ::mlir::failure();
|
||||
}]
|
||||
>,
|
||||
];
|
||||
|
||||
@ -71,6 +101,15 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
|
||||
return ::mlir::detail::verifySubsetOpInterface(
|
||||
::mlir::cast<::mlir::SubsetOpInterface>($_op));
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return the container that this operation is operating on. In case of an
|
||||
/// extraction op, the container is the source tensor. In case of an
|
||||
/// insertion op, the container is the destination tensor.
|
||||
Value getTensorContainer() {
|
||||
return ::mlir::detail::getTensorContainer(getOperation());
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def SubsetExtractionOpInterface
|
||||
|
@ -21,6 +21,31 @@
|
||||
namespace mlir {
|
||||
class OffsetSizeAndStrideOpInterface;
|
||||
|
||||
/// A hyperrectangular slice, represented as a list of offsets, sizes and
|
||||
/// strides.
|
||||
class HyperrectangularSlice {
|
||||
public:
|
||||
HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
ArrayRef<OpFoldResult> strides);
|
||||
|
||||
/// Create a hyperrectangular slice with unit strides.
|
||||
HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes);
|
||||
|
||||
/// Infer a hyperrectangular slice from `OffsetSizeAndStrideOpInterface`.
|
||||
HyperrectangularSlice(OffsetSizeAndStrideOpInterface op);
|
||||
|
||||
ArrayRef<OpFoldResult> getMixedOffsets() const { return mixedOffsets; }
|
||||
ArrayRef<OpFoldResult> getMixedSizes() const { return mixedSizes; }
|
||||
ArrayRef<OpFoldResult> getMixedStrides() const { return mixedStrides; }
|
||||
|
||||
private:
|
||||
SmallVector<OpFoldResult> mixedOffsets;
|
||||
SmallVector<OpFoldResult> mixedSizes;
|
||||
SmallVector<OpFoldResult> mixedStrides;
|
||||
};
|
||||
|
||||
using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
|
||||
|
||||
/// A helper class to be used with `ValueBoundsOpInterface`. This class stores a
|
||||
@ -182,12 +207,34 @@ public:
|
||||
std::optional<int64_t> dim1 = std::nullopt,
|
||||
std::optional<int64_t> dim2 = std::nullopt);
|
||||
|
||||
/// Compute whether the given values/attributes are equal. Return "failure" if
|
||||
/// equality could not be determined.
|
||||
///
|
||||
/// `ofr1`/`ofr2` must be of index type.
|
||||
static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
|
||||
|
||||
/// Return "true" if the given slices are guaranteed to be overlapping.
|
||||
/// Return "false" if the given slices are guaranteed to be non-overlapping.
|
||||
/// Return "failure" if unknown.
|
||||
static FailureOr<bool>
|
||||
areOverlappingSlices(OffsetSizeAndStrideOpInterface slice1,
|
||||
OffsetSizeAndStrideOpInterface slice2);
|
||||
///
|
||||
/// Slices are overlapping if for all dimensions:
|
||||
/// * offset1 + size1 * stride1 <= offset2
|
||||
/// * and offset2 + size2 * stride2 <= offset1
|
||||
///
|
||||
/// Slice are non-overlapping if the above constraint is not satisfied for
|
||||
/// at least one dimension.
|
||||
static FailureOr<bool> areOverlappingSlices(MLIRContext *ctx,
|
||||
HyperrectangularSlice slice1,
|
||||
HyperrectangularSlice slice2);
|
||||
|
||||
/// Return "true" if the given slices are guaranteed to be equivalent.
|
||||
/// Return "false" if the given slices are guaranteed to be non-equivalent.
|
||||
/// Return "failure" if unknown.
|
||||
///
|
||||
/// Slices are equivalent if their offsets, sizes and strices are equal.
|
||||
static FailureOr<bool> areEquivalentSlices(MLIRContext *ctx,
|
||||
HyperrectangularSlice slice1,
|
||||
HyperrectangularSlice slice2);
|
||||
|
||||
/// Add a bound for the given index-typed value or shaped value. This function
|
||||
/// returns a builder that adds the bound.
|
||||
|
@ -17,73 +17,12 @@ using namespace mlir::tensor;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Return the tensor that the given subset op operates on.
|
||||
Value getContainerOperand(SubsetOpInterface op) {
|
||||
if (auto extractionOp =
|
||||
dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
|
||||
return extractionOp.getSourceOperand().get();
|
||||
if (auto insertionOp =
|
||||
dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
|
||||
return insertionOp.getDestinationOperand().get();
|
||||
llvm_unreachable("expected SubsetExtraction/InsertionOpInterface");
|
||||
}
|
||||
|
||||
/// Return "true" if the two ops operate on an equivalent subset.
|
||||
/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
|
||||
/// if the two ops operate non-equivalent subsets, if equivalence cannot be
|
||||
/// determined or if `op1` is not a subset op.
|
||||
template <typename OpTy>
|
||||
bool operateOnEquivalentSubsets(
|
||||
OpTy op1, SubsetOpInterface op2,
|
||||
function_ref<bool(Value, Value)> equivalenceFn) {
|
||||
auto offsetsSizesAndStrides2 =
|
||||
dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
|
||||
if (!offsetsSizesAndStrides2)
|
||||
return false;
|
||||
if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2,
|
||||
isEqualConstantIntOrValue))
|
||||
return false;
|
||||
return equivalenceFn(
|
||||
getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
|
||||
getContainerOperand(op2));
|
||||
}
|
||||
|
||||
/// Return "true" if the two ops operate on a disjoint subsets.
|
||||
/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
|
||||
/// if the two ops operate non-disjoint subsets, if disjointness cannot be
|
||||
/// determined or if `op1` is not a subset op.
|
||||
template <typename OpTy>
|
||||
bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2,
|
||||
function_ref<bool(Value, Value)> equivalenceFn) {
|
||||
auto offsetsSizesAndStrides2 =
|
||||
dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
|
||||
if (!offsetsSizesAndStrides2)
|
||||
return false;
|
||||
FailureOr<bool> overlappingSlices =
|
||||
ValueBoundsConstraintSet::areOverlappingSlices(op1,
|
||||
offsetsSizesAndStrides2);
|
||||
if (failed(overlappingSlices) || *overlappingSlices)
|
||||
return false;
|
||||
return equivalenceFn(
|
||||
getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
|
||||
getContainerOperand(op2));
|
||||
}
|
||||
|
||||
struct ExtractSliceOpSubsetOpInterface
|
||||
: public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
|
||||
tensor::ExtractSliceOp> {
|
||||
bool operatesOnEquivalentSubset(
|
||||
Operation *op, SubsetOpInterface candidate,
|
||||
function_ref<bool(Value, Value)> equivalenceFn) const {
|
||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
||||
return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn);
|
||||
}
|
||||
|
||||
bool operatesOnDisjointSubset(
|
||||
Operation *op, SubsetOpInterface candidate,
|
||||
function_ref<bool(Value, Value)> equivalenceFn) const {
|
||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
||||
return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn);
|
||||
FailureOr<HyperrectangularSlice>
|
||||
getAccessedHyperrectangularSlice(Operation *op) const {
|
||||
return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
|
||||
}
|
||||
};
|
||||
|
||||
@ -99,18 +38,9 @@ template <typename OpTy>
|
||||
struct InsertSliceLikeOpSubsetOpInterface
|
||||
: public SubsetOpInterface::ExternalModel<
|
||||
InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
|
||||
bool operatesOnEquivalentSubset(
|
||||
Operation *op, SubsetOpInterface candidate,
|
||||
function_ref<bool(Value, Value)> equivalenceFn) const {
|
||||
auto insertSliceOp = cast<OpTy>(op);
|
||||
return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn);
|
||||
}
|
||||
|
||||
bool operatesOnDisjointSubset(
|
||||
Operation *op, SubsetOpInterface candidate,
|
||||
function_ref<bool(Value, Value)> equivalenceFn) const {
|
||||
auto insertSliceOp = cast<OpTy>(op);
|
||||
return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn);
|
||||
FailureOr<HyperrectangularSlice>
|
||||
getAccessedHyperrectangularSlice(Operation *op) const {
|
||||
return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -93,10 +93,12 @@ add_mlir_library(MLIRSubsetOpInterface
|
||||
DEPENDS
|
||||
MLIRDestinationStyleOpInterface
|
||||
MLIRSubsetOpInterfaceIncGen
|
||||
MLIRValueBoundsOpInterface
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRDestinationStyleOpInterface
|
||||
MLIRIR
|
||||
MLIRValueBoundsOpInterface
|
||||
)
|
||||
|
||||
add_mlir_interface_library(TilingInterface)
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlir/Interfaces/SubsetOpInterface.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
||||
|
||||
#include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
|
||||
|
||||
@ -40,6 +41,54 @@ bool detail::defaultIsEquivalentSubset(
|
||||
candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
|
||||
}
|
||||
|
||||
bool detail::defaultOperatesOnEquivalentSubset(
|
||||
Operation *op, SubsetOpInterface candidate,
|
||||
function_ref<bool(Value, Value)> equivalenceFn) {
|
||||
auto subsetOp = cast<SubsetOpInterface>(op);
|
||||
FailureOr<HyperrectangularSlice> slice =
|
||||
subsetOp.getAccessedHyperrectangularSlice();
|
||||
assert(succeeded(slice) &&
|
||||
"operatesOnEquivalentSubset must be implemented if "
|
||||
"getAccessedHyperrectangularSlice is not implemented");
|
||||
FailureOr<HyperrectangularSlice> otherSlice =
|
||||
candidate.getAccessedHyperrectangularSlice();
|
||||
if (failed(otherSlice))
|
||||
return false;
|
||||
if (!equivalenceFn(subsetOp.getTensorContainer(),
|
||||
candidate.getTensorContainer()))
|
||||
return false;
|
||||
FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
|
||||
op->getContext(), *slice, *otherSlice);
|
||||
return succeeded(equivalent) && *equivalent;
|
||||
}
|
||||
|
||||
bool detail::defaultOperatesOnDisjointSubset(
|
||||
Operation *op, SubsetOpInterface candidate,
|
||||
function_ref<bool(Value, Value)> equivalenceFn) {
|
||||
auto subsetOp = cast<SubsetOpInterface>(op);
|
||||
FailureOr<HyperrectangularSlice> slice =
|
||||
subsetOp.getAccessedHyperrectangularSlice();
|
||||
assert(succeeded(slice) &&
|
||||
"defaultOperatesOnDisjointSubset must be implemented if "
|
||||
"getAccessedHyperrectangularSlice is not implemented");
|
||||
FailureOr<HyperrectangularSlice> otherSlice =
|
||||
candidate.getAccessedHyperrectangularSlice();
|
||||
if (failed(otherSlice))
|
||||
return false;
|
||||
if (!equivalenceFn(subsetOp.getTensorContainer(),
|
||||
candidate.getTensorContainer()))
|
||||
return false;
|
||||
FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
|
||||
op->getContext(), *slice, *otherSlice);
|
||||
return succeeded(overlapping) && !*overlapping;
|
||||
}
|
||||
|
||||
Value detail::getTensorContainer(Operation *op) {
|
||||
if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
|
||||
return insertionOp.getDestinationOperand().get();
|
||||
return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
|
||||
}
|
||||
|
||||
LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
|
||||
if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
|
||||
isa<SubsetInsertionOpInterface>(op.getOperation())))
|
||||
|
@ -25,6 +25,32 @@ namespace mlir {
|
||||
#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
|
||||
} // namespace mlir
|
||||
|
||||
HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
ArrayRef<OpFoldResult> strides)
|
||||
: mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) {
|
||||
assert(offsets.size() == sizes.size() &&
|
||||
"expected same number of offsets, sizes, strides");
|
||||
assert(offsets.size() == strides.size() &&
|
||||
"expected same number of offsets, sizes, strides");
|
||||
}
|
||||
|
||||
HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes)
|
||||
: mixedOffsets(offsets), mixedSizes(sizes) {
|
||||
assert(offsets.size() == sizes.size() &&
|
||||
"expected same number of offsets and sizes");
|
||||
// Assume that all strides are 1.
|
||||
if (offsets.empty())
|
||||
return;
|
||||
MLIRContext *ctx = offsets.front().getContext();
|
||||
mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1));
|
||||
}
|
||||
|
||||
HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op)
|
||||
: HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(),
|
||||
op.getMixedStrides()) {}
|
||||
|
||||
/// If ofr is a constant integer or an IntegerAttr, return the integer.
|
||||
static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
|
||||
// Case 1: Check for Constant integer.
|
||||
@ -524,19 +550,44 @@ ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
|
||||
return *delta == 0;
|
||||
}
|
||||
|
||||
FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
|
||||
OffsetSizeAndStrideOpInterface slice1,
|
||||
OffsetSizeAndStrideOpInterface slice2) {
|
||||
assert(slice1.getStaticOffsets().size() == slice1.getStaticOffsets().size() &&
|
||||
FailureOr<bool> ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1,
|
||||
OpFoldResult ofr2) {
|
||||
Builder b(ofr1.getContext());
|
||||
AffineMap map =
|
||||
AffineMap::get(/*dimCount=*/0, /*symbolCount=*/2,
|
||||
b.getAffineSymbolExpr(0) - b.getAffineSymbolExpr(1));
|
||||
SmallVector<OpFoldResult> ofrOperands;
|
||||
ofrOperands.push_back(ofr1);
|
||||
ofrOperands.push_back(ofr2);
|
||||
SmallVector<Value> valueOperands;
|
||||
AffineMap foldedMap =
|
||||
foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
|
||||
ValueDimList valueDims;
|
||||
for (Value v : valueOperands) {
|
||||
assert(v.getType().isIndex() && "expected index type");
|
||||
valueDims.emplace_back(v, std::nullopt);
|
||||
}
|
||||
FailureOr<int64_t> delta =
|
||||
computeConstantBound(presburger::BoundType::EQ, foldedMap, valueDims);
|
||||
if (failed(delta))
|
||||
return failure();
|
||||
return *delta == 0;
|
||||
}
|
||||
|
||||
FailureOr<bool>
|
||||
ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
|
||||
HyperrectangularSlice slice1,
|
||||
HyperrectangularSlice slice2) {
|
||||
assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
|
||||
"expected slices of same rank");
|
||||
assert(slice1.getStaticSizes().size() == slice1.getStaticSizes().size() &&
|
||||
assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
|
||||
"expected slices of same rank");
|
||||
assert(slice1.getStaticStrides().size() == slice1.getStaticStrides().size() &&
|
||||
assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
|
||||
"expected slices of same rank");
|
||||
|
||||
Builder b(slice1.getContext());
|
||||
Builder b(ctx);
|
||||
bool foundUnknownBound = false;
|
||||
for (int64_t i = 0, e = slice1.getStaticOffsets().size(); i < e; ++i) {
|
||||
for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) {
|
||||
AffineMap map =
|
||||
AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
|
||||
b.getAffineSymbolExpr(0) +
|
||||
@ -588,6 +639,48 @@ FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
|
||||
return true;
|
||||
}
|
||||
|
||||
FailureOr<bool>
|
||||
ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx,
|
||||
HyperrectangularSlice slice1,
|
||||
HyperrectangularSlice slice2) {
|
||||
assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
|
||||
"expected slices of same rank");
|
||||
assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
|
||||
"expected slices of same rank");
|
||||
assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
|
||||
"expected slices of same rank");
|
||||
|
||||
// The two slices are equivalent if all of their offsets, sizes and strides
|
||||
// are equal. If equality cannot be determined for at least one of those
|
||||
// values, equivalence cannot be determined and this function returns
|
||||
// "failure".
|
||||
for (auto [offset1, offset2] :
|
||||
llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) {
|
||||
FailureOr<bool> equal = areEqual(offset1, offset2);
|
||||
if (failed(equal))
|
||||
return failure();
|
||||
if (!equal.value())
|
||||
return false;
|
||||
}
|
||||
for (auto [size1, size2] :
|
||||
llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) {
|
||||
FailureOr<bool> equal = areEqual(size1, size2);
|
||||
if (failed(equal))
|
||||
return failure();
|
||||
if (!equal.value())
|
||||
return false;
|
||||
}
|
||||
for (auto [stride1, stride2] :
|
||||
llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) {
|
||||
FailureOr<bool> equal = areEqual(stride1, stride2);
|
||||
if (failed(equal))
|
||||
return failure();
|
||||
if (!equal.value())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ValueBoundsConstraintSet::BoundBuilder &
|
||||
ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
|
||||
assert(!this->dim.has_value() && "dim was already set");
|
||||
|
@ -7,6 +7,11 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%ub = "test.foo"() : () -> (index)
|
||||
%step = "test.foo"() : () -> (index)
|
||||
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%add = arith.addi %c0, %c1 : index
|
||||
%sub = arith.subi %add, %c1 : index
|
||||
|
||||
// CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
|
||||
// CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
|
||||
%0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
|
||||
@ -17,7 +22,9 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
|
||||
// CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
|
||||
%2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
|
||||
%3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
|
||||
// Obfuscate the IR by inserting at offset %sub instead of 0; both of them
|
||||
// have the same value.
|
||||
%3 = tensor.insert_slice %2 into %t[%sub][5][1] : tensor<5xf32> into tensor<?xf32>
|
||||
// CHECK: scf.yield %[[t]], %[[foo]]
|
||||
scf.yield %3 : tensor<?xf32>
|
||||
}
|
||||
|
@ -10230,6 +10230,7 @@ cc_library(
|
||||
":IR",
|
||||
":SubsetOpInterfaceIncGen",
|
||||
":Support",
|
||||
":ValueBoundsOpInterface",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user