mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic:TPU] In infer ext rule, avoid assigning offsets outside of dst first tile
Note that offsets outside of first tile are still disabled (for both infer and apply), and once we support it we will want to assign offsets differently, this is mostly to avoid assigning invalid layouts (that may not just be outside the first tile, but outside the vreg slice) PiperOrigin-RevId: 709168368
This commit is contained in:
parent
b8091a437a
commit
4452960947
@ -1647,10 +1647,17 @@ class VectorLayoutInferer {
|
||||
Layout dst_layout;
|
||||
if (layout.tiling() == nativeTiling(src_bitwidth)) {
|
||||
// If the source is already in native tiling, we can unpack it directly.
|
||||
src_layout = layout;
|
||||
std::array<int64_t, 2> dst_native_tiling = nativeTiling(dst_bitwidth);
|
||||
LayoutOffsets offsets = {layout.offsets()[0]
|
||||
? *layout.offsets()[0] % dst_native_tiling[0]
|
||||
: LayoutOffset(),
|
||||
layout.offsets()[1]};
|
||||
DCHECK_LT(offsets[1].value_or(0), dst_native_tiling[1]);
|
||||
src_layout = VectorLayout(src_bitwidth, offsets, layout.tiling(),
|
||||
layout.implicit_dim());
|
||||
dst_layout =
|
||||
VectorLayout(dst_bitwidth, layout.offsets(),
|
||||
nativeTiling(dst_bitwidth), layout.implicit_dim());
|
||||
VectorLayout(dst_bitwidth, offsets, dst_native_tiling,
|
||||
layout.implicit_dim());
|
||||
} else if (dst_bitwidth == 32 &&
|
||||
default_tiling_[0] % layout.tiling()[0] == 0 &&
|
||||
default_tiling_[1] == layout.tiling()[1]) {
|
||||
@ -1659,13 +1666,17 @@ class VectorLayoutInferer {
|
||||
// tiling through the op.
|
||||
// TODO(jevinjiang): we can relax this for non-32bit as well.
|
||||
src_layout = layout;
|
||||
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
|
||||
layout.implicit_dim());
|
||||
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(),
|
||||
src_layout->tiling(), layout.implicit_dim());
|
||||
} else {
|
||||
// TODO(b/335863273): we should also reduce offsets.
|
||||
src_layout = VectorLayout(src_bitwidth, layout.offsets(), default_tiling_,
|
||||
LayoutOffsets offsets = {
|
||||
layout.offsets()[0] ? *layout.offsets()[0] % default_tiling_[0]
|
||||
: LayoutOffset(),
|
||||
layout.offsets()[1] ? *layout.offsets()[1] % default_tiling_[1]
|
||||
: LayoutOffset()};
|
||||
src_layout = VectorLayout(src_bitwidth, offsets, default_tiling_,
|
||||
layout.implicit_dim());
|
||||
dst_layout = VectorLayout(dst_bitwidth, layout.offsets(), default_tiling_,
|
||||
dst_layout = VectorLayout(dst_bitwidth, offsets, default_tiling_,
|
||||
layout.implicit_dim());
|
||||
}
|
||||
setLayout(op, src_layout, dst_layout);
|
||||
|
Loading…
x
Reference in New Issue
Block a user