[Mosaic] Add support for sublane strided load.

PiperOrigin-RevId: 552581319
This commit is contained in:
Jevin Jiang 2023-07-31 13:47:04 -07:00 committed by jax authors
parent e4955ecd23
commit 9d62d867bc
3 changed files with 67 additions and 6 deletions

View File

@ -37,6 +37,19 @@ class TPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
let mnemonic = mnemonic_;
}
def TPU_Vreg : Type<
And<[IsVectorTypePred,
Or<[
And<[
CPred<"llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128}">,
CPred<"llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth() == 32">
]>,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{"
"8, 128, 32 / ::llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth()}">,
]>
]>,
"native-sized vreg", "::mlir::VectorType">;
class TPU_Type<string name, string mnemonic_, list<Trait> traits = []>
: TypeDef<TPU_Dialect, name, traits> {
let mnemonic = mnemonic_;
@ -135,10 +148,9 @@ def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure, SameOperandsAndResultType]> {
}];
}
// TODO(apaszke): Verify that stores are of native size and are aligned.
def TPU_StoreOp : TPU_Op<"store", [AttrSizedOperandSegments]> {
let arguments = (ins
AnyType:$valueToStore,
TPU_Vreg:$valueToStore,
AnyType:$base,
Variadic<Index>:$indices,
DenseBoolArrayAttr:$sublane_mask,
@ -150,16 +162,16 @@ def TPU_StoreOp : TPU_Op<"store", [AttrSizedOperandSegments]> {
}];
}
// TODO(apaszke): Verify that loads are of native size and are aligned.
def TPU_LoadOp : TPU_Op<"load"> {
let arguments = (ins
AnyType:$base,
Variadic<Index>:$indices,
DenseBoolArrayAttr:$sublane_mask
DenseBoolArrayAttr:$sublane_mask,
OptionalAttr<I32Attr>:$sublane_stride // In sublane-sized units
);
let results = (outs AnyType:$result);
let results = (outs TPU_Vreg:$result);
let assemblyFormat = [{
$base `[` $indices `]` `sublanes` $sublane_mask attr-dict `:` type($base) `,` type($result)
$base `[` $indices `]` `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($result)
}];
}

View File

@ -208,6 +208,10 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::LoadOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::EraseLayoutOp>(any_op)) {
if (infer(op).failed()) {
return failure();
@ -529,6 +533,21 @@ class VectorLayoutInferer {
return success();
}
LogicalResult infer(tpu::LoadOp op) {
auto res_ty = op.getResult().getType();
int8_t bitwidth = res_ty.getElementTypeBitWidth();
// We expect the result is already a native-sized vreg.
TPU_CHECK_OP(bitwidth == 32 && res_ty.getShape()[0] == target_shape_[0] &&
res_ty.getShape()[1] == target_shape_[1],
"Only 32-bit loads suppored");
SmallVector<Layout, 4> in_layout(op->getNumOperands(), kNoLayout);
auto out_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
setLayout(op, in_layout, out_layout);
return success();
}
LogicalResult infer(tpu::EraseLayoutOp op) {
setLayout(op, kNoLayout, kNoLayout);
return success();

View File

@ -1697,6 +1697,36 @@ def _scf_yield_rule( # pylint: disable=missing-function-docstring
return ctx.set_operands(op.operation, unrolled)
@_register_rule("tpu.load")
def _tpu_load_rule(
ctx: RewriteContext,
op: tpu.LoadOp,
layout_in: Sequence[Layout],
layout_out: VectorLayout,
):
assert all(li is None for li in layout_in)
ty = ir.VectorType(op.result.type)
# We expect the result is already a native-sized vreg.
if layout_out.bitwidth != 32:
raise NotImplementedError("Only 32-bit loads supported")
assert layout_out == VectorLayout(32, (0, 0), TARGET_SHAPE, None)
indices = [get_int_const(v, "tpu.load index") for v in op.indices]
if indices[1] % TARGET_SHAPE.lanes:
raise NotImplementedError(
f"Lane index is not a multiple of {TARGET_SHAPE.lanes}"
)
tile = tpu.LoadOp(
ty,
op.base,
op.indices,
op.sublane_mask,
sublane_stride=op.sublane_stride,
)
return ctx.replace(op, assemble(ty, layout_out, np.asarray([[tile]])))
@_register_rule("tpu.trace")
def _tpu_trace_rule(ctx: RewriteContext, op: tpu.TraceOp, # pylint: disable=missing-function-docstring
layout_in: Layout, layout_out: Layout):