mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[XLA:Mosaic] Use join to find a compatible output layout in scf.if.
PiperOrigin-RevId: 605442646
This commit is contained in:
parent
d29c86eb52
commit
9b320f23f0
@ -582,28 +582,19 @@ class VectorLayoutInferer {
|
||||
if (!isa<VectorType>(operand.getType())) {
|
||||
continue;
|
||||
}
|
||||
auto shape = dyn_cast<VectorType>(operand.getType()).getShape();
|
||||
auto layout = getLayout(operand);
|
||||
CHECK(result_layout[i].has_value() && layout.has_value());
|
||||
auto shape = dyn_cast<VectorType>(operand.getType()).getShape();
|
||||
if (result_layout[i].value().generalizes(layout.value(), shape,
|
||||
target_shape_)) {
|
||||
result_layout[i] = layout;
|
||||
} else if (layout.value().generalizes(result_layout[i].value(), shape,
|
||||
target_shape_)) {
|
||||
// No change.
|
||||
} else {
|
||||
// TODO(jevinjiang): ideally we can try to find a compatible layout
|
||||
// which can be generalized from both then and else branch when then
|
||||
// yield layout and else yield layout can not generalize each other. For
|
||||
// example, if then yields offset (*, 0) and else yields offset (0, *),
|
||||
// the compatible offset could be (0, 0). But it is too complex to
|
||||
// handle for now. We can add the support when there is a use case.
|
||||
result_layout[i] =
|
||||
VectorLayout::join(result_layout[i].value(), layout.value(), shape);
|
||||
if (!result_layout[i].has_value()) {
|
||||
op.emitOpError(
|
||||
"failed to find a compatible layout for then and else branch");
|
||||
"failed to find a compatible layout in then and else branch for "
|
||||
"output ")
|
||||
<< i;
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
setInLayout(then_yield, result_layout);
|
||||
setInLayout(else_yield, result_layout);
|
||||
setOutLayout(op, result_layout);
|
||||
|
Loading…
x
Reference in New Issue
Block a user