mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Support relayout from (1,128) to (8,128) when dst.offset is (0, 0).
PiperOrigin-RevId: 564882618
This commit is contained in:
parent
d3950b93cb
commit
d4b564a263
@ -1214,7 +1214,7 @@ def relayout(
|
||||
and dst.implicit_dim is None
|
||||
and src.bitwidth == 32
|
||||
and src.offsets == (0, 0)
|
||||
and dst.offsets == (REPLICATED, 0)
|
||||
and (dst.offsets == (REPLICATED, 0) or dst.offsets == (0, 0))
|
||||
and src.tiling == (1, 128)
|
||||
and dst.tiling == (8, 128)
|
||||
and src_tiles.shape[-2] == 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user