mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic:TPU] Add shuffled load and store.
we also emulate shuffled store using (store + shuffled load + store) for previous generations. PiperOrigin-RevId: 662657663
This commit is contained in:
parent
d2b85a48af
commit
2dea3d6a0c
@ -233,6 +233,37 @@ def TPU_StridedStoreOp : TPU_Op<"strided_store"> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load"> {
|
||||
let arguments = (ins
|
||||
AnyMemRef:$base,
|
||||
Variadic<Index>:$indices,
|
||||
DenseBoolArrayAttr:$sublane_mask,
|
||||
DenseI32ArrayAttr:$sublane_offsets
|
||||
);
|
||||
let results = (outs TPU_Vreg:$result);
|
||||
let assemblyFormat = [{
|
||||
$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizeMethod = 1;
|
||||
}
|
||||
|
||||
def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> {
|
||||
let arguments = (ins
|
||||
TPU_Vreg:$valueToStore,
|
||||
AnyMemRef:$base,
|
||||
Variadic<Index>:$indices,
|
||||
DenseBoolArrayAttr:$sublane_mask,
|
||||
DenseI32ArrayAttr:$sublane_offsets
|
||||
);
|
||||
let results = (outs);
|
||||
let assemblyFormat = [{
|
||||
$base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizeMethod = 1;
|
||||
}
|
||||
|
||||
// TODO(jevinjiang): deprecate to use dynamic_rotate.
|
||||
def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> {
|
||||
let arguments = (ins
|
||||
|
@ -488,6 +488,81 @@ LogicalResult RegionOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ShuffledLoadOp::verify() {
|
||||
if (getBase().getType().getRank() != getIndices().size()) {
|
||||
return emitOpError("Base memref's rank and indices size do not match: ")
|
||||
<< getBase().getType().getRank() << " vs " << getIndices().size();
|
||||
}
|
||||
if (getSublaneMask().size() != getType().getShape()[0]) {
|
||||
return emitOpError("Expected sublane mask size equals to ")
|
||||
<< getType().getShape()[0] << " but got " << getSublaneMask().size();
|
||||
}
|
||||
if (getSublaneOffsets().size() != getType().getShape()[0]) {
|
||||
return emitOpError("Expected sublane offsets size equals to ")
|
||||
<< getType().getShape()[0] << " but got "
|
||||
<< getSublaneOffsets().size();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ShuffledLoadOp::canonicalize(ShuffledLoadOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
bool can_convert_to_simple_load = true;
|
||||
for (int i = 0; i < op.getSublaneOffsets().size(); ++i) {
|
||||
if (op.getSublaneOffsets()[i] != i) {
|
||||
can_convert_to_simple_load = false;
|
||||
break;
|
||||
};
|
||||
}
|
||||
if (can_convert_to_simple_load) {
|
||||
rewriter.replaceOpWithNewOp<tpu::LoadOp>(
|
||||
op, op.getType(), op.getBase(), op.getIndices(), op.getSublaneMask(),
|
||||
/*sublane_stride=*/nullptr);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ShuffledStoreOp::verify() {
|
||||
if (getBase().getType().getRank() != getIndices().size()) {
|
||||
return emitOpError("Base memref's rank and indices size do not match: ")
|
||||
<< getBase().getType().getRank() << " vs " << getIndices().size();
|
||||
}
|
||||
if (getValueToStore().getType().getRank() != getIndices().size()) {
|
||||
return emitOpError(
|
||||
"The rank of value to store and indices size do not match: ")
|
||||
<< getBase().getType().getRank() << " vs " << getIndices().size();
|
||||
}
|
||||
if (getSublaneMask().size() != getValueToStore().getType().getShape()[0]) {
|
||||
return emitOpError("Expected sublane mask size equals to ")
|
||||
<< getValueToStore().getType().getShape()[0] << " but got "
|
||||
<< getSublaneMask().size();
|
||||
}
|
||||
if (getSublaneOffsets().size() != getValueToStore().getType().getShape()[0]) {
|
||||
return emitOpError("Expected sublane offsets size equals to ")
|
||||
<< getValueToStore().getType().getShape()[0] << " but got "
|
||||
<< getSublaneOffsets().size();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ShuffledStoreOp::canonicalize(ShuffledStoreOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
bool can_convert_to_simple_store = true;
|
||||
for (int i = 0; i < op.getSublaneOffsets().size(); ++i) {
|
||||
if (op.getSublaneOffsets()[i] != i) {
|
||||
can_convert_to_simple_store = false;
|
||||
break;
|
||||
};
|
||||
}
|
||||
if (can_convert_to_simple_store) {
|
||||
rewriter.replaceOpWithNewOp<tpu::StoreOp>(op, op.getValueToStore(),
|
||||
op.getBase(), op.getIndices(),
|
||||
op.getSublaneMask(),
|
||||
/*mask=*/nullptr,
|
||||
/*sublane_stride=*/nullptr);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace mlir
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user