[Mosaic] Also check bitwidth in apply-vector-layout's layoutIsValidForValue.

PiperOrigin-RevId: 635595321
This commit is contained in:
Tomás Longeri 2024-05-20 15:56:20 -07:00 committed by jax authors
parent 118ca21b5b
commit b197ae527e

View File

@ -548,8 +548,21 @@ bool layoutIsValidForValue(const Layout &l, const Value v,
const std::array<int64_t, 2> target_shape) {
// l must be non-null iff v is of vector type
if (const auto vty = dyn_cast<VectorType>(v.getType())) {
return l.has_value() && l->isValid(target_shape) &&
l->layout_rank() <= vty.getRank();
if (!l.has_value()) {
return false;
}
// Vector type should have the same bitwidth as the layout, except for the
// i1 special case, used for vmasks (see comment for VectorLayout class).
if (!vty.getElementType().isIntOrFloat()) {
return false;
}
const int8_t bitwidth = vty.getElementTypeBitWidth();
if (bitwidth != l->bitwidth() && bitwidth != 1) {
return false;
}
return l->isValid(target_shape) && l->layout_rank() <= vty.getRank();
}
return !l.has_value();
}