[Mosaic] Add support for sublane strided store.

PiperOrigin-RevId: 552595663
This commit is contained in:
Jevin Jiang 2023-07-31 14:39:17 -07:00 committed by jax authors
parent a40f900e23
commit 6e37c4202d
3 changed files with 59 additions and 2 deletions

View File

@ -154,11 +154,12 @@ def TPU_StoreOp : TPU_Op<"store", [AttrSizedOperandSegments]> {
AnyType:$base,
Variadic<Index>:$indices,
DenseBoolArrayAttr:$sublane_mask,
Optional<AnyType>:$mask
Optional<AnyType>:$mask,
OptionalAttr<I32Attr>:$sublane_stride // In sublane-sized units
);
let results = (outs);
let assemblyFormat = [{
$base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? `sublanes` $sublane_mask attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask)
$base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask)
}];
}

View File

@ -212,6 +212,10 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::StoreOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::EraseLayoutOp>(any_op)) {
if (infer(op).failed()) {
return failure();
@ -548,6 +552,23 @@ class VectorLayoutInferer {
return success();
}
LogicalResult infer(tpu::StoreOp op) {
auto store_ty = op.getValueToStore().getType();
int8_t bitwidth = store_ty.getElementTypeBitWidth();
// We expect the value to store is already a native-sized vreg.
TPU_CHECK_OP(bitwidth == 32 &&
store_ty.getShape()[0] == target_shape_[0] &&
store_ty.getShape()[1] == target_shape_[1],
"Only 32-bit stores suppored");
auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
SmallVector<Layout, 5> in_layout{store_layout};
in_layout.insert(in_layout.end(), op.getIndices().size() + 1, kNoLayout);
setInLayout(op, in_layout);
return success();
}
LogicalResult infer(tpu::EraseLayoutOp op) {
setLayout(op, kNoLayout, kNoLayout);
return success();

View File

@ -1727,6 +1727,41 @@ def _tpu_load_rule(
)
return ctx.replace(op, assemble(ty, layout_out, np.asarray([[tile]])))
@_register_rule("tpu.store")
def _tpu_store_rule( # pylint: disable=missing-function-docstring
ctx: RewriteContext,
op: tpu.StoreOp,
layout_in: Sequence[Layout],
layout_out: None, # pylint: disable=unused-argument
):
to_store_layout, *other_layouts = layout_in
assert all(li is None for li in other_layouts)
# We expect the value to store is already a native-sized vreg.
if to_store_layout.bitwidth != 32:
raise NotImplementedError("Only 32-bit stores supported")
assert to_store_layout == VectorLayout(32, (0, 0), TARGET_SHAPE, None)
indices = [get_int_const(v, "tpu.store index") for v in op.indices]
if indices[1] % TARGET_SHAPE.lanes:
raise NotImplementedError(
f"Lane index is not a multiple of {TARGET_SHAPE.lanes}"
)
tiles = disassemble(to_store_layout, op.valueToStore)
assert tiles.shape == (1, 1)
tpu.StoreOp(
tiles[0][0],
op.base,
op.indices,
op.sublane_mask,
mask=op.mask,
sublane_stride=op.sublane_stride,
)
return ctx.erase(op)
@_register_rule("tpu.trace")
def _tpu_trace_rule(ctx: RewriteContext, op: tpu.TraceOp, # pylint: disable=missing-function-docstring
layout_in: Layout, layout_out: Layout):