[XLA:Mosaic] Use join to find a compatible output layout in scf.if.

PiperOrigin-RevId: 605442646
This commit is contained in:
Jevin Jiang 2024-02-08 15:19:30 -08:00 committed by jax authors
parent d29c86eb52
commit 9b320f23f0

View File

@ -582,28 +582,19 @@ class VectorLayoutInferer {
if (!isa<VectorType>(operand.getType())) {
continue;
}
auto shape = dyn_cast<VectorType>(operand.getType()).getShape();
auto layout = getLayout(operand);
CHECK(result_layout[i].has_value() && layout.has_value());
auto shape = dyn_cast<VectorType>(operand.getType()).getShape();
if (result_layout[i].value().generalizes(layout.value(), shape,
target_shape_)) {
result_layout[i] = layout;
} else if (layout.value().generalizes(result_layout[i].value(), shape,
target_shape_)) {
// No change.
} else {
// TODO(jevinjiang): ideally we can try to find a compatible layout
// which can be generalized from both then and else branch when then
// yield layout and else yield layout can not generalize each other. For
// example, if then yields offset (*, 0) and else yields offset (0, *),
// the compatible offset could be (0, 0). But it is too complex to
// handle for now. We can add the support when there is a use case.
result_layout[i] =
VectorLayout::join(result_layout[i].value(), layout.value(), shape);
if (!result_layout[i].has_value()) {
op.emitOpError(
"failed to find a compatible layout for then and else branch");
"failed to find a compatible layout in then and else branch for "
"output ")
<< i;
return failure();
}
}
setInLayout(then_yield, result_layout);
setInLayout(else_yield, result_layout);
setOutLayout(op, result_layout);