Support relayout of tiles in register when the layout tiling changes.

PiperOrigin-RevId: 563570338
This commit is contained in:
jax authors 2023-09-07 16:00:01 -07:00
parent a6eed40f24
commit bfd79b84e4

View File

@ -854,6 +854,322 @@ def print_layout(layout: Layout) -> str:
return f'#tpu.vpad<"{layout.bitwidth},{{{o[0]},{o[1]}}},({layout.tiling[0]},{layout.tiling[1]}){implicit_dim}">'
def select_tiles_from_rotated_row_vregs(
rotated_row_vregs: np.ndarray,
start_src_col: int,
end_src_col: int,
first_dst_tile_sublane_offset: int,
dst_layout: VectorLayout,
hw_generation: int,
) -> ValueLike:
"""Assembles a destination tile using partial data from rotated vregs using a divide-and-conquer strategy.
Arguments:
rotated_row_vregs: A row of rotated vregs, from which destination tile(s)
is/are to be selected to assemble a new vreg.
src_layout: The source layout.
start_src_col: The first rotated vreg in the row of rotated vregs to
process.
end_src_col: The last rotated vreg in the row of rotated vreg to process.
first_dst_tile_sublane_offset: Sublane offset where the first dst tile to be
selected starts.
dst_layout: Destination layout, based on which retiling is being performed.
hw_generation: The generation of a target hardware.
Returns:
A new vreg assembled from dst tiles stored in given rotated vregs.
"""
if start_src_col > end_src_col:
raise ValueError("Invalid values for start and end column.")
if start_src_col == end_src_col:
return rotated_row_vregs[start_src_col]
mid_src_col = start_src_col + (end_src_col - start_src_col) // 2
left_partial_vreg = select_tiles_from_rotated_row_vregs(
rotated_row_vregs,
start_src_col,
mid_src_col,
first_dst_tile_sublane_offset,
dst_layout,
hw_generation,
)
left_tiles_count = mid_src_col - start_src_col + 1
right_first_dst_tile_sublane_offset = (
first_dst_tile_sublane_offset
+ left_tiles_count * dst_layout.sublanes_per_tile
) % TARGET_SHAPE.sublanes
right_partial_vreg = select_tiles_from_rotated_row_vregs(
rotated_row_vregs,
mid_src_col + 1,
end_src_col,
right_first_dst_tile_sublane_offset,
dst_layout,
hw_generation,
)
i1 = ir.IntegerType.get_signless(1)
mask_vreg_ty = (
ir.VectorType.get((*TARGET_SHAPE, 2), i1)
if dst_layout.packing == 2
else ir.VectorType.get(TARGET_SHAPE, i1)
)
if first_dst_tile_sublane_offset < right_first_dst_tile_sublane_offset:
# The useful data sublanes in left vregs do not wrap around in vreg.
# For e.g. consider (2,128) destination tiling and we are trying to merge
# two vregs as follows:
#
# vreg 0: vreg 1:
# x x x x x dst_tile_2
# x x x x x dst_tile_3
# dst_tile_4 x x x x x
# dst_tile_5 x x x x x
# dst_tile_6 x x x x x
# dst_tile_7 x x x x x
# x x x x x dst_tile_0
# x x x x x dst_tile_1
#
# In the above case, the data we want to select from vreg 1 wraps around,
# whereas vreg 0 useful data is contiguous. It is easier to create '1' mask
# for vreg 0.
sublanes_mask = tpu.CreateMaskOp(
mask_vreg_ty,
map(ix_cst, [first_dst_tile_sublane_offset, 0]),
map(ix_cst, [right_first_dst_tile_sublane_offset, TARGET_SHAPE.lanes]),
)
return arith.SelectOp(sublanes_mask, left_partial_vreg, right_partial_vreg)
sublanes_mask = tpu.CreateMaskOp(
mask_vreg_ty,
map(ix_cst, [right_first_dst_tile_sublane_offset, 0]),
map(ix_cst, [first_dst_tile_sublane_offset, TARGET_SHAPE.lanes]),
)
return arith.SelectOp(sublanes_mask, right_partial_vreg, left_partial_vreg)
def is_positive_pow_2(x: int) -> bool:
"""Returns true iff the integer argument is positive and a power of 2.
Arguments:
x: an integer argument.
"""
return (x > 0) and (x & (x - 1)) == 0
def retile_to_reduced_sublanes(
value_shape: tuple[int, ...],
src_layout: VectorLayout,
src_vreg_array: np.ndarray,
dst_layout: VectorLayout,
hw_generation: int,
) -> np.ndarray:
"""Retiles across vregs to match the destination layout when the sublane tiling dimension is reduced.
Arguments:
value_shape: The shape of the value which needs to be retiled in vregs.
src_layout: The source layout.
src_vreg_array: An array of vregs storing source tiles.
dst_layout: The destination layout, with reduced sublane dimension, based on
which the retiling will be performed.
hw_generation: The generation of a target hardware.
Returns:
A new array of vregs that store tiles based on the destination layout.
"""
dst_tiling_sublane = dst_layout.tiling[-2]
assert dst_tiling_sublane > 0 and dst_tiling_sublane < src_layout.tiling[-2] and is_positive_pow_2(dst_tiling_sublane)
assert src_layout.tiling[-1] == dst_layout.tiling[-1]
dst_vreg_array = np.empty(
dst_layout.tile_array_shape(value_shape), dtype=object
)
# We need to rotate each src tile in each src vreg once so that that they can
# be merged to form new vregs. If a src vreg contains more than one src tile,
# it will be rotated once per src tile. Consider (8,512) tensor stored with
# layout (8,128) in a vreg array of shape (1, 4). Each src vreg
# contains one src tile in this case. Given, the destination layout is
# (2,128), each src tile is divided into 4 destination tiles as shown below:
#
# src_vreg_0_0: src_vreg_0_1: src_vreg_0_2: src_vreg_0_3:
# dst_tile_0_0_0 dst_tile_0_0_1 dst_tile_0_0_2 dst_tile_0_0_3
# dst_tile_1_0_0 dst_tile_1_0_1 dst_tile_1_0_2 dst_tile_1_0_3
# dst_tile_2_0_0 dst_tile_2_0_1 dst_tile_2_0_2 dst_tile_2_0_3
# dst_tile_3_0_0 dst_tile_3_0_1 dst_tile_3_0_2 dst_tile_3_0_3
# In this example, each src tile in the src vreg is rotated by
# col * sublanes_per_tile to produce the following rotated src vregs:
#
# rot_src_vreg_0_0: rot_src_vreg_0_1: rot_src_vreg_0_2: rot_src_vreg_0_3:
# dst_tile_0_0_0 dst_tile_3_0_1 dst_tile_2_0_2 dst_tile_1_0_3
# dst_tile_1_0_0 dst_tile_0_0_1 dst_tile_3_0_2 dst_tile_2_0_3
# dst_tile_2_0_0 dst_tile_1_0_1 dst_tile_0_0_2 dst_tile_3_0_3
# dst_tile_3_0_0 dst_tile_2_0_1 dst_tile_1_0_2 dst_tile_0_0_3
# If there were 2 src tiles in the src vreg, we would have rotated each src
# vreg twice, producing 2 rotated src vreg per src vreg. The rotation amount
# is calculated from the src and the dest tiling.
rotated_src_vregs_array = np.empty(
(
*(src_vreg_array.shape[:-1]),
# Each vreg may store more than one src tile. We may have to rotate a
# vreg, once for every src tile in the vreg.
src_vreg_array.shape[-1] * src_layout.tiles_per_vreg,
),
dtype=object,
)
for *other_dims, row, idx in np.ndindex(rotated_src_vregs_array.shape):
tile_idx = idx % dst_layout.tiles_per_vreg
dst_sublane = tile_idx * dst_layout.sublanes_per_tile
src_col, src_tile_offset = divmod(idx, src_layout.tiles_per_vreg)
src_vreg = src_vreg_array[(*other_dims, row, src_col)]
src_sublane = src_tile_offset * src_layout.sublanes_per_tile
rotate_amt = dst_sublane - src_sublane
if rotate_amt == 0:
rotated_src_vregs_array[(*other_dims, row, idx)] = src_vreg
continue
if rotate_amt < 0:
rotate_amt = TARGET_SHAPE.sublanes + rotate_amt
rotated_src_vregs_array[(*other_dims, row, idx)] = tpu.RotateOp(
src_vreg, amount=rotate_amt, dimension=0
)
# Assemble output vregs using tiles from rotated vregs using select.
# Given, above example, destination vregs are then assembled as follows:
# dst_vreg_0_0:
# dst_tile_0_0_0
# dst_tile_0_0_1
# dst_tile_0_0_2
# dst_tile_0_0_3
# dst_vreg_1_0: (Notice dst tiles are not in correct offset!)
# dst_tile_1_0_3
# dst_tile_1_0_0
# dst_tile_1_0_1
# dst_tile_1_0_2
# dst_vreg_2_0: (Notice dst tiles are not in correct offset!)
# dst_tile_2_0_2
# dst_tile_2_0_3
# dst_tile_2_0_0
# dst_tile_2_0_1
# dst_vreg_3_0: (Notice dst tiles are not in correct offset!)
# dst_tile_3_0_1
# dst_tile_3_0_2
# dst_tile_3_0_3
# dst_tile_3_0_0
# Each destination vreg is assembled from destination tiles in multiple
# rotated src vregs. In the above example, if we wanted each destination tile
# to be in correct sublane offset in a rotated vreg, say rot_src_vreg_0_1,
# before assembling the destination tiles, we would have had to rotate
# src_vreg_0_1 four times, creating 4 rotated vregs (instead of 1) for each
# src vreg. In the above example, we instead rotated a src vreg src_vreg_0_1
# only once to obtain rot_src_vreg_0_1 where the dst_tile_0_0_1 is in correct
# final sublane offset, i.e. 2. But notice the sublane offset of
# dst_tile_1_0_1 in the same rotated vreg. Its correct final destination
# sublane offset is 2, but in rot_src_vreg_0_1, its offset is 4. Its sublane
# offset is off by 2. We need to correct these sublane offsets in the final
# assembled dst vregs. A single rotation of each assembled dst vreg is needed
# to correct such sublane offsets. This strategy reduces the number of sublane
# rotations required. See comments below.
tile_sublane_change_factor = src_layout.tiling[-2] // dst_layout.tiling[-2]
for *other_dims, row, col in np.ndindex(dst_vreg_array.shape):
rotated_vreg_row, first_dst_tile_offset = divmod(
row, tile_sublane_change_factor
)
first_dst_tile_sublane_offset = (
first_dst_tile_offset * dst_layout.sublanes_per_tile
)
src_vreg_array_col_start = col * dst_layout.tiles_per_vreg
src_vreg_array_col_end = (
min(
((col + 1) * dst_layout.tiles_per_vreg),
rotated_src_vregs_array.shape[-1],
)
- 1
)
dst_tile = select_tiles_from_rotated_row_vregs(
rotated_row_vregs=rotated_src_vregs_array[
(*other_dims, rotated_vreg_row, slice(None))
],
start_src_col=src_vreg_array_col_start,
end_src_col=src_vreg_array_col_end,
first_dst_tile_sublane_offset=first_dst_tile_sublane_offset,
dst_layout=dst_layout,
hw_generation=hw_generation,
)
if first_dst_tile_sublane_offset == 0:
# No need to rotate. First dst tile is already at offset 0, which means
# rest of the dst tiles are also at correct sublane offset.
dst_vreg_array[(*other_dims, row, col)] = dst_tile
else:
# Fix the destination tile sublane offset by rotating assembled dest vreg
# once (See comments above). The dst vregs are fixed as follows:
# No rotation needed.
# dst_tile_0_0_0
# dst_tile_0_0_1
# dst_tile_0_0_2
# dst_tile_0_0_3
# Rotated by -1 * (sublanes_per_tile=2) * (row=1):
# dst_tile_1_0_0
# dst_tile_1_0_1
# dst_tile_1_0_2
# dst_tile_1_0_3
# Rotated by -1 * (sublanes_per_tile=2) * (row=2):
# dst_tile_2_0_0
# dst_tile_2_0_1
# dst_tile_2_0_2
# dst_tile_2_0_3
# Rotated by -1 * (sublanes_per_tile=2) * (row=3):
# dst_tile_3_0_0
# dst_tile_3_0_1
# dst_tile_3_0_2
# dst_tile_3_0_3
dst_vreg_array[(*other_dims, row, col)] = tpu.RotateOp(
dst_tile,
amount=TARGET_SHAPE.sublanes - first_dst_tile_sublane_offset,
dimension=0,
)
return dst_vreg_array
def is_supported_reduced_sublanes_retile(
src_layout: VectorLayout, dst_layout: VectorLayout
) -> bool:
"""Returns true iff the layout changes involve reduced sublanes per tile.
Arguments:
src_layout: The existing layout.
dst_layout: The new layout based on which the retiling is to be carried out.
"""
return (
src_layout.implicit_dim is None
and dst_layout.implicit_dim is None
and all(
(os or 0) == (ot or 0)
for os, ot in zip(src_layout.offsets, dst_layout.offsets)
)
# TODO (kumudbhandari): We have not tested any tile size where
# tile[-1] != TARGET_SHAPE.lanes. It should work but needs to be tested.
and src_layout.tiling[-1] == dst_layout.tiling[-1] == TARGET_SHAPE.lanes
and dst_layout.tiling[-2] < src_layout.tiling[-2]
and src_layout.bitwidth == dst_layout.bitwidth
and is_positive_pow_2(src_layout.tiling[-2])
and is_positive_pow_2(dst_layout.tiling[-2])
)
# TODO(apaszke): Test this function properly
def relayout(
v: ir.Value, src: VectorLayout, dst: VectorLayout, hw_generation: int
@ -947,136 +1263,16 @@ def relayout(
)
src = new_src
src_tiles = src_tiles_retiled
# (16,128) -> (8,128) tiling change for packed 16-bit types.
if (
src.implicit_dim is None
and dst.implicit_dim is None
and src.offsets == dst.offsets
and ir.BF16Type.isinstance(vty.element_type)
and src.tiling == (16, 128)
and dst.tiling == (8, 128)
):
new_src = VectorLayout(src.bitwidth, src.offsets, dst.tiling, None)
src_tiles_retiled = np.empty(
new_src.tile_array_shape(vty.shape), dtype=object)
for (*batch_idx, dst_row, dst_col) in np.ndindex(src_tiles_retiled.shape):
src_row1 = src_tiles[(*batch_idx, dst_row // 2, dst_col * 2)]
src_row2_col = min(dst_col * 2 + 1, src_tiles.shape[-1] - 1)
src_row2 = src_tiles[(*batch_idx, dst_row // 2, src_row2_col)]
vreg_part = dst_row % 2
vreg_f32 = ir.VectorType.get(TARGET_SHAPE, ir.F32Type.get())
half_row1 = tpu.UnpackSubelementsOp(vreg_f32, src_row1, vreg_part)
half_row2 = tpu.UnpackSubelementsOp(vreg_f32, src_row2, vreg_part)
src_tiles_retiled[(*batch_idx, dst_row, dst_col)] = tpu.PackSubelementsOp(
src_row1.type, [half_row1, half_row2]
)
src = new_src
src_tiles = src_tiles_retiled
# Handle retiling from (8, 128) to (1, 128) for 32 bits data.
if (
src.implicit_dim is None
and dst.implicit_dim is None
and type_bitwidth(vty.element_type) == 32
and all((o or 0) == 0 for o in src.offsets)
and dst.offsets == (0, 0)
and src.tiling == (8, 128)
and dst.tiling == (1, 128)
):
new_src = VectorLayout(src.bitwidth, dst.offsets, dst.tiling, None)
src_tiles_retiled = np.empty(
new_src.tile_array_shape(vty.shape), dtype=object
if is_supported_reduced_sublanes_retile(src, dst):
src_tiles = retile_to_reduced_sublanes(
value_shape=vty.shape,
src_layout=src,
src_vreg_array=src_tiles,
dst_layout=dst,
hw_generation=hw_generation,
)
for *batch_idx, dst_row, dst_col in np.ndindex(src_tiles_retiled.shape):
src_row = dst_row // 8
src_col_0 = dst_col * 8
src_tile_vreg_count = (
8 if src_col_0 + 8 <= src_tiles.shape[-1] else src_tiles.shape[-1] % 8
)
src_tiles_retiled[(*batch_idx, dst_row, dst_col)] = src_tiles[
(*batch_idx, src_row, src_col_0)
]
for i in range(1, src_tile_vreg_count):
bounds = RectangularVRegBounds(
TargetTuple(slice(i, i + 1), slice(0, TARGET_SHAPE.lanes))
)
mask = bounds.get_vector_mask(hw_generation)
src_tile = src_tiles[(*batch_idx, src_row, src_col_0 + i)]
src_tile = tpu.RotateOp(src_tile, amount=i, dimension=0)
src_tiles_retiled[(*batch_idx, dst_row, dst_col)] = arith.SelectOp(
mask, src_tile, src_tiles_retiled[(*batch_idx, dst_row, dst_col)]
)
src = new_src
src_tiles = src_tiles_retiled
# TODO(kumudbhandari): Generalize the logic below to handle retiling from
# (8,128) to (x, 128) where x=1, 2, or 4.
if (
src.implicit_dim is None
and dst.implicit_dim is None
and src.offsets == dst.offsets
and src.bitwidth != 16
and src.tiling == (8, 128)
and dst.tiling == (4, 128)
):
retiled_src_layout = VectorLayout(
src.bitwidth, src.offsets, dst.tiling, None
)
retiled_src_tiles = np.empty(
retiled_src_layout.tile_array_shape(vty.shape), dtype=object
)
# Consider a value of type and shape: f32(8, 256). Retiling from (8,128) to
# (4,128):
# vreg (tile) array shape (1, 2), with original (8,128) tiling:
# vreg_0_0: slice:[0:7, 0:127] vreg_0_1: slice:[0:7, 128:255]
#
# vreg (tile) array shape: (2, 1), with (4,128) retiling:
# vreg_0_0: slice: [0:3, 0:127], slice: [0:3, 128:255]
# vreg_1_0: slice:[4:7, 0:127], slice: [4:7, 128:255]
for *other_idices, retiled_row_idx, retiled_col_idx in np.ndindex(
retiled_src_tiles.shape
):
# The first src tile, half of which forms the first half of the retiled
# tile(retiled_row_idx, retiled_col_idx).
src_tile_row_idx = retiled_row_idx // 2
src_tile_col_idx_1 = retiled_col_idx * 2
src_tile_1 = src_tiles[
(*other_idices, src_tile_row_idx, src_tile_col_idx_1)
]
# The second src tile, half of which forms the second half of the retiled
# tile(retiled_row_idx, retiled_col_idx).
src_tile_col_idx_2 = min(src_tile_col_idx_1 + 1, src_tiles.shape[-1] - 1)
src_tile_2 = src_tiles[
(*other_idices, src_tile_row_idx, src_tile_col_idx_2)
]
# Each (retiled_row_idx)th tile is formed from 2 top or 2 bottom half
# sublanes of the original tile.
# We need to rotate sublanes of one of the two tiles to push either a top
# half to the bottom or vice-versa.
tile_to_merge_1, tile_to_merge_2 = (
(src_tile_1, tpu.RotateOp(src_tile_2, amount=4, dimension=0))
if retiled_row_idx % 2 == 0
else (tpu.RotateOp(src_tile_1, amount=4, dimension=0), src_tile_2)
)
# Create a mask to select first half from tile 1 and second half of data
# from tile 2 to be merged.
vreg_half_bound = RectangularVRegBounds(
TargetTuple(
slice(0, TARGET_SHAPE.sublanes // 2), slice(0, TARGET_SHAPE.lanes)
)
)
vreg_select_mask = vreg_half_bound.get_vector_mask(hw_generation)
retiled_src_tiles[(*other_idices, retiled_row_idx, retiled_col_idx)] = (
arith.SelectOp(vreg_select_mask, tile_to_merge_1, tile_to_merge_2)
)
src = retiled_src_layout
src_tiles = retiled_src_tiles
src = dst
# Fix up the offsets, assuming everything else matches between src and dst.
if src.tiling == dst.tiling and src.implicit_dim == dst.implicit_dim: