[Mosaic] Support relayout from (1, 128) to (8, 128).

PiperOrigin-RevId: 563534657
This commit is contained in:
Jevin Jiang 2023-09-07 13:49:05 -07:00 committed by jax authors
parent a9410e5547
commit 8b700fa75d
2 changed files with 39 additions and 0 deletions

View File

@ -659,6 +659,18 @@ class VectorLayoutInferer {
TPU_CHECK_OP(res_ty.getRank() >= 2, "result rank below 2D unsupported");
auto some_layout = getLayout(op.getSource());
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
// We want to force the layout to be (8, 128) instead of (1, 128) if we
// are broadcasting sublane dim from 1 to at least 8.
if (some_layout->bitwidth() == kNativeBitwidth &&
some_layout->implicit_dim() == ImplicitDim::kNone &&
some_layout->tiling()[0] == 1 &&
some_layout->tiling()[1] == default_tiling_[1] &&
src_ty.getDimSize(src_ty.getRank() - 2) == 1 &&
res_ty.getDimSize(res_ty.getRank() - 2) >= 8) {
*some_layout = VectorLayout(
some_layout->bitwidth(), {std::nullopt, some_layout->offsets()[1]},
default_tiling_, some_layout->implicit_dim());
}
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kSecondMinor &&
src_ty.getDimSize(src_ty.getRank() - 2) == 1) {

View File

@ -892,6 +892,33 @@ def relayout(
if dst.implicit_dim is None and vty.shape[src.implicit_dim] == 1:
src = VectorLayout(src.bitwidth, src.offsets, src.tiling, None)
# Handle retiling from (1, 128) to (8, 128) for 32-bit data.
if (
src.implicit_dim is None
and dst.implicit_dim is None
and src.bitwidth == 32
and src.offsets == (0, 0)
and dst.offsets == (REPLICATED, 0)
and src.tiling == (1, 128)
and dst.tiling == (8, 128)
and src_tiles.shape[-2] == 1
):
src_tiles_retiled = np.empty(
dst.tile_array_shape(vty.shape), dtype=object
)
for *batch_idx, dst_col in np.ndindex(
src_tiles_retiled.shape[:-2] + src_tiles_retiled.shape[-1:]
):
src_col = dst_col // 8
slane_idx = dst_col % 8
gather_indices = ir.DenseI32ArrayAttr.get([slane_idx] * 8)
src_tile = src_tiles[(*batch_idx, 0, src_col)]
src_tiles_retiled[(*batch_idx, slice(None), dst_col)] = tpu.GatherOp(
src_tile.type, src_tile, gather_indices, 0
)
src = dst
src_tiles = src_tiles_retiled
# TODO(apaszke): Generalize retiling to general 16-bit types (might need to
# use a different unpacking op).
# (8,128) -> (16,128) tiling change for packed 16-bit types.