mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] apply_vector_layout C++ rewrite (17): for.op (from cl/568376871)
PiperOrigin-RevId: 571155269
This commit is contained in:
parent
8f911e1512
commit
f4bb1c0c62
@ -798,6 +798,20 @@ LogicalResult func_return_rule(RewriteContext &ctx, Operation &op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
scf::ForOp for_op = cast<scf::ForOp>(op);
|
||||
CHECK_EQ(layouts_in.size(), 3 + for_op.getInitArgs().size());
|
||||
CHECK_EQ(layouts_out.size(), for_op.getResults().size());
|
||||
if (!for_op.getInitArgs().empty() || !for_op.getResults().empty()) {
|
||||
return for_op.emitOpError("Not implemented: inputs and outputs in scf.for");
|
||||
}
|
||||
// It is an invariant that scf::ForOp should have a single region with a
|
||||
// single block (checked by MLIR verifier).
|
||||
return applyLayoutBlock(ctx, for_op.getRegion().getBlocks().front());
|
||||
}
|
||||
|
||||
LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
@ -2554,6 +2568,7 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
rules_elementwise_op_entry<math::RsqrtOp, 1>(),
|
||||
rules_elementwise_op_entry<math::TanhOp, 1>(),
|
||||
{func::ReturnOp::getOperationName(), func_return_rule},
|
||||
{scf::ForOp::getOperationName(), scf_for_rule},
|
||||
{scf::IfOp::getOperationName(), scf_if_rule},
|
||||
{scf::YieldOp::getOperationName(), scf_yield_rule},
|
||||
{tpu::IotaOp::getOperationName(), tpu_iota_rule},
|
||||
|
Loading…
x
Reference in New Issue
Block a user