[Mosaic:TPU] In infer ext rule, avoid assigning offsets outside of dst first tile

Note that offsets outside of first tile are still disabled (for both infer and apply), and once we support it we will want to assign offsets differently, this is mostly to avoid assigning invalid layouts (that may not just be outside the first tile, but outside the vreg slice)

PiperOrigin-RevId: 709168368
This commit is contained in:
Tomás Longeri 2024-12-23 15:47:29 -08:00 committed by jax authors
parent b8091a437a
commit 4452960947

View File

@ -1647,10 +1647,17 @@ class VectorLayoutInferer {
Layout dst_layout;
if (layout.tiling() == nativeTiling(src_bitwidth)) {
// If the source is already in native tiling, we can unpack it directly.
src_layout = layout;
std::array<int64_t, 2> dst_native_tiling = nativeTiling(dst_bitwidth);
LayoutOffsets offsets = {layout.offsets()[0]
? *layout.offsets()[0] % dst_native_tiling[0]
: LayoutOffset(),
layout.offsets()[1]};
DCHECK_LT(offsets[1].value_or(0), dst_native_tiling[1]);
src_layout = VectorLayout(src_bitwidth, offsets, layout.tiling(),
layout.implicit_dim());
dst_layout =
VectorLayout(dst_bitwidth, layout.offsets(),
nativeTiling(dst_bitwidth), layout.implicit_dim());
VectorLayout(dst_bitwidth, offsets, dst_native_tiling,
layout.implicit_dim());
} else if (dst_bitwidth == 32 &&
default_tiling_[0] % layout.tiling()[0] == 0 &&
default_tiling_[1] == layout.tiling()[1]) {
@ -1659,13 +1666,17 @@ class VectorLayoutInferer {
// tiling through the op.
// TODO(jevinjiang): we can relax this for non-32bit as well.
src_layout = layout;
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
layout.implicit_dim());
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(),
src_layout->tiling(), layout.implicit_dim());
} else {
// TODO(b/335863273): we should also reduce offsets.
src_layout = VectorLayout(src_bitwidth, layout.offsets(), default_tiling_,
LayoutOffsets offsets = {
layout.offsets()[0] ? *layout.offsets()[0] % default_tiling_[0]
: LayoutOffset(),
layout.offsets()[1] ? *layout.offsets()[1] % default_tiling_[1]
: LayoutOffset()};
src_layout = VectorLayout(src_bitwidth, offsets, default_tiling_,
layout.implicit_dim());
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(), default_tiling_,
dst_layout = VectorLayout(dst_bitwidth, offsets, default_tiling_,
layout.implicit_dim());
}
setLayout(op, src_layout, dst_layout);