From 9b320f23f0fbed447f6e68c4b4db4603eea62bba Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 8 Feb 2024 15:19:30 -0800 Subject: [PATCH] [XLA:Mosaic] Use join to find a compatible output layout in scf.if. PiperOrigin-RevId: 605442646 --- .../tpu/transforms/infer_vector_layout.cc | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index a4470efe8..d9ab17b22 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -582,28 +582,19 @@ class VectorLayoutInferer { if (!isa(operand.getType())) { continue; } + auto shape = dyn_cast(operand.getType()).getShape(); auto layout = getLayout(operand); CHECK(result_layout[i].has_value() && layout.has_value()); - auto shape = dyn_cast(operand.getType()).getShape(); - if (result_layout[i].value().generalizes(layout.value(), shape, - target_shape_)) { - result_layout[i] = layout; - } else if (layout.value().generalizes(result_layout[i].value(), shape, - target_shape_)) { - // No change. - } else { - // TODO(jevinjiang): ideally we can try to find a compatible layout - // which can be generalized from both then and else branch when then - // yield layout and else yield layout can not generalize each other. For - // example, if then yields offset (*, 0) and else yields offset (0, *), - // the compatible offset could be (0, 0). But it is too complex to - // handle for now. We can add the support when there is a use case. + result_layout[i] = + VectorLayout::join(result_layout[i].value(), layout.value(), shape); + if (!result_layout[i].has_value()) { op.emitOpError( - "failed to find a compatible layout for then and else branch"); + "failed to find a compatible layout in then and else branch for " + "output ") + << i; return failure(); } } - setInLayout(then_yield, result_layout); setInLayout(else_yield, result_layout); setOutLayout(op, result_layout);