[Mosaic TPU] Normalize inferred layouts to supported ones in matmul rule

Previously the rule would complain if the layouts were unsupported, but that's not
the right way to handle that situation. With this change, we simply pick a supported
configuration instead (and expect relayout to handle it).

PiperOrigin-RevId: 640190248
This commit is contained in:
Adam Paszke 2024-06-04 10:02:03 -07:00 committed by jax authors
parent 9939cc9974
commit 6a1fcc6cb2

View File

@ -1756,25 +1756,31 @@ class VectorLayoutInferer {
}
LogicalResult inferMatmul(Operation *op) {
auto get_unpadded_layout =
[&](Value v, std::optional<int64_t> major_multiple = std::nullopt,
auto get_operand_layout =
[&](Value v, llvm::StringRef operand_name,
std::optional<int64_t> major_multiple = std::nullopt,
std::optional<int64_t> minor_multiple =
std::nullopt) -> std::optional<VectorLayout> {
auto pad = getLayout(v);
if (!pad.has_value() || pad->implicit_dim() != ImplicitDim::kNone) {
auto layout = getLayout(v);
if (!layout.has_value()) {
op->emitOpError("Internal error: assert failed: Operand ")
<< operand_name << " has no vector layout";
return std::nullopt;
}
auto vty = cast<VectorType>(v.getType());
auto tiling = nativeTiling(vty.getElementTypeBitWidth());
auto shape = vty.getShape().take_back(2);
if (pad->offsets()[0].value_or(0) != 0 ||
pad->offsets()[1].value_or(0) != 0 ||
shape[0] % major_multiple.value_or(tiling[0]) != 0 ||
if (shape[0] % major_multiple.value_or(tiling[0]) != 0 ||
shape[1] % minor_multiple.value_or(tiling[1]) != 0) {
op->emitOpError("Matmul operand")
<< operand_name << " must have a shape divisible by ("
<< major_multiple.value_or(tiling[0]) << ", "
<< minor_multiple.value_or(tiling[1]) << "), but got: (" << shape[0]
<< ", " << shape[1] << ")";
return std::nullopt;
}
// Override tiling to match the native one.
return VectorLayout(pad->bitwidth(), pad->offsets(), tiling,
return VectorLayout(layout->bitwidth(), {0, 0}, tiling,
ImplicitDim::kNone);
};
auto res_ty = dyn_cast<VectorType>(op->getResult(0).getType());
@ -1796,15 +1802,18 @@ class VectorLayoutInferer {
rhs_major_multiple = 1;
}
in_layout[0] =
get_unpadded_layout(op->getOperand(0), lhs_major_multiple, 1);
get_operand_layout(op->getOperand(0), "lhs", lhs_major_multiple, 1);
if (!in_layout[0].has_value()) {
return failure();
}
in_layout[1] =
get_unpadded_layout(op->getOperand(1), rhs_major_multiple, 1);
in_layout[2] = get_unpadded_layout(op->getOperand(2), 1, 1);
for (Layout &l : in_layout) {
if (!l.has_value()) {
op->emitOpError("unsupported operand shapes or layouts");
return failure();
}
get_operand_layout(op->getOperand(1), "rhs", rhs_major_multiple, 1);
if (!in_layout[1].has_value()) {
return failure();
}
in_layout[2] = get_operand_layout(op->getOperand(2), "result", 1, 1);
if (!in_layout[2].has_value()) {
return failure();
}
setLayout(op, in_layout,
VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_,