mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-18 09:06:43 +00:00
[mlir] Add two clone methods about encoding to RankedTensorType. (#127709)
There are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType: - dropEncoding(): Return a clone of this type without the encoding. - cloneWithEncoding(Attribute encoding): Return a clone of this type with the given new encoding and the same shape and element type as this type. Signed-off-by: hanhanW <hanhan0912@gmail.com>
This commit is contained in:
parent
fb191efa70
commit
28d7671471
@ -1035,6 +1035,17 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
|
||||
RankedTensorType clone(::mlir::Type elementType) {
|
||||
return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
|
||||
}
|
||||
|
||||
/// Return a clone of this type without the encoding.
|
||||
RankedTensorType dropEncoding() {
|
||||
return RankedTensorType::get(getShape(), getElementType());
|
||||
}
|
||||
|
||||
/// Return a clone of this type with the given new encoding and the same
|
||||
/// shape and element type as this type.
|
||||
RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
|
||||
return RankedTensorType::get(getShape(), getElementType(), encoding);
|
||||
}
|
||||
}];
|
||||
let skipDefaultBuilders = 1;
|
||||
let genVerifyDecl = 1;
|
||||
|
@ -282,6 +282,20 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
|
||||
ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
|
||||
view = mlir::cast<TensorWithString>(viewCreated);
|
||||
EXPECT_EQ(view.getName(), "bob");
|
||||
|
||||
// Verify encoding clone methods.
|
||||
EXPECT_EQ(unitEncodingRankedTensorType,
|
||||
cast<RankedTensorType>(noEncodingRankedTensorType)
|
||||
.cloneWithEncoding(unitAttr));
|
||||
EXPECT_EQ(stringEncodingRankedTensorType,
|
||||
cast<RankedTensorType>(noEncodingRankedTensorType)
|
||||
.cloneWithEncoding(stringAttr));
|
||||
EXPECT_EQ(
|
||||
noEncodingRankedTensorType,
|
||||
cast<RankedTensorType>(unitEncodingRankedTensorType).dropEncoding());
|
||||
EXPECT_EQ(
|
||||
noEncodingRankedTensorType,
|
||||
cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user