[XLA:Mosaic] Fix infer/apply vector layout rule for terminators (scf::yieldOp, scf::conditionOp).

We should infer layout for each terminator inside its own region and find a compatible layout for a final result if the result is based on terminators from multiple regions like scf::ifOp, scf::whileOp, scf::forOp. If no compatible layout is found, we will fall back to a normalized layout. Finally we also need to ensure the layouts in input, terminator and output are consistent across loops.

PiperOrigin-RevId: 639122434
This commit is contained in:
Jevin Jiang 2024-05-31 12:46:49 -07:00 committed by jax authors
parent d9f07d0350
commit 389bf93abf
2 changed files with 339 additions and 155 deletions

View File

@ -913,16 +913,35 @@ LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op,
scf::ForOp for_op = cast<scf::ForOp>(op);
TPU_ASSERT_EQ_OP(layouts_in.size(), for_op->getNumOperands());
TPU_ASSERT_EQ_OP(layouts_out.size(), for_op->getNumResults());
if (!llvm::equal(layouts_in.drop_front(3), layouts_out)) {
return op.emitOpError(
"Expected matched layouts in scf.for's inputs and outputs");
}
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<Layout> yield_in_layouts,
getInLayouts(*for_op.getBody()->getTerminator(), ctx.target_shape));
if (!llvm::equal(ArrayRef<Layout>(yield_in_layouts), layouts_out)) {
return op.emitOpError(
"Expected matched layouts in scf.yield operands and scf.for's results");
int out_idx = 0;
for (auto [in_layout, yield_layout, out_layout, result] :
llvm::zip_equal(layouts_in.drop_front(3), yield_in_layouts, layouts_out,
op.getResults())) {
if (auto vty = dyn_cast<VectorType>(result.getType())) {
TPU_ASSERT_OP(in_layout.has_value());
TPU_ASSERT_OP(yield_layout.has_value());
TPU_ASSERT_OP(out_layout.has_value());
if (in_layout.value() != yield_layout.value()) {
return op.emitOpError(
"Not implemented: for loop input layout does not match with "
"yield layout ")
<< out_idx;
}
if (in_layout.value() != out_layout.value()) {
return op.emitOpError(
"Not implemented: for loop input layout does not match with "
"out layout ")
<< out_idx;
}
} else {
TPU_ASSERT_EQ_OP(in_layout, kNoLayout);
TPU_ASSERT_EQ_OP(yield_layout, kNoLayout);
TPU_ASSERT_EQ_OP(out_layout, kNoLayout);
}
++out_idx;
}
if (failed(applyLayoutBlock(ctx, *for_op.getBody()))) {
@ -1047,30 +1066,53 @@ LogicalResult scf_while_rule(RewriteContext &ctx, Operation &op,
// It takes multiple arguments -- the first being the decision to execute the
// after region or branch to the exit.
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<Layout> condition_in_layouts,
const SmallVector<Layout> cond_in_layouts,
getInLayouts(*while_op.getBeforeBody()->getTerminator(),
ctx.target_shape));
if (!llvm::equal(ArrayRef<Layout>(condition_in_layouts).drop_front(1),
layouts_out)) {
return op.emitOpError(
"Mismatched layouts between scf.while result and its before region "
"condition.");
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<Layout> yield_in_layouts,
getInLayouts(*while_op.getYieldOp(), ctx.target_shape));
int out_idx = 0;
for (auto [in_layout, cond_layout, yield_layout, out_layout, result] :
llvm::zip_equal(layouts_in,
ArrayRef<Layout>(cond_in_layouts).drop_front(1),
yield_in_layouts, layouts_out, op.getResults())) {
if (auto vty = dyn_cast<VectorType>(result.getType())) {
TPU_ASSERT_OP(in_layout.has_value());
TPU_ASSERT_OP(yield_layout.has_value());
TPU_ASSERT_OP(out_layout.has_value());
if (in_layout.value() != cond_layout.value()) {
return op.emitOpError(
"Not implemented: while loop input layout does not match "
"with condition layout ")
<< out_idx;
}
if (in_layout.value() != yield_layout.value()) {
return op.emitOpError(
"Not implemented: while loop input layout does not match "
"with yield layout ")
<< out_idx;
}
if (in_layout.value() != out_layout.value()) {
return op.emitOpError(
"Not implemented: while loop input layout does not match "
"with output layout ")
<< out_idx;
}
} else {
TPU_ASSERT_EQ_OP(in_layout, kNoLayout);
TPU_ASSERT_EQ_OP(cond_layout, kNoLayout);
TPU_ASSERT_EQ_OP(yield_layout, kNoLayout);
TPU_ASSERT_EQ_OP(out_layout, kNoLayout);
}
++out_idx;
}
if (failed(applyLayoutBlock(ctx, *while_op.getBeforeBody()))) {
return failure();
}
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<Layout> after_yield_in_layouts,
getInLayouts(*while_op.getYieldOp(), ctx.target_shape));
if (!layouts_out.empty() &&
ArrayRef<Layout>(after_yield_in_layouts) != layouts_out) {
return op.emitOpError(
"Not implemented: different layouts while's yield's operands and "
"results");
}
if (failed(applyLayoutBlock(ctx, *while_op.getAfterBody()))) {
return failure();
}
@ -1221,17 +1263,42 @@ LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_OP(!layouts_in.front().has_value());
ImplicitLocOpBuilder builder(op.getLoc(), &op);
scf::IfOp if_op = cast<scf::IfOp>(op);
SmallVector<Layout, 4> then_yield_in_layouts;
SmallVector<Layout, 4> else_yield_in_layouts;
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<Layout> then_yield_in_layouts,
then_yield_in_layouts,
getInLayouts(*if_op.thenYield(), ctx.target_shape));
// TODO(tlongeri): ArrayRef<Layout> conversion should not be necessary, fix
// after LLVM adds const qualifiers to ==/!= operators. Also
// applies to else_yield_in_layouts comparison below.
if (!layouts_out.empty() &&
ArrayRef<Layout>(then_yield_in_layouts) != layouts_out) {
return op.emitOpError(
"Not implemented: different layouts in then yield's operands and if's "
"results");
if (!if_op.getElseRegion().empty()) {
FAILUREOR_ASSIGN_OR_RETURN(
else_yield_in_layouts,
getInLayouts(*if_op.elseYield(), ctx.target_shape));
}
int out_idx = 0;
for (auto [then_layout, else_layout, result_layout, result] :
llvm::zip_equal(then_yield_in_layouts, else_yield_in_layouts,
layouts_out, op.getResults())) {
if (auto vty = dyn_cast<VectorType>(result.getType())) {
TPU_ASSERT_OP(then_layout.has_value());
TPU_ASSERT_OP(else_layout.has_value());
TPU_ASSERT_OP(result_layout.has_value());
if (result_layout.value() != then_layout.value()) {
return op.emitOpError(
"Not implemented: yield layout from then branch does not "
"match with output layout ")
<< out_idx;
}
if (result_layout.value() != else_layout.value()) {
return op.emitOpError(
"Not implemented: yield layout from else branch does not "
"match with output layout ")
<< out_idx;
}
} else {
TPU_ASSERT_EQ_OP(then_layout, kNoLayout);
TPU_ASSERT_EQ_OP(else_layout, kNoLayout);
TPU_ASSERT_EQ_OP(result_layout, kNoLayout);
}
++out_idx;
}
if (failed(applyLayoutBlock(ctx, *if_op.thenBlock()))) {
return failure();
@ -1241,15 +1308,6 @@ LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_EQ_OP(layouts_out.size(), 0);
return success();
}
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<Layout> else_yield_in_layouts,
getInLayouts(*if_op.elseYield(), ctx.target_shape));
if (!layouts_out.empty() &&
ArrayRef<Layout>(else_yield_in_layouts) != layouts_out) {
return op.emitOpError(
"Not implemented: different layouts in else yield's operands and if's "
"results");
}
if (failed(applyLayoutBlock(ctx, *if_op.elseBlock()))) {
return failure();
}

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
@ -121,7 +122,10 @@ class VectorLayoutInferer {
LogicalResult inferBlock(
Block &block,
const std::function<LogicalResult(Operation *)> &match_terminator) {
const std::function<LogicalResult(Operation *)> &match_terminator,
// TODO(jevinjiang): Propagate this flag deeper because it won't work when
// there is an op with blocks inside this block.
bool override_layout = false) {
for (Operation &any_op : block.without_terminator()) {
VLOG(kLayoutLog) << Print(&any_op);
if (any_op.hasAttr("in_layout") || any_op.hasAttr("out_layout")) {
@ -130,6 +134,8 @@ class VectorLayoutInferer {
any_op.hasAttr("in_layout") && any_op.hasAttr("out_layout"),
"expect layout attributes in tpu::AssumeLayoutOp");
continue;
} else if (override_layout) {
// Intend to override the layouts attribute.
} else {
any_op.emitOpError("layout attributes already attached");
return failure();
@ -220,10 +226,6 @@ class VectorLayoutInferer {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<scf::ConditionOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::RotateOp>(any_op)) {
if (infer(op).failed()) {
return failure();
@ -427,19 +429,7 @@ class VectorLayoutInferer {
auto then_yield = op.thenBlock()->getTerminator();
TPU_CHECK_OP(then_yield->getOperandTypes() == op->getResultTypes(),
"scf if results and then branch yield operands do not match");
SmallVector<Layout, 4> result_layout;
result_layout.reserve(then_yield->getNumOperands());
for (const auto &operand : then_yield->getOperands()) {
if (operand.getType().isSignlessIntOrIndexOrFloat()) {
result_layout.push_back(kNoLayout);
} else if (isa<VectorType>(operand.getType())) {
result_layout.push_back(getLayout(operand));
} else {
op.emitOpError("unsupported scf.yield type");
return failure();
}
}
auto then_yield_in_layouts = getLayoutFromOperands(then_yield);
if (auto else_block = op.elseBlock()) {
if (inferBlock(*else_block, match_yield).failed()) {
op.emitOpError("failed to infer layout for else branch");
@ -454,32 +444,53 @@ class VectorLayoutInferer {
auto else_yield = op.elseBlock()->getTerminator();
TPU_CHECK_OP(else_yield->getOperandTypes() == op->getResultTypes(),
"scf if results and else branch yield operands do not match");
// Check each layout of the yield in else branch and override the
// result_layout if else branch's yield layout is less general. For example,
// if we yield offset (*, *) in then branch and offset (*, 0) in else
// branch, the result offset should be (*, 0).
for (int i = 0; i < else_yield->getNumOperands(); ++i) {
const auto &operand = else_yield->getOperand(i);
if (!isa<VectorType>(operand.getType())) {
continue;
}
auto shape = dyn_cast<VectorType>(operand.getType()).getShape();
auto layout = getLayout(operand);
CHECK(result_layout[i].has_value() && layout.has_value());
result_layout[i] =
VectorLayout::join(result_layout[i].value(), layout.value(), shape);
if (!result_layout[i].has_value()) {
op.emitOpError(
"failed to find a compatible layout in then and else branch for "
"output ")
<< i;
return failure();
auto else_yield_in_layouts = getLayoutFromOperands(else_yield);
// Find a compatible layout from then and else branches for each reuslt. For
// example, if we yield offset (*, *) in then branch and offset (*, 0) in
// else branch, the result offset should be (*, 0).
SmallVector<Layout, 4> out_layouts;
out_layouts.reserve(op->getNumResults());
int out_idx = 0;
for (auto [then_layout, else_layout, result] : llvm::zip_equal(
then_yield_in_layouts, else_yield_in_layouts, op.getResults())) {
if (auto vty = dyn_cast<VectorType>(result.getType())) {
if (!then_layout.has_value()) {
return op.emitOpError(
"expected a vector layout for then yield input ")
<< out_idx;
}
if (!else_layout.has_value()) {
return op.emitOpError(
"expected a vector layout for else yield input ")
<< out_idx;
}
auto compatible_layout = VectorLayout::join(
then_layout.value(), else_layout.value(), vty.getShape());
// If no compatible layout is found in layouts for then and else
// branches, the output layout falls back to a normalized layout which
// has offsets 0 and the native tiling.
if (!compatible_layout.has_value()) {
compatible_layout = VectorLayout(
then_layout->bitwidth(), {0, 0},
nativeTiling(then_layout->bitwidth()), ImplicitDim::kNone);
}
out_layouts.push_back(compatible_layout);
} else {
if (then_layout.has_value()) {
return op.emitOpError("expected no layout for then yield input ")
<< out_idx;
}
if (else_layout.has_value()) {
return op.emitOpError("expected no layout for else yield input ")
<< out_idx;
}
out_layouts.push_back(kNoLayout);
}
++out_idx;
}
setInLayout(then_yield, result_layout);
setInLayout(else_yield, result_layout);
setOutLayout(op, result_layout);
setInLayout(then_yield, out_layouts);
setInLayout(else_yield, out_layouts);
setOutLayout(op, out_layouts);
return success();
}
@ -497,24 +508,11 @@ class VectorLayoutInferer {
op->getNumOperands() == 3 + op.getNumResults(),
"expected num_operands is equal to 3 + num_results in scf.for");
SmallVector<Layout, 4> in_layouts;
in_layouts.reserve(op->getNumOperands());
in_layouts.push_back(kNoLayout); // Lower bound.
in_layouts.push_back(kNoLayout); // Upper bound.
in_layouts.push_back(kNoLayout); // Step.
for (const auto &arg : op.getInitArgs()) {
if (arg.getType().isSignlessIntOrIndexOrFloat()) {
in_layouts.push_back(kNoLayout);
} else if (isa<VectorType>(arg.getType())) {
auto layout = getLayout(arg);
in_layouts.push_back(layout);
} else {
op.emitOpError() << "unsupported arg type " << arg.getType()
<< " in scf::for";
return failure();
}
}
ArrayRef<Layout> out_layouts = ArrayRef<Layout>(in_layouts).drop_front(3);
SmallVector<Layout, 4> in_layouts = getLayoutFromOperands(op);
// Drop the first 3 layouts for lower bound, upper bound and step.
ArrayRef<Layout> arg_layouts = ArrayRef<Layout>(in_layouts).drop_front(3);
SmallVector<tpu::AssumeLayoutOp, 4> assume_layout_ops;
assume_layout_ops.reserve(arg_layouts.size());
// Use tpu.assume_layout to annotate every block argument with the layout of
// the corresponding operand in forOp and replace all uses of the block
// argument with the result of tpu.assume_layout.
@ -523,13 +521,15 @@ class VectorLayoutInferer {
// Drop the induction_variable and layouts of bounds+step (respectively).
for (auto [iter_arg, layout] : llvm::zip_equal(
op.getBody()->getArguments().drop_front(1), out_layouts)) {
op.getBody()->getArguments().drop_front(1), arg_layouts)) {
if (!dyn_cast<VectorType>(iter_arg.getType())) {
assume_layout_ops.push_back(nullptr);
continue;
}
auto assume_layout_op =
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
setLayout(assume_layout_op, layout, layout);
assume_layout_ops.push_back(assume_layout_op);
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
return operand.getOwner() != assume_layout_op;
});
@ -539,6 +539,72 @@ class VectorLayoutInferer {
return failure();
}
auto yield_op = op.getBody()->getTerminator();
auto yield_in_layouts = getLayoutFromOperands(yield_op);
SmallVector<Layout, 4> out_layouts;
out_layouts.reserve(op->getNumResults());
int out_idx = 0;
bool require_reinfer = false;
for (auto [in_layout, yield_layout, result] :
llvm::zip_equal(ArrayRef<Layout>(in_layouts).drop_front(3),
yield_in_layouts, op.getResults())) {
if (auto vty = dyn_cast<VectorType>(result.getType())) {
if (!in_layout.has_value()) {
return op.emitOpError("expected a vector layout for input ")
<< out_idx;
}
if (!yield_layout.has_value()) {
return op.emitOpError("expected a vector layout for yield input ")
<< out_idx;
}
auto compatible_layout = VectorLayout::join(
in_layout.value(), yield_layout.value(), vty.getShape());
// If no compatible layout is found in layouts for input and
// yield, the output layout falls back to a normalized layout which
// has offsets 0 and the native tiling.
if (!compatible_layout.has_value()) {
compatible_layout = VectorLayout(in_layout->bitwidth(), {0, 0},
nativeTiling(in_layout->bitwidth()),
ImplicitDim::kNone);
}
if (!require_reinfer &&
(compatible_layout.value() != in_layout.value() ||
compatible_layout.value() != yield_layout.value())) {
require_reinfer = true;
}
out_layouts.push_back(compatible_layout);
} else {
if (in_layout.has_value()) {
return op.emitOpError("expected no layout for input ") << out_idx;
}
if (yield_layout.has_value()) {
return op.emitOpError("expected no layout for yield input ")
<< out_idx;
}
out_layouts.push_back(kNoLayout);
}
++out_idx;
}
if (require_reinfer) {
// Terminator in the loop will carry layouts to the next loop but
// the loop's block args' layouts are determined by the initial inputs. We
// need to force the same layouts for all in order to make layouts be
// consistent across all branches. To ensure that, we need to reprocess
// layout inference for the entire body with the final consolidated
// layout.
for (int64_t i = 0; i < out_layouts.size(); ++i) {
if (assume_layout_ops[i]) {
setLayout(assume_layout_ops[i], out_layouts[i], out_layouts[i]);
}
}
if (inferBlock(*op.getBody(), match_yield, /*override_layout=*/true)
.failed()) {
return op.emitOpError("failed to infer layout for scf.for op");
}
std::copy(out_layouts.begin(), out_layouts.end(),
in_layouts.begin() + 3); // Skip first 3 layouts for lower
// bound, upper bound and step.
}
setInLayout(yield_op, out_layouts);
setLayout(op, in_layouts, out_layouts);
return success();
@ -555,34 +621,11 @@ class VectorLayoutInferer {
};
TPU_CHECK_OP(op.getNumRegions() == 2, "expected two blocks for scf.while");
const auto layout_for_type = [&op, this](const ::mlir::Value &arg,
SmallVector<Layout> *layouts) {
if (arg.getType().isSignlessIntOrIndexOrFloat()) {
layouts->push_back(kNoLayout);
} else if (isa<VectorType>(arg.getType())) {
auto layout = getLayout(arg);
layouts->push_back(layout);
} else {
op.emitOpError() << "unsupported arg type " << arg.getType()
<< " in scf.while";
return failure();
}
return success();
};
SmallVector<Layout> in_layouts;
in_layouts.reserve(op->getNumOperands());
for (const auto &arg : op.getInits()) {
const auto status = layout_for_type(arg, &in_layouts);
if (status.failed()) return status;
}
// Formally, the types and layouts of the results should follow the layout
// of the condition op in the Before region, rather than mimicking the input
// layouts. In practice these are constrained to be the same for our current
// pipelines, but doesn't represent the full expressiveness of scf.while.
// TODO(hmckenzie): Base output layout on ConditionOp, not inputs.
SmallVector<Layout> out_layouts = in_layouts;
SmallVector<Layout, 4> in_layouts = getLayoutFromOperands(op);
SmallVector<tpu::AssumeLayoutOp, 4> before_assume_layout_ops;
before_assume_layout_ops.reserve(in_layouts.size());
SmallVector<tpu::AssumeLayoutOp, 4> after_assume_layout_ops;
after_assume_layout_ops.reserve(in_layouts.size());
// Use tpu.assume_layout to annotate every block argument with the layout of
// the corresponding operand in WhileOp and replace all uses of the block
@ -592,11 +635,13 @@ class VectorLayoutInferer {
for (auto [iter_arg, layout] :
llvm::zip_equal(op.getBeforeBody()->getArguments(), in_layouts)) {
if (!dyn_cast<VectorType>(iter_arg.getType())) {
before_assume_layout_ops.push_back(nullptr);
continue;
}
auto assume_layout_op =
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
setLayout(assume_layout_op, layout, layout);
before_assume_layout_ops.push_back(assume_layout_op);
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
return operand.getOwner() != assume_layout_op;
});
@ -608,13 +653,15 @@ class VectorLayoutInferer {
builder =
ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getAfterBody());
for (auto [iter_arg, layout] :
llvm::zip_equal(op.getAfterBody()->getArguments(), out_layouts)) {
llvm::zip_equal(op.getAfterBody()->getArguments(), in_layouts)) {
if (!dyn_cast<VectorType>(iter_arg.getType())) {
after_assume_layout_ops.push_back(nullptr);
continue;
}
auto assume_layout_op =
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
setLayout(assume_layout_op, layout, layout);
after_assume_layout_ops.push_back(assume_layout_op);
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
return operand.getOwner() != assume_layout_op;
});
@ -624,35 +671,101 @@ class VectorLayoutInferer {
return failure();
}
auto *condition_op = op.getBeforeBody()->getTerminator();
SmallVector<Layout> cond_layout;
cond_layout.reserve(out_layouts.size() + 1);
cond_layout.push_back(kNoLayout);
cond_layout.append(out_layouts);
setInLayout(condition_op, cond_layout);
auto *cond_op = op.getBeforeBody()->getTerminator();
auto cond_in_layouts = getLayoutFromOperands(cond_op);
auto *yield_op = op.getAfterBody()->getTerminator();
setInLayout(yield_op, in_layouts);
auto yield_in_layouts = getLayoutFromOperands(yield_op);
setLayout(op, in_layouts, out_layouts);
return success();
}
LogicalResult infer(scf::ConditionOp op) {
SmallVector<Layout> in_layouts;
in_layouts.reserve(op->getNumOperands());
for (const auto &arg : op.getOperands()) {
if (arg.getType().isSignlessIntOrIndexOrFloat()) {
in_layouts.push_back(kNoLayout);
} else if (isa<VectorType>(arg.getType())) {
auto layout = getLayout(arg);
in_layouts.push_back(layout);
// Find a compatible layout from condition body and loop body for each
// reuslt. For example, if we yield offset (*, *) in condition body and
// offset (*, 0) in loop body, the result offset should be (*, 0).
SmallVector<Layout, 4> out_layouts;
out_layouts.reserve(op->getNumResults());
int out_idx = 0;
bool require_reinfer = false;
for (auto [in_layout, cond_layout, yield_layout, result] : llvm::zip_equal(
in_layouts, ArrayRef<Layout>(cond_in_layouts).drop_front(1),
yield_in_layouts, op.getResults())) {
if (auto vty = dyn_cast<VectorType>(result.getType())) {
if (!in_layout.has_value()) {
return op.emitOpError("expected a vector layout for whileOp input ")
<< out_idx;
}
if (!cond_layout.has_value()) {
return op.emitOpError("expected a vector layout for condition input ")
<< out_idx + 1; // ConditionOp's first input is 1 bit bool.
}
if (!yield_layout.has_value()) {
return op.emitOpError("expected a vector layout for yield input ")
<< out_idx;
}
auto compatible_layout = VectorLayout::join(
cond_layout.value(), yield_layout.value(), vty.getShape());
if (compatible_layout.has_value()) {
compatible_layout = VectorLayout::join(
in_layout.value(), compatible_layout.value(), vty.getShape());
}
// If no compatible layout is found in layouts for input, condition and
// yield, the output layout falls back to a normalized layout which
// has offsets 0 and the native tiling.
if (!compatible_layout.has_value()) {
compatible_layout = VectorLayout(in_layout->bitwidth(), {0, 0},
nativeTiling(in_layout->bitwidth()),
ImplicitDim::kNone);
}
if (!require_reinfer &&
(compatible_layout.value() != in_layout.value() ||
compatible_layout.value() != cond_layout.value() ||
compatible_layout.value() != yield_layout.value())) {
require_reinfer = true;
}
out_layouts.push_back(compatible_layout);
} else {
op.emitOpError() << "unsupported arg type " << arg.getType()
<< " in scf::condition";
return failure();
if (in_layout.has_value()) {
return op.emitOpError("expected no layout for whileOp input ")
<< out_idx;
}
if (cond_layout.has_value()) {
return op.emitOpError("expected no layout for condition input ")
<< out_idx + 1; // ConditionOp's first input is 1 bit bool.
}
if (yield_layout.has_value()) {
return op.emitOpError("expected no layout for yield input ")
<< out_idx;
}
out_layouts.push_back(kNoLayout);
}
++out_idx;
}
if (require_reinfer) {
// Terminator in the loop will carry layouts to the next loop but
// the loop's block args' layouts are determined by the initial inputs. We
// need to force the same layouts for all in order to make layouts be
// consistent across all branches. To ensure that, we need to reprocess
// layout inference for the entire body with the final consolidated
// layout.
for (int64_t i = 0; i < out_layouts.size(); ++i) {
if (before_assume_layout_ops[i]) {
setLayout(before_assume_layout_ops[i], out_layouts[i],
out_layouts[i]);
}
if (after_assume_layout_ops[i]) {
setLayout(after_assume_layout_ops[i], out_layouts[i], out_layouts[i]);
}
}
if (inferBlock(*op.getBeforeBody(), match_condition,
/*override_layout=*/true)
.failed() ||
inferBlock(*op.getAfterBody(), match_yield, /*override_layout=*/true)
.failed()) {
return op.emitOpError("failed to infer layout for scf.while op");
}
}
setLayout(op, in_layouts, ArrayRef<Layout>(in_layouts).drop_front(1));
std::copy(out_layouts.begin(), out_layouts.end(),
cond_in_layouts.begin() + 1); // Skip the first 1 bit bool.
setInLayout(cond_op, cond_in_layouts);
setInLayout(yield_op, out_layouts);
setLayout(op, out_layouts, out_layouts);
return success();
}
@ -1807,6 +1920,19 @@ class VectorLayoutInferer {
return cast<VectorLayoutAttr>(out_attrs[result_index]).getLayout();
}
SmallVector<Layout, 4> getLayoutFromOperands(Operation *op) {
SmallVector<Layout, 4> layouts;
layouts.reserve(op->getNumOperands());
for (const auto &operand : op->getOperands()) {
if (isa<VectorType>(operand.getType())) {
layouts.push_back(getLayout(operand));
} else {
layouts.push_back(kNoLayout);
}
}
return layouts;
}
private:
std::optional<absl::Span<const int64_t>> verifyMemoryTiling(
Operation *op, ArrayRef<xla::Tile> mem_tiling, int64_t rank,