[Mosaic] Support relayout from (1,128) to (8,128) when dst.offset is (0, 0).

PiperOrigin-RevId: 564882618
This commit is contained in:
Jevin Jiang 2023-09-12 17:33:47 -07:00 committed by jax authors
parent d3950b93cb
commit d4b564a263

View File

@ -1214,7 +1214,7 @@ def relayout(
and dst.implicit_dim is None
and src.bitwidth == 32
and src.offsets == (0, 0)
and dst.offsets == (REPLICATED, 0)
and (dst.offsets == (REPLICATED, 0) or dst.offsets == (0, 0))
and src.tiling == (1, 128)
and dst.tiling == (8, 128)
and src_tiles.shape[-2] == 1