mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic TPU] Normalize inferred layouts to supported ones in matmul rule
Previously the rule would complain if the layouts were unsupported, but that's not the right way to handle that situation. With this change, we simply pick a supported configuration instead (and expect relayout to handle it). PiperOrigin-RevId: 640190248
This commit is contained in:
parent
9939cc9974
commit
6a1fcc6cb2
@ -1756,25 +1756,31 @@ class VectorLayoutInferer {
|
||||
}
|
||||
|
||||
LogicalResult inferMatmul(Operation *op) {
|
||||
auto get_unpadded_layout =
|
||||
[&](Value v, std::optional<int64_t> major_multiple = std::nullopt,
|
||||
auto get_operand_layout =
|
||||
[&](Value v, llvm::StringRef operand_name,
|
||||
std::optional<int64_t> major_multiple = std::nullopt,
|
||||
std::optional<int64_t> minor_multiple =
|
||||
std::nullopt) -> std::optional<VectorLayout> {
|
||||
auto pad = getLayout(v);
|
||||
if (!pad.has_value() || pad->implicit_dim() != ImplicitDim::kNone) {
|
||||
auto layout = getLayout(v);
|
||||
if (!layout.has_value()) {
|
||||
op->emitOpError("Internal error: assert failed: Operand ")
|
||||
<< operand_name << " has no vector layout";
|
||||
return std::nullopt;
|
||||
}
|
||||
auto vty = cast<VectorType>(v.getType());
|
||||
auto tiling = nativeTiling(vty.getElementTypeBitWidth());
|
||||
auto shape = vty.getShape().take_back(2);
|
||||
if (pad->offsets()[0].value_or(0) != 0 ||
|
||||
pad->offsets()[1].value_or(0) != 0 ||
|
||||
shape[0] % major_multiple.value_or(tiling[0]) != 0 ||
|
||||
if (shape[0] % major_multiple.value_or(tiling[0]) != 0 ||
|
||||
shape[1] % minor_multiple.value_or(tiling[1]) != 0) {
|
||||
op->emitOpError("Matmul operand")
|
||||
<< operand_name << " must have a shape divisible by ("
|
||||
<< major_multiple.value_or(tiling[0]) << ", "
|
||||
<< minor_multiple.value_or(tiling[1]) << "), but got: (" << shape[0]
|
||||
<< ", " << shape[1] << ")";
|
||||
return std::nullopt;
|
||||
}
|
||||
// Override tiling to match the native one.
|
||||
return VectorLayout(pad->bitwidth(), pad->offsets(), tiling,
|
||||
return VectorLayout(layout->bitwidth(), {0, 0}, tiling,
|
||||
ImplicitDim::kNone);
|
||||
};
|
||||
auto res_ty = dyn_cast<VectorType>(op->getResult(0).getType());
|
||||
@ -1796,15 +1802,18 @@ class VectorLayoutInferer {
|
||||
rhs_major_multiple = 1;
|
||||
}
|
||||
in_layout[0] =
|
||||
get_unpadded_layout(op->getOperand(0), lhs_major_multiple, 1);
|
||||
get_operand_layout(op->getOperand(0), "lhs", lhs_major_multiple, 1);
|
||||
if (!in_layout[0].has_value()) {
|
||||
return failure();
|
||||
}
|
||||
in_layout[1] =
|
||||
get_unpadded_layout(op->getOperand(1), rhs_major_multiple, 1);
|
||||
in_layout[2] = get_unpadded_layout(op->getOperand(2), 1, 1);
|
||||
for (Layout &l : in_layout) {
|
||||
if (!l.has_value()) {
|
||||
op->emitOpError("unsupported operand shapes or layouts");
|
||||
return failure();
|
||||
}
|
||||
get_operand_layout(op->getOperand(1), "rhs", rhs_major_multiple, 1);
|
||||
if (!in_layout[1].has_value()) {
|
||||
return failure();
|
||||
}
|
||||
in_layout[2] = get_operand_layout(op->getOperand(2), "result", 1, 1);
|
||||
if (!in_layout[2].has_value()) {
|
||||
return failure();
|
||||
}
|
||||
setLayout(op, in_layout,
|
||||
VectorLayout(kNativeBitwidth, {0, 0}, default_tiling_,
|
||||
|
Loading…
x
Reference in New Issue
Block a user