mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic TPU] Use vmask pack if possible for mask's bitwidth change and introduce relayout op.
PiperOrigin-RevId: 719089676
This commit is contained in:
parent
3864512b72
commit
8e1f956804
@ -407,6 +407,7 @@ def _lower_tpu_kernel(
|
||||
(
|
||||
"func.func(tpu-relayout-insertion{"
|
||||
f" sublane-count={sl_cnt} lane-count={l_cnt}"
|
||||
f" hardware-generation={hardware_generation}"
|
||||
"})"
|
||||
),
|
||||
]
|
||||
|
@ -389,6 +389,21 @@ def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]>
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TPU_RelayoutOp : TPU_Op<"relayout", [SameOperandsAndResultType]> {
|
||||
let arguments = (ins AnyType:$input);
|
||||
let results = (outs AnyType:$output);
|
||||
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
|
||||
}
|
||||
|
||||
def TPU_PackMaskOp : TPU_Op<"pack_vmsk", [Pure, SameTypeOperands]> {
|
||||
let arguments = (ins
|
||||
VectorOfNonZeroRankOf<[I1]>: $low,
|
||||
VectorOfNonZeroRankOf<[I1]>: $high
|
||||
);
|
||||
let results = (outs VectorOfNonZeroRankOf<[I1]>:$output);
|
||||
let assemblyFormat = [{ $low `,` $high `,` attr-dict `:` type($low) `,` type($high) `->` type($output) }];
|
||||
}
|
||||
|
||||
def TPU_GatherOp : TPU_Op<"gather", [Pure]> {
|
||||
let arguments = (ins
|
||||
AnyVectorOfNonZeroRank:$source,
|
||||
@ -891,6 +906,9 @@ def RelayoutInsertionPass : Pass<"tpu-relayout-insertion", "::mlir::func::FuncOp
|
||||
];
|
||||
let constructor = "::mlir::tpu::createRelayoutInsertionPass()";
|
||||
let options = [
|
||||
// If hardware_generation is not set, the default value of -1 will crash on
|
||||
// runOnOperation.
|
||||
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
|
||||
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
|
||||
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
|
||||
];
|
||||
|
@ -84,6 +84,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
|
||||
const TpuTilingFlags &tpu_tiling_flags = {});
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
|
||||
int hardware_generation = -1,
|
||||
std::array<int64_t, 2> target_shape = {8, 128});
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
|
||||
|
@ -2172,6 +2172,74 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult tpu_relayout_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
TPU_ASSERT_EQ_OP(op.getNumOperands(), 1);
|
||||
TPU_ASSERT_EQ_OP(op.getNumResults(), 1);
|
||||
TPU_ASSERT_EQ_OP(layouts_in.size(), 1);
|
||||
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
|
||||
TPU_ASSERT_OP(layouts_in[0].has_value());
|
||||
TPU_ASSERT_OP(layouts_out[0].has_value());
|
||||
const auto& in_layout = *layouts_in[0];
|
||||
const auto& out_layout = *layouts_out[0];
|
||||
auto realyout_op = cast<tpu::RelayoutOp>(op);
|
||||
auto in_bitwidth = in_layout.bitwidth();
|
||||
auto out_bitwidth = out_layout.bitwidth();
|
||||
auto vty = cast<VectorType>(realyout_op.getType());
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), &op);
|
||||
if (in_layout == out_layout) {
|
||||
realyout_op.replaceAllUsesWith(realyout_op.getInput());
|
||||
realyout_op.erase();
|
||||
return success();
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
xla::Array<Value> vals,
|
||||
disassemble(builder, in_layout,
|
||||
cast<TypedValue<VectorType>>(realyout_op.getInput()),
|
||||
ctx.target_shape,
|
||||
/*use_implicit_shape=*/true));
|
||||
// Packing vector masks from 32-bit to 16-bit.
|
||||
if (vty.getElementType() == builder.getI1Type() && in_bitwidth == 32 &&
|
||||
out_bitwidth == 16 &&
|
||||
in_layout.tiling()[0] == in_layout.packing() * ctx.target_shape[0] &&
|
||||
in_layout.tiling()[1] == ctx.target_shape[1] &&
|
||||
in_layout.tiling() == out_layout.tiling() &&
|
||||
in_layout.offsets() == out_layout.offsets() &&
|
||||
in_layout.implicit_dim() == out_layout.implicit_dim()) {
|
||||
std::vector<int64_t> vmsks_shape(vals.dimensions().begin(),
|
||||
vals.dimensions().end());
|
||||
*(vmsks_shape.end() - 1) = llvm::divideCeil(vmsks_shape.back(), 2);
|
||||
xla::Array<Value> out_vmsks(vmsks_shape, nullptr);
|
||||
SmallVector<int64_t> val_idx;
|
||||
Value default_val =
|
||||
getFullLikeVector(builder, cast<TypedValue<VectorType>>(*vals.begin()),
|
||||
IntegerAttr::get(builder.getI1Type(), 0));
|
||||
out_vmsks.Each([&](absl::Span<const int64_t> idx, Value *v) {
|
||||
val_idx.assign(idx.begin(), idx.end());
|
||||
// TODO(jevinjiang): can be simplified when offset is replicated.
|
||||
*(val_idx.end() - 1) *= 2;
|
||||
Value low_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1)
|
||||
? vals(val_idx)
|
||||
: default_val;
|
||||
*(val_idx.end() - 1) += 1;
|
||||
Value high_part = *(val_idx.end() - 1) < *(vals.dimensions().end() - 1)
|
||||
? vals(val_idx)
|
||||
: default_val;
|
||||
const VectorType mask_ty = getNativeVregOrVmaskType(
|
||||
builder.getI1Type(), in_bitwidth / 2, ctx.target_shape);
|
||||
*v = builder.create<PackMaskOp>(mask_ty, low_part, high_part);
|
||||
});
|
||||
const RollVectorsOp rolled_op =
|
||||
assemble(builder, vty, out_layout, out_vmsks, ctx.target_shape,
|
||||
/*use_implicit_shape=*/true);
|
||||
op.replaceAllUsesWith(rolled_op);
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
return op.emitOpError("Not implemented: unsupported layout change");
|
||||
}
|
||||
|
||||
// TODO(b/347016737): Deprecate tpu.rotate and only use tpu.dynamic_rotate. So
|
||||
// we do not need template for the op type and to explicitly force amount
|
||||
// argument to dynamic.
|
||||
@ -4644,9 +4712,9 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult prng_random_bits_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
TPU_ASSERT_EQ_OP(layouts_in.size(), 0);
|
||||
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
|
||||
TPU_ASSERT_OP(layouts_out.front().has_value());
|
||||
@ -4711,7 +4779,8 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
{tpu::BitcastOp::getOperationName(), tpu_bitcast_rule},
|
||||
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
|
||||
{tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule},
|
||||
{tpu::PRNGRandomBitsOp::getOperationName(), prng_random_bits_rule},
|
||||
{tpu::PRNGRandomBitsOp::getOperationName(), tpu_prng_random_bits_rule},
|
||||
{tpu::RelayoutOp::getOperationName(), tpu_relayout_rule},
|
||||
{tpu::FPToSIOp::getOperationName(), tpu_fptosi_rule},
|
||||
{vector::BroadcastOp::getOperationName(), vector_broadcast_rule},
|
||||
{vector::ExtractOp::getOperationName(), vector_extract_rule},
|
||||
|
@ -31,12 +31,40 @@ namespace {
|
||||
|
||||
FailureOr<TypedValue<VectorType>> relayout(
|
||||
OpBuilder &builder, TypedValue<VectorType> v, VectorLayout src,
|
||||
VectorLayout dst, const std::array<int64_t, 2> target_shape) {
|
||||
VectorLayout dst, int hardware_generation,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
// change bitwidth
|
||||
if (v.getType().getElementType() == builder.getI1Type() &&
|
||||
// TODO(jevinjiang): for other relayout changes (tiling, offsets, implicit
|
||||
// dim), we currently rely on apply-vector-layout pass to do the relayout.
|
||||
src.bitwidth() != dst.bitwidth()) {
|
||||
auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling());
|
||||
auto dst_bitwidth_layout = VectorLayout(
|
||||
dst.bitwidth(),
|
||||
{
|
||||
src.offsets()[0].has_value() ? *src.offsets()[0] % vreg_slice[0]
|
||||
: LayoutOffset(),
|
||||
src.offsets()[1].has_value() ? *src.offsets()[1] % vreg_slice[1]
|
||||
: LayoutOffset(),
|
||||
},
|
||||
src.tiling(), src.implicit_dim());
|
||||
if (!dst_bitwidth_layout.isValid(target_shape)) {
|
||||
return emitError(v.getLoc(),
|
||||
"Not implemented: failed to infer valid layout during "
|
||||
"relayout, got ")
|
||||
<< dst_bitwidth_layout;
|
||||
}
|
||||
// We might be able to pack mask directly.
|
||||
// TODO(jevinjiang): Add support for 16bit -> 8bit mask packing.
|
||||
if (src.bitwidth() == 32 && dst.bitwidth() == 16 &&
|
||||
// TODO(jevinjiang): support mask packing for non-native source tiling.
|
||||
src.tiling()[0] == src.packing() * target_shape[0] &&
|
||||
src.tiling()[1] == target_shape[1]) {
|
||||
auto relayout_op =
|
||||
builder.create<tpu::RelayoutOp>(v.getLoc(), v.getType(), v);
|
||||
setLayout(relayout_op, src, dst_bitwidth_layout);
|
||||
return cast<TypedValue<VectorType>>(relayout_op.getResult());
|
||||
}
|
||||
CHECK(llvm::isPowerOf2_32(src.bitwidth()));
|
||||
CHECK(llvm::isPowerOf2_32(dst.bitwidth()));
|
||||
auto make_vty = [&](int bitwidth) {
|
||||
@ -56,25 +84,9 @@ FailureOr<TypedValue<VectorType>> relayout(
|
||||
};
|
||||
auto src_int_vty = make_vty(src.bitwidth());
|
||||
auto dst_int_vty = make_vty(dst.bitwidth());
|
||||
auto vreg_slice = src.vregSlice(target_shape, dst.bitwidth(), src.tiling());
|
||||
// TODO(jevinjiang): Since dst_bitwidth_layout will be firstly used in the
|
||||
// extSI or truncI below, we can reuse the inferExt and inferTrunc from
|
||||
// infer-vector-layout pass.
|
||||
auto dst_bitwidth_layout = VectorLayout(
|
||||
dst.bitwidth(),
|
||||
{
|
||||
src.offsets()[0].has_value() ? *src.offsets()[0] % vreg_slice[0]
|
||||
: LayoutOffset(),
|
||||
src.offsets()[1].has_value() ? *src.offsets()[1] % vreg_slice[1]
|
||||
: LayoutOffset(),
|
||||
},
|
||||
src.tiling(), src.implicit_dim());
|
||||
if (!dst_bitwidth_layout.isValid(target_shape)) {
|
||||
return emitError(v.getLoc(),
|
||||
"Not implemented: failed to infer valid layout during "
|
||||
"relayout, got ")
|
||||
<< dst_bitwidth_layout;
|
||||
}
|
||||
auto ext_op = builder.create<arith::ExtUIOp>(v.getLoc(), src_int_vty, v);
|
||||
setLayout(ext_op, src, src);
|
||||
|
||||
@ -98,7 +110,7 @@ FailureOr<TypedValue<VectorType>> relayout(
|
||||
|
||||
// TODO(jevinjiang): make relayout to an op so we don't need decide when to
|
||||
// relayout in apply-vector-layout pass.
|
||||
LogicalResult insertRelayout(Operation &op,
|
||||
LogicalResult insertRelayout(Operation &op, int hardware_generation,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layouts,
|
||||
getInLayouts(op, target_shape));
|
||||
@ -136,9 +148,9 @@ LogicalResult insertRelayout(Operation &op,
|
||||
continue;
|
||||
}
|
||||
OpBuilder builder(&op);
|
||||
FAILUREOR_ASSIGN_OR_RETURN(Value new_v,
|
||||
relayout(builder, vector_operand, /*src=*/*lo,
|
||||
/*dst=*/*li, target_shape));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
Value new_v, relayout(builder, vector_operand, /*src=*/*lo,
|
||||
/*dst=*/*li, hardware_generation, target_shape));
|
||||
op.setOperand(idx, new_v);
|
||||
}
|
||||
return success();
|
||||
@ -146,14 +158,22 @@ LogicalResult insertRelayout(Operation &op,
|
||||
|
||||
struct RelayoutInsertionPass
|
||||
: public impl::RelayoutInsertionPassBase<RelayoutInsertionPass> {
|
||||
RelayoutInsertionPass(std::array<int64_t, 2> target_shape) {
|
||||
RelayoutInsertionPass(int generation, std::array<int64_t, 2> target_shape) {
|
||||
this->hardware_generation = generation;
|
||||
this->sublane_count = target_shape[0];
|
||||
this->lane_count = target_shape[1];
|
||||
}
|
||||
void runOnOperation() override {
|
||||
// Fail if hardware_generation has not been set from the default value.
|
||||
if (hardware_generation < 0) {
|
||||
getOperation().emitError("hardware_generation must be set");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
func::FuncOp func = getOperation();
|
||||
auto result = func.walk([&](Operation *op) {
|
||||
if (insertRelayout(*op, {sublane_count, lane_count}).failed()) {
|
||||
if (insertRelayout(*op, hardware_generation, {sublane_count, lane_count})
|
||||
.failed()) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
@ -168,8 +188,9 @@ struct RelayoutInsertionPass
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
|
||||
std::array<int64_t, 2> target_shape) {
|
||||
return std::make_unique<RelayoutInsertionPass>(target_shape);
|
||||
int hardware_generation, std::array<int64_t, 2> target_shape) {
|
||||
return std::make_unique<RelayoutInsertionPass>(hardware_generation,
|
||||
target_shape);
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
@ -332,22 +332,17 @@ class OpsTest(PallasBaseTest):
|
||||
dtype=[jnp.float32, jnp.bfloat16],
|
||||
)
|
||||
def test_i1_relayout_with_bitwidth_change(self, msk_dtype, dtype):
|
||||
# TODO(jevinjiang): Remove after 12 weeks have passed.
|
||||
if not jtu.if_cloud_tpu_at_least(2024, 12, 19):
|
||||
self.skipTest("Requires libtpu built after 2024-12-19")
|
||||
if not jtu.if_cloud_tpu_at_least(2025, 1, 25):
|
||||
self.skipTest("Requires libtpu built after 2025-01-25")
|
||||
shape = (129, 129)
|
||||
msk_bitwidth = pallas_utils.dtype_bitwidth(msk_dtype)
|
||||
bitwidth = pallas_utils.dtype_bitwidth(dtype)
|
||||
if (
|
||||
(jtu.get_tpu_version() > 5 and msk_bitwidth < 8)
|
||||
or (jtu.get_tpu_version() == 5 and msk_bitwidth not in (8, 32))
|
||||
or (jtu.get_tpu_version() < 5 and msk_bitwidth < 32)
|
||||
):
|
||||
if jtu.get_tpu_version() < 5 and msk_bitwidth < 32:
|
||||
self.skipTest(
|
||||
"Not implemented: cast vector to mask with bitwidth =="
|
||||
f" {msk_bitwidth}"
|
||||
)
|
||||
if jtu.get_tpu_version() <= 5 and bitwidth < 32:
|
||||
if jtu.get_tpu_version() < 5 and bitwidth < 32:
|
||||
self.skipTest(f"Not implemented: comparison with bitwidth == {bitwidth}")
|
||||
|
||||
@functools.partial(
|
||||
|
Loading…
x
Reference in New Issue
Block a user