mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Add support for sublane strided store.
PiperOrigin-RevId: 552595663
This commit is contained in:
parent
a40f900e23
commit
6e37c4202d
@ -154,11 +154,12 @@ def TPU_StoreOp : TPU_Op<"store", [AttrSizedOperandSegments]> {
|
||||
AnyType:$base,
|
||||
Variadic<Index>:$indices,
|
||||
DenseBoolArrayAttr:$sublane_mask,
|
||||
Optional<AnyType>:$mask
|
||||
Optional<AnyType>:$mask,
|
||||
OptionalAttr<I32Attr>:$sublane_stride // In sublane-sized units
|
||||
);
|
||||
let results = (outs);
|
||||
let assemblyFormat = [{
|
||||
$base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? `sublanes` $sublane_mask attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask)
|
||||
$base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask)
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -212,6 +212,10 @@ class VectorLayoutInferer {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::StoreOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::EraseLayoutOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
@ -548,6 +552,23 @@ class VectorLayoutInferer {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::StoreOp op) {
|
||||
auto store_ty = op.getValueToStore().getType();
|
||||
int8_t bitwidth = store_ty.getElementTypeBitWidth();
|
||||
|
||||
// We expect the value to store is already a native-sized vreg.
|
||||
TPU_CHECK_OP(bitwidth == 32 &&
|
||||
store_ty.getShape()[0] == target_shape_[0] &&
|
||||
store_ty.getShape()[1] == target_shape_[1],
|
||||
"Only 32-bit stores suppored");
|
||||
auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
|
||||
ImplicitDim::kNone);
|
||||
SmallVector<Layout, 5> in_layout{store_layout};
|
||||
in_layout.insert(in_layout.end(), op.getIndices().size() + 1, kNoLayout);
|
||||
setInLayout(op, in_layout);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::EraseLayoutOp op) {
|
||||
setLayout(op, kNoLayout, kNoLayout);
|
||||
return success();
|
||||
|
@ -1727,6 +1727,41 @@ def _tpu_load_rule(
|
||||
)
|
||||
return ctx.replace(op, assemble(ty, layout_out, np.asarray([[tile]])))
|
||||
|
||||
|
||||
@_register_rule("tpu.store")
|
||||
def _tpu_store_rule( # pylint: disable=missing-function-docstring
|
||||
ctx: RewriteContext,
|
||||
op: tpu.StoreOp,
|
||||
layout_in: Sequence[Layout],
|
||||
layout_out: None, # pylint: disable=unused-argument
|
||||
):
|
||||
to_store_layout, *other_layouts = layout_in
|
||||
assert all(li is None for li in other_layouts)
|
||||
|
||||
# We expect the value to store is already a native-sized vreg.
|
||||
if to_store_layout.bitwidth != 32:
|
||||
raise NotImplementedError("Only 32-bit stores supported")
|
||||
assert to_store_layout == VectorLayout(32, (0, 0), TARGET_SHAPE, None)
|
||||
|
||||
indices = [get_int_const(v, "tpu.store index") for v in op.indices]
|
||||
if indices[1] % TARGET_SHAPE.lanes:
|
||||
raise NotImplementedError(
|
||||
f"Lane index is not a multiple of {TARGET_SHAPE.lanes}"
|
||||
)
|
||||
|
||||
tiles = disassemble(to_store_layout, op.valueToStore)
|
||||
assert tiles.shape == (1, 1)
|
||||
tpu.StoreOp(
|
||||
tiles[0][0],
|
||||
op.base,
|
||||
op.indices,
|
||||
op.sublane_mask,
|
||||
mask=op.mask,
|
||||
sublane_stride=op.sublane_stride,
|
||||
)
|
||||
return ctx.erase(op)
|
||||
|
||||
|
||||
@_register_rule("tpu.trace")
|
||||
def _tpu_trace_rule(ctx: RewriteContext, op: tpu.TraceOp, # pylint: disable=missing-function-docstring
|
||||
layout_in: Layout, layout_out: Layout):
|
||||
|
Loading…
x
Reference in New Issue
Block a user