From ff614a5729e9a4fc32465ad5ff3b87e044429c2d Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 1 Nov 2023 11:29:00 +0900 Subject: [PATCH] [mlir][Interfaces] LISH: Add helpers for hyperrectangular subsets (#70628) The majority of subset ops operate on hyperrectangular subsets. This commit adds a new optional interface method (`getAccessedHyperrectangularSlice`) that can be implemented by such subset ops. If implemented, the other `operatesOn...` interface methods of the `SubsetOpInterface` do not have to be implemented anymore. The comparison logic for hyperrectangular subsets (is disjoint/equivalent) is implemented with `ValueBoundsOpInterface`. This makes the subset hoisting more powerful: simple cases where two different SSA values always have the same runtime value can now be supported. --- .../Bufferization/IR/BufferizationOps.td | 3 +- mlir/include/mlir/IR/OpDefinition.h | 5 + .../mlir/Interfaces/SubsetOpInterface.h | 16 ++- .../mlir/Interfaces/SubsetOpInterface.td | 53 +++++++-- .../mlir/Interfaces/ValueBoundsOpInterface.h | 53 ++++++++- .../SubsetInsertionOpInterfaceImpl.cpp | 82 +------------ mlir/lib/Interfaces/CMakeLists.txt | 2 + mlir/lib/Interfaces/SubsetOpInterface.cpp | 49 ++++++++ .../lib/Interfaces/ValueBoundsOpInterface.cpp | 109 ++++++++++++++++-- .../loop-invariant-subset-hoisting.mlir | 9 +- .../llvm-project-overlay/mlir/BUILD.bazel | 1 + 11 files changed, 285 insertions(+), 97 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index e6b6d052df96..9dc6afcaab31 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -220,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp AllElementTypesMatch<["source", "dest"]>, BufferizableOpInterface, DestinationStyleOpInterface, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 8ab37c1d51d6..bd68c2744574 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -268,6 +268,11 @@ class OpFoldResult : public PointerUnion { public: void dump() const { llvm::errs() << *this << "\n"; } + + MLIRContext *getContext() const { + return is() ? get().getContext() + : get().getContext(); + } }; // Temporarily exit the MLIR namespace to add casting support as later code in diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.h b/mlir/include/mlir/Interfaces/SubsetOpInterface.h index 049cf2456a9c..98c33ec65012 100644 --- a/mlir/include/mlir/Interfaces/SubsetOpInterface.h +++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.h @@ -10,6 +10,7 @@ #define MLIR_INTERFACES_SUBSETOPINTERFACE_H_ #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" namespace mlir { class SubsetOpInterface; @@ -27,10 +28,23 @@ OpOperand &defaultGetDestinationOperand(Operation *op); /// `DestinationStyleOpInterface`. OpResult defaultGetUpdatedDestination(Operation *op); -/// Default implementation of `isEquivalentSubset`. +/// Default implementation of `SubsetInsertionOpInterface::isEquivalentSubset`. bool defaultIsEquivalentSubset(Operation *op, Value candidate, function_ref equivalenceFn); +/// Default implementation of `SubsetOpInterface::operatesOnEquivalentSubset`. +bool defaultOperatesOnEquivalentSubset( + Operation *op, SubsetOpInterface candidate, + function_ref equivalenceFn); + +/// Default implementation of `SubsetOpInterface::operatesOnDisjointSubset`. +bool defaultOperatesOnDisjointSubset( + Operation *op, SubsetOpInterface candidate, + function_ref equivalenceFn); + +/// Return the container that the given subset op is operating on. +Value getTensorContainer(Operation *op); + /// Verify `SubsetOpInterface`. LogicalResult verifySubsetOpInterface(SubsetOpInterface op); diff --git a/mlir/include/mlir/Interfaces/SubsetOpInterface.td b/mlir/include/mlir/Interfaces/SubsetOpInterface.td index 9ebed2c94818..7000e7dfc89c 100644 --- a/mlir/include/mlir/Interfaces/SubsetOpInterface.td +++ b/mlir/include/mlir/Interfaces/SubsetOpInterface.td @@ -32,11 +32,6 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> { hyperrectangular slice. - `tensor.gather/scatter` describe the subset as list of indices. (Not implemented yet.) - - Note: This interface does not expose any interface methods to get a - description of the accessed subset. That is because there is currently no - efficient way to describe arbitrary subsets. This interface merely provides - interface methods to check if two subsets are equivalent or disjoint. }]; let cppNamespace = "::mlir"; @@ -46,24 +41,59 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> { Return "true" if this op and the given candidate subset op operate on equivalent subsets. Return "false" if the two subsets are disjoint or cannot be proven to be equivalent. + + This interface method does not have to be implemented if + `getAccessedHyperrectangularSlice` is implemented. }], /*retType=*/"bool", /*methodName=*/"operatesOnEquivalentSubset", /*args=*/(ins "::mlir::SubsetOpInterface":$candidate, - "::llvm::function_ref":$equivalenceFn) + "::llvm::function_ref":$equivalenceFn), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::detail::defaultOperatesOnEquivalentSubset( + $_op, candidate, equivalenceFn); + }] >, InterfaceMethod< /*desc=*/[{ Return "true" if this op and the given candidate subset op operate on disjoint subsets. Return "false" if the two subsets are equivalent, overlapping or cannot be proven to be disjoint. + + This interface method does not have to be implemented if + `getAccessedHyperrectangularSlice` is implemented. }], /*retType=*/"bool", /*methodName=*/"operatesOnDisjointSubset", /*args=*/(ins "::mlir::SubsetOpInterface":$candidate, - "::llvm::function_ref":$equivalenceFn) + "::llvm::function_ref":$equivalenceFn), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::detail::defaultOperatesOnDisjointSubset( + $_op, candidate, equivalenceFn); + }] + >, + InterfaceMethod< + /*desc=*/[{ + If this op operates on a hyperrectangular subset, return a + description of the subset in terms of offsets, sizes and strides. + Otherwise, return "failure". + + This interface method is a convenience method for the most common case + of hyperrectangular subset ops. It is optional. If it is implemented, + `operatesOnEquivalentSubset` and `operatesOnDisjointSubset` do not + have to be implemented. + }], + /*retType=*/"::mlir::FailureOr<::mlir::HyperrectangularSlice>", + /*methodName=*/"getAccessedHyperrectangularSlice", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::failure(); + }] >, ]; @@ -71,6 +101,15 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> { return ::mlir::detail::verifySubsetOpInterface( ::mlir::cast<::mlir::SubsetOpInterface>($_op)); }]; + + let extraClassDeclaration = [{ + /// Return the container that this operation is operating on. In case of an + /// extraction op, the container is the source tensor. In case of an + /// insertion op, the container is the destination tensor. + Value getTensorContainer() { + return ::mlir::detail::getTensorContainer(getOperation()); + } + }]; } def SubsetExtractionOpInterface diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 8e2986a2d1f0..28dadfb9ecf8 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -21,6 +21,31 @@ namespace mlir { class OffsetSizeAndStrideOpInterface; +/// A hyperrectangular slice, represented as a list of offsets, sizes and +/// strides. +class HyperrectangularSlice { +public: + HyperrectangularSlice(ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides); + + /// Create a hyperrectangular slice with unit strides. + HyperrectangularSlice(ArrayRef offsets, + ArrayRef sizes); + + /// Infer a hyperrectangular slice from `OffsetSizeAndStrideOpInterface`. + HyperrectangularSlice(OffsetSizeAndStrideOpInterface op); + + ArrayRef getMixedOffsets() const { return mixedOffsets; } + ArrayRef getMixedSizes() const { return mixedSizes; } + ArrayRef getMixedStrides() const { return mixedStrides; } + +private: + SmallVector mixedOffsets; + SmallVector mixedSizes; + SmallVector mixedStrides; +}; + using ValueDimList = SmallVector>>; /// A helper class to be used with `ValueBoundsOpInterface`. This class stores a @@ -182,12 +207,34 @@ public: std::optional dim1 = std::nullopt, std::optional dim2 = std::nullopt); + /// Compute whether the given values/attributes are equal. Return "failure" if + /// equality could not be determined. + /// + /// `ofr1`/`ofr2` must be of index type. + static FailureOr areEqual(OpFoldResult ofr1, OpFoldResult ofr2); + /// Return "true" if the given slices are guaranteed to be overlapping. /// Return "false" if the given slices are guaranteed to be non-overlapping. /// Return "failure" if unknown. - static FailureOr - areOverlappingSlices(OffsetSizeAndStrideOpInterface slice1, - OffsetSizeAndStrideOpInterface slice2); + /// + /// Slices are overlapping if for all dimensions: + /// * offset1 + size1 * stride1 <= offset2 + /// * and offset2 + size2 * stride2 <= offset1 + /// + /// Slice are non-overlapping if the above constraint is not satisfied for + /// at least one dimension. + static FailureOr areOverlappingSlices(MLIRContext *ctx, + HyperrectangularSlice slice1, + HyperrectangularSlice slice2); + + /// Return "true" if the given slices are guaranteed to be equivalent. + /// Return "false" if the given slices are guaranteed to be non-equivalent. + /// Return "failure" if unknown. + /// + /// Slices are equivalent if their offsets, sizes and strices are equal. + static FailureOr areEquivalentSlices(MLIRContext *ctx, + HyperrectangularSlice slice1, + HyperrectangularSlice slice2); /// Add a bound for the given index-typed value or shaped value. This function /// returns a builder that adds the bound. diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp index 7a1bafd409ee..d50d7c62b789 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp @@ -17,73 +17,12 @@ using namespace mlir::tensor; namespace { -/// Return the tensor that the given subset op operates on. -Value getContainerOperand(SubsetOpInterface op) { - if (auto extractionOp = - dyn_cast(op.getOperation())) - return extractionOp.getSourceOperand().get(); - if (auto insertionOp = - dyn_cast(op.getOperation())) - return insertionOp.getDestinationOperand().get(); - llvm_unreachable("expected SubsetExtraction/InsertionOpInterface"); -} - -/// Return "true" if the two ops operate on an equivalent subset. -/// `equivalenceFn` is used to determine equivalence of tensors. Return "false" -/// if the two ops operate non-equivalent subsets, if equivalence cannot be -/// determined or if `op1` is not a subset op. -template -bool operateOnEquivalentSubsets( - OpTy op1, SubsetOpInterface op2, - function_ref equivalenceFn) { - auto offsetsSizesAndStrides2 = - dyn_cast(op2.getOperation()); - if (!offsetsSizesAndStrides2) - return false; - if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2, - isEqualConstantIntOrValue)) - return false; - return equivalenceFn( - getContainerOperand(cast(op1.getOperation())), - getContainerOperand(op2)); -} - -/// Return "true" if the two ops operate on a disjoint subsets. -/// `equivalenceFn` is used to determine equivalence of tensors. Return "false" -/// if the two ops operate non-disjoint subsets, if disjointness cannot be -/// determined or if `op1` is not a subset op. -template -bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2, - function_ref equivalenceFn) { - auto offsetsSizesAndStrides2 = - dyn_cast(op2.getOperation()); - if (!offsetsSizesAndStrides2) - return false; - FailureOr overlappingSlices = - ValueBoundsConstraintSet::areOverlappingSlices(op1, - offsetsSizesAndStrides2); - if (failed(overlappingSlices) || *overlappingSlices) - return false; - return equivalenceFn( - getContainerOperand(cast(op1.getOperation())), - getContainerOperand(op2)); -} - struct ExtractSliceOpSubsetOpInterface : public SubsetOpInterface::ExternalModel { - bool operatesOnEquivalentSubset( - Operation *op, SubsetOpInterface candidate, - function_ref equivalenceFn) const { - auto extractSliceOp = cast(op); - return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn); - } - - bool operatesOnDisjointSubset( - Operation *op, SubsetOpInterface candidate, - function_ref equivalenceFn) const { - auto extractSliceOp = cast(op); - return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn); + FailureOr + getAccessedHyperrectangularSlice(Operation *op) const { + return HyperrectangularSlice(cast(op)); } }; @@ -99,18 +38,9 @@ template struct InsertSliceLikeOpSubsetOpInterface : public SubsetOpInterface::ExternalModel< InsertSliceLikeOpSubsetOpInterface, OpTy> { - bool operatesOnEquivalentSubset( - Operation *op, SubsetOpInterface candidate, - function_ref equivalenceFn) const { - auto insertSliceOp = cast(op); - return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn); - } - - bool operatesOnDisjointSubset( - Operation *op, SubsetOpInterface candidate, - function_ref equivalenceFn) const { - auto insertSliceOp = cast(op); - return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn); + FailureOr + getAccessedHyperrectangularSlice(Operation *op) const { + return HyperrectangularSlice(cast(op)); } }; diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index 2652d261f480..e7c76e70ed6b 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -93,10 +93,12 @@ add_mlir_library(MLIRSubsetOpInterface DEPENDS MLIRDestinationStyleOpInterface MLIRSubsetOpInterfaceIncGen + MLIRValueBoundsOpInterface LINK_LIBS PUBLIC MLIRDestinationStyleOpInterface MLIRIR + MLIRValueBoundsOpInterface ) add_mlir_interface_library(TilingInterface) diff --git a/mlir/lib/Interfaces/SubsetOpInterface.cpp b/mlir/lib/Interfaces/SubsetOpInterface.cpp index 7245ab20c499..d0bdadf500f6 100644 --- a/mlir/lib/Interfaces/SubsetOpInterface.cpp +++ b/mlir/lib/Interfaces/SubsetOpInterface.cpp @@ -8,6 +8,7 @@ #include "mlir/Interfaces/SubsetOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Interfaces/SubsetOpInterface.cpp.inc" @@ -40,6 +41,54 @@ bool detail::defaultIsEquivalentSubset( candidate.getDefiningOp(), equivalenceFn); } +bool detail::defaultOperatesOnEquivalentSubset( + Operation *op, SubsetOpInterface candidate, + function_ref equivalenceFn) { + auto subsetOp = cast(op); + FailureOr slice = + subsetOp.getAccessedHyperrectangularSlice(); + assert(succeeded(slice) && + "operatesOnEquivalentSubset must be implemented if " + "getAccessedHyperrectangularSlice is not implemented"); + FailureOr otherSlice = + candidate.getAccessedHyperrectangularSlice(); + if (failed(otherSlice)) + return false; + if (!equivalenceFn(subsetOp.getTensorContainer(), + candidate.getTensorContainer())) + return false; + FailureOr equivalent = ValueBoundsConstraintSet::areEquivalentSlices( + op->getContext(), *slice, *otherSlice); + return succeeded(equivalent) && *equivalent; +} + +bool detail::defaultOperatesOnDisjointSubset( + Operation *op, SubsetOpInterface candidate, + function_ref equivalenceFn) { + auto subsetOp = cast(op); + FailureOr slice = + subsetOp.getAccessedHyperrectangularSlice(); + assert(succeeded(slice) && + "defaultOperatesOnDisjointSubset must be implemented if " + "getAccessedHyperrectangularSlice is not implemented"); + FailureOr otherSlice = + candidate.getAccessedHyperrectangularSlice(); + if (failed(otherSlice)) + return false; + if (!equivalenceFn(subsetOp.getTensorContainer(), + candidate.getTensorContainer())) + return false; + FailureOr overlapping = ValueBoundsConstraintSet::areOverlappingSlices( + op->getContext(), *slice, *otherSlice); + return succeeded(overlapping) && !*overlapping; +} + +Value detail::getTensorContainer(Operation *op) { + if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op)) + return insertionOp.getDestinationOperand().get(); + return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get(); +} + LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) { if (!(isa(op.getOperation()) ^ isa(op.getOperation()))) diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index f0c37c872e6d..62ba63402925 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -25,6 +25,32 @@ namespace mlir { #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc" } // namespace mlir +HyperrectangularSlice::HyperrectangularSlice(ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) + : mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) { + assert(offsets.size() == sizes.size() && + "expected same number of offsets, sizes, strides"); + assert(offsets.size() == strides.size() && + "expected same number of offsets, sizes, strides"); +} + +HyperrectangularSlice::HyperrectangularSlice(ArrayRef offsets, + ArrayRef sizes) + : mixedOffsets(offsets), mixedSizes(sizes) { + assert(offsets.size() == sizes.size() && + "expected same number of offsets and sizes"); + // Assume that all strides are 1. + if (offsets.empty()) + return; + MLIRContext *ctx = offsets.front().getContext(); + mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1)); +} + +HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op) + : HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides()) {} + /// If ofr is a constant integer or an IntegerAttr, return the integer. static std::optional getConstantIntValue(OpFoldResult ofr) { // Case 1: Check for Constant integer. @@ -524,19 +550,44 @@ ValueBoundsConstraintSet::areEqual(Value value1, Value value2, return *delta == 0; } -FailureOr ValueBoundsConstraintSet::areOverlappingSlices( - OffsetSizeAndStrideOpInterface slice1, - OffsetSizeAndStrideOpInterface slice2) { - assert(slice1.getStaticOffsets().size() == slice1.getStaticOffsets().size() && +FailureOr ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1, + OpFoldResult ofr2) { + Builder b(ofr1.getContext()); + AffineMap map = + AffineMap::get(/*dimCount=*/0, /*symbolCount=*/2, + b.getAffineSymbolExpr(0) - b.getAffineSymbolExpr(1)); + SmallVector ofrOperands; + ofrOperands.push_back(ofr1); + ofrOperands.push_back(ofr2); + SmallVector valueOperands; + AffineMap foldedMap = + foldAttributesIntoMap(b, map, ofrOperands, valueOperands); + ValueDimList valueDims; + for (Value v : valueOperands) { + assert(v.getType().isIndex() && "expected index type"); + valueDims.emplace_back(v, std::nullopt); + } + FailureOr delta = + computeConstantBound(presburger::BoundType::EQ, foldedMap, valueDims); + if (failed(delta)) + return failure(); + return *delta == 0; +} + +FailureOr +ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx, + HyperrectangularSlice slice1, + HyperrectangularSlice slice2) { + assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() && "expected slices of same rank"); - assert(slice1.getStaticSizes().size() == slice1.getStaticSizes().size() && + assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() && "expected slices of same rank"); - assert(slice1.getStaticStrides().size() == slice1.getStaticStrides().size() && + assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() && "expected slices of same rank"); - Builder b(slice1.getContext()); + Builder b(ctx); bool foundUnknownBound = false; - for (int64_t i = 0, e = slice1.getStaticOffsets().size(); i < e; ++i) { + for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) { AffineMap map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4, b.getAffineSymbolExpr(0) + @@ -588,6 +639,48 @@ FailureOr ValueBoundsConstraintSet::areOverlappingSlices( return true; } +FailureOr +ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx, + HyperrectangularSlice slice1, + HyperrectangularSlice slice2) { + assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() && + "expected slices of same rank"); + assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() && + "expected slices of same rank"); + assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() && + "expected slices of same rank"); + + // The two slices are equivalent if all of their offsets, sizes and strides + // are equal. If equality cannot be determined for at least one of those + // values, equivalence cannot be determined and this function returns + // "failure". + for (auto [offset1, offset2] : + llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) { + FailureOr equal = areEqual(offset1, offset2); + if (failed(equal)) + return failure(); + if (!equal.value()) + return false; + } + for (auto [size1, size2] : + llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) { + FailureOr equal = areEqual(size1, size2); + if (failed(equal)) + return failure(); + if (!equal.value()) + return false; + } + for (auto [stride1, stride2] : + llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) { + FailureOr equal = areEqual(stride1, stride2); + if (failed(equal)) + return failure(); + if (!equal.value()) + return false; + } + return true; +} + ValueBoundsConstraintSet::BoundBuilder & ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) { assert(!this->dim.has_value() && "dim was already set"); diff --git a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir index b9161f4e20d1..bb60eeaba524 100644 --- a/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir +++ b/mlir/test/Transforms/loop-invariant-subset-hoisting.mlir @@ -7,6 +7,11 @@ func.func @hoist_matching_extract_insert(%arg: tensor) -> tensor { %ub = "test.foo"() : () -> (index) %step = "test.foo"() : () -> (index) + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %add = arith.addi %c0, %c1 : index + %sub = arith.subi %add, %c1 : index + // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]] // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]]) %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor) { @@ -17,7 +22,9 @@ func.func @hoist_matching_extract_insert(%arg: tensor) -> tensor { %1 = tensor.extract_slice %t[0][5][1] : tensor to tensor<5xf32> // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]]) %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>) - %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor + // Obfuscate the IR by inserting at offset %sub instead of 0; both of them + // have the same value. + %3 = tensor.insert_slice %2 into %t[%sub][5][1] : tensor<5xf32> into tensor // CHECK: scf.yield %[[t]], %[[foo]] scf.yield %3 : tensor } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 0a2ae427169a..7109b6c43905 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10230,6 +10230,7 @@ cc_library( ":IR", ":SubsetOpInterfaceIncGen", ":Support", + ":ValueBoundsOpInterface", "//llvm:Support", ], )