[Mosaic:TPU] Add shuffled load and store.

we also emulate shuffled store using (store + shuffled load + store) for previous generations.

PiperOrigin-RevId: 662657663
This commit is contained in:
Jevin Jiang 2024-08-13 14:40:37 -07:00 committed by jax authors
parent d2b85a48af
commit 2dea3d6a0c
2 changed files with 106 additions and 0 deletions

View File

@ -233,6 +233,37 @@ def TPU_StridedStoreOp : TPU_Op<"strided_store"> {
let hasVerifier = 1;
}
def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load"> {
let arguments = (ins
AnyMemRef:$base,
Variadic<Index>:$indices,
DenseBoolArrayAttr:$sublane_mask,
DenseI32ArrayAttr:$sublane_offsets
);
let results = (outs TPU_Vreg:$result);
let assemblyFormat = [{
$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)
}];
let hasVerifier = 1;
let hasCanonicalizeMethod = 1;
}
def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> {
let arguments = (ins
TPU_Vreg:$valueToStore,
AnyMemRef:$base,
Variadic<Index>:$indices,
DenseBoolArrayAttr:$sublane_mask,
DenseI32ArrayAttr:$sublane_offsets
);
let results = (outs);
let assemblyFormat = [{
$base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
}];
let hasVerifier = 1;
let hasCanonicalizeMethod = 1;
}
// TODO(jevinjiang): deprecate to use dynamic_rotate.
def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> {
let arguments = (ins

View File

@ -488,6 +488,81 @@ LogicalResult RegionOp::verify() {
return success();
}
LogicalResult ShuffledLoadOp::verify() {
if (getBase().getType().getRank() != getIndices().size()) {
return emitOpError("Base memref's rank and indices size do not match: ")
<< getBase().getType().getRank() << " vs " << getIndices().size();
}
if (getSublaneMask().size() != getType().getShape()[0]) {
return emitOpError("Expected sublane mask size equals to ")
<< getType().getShape()[0] << " but got " << getSublaneMask().size();
}
if (getSublaneOffsets().size() != getType().getShape()[0]) {
return emitOpError("Expected sublane offsets size equals to ")
<< getType().getShape()[0] << " but got "
<< getSublaneOffsets().size();
}
return success();
}
LogicalResult ShuffledLoadOp::canonicalize(ShuffledLoadOp op,
PatternRewriter &rewriter) {
bool can_convert_to_simple_load = true;
for (int i = 0; i < op.getSublaneOffsets().size(); ++i) {
if (op.getSublaneOffsets()[i] != i) {
can_convert_to_simple_load = false;
break;
};
}
if (can_convert_to_simple_load) {
rewriter.replaceOpWithNewOp<tpu::LoadOp>(
op, op.getType(), op.getBase(), op.getIndices(), op.getSublaneMask(),
/*sublane_stride=*/nullptr);
}
return success();
}
LogicalResult ShuffledStoreOp::verify() {
if (getBase().getType().getRank() != getIndices().size()) {
return emitOpError("Base memref's rank and indices size do not match: ")
<< getBase().getType().getRank() << " vs " << getIndices().size();
}
if (getValueToStore().getType().getRank() != getIndices().size()) {
return emitOpError(
"The rank of value to store and indices size do not match: ")
<< getBase().getType().getRank() << " vs " << getIndices().size();
}
if (getSublaneMask().size() != getValueToStore().getType().getShape()[0]) {
return emitOpError("Expected sublane mask size equals to ")
<< getValueToStore().getType().getShape()[0] << " but got "
<< getSublaneMask().size();
}
if (getSublaneOffsets().size() != getValueToStore().getType().getShape()[0]) {
return emitOpError("Expected sublane offsets size equals to ")
<< getValueToStore().getType().getShape()[0] << " but got "
<< getSublaneOffsets().size();
}
return success();
}
LogicalResult ShuffledStoreOp::canonicalize(ShuffledStoreOp op,
PatternRewriter &rewriter) {
bool can_convert_to_simple_store = true;
for (int i = 0; i < op.getSublaneOffsets().size(); ++i) {
if (op.getSublaneOffsets()[i] != i) {
can_convert_to_simple_store = false;
break;
};
}
if (can_convert_to_simple_store) {
rewriter.replaceOpWithNewOp<tpu::StoreOp>(op, op.getValueToStore(),
op.getBase(), op.getIndices(),
op.getSublaneMask(),
/*mask=*/nullptr,
/*sublane_stride=*/nullptr);
}
return success();
}
} // namespace tpu
} // namespace mlir