[mlir] Add two clone methods about encoding to RankedTensorType. (#127709)

There are clone methods for shape and element type, but not for
encodings. The revision adds two clone method to RankedTensorType:
- dropEncoding(): Return a clone of this type without the encoding.
- cloneWithEncoding(Attribute encoding): Return a clone of this type
with the given new encoding and the same shape and element type as this
type.

Signed-off-by: hanhanW <hanhan0912@gmail.com>
This commit is contained in:
Han-Chung Wang 2025-02-27 17:59:27 -08:00 committed by GitHub
parent fb191efa70
commit 28d7671471
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 0 deletions

View File

@ -1035,6 +1035,17 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
RankedTensorType clone(::mlir::Type elementType) {
return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
}
/// Return a clone of this type without the encoding.
RankedTensorType dropEncoding() {
return RankedTensorType::get(getShape(), getElementType());
}
/// Return a clone of this type with the given new encoding and the same
/// shape and element type as this type.
RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
return RankedTensorType::get(getShape(), getElementType(), encoding);
}
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;

View File

@ -282,6 +282,20 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
view = mlir::cast<TensorWithString>(viewCreated);
EXPECT_EQ(view.getName(), "bob");
// Verify encoding clone methods.
EXPECT_EQ(unitEncodingRankedTensorType,
cast<RankedTensorType>(noEncodingRankedTensorType)
.cloneWithEncoding(unitAttr));
EXPECT_EQ(stringEncodingRankedTensorType,
cast<RankedTensorType>(noEncodingRankedTensorType)
.cloneWithEncoding(stringAttr));
EXPECT_EQ(
noEncodingRankedTensorType,
cast<RankedTensorType>(unitEncodingRankedTensorType).dropEncoding());
EXPECT_EQ(
noEncodingRankedTensorType,
cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
}
} // namespace