[Mosaic] apply_vector_layout C++ rewrite (17): for.op (from cl/568376871)

PiperOrigin-RevId: 571155269
This commit is contained in:
Tomás Longeri 2023-10-05 16:01:37 -07:00 committed by jax authors
parent 8f911e1512
commit f4bb1c0c62

View File

@ -798,6 +798,20 @@ LogicalResult func_return_rule(RewriteContext &ctx, Operation &op,
return success();
}
LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
scf::ForOp for_op = cast<scf::ForOp>(op);
CHECK_EQ(layouts_in.size(), 3 + for_op.getInitArgs().size());
CHECK_EQ(layouts_out.size(), for_op.getResults().size());
if (!for_op.getInitArgs().empty() || !for_op.getResults().empty()) {
return for_op.emitOpError("Not implemented: inputs and outputs in scf.for");
}
// It is an invariant that scf::ForOp should have a single region with a
// single block (checked by MLIR verifier).
return applyLayoutBlock(ctx, for_op.getRegion().getBlocks().front());
}
LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
@ -2554,6 +2568,7 @@ const llvm::StringMap<rule_type> &rules() {
rules_elementwise_op_entry<math::RsqrtOp, 1>(),
rules_elementwise_op_entry<math::TanhOp, 1>(),
{func::ReturnOp::getOperationName(), func_return_rule},
{scf::ForOp::getOperationName(), scf_for_rule},
{scf::IfOp::getOperationName(), scf_if_rule},
{scf::YieldOp::getOperationName(), scf_yield_rule},
{tpu::IotaOp::getOperationName(), tpu_iota_rule},