[mlir][Interfaces] LISH: Add helpers for hyperrectangular subsets (#70628)

The majority of subset ops operate on hyperrectangular subsets. This
commit adds a new optional interface method
(`getAccessedHyperrectangularSlice`) that can be implemented by such
subset ops. If implemented, the other `operatesOn...` interface methods
of the `SubsetOpInterface` do not have to be implemented anymore.

The comparison logic for hyperrectangular subsets (is
disjoint/equivalent) is implemented with `ValueBoundsOpInterface`. This
makes the subset hoisting more powerful: simple cases where two
different SSA values always have the same runtime value can now be
supported.
This commit is contained in:
Matthias Springer 2023-11-01 11:29:00 +09:00 committed by GitHub
parent 5b6ceaf8c3
commit ff614a5729
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 285 additions and 97 deletions

View File

@ -220,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
AllElementTypesMatch<["source", "dest"]>, AllElementTypesMatch<["source", "dest"]>,
BufferizableOpInterface, DestinationStyleOpInterface, BufferizableOpInterface, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>, DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<SubsetOpInterface>, DeclareOpInterfaceMethods<SubsetOpInterface,
["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
DeclareOpInterfaceMethods<SubsetInsertionOpInterface, DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction", ["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
"buildSubsetExtraction", "isEquivalentSubset"]>, "buildSubsetExtraction", "isEquivalentSubset"]>,

View File

@ -268,6 +268,11 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
public: public:
void dump() const { llvm::errs() << *this << "\n"; } void dump() const { llvm::errs() << *this << "\n"; }
MLIRContext *getContext() const {
return is<Attribute>() ? get<Attribute>().getContext()
: get<Value>().getContext();
}
}; };
// Temporarily exit the MLIR namespace to add casting support as later code in // Temporarily exit the MLIR namespace to add casting support as later code in

View File

@ -10,6 +10,7 @@
#define MLIR_INTERFACES_SUBSETOPINTERFACE_H_ #define MLIR_INTERFACES_SUBSETOPINTERFACE_H_
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
namespace mlir { namespace mlir {
class SubsetOpInterface; class SubsetOpInterface;
@ -27,10 +28,23 @@ OpOperand &defaultGetDestinationOperand(Operation *op);
/// `DestinationStyleOpInterface`. /// `DestinationStyleOpInterface`.
OpResult defaultGetUpdatedDestination(Operation *op); OpResult defaultGetUpdatedDestination(Operation *op);
/// Default implementation of `isEquivalentSubset`. /// Default implementation of `SubsetInsertionOpInterface::isEquivalentSubset`.
bool defaultIsEquivalentSubset(Operation *op, Value candidate, bool defaultIsEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn); function_ref<bool(Value, Value)> equivalenceFn);
/// Default implementation of `SubsetOpInterface::operatesOnEquivalentSubset`.
bool defaultOperatesOnEquivalentSubset(
Operation *op, SubsetOpInterface candidate,
function_ref<bool(Value, Value)> equivalenceFn);
/// Default implementation of `SubsetOpInterface::operatesOnDisjointSubset`.
bool defaultOperatesOnDisjointSubset(
Operation *op, SubsetOpInterface candidate,
function_ref<bool(Value, Value)> equivalenceFn);
/// Return the container that the given subset op is operating on.
Value getTensorContainer(Operation *op);
/// Verify `SubsetOpInterface`. /// Verify `SubsetOpInterface`.
LogicalResult verifySubsetOpInterface(SubsetOpInterface op); LogicalResult verifySubsetOpInterface(SubsetOpInterface op);

View File

@ -32,11 +32,6 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
hyperrectangular slice. hyperrectangular slice.
- `tensor.gather/scatter` describe the subset as list of indices. (Not - `tensor.gather/scatter` describe the subset as list of indices. (Not
implemented yet.) implemented yet.)
Note: This interface does not expose any interface methods to get a
description of the accessed subset. That is because there is currently no
efficient way to describe arbitrary subsets. This interface merely provides
interface methods to check if two subsets are equivalent or disjoint.
}]; }];
let cppNamespace = "::mlir"; let cppNamespace = "::mlir";
@ -46,24 +41,59 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
Return "true" if this op and the given candidate subset op operate on Return "true" if this op and the given candidate subset op operate on
equivalent subsets. Return "false" if the two subsets are disjoint equivalent subsets. Return "false" if the two subsets are disjoint
or cannot be proven to be equivalent. or cannot be proven to be equivalent.
This interface method does not have to be implemented if
`getAccessedHyperrectangularSlice` is implemented.
}], }],
/*retType=*/"bool", /*retType=*/"bool",
/*methodName=*/"operatesOnEquivalentSubset", /*methodName=*/"operatesOnEquivalentSubset",
/*args=*/(ins /*args=*/(ins
"::mlir::SubsetOpInterface":$candidate, "::mlir::SubsetOpInterface":$candidate,
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn) "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::mlir::detail::defaultOperatesOnEquivalentSubset(
$_op, candidate, equivalenceFn);
}]
>, >,
InterfaceMethod< InterfaceMethod<
/*desc=*/[{ /*desc=*/[{
Return "true" if this op and the given candidate subset op operate on Return "true" if this op and the given candidate subset op operate on
disjoint subsets. Return "false" if the two subsets are equivalent, disjoint subsets. Return "false" if the two subsets are equivalent,
overlapping or cannot be proven to be disjoint. overlapping or cannot be proven to be disjoint.
This interface method does not have to be implemented if
`getAccessedHyperrectangularSlice` is implemented.
}], }],
/*retType=*/"bool", /*retType=*/"bool",
/*methodName=*/"operatesOnDisjointSubset", /*methodName=*/"operatesOnDisjointSubset",
/*args=*/(ins /*args=*/(ins
"::mlir::SubsetOpInterface":$candidate, "::mlir::SubsetOpInterface":$candidate,
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn) "::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::mlir::detail::defaultOperatesOnDisjointSubset(
$_op, candidate, equivalenceFn);
}]
>,
InterfaceMethod<
/*desc=*/[{
If this op operates on a hyperrectangular subset, return a
description of the subset in terms of offsets, sizes and strides.
Otherwise, return "failure".
This interface method is a convenience method for the most common case
of hyperrectangular subset ops. It is optional. If it is implemented,
`operatesOnEquivalentSubset` and `operatesOnDisjointSubset` do not
have to be implemented.
}],
/*retType=*/"::mlir::FailureOr<::mlir::HyperrectangularSlice>",
/*methodName=*/"getAccessedHyperrectangularSlice",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::mlir::failure();
}]
>, >,
]; ];
@ -71,6 +101,15 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
return ::mlir::detail::verifySubsetOpInterface( return ::mlir::detail::verifySubsetOpInterface(
::mlir::cast<::mlir::SubsetOpInterface>($_op)); ::mlir::cast<::mlir::SubsetOpInterface>($_op));
}]; }];
let extraClassDeclaration = [{
/// Return the container that this operation is operating on. In case of an
/// extraction op, the container is the source tensor. In case of an
/// insertion op, the container is the destination tensor.
Value getTensorContainer() {
return ::mlir::detail::getTensorContainer(getOperation());
}
}];
} }
def SubsetExtractionOpInterface def SubsetExtractionOpInterface

View File

@ -21,6 +21,31 @@
namespace mlir { namespace mlir {
class OffsetSizeAndStrideOpInterface; class OffsetSizeAndStrideOpInterface;
/// A hyperrectangular slice, represented as a list of offsets, sizes and
/// strides.
class HyperrectangularSlice {
public:
HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides);
/// Create a hyperrectangular slice with unit strides.
HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes);
/// Infer a hyperrectangular slice from `OffsetSizeAndStrideOpInterface`.
HyperrectangularSlice(OffsetSizeAndStrideOpInterface op);
ArrayRef<OpFoldResult> getMixedOffsets() const { return mixedOffsets; }
ArrayRef<OpFoldResult> getMixedSizes() const { return mixedSizes; }
ArrayRef<OpFoldResult> getMixedStrides() const { return mixedStrides; }
private:
SmallVector<OpFoldResult> mixedOffsets;
SmallVector<OpFoldResult> mixedSizes;
SmallVector<OpFoldResult> mixedStrides;
};
using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>; using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
/// A helper class to be used with `ValueBoundsOpInterface`. This class stores a /// A helper class to be used with `ValueBoundsOpInterface`. This class stores a
@ -182,12 +207,34 @@ public:
std::optional<int64_t> dim1 = std::nullopt, std::optional<int64_t> dim1 = std::nullopt,
std::optional<int64_t> dim2 = std::nullopt); std::optional<int64_t> dim2 = std::nullopt);
/// Compute whether the given values/attributes are equal. Return "failure" if
/// equality could not be determined.
///
/// `ofr1`/`ofr2` must be of index type.
static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
/// Return "true" if the given slices are guaranteed to be overlapping. /// Return "true" if the given slices are guaranteed to be overlapping.
/// Return "false" if the given slices are guaranteed to be non-overlapping. /// Return "false" if the given slices are guaranteed to be non-overlapping.
/// Return "failure" if unknown. /// Return "failure" if unknown.
static FailureOr<bool> ///
areOverlappingSlices(OffsetSizeAndStrideOpInterface slice1, /// Slices are overlapping if for all dimensions:
OffsetSizeAndStrideOpInterface slice2); /// * offset1 + size1 * stride1 <= offset2
/// * and offset2 + size2 * stride2 <= offset1
///
/// Slice are non-overlapping if the above constraint is not satisfied for
/// at least one dimension.
static FailureOr<bool> areOverlappingSlices(MLIRContext *ctx,
HyperrectangularSlice slice1,
HyperrectangularSlice slice2);
/// Return "true" if the given slices are guaranteed to be equivalent.
/// Return "false" if the given slices are guaranteed to be non-equivalent.
/// Return "failure" if unknown.
///
/// Slices are equivalent if their offsets, sizes and strices are equal.
static FailureOr<bool> areEquivalentSlices(MLIRContext *ctx,
HyperrectangularSlice slice1,
HyperrectangularSlice slice2);
/// Add a bound for the given index-typed value or shaped value. This function /// Add a bound for the given index-typed value or shaped value. This function
/// returns a builder that adds the bound. /// returns a builder that adds the bound.

View File

@ -17,73 +17,12 @@ using namespace mlir::tensor;
namespace { namespace {
/// Return the tensor that the given subset op operates on.
Value getContainerOperand(SubsetOpInterface op) {
if (auto extractionOp =
dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
return extractionOp.getSourceOperand().get();
if (auto insertionOp =
dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
return insertionOp.getDestinationOperand().get();
llvm_unreachable("expected SubsetExtraction/InsertionOpInterface");
}
/// Return "true" if the two ops operate on an equivalent subset.
/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
/// if the two ops operate non-equivalent subsets, if equivalence cannot be
/// determined or if `op1` is not a subset op.
template <typename OpTy>
bool operateOnEquivalentSubsets(
OpTy op1, SubsetOpInterface op2,
function_ref<bool(Value, Value)> equivalenceFn) {
auto offsetsSizesAndStrides2 =
dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
if (!offsetsSizesAndStrides2)
return false;
if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2,
isEqualConstantIntOrValue))
return false;
return equivalenceFn(
getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
getContainerOperand(op2));
}
/// Return "true" if the two ops operate on a disjoint subsets.
/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
/// if the two ops operate non-disjoint subsets, if disjointness cannot be
/// determined or if `op1` is not a subset op.
template <typename OpTy>
bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2,
function_ref<bool(Value, Value)> equivalenceFn) {
auto offsetsSizesAndStrides2 =
dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
if (!offsetsSizesAndStrides2)
return false;
FailureOr<bool> overlappingSlices =
ValueBoundsConstraintSet::areOverlappingSlices(op1,
offsetsSizesAndStrides2);
if (failed(overlappingSlices) || *overlappingSlices)
return false;
return equivalenceFn(
getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
getContainerOperand(op2));
}
struct ExtractSliceOpSubsetOpInterface struct ExtractSliceOpSubsetOpInterface
: public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface, : public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
tensor::ExtractSliceOp> { tensor::ExtractSliceOp> {
bool operatesOnEquivalentSubset( FailureOr<HyperrectangularSlice>
Operation *op, SubsetOpInterface candidate, getAccessedHyperrectangularSlice(Operation *op) const {
function_ref<bool(Value, Value)> equivalenceFn) const { return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn);
}
bool operatesOnDisjointSubset(
Operation *op, SubsetOpInterface candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn);
} }
}; };
@ -99,18 +38,9 @@ template <typename OpTy>
struct InsertSliceLikeOpSubsetOpInterface struct InsertSliceLikeOpSubsetOpInterface
: public SubsetOpInterface::ExternalModel< : public SubsetOpInterface::ExternalModel<
InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> { InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
bool operatesOnEquivalentSubset( FailureOr<HyperrectangularSlice>
Operation *op, SubsetOpInterface candidate, getAccessedHyperrectangularSlice(Operation *op) const {
function_ref<bool(Value, Value)> equivalenceFn) const { return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
auto insertSliceOp = cast<OpTy>(op);
return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn);
}
bool operatesOnDisjointSubset(
Operation *op, SubsetOpInterface candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
auto insertSliceOp = cast<OpTy>(op);
return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn);
} }
}; };

View File

@ -93,10 +93,12 @@ add_mlir_library(MLIRSubsetOpInterface
DEPENDS DEPENDS
MLIRDestinationStyleOpInterface MLIRDestinationStyleOpInterface
MLIRSubsetOpInterfaceIncGen MLIRSubsetOpInterfaceIncGen
MLIRValueBoundsOpInterface
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRDestinationStyleOpInterface MLIRDestinationStyleOpInterface
MLIRIR MLIRIR
MLIRValueBoundsOpInterface
) )
add_mlir_interface_library(TilingInterface) add_mlir_interface_library(TilingInterface)

View File

@ -8,6 +8,7 @@
#include "mlir/Interfaces/SubsetOpInterface.h" #include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/SubsetOpInterface.cpp.inc" #include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
@ -40,6 +41,54 @@ bool detail::defaultIsEquivalentSubset(
candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn); candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
} }
bool detail::defaultOperatesOnEquivalentSubset(
Operation *op, SubsetOpInterface candidate,
function_ref<bool(Value, Value)> equivalenceFn) {
auto subsetOp = cast<SubsetOpInterface>(op);
FailureOr<HyperrectangularSlice> slice =
subsetOp.getAccessedHyperrectangularSlice();
assert(succeeded(slice) &&
"operatesOnEquivalentSubset must be implemented if "
"getAccessedHyperrectangularSlice is not implemented");
FailureOr<HyperrectangularSlice> otherSlice =
candidate.getAccessedHyperrectangularSlice();
if (failed(otherSlice))
return false;
if (!equivalenceFn(subsetOp.getTensorContainer(),
candidate.getTensorContainer()))
return false;
FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
op->getContext(), *slice, *otherSlice);
return succeeded(equivalent) && *equivalent;
}
bool detail::defaultOperatesOnDisjointSubset(
Operation *op, SubsetOpInterface candidate,
function_ref<bool(Value, Value)> equivalenceFn) {
auto subsetOp = cast<SubsetOpInterface>(op);
FailureOr<HyperrectangularSlice> slice =
subsetOp.getAccessedHyperrectangularSlice();
assert(succeeded(slice) &&
"defaultOperatesOnDisjointSubset must be implemented if "
"getAccessedHyperrectangularSlice is not implemented");
FailureOr<HyperrectangularSlice> otherSlice =
candidate.getAccessedHyperrectangularSlice();
if (failed(otherSlice))
return false;
if (!equivalenceFn(subsetOp.getTensorContainer(),
candidate.getTensorContainer()))
return false;
FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
op->getContext(), *slice, *otherSlice);
return succeeded(overlapping) && !*overlapping;
}
Value detail::getTensorContainer(Operation *op) {
if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
return insertionOp.getDestinationOperand().get();
return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
}
LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) { LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^ if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
isa<SubsetInsertionOpInterface>(op.getOperation()))) isa<SubsetInsertionOpInterface>(op.getOperation())))

View File

@ -25,6 +25,32 @@ namespace mlir {
#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc" #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
} // namespace mlir } // namespace mlir
HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides)
: mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) {
assert(offsets.size() == sizes.size() &&
"expected same number of offsets, sizes, strides");
assert(offsets.size() == strides.size() &&
"expected same number of offsets, sizes, strides");
}
HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes)
: mixedOffsets(offsets), mixedSizes(sizes) {
assert(offsets.size() == sizes.size() &&
"expected same number of offsets and sizes");
// Assume that all strides are 1.
if (offsets.empty())
return;
MLIRContext *ctx = offsets.front().getContext();
mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1));
}
HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op)
: HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(),
op.getMixedStrides()) {}
/// If ofr is a constant integer or an IntegerAttr, return the integer. /// If ofr is a constant integer or an IntegerAttr, return the integer.
static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) { static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer. // Case 1: Check for Constant integer.
@ -524,19 +550,44 @@ ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
return *delta == 0; return *delta == 0;
} }
FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices( FailureOr<bool> ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1,
OffsetSizeAndStrideOpInterface slice1, OpFoldResult ofr2) {
OffsetSizeAndStrideOpInterface slice2) { Builder b(ofr1.getContext());
assert(slice1.getStaticOffsets().size() == slice1.getStaticOffsets().size() && AffineMap map =
AffineMap::get(/*dimCount=*/0, /*symbolCount=*/2,
b.getAffineSymbolExpr(0) - b.getAffineSymbolExpr(1));
SmallVector<OpFoldResult> ofrOperands;
ofrOperands.push_back(ofr1);
ofrOperands.push_back(ofr2);
SmallVector<Value> valueOperands;
AffineMap foldedMap =
foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
ValueDimList valueDims;
for (Value v : valueOperands) {
assert(v.getType().isIndex() && "expected index type");
valueDims.emplace_back(v, std::nullopt);
}
FailureOr<int64_t> delta =
computeConstantBound(presburger::BoundType::EQ, foldedMap, valueDims);
if (failed(delta))
return failure();
return *delta == 0;
}
FailureOr<bool>
ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
HyperrectangularSlice slice1,
HyperrectangularSlice slice2) {
assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
"expected slices of same rank"); "expected slices of same rank");
assert(slice1.getStaticSizes().size() == slice1.getStaticSizes().size() && assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
"expected slices of same rank"); "expected slices of same rank");
assert(slice1.getStaticStrides().size() == slice1.getStaticStrides().size() && assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
"expected slices of same rank"); "expected slices of same rank");
Builder b(slice1.getContext()); Builder b(ctx);
bool foundUnknownBound = false; bool foundUnknownBound = false;
for (int64_t i = 0, e = slice1.getStaticOffsets().size(); i < e; ++i) { for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) {
AffineMap map = AffineMap map =
AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4, AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
b.getAffineSymbolExpr(0) + b.getAffineSymbolExpr(0) +
@ -588,6 +639,48 @@ FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
return true; return true;
} }
FailureOr<bool>
ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx,
HyperrectangularSlice slice1,
HyperrectangularSlice slice2) {
assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
"expected slices of same rank");
assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
"expected slices of same rank");
assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
"expected slices of same rank");
// The two slices are equivalent if all of their offsets, sizes and strides
// are equal. If equality cannot be determined for at least one of those
// values, equivalence cannot be determined and this function returns
// "failure".
for (auto [offset1, offset2] :
llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) {
FailureOr<bool> equal = areEqual(offset1, offset2);
if (failed(equal))
return failure();
if (!equal.value())
return false;
}
for (auto [size1, size2] :
llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) {
FailureOr<bool> equal = areEqual(size1, size2);
if (failed(equal))
return failure();
if (!equal.value())
return false;
}
for (auto [stride1, stride2] :
llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) {
FailureOr<bool> equal = areEqual(stride1, stride2);
if (failed(equal))
return failure();
if (!equal.value())
return false;
}
return true;
}
ValueBoundsConstraintSet::BoundBuilder & ValueBoundsConstraintSet::BoundBuilder &
ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) { ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
assert(!this->dim.has_value() && "dim was already set"); assert(!this->dim.has_value() && "dim was already set");

View File

@ -7,6 +7,11 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
%ub = "test.foo"() : () -> (index) %ub = "test.foo"() : () -> (index)
%step = "test.foo"() : () -> (index) %step = "test.foo"() : () -> (index)
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%add = arith.addi %c0, %c1 : index
%sub = arith.subi %add, %c1 : index
// CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]] // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
// CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]]) // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
%0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) { %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
@ -17,7 +22,9 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
%1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32> %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
// CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]]) // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
%2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
%3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32> // Obfuscate the IR by inserting at offset %sub instead of 0; both of them
// have the same value.
%3 = tensor.insert_slice %2 into %t[%sub][5][1] : tensor<5xf32> into tensor<?xf32>
// CHECK: scf.yield %[[t]], %[[foo]] // CHECK: scf.yield %[[t]], %[[foo]]
scf.yield %3 : tensor<?xf32> scf.yield %3 : tensor<?xf32>
} }

View File

@ -10230,6 +10230,7 @@ cc_library(
":IR", ":IR",
":SubsetOpInterfaceIncGen", ":SubsetOpInterfaceIncGen",
":Support", ":Support",
":ValueBoundsOpInterface",
"//llvm:Support", "//llvm:Support",
], ],
) )