[Mosaic] Always use 32-bit selects while retiling

Retiling never needs to use packed masks, and those aren't supported on all TPUs.

PiperOrigin-RevId: 627692517
This commit is contained in:
Adam Paszke 2024-04-24 05:11:16 -07:00 committed by jax authors
parent 5a2d7a2df4
commit a72a204c39

View File

@ -3932,6 +3932,7 @@ Value selectTilesFromRotatedRowVregs(
Value left_partial_vreg = selectTilesFromRotatedRowVregs(
builder, rotated_row_vregs, start_src_col, mid_src_col,
first_dst_tile_sublane_offset, dst_layout, target_shape);
Location loc = left_partial_vreg.getLoc();
const int64_t left_tiles_count = mid_src_col - start_src_col + 1;
const int64_t right_first_dst_tile_sublane_offset =
@ -3944,12 +3945,23 @@ Value selectTilesFromRotatedRowVregs(
right_first_dst_tile_sublane_offset, dst_layout, target_shape);
const IntegerType i1 = builder.getI1Type();
const auto mask_vreg_ty =
dst_layout.packing() > 1
? VectorType::get(ArrayRef<int64_t>{target_shape[0], target_shape[1],
dst_layout.packing()},
i1)
: VectorType::get(target_shape, i1);
// We never need to select partial sublanes, even for packed data.
const auto mask_vreg_ty = VectorType::get(target_shape, i1);
auto i32_vreg = VectorType::get(target_shape, builder.getI32Type());
auto select_32bit = [&](Value sublane_mask, Value left, Value right) {
// Always do the selects on 32-bit granularity for maximum HW compatibility.
Type vreg_ty = left.getType();
if (dst_layout.packing() != 1) {
left = builder.create<tpu::BitcastVregOp>(loc, i32_vreg, left);
right = builder.create<tpu::BitcastVregOp>(loc, i32_vreg, right);
}
Value result =
builder.create<arith::SelectOp>(loc, sublane_mask, left, right);
if (dst_layout.packing() != 1) {
result = builder.create<tpu::BitcastVregOp>(loc, vreg_ty, result);
}
return result;
};
auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder,
left_partial_vreg.getLoc());
@ -3977,9 +3989,7 @@ Value selectTilesFromRotatedRowVregs(
boundIdxConst(0)},
ArrayRef<Value>{boundIdxConst(right_first_dst_tile_sublane_offset),
boundIdxConst(target_shape[1])});
return builder.create<arith::SelectOp>(left_partial_vreg.getLoc(),
sublanes_mask, left_partial_vreg,
right_partial_vreg);
return select_32bit(sublanes_mask, left_partial_vreg, right_partial_vreg);
}
auto sublanes_mask = builder.create<tpu::CreateMaskOp>(
@ -3988,9 +3998,7 @@ Value selectTilesFromRotatedRowVregs(
boundIdxConst(0)},
ArrayRef<Value>{boundIdxConst(first_dst_tile_sublane_offset),
boundIdxConst(target_shape[1])});
return builder.create<arith::SelectOp>(left_partial_vreg.getLoc(),
sublanes_mask, right_partial_vreg,
left_partial_vreg);
return select_32bit(sublanes_mask, right_partial_vreg, left_partial_vreg);
}
// Retiles across vregs to match the destination layout when the sublane tiling