mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[XLA:Mosaic] Fix infer/apply vector layout rule for terminators (scf::yieldOp, scf::conditionOp).
We should infer layout for each terminator inside its own region and find a compatible layout for a final result if the result is based on terminators from multiple regions like scf::ifOp, scf::whileOp, scf::forOp. If no compatible layout is found, we will fall back to a normalized layout. Finally we also need to ensure the layouts in input, terminator and output are consistent across loops. PiperOrigin-RevId: 639122434
This commit is contained in:
parent
d9f07d0350
commit
389bf93abf
@ -913,16 +913,35 @@ LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op,
|
||||
scf::ForOp for_op = cast<scf::ForOp>(op);
|
||||
TPU_ASSERT_EQ_OP(layouts_in.size(), for_op->getNumOperands());
|
||||
TPU_ASSERT_EQ_OP(layouts_out.size(), for_op->getNumResults());
|
||||
if (!llvm::equal(layouts_in.drop_front(3), layouts_out)) {
|
||||
return op.emitOpError(
|
||||
"Expected matched layouts in scf.for's inputs and outputs");
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const SmallVector<Layout> yield_in_layouts,
|
||||
getInLayouts(*for_op.getBody()->getTerminator(), ctx.target_shape));
|
||||
if (!llvm::equal(ArrayRef<Layout>(yield_in_layouts), layouts_out)) {
|
||||
return op.emitOpError(
|
||||
"Expected matched layouts in scf.yield operands and scf.for's results");
|
||||
int out_idx = 0;
|
||||
for (auto [in_layout, yield_layout, out_layout, result] :
|
||||
llvm::zip_equal(layouts_in.drop_front(3), yield_in_layouts, layouts_out,
|
||||
op.getResults())) {
|
||||
if (auto vty = dyn_cast<VectorType>(result.getType())) {
|
||||
TPU_ASSERT_OP(in_layout.has_value());
|
||||
TPU_ASSERT_OP(yield_layout.has_value());
|
||||
TPU_ASSERT_OP(out_layout.has_value());
|
||||
if (in_layout.value() != yield_layout.value()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: for loop input layout does not match with "
|
||||
"yield layout ")
|
||||
<< out_idx;
|
||||
}
|
||||
if (in_layout.value() != out_layout.value()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: for loop input layout does not match with "
|
||||
"out layout ")
|
||||
<< out_idx;
|
||||
}
|
||||
} else {
|
||||
TPU_ASSERT_EQ_OP(in_layout, kNoLayout);
|
||||
TPU_ASSERT_EQ_OP(yield_layout, kNoLayout);
|
||||
TPU_ASSERT_EQ_OP(out_layout, kNoLayout);
|
||||
}
|
||||
++out_idx;
|
||||
}
|
||||
|
||||
if (failed(applyLayoutBlock(ctx, *for_op.getBody()))) {
|
||||
@ -1047,30 +1066,53 @@ LogicalResult scf_while_rule(RewriteContext &ctx, Operation &op,
|
||||
// It takes multiple arguments -- the first being the decision to execute the
|
||||
// after region or branch to the exit.
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const SmallVector<Layout> condition_in_layouts,
|
||||
const SmallVector<Layout> cond_in_layouts,
|
||||
getInLayouts(*while_op.getBeforeBody()->getTerminator(),
|
||||
ctx.target_shape));
|
||||
if (!llvm::equal(ArrayRef<Layout>(condition_in_layouts).drop_front(1),
|
||||
layouts_out)) {
|
||||
return op.emitOpError(
|
||||
"Mismatched layouts between scf.while result and its before region "
|
||||
"condition.");
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const SmallVector<Layout> yield_in_layouts,
|
||||
getInLayouts(*while_op.getYieldOp(), ctx.target_shape));
|
||||
int out_idx = 0;
|
||||
for (auto [in_layout, cond_layout, yield_layout, out_layout, result] :
|
||||
llvm::zip_equal(layouts_in,
|
||||
ArrayRef<Layout>(cond_in_layouts).drop_front(1),
|
||||
yield_in_layouts, layouts_out, op.getResults())) {
|
||||
if (auto vty = dyn_cast<VectorType>(result.getType())) {
|
||||
TPU_ASSERT_OP(in_layout.has_value());
|
||||
TPU_ASSERT_OP(yield_layout.has_value());
|
||||
TPU_ASSERT_OP(out_layout.has_value());
|
||||
if (in_layout.value() != cond_layout.value()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: while loop input layout does not match "
|
||||
"with condition layout ")
|
||||
<< out_idx;
|
||||
}
|
||||
if (in_layout.value() != yield_layout.value()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: while loop input layout does not match "
|
||||
"with yield layout ")
|
||||
<< out_idx;
|
||||
}
|
||||
if (in_layout.value() != out_layout.value()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: while loop input layout does not match "
|
||||
"with output layout ")
|
||||
<< out_idx;
|
||||
}
|
||||
} else {
|
||||
TPU_ASSERT_EQ_OP(in_layout, kNoLayout);
|
||||
TPU_ASSERT_EQ_OP(cond_layout, kNoLayout);
|
||||
TPU_ASSERT_EQ_OP(yield_layout, kNoLayout);
|
||||
TPU_ASSERT_EQ_OP(out_layout, kNoLayout);
|
||||
}
|
||||
++out_idx;
|
||||
}
|
||||
|
||||
if (failed(applyLayoutBlock(ctx, *while_op.getBeforeBody()))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const SmallVector<Layout> after_yield_in_layouts,
|
||||
getInLayouts(*while_op.getYieldOp(), ctx.target_shape));
|
||||
if (!layouts_out.empty() &&
|
||||
ArrayRef<Layout>(after_yield_in_layouts) != layouts_out) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: different layouts while's yield's operands and "
|
||||
"results");
|
||||
}
|
||||
|
||||
if (failed(applyLayoutBlock(ctx, *while_op.getAfterBody()))) {
|
||||
return failure();
|
||||
}
|
||||
@ -1221,17 +1263,42 @@ LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op,
|
||||
TPU_ASSERT_OP(!layouts_in.front().has_value());
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), &op);
|
||||
scf::IfOp if_op = cast<scf::IfOp>(op);
|
||||
SmallVector<Layout, 4> then_yield_in_layouts;
|
||||
SmallVector<Layout, 4> else_yield_in_layouts;
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const SmallVector<Layout> then_yield_in_layouts,
|
||||
then_yield_in_layouts,
|
||||
getInLayouts(*if_op.thenYield(), ctx.target_shape));
|
||||
// TODO(tlongeri): ArrayRef<Layout> conversion should not be necessary, fix
|
||||
// after LLVM adds const qualifiers to ==/!= operators. Also
|
||||
// applies to else_yield_in_layouts comparison below.
|
||||
if (!layouts_out.empty() &&
|
||||
ArrayRef<Layout>(then_yield_in_layouts) != layouts_out) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: different layouts in then yield's operands and if's "
|
||||
"results");
|
||||
if (!if_op.getElseRegion().empty()) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
else_yield_in_layouts,
|
||||
getInLayouts(*if_op.elseYield(), ctx.target_shape));
|
||||
}
|
||||
int out_idx = 0;
|
||||
for (auto [then_layout, else_layout, result_layout, result] :
|
||||
llvm::zip_equal(then_yield_in_layouts, else_yield_in_layouts,
|
||||
layouts_out, op.getResults())) {
|
||||
if (auto vty = dyn_cast<VectorType>(result.getType())) {
|
||||
TPU_ASSERT_OP(then_layout.has_value());
|
||||
TPU_ASSERT_OP(else_layout.has_value());
|
||||
TPU_ASSERT_OP(result_layout.has_value());
|
||||
if (result_layout.value() != then_layout.value()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: yield layout from then branch does not "
|
||||
"match with output layout ")
|
||||
<< out_idx;
|
||||
}
|
||||
if (result_layout.value() != else_layout.value()) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: yield layout from else branch does not "
|
||||
"match with output layout ")
|
||||
<< out_idx;
|
||||
}
|
||||
} else {
|
||||
TPU_ASSERT_EQ_OP(then_layout, kNoLayout);
|
||||
TPU_ASSERT_EQ_OP(else_layout, kNoLayout);
|
||||
TPU_ASSERT_EQ_OP(result_layout, kNoLayout);
|
||||
}
|
||||
++out_idx;
|
||||
}
|
||||
if (failed(applyLayoutBlock(ctx, *if_op.thenBlock()))) {
|
||||
return failure();
|
||||
@ -1241,15 +1308,6 @@ LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op,
|
||||
TPU_ASSERT_EQ_OP(layouts_out.size(), 0);
|
||||
return success();
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const SmallVector<Layout> else_yield_in_layouts,
|
||||
getInLayouts(*if_op.elseYield(), ctx.target_shape));
|
||||
if (!layouts_out.empty() &&
|
||||
ArrayRef<Layout>(else_yield_in_layouts) != layouts_out) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: different layouts in else yield's operands and if's "
|
||||
"results");
|
||||
}
|
||||
if (failed(applyLayoutBlock(ctx, *if_op.elseBlock()))) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/include/mlir/IR/Attributes.h"
|
||||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
|
||||
@ -121,7 +122,10 @@ class VectorLayoutInferer {
|
||||
|
||||
LogicalResult inferBlock(
|
||||
Block &block,
|
||||
const std::function<LogicalResult(Operation *)> &match_terminator) {
|
||||
const std::function<LogicalResult(Operation *)> &match_terminator,
|
||||
// TODO(jevinjiang): Propagate this flag deeper because it won't work when
|
||||
// there is an op with blocks inside this block.
|
||||
bool override_layout = false) {
|
||||
for (Operation &any_op : block.without_terminator()) {
|
||||
VLOG(kLayoutLog) << Print(&any_op);
|
||||
if (any_op.hasAttr("in_layout") || any_op.hasAttr("out_layout")) {
|
||||
@ -130,6 +134,8 @@ class VectorLayoutInferer {
|
||||
any_op.hasAttr("in_layout") && any_op.hasAttr("out_layout"),
|
||||
"expect layout attributes in tpu::AssumeLayoutOp");
|
||||
continue;
|
||||
} else if (override_layout) {
|
||||
// Intend to override the layouts attribute.
|
||||
} else {
|
||||
any_op.emitOpError("layout attributes already attached");
|
||||
return failure();
|
||||
@ -220,10 +226,6 @@ class VectorLayoutInferer {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<scf::ConditionOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto op = dyn_cast<tpu::RotateOp>(any_op)) {
|
||||
if (infer(op).failed()) {
|
||||
return failure();
|
||||
@ -427,19 +429,7 @@ class VectorLayoutInferer {
|
||||
auto then_yield = op.thenBlock()->getTerminator();
|
||||
TPU_CHECK_OP(then_yield->getOperandTypes() == op->getResultTypes(),
|
||||
"scf if results and then branch yield operands do not match");
|
||||
SmallVector<Layout, 4> result_layout;
|
||||
result_layout.reserve(then_yield->getNumOperands());
|
||||
for (const auto &operand : then_yield->getOperands()) {
|
||||
if (operand.getType().isSignlessIntOrIndexOrFloat()) {
|
||||
result_layout.push_back(kNoLayout);
|
||||
} else if (isa<VectorType>(operand.getType())) {
|
||||
result_layout.push_back(getLayout(operand));
|
||||
} else {
|
||||
op.emitOpError("unsupported scf.yield type");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
auto then_yield_in_layouts = getLayoutFromOperands(then_yield);
|
||||
if (auto else_block = op.elseBlock()) {
|
||||
if (inferBlock(*else_block, match_yield).failed()) {
|
||||
op.emitOpError("failed to infer layout for else branch");
|
||||
@ -454,32 +444,53 @@ class VectorLayoutInferer {
|
||||
auto else_yield = op.elseBlock()->getTerminator();
|
||||
TPU_CHECK_OP(else_yield->getOperandTypes() == op->getResultTypes(),
|
||||
"scf if results and else branch yield operands do not match");
|
||||
|
||||
// Check each layout of the yield in else branch and override the
|
||||
// result_layout if else branch's yield layout is less general. For example,
|
||||
// if we yield offset (*, *) in then branch and offset (*, 0) in else
|
||||
// branch, the result offset should be (*, 0).
|
||||
for (int i = 0; i < else_yield->getNumOperands(); ++i) {
|
||||
const auto &operand = else_yield->getOperand(i);
|
||||
if (!isa<VectorType>(operand.getType())) {
|
||||
continue;
|
||||
}
|
||||
auto shape = dyn_cast<VectorType>(operand.getType()).getShape();
|
||||
auto layout = getLayout(operand);
|
||||
CHECK(result_layout[i].has_value() && layout.has_value());
|
||||
result_layout[i] =
|
||||
VectorLayout::join(result_layout[i].value(), layout.value(), shape);
|
||||
if (!result_layout[i].has_value()) {
|
||||
op.emitOpError(
|
||||
"failed to find a compatible layout in then and else branch for "
|
||||
"output ")
|
||||
<< i;
|
||||
return failure();
|
||||
auto else_yield_in_layouts = getLayoutFromOperands(else_yield);
|
||||
// Find a compatible layout from then and else branches for each reuslt. For
|
||||
// example, if we yield offset (*, *) in then branch and offset (*, 0) in
|
||||
// else branch, the result offset should be (*, 0).
|
||||
SmallVector<Layout, 4> out_layouts;
|
||||
out_layouts.reserve(op->getNumResults());
|
||||
int out_idx = 0;
|
||||
for (auto [then_layout, else_layout, result] : llvm::zip_equal(
|
||||
then_yield_in_layouts, else_yield_in_layouts, op.getResults())) {
|
||||
if (auto vty = dyn_cast<VectorType>(result.getType())) {
|
||||
if (!then_layout.has_value()) {
|
||||
return op.emitOpError(
|
||||
"expected a vector layout for then yield input ")
|
||||
<< out_idx;
|
||||
}
|
||||
if (!else_layout.has_value()) {
|
||||
return op.emitOpError(
|
||||
"expected a vector layout for else yield input ")
|
||||
<< out_idx;
|
||||
}
|
||||
auto compatible_layout = VectorLayout::join(
|
||||
then_layout.value(), else_layout.value(), vty.getShape());
|
||||
// If no compatible layout is found in layouts for then and else
|
||||
// branches, the output layout falls back to a normalized layout which
|
||||
// has offsets 0 and the native tiling.
|
||||
if (!compatible_layout.has_value()) {
|
||||
compatible_layout = VectorLayout(
|
||||
then_layout->bitwidth(), {0, 0},
|
||||
nativeTiling(then_layout->bitwidth()), ImplicitDim::kNone);
|
||||
}
|
||||
out_layouts.push_back(compatible_layout);
|
||||
} else {
|
||||
if (then_layout.has_value()) {
|
||||
return op.emitOpError("expected no layout for then yield input ")
|
||||
<< out_idx;
|
||||
}
|
||||
if (else_layout.has_value()) {
|
||||
return op.emitOpError("expected no layout for else yield input ")
|
||||
<< out_idx;
|
||||
}
|
||||
out_layouts.push_back(kNoLayout);
|
||||
}
|
||||
++out_idx;
|
||||
}
|
||||
setInLayout(then_yield, result_layout);
|
||||
setInLayout(else_yield, result_layout);
|
||||
setOutLayout(op, result_layout);
|
||||
setInLayout(then_yield, out_layouts);
|
||||
setInLayout(else_yield, out_layouts);
|
||||
setOutLayout(op, out_layouts);
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -497,24 +508,11 @@ class VectorLayoutInferer {
|
||||
op->getNumOperands() == 3 + op.getNumResults(),
|
||||
"expected num_operands is equal to 3 + num_results in scf.for");
|
||||
|
||||
SmallVector<Layout, 4> in_layouts;
|
||||
in_layouts.reserve(op->getNumOperands());
|
||||
in_layouts.push_back(kNoLayout); // Lower bound.
|
||||
in_layouts.push_back(kNoLayout); // Upper bound.
|
||||
in_layouts.push_back(kNoLayout); // Step.
|
||||
for (const auto &arg : op.getInitArgs()) {
|
||||
if (arg.getType().isSignlessIntOrIndexOrFloat()) {
|
||||
in_layouts.push_back(kNoLayout);
|
||||
} else if (isa<VectorType>(arg.getType())) {
|
||||
auto layout = getLayout(arg);
|
||||
in_layouts.push_back(layout);
|
||||
} else {
|
||||
op.emitOpError() << "unsupported arg type " << arg.getType()
|
||||
<< " in scf::for";
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
ArrayRef<Layout> out_layouts = ArrayRef<Layout>(in_layouts).drop_front(3);
|
||||
SmallVector<Layout, 4> in_layouts = getLayoutFromOperands(op);
|
||||
// Drop the first 3 layouts for lower bound, upper bound and step.
|
||||
ArrayRef<Layout> arg_layouts = ArrayRef<Layout>(in_layouts).drop_front(3);
|
||||
SmallVector<tpu::AssumeLayoutOp, 4> assume_layout_ops;
|
||||
assume_layout_ops.reserve(arg_layouts.size());
|
||||
// Use tpu.assume_layout to annotate every block argument with the layout of
|
||||
// the corresponding operand in forOp and replace all uses of the block
|
||||
// argument with the result of tpu.assume_layout.
|
||||
@ -523,13 +521,15 @@ class VectorLayoutInferer {
|
||||
|
||||
// Drop the induction_variable and layouts of bounds+step (respectively).
|
||||
for (auto [iter_arg, layout] : llvm::zip_equal(
|
||||
op.getBody()->getArguments().drop_front(1), out_layouts)) {
|
||||
op.getBody()->getArguments().drop_front(1), arg_layouts)) {
|
||||
if (!dyn_cast<VectorType>(iter_arg.getType())) {
|
||||
assume_layout_ops.push_back(nullptr);
|
||||
continue;
|
||||
}
|
||||
auto assume_layout_op =
|
||||
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
|
||||
setLayout(assume_layout_op, layout, layout);
|
||||
assume_layout_ops.push_back(assume_layout_op);
|
||||
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
|
||||
return operand.getOwner() != assume_layout_op;
|
||||
});
|
||||
@ -539,6 +539,72 @@ class VectorLayoutInferer {
|
||||
return failure();
|
||||
}
|
||||
auto yield_op = op.getBody()->getTerminator();
|
||||
auto yield_in_layouts = getLayoutFromOperands(yield_op);
|
||||
|
||||
SmallVector<Layout, 4> out_layouts;
|
||||
out_layouts.reserve(op->getNumResults());
|
||||
int out_idx = 0;
|
||||
bool require_reinfer = false;
|
||||
for (auto [in_layout, yield_layout, result] :
|
||||
llvm::zip_equal(ArrayRef<Layout>(in_layouts).drop_front(3),
|
||||
yield_in_layouts, op.getResults())) {
|
||||
if (auto vty = dyn_cast<VectorType>(result.getType())) {
|
||||
if (!in_layout.has_value()) {
|
||||
return op.emitOpError("expected a vector layout for input ")
|
||||
<< out_idx;
|
||||
}
|
||||
if (!yield_layout.has_value()) {
|
||||
return op.emitOpError("expected a vector layout for yield input ")
|
||||
<< out_idx;
|
||||
}
|
||||
auto compatible_layout = VectorLayout::join(
|
||||
in_layout.value(), yield_layout.value(), vty.getShape());
|
||||
// If no compatible layout is found in layouts for input and
|
||||
// yield, the output layout falls back to a normalized layout which
|
||||
// has offsets 0 and the native tiling.
|
||||
if (!compatible_layout.has_value()) {
|
||||
compatible_layout = VectorLayout(in_layout->bitwidth(), {0, 0},
|
||||
nativeTiling(in_layout->bitwidth()),
|
||||
ImplicitDim::kNone);
|
||||
}
|
||||
if (!require_reinfer &&
|
||||
(compatible_layout.value() != in_layout.value() ||
|
||||
compatible_layout.value() != yield_layout.value())) {
|
||||
require_reinfer = true;
|
||||
}
|
||||
out_layouts.push_back(compatible_layout);
|
||||
} else {
|
||||
if (in_layout.has_value()) {
|
||||
return op.emitOpError("expected no layout for input ") << out_idx;
|
||||
}
|
||||
if (yield_layout.has_value()) {
|
||||
return op.emitOpError("expected no layout for yield input ")
|
||||
<< out_idx;
|
||||
}
|
||||
out_layouts.push_back(kNoLayout);
|
||||
}
|
||||
++out_idx;
|
||||
}
|
||||
if (require_reinfer) {
|
||||
// Terminator in the loop will carry layouts to the next loop but
|
||||
// the loop's block args' layouts are determined by the initial inputs. We
|
||||
// need to force the same layouts for all in order to make layouts be
|
||||
// consistent across all branches. To ensure that, we need to reprocess
|
||||
// layout inference for the entire body with the final consolidated
|
||||
// layout.
|
||||
for (int64_t i = 0; i < out_layouts.size(); ++i) {
|
||||
if (assume_layout_ops[i]) {
|
||||
setLayout(assume_layout_ops[i], out_layouts[i], out_layouts[i]);
|
||||
}
|
||||
}
|
||||
if (inferBlock(*op.getBody(), match_yield, /*override_layout=*/true)
|
||||
.failed()) {
|
||||
return op.emitOpError("failed to infer layout for scf.for op");
|
||||
}
|
||||
std::copy(out_layouts.begin(), out_layouts.end(),
|
||||
in_layouts.begin() + 3); // Skip first 3 layouts for lower
|
||||
// bound, upper bound and step.
|
||||
}
|
||||
setInLayout(yield_op, out_layouts);
|
||||
setLayout(op, in_layouts, out_layouts);
|
||||
return success();
|
||||
@ -555,34 +621,11 @@ class VectorLayoutInferer {
|
||||
};
|
||||
TPU_CHECK_OP(op.getNumRegions() == 2, "expected two blocks for scf.while");
|
||||
|
||||
const auto layout_for_type = [&op, this](const ::mlir::Value &arg,
|
||||
SmallVector<Layout> *layouts) {
|
||||
if (arg.getType().isSignlessIntOrIndexOrFloat()) {
|
||||
layouts->push_back(kNoLayout);
|
||||
} else if (isa<VectorType>(arg.getType())) {
|
||||
auto layout = getLayout(arg);
|
||||
layouts->push_back(layout);
|
||||
} else {
|
||||
op.emitOpError() << "unsupported arg type " << arg.getType()
|
||||
<< " in scf.while";
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
SmallVector<Layout> in_layouts;
|
||||
in_layouts.reserve(op->getNumOperands());
|
||||
for (const auto &arg : op.getInits()) {
|
||||
const auto status = layout_for_type(arg, &in_layouts);
|
||||
if (status.failed()) return status;
|
||||
}
|
||||
|
||||
// Formally, the types and layouts of the results should follow the layout
|
||||
// of the condition op in the Before region, rather than mimicking the input
|
||||
// layouts. In practice these are constrained to be the same for our current
|
||||
// pipelines, but doesn't represent the full expressiveness of scf.while.
|
||||
// TODO(hmckenzie): Base output layout on ConditionOp, not inputs.
|
||||
SmallVector<Layout> out_layouts = in_layouts;
|
||||
SmallVector<Layout, 4> in_layouts = getLayoutFromOperands(op);
|
||||
SmallVector<tpu::AssumeLayoutOp, 4> before_assume_layout_ops;
|
||||
before_assume_layout_ops.reserve(in_layouts.size());
|
||||
SmallVector<tpu::AssumeLayoutOp, 4> after_assume_layout_ops;
|
||||
after_assume_layout_ops.reserve(in_layouts.size());
|
||||
|
||||
// Use tpu.assume_layout to annotate every block argument with the layout of
|
||||
// the corresponding operand in WhileOp and replace all uses of the block
|
||||
@ -592,11 +635,13 @@ class VectorLayoutInferer {
|
||||
for (auto [iter_arg, layout] :
|
||||
llvm::zip_equal(op.getBeforeBody()->getArguments(), in_layouts)) {
|
||||
if (!dyn_cast<VectorType>(iter_arg.getType())) {
|
||||
before_assume_layout_ops.push_back(nullptr);
|
||||
continue;
|
||||
}
|
||||
auto assume_layout_op =
|
||||
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
|
||||
setLayout(assume_layout_op, layout, layout);
|
||||
before_assume_layout_ops.push_back(assume_layout_op);
|
||||
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
|
||||
return operand.getOwner() != assume_layout_op;
|
||||
});
|
||||
@ -608,13 +653,15 @@ class VectorLayoutInferer {
|
||||
builder =
|
||||
ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getAfterBody());
|
||||
for (auto [iter_arg, layout] :
|
||||
llvm::zip_equal(op.getAfterBody()->getArguments(), out_layouts)) {
|
||||
llvm::zip_equal(op.getAfterBody()->getArguments(), in_layouts)) {
|
||||
if (!dyn_cast<VectorType>(iter_arg.getType())) {
|
||||
after_assume_layout_ops.push_back(nullptr);
|
||||
continue;
|
||||
}
|
||||
auto assume_layout_op =
|
||||
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
|
||||
setLayout(assume_layout_op, layout, layout);
|
||||
after_assume_layout_ops.push_back(assume_layout_op);
|
||||
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
|
||||
return operand.getOwner() != assume_layout_op;
|
||||
});
|
||||
@ -624,35 +671,101 @@ class VectorLayoutInferer {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto *condition_op = op.getBeforeBody()->getTerminator();
|
||||
SmallVector<Layout> cond_layout;
|
||||
cond_layout.reserve(out_layouts.size() + 1);
|
||||
cond_layout.push_back(kNoLayout);
|
||||
cond_layout.append(out_layouts);
|
||||
setInLayout(condition_op, cond_layout);
|
||||
|
||||
auto *cond_op = op.getBeforeBody()->getTerminator();
|
||||
auto cond_in_layouts = getLayoutFromOperands(cond_op);
|
||||
auto *yield_op = op.getAfterBody()->getTerminator();
|
||||
setInLayout(yield_op, in_layouts);
|
||||
auto yield_in_layouts = getLayoutFromOperands(yield_op);
|
||||
|
||||
setLayout(op, in_layouts, out_layouts);
|
||||
return success();
|
||||
}
|
||||
LogicalResult infer(scf::ConditionOp op) {
|
||||
SmallVector<Layout> in_layouts;
|
||||
in_layouts.reserve(op->getNumOperands());
|
||||
for (const auto &arg : op.getOperands()) {
|
||||
if (arg.getType().isSignlessIntOrIndexOrFloat()) {
|
||||
in_layouts.push_back(kNoLayout);
|
||||
} else if (isa<VectorType>(arg.getType())) {
|
||||
auto layout = getLayout(arg);
|
||||
in_layouts.push_back(layout);
|
||||
// Find a compatible layout from condition body and loop body for each
|
||||
// reuslt. For example, if we yield offset (*, *) in condition body and
|
||||
// offset (*, 0) in loop body, the result offset should be (*, 0).
|
||||
SmallVector<Layout, 4> out_layouts;
|
||||
out_layouts.reserve(op->getNumResults());
|
||||
int out_idx = 0;
|
||||
bool require_reinfer = false;
|
||||
for (auto [in_layout, cond_layout, yield_layout, result] : llvm::zip_equal(
|
||||
in_layouts, ArrayRef<Layout>(cond_in_layouts).drop_front(1),
|
||||
yield_in_layouts, op.getResults())) {
|
||||
if (auto vty = dyn_cast<VectorType>(result.getType())) {
|
||||
if (!in_layout.has_value()) {
|
||||
return op.emitOpError("expected a vector layout for whileOp input ")
|
||||
<< out_idx;
|
||||
}
|
||||
if (!cond_layout.has_value()) {
|
||||
return op.emitOpError("expected a vector layout for condition input ")
|
||||
<< out_idx + 1; // ConditionOp's first input is 1 bit bool.
|
||||
}
|
||||
if (!yield_layout.has_value()) {
|
||||
return op.emitOpError("expected a vector layout for yield input ")
|
||||
<< out_idx;
|
||||
}
|
||||
auto compatible_layout = VectorLayout::join(
|
||||
cond_layout.value(), yield_layout.value(), vty.getShape());
|
||||
if (compatible_layout.has_value()) {
|
||||
compatible_layout = VectorLayout::join(
|
||||
in_layout.value(), compatible_layout.value(), vty.getShape());
|
||||
}
|
||||
// If no compatible layout is found in layouts for input, condition and
|
||||
// yield, the output layout falls back to a normalized layout which
|
||||
// has offsets 0 and the native tiling.
|
||||
if (!compatible_layout.has_value()) {
|
||||
compatible_layout = VectorLayout(in_layout->bitwidth(), {0, 0},
|
||||
nativeTiling(in_layout->bitwidth()),
|
||||
ImplicitDim::kNone);
|
||||
}
|
||||
if (!require_reinfer &&
|
||||
(compatible_layout.value() != in_layout.value() ||
|
||||
compatible_layout.value() != cond_layout.value() ||
|
||||
compatible_layout.value() != yield_layout.value())) {
|
||||
require_reinfer = true;
|
||||
}
|
||||
out_layouts.push_back(compatible_layout);
|
||||
} else {
|
||||
op.emitOpError() << "unsupported arg type " << arg.getType()
|
||||
<< " in scf::condition";
|
||||
return failure();
|
||||
if (in_layout.has_value()) {
|
||||
return op.emitOpError("expected no layout for whileOp input ")
|
||||
<< out_idx;
|
||||
}
|
||||
if (cond_layout.has_value()) {
|
||||
return op.emitOpError("expected no layout for condition input ")
|
||||
<< out_idx + 1; // ConditionOp's first input is 1 bit bool.
|
||||
}
|
||||
if (yield_layout.has_value()) {
|
||||
return op.emitOpError("expected no layout for yield input ")
|
||||
<< out_idx;
|
||||
}
|
||||
out_layouts.push_back(kNoLayout);
|
||||
}
|
||||
++out_idx;
|
||||
}
|
||||
if (require_reinfer) {
|
||||
// Terminator in the loop will carry layouts to the next loop but
|
||||
// the loop's block args' layouts are determined by the initial inputs. We
|
||||
// need to force the same layouts for all in order to make layouts be
|
||||
// consistent across all branches. To ensure that, we need to reprocess
|
||||
// layout inference for the entire body with the final consolidated
|
||||
// layout.
|
||||
for (int64_t i = 0; i < out_layouts.size(); ++i) {
|
||||
if (before_assume_layout_ops[i]) {
|
||||
setLayout(before_assume_layout_ops[i], out_layouts[i],
|
||||
out_layouts[i]);
|
||||
}
|
||||
if (after_assume_layout_ops[i]) {
|
||||
setLayout(after_assume_layout_ops[i], out_layouts[i], out_layouts[i]);
|
||||
}
|
||||
}
|
||||
if (inferBlock(*op.getBeforeBody(), match_condition,
|
||||
/*override_layout=*/true)
|
||||
.failed() ||
|
||||
inferBlock(*op.getAfterBody(), match_yield, /*override_layout=*/true)
|
||||
.failed()) {
|
||||
return op.emitOpError("failed to infer layout for scf.while op");
|
||||
}
|
||||
}
|
||||
setLayout(op, in_layouts, ArrayRef<Layout>(in_layouts).drop_front(1));
|
||||
std::copy(out_layouts.begin(), out_layouts.end(),
|
||||
cond_in_layouts.begin() + 1); // Skip the first 1 bit bool.
|
||||
setInLayout(cond_op, cond_in_layouts);
|
||||
setInLayout(yield_op, out_layouts);
|
||||
setLayout(op, out_layouts, out_layouts);
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -1807,6 +1920,19 @@ class VectorLayoutInferer {
|
||||
return cast<VectorLayoutAttr>(out_attrs[result_index]).getLayout();
|
||||
}
|
||||
|
||||
SmallVector<Layout, 4> getLayoutFromOperands(Operation *op) {
|
||||
SmallVector<Layout, 4> layouts;
|
||||
layouts.reserve(op->getNumOperands());
|
||||
for (const auto &operand : op->getOperands()) {
|
||||
if (isa<VectorType>(operand.getType())) {
|
||||
layouts.push_back(getLayout(operand));
|
||||
} else {
|
||||
layouts.push_back(kNoLayout);
|
||||
}
|
||||
}
|
||||
return layouts;
|
||||
}
|
||||
|
||||
private:
|
||||
std::optional<absl::Span<const int64_t>> verifyMemoryTiling(
|
||||
Operation *op, ArrayRef<xla::Tile> mem_tiling, int64_t rank,
|
||||
|
Loading…
x
Reference in New Issue
Block a user