mirror of
https://github.com/llvm/llvm-project.git
synced 2025-05-02 23:46:05 +00:00
[mlir] Centralize handling of memref element types.
This also beefs up the test coverage: - Make unranked memref testing consistent with ranked memrefs. - Add testing for the invalid element type cases. This is not quite NFC: index types are now allowed in unranked memrefs. Differential Revision: https://reviews.llvm.org/D85541
This commit is contained in:
parent
a97dfdc30b
commit
b0d76f454d
@ -426,6 +426,11 @@ class BaseMemRefType : public ShapedType {
|
||||
public:
|
||||
using ShapedType::ShapedType;
|
||||
|
||||
/// Return true if the specified element type is ok in a memref.
|
||||
static bool isValidElementType(Type type) {
|
||||
return type.isIntOrIndexOrFloat() || type.isa<VectorType, ComplexType>();
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(Type type);
|
||||
};
|
||||
|
@ -408,9 +408,7 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
|
||||
Optional<Location> location) {
|
||||
auto *context = elementType.getContext();
|
||||
|
||||
// Check that memref is formed from allowed types.
|
||||
if (!elementType.isIntOrIndexOrFloat() &&
|
||||
!elementType.isa<VectorType, ComplexType>())
|
||||
if (!BaseMemRefType::isValidElementType(elementType))
|
||||
return emitOptionalError(location, "invalid memref element type"),
|
||||
MemRefType();
|
||||
|
||||
@ -486,9 +484,7 @@ unsigned UnrankedMemRefType::getMemorySpace() const {
|
||||
LogicalResult
|
||||
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
|
||||
unsigned memorySpace) {
|
||||
// Check that memref is formed from allowed types.
|
||||
if (!elementType.isIntOrFloat() &&
|
||||
!elementType.isa<VectorType, ComplexType>())
|
||||
if (!BaseMemRefType::isValidElementType(elementType))
|
||||
return emitError(loc, "invalid memref element type");
|
||||
return success();
|
||||
}
|
||||
|
@ -17,6 +17,14 @@ func @nestedtensor(tensor<tensor<i8>>) -> () // expected-error {{invalid tensor
|
||||
|
||||
// -----
|
||||
|
||||
func @illegalmemrefelementtype(memref<?xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
|
||||
|
||||
// -----
|
||||
|
||||
func @illegalunrankedmemrefelementtype(memref<*xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
|
||||
|
||||
// -----
|
||||
|
||||
func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements must be int or float type}}
|
||||
|
||||
// -----
|
||||
|
@ -152,6 +152,12 @@ func @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
|
||||
// CHECK: func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
|
||||
func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
|
||||
|
||||
// CHECK: func @unranked_memref_with_index_elems(memref<*xindex>)
|
||||
func @unranked_memref_with_index_elems(memref<*xindex>)
|
||||
|
||||
// CHECK: func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
|
||||
func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
|
||||
|
||||
// CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
|
||||
func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user