[mlir][sparse] add sparse tensor type conversion operation

Introduces a conversion from one (sparse) tensor type to another
(sparse) tensor type. See the operation doc for details. Actual
codegen for all cases is still TBD.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D107205
This commit is contained in:
Aart Bik 2021-07-30 17:52:39 -07:00
parent 7f55557765
commit 697ea09d47
5 changed files with 127 additions and 29 deletions

View File

@ -51,6 +51,38 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", []>,
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
}
def SparseTensor_ConvertOp : SparseTensor_Op<"convert", [SameOperandsAndResultType]>,
Arguments<(ins AnyTensor:$source)>,
Results<(outs AnyTensor:$dest)> {
string summary = "Converts between different tensor types";
string description = [{
Converts one sparse or dense tensor type to another tensor type. The rank
and dimensions of the source and destination types must match exactly,
only the sparse encoding of these types may be different. The name `convert`
was preferred over `cast`, since the operation may incur a non-trivial cost.
When converting between two different sparse tensor types, only explicitly
stored values are moved from one underlying sparse storage format to
the other. When converting from an unannotated dense tensor type to a
sparse tensor type, an explicit test for nonzero values is used. When
converting to an unannotated dense tensor type, implicit zeroes in the
sparse storage format are made explicit. Note that the conversions can have
non-trivial costs associated with them, since they may involve elaborate
data structure transformations. Also, conversions from sparse tensor types
into dense tensor types may be infeasible in terms of storage requirements.
Examples:
```mlir
%0 = sparse_tensor.convert %1 : tensor<32x32xf32> to tensor<32x32xf32, #CSR>
%2 = sparse_tensor.convert %3 : tensor<8x8xi32, #CSC> to tensor<8x8xi32, #CSR>
```
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
}
def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
Arguments<(ins AnyTensor:$tensor, Index:$dim)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {

View File

@ -208,11 +208,27 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) {
}
static LogicalResult verify(NewOp op) {
if (!getSparseTensorEncoding(op.getResult().getType()))
if (!getSparseTensorEncoding(op.result().getType()))
return op.emitError("expected a sparse tensor result");
return success();
}
static LogicalResult verify(ConvertOp op) {
if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) {
if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) {
assert(tp1.getRank() == tp2.getRank());
auto shape1 = tp1.getShape();
auto shape2 = tp2.getShape();
for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
if (shape1[d] != shape2[d])
return op.emitError()
<< "unexpected conversion mismatch in dimension " << d;
return success();
}
}
return op.emitError("unexpected type in convert");
}
static LogicalResult verify(ToPointersOp op) {
if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
if (failed(isInBounds(op.dim(), op.tensor())))

View File

@ -27,16 +27,23 @@ using namespace mlir::sparse_tensor;
namespace {
/// Internal encoding of primary storage. Keep this enum consistent
/// with the equivalent enum in the sparse runtime support library.
enum PrimaryTypeEnum : uint64_t {
kF64 = 1,
kF32 = 2,
kI64 = 3,
kI32 = 4,
kI16 = 5,
kI8 = 6
};
/// Returns internal type encoding for primary storage. Keep these
/// values consistent with the sparse runtime support library.
static unsigned getPrimaryTypeEncoding(Type tp) {
if (tp.isF64())
return 1;
if (tp.isF32())
return 2;
if (tp.isInteger(64))
return 3;
if (tp.isInteger(32))
return 4;
if (tp.isInteger(16))
return 5;
if (tp.isInteger(8))
return 6;
return 0;
}
/// Returns internal type encoding for overhead storage. Keep these
/// values consistent with the sparse runtime support library.
@ -170,20 +177,8 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
// Secondary and primary types encoding.
unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
unsigned primary;
if (eltType.isF64())
primary = kF64;
else if (eltType.isF32())
primary = kF32;
else if (eltType.isInteger(64))
primary = kI64;
else if (eltType.isInteger(32))
primary = kI32;
else if (eltType.isInteger(16))
primary = kI16;
else if (eltType.isInteger(8))
primary = kI8;
else
unsigned primary = getPrimaryTypeEncoding(eltType);
if (!primary)
return failure();
params.push_back(
rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
@ -200,6 +195,17 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
}
};
/// Sparse conversion rule for the convert operator.
class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ConvertOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// TODO: implement conversions lowering
return failure();
}
};
/// Sparse conversion rule for pointer accesses.
class SparseTensorToPointersConverter
: public OpConversionPattern<ToPointersOp> {
@ -324,8 +330,8 @@ public:
void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
SparseTensorNewConverter, SparseTensorToPointersConverter,
SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
SparseTensorToTensorConverter>(typeConverter,
patterns.getContext());
SparseTensorNewConverter, SparseTensorConvertConverter,
SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
typeConverter, patterns.getContext());
}

View File

@ -111,3 +111,21 @@ func @sparse_to_unannotated_tensor(%arg0: memref<?xf64>) -> tensor<16x32xf64> {
%0 = sparse_tensor.tensor %arg0 : memref<?xf64> to tensor<16x32xf64>
return %0 : tensor<16x32xf64>
}
// -----
func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> {
// expected-error@+1 {{unexpected type in convert}}
%0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32>
return %0 : tensor<10xf32>
}
// -----
#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
func @sparse_convert_mismatch(%arg0: tensor<10x10xf32>) -> tensor<10x?xf32, #CSR> {
// expected-error@+1 {{unexpected conversion mismatch in dimension 1}}
%0 = sparse_tensor.convert %arg0 : tensor<10x10xf32> to tensor<10x?xf32, #CSR>
return %0 : tensor<10x?xf32, #CSR>
}

View File

@ -15,6 +15,32 @@ func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
// CHECK-LABEL: func @sparse_convert_1d_to_sparse(
// CHECK-SAME: %[[A:.*]]: tensor<64xf32>)
// CHECK: %[[T:.*]] = sparse_tensor.convert %[[A]] : tensor<64xf32> to tensor<64xf32, #{{.*}}>
// CHECK: return %[[T]] : tensor<64xf32, #{{.*}}>
func @sparse_convert_1d_to_sparse(%arg0: tensor<64xf32>) -> tensor<64xf32, #SparseVector> {
%0 = sparse_tensor.convert %arg0 : tensor<64xf32> to tensor<64xf32, #SparseVector>
return %0 : tensor<64xf32, #SparseVector>
}
// -----
#SparseTensor = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ] }>
// CHECK-LABEL: func @sparse_convert_3d_from_sparse(
// CHECK-SAME: %[[A:.*]]: tensor<8x8x8xf64, #{{.*}}>)
// CHECK: %[[T:.*]] = sparse_tensor.convert %[[A]] : tensor<8x8x8xf64, #{{.*}}> to tensor<8x8x8xf64>
// CHECK: return %[[T]] : tensor<8x8x8xf64>
func @sparse_convert_3d_from_sparse(%arg0: tensor<8x8x8xf64, #SparseTensor>) -> tensor<8x8x8xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<8x8x8xf64, #SparseTensor> to tensor<8x8x8xf64>
return %0 : tensor<8x8x8xf64>
}
// -----
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
// CHECK-LABEL: func @sparse_pointers(
// CHECK-SAME: %[[A:.*]]: tensor<128xf64, #{{.*}}>)
// CHECK: %[[C:.*]] = constant 0 : index