mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Support relayout of tiles in register when the layout tiling changes.
PiperOrigin-RevId: 563570338
This commit is contained in:
parent
a6eed40f24
commit
bfd79b84e4
@ -854,6 +854,322 @@ def print_layout(layout: Layout) -> str:
|
||||
return f'#tpu.vpad<"{layout.bitwidth},{{{o[0]},{o[1]}}},({layout.tiling[0]},{layout.tiling[1]}){implicit_dim}">'
|
||||
|
||||
|
||||
def select_tiles_from_rotated_row_vregs(
|
||||
rotated_row_vregs: np.ndarray,
|
||||
start_src_col: int,
|
||||
end_src_col: int,
|
||||
first_dst_tile_sublane_offset: int,
|
||||
dst_layout: VectorLayout,
|
||||
hw_generation: int,
|
||||
) -> ValueLike:
|
||||
"""Assembles a destination tile using partial data from rotated vregs using a divide-and-conquer strategy.
|
||||
|
||||
Arguments:
|
||||
rotated_row_vregs: A row of rotated vregs, from which destination tile(s)
|
||||
is/are to be selected to assemble a new vreg.
|
||||
src_layout: The source layout.
|
||||
start_src_col: The first rotated vreg in the row of rotated vregs to
|
||||
process.
|
||||
end_src_col: The last rotated vreg in the row of rotated vreg to process.
|
||||
first_dst_tile_sublane_offset: Sublane offset where the first dst tile to be
|
||||
selected starts.
|
||||
dst_layout: Destination layout, based on which retiling is being performed.
|
||||
hw_generation: The generation of a target hardware.
|
||||
|
||||
Returns:
|
||||
A new vreg assembled from dst tiles stored in given rotated vregs.
|
||||
"""
|
||||
|
||||
if start_src_col > end_src_col:
|
||||
raise ValueError("Invalid values for start and end column.")
|
||||
if start_src_col == end_src_col:
|
||||
return rotated_row_vregs[start_src_col]
|
||||
|
||||
mid_src_col = start_src_col + (end_src_col - start_src_col) // 2
|
||||
|
||||
left_partial_vreg = select_tiles_from_rotated_row_vregs(
|
||||
rotated_row_vregs,
|
||||
start_src_col,
|
||||
mid_src_col,
|
||||
first_dst_tile_sublane_offset,
|
||||
dst_layout,
|
||||
hw_generation,
|
||||
)
|
||||
|
||||
left_tiles_count = mid_src_col - start_src_col + 1
|
||||
right_first_dst_tile_sublane_offset = (
|
||||
first_dst_tile_sublane_offset
|
||||
+ left_tiles_count * dst_layout.sublanes_per_tile
|
||||
) % TARGET_SHAPE.sublanes
|
||||
|
||||
right_partial_vreg = select_tiles_from_rotated_row_vregs(
|
||||
rotated_row_vregs,
|
||||
mid_src_col + 1,
|
||||
end_src_col,
|
||||
right_first_dst_tile_sublane_offset,
|
||||
dst_layout,
|
||||
hw_generation,
|
||||
)
|
||||
|
||||
i1 = ir.IntegerType.get_signless(1)
|
||||
mask_vreg_ty = (
|
||||
ir.VectorType.get((*TARGET_SHAPE, 2), i1)
|
||||
if dst_layout.packing == 2
|
||||
else ir.VectorType.get(TARGET_SHAPE, i1)
|
||||
)
|
||||
|
||||
if first_dst_tile_sublane_offset < right_first_dst_tile_sublane_offset:
|
||||
# The useful data sublanes in left vregs do not wrap around in vreg.
|
||||
# For e.g. consider (2,128) destination tiling and we are trying to merge
|
||||
# two vregs as follows:
|
||||
#
|
||||
# vreg 0: vreg 1:
|
||||
# x x x x x dst_tile_2
|
||||
# x x x x x dst_tile_3
|
||||
# dst_tile_4 x x x x x
|
||||
# dst_tile_5 x x x x x
|
||||
# dst_tile_6 x x x x x
|
||||
# dst_tile_7 x x x x x
|
||||
# x x x x x dst_tile_0
|
||||
# x x x x x dst_tile_1
|
||||
#
|
||||
# In the above case, the data we want to select from vreg 1 wraps around,
|
||||
# whereas vreg 0 useful data is contiguous. It is easier to create '1' mask
|
||||
# for vreg 0.
|
||||
sublanes_mask = tpu.CreateMaskOp(
|
||||
mask_vreg_ty,
|
||||
map(ix_cst, [first_dst_tile_sublane_offset, 0]),
|
||||
map(ix_cst, [right_first_dst_tile_sublane_offset, TARGET_SHAPE.lanes]),
|
||||
)
|
||||
return arith.SelectOp(sublanes_mask, left_partial_vreg, right_partial_vreg)
|
||||
|
||||
sublanes_mask = tpu.CreateMaskOp(
|
||||
mask_vreg_ty,
|
||||
map(ix_cst, [right_first_dst_tile_sublane_offset, 0]),
|
||||
map(ix_cst, [first_dst_tile_sublane_offset, TARGET_SHAPE.lanes]),
|
||||
)
|
||||
return arith.SelectOp(sublanes_mask, right_partial_vreg, left_partial_vreg)
|
||||
|
||||
|
||||
def is_positive_pow_2(x: int) -> bool:
|
||||
"""Returns true iff the integer argument is positive and a power of 2.
|
||||
|
||||
Arguments:
|
||||
x: an integer argument.
|
||||
"""
|
||||
return (x > 0) and (x & (x - 1)) == 0
|
||||
|
||||
|
||||
def retile_to_reduced_sublanes(
|
||||
value_shape: tuple[int, ...],
|
||||
src_layout: VectorLayout,
|
||||
src_vreg_array: np.ndarray,
|
||||
dst_layout: VectorLayout,
|
||||
hw_generation: int,
|
||||
) -> np.ndarray:
|
||||
"""Retiles across vregs to match the destination layout when the sublane tiling dimension is reduced.
|
||||
|
||||
Arguments:
|
||||
value_shape: The shape of the value which needs to be retiled in vregs.
|
||||
src_layout: The source layout.
|
||||
src_vreg_array: An array of vregs storing source tiles.
|
||||
dst_layout: The destination layout, with reduced sublane dimension, based on
|
||||
which the retiling will be performed.
|
||||
hw_generation: The generation of a target hardware.
|
||||
|
||||
Returns:
|
||||
A new array of vregs that store tiles based on the destination layout.
|
||||
"""
|
||||
dst_tiling_sublane = dst_layout.tiling[-2]
|
||||
assert dst_tiling_sublane > 0 and dst_tiling_sublane < src_layout.tiling[-2] and is_positive_pow_2(dst_tiling_sublane)
|
||||
assert src_layout.tiling[-1] == dst_layout.tiling[-1]
|
||||
|
||||
dst_vreg_array = np.empty(
|
||||
dst_layout.tile_array_shape(value_shape), dtype=object
|
||||
)
|
||||
|
||||
# We need to rotate each src tile in each src vreg once so that that they can
|
||||
# be merged to form new vregs. If a src vreg contains more than one src tile,
|
||||
# it will be rotated once per src tile. Consider (8,512) tensor stored with
|
||||
# layout (8,128) in a vreg array of shape (1, 4). Each src vreg
|
||||
# contains one src tile in this case. Given, the destination layout is
|
||||
# (2,128), each src tile is divided into 4 destination tiles as shown below:
|
||||
#
|
||||
# src_vreg_0_0: src_vreg_0_1: src_vreg_0_2: src_vreg_0_3:
|
||||
# dst_tile_0_0_0 dst_tile_0_0_1 dst_tile_0_0_2 dst_tile_0_0_3
|
||||
# dst_tile_1_0_0 dst_tile_1_0_1 dst_tile_1_0_2 dst_tile_1_0_3
|
||||
# dst_tile_2_0_0 dst_tile_2_0_1 dst_tile_2_0_2 dst_tile_2_0_3
|
||||
# dst_tile_3_0_0 dst_tile_3_0_1 dst_tile_3_0_2 dst_tile_3_0_3
|
||||
|
||||
# In this example, each src tile in the src vreg is rotated by
|
||||
# col * sublanes_per_tile to produce the following rotated src vregs:
|
||||
#
|
||||
# rot_src_vreg_0_0: rot_src_vreg_0_1: rot_src_vreg_0_2: rot_src_vreg_0_3:
|
||||
# dst_tile_0_0_0 dst_tile_3_0_1 dst_tile_2_0_2 dst_tile_1_0_3
|
||||
# dst_tile_1_0_0 dst_tile_0_0_1 dst_tile_3_0_2 dst_tile_2_0_3
|
||||
# dst_tile_2_0_0 dst_tile_1_0_1 dst_tile_0_0_2 dst_tile_3_0_3
|
||||
# dst_tile_3_0_0 dst_tile_2_0_1 dst_tile_1_0_2 dst_tile_0_0_3
|
||||
|
||||
# If there were 2 src tiles in the src vreg, we would have rotated each src
|
||||
# vreg twice, producing 2 rotated src vreg per src vreg. The rotation amount
|
||||
# is calculated from the src and the dest tiling.
|
||||
|
||||
rotated_src_vregs_array = np.empty(
|
||||
(
|
||||
*(src_vreg_array.shape[:-1]),
|
||||
# Each vreg may store more than one src tile. We may have to rotate a
|
||||
# vreg, once for every src tile in the vreg.
|
||||
src_vreg_array.shape[-1] * src_layout.tiles_per_vreg,
|
||||
),
|
||||
dtype=object,
|
||||
)
|
||||
|
||||
for *other_dims, row, idx in np.ndindex(rotated_src_vregs_array.shape):
|
||||
tile_idx = idx % dst_layout.tiles_per_vreg
|
||||
dst_sublane = tile_idx * dst_layout.sublanes_per_tile
|
||||
src_col, src_tile_offset = divmod(idx, src_layout.tiles_per_vreg)
|
||||
src_vreg = src_vreg_array[(*other_dims, row, src_col)]
|
||||
src_sublane = src_tile_offset * src_layout.sublanes_per_tile
|
||||
rotate_amt = dst_sublane - src_sublane
|
||||
if rotate_amt == 0:
|
||||
rotated_src_vregs_array[(*other_dims, row, idx)] = src_vreg
|
||||
continue
|
||||
if rotate_amt < 0:
|
||||
rotate_amt = TARGET_SHAPE.sublanes + rotate_amt
|
||||
rotated_src_vregs_array[(*other_dims, row, idx)] = tpu.RotateOp(
|
||||
src_vreg, amount=rotate_amt, dimension=0
|
||||
)
|
||||
# Assemble output vregs using tiles from rotated vregs using select.
|
||||
# Given, above example, destination vregs are then assembled as follows:
|
||||
# dst_vreg_0_0:
|
||||
# dst_tile_0_0_0
|
||||
# dst_tile_0_0_1
|
||||
# dst_tile_0_0_2
|
||||
# dst_tile_0_0_3
|
||||
|
||||
# dst_vreg_1_0: (Notice dst tiles are not in correct offset!)
|
||||
# dst_tile_1_0_3
|
||||
# dst_tile_1_0_0
|
||||
# dst_tile_1_0_1
|
||||
# dst_tile_1_0_2
|
||||
|
||||
# dst_vreg_2_0: (Notice dst tiles are not in correct offset!)
|
||||
# dst_tile_2_0_2
|
||||
# dst_tile_2_0_3
|
||||
# dst_tile_2_0_0
|
||||
# dst_tile_2_0_1
|
||||
|
||||
# dst_vreg_3_0: (Notice dst tiles are not in correct offset!)
|
||||
# dst_tile_3_0_1
|
||||
# dst_tile_3_0_2
|
||||
# dst_tile_3_0_3
|
||||
# dst_tile_3_0_0
|
||||
|
||||
# Each destination vreg is assembled from destination tiles in multiple
|
||||
# rotated src vregs. In the above example, if we wanted each destination tile
|
||||
# to be in correct sublane offset in a rotated vreg, say rot_src_vreg_0_1,
|
||||
# before assembling the destination tiles, we would have had to rotate
|
||||
# src_vreg_0_1 four times, creating 4 rotated vregs (instead of 1) for each
|
||||
# src vreg. In the above example, we instead rotated a src vreg src_vreg_0_1
|
||||
# only once to obtain rot_src_vreg_0_1 where the dst_tile_0_0_1 is in correct
|
||||
# final sublane offset, i.e. 2. But notice the sublane offset of
|
||||
# dst_tile_1_0_1 in the same rotated vreg. Its correct final destination
|
||||
# sublane offset is 2, but in rot_src_vreg_0_1, its offset is 4. Its sublane
|
||||
# offset is off by 2. We need to correct these sublane offsets in the final
|
||||
# assembled dst vregs. A single rotation of each assembled dst vreg is needed
|
||||
# to correct such sublane offsets. This strategy reduces the number of sublane
|
||||
# rotations required. See comments below.
|
||||
tile_sublane_change_factor = src_layout.tiling[-2] // dst_layout.tiling[-2]
|
||||
for *other_dims, row, col in np.ndindex(dst_vreg_array.shape):
|
||||
rotated_vreg_row, first_dst_tile_offset = divmod(
|
||||
row, tile_sublane_change_factor
|
||||
)
|
||||
first_dst_tile_sublane_offset = (
|
||||
first_dst_tile_offset * dst_layout.sublanes_per_tile
|
||||
)
|
||||
src_vreg_array_col_start = col * dst_layout.tiles_per_vreg
|
||||
src_vreg_array_col_end = (
|
||||
min(
|
||||
((col + 1) * dst_layout.tiles_per_vreg),
|
||||
rotated_src_vregs_array.shape[-1],
|
||||
)
|
||||
- 1
|
||||
)
|
||||
dst_tile = select_tiles_from_rotated_row_vregs(
|
||||
rotated_row_vregs=rotated_src_vregs_array[
|
||||
(*other_dims, rotated_vreg_row, slice(None))
|
||||
],
|
||||
start_src_col=src_vreg_array_col_start,
|
||||
end_src_col=src_vreg_array_col_end,
|
||||
first_dst_tile_sublane_offset=first_dst_tile_sublane_offset,
|
||||
dst_layout=dst_layout,
|
||||
hw_generation=hw_generation,
|
||||
)
|
||||
if first_dst_tile_sublane_offset == 0:
|
||||
# No need to rotate. First dst tile is already at offset 0, which means
|
||||
# rest of the dst tiles are also at correct sublane offset.
|
||||
dst_vreg_array[(*other_dims, row, col)] = dst_tile
|
||||
else:
|
||||
# Fix the destination tile sublane offset by rotating assembled dest vreg
|
||||
# once (See comments above). The dst vregs are fixed as follows:
|
||||
# No rotation needed.
|
||||
# dst_tile_0_0_0
|
||||
# dst_tile_0_0_1
|
||||
# dst_tile_0_0_2
|
||||
# dst_tile_0_0_3
|
||||
|
||||
# Rotated by -1 * (sublanes_per_tile=2) * (row=1):
|
||||
# dst_tile_1_0_0
|
||||
# dst_tile_1_0_1
|
||||
# dst_tile_1_0_2
|
||||
# dst_tile_1_0_3
|
||||
|
||||
# Rotated by -1 * (sublanes_per_tile=2) * (row=2):
|
||||
# dst_tile_2_0_0
|
||||
# dst_tile_2_0_1
|
||||
# dst_tile_2_0_2
|
||||
# dst_tile_2_0_3
|
||||
|
||||
# Rotated by -1 * (sublanes_per_tile=2) * (row=3):
|
||||
# dst_tile_3_0_0
|
||||
# dst_tile_3_0_1
|
||||
# dst_tile_3_0_2
|
||||
# dst_tile_3_0_3
|
||||
dst_vreg_array[(*other_dims, row, col)] = tpu.RotateOp(
|
||||
dst_tile,
|
||||
amount=TARGET_SHAPE.sublanes - first_dst_tile_sublane_offset,
|
||||
dimension=0,
|
||||
)
|
||||
return dst_vreg_array
|
||||
|
||||
|
||||
def is_supported_reduced_sublanes_retile(
|
||||
src_layout: VectorLayout, dst_layout: VectorLayout
|
||||
) -> bool:
|
||||
"""Returns true iff the layout changes involve reduced sublanes per tile.
|
||||
|
||||
Arguments:
|
||||
src_layout: The existing layout.
|
||||
dst_layout: The new layout based on which the retiling is to be carried out.
|
||||
"""
|
||||
|
||||
return (
|
||||
src_layout.implicit_dim is None
|
||||
and dst_layout.implicit_dim is None
|
||||
and all(
|
||||
(os or 0) == (ot or 0)
|
||||
for os, ot in zip(src_layout.offsets, dst_layout.offsets)
|
||||
)
|
||||
# TODO (kumudbhandari): We have not tested any tile size where
|
||||
# tile[-1] != TARGET_SHAPE.lanes. It should work but needs to be tested.
|
||||
and src_layout.tiling[-1] == dst_layout.tiling[-1] == TARGET_SHAPE.lanes
|
||||
and dst_layout.tiling[-2] < src_layout.tiling[-2]
|
||||
and src_layout.bitwidth == dst_layout.bitwidth
|
||||
and is_positive_pow_2(src_layout.tiling[-2])
|
||||
and is_positive_pow_2(dst_layout.tiling[-2])
|
||||
)
|
||||
|
||||
|
||||
# TODO(apaszke): Test this function properly
|
||||
def relayout(
|
||||
v: ir.Value, src: VectorLayout, dst: VectorLayout, hw_generation: int
|
||||
@ -947,136 +1263,16 @@ def relayout(
|
||||
)
|
||||
src = new_src
|
||||
src_tiles = src_tiles_retiled
|
||||
# (16,128) -> (8,128) tiling change for packed 16-bit types.
|
||||
if (
|
||||
src.implicit_dim is None
|
||||
and dst.implicit_dim is None
|
||||
and src.offsets == dst.offsets
|
||||
and ir.BF16Type.isinstance(vty.element_type)
|
||||
and src.tiling == (16, 128)
|
||||
and dst.tiling == (8, 128)
|
||||
):
|
||||
new_src = VectorLayout(src.bitwidth, src.offsets, dst.tiling, None)
|
||||
src_tiles_retiled = np.empty(
|
||||
new_src.tile_array_shape(vty.shape), dtype=object)
|
||||
for (*batch_idx, dst_row, dst_col) in np.ndindex(src_tiles_retiled.shape):
|
||||
src_row1 = src_tiles[(*batch_idx, dst_row // 2, dst_col * 2)]
|
||||
src_row2_col = min(dst_col * 2 + 1, src_tiles.shape[-1] - 1)
|
||||
src_row2 = src_tiles[(*batch_idx, dst_row // 2, src_row2_col)]
|
||||
|
||||
vreg_part = dst_row % 2
|
||||
vreg_f32 = ir.VectorType.get(TARGET_SHAPE, ir.F32Type.get())
|
||||
half_row1 = tpu.UnpackSubelementsOp(vreg_f32, src_row1, vreg_part)
|
||||
half_row2 = tpu.UnpackSubelementsOp(vreg_f32, src_row2, vreg_part)
|
||||
src_tiles_retiled[(*batch_idx, dst_row, dst_col)] = tpu.PackSubelementsOp(
|
||||
src_row1.type, [half_row1, half_row2]
|
||||
)
|
||||
src = new_src
|
||||
src_tiles = src_tiles_retiled
|
||||
|
||||
# Handle retiling from (8, 128) to (1, 128) for 32 bits data.
|
||||
if (
|
||||
src.implicit_dim is None
|
||||
and dst.implicit_dim is None
|
||||
and type_bitwidth(vty.element_type) == 32
|
||||
and all((o or 0) == 0 for o in src.offsets)
|
||||
and dst.offsets == (0, 0)
|
||||
and src.tiling == (8, 128)
|
||||
and dst.tiling == (1, 128)
|
||||
):
|
||||
new_src = VectorLayout(src.bitwidth, dst.offsets, dst.tiling, None)
|
||||
src_tiles_retiled = np.empty(
|
||||
new_src.tile_array_shape(vty.shape), dtype=object
|
||||
if is_supported_reduced_sublanes_retile(src, dst):
|
||||
src_tiles = retile_to_reduced_sublanes(
|
||||
value_shape=vty.shape,
|
||||
src_layout=src,
|
||||
src_vreg_array=src_tiles,
|
||||
dst_layout=dst,
|
||||
hw_generation=hw_generation,
|
||||
)
|
||||
for *batch_idx, dst_row, dst_col in np.ndindex(src_tiles_retiled.shape):
|
||||
src_row = dst_row // 8
|
||||
src_col_0 = dst_col * 8
|
||||
src_tile_vreg_count = (
|
||||
8 if src_col_0 + 8 <= src_tiles.shape[-1] else src_tiles.shape[-1] % 8
|
||||
)
|
||||
src_tiles_retiled[(*batch_idx, dst_row, dst_col)] = src_tiles[
|
||||
(*batch_idx, src_row, src_col_0)
|
||||
]
|
||||
for i in range(1, src_tile_vreg_count):
|
||||
bounds = RectangularVRegBounds(
|
||||
TargetTuple(slice(i, i + 1), slice(0, TARGET_SHAPE.lanes))
|
||||
)
|
||||
mask = bounds.get_vector_mask(hw_generation)
|
||||
src_tile = src_tiles[(*batch_idx, src_row, src_col_0 + i)]
|
||||
src_tile = tpu.RotateOp(src_tile, amount=i, dimension=0)
|
||||
src_tiles_retiled[(*batch_idx, dst_row, dst_col)] = arith.SelectOp(
|
||||
mask, src_tile, src_tiles_retiled[(*batch_idx, dst_row, dst_col)]
|
||||
)
|
||||
src = new_src
|
||||
src_tiles = src_tiles_retiled
|
||||
|
||||
# TODO(kumudbhandari): Generalize the logic below to handle retiling from
|
||||
# (8,128) to (x, 128) where x=1, 2, or 4.
|
||||
if (
|
||||
src.implicit_dim is None
|
||||
and dst.implicit_dim is None
|
||||
and src.offsets == dst.offsets
|
||||
and src.bitwidth != 16
|
||||
and src.tiling == (8, 128)
|
||||
and dst.tiling == (4, 128)
|
||||
):
|
||||
retiled_src_layout = VectorLayout(
|
||||
src.bitwidth, src.offsets, dst.tiling, None
|
||||
)
|
||||
retiled_src_tiles = np.empty(
|
||||
retiled_src_layout.tile_array_shape(vty.shape), dtype=object
|
||||
)
|
||||
|
||||
# Consider a value of type and shape: f32(8, 256). Retiling from (8,128) to
|
||||
# (4,128):
|
||||
# vreg (tile) array shape (1, 2), with original (8,128) tiling:
|
||||
# vreg_0_0: slice:[0:7, 0:127] vreg_0_1: slice:[0:7, 128:255]
|
||||
#
|
||||
# vreg (tile) array shape: (2, 1), with (4,128) retiling:
|
||||
# vreg_0_0: slice: [0:3, 0:127], slice: [0:3, 128:255]
|
||||
# vreg_1_0: slice:[4:7, 0:127], slice: [4:7, 128:255]
|
||||
for *other_idices, retiled_row_idx, retiled_col_idx in np.ndindex(
|
||||
retiled_src_tiles.shape
|
||||
):
|
||||
# The first src tile, half of which forms the first half of the retiled
|
||||
# tile(retiled_row_idx, retiled_col_idx).
|
||||
src_tile_row_idx = retiled_row_idx // 2
|
||||
src_tile_col_idx_1 = retiled_col_idx * 2
|
||||
|
||||
src_tile_1 = src_tiles[
|
||||
(*other_idices, src_tile_row_idx, src_tile_col_idx_1)
|
||||
]
|
||||
|
||||
# The second src tile, half of which forms the second half of the retiled
|
||||
# tile(retiled_row_idx, retiled_col_idx).
|
||||
src_tile_col_idx_2 = min(src_tile_col_idx_1 + 1, src_tiles.shape[-1] - 1)
|
||||
src_tile_2 = src_tiles[
|
||||
(*other_idices, src_tile_row_idx, src_tile_col_idx_2)
|
||||
]
|
||||
|
||||
# Each (retiled_row_idx)th tile is formed from 2 top or 2 bottom half
|
||||
# sublanes of the original tile.
|
||||
# We need to rotate sublanes of one of the two tiles to push either a top
|
||||
# half to the bottom or vice-versa.
|
||||
tile_to_merge_1, tile_to_merge_2 = (
|
||||
(src_tile_1, tpu.RotateOp(src_tile_2, amount=4, dimension=0))
|
||||
if retiled_row_idx % 2 == 0
|
||||
else (tpu.RotateOp(src_tile_1, amount=4, dimension=0), src_tile_2)
|
||||
)
|
||||
# Create a mask to select first half from tile 1 and second half of data
|
||||
# from tile 2 to be merged.
|
||||
vreg_half_bound = RectangularVRegBounds(
|
||||
TargetTuple(
|
||||
slice(0, TARGET_SHAPE.sublanes // 2), slice(0, TARGET_SHAPE.lanes)
|
||||
)
|
||||
)
|
||||
vreg_select_mask = vreg_half_bound.get_vector_mask(hw_generation)
|
||||
retiled_src_tiles[(*other_idices, retiled_row_idx, retiled_col_idx)] = (
|
||||
arith.SelectOp(vreg_select_mask, tile_to_merge_1, tile_to_merge_2)
|
||||
)
|
||||
|
||||
src = retiled_src_layout
|
||||
src_tiles = retiled_src_tiles
|
||||
src = dst
|
||||
|
||||
# Fix up the offsets, assuming everything else matches between src and dst.
|
||||
if src.tiling == dst.tiling and src.implicit_dim == dst.implicit_dim:
|
||||
|
Loading…
x
Reference in New Issue
Block a user