mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Support scf.while and scf.condition.
This allows lowering while loops of a more general form than "for i" loops. Improving generality here allows us to implement more interesting dynamic looping behaviors, such as progressive scans in VMEM. PiperOrigin-RevId: 625411151
This commit is contained in:
parent
1a650cdc00
commit
5bd6013e76
@ -1000,6 +1000,185 @@ LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult scf_while_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
scf::WhileOp while_op = cast<scf::WhileOp>(op);
|
||||
TPU_ASSERT_EQ_OP(layouts_in.size(), while_op->getNumOperands());
|
||||
TPU_ASSERT_EQ_OP(layouts_out.size(), while_op->getNumResults());
|
||||
TPU_ASSERT_EQ_OP(layouts_in.size(), layouts_out.size());
|
||||
|
||||
// The terminator for the before region is the condition op.
|
||||
// It takes multiple arguments -- the first being the decision to execute the
|
||||
// after region or branch to the exit.
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const SmallVector<Layout> condition_in_layouts,
|
||||
getInLayouts(*while_op.getBeforeBody()->getTerminator(),
|
||||
ctx.target_shape));
|
||||
if (!llvm::equal(ArrayRef<Layout>(condition_in_layouts).drop_front(1),
|
||||
layouts_out)) {
|
||||
return op.emitOpError(
|
||||
"Mismatched layouts between scf.while result and its before region "
|
||||
"condition.");
|
||||
}
|
||||
|
||||
if (failed(applyLayoutBlock(ctx, *while_op.getBeforeBody()))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const SmallVector<Layout> after_yield_in_layouts,
|
||||
getInLayouts(*while_op.getYieldOp(), ctx.target_shape));
|
||||
if (!layouts_out.empty() &&
|
||||
ArrayRef<Layout>(after_yield_in_layouts) != layouts_out) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: different layouts while's yield's operands and "
|
||||
"results");
|
||||
}
|
||||
|
||||
if (failed(applyLayoutBlock(ctx, *while_op.getAfterBody()))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (op.getNumResults() == 0) {
|
||||
return success();
|
||||
}
|
||||
|
||||
OpBuilder builder(&op);
|
||||
SmallVector<Value> unrolled_args;
|
||||
for (int i = 0; i < layouts_in.size(); ++i) {
|
||||
auto layout = layouts_in[i];
|
||||
auto operand = while_op.getOperand(i);
|
||||
if (auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand)) {
|
||||
if (!layout.has_value()) {
|
||||
return op.emitOpError("Expected layout for vector operand");
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const xla::Array<Value> tiles,
|
||||
disassemble(builder, *layout, vector_operand, ctx.target_shape));
|
||||
unrolled_args.append(tiles.begin(), tiles.end());
|
||||
} else {
|
||||
if (layout.has_value()) {
|
||||
return op.emitOpError("Expected no layout for scalar operand");
|
||||
}
|
||||
unrolled_args.push_back(operand);
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new scf::WhileOp with unrolled args.
|
||||
auto new_op = builder.create<scf::WhileOp>(
|
||||
while_op->getLoc(),
|
||||
TypeRange(while_op.getConditionOp().getOperands().drop_front(1)),
|
||||
unrolled_args, nullptr, nullptr);
|
||||
|
||||
const auto tile_body_args = [&](::mlir::Block *old_body,
|
||||
::mlir::Block *new_body,
|
||||
const ArrayRef<Layout> layouts) {
|
||||
TPU_ASSERT_OP(old_body != nullptr);
|
||||
TPU_ASSERT_OP(new_body != nullptr);
|
||||
int num_old_args = old_body->getNumArguments();
|
||||
SmallVector<Location> locs(new_body->getNumArguments(), while_op.getLoc());
|
||||
old_body->addArguments(TypeRange(new_body->getArguments()), locs);
|
||||
builder.setInsertionPointToStart(old_body);
|
||||
auto arg_idx = num_old_args;
|
||||
for (auto [old_arg, layout] : llvm::zip_equal(
|
||||
old_body->getArguments().take_front(num_old_args), layouts)) {
|
||||
if (const auto vty = dyn_cast<VectorType>(old_arg.getType())) {
|
||||
TPU_ASSERT_OP(layout.has_value());
|
||||
const SmallVector<int64_t> tiles_shape =
|
||||
layout->tileArrayShape(vty.getShape(), ctx.target_shape);
|
||||
const int64_t num_vectors = ShapedType::getNumElements(tiles_shape);
|
||||
xla::Array<Value> tiles(tiles_shape);
|
||||
TPU_ASSERT_LE_OP(arg_idx + num_vectors, old_body->getNumArguments());
|
||||
tiles.SetValues(llvm::make_range(
|
||||
old_body->getArguments().begin() + arg_idx,
|
||||
old_body->getArguments().begin() + arg_idx + num_vectors));
|
||||
arg_idx += num_vectors;
|
||||
RollVectorsOp rolled_op =
|
||||
assemble(builder, vty, *layout, tiles, ctx.target_shape);
|
||||
old_arg.replaceUsesWithIf(rolled_op, [&](OpOperand &operand) {
|
||||
return operand.getOwner() != rolled_op;
|
||||
});
|
||||
} else {
|
||||
TPU_ASSERT_OP(!layout.has_value());
|
||||
old_arg.replaceAllUsesWith(old_body->getArgument(arg_idx));
|
||||
++arg_idx;
|
||||
}
|
||||
}
|
||||
old_body->eraseArguments(0, num_old_args);
|
||||
return success();
|
||||
};
|
||||
|
||||
const auto before_status = tile_body_args(while_op.getBeforeBody(),
|
||||
new_op.getBeforeBody(), layouts_in);
|
||||
if (before_status.failed()) return before_status;
|
||||
new_op.getBefore().takeBody(while_op.getBefore());
|
||||
|
||||
const auto after_status = tile_body_args(while_op.getAfterBody(),
|
||||
new_op.getAfterBody(), layouts_out);
|
||||
if (after_status.failed()) return after_status;
|
||||
new_op.getAfter().takeBody(while_op.getAfter());
|
||||
|
||||
builder.setInsertionPointAfter(new_op);
|
||||
int64_t res_idx = 0;
|
||||
SmallVector<Value> rolled_results;
|
||||
for (auto [result, layout] :
|
||||
llvm::zip_equal(while_op.getResults(), layouts_out)) {
|
||||
if (const auto vty = dyn_cast<VectorType>(result.getType())) {
|
||||
TPU_ASSERT_OP(layout.has_value());
|
||||
const SmallVector<int64_t> tiles_shape =
|
||||
layout->tileArrayShape(vty.getShape(), ctx.target_shape);
|
||||
const int64_t num_vectors = ShapedType::getNumElements(tiles_shape);
|
||||
xla::Array<Value> tiles(tiles_shape);
|
||||
TPU_ASSERT_LE_OP(res_idx + num_vectors, new_op.getResults().size());
|
||||
tiles.SetValues(llvm::make_range(
|
||||
new_op.getResults().begin() + res_idx,
|
||||
new_op.getResults().begin() + res_idx + num_vectors));
|
||||
res_idx += num_vectors;
|
||||
RollVectorsOp rolled_op =
|
||||
assemble(builder, vty, *layout, tiles, ctx.target_shape);
|
||||
rolled_results.push_back(rolled_op);
|
||||
} else {
|
||||
TPU_ASSERT_OP(!layout.has_value());
|
||||
rolled_results.push_back(new_op.getResult(res_idx));
|
||||
++res_idx;
|
||||
}
|
||||
}
|
||||
|
||||
while_op.replaceAllUsesWith(rolled_results);
|
||||
while_op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult scf_condition_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
OpBuilder builder(&op);
|
||||
auto condition_op = cast<scf::ConditionOp>(op);
|
||||
TPU_ASSERT_EQ_OP(layouts_in.size(), condition_op.getNumOperands());
|
||||
TPU_ASSERT_EQ_OP(layouts_out.size(), 0);
|
||||
SmallVector<Value> unrolled;
|
||||
|
||||
for (auto [operand, layout] :
|
||||
llvm::zip_equal(condition_op.getOperands(), layouts_in)) {
|
||||
if (auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand)) {
|
||||
// When the operand has vector type, disassemble the operand.
|
||||
TPU_ASSERT_OP(layout.has_value());
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const xla::Array<Value> tiles,
|
||||
disassemble(builder, *layout, vector_operand, ctx.target_shape));
|
||||
unrolled.append(tiles.begin(), tiles.end());
|
||||
} else {
|
||||
TPU_ASSERT_OP(!layout.has_value());
|
||||
unrolled.push_back(operand);
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the old operands with unrolled operands.
|
||||
condition_op->setOperands(unrolled);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
@ -3634,6 +3813,8 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
{arith::TruncIOp::getOperationName(), arith_trunci_rule},
|
||||
{func::ReturnOp::getOperationName(), func_return_rule},
|
||||
{scf::ForOp::getOperationName(), scf_for_rule},
|
||||
{scf::WhileOp::getOperationName(), scf_while_rule},
|
||||
{scf::ConditionOp::getOperationName(), scf_condition_rule},
|
||||
{scf::IfOp::getOperationName(), scf_if_rule},
|
||||
{scf::YieldOp::getOperationName(), scf_yield_rule},
|
||||
{tpu::RotateOp::getOperationName(), tpu_rotate_rule},
|
||||
|
@ -220,6 +220,14 @@ class VectorLayoutInferer {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<scf::WhileOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<scf::ConditionOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::RotateOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
@ -536,6 +544,118 @@ class VectorLayoutInferer {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(scf::WhileOp op) {
|
||||
static LogicalResult (*match_condition)(Operation *) = [](Operation *op) {
|
||||
TPU_CHECK_OP(isa<scf::ConditionOp>(op), "expected condition terminator");
|
||||
return success();
|
||||
};
|
||||
static LogicalResult (*match_yield)(Operation *) = [](Operation *op) {
|
||||
TPU_CHECK_OP(isa<scf::YieldOp>(op), "expected yield terminator");
|
||||
return success();
|
||||
};
|
||||
TPU_CHECK_OP(op.getNumRegions() == 2, "expected two blocks for scf.while");
|
||||
|
||||
const auto layout_for_type = [&op, this](const ::mlir::Value &arg,
|
||||
SmallVector<Layout> *layouts) {
|
||||
if (arg.getType().isSignlessIntOrIndexOrFloat()) {
|
||||
layouts->push_back(kNoLayout);
|
||||
} else if (isa<VectorType>(arg.getType())) {
|
||||
auto layout = getLayout(arg);
|
||||
layouts->push_back(layout);
|
||||
} else {
|
||||
op.emitOpError() << "unsupported arg type " << arg.getType()
|
||||
<< " in scf.while";
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
SmallVector<Layout> in_layouts;
|
||||
in_layouts.reserve(op->getNumOperands());
|
||||
for (const auto &arg : op.getInits()) {
|
||||
const auto status = layout_for_type(arg, &in_layouts);
|
||||
if (status.failed()) return status;
|
||||
}
|
||||
|
||||
// Formally, the types and layouts of the results should follow the layout
|
||||
// of the condition op in the Before region, rather than mimicking the input
|
||||
// layouts. In practice these are constrained to be the same for our current
|
||||
// pipelines, but doesn't represent the full expressiveness of scf.while.
|
||||
// TODO(hmckenzie): Base output layout on ConditionOp, not inputs.
|
||||
SmallVector<Layout> out_layouts = in_layouts;
|
||||
|
||||
// Use tpu.assume_layout to annotate every block argument with the layout of
|
||||
// the corresponding operand in WhileOp and replace all uses of the block
|
||||
// argument with the result of tpu.assume_layout.
|
||||
ImplicitLocOpBuilder builder =
|
||||
ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getBeforeBody());
|
||||
for (auto [iter_arg, layout] :
|
||||
llvm::zip_equal(op.getBeforeBody()->getArguments(), in_layouts)) {
|
||||
if (!dyn_cast<VectorType>(iter_arg.getType())) {
|
||||
continue;
|
||||
}
|
||||
auto assume_layout_op =
|
||||
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
|
||||
setLayout(assume_layout_op, layout, layout);
|
||||
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
|
||||
return operand.getOwner() != assume_layout_op;
|
||||
});
|
||||
}
|
||||
if (inferBlock(*op.getBeforeBody(), match_condition).failed()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
builder =
|
||||
ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getAfterBody());
|
||||
for (auto [iter_arg, layout] :
|
||||
llvm::zip_equal(op.getAfterBody()->getArguments(), out_layouts)) {
|
||||
if (!dyn_cast<VectorType>(iter_arg.getType())) {
|
||||
continue;
|
||||
}
|
||||
auto assume_layout_op =
|
||||
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
|
||||
setLayout(assume_layout_op, layout, layout);
|
||||
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
|
||||
return operand.getOwner() != assume_layout_op;
|
||||
});
|
||||
}
|
||||
|
||||
if (inferBlock(*op.getAfterBody(), match_yield).failed()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto *condition_op = op.getBeforeBody()->getTerminator();
|
||||
SmallVector<Layout> cond_layout;
|
||||
cond_layout.reserve(out_layouts.size() + 1);
|
||||
cond_layout.push_back(kNoLayout);
|
||||
cond_layout.append(out_layouts);
|
||||
setInLayout(condition_op, cond_layout);
|
||||
|
||||
auto *yield_op = op.getAfterBody()->getTerminator();
|
||||
setInLayout(yield_op, in_layouts);
|
||||
|
||||
setLayout(op, in_layouts, out_layouts);
|
||||
return success();
|
||||
}
|
||||
LogicalResult infer(scf::ConditionOp op) {
|
||||
SmallVector<Layout> in_layouts;
|
||||
in_layouts.reserve(op->getNumOperands());
|
||||
for (const auto &arg : op.getOperands()) {
|
||||
if (arg.getType().isSignlessIntOrIndexOrFloat()) {
|
||||
in_layouts.push_back(kNoLayout);
|
||||
} else if (isa<VectorType>(arg.getType())) {
|
||||
auto layout = getLayout(arg);
|
||||
in_layouts.push_back(layout);
|
||||
} else {
|
||||
op.emitOpError() << "unsupported arg type " << arg.getType()
|
||||
<< " in scf::condition";
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
setLayout(op, in_layouts, ArrayRef<Layout>(in_layouts).drop_front(1));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult infer(tpu::RotateOp op) {
|
||||
auto bitwidth = op.getType().getElementTypeBitWidth();
|
||||
if (bitwidth != 32) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user