mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-27 15:26:06 +00:00
[mlir][Interfaces] LISH: Add helpers for hyperrectangular subsets (#70628)
The majority of subset ops operate on hyperrectangular subsets. This commit adds a new optional interface method (`getAccessedHyperrectangularSlice`) that can be implemented by such subset ops. If implemented, the other `operatesOn...` interface methods of the `SubsetOpInterface` do not have to be implemented anymore. The comparison logic for hyperrectangular subsets (is disjoint/equivalent) is implemented with `ValueBoundsOpInterface`. This makes the subset hoisting more powerful: simple cases where two different SSA values always have the same runtime value can now be supported.
This commit is contained in:
parent
5b6ceaf8c3
commit
ff614a5729
@ -220,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
|
|||||||
AllElementTypesMatch<["source", "dest"]>,
|
AllElementTypesMatch<["source", "dest"]>,
|
||||||
BufferizableOpInterface, DestinationStyleOpInterface,
|
BufferizableOpInterface, DestinationStyleOpInterface,
|
||||||
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
|
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
|
||||||
DeclareOpInterfaceMethods<SubsetOpInterface>,
|
DeclareOpInterfaceMethods<SubsetOpInterface,
|
||||||
|
["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
|
||||||
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
|
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
|
||||||
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
|
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
|
||||||
"buildSubsetExtraction", "isEquivalentSubset"]>,
|
"buildSubsetExtraction", "isEquivalentSubset"]>,
|
||||||
|
@ -268,6 +268,11 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
void dump() const { llvm::errs() << *this << "\n"; }
|
void dump() const { llvm::errs() << *this << "\n"; }
|
||||||
|
|
||||||
|
MLIRContext *getContext() const {
|
||||||
|
return is<Attribute>() ? get<Attribute>().getContext()
|
||||||
|
: get<Value>().getContext();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Temporarily exit the MLIR namespace to add casting support as later code in
|
// Temporarily exit the MLIR namespace to add casting support as later code in
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#define MLIR_INTERFACES_SUBSETOPINTERFACE_H_
|
#define MLIR_INTERFACES_SUBSETOPINTERFACE_H_
|
||||||
|
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class SubsetOpInterface;
|
class SubsetOpInterface;
|
||||||
@ -27,10 +28,23 @@ OpOperand &defaultGetDestinationOperand(Operation *op);
|
|||||||
/// `DestinationStyleOpInterface`.
|
/// `DestinationStyleOpInterface`.
|
||||||
OpResult defaultGetUpdatedDestination(Operation *op);
|
OpResult defaultGetUpdatedDestination(Operation *op);
|
||||||
|
|
||||||
/// Default implementation of `isEquivalentSubset`.
|
/// Default implementation of `SubsetInsertionOpInterface::isEquivalentSubset`.
|
||||||
bool defaultIsEquivalentSubset(Operation *op, Value candidate,
|
bool defaultIsEquivalentSubset(Operation *op, Value candidate,
|
||||||
function_ref<bool(Value, Value)> equivalenceFn);
|
function_ref<bool(Value, Value)> equivalenceFn);
|
||||||
|
|
||||||
|
/// Default implementation of `SubsetOpInterface::operatesOnEquivalentSubset`.
|
||||||
|
bool defaultOperatesOnEquivalentSubset(
|
||||||
|
Operation *op, SubsetOpInterface candidate,
|
||||||
|
function_ref<bool(Value, Value)> equivalenceFn);
|
||||||
|
|
||||||
|
/// Default implementation of `SubsetOpInterface::operatesOnDisjointSubset`.
|
||||||
|
bool defaultOperatesOnDisjointSubset(
|
||||||
|
Operation *op, SubsetOpInterface candidate,
|
||||||
|
function_ref<bool(Value, Value)> equivalenceFn);
|
||||||
|
|
||||||
|
/// Return the container that the given subset op is operating on.
|
||||||
|
Value getTensorContainer(Operation *op);
|
||||||
|
|
||||||
/// Verify `SubsetOpInterface`.
|
/// Verify `SubsetOpInterface`.
|
||||||
LogicalResult verifySubsetOpInterface(SubsetOpInterface op);
|
LogicalResult verifySubsetOpInterface(SubsetOpInterface op);
|
||||||
|
|
||||||
|
@ -32,11 +32,6 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
|
|||||||
hyperrectangular slice.
|
hyperrectangular slice.
|
||||||
- `tensor.gather/scatter` describe the subset as list of indices. (Not
|
- `tensor.gather/scatter` describe the subset as list of indices. (Not
|
||||||
implemented yet.)
|
implemented yet.)
|
||||||
|
|
||||||
Note: This interface does not expose any interface methods to get a
|
|
||||||
description of the accessed subset. That is because there is currently no
|
|
||||||
efficient way to describe arbitrary subsets. This interface merely provides
|
|
||||||
interface methods to check if two subsets are equivalent or disjoint.
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let cppNamespace = "::mlir";
|
let cppNamespace = "::mlir";
|
||||||
@ -46,24 +41,59 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
|
|||||||
Return "true" if this op and the given candidate subset op operate on
|
Return "true" if this op and the given candidate subset op operate on
|
||||||
equivalent subsets. Return "false" if the two subsets are disjoint
|
equivalent subsets. Return "false" if the two subsets are disjoint
|
||||||
or cannot be proven to be equivalent.
|
or cannot be proven to be equivalent.
|
||||||
|
|
||||||
|
This interface method does not have to be implemented if
|
||||||
|
`getAccessedHyperrectangularSlice` is implemented.
|
||||||
}],
|
}],
|
||||||
/*retType=*/"bool",
|
/*retType=*/"bool",
|
||||||
/*methodName=*/"operatesOnEquivalentSubset",
|
/*methodName=*/"operatesOnEquivalentSubset",
|
||||||
/*args=*/(ins
|
/*args=*/(ins
|
||||||
"::mlir::SubsetOpInterface":$candidate,
|
"::mlir::SubsetOpInterface":$candidate,
|
||||||
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
|
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
return ::mlir::detail::defaultOperatesOnEquivalentSubset(
|
||||||
|
$_op, candidate, equivalenceFn);
|
||||||
|
}]
|
||||||
>,
|
>,
|
||||||
InterfaceMethod<
|
InterfaceMethod<
|
||||||
/*desc=*/[{
|
/*desc=*/[{
|
||||||
Return "true" if this op and the given candidate subset op operate on
|
Return "true" if this op and the given candidate subset op operate on
|
||||||
disjoint subsets. Return "false" if the two subsets are equivalent,
|
disjoint subsets. Return "false" if the two subsets are equivalent,
|
||||||
overlapping or cannot be proven to be disjoint.
|
overlapping or cannot be proven to be disjoint.
|
||||||
|
|
||||||
|
This interface method does not have to be implemented if
|
||||||
|
`getAccessedHyperrectangularSlice` is implemented.
|
||||||
}],
|
}],
|
||||||
/*retType=*/"bool",
|
/*retType=*/"bool",
|
||||||
/*methodName=*/"operatesOnDisjointSubset",
|
/*methodName=*/"operatesOnDisjointSubset",
|
||||||
/*args=*/(ins
|
/*args=*/(ins
|
||||||
"::mlir::SubsetOpInterface":$candidate,
|
"::mlir::SubsetOpInterface":$candidate,
|
||||||
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
|
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
return ::mlir::detail::defaultOperatesOnDisjointSubset(
|
||||||
|
$_op, candidate, equivalenceFn);
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{
|
||||||
|
If this op operates on a hyperrectangular subset, return a
|
||||||
|
description of the subset in terms of offsets, sizes and strides.
|
||||||
|
Otherwise, return "failure".
|
||||||
|
|
||||||
|
This interface method is a convenience method for the most common case
|
||||||
|
of hyperrectangular subset ops. It is optional. If it is implemented,
|
||||||
|
`operatesOnEquivalentSubset` and `operatesOnDisjointSubset` do not
|
||||||
|
have to be implemented.
|
||||||
|
}],
|
||||||
|
/*retType=*/"::mlir::FailureOr<::mlir::HyperrectangularSlice>",
|
||||||
|
/*methodName=*/"getAccessedHyperrectangularSlice",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/"",
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
return ::mlir::failure();
|
||||||
|
}]
|
||||||
>,
|
>,
|
||||||
];
|
];
|
||||||
|
|
||||||
@ -71,6 +101,15 @@ def SubsetOpInterface : OpInterface<"SubsetOpInterface"> {
|
|||||||
return ::mlir::detail::verifySubsetOpInterface(
|
return ::mlir::detail::verifySubsetOpInterface(
|
||||||
::mlir::cast<::mlir::SubsetOpInterface>($_op));
|
::mlir::cast<::mlir::SubsetOpInterface>($_op));
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
/// Return the container that this operation is operating on. In case of an
|
||||||
|
/// extraction op, the container is the source tensor. In case of an
|
||||||
|
/// insertion op, the container is the destination tensor.
|
||||||
|
Value getTensorContainer() {
|
||||||
|
return ::mlir::detail::getTensorContainer(getOperation());
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SubsetExtractionOpInterface
|
def SubsetExtractionOpInterface
|
||||||
|
@ -21,6 +21,31 @@
|
|||||||
namespace mlir {
|
namespace mlir {
|
||||||
class OffsetSizeAndStrideOpInterface;
|
class OffsetSizeAndStrideOpInterface;
|
||||||
|
|
||||||
|
/// A hyperrectangular slice, represented as a list of offsets, sizes and
|
||||||
|
/// strides.
|
||||||
|
class HyperrectangularSlice {
|
||||||
|
public:
|
||||||
|
HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
|
||||||
|
ArrayRef<OpFoldResult> sizes,
|
||||||
|
ArrayRef<OpFoldResult> strides);
|
||||||
|
|
||||||
|
/// Create a hyperrectangular slice with unit strides.
|
||||||
|
HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
|
||||||
|
ArrayRef<OpFoldResult> sizes);
|
||||||
|
|
||||||
|
/// Infer a hyperrectangular slice from `OffsetSizeAndStrideOpInterface`.
|
||||||
|
HyperrectangularSlice(OffsetSizeAndStrideOpInterface op);
|
||||||
|
|
||||||
|
ArrayRef<OpFoldResult> getMixedOffsets() const { return mixedOffsets; }
|
||||||
|
ArrayRef<OpFoldResult> getMixedSizes() const { return mixedSizes; }
|
||||||
|
ArrayRef<OpFoldResult> getMixedStrides() const { return mixedStrides; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
SmallVector<OpFoldResult> mixedOffsets;
|
||||||
|
SmallVector<OpFoldResult> mixedSizes;
|
||||||
|
SmallVector<OpFoldResult> mixedStrides;
|
||||||
|
};
|
||||||
|
|
||||||
using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
|
using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
|
||||||
|
|
||||||
/// A helper class to be used with `ValueBoundsOpInterface`. This class stores a
|
/// A helper class to be used with `ValueBoundsOpInterface`. This class stores a
|
||||||
@ -182,12 +207,34 @@ public:
|
|||||||
std::optional<int64_t> dim1 = std::nullopt,
|
std::optional<int64_t> dim1 = std::nullopt,
|
||||||
std::optional<int64_t> dim2 = std::nullopt);
|
std::optional<int64_t> dim2 = std::nullopt);
|
||||||
|
|
||||||
|
/// Compute whether the given values/attributes are equal. Return "failure" if
|
||||||
|
/// equality could not be determined.
|
||||||
|
///
|
||||||
|
/// `ofr1`/`ofr2` must be of index type.
|
||||||
|
static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
|
||||||
|
|
||||||
/// Return "true" if the given slices are guaranteed to be overlapping.
|
/// Return "true" if the given slices are guaranteed to be overlapping.
|
||||||
/// Return "false" if the given slices are guaranteed to be non-overlapping.
|
/// Return "false" if the given slices are guaranteed to be non-overlapping.
|
||||||
/// Return "failure" if unknown.
|
/// Return "failure" if unknown.
|
||||||
static FailureOr<bool>
|
///
|
||||||
areOverlappingSlices(OffsetSizeAndStrideOpInterface slice1,
|
/// Slices are overlapping if for all dimensions:
|
||||||
OffsetSizeAndStrideOpInterface slice2);
|
/// * offset1 + size1 * stride1 <= offset2
|
||||||
|
/// * and offset2 + size2 * stride2 <= offset1
|
||||||
|
///
|
||||||
|
/// Slice are non-overlapping if the above constraint is not satisfied for
|
||||||
|
/// at least one dimension.
|
||||||
|
static FailureOr<bool> areOverlappingSlices(MLIRContext *ctx,
|
||||||
|
HyperrectangularSlice slice1,
|
||||||
|
HyperrectangularSlice slice2);
|
||||||
|
|
||||||
|
/// Return "true" if the given slices are guaranteed to be equivalent.
|
||||||
|
/// Return "false" if the given slices are guaranteed to be non-equivalent.
|
||||||
|
/// Return "failure" if unknown.
|
||||||
|
///
|
||||||
|
/// Slices are equivalent if their offsets, sizes and strices are equal.
|
||||||
|
static FailureOr<bool> areEquivalentSlices(MLIRContext *ctx,
|
||||||
|
HyperrectangularSlice slice1,
|
||||||
|
HyperrectangularSlice slice2);
|
||||||
|
|
||||||
/// Add a bound for the given index-typed value or shaped value. This function
|
/// Add a bound for the given index-typed value or shaped value. This function
|
||||||
/// returns a builder that adds the bound.
|
/// returns a builder that adds the bound.
|
||||||
|
@ -17,73 +17,12 @@ using namespace mlir::tensor;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
/// Return the tensor that the given subset op operates on.
|
|
||||||
Value getContainerOperand(SubsetOpInterface op) {
|
|
||||||
if (auto extractionOp =
|
|
||||||
dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
|
|
||||||
return extractionOp.getSourceOperand().get();
|
|
||||||
if (auto insertionOp =
|
|
||||||
dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
|
|
||||||
return insertionOp.getDestinationOperand().get();
|
|
||||||
llvm_unreachable("expected SubsetExtraction/InsertionOpInterface");
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return "true" if the two ops operate on an equivalent subset.
|
|
||||||
/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
|
|
||||||
/// if the two ops operate non-equivalent subsets, if equivalence cannot be
|
|
||||||
/// determined or if `op1` is not a subset op.
|
|
||||||
template <typename OpTy>
|
|
||||||
bool operateOnEquivalentSubsets(
|
|
||||||
OpTy op1, SubsetOpInterface op2,
|
|
||||||
function_ref<bool(Value, Value)> equivalenceFn) {
|
|
||||||
auto offsetsSizesAndStrides2 =
|
|
||||||
dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
|
|
||||||
if (!offsetsSizesAndStrides2)
|
|
||||||
return false;
|
|
||||||
if (!sameOffsetsSizesAndStrides(op1, offsetsSizesAndStrides2,
|
|
||||||
isEqualConstantIntOrValue))
|
|
||||||
return false;
|
|
||||||
return equivalenceFn(
|
|
||||||
getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
|
|
||||||
getContainerOperand(op2));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return "true" if the two ops operate on a disjoint subsets.
|
|
||||||
/// `equivalenceFn` is used to determine equivalence of tensors. Return "false"
|
|
||||||
/// if the two ops operate non-disjoint subsets, if disjointness cannot be
|
|
||||||
/// determined or if `op1` is not a subset op.
|
|
||||||
template <typename OpTy>
|
|
||||||
bool operateOnDisjointSubsets(OpTy op1, SubsetOpInterface op2,
|
|
||||||
function_ref<bool(Value, Value)> equivalenceFn) {
|
|
||||||
auto offsetsSizesAndStrides2 =
|
|
||||||
dyn_cast<OffsetSizeAndStrideOpInterface>(op2.getOperation());
|
|
||||||
if (!offsetsSizesAndStrides2)
|
|
||||||
return false;
|
|
||||||
FailureOr<bool> overlappingSlices =
|
|
||||||
ValueBoundsConstraintSet::areOverlappingSlices(op1,
|
|
||||||
offsetsSizesAndStrides2);
|
|
||||||
if (failed(overlappingSlices) || *overlappingSlices)
|
|
||||||
return false;
|
|
||||||
return equivalenceFn(
|
|
||||||
getContainerOperand(cast<SubsetOpInterface>(op1.getOperation())),
|
|
||||||
getContainerOperand(op2));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ExtractSliceOpSubsetOpInterface
|
struct ExtractSliceOpSubsetOpInterface
|
||||||
: public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
|
: public SubsetOpInterface::ExternalModel<ExtractSliceOpSubsetOpInterface,
|
||||||
tensor::ExtractSliceOp> {
|
tensor::ExtractSliceOp> {
|
||||||
bool operatesOnEquivalentSubset(
|
FailureOr<HyperrectangularSlice>
|
||||||
Operation *op, SubsetOpInterface candidate,
|
getAccessedHyperrectangularSlice(Operation *op) const {
|
||||||
function_ref<bool(Value, Value)> equivalenceFn) const {
|
return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
|
||||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
|
||||||
return operateOnEquivalentSubsets(extractSliceOp, candidate, equivalenceFn);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operatesOnDisjointSubset(
|
|
||||||
Operation *op, SubsetOpInterface candidate,
|
|
||||||
function_ref<bool(Value, Value)> equivalenceFn) const {
|
|
||||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
|
||||||
return operateOnDisjointSubsets(extractSliceOp, candidate, equivalenceFn);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -99,18 +38,9 @@ template <typename OpTy>
|
|||||||
struct InsertSliceLikeOpSubsetOpInterface
|
struct InsertSliceLikeOpSubsetOpInterface
|
||||||
: public SubsetOpInterface::ExternalModel<
|
: public SubsetOpInterface::ExternalModel<
|
||||||
InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
|
InsertSliceLikeOpSubsetOpInterface<OpTy>, OpTy> {
|
||||||
bool operatesOnEquivalentSubset(
|
FailureOr<HyperrectangularSlice>
|
||||||
Operation *op, SubsetOpInterface candidate,
|
getAccessedHyperrectangularSlice(Operation *op) const {
|
||||||
function_ref<bool(Value, Value)> equivalenceFn) const {
|
return HyperrectangularSlice(cast<OffsetSizeAndStrideOpInterface>(op));
|
||||||
auto insertSliceOp = cast<OpTy>(op);
|
|
||||||
return operateOnEquivalentSubsets(insertSliceOp, candidate, equivalenceFn);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool operatesOnDisjointSubset(
|
|
||||||
Operation *op, SubsetOpInterface candidate,
|
|
||||||
function_ref<bool(Value, Value)> equivalenceFn) const {
|
|
||||||
auto insertSliceOp = cast<OpTy>(op);
|
|
||||||
return operateOnDisjointSubsets(insertSliceOp, candidate, equivalenceFn);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -93,10 +93,12 @@ add_mlir_library(MLIRSubsetOpInterface
|
|||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRDestinationStyleOpInterface
|
MLIRDestinationStyleOpInterface
|
||||||
MLIRSubsetOpInterfaceIncGen
|
MLIRSubsetOpInterfaceIncGen
|
||||||
|
MLIRValueBoundsOpInterface
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRDestinationStyleOpInterface
|
MLIRDestinationStyleOpInterface
|
||||||
MLIRIR
|
MLIRIR
|
||||||
|
MLIRValueBoundsOpInterface
|
||||||
)
|
)
|
||||||
|
|
||||||
add_mlir_interface_library(TilingInterface)
|
add_mlir_interface_library(TilingInterface)
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "mlir/Interfaces/SubsetOpInterface.h"
|
#include "mlir/Interfaces/SubsetOpInterface.h"
|
||||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
||||||
|
|
||||||
#include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
|
#include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
|
||||||
|
|
||||||
@ -40,6 +41,54 @@ bool detail::defaultIsEquivalentSubset(
|
|||||||
candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
|
candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool detail::defaultOperatesOnEquivalentSubset(
|
||||||
|
Operation *op, SubsetOpInterface candidate,
|
||||||
|
function_ref<bool(Value, Value)> equivalenceFn) {
|
||||||
|
auto subsetOp = cast<SubsetOpInterface>(op);
|
||||||
|
FailureOr<HyperrectangularSlice> slice =
|
||||||
|
subsetOp.getAccessedHyperrectangularSlice();
|
||||||
|
assert(succeeded(slice) &&
|
||||||
|
"operatesOnEquivalentSubset must be implemented if "
|
||||||
|
"getAccessedHyperrectangularSlice is not implemented");
|
||||||
|
FailureOr<HyperrectangularSlice> otherSlice =
|
||||||
|
candidate.getAccessedHyperrectangularSlice();
|
||||||
|
if (failed(otherSlice))
|
||||||
|
return false;
|
||||||
|
if (!equivalenceFn(subsetOp.getTensorContainer(),
|
||||||
|
candidate.getTensorContainer()))
|
||||||
|
return false;
|
||||||
|
FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
|
||||||
|
op->getContext(), *slice, *otherSlice);
|
||||||
|
return succeeded(equivalent) && *equivalent;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool detail::defaultOperatesOnDisjointSubset(
|
||||||
|
Operation *op, SubsetOpInterface candidate,
|
||||||
|
function_ref<bool(Value, Value)> equivalenceFn) {
|
||||||
|
auto subsetOp = cast<SubsetOpInterface>(op);
|
||||||
|
FailureOr<HyperrectangularSlice> slice =
|
||||||
|
subsetOp.getAccessedHyperrectangularSlice();
|
||||||
|
assert(succeeded(slice) &&
|
||||||
|
"defaultOperatesOnDisjointSubset must be implemented if "
|
||||||
|
"getAccessedHyperrectangularSlice is not implemented");
|
||||||
|
FailureOr<HyperrectangularSlice> otherSlice =
|
||||||
|
candidate.getAccessedHyperrectangularSlice();
|
||||||
|
if (failed(otherSlice))
|
||||||
|
return false;
|
||||||
|
if (!equivalenceFn(subsetOp.getTensorContainer(),
|
||||||
|
candidate.getTensorContainer()))
|
||||||
|
return false;
|
||||||
|
FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
|
||||||
|
op->getContext(), *slice, *otherSlice);
|
||||||
|
return succeeded(overlapping) && !*overlapping;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value detail::getTensorContainer(Operation *op) {
|
||||||
|
if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
|
||||||
|
return insertionOp.getDestinationOperand().get();
|
||||||
|
return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
|
LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
|
||||||
if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
|
if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
|
||||||
isa<SubsetInsertionOpInterface>(op.getOperation())))
|
isa<SubsetInsertionOpInterface>(op.getOperation())))
|
||||||
|
@ -25,6 +25,32 @@ namespace mlir {
|
|||||||
#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
|
#include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
|
||||||
|
ArrayRef<OpFoldResult> sizes,
|
||||||
|
ArrayRef<OpFoldResult> strides)
|
||||||
|
: mixedOffsets(offsets), mixedSizes(sizes), mixedStrides(strides) {
|
||||||
|
assert(offsets.size() == sizes.size() &&
|
||||||
|
"expected same number of offsets, sizes, strides");
|
||||||
|
assert(offsets.size() == strides.size() &&
|
||||||
|
"expected same number of offsets, sizes, strides");
|
||||||
|
}
|
||||||
|
|
||||||
|
HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
|
||||||
|
ArrayRef<OpFoldResult> sizes)
|
||||||
|
: mixedOffsets(offsets), mixedSizes(sizes) {
|
||||||
|
assert(offsets.size() == sizes.size() &&
|
||||||
|
"expected same number of offsets and sizes");
|
||||||
|
// Assume that all strides are 1.
|
||||||
|
if (offsets.empty())
|
||||||
|
return;
|
||||||
|
MLIRContext *ctx = offsets.front().getContext();
|
||||||
|
mixedStrides.append(offsets.size(), Builder(ctx).getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
HyperrectangularSlice::HyperrectangularSlice(OffsetSizeAndStrideOpInterface op)
|
||||||
|
: HyperrectangularSlice(op.getMixedOffsets(), op.getMixedSizes(),
|
||||||
|
op.getMixedStrides()) {}
|
||||||
|
|
||||||
/// If ofr is a constant integer or an IntegerAttr, return the integer.
|
/// If ofr is a constant integer or an IntegerAttr, return the integer.
|
||||||
static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
|
static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
|
||||||
// Case 1: Check for Constant integer.
|
// Case 1: Check for Constant integer.
|
||||||
@ -524,19 +550,44 @@ ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
|
|||||||
return *delta == 0;
|
return *delta == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
|
FailureOr<bool> ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1,
|
||||||
OffsetSizeAndStrideOpInterface slice1,
|
OpFoldResult ofr2) {
|
||||||
OffsetSizeAndStrideOpInterface slice2) {
|
Builder b(ofr1.getContext());
|
||||||
assert(slice1.getStaticOffsets().size() == slice1.getStaticOffsets().size() &&
|
AffineMap map =
|
||||||
|
AffineMap::get(/*dimCount=*/0, /*symbolCount=*/2,
|
||||||
|
b.getAffineSymbolExpr(0) - b.getAffineSymbolExpr(1));
|
||||||
|
SmallVector<OpFoldResult> ofrOperands;
|
||||||
|
ofrOperands.push_back(ofr1);
|
||||||
|
ofrOperands.push_back(ofr2);
|
||||||
|
SmallVector<Value> valueOperands;
|
||||||
|
AffineMap foldedMap =
|
||||||
|
foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
|
||||||
|
ValueDimList valueDims;
|
||||||
|
for (Value v : valueOperands) {
|
||||||
|
assert(v.getType().isIndex() && "expected index type");
|
||||||
|
valueDims.emplace_back(v, std::nullopt);
|
||||||
|
}
|
||||||
|
FailureOr<int64_t> delta =
|
||||||
|
computeConstantBound(presburger::BoundType::EQ, foldedMap, valueDims);
|
||||||
|
if (failed(delta))
|
||||||
|
return failure();
|
||||||
|
return *delta == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<bool>
|
||||||
|
ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
|
||||||
|
HyperrectangularSlice slice1,
|
||||||
|
HyperrectangularSlice slice2) {
|
||||||
|
assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
|
||||||
"expected slices of same rank");
|
"expected slices of same rank");
|
||||||
assert(slice1.getStaticSizes().size() == slice1.getStaticSizes().size() &&
|
assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
|
||||||
"expected slices of same rank");
|
"expected slices of same rank");
|
||||||
assert(slice1.getStaticStrides().size() == slice1.getStaticStrides().size() &&
|
assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
|
||||||
"expected slices of same rank");
|
"expected slices of same rank");
|
||||||
|
|
||||||
Builder b(slice1.getContext());
|
Builder b(ctx);
|
||||||
bool foundUnknownBound = false;
|
bool foundUnknownBound = false;
|
||||||
for (int64_t i = 0, e = slice1.getStaticOffsets().size(); i < e; ++i) {
|
for (int64_t i = 0, e = slice1.getMixedOffsets().size(); i < e; ++i) {
|
||||||
AffineMap map =
|
AffineMap map =
|
||||||
AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
|
AffineMap::get(/*dimCount=*/0, /*symbolCount=*/4,
|
||||||
b.getAffineSymbolExpr(0) +
|
b.getAffineSymbolExpr(0) +
|
||||||
@ -588,6 +639,48 @@ FailureOr<bool> ValueBoundsConstraintSet::areOverlappingSlices(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FailureOr<bool>
|
||||||
|
ValueBoundsConstraintSet::areEquivalentSlices(MLIRContext *ctx,
|
||||||
|
HyperrectangularSlice slice1,
|
||||||
|
HyperrectangularSlice slice2) {
|
||||||
|
assert(slice1.getMixedOffsets().size() == slice1.getMixedOffsets().size() &&
|
||||||
|
"expected slices of same rank");
|
||||||
|
assert(slice1.getMixedSizes().size() == slice1.getMixedSizes().size() &&
|
||||||
|
"expected slices of same rank");
|
||||||
|
assert(slice1.getMixedStrides().size() == slice1.getMixedStrides().size() &&
|
||||||
|
"expected slices of same rank");
|
||||||
|
|
||||||
|
// The two slices are equivalent if all of their offsets, sizes and strides
|
||||||
|
// are equal. If equality cannot be determined for at least one of those
|
||||||
|
// values, equivalence cannot be determined and this function returns
|
||||||
|
// "failure".
|
||||||
|
for (auto [offset1, offset2] :
|
||||||
|
llvm::zip_equal(slice1.getMixedOffsets(), slice2.getMixedOffsets())) {
|
||||||
|
FailureOr<bool> equal = areEqual(offset1, offset2);
|
||||||
|
if (failed(equal))
|
||||||
|
return failure();
|
||||||
|
if (!equal.value())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (auto [size1, size2] :
|
||||||
|
llvm::zip_equal(slice1.getMixedSizes(), slice2.getMixedSizes())) {
|
||||||
|
FailureOr<bool> equal = areEqual(size1, size2);
|
||||||
|
if (failed(equal))
|
||||||
|
return failure();
|
||||||
|
if (!equal.value())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (auto [stride1, stride2] :
|
||||||
|
llvm::zip_equal(slice1.getMixedStrides(), slice2.getMixedStrides())) {
|
||||||
|
FailureOr<bool> equal = areEqual(stride1, stride2);
|
||||||
|
if (failed(equal))
|
||||||
|
return failure();
|
||||||
|
if (!equal.value())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
ValueBoundsConstraintSet::BoundBuilder &
|
ValueBoundsConstraintSet::BoundBuilder &
|
||||||
ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
|
ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
|
||||||
assert(!this->dim.has_value() && "dim was already set");
|
assert(!this->dim.has_value() && "dim was already set");
|
||||||
|
@ -7,6 +7,11 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
|
|||||||
%ub = "test.foo"() : () -> (index)
|
%ub = "test.foo"() : () -> (index)
|
||||||
%step = "test.foo"() : () -> (index)
|
%step = "test.foo"() : () -> (index)
|
||||||
|
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%add = arith.addi %c0, %c1 : index
|
||||||
|
%sub = arith.subi %add, %c1 : index
|
||||||
|
|
||||||
// CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
|
// CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
|
||||||
// CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
|
// CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
|
||||||
%0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
|
%0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
|
||||||
@ -17,7 +22,9 @@ func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
|
|||||||
%1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
|
%1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
|
||||||
// CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
|
// CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
|
||||||
%2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
|
%2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
|
||||||
%3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
|
// Obfuscate the IR by inserting at offset %sub instead of 0; both of them
|
||||||
|
// have the same value.
|
||||||
|
%3 = tensor.insert_slice %2 into %t[%sub][5][1] : tensor<5xf32> into tensor<?xf32>
|
||||||
// CHECK: scf.yield %[[t]], %[[foo]]
|
// CHECK: scf.yield %[[t]], %[[foo]]
|
||||||
scf.yield %3 : tensor<?xf32>
|
scf.yield %3 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
@ -10230,6 +10230,7 @@ cc_library(
|
|||||||
":IR",
|
":IR",
|
||||||
":SubsetOpInterfaceIncGen",
|
":SubsetOpInterfaceIncGen",
|
||||||
":Support",
|
":Support",
|
||||||
|
":ValueBoundsOpInterface",
|
||||||
"//llvm:Support",
|
"//llvm:Support",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user