mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic] Add support for sublane strided load.
PiperOrigin-RevId: 552581319
This commit is contained in:
parent
e4955ecd23
commit
9d62d867bc
@ -37,6 +37,19 @@ class TPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
|
||||
let mnemonic = mnemonic_;
|
||||
}
|
||||
|
||||
def TPU_Vreg : Type<
|
||||
And<[IsVectorTypePred,
|
||||
Or<[
|
||||
And<[
|
||||
CPred<"llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128}">,
|
||||
CPred<"llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth() == 32">
|
||||
]>,
|
||||
CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{"
|
||||
"8, 128, 32 / ::llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth()}">,
|
||||
]>
|
||||
]>,
|
||||
"native-sized vreg", "::mlir::VectorType">;
|
||||
|
||||
class TPU_Type<string name, string mnemonic_, list<Trait> traits = []>
|
||||
: TypeDef<TPU_Dialect, name, traits> {
|
||||
let mnemonic = mnemonic_;
|
||||
@ -135,10 +148,9 @@ def TPU_AllReduceOp : TPU_Op<"all_reduce", [Pure, SameOperandsAndResultType]> {
|
||||
}];
|
||||
}
|
||||
|
||||
// TODO(apaszke): Verify that stores are of native size and are aligned.
|
||||
def TPU_StoreOp : TPU_Op<"store", [AttrSizedOperandSegments]> {
|
||||
let arguments = (ins
|
||||
AnyType:$valueToStore,
|
||||
TPU_Vreg:$valueToStore,
|
||||
AnyType:$base,
|
||||
Variadic<Index>:$indices,
|
||||
DenseBoolArrayAttr:$sublane_mask,
|
||||
@ -150,16 +162,16 @@ def TPU_StoreOp : TPU_Op<"store", [AttrSizedOperandSegments]> {
|
||||
}];
|
||||
}
|
||||
|
||||
// TODO(apaszke): Verify that loads are of native size and are aligned.
|
||||
def TPU_LoadOp : TPU_Op<"load"> {
|
||||
let arguments = (ins
|
||||
AnyType:$base,
|
||||
Variadic<Index>:$indices,
|
||||
DenseBoolArrayAttr:$sublane_mask
|
||||
DenseBoolArrayAttr:$sublane_mask,
|
||||
OptionalAttr<I32Attr>:$sublane_stride // In sublane-sized units
|
||||
);
|
||||
let results = (outs AnyType:$result);
|
||||
let results = (outs TPU_Vreg:$result);
|
||||
let assemblyFormat = [{
|
||||
$base `[` $indices `]` `sublanes` $sublane_mask attr-dict `:` type($base) `,` type($result)
|
||||
$base `[` $indices `]` `sublanes` $sublane_mask (`sublane_stride` $sublane_stride^)? attr-dict `:` type($base) `,` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -208,6 +208,10 @@ class VectorLayoutInferer {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::LoadOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::EraseLayoutOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
@ -529,6 +533,21 @@ class VectorLayoutInferer {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::LoadOp op) {
|
||||
auto res_ty = op.getResult().getType();
|
||||
int8_t bitwidth = res_ty.getElementTypeBitWidth();
|
||||
|
||||
// We expect the result is already a native-sized vreg.
|
||||
TPU_CHECK_OP(bitwidth == 32 && res_ty.getShape()[0] == target_shape_[0] &&
|
||||
res_ty.getShape()[1] == target_shape_[1],
|
||||
"Only 32-bit loads suppored");
|
||||
SmallVector<Layout, 4> in_layout(op->getNumOperands(), kNoLayout);
|
||||
auto out_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
|
||||
ImplicitDim::kNone);
|
||||
setLayout(op, in_layout, out_layout);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::EraseLayoutOp op) {
|
||||
setLayout(op, kNoLayout, kNoLayout);
|
||||
return success();
|
||||
|
@ -1697,6 +1697,36 @@ def _scf_yield_rule( # pylint: disable=missing-function-docstring
|
||||
return ctx.set_operands(op.operation, unrolled)
|
||||
|
||||
|
||||
@_register_rule("tpu.load")
|
||||
def _tpu_load_rule(
|
||||
ctx: RewriteContext,
|
||||
op: tpu.LoadOp,
|
||||
layout_in: Sequence[Layout],
|
||||
layout_out: VectorLayout,
|
||||
):
|
||||
assert all(li is None for li in layout_in)
|
||||
ty = ir.VectorType(op.result.type)
|
||||
|
||||
# We expect the result is already a native-sized vreg.
|
||||
if layout_out.bitwidth != 32:
|
||||
raise NotImplementedError("Only 32-bit loads supported")
|
||||
assert layout_out == VectorLayout(32, (0, 0), TARGET_SHAPE, None)
|
||||
|
||||
indices = [get_int_const(v, "tpu.load index") for v in op.indices]
|
||||
if indices[1] % TARGET_SHAPE.lanes:
|
||||
raise NotImplementedError(
|
||||
f"Lane index is not a multiple of {TARGET_SHAPE.lanes}"
|
||||
)
|
||||
|
||||
tile = tpu.LoadOp(
|
||||
ty,
|
||||
op.base,
|
||||
op.indices,
|
||||
op.sublane_mask,
|
||||
sublane_stride=op.sublane_stride,
|
||||
)
|
||||
return ctx.replace(op, assemble(ty, layout_out, np.asarray([[tile]])))
|
||||
|
||||
@_register_rule("tpu.trace")
|
||||
def _tpu_trace_rule(ctx: RewriteContext, op: tpu.TraceOp, # pylint: disable=missing-function-docstring
|
||||
layout_in: Layout, layout_out: Layout):
|
||||
|
Loading…
x
Reference in New Issue
Block a user