mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Support relayout from (1, 128) to (8, 128).
PiperOrigin-RevId: 563534657
This commit is contained in:
parent
a9410e5547
commit
8b700fa75d
@ -659,6 +659,18 @@ class VectorLayoutInferer {
|
||||
TPU_CHECK_OP(res_ty.getRank() >= 2, "result rank below 2D unsupported");
|
||||
auto some_layout = getLayout(op.getSource());
|
||||
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
|
||||
// We want to force the layout to be (8, 128) instead of (1, 128) if we
|
||||
// are broadcasting sublane dim from 1 to at least 8.
|
||||
if (some_layout->bitwidth() == kNativeBitwidth &&
|
||||
some_layout->implicit_dim() == ImplicitDim::kNone &&
|
||||
some_layout->tiling()[0] == 1 &&
|
||||
some_layout->tiling()[1] == default_tiling_[1] &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 2) == 1 &&
|
||||
res_ty.getDimSize(res_ty.getRank() - 2) >= 8) {
|
||||
*some_layout = VectorLayout(
|
||||
some_layout->bitwidth(), {std::nullopt, some_layout->offsets()[1]},
|
||||
default_tiling_, some_layout->implicit_dim());
|
||||
}
|
||||
auto &layout = *some_layout;
|
||||
if (layout.implicit_dim() == ImplicitDim::kSecondMinor &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 2) == 1) {
|
||||
|
@ -892,6 +892,33 @@ def relayout(
|
||||
if dst.implicit_dim is None and vty.shape[src.implicit_dim] == 1:
|
||||
src = VectorLayout(src.bitwidth, src.offsets, src.tiling, None)
|
||||
|
||||
# Handle retiling from (1, 128) to (8, 128) for 32-bit data.
|
||||
if (
|
||||
src.implicit_dim is None
|
||||
and dst.implicit_dim is None
|
||||
and src.bitwidth == 32
|
||||
and src.offsets == (0, 0)
|
||||
and dst.offsets == (REPLICATED, 0)
|
||||
and src.tiling == (1, 128)
|
||||
and dst.tiling == (8, 128)
|
||||
and src_tiles.shape[-2] == 1
|
||||
):
|
||||
src_tiles_retiled = np.empty(
|
||||
dst.tile_array_shape(vty.shape), dtype=object
|
||||
)
|
||||
for *batch_idx, dst_col in np.ndindex(
|
||||
src_tiles_retiled.shape[:-2] + src_tiles_retiled.shape[-1:]
|
||||
):
|
||||
src_col = dst_col // 8
|
||||
slane_idx = dst_col % 8
|
||||
gather_indices = ir.DenseI32ArrayAttr.get([slane_idx] * 8)
|
||||
src_tile = src_tiles[(*batch_idx, 0, src_col)]
|
||||
src_tiles_retiled[(*batch_idx, slice(None), dst_col)] = tpu.GatherOp(
|
||||
src_tile.type, src_tile, gather_indices, 0
|
||||
)
|
||||
src = dst
|
||||
src_tiles = src_tiles_retiled
|
||||
|
||||
# TODO(apaszke): Generalize retiling to general 16-bit types (might need to
|
||||
# use a different unpacking op).
|
||||
# (8,128) -> (16,128) tiling change for packed 16-bit types.
|
||||
|
Loading…
x
Reference in New Issue
Block a user