mirror of
https://github.com/llvm/llvm-project.git
synced 2025-05-02 23:46:05 +00:00
[mlir] Centralize handling of memref element types.
This also beefs up the test coverage: - Make unranked memref testing consistent with ranked memrefs. - Add testing for the invalid element type cases. This is not quite NFC: index types are now allowed in unranked memrefs. Differential Revision: https://reviews.llvm.org/D85541
This commit is contained in:
parent
a97dfdc30b
commit
b0d76f454d
@ -426,6 +426,11 @@ class BaseMemRefType : public ShapedType {
|
|||||||
public:
|
public:
|
||||||
using ShapedType::ShapedType;
|
using ShapedType::ShapedType;
|
||||||
|
|
||||||
|
/// Return true if the specified element type is ok in a memref.
|
||||||
|
static bool isValidElementType(Type type) {
|
||||||
|
return type.isIntOrIndexOrFloat() || type.isa<VectorType, ComplexType>();
|
||||||
|
}
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(Type type);
|
static bool classof(Type type);
|
||||||
};
|
};
|
||||||
|
@ -408,9 +408,7 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
|
|||||||
Optional<Location> location) {
|
Optional<Location> location) {
|
||||||
auto *context = elementType.getContext();
|
auto *context = elementType.getContext();
|
||||||
|
|
||||||
// Check that memref is formed from allowed types.
|
if (!BaseMemRefType::isValidElementType(elementType))
|
||||||
if (!elementType.isIntOrIndexOrFloat() &&
|
|
||||||
!elementType.isa<VectorType, ComplexType>())
|
|
||||||
return emitOptionalError(location, "invalid memref element type"),
|
return emitOptionalError(location, "invalid memref element type"),
|
||||||
MemRefType();
|
MemRefType();
|
||||||
|
|
||||||
@ -486,9 +484,7 @@ unsigned UnrankedMemRefType::getMemorySpace() const {
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
|
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
|
||||||
unsigned memorySpace) {
|
unsigned memorySpace) {
|
||||||
// Check that memref is formed from allowed types.
|
if (!BaseMemRefType::isValidElementType(elementType))
|
||||||
if (!elementType.isIntOrFloat() &&
|
|
||||||
!elementType.isa<VectorType, ComplexType>())
|
|
||||||
return emitError(loc, "invalid memref element type");
|
return emitError(loc, "invalid memref element type");
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,14 @@ func @nestedtensor(tensor<tensor<i8>>) -> () // expected-error {{invalid tensor
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @illegalmemrefelementtype(memref<?xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @illegalunrankedmemrefelementtype(memref<*xtensor<i8>>) -> () // expected-error {{invalid memref element type}}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements must be int or float type}}
|
func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements must be int or float type}}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -152,6 +152,12 @@ func @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
|
|||||||
// CHECK: func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
|
// CHECK: func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
|
||||||
func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
|
func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
|
||||||
|
|
||||||
|
// CHECK: func @unranked_memref_with_index_elems(memref<*xindex>)
|
||||||
|
func @unranked_memref_with_index_elems(memref<*xindex>)
|
||||||
|
|
||||||
|
// CHECK: func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
|
||||||
|
func @unranked_memref_with_vector_elems(memref<*xvector<10xf32>>)
|
||||||
|
|
||||||
// CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
|
// CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
|
||||||
func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
|
func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user