mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Also check bitwidth in apply-vector-layout's layoutIsValidForValue
.
PiperOrigin-RevId: 635595321
This commit is contained in:
parent
118ca21b5b
commit
b197ae527e
@ -548,8 +548,21 @@ bool layoutIsValidForValue(const Layout &l, const Value v,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
// l must be non-null iff v is of vector type
|
||||
if (const auto vty = dyn_cast<VectorType>(v.getType())) {
|
||||
return l.has_value() && l->isValid(target_shape) &&
|
||||
l->layout_rank() <= vty.getRank();
|
||||
if (!l.has_value()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Vector type should have the same bitwidth as the layout, except for the
|
||||
// i1 special case, used for vmasks (see comment for VectorLayout class).
|
||||
if (!vty.getElementType().isIntOrFloat()) {
|
||||
return false;
|
||||
}
|
||||
const int8_t bitwidth = vty.getElementTypeBitWidth();
|
||||
if (bitwidth != l->bitwidth() && bitwidth != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return l->isValid(target_shape) && l->layout_rank() <= vty.getRank();
|
||||
}
|
||||
return !l.has_value();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user