mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[XLA:Mosaic] Use join to find a compatible output layout in scf.if.
PiperOrigin-RevId: 605442646
This commit is contained in:
parent
d29c86eb52
commit
9b320f23f0
@ -582,28 +582,19 @@ class VectorLayoutInferer {
|
|||||||
if (!isa<VectorType>(operand.getType())) {
|
if (!isa<VectorType>(operand.getType())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
auto shape = dyn_cast<VectorType>(operand.getType()).getShape();
|
||||||
auto layout = getLayout(operand);
|
auto layout = getLayout(operand);
|
||||||
CHECK(result_layout[i].has_value() && layout.has_value());
|
CHECK(result_layout[i].has_value() && layout.has_value());
|
||||||
auto shape = dyn_cast<VectorType>(operand.getType()).getShape();
|
result_layout[i] =
|
||||||
if (result_layout[i].value().generalizes(layout.value(), shape,
|
VectorLayout::join(result_layout[i].value(), layout.value(), shape);
|
||||||
target_shape_)) {
|
if (!result_layout[i].has_value()) {
|
||||||
result_layout[i] = layout;
|
|
||||||
} else if (layout.value().generalizes(result_layout[i].value(), shape,
|
|
||||||
target_shape_)) {
|
|
||||||
// No change.
|
|
||||||
} else {
|
|
||||||
// TODO(jevinjiang): ideally we can try to find a compatible layout
|
|
||||||
// which can be generalized from both then and else branch when then
|
|
||||||
// yield layout and else yield layout can not generalize each other. For
|
|
||||||
// example, if then yields offset (*, 0) and else yields offset (0, *),
|
|
||||||
// the compatible offset could be (0, 0). But it is too complex to
|
|
||||||
// handle for now. We can add the support when there is a use case.
|
|
||||||
op.emitOpError(
|
op.emitOpError(
|
||||||
"failed to find a compatible layout for then and else branch");
|
"failed to find a compatible layout in then and else branch for "
|
||||||
|
"output ")
|
||||||
|
<< i;
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
setInLayout(then_yield, result_layout);
|
setInLayout(then_yield, result_layout);
|
||||||
setInLayout(else_yield, result_layout);
|
setInLayout(else_yield, result_layout);
|
||||||
setOutLayout(op, result_layout);
|
setOutLayout(op, result_layout);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user