[Mosaic TPU] Use vmask pack if possible for mask's bitwidth change and introduce relayout op.

PiperOrigin-RevId: 719089676
This commit is contained in:
Jevin Jiang 2025-01-23 18:14:22 -08:00 committed by jax authors
parent 3864512b72
commit 8e1f956804
6 changed files with 143 additions and 38 deletions

View File

@ -407,6 +407,7 @@ def _lower_tpu_kernel(
(
"func.func(tpu-relayout-insertion{"
f" sublane-count={sl_cnt} lane-count={l_cnt}"
f" hardware-generation={hardware_generation}"
"})"
),
]

View File

@ -389,6 +389,21 @@ def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]>
let hasVerifier = 1;
}
def TPU_RelayoutOp : TPU_Op<"relayout", [SameOperandsAndResultType]> {
let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output);
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
}
def TPU_PackMaskOp : TPU_Op<"pack_vmsk", [Pure, SameTypeOperands]> {
let arguments = (ins
VectorOfNonZeroRankOf<[I1]>: $low,
VectorOfNonZeroRankOf<[I1]>: $high
);
let results = (outs VectorOfNonZeroRankOf<[I1]>:$output);
let assemblyFormat = [{ $low `,` $high `,` attr-dict `:` type($low) `,` type($high) `->` type($output) }];
}
def TPU_GatherOp : TPU_Op<"gather", [Pure]> {
let arguments = (ins
AnyVectorOfNonZeroRank:$source,
@ -891,6 +906,9 @@ def RelayoutInsertionPass : Pass<"tpu-relayout-insertion", "::mlir::func::FuncOp
];
let constructor = "::mlir::tpu::createRelayoutInsertionPass()";
let options = [
// If hardware_generation is not set, the default value of -1 will crash on
// runOnOperation.
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
];

View File

@ -84,6 +84,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
const TpuTilingFlags &tpu_tiling_flags = {});
std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
int hardware_generation = -1,
std::array<int64_t, 2> target_shape = {8, 128});
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(

View File

@ -2172,6 +2172,74 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op,
return success();
}
LogicalResult tpu_relayout_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
TPU_ASSERT_EQ_OP(op.getNumOperands(), 1);
TPU_ASSERT_EQ_OP(op.getNumResults(), 1);
TPU_ASSERT_EQ_OP(layouts_in.size(), 1);
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
TPU_ASSERT_OP(layouts_in[0].has_value());
TPU_ASSERT_OP(layouts_out[0].has_value());
const auto& in_layout = *layouts_in[0];
const auto& out_layout = *layouts_out[0];
auto realyout_op = cast<tpu::RelayoutOp>(op);
auto in_bitwidth = in_layout.bitwidth();
auto out_bitwidth = out_layout.bitwidth();
auto vty = cast<VectorType>(realyout_op.getType());
ImplicitLocOpBuilder builder(op.getLoc(), &op);
if (in_layout == out_layout) {
realyout_op.replaceAllUsesWith(realyout_op.getInput());
realyout_op.erase();
return success();
}
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> vals,
disassemble(builder, in_layout,
cast<TypedValue<VectorType>>(realyout_op.getInput()),
ctx.target_shape,
/*use_implicit_shape=*/true));
// Packing vector masks from 32-bit to 16-bit.
if (vty.getElementType() == builder.getI1Type() && in_bitwidth == 32 &&
out_bitwidth == 16 &&
in_layout.tiling()[0] == in_layout.packing() * ctx.target_shape[0] &&
in_layout.tiling()[1] == ctx.target_shape[1] &&
in_layout.tiling() == out_layout.tiling() &&
in_layout.offsets() == out_layout.offsets() &&
in_layout.implicit_dim() == out_layout.implicit_dim()) {
std::vector<int64_t> vmsks_shape(vals.dimensions().begin(),
vals.dimensions().end());
*(vmsks_shape.end() - 1) = llvm::divideCeil(vmsks_shape.back(), 2);
xla::Array<Value> out_vmsks(vmsks_shape, nullptr);
SmallVector<int64_t> val_idx;
Value default_val =
getFullLikeVector(builder, cast<TypedValue<VectorType>>(*vals.begin()),
IntegerAttr::get(builder.getI1Type(), 0));
out_vmsks.Each([&](absl::Span<const int64_t> idx, Value *v) {
val_idx.assign(idx.begin(), idx.end());
// TODO(jevinjiang): can be simplified when offset is replicated.
*(val_idx.end() - 1) *= 2;
Value low_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1)
? vals(val_idx)
: default_val;
*(val_idx.end() - 1) += 1;
Value high_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1)
? vals(val_idx)
: default_val;
const VectorType mask_ty = getNativeVregOrVmaskType(
builder.getI1Type(), in_bitwidth / 2, ctx.target_shape);
*v = builder.create<PackMaskOp>(mask_ty, low_part, high_part);
});
const RollVectorsOp rolled_op =
assemble(builder, vty, out_layout, out_vmsks, ctx.target_shape,
/*use_implicit_shape=*/true);
op.replaceAllUsesWith(rolled_op);
op.erase();
return success();
}
return op.emitOpError("Not implemented: unsupported layout change");
}
// TODO(b/347016737): Deprecate tpu.rotate and only use tpu.dynamic_rotate. So
// we do not need template for the op type and to explicitly force amount
// argument to dynamic.
@ -4644,9 +4712,9 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
return success();
}
LogicalResult prng_random_bits_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
TPU_ASSERT_EQ_OP(layouts_in.size(), 0);
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
TPU_ASSERT_OP(layouts_out.front().has_value());
@ -4711,7 +4779,8 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::BitcastOp::getOperationName(), tpu_bitcast_rule},
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
{tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule},
{tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule},
{tpu::PRNGRandomBitsOp::getOperationName(), tpu_prng_random_bits_rule},
{tpu::RelayoutOp::getOperationName(), tpu_relayout_rule},
{tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule},
{vector::BroadcastOp::getOperationName(), vector_broadcast_rule},
{vector::ExtractOp::getOperationName(), vector_extract_rule},

View File

@ -31,12 +31,40 @@ namespace {
FailureOr<TypedValue<VectorType>> relayout(
OpBuilder &builder, TypedValue<VectorType> v, VectorLayout src,
VectorLayout dst, const std::array<int64_t, 2> target_shape) {
VectorLayout dst, int hardware_generation,
const std::array<int64_t, 2> target_shape) {
// change bitwidth
if (v.getType().getElementType() == builder.getI1Type() &&
// TODO(jevinjiang): for other relayout changes (tiling, offsets, implicit
// dim), we currently rely on apply-vector-layout pass to do the relayout.
src.bitwidth() != dst.bitwidth()) {
auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling());
auto dst_bitwidth_layout = VectorLayout(
dst.bitwidth(),
{
src.offsets()[0].has_value() ? *src.offsets()[0] % vreg_slice[0]
: LayoutOffset(),
src.offsets()[1].has_value() ? *src.offsets()[1] % vreg_slice[1]
: LayoutOffset(),
},
src.tiling(), src.implicit_dim());
if (!dst_bitwidth_layout.isValid(target_shape)) {
return emitError(v.getLoc(),
"Not implemented: failed to infer valid layout during "
"relayout, got ")
<< dst_bitwidth_layout;
}
// We might be able to pack mask directly.
// TODO(jevinjiang): Add support for 16bit -> 8bit mask packing.
if (src.bitwidth() == 32 && dst.bitwidth() == 16 &&
// TODO(jevinjiang): support mask packing for non-native source tiling.
src.tiling()[0] == src.packing() * target_shape[0] &&
src.tiling()[1] == target_shape[1]) {
auto relayout_op =
builder.create<tpu::RelayoutOp>(v.getLoc(), v.getType(), v);
setLayout(relayout_op, src, dst_bitwidth_layout);
return cast<TypedValue<VectorType>>(relayout_op.getResult());
}
CHECK(llvm::isPowerOf2_32(src.bitwidth()));
CHECK(llvm::isPowerOf2_32(dst.bitwidth()));
auto make_vty = [&](int bitwidth) {
@ -56,25 +84,9 @@ FailureOr<TypedValue<VectorType>> relayout(
};
auto src_int_vty = make_vty(src.bitwidth());
auto dst_int_vty = make_vty(dst.bitwidth());
auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling());
// TODO(jevinjiang): Since dst_bitwidth_layout will be firstly used in the
// extSI or truncI below, we can reuse the inferExt and inferTrunc from
// infer-vector-layout pass.
auto dst_bitwidth_layout = VectorLayout(
dst.bitwidth(),
{
src.offsets()[0].has_value() ? *src.offsets()[0] % vreg_slice[0]
: LayoutOffset(),
src.offsets()[1].has_value() ? *src.offsets()[1] % vreg_slice[1]
: LayoutOffset(),
},
src.tiling(), src.implicit_dim());
if (!dst_bitwidth_layout.isValid(target_shape)) {
return emitError(v.getLoc(),
"Not implemented: failed to infer valid layout during "
"relayout, got ")
<< dst_bitwidth_layout;
}
auto ext_op = builder.create<arith::ExtUIOp>(v.getLoc(), src_int_vty, v);
setLayout(ext_op, src, src);
@ -98,7 +110,7 @@ FailureOr<TypedValue<VectorType>> relayout(
// TODO(jevinjiang): make relayout to an op so we don't need decide when to
// relayout in apply-vector-layout pass.
LogicalResult insertRelayout(Operation &op,
LogicalResult insertRelayout(Operation &op, int hardware_generation,
const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layouts,
getInLayouts(op, target_shape));
@ -136,9 +148,9 @@ LogicalResult insertRelayout(Operation &op,
continue;
}
OpBuilder builder(&op);
FAILUREOR_ASSIGN_OR_RETURN(Value new_v,
relayout(builder, vector_operand, /*src=*/*lo,
/*dst=*/*li, target_shape));
FAILUREOR_ASSIGN_OR_RETURN(
Value new_v, relayout(builder, vector_operand, /*src=*/*lo,
/*dst=*/*li, hardware_generation, target_shape));
op.setOperand(idx, new_v);
}
return success();
@ -146,14 +158,22 @@ LogicalResult insertRelayout(Operation &op,
struct RelayoutInsertionPass
: public impl::RelayoutInsertionPassBase<RelayoutInsertionPass> {
RelayoutInsertionPass(std::array<int64_t, 2> target_shape) {
RelayoutInsertionPass(int generation, std::array<int64_t, 2> target_shape) {
this->hardware_generation = generation;
this->sublane_count = target_shape[0];
this->lane_count = target_shape[1];
}
void runOnOperation() override {
// Fail if hardware_generation has not been set from the default value.
if (hardware_generation < 0) {
getOperation().emitError("hardware_generation must be set");
signalPassFailure();
return;
}
func::FuncOp func = getOperation();
auto result = func.walk([&](Operation *op) {
if (insertRelayout(*op, {sublane_count, lane_count}).failed()) {
if (insertRelayout(*op, hardware_generation, {sublane_count, lane_count})
.failed()) {
return WalkResult::interrupt();
}
return WalkResult::advance();
@ -168,8 +188,9 @@ struct RelayoutInsertionPass
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
std::array<int64_t, 2> target_shape) {
return std::make_unique<RelayoutInsertionPass>(target_shape);
int hardware_generation, std::array<int64_t, 2> target_shape) {
return std::make_unique<RelayoutInsertionPass>(hardware_generation,
target_shape);
}
} // namespace mlir::tpu

View File

@ -332,22 +332,17 @@ class OpsTest(PallasBaseTest):
dtype=[jnp.float32, jnp.bfloat16],
)
def test_i1_relayout_with_bitwidth_change(self, msk_dtype, dtype):
# TODO(jevinjiang): Remove after 12 weeks have passed.
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
self.skipTest("Requires libtpu built after 2024-12-19")
if not jtu.if_cloud_tpu_at_least(2025, 1, 25):
self.skipTest("Requires libtpu built after 2025-01-25")
shape = (129, 129)
msk_bitwidth = pallas_utils.dtype_bitwidth(msk_dtype)
bitwidth = pallas_utils.dtype_bitwidth(dtype)
if (
(jtu.get_tpu_version() > 5 and msk_bitwidth < 8)
or (jtu.get_tpu_version() == 5 and msk_bitwidth not in (8, 32))
or (jtu.get_tpu_version() < 5 and msk_bitwidth < 32)
):
if jtu.get_tpu_version() < 5 and msk_bitwidth < 32:
self.skipTest(
"Not implemented: cast vector to mask with bitwidth =="
f" {msk_bitwidth}"
)
if jtu.get_tpu_version() <= 5 and bitwidth < 32:
if jtu.get_tpu_version() < 5 and bitwidth < 32:
self.skipTest(f"Not implemented: comparison with bitwidth == {bitwidth}")
@functools.partial(