mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Always use 32-bit selects while retiling
Retiling never needs to use packed masks, and those aren't supported on all TPUs. PiperOrigin-RevId: 627692517
This commit is contained in:
parent
5a2d7a2df4
commit
a72a204c39
@ -3932,6 +3932,7 @@ Value selectTilesFromRotatedRowVregs(
|
||||
Value left_partial_vreg = selectTilesFromRotatedRowVregs(
|
||||
builder, rotated_row_vregs, start_src_col, mid_src_col,
|
||||
first_dst_tile_sublane_offset, dst_layout, target_shape);
|
||||
Location loc = left_partial_vreg.getLoc();
|
||||
|
||||
const int64_t left_tiles_count = mid_src_col - start_src_col + 1;
|
||||
const int64_t right_first_dst_tile_sublane_offset =
|
||||
@ -3944,12 +3945,23 @@ Value selectTilesFromRotatedRowVregs(
|
||||
right_first_dst_tile_sublane_offset, dst_layout, target_shape);
|
||||
|
||||
const IntegerType i1 = builder.getI1Type();
|
||||
const auto mask_vreg_ty =
|
||||
dst_layout.packing() > 1
|
||||
? VectorType::get(ArrayRef<int64_t>{target_shape[0], target_shape[1],
|
||||
dst_layout.packing()},
|
||||
i1)
|
||||
: VectorType::get(target_shape, i1);
|
||||
// We never need to select partial sublanes, even for packed data.
|
||||
const auto mask_vreg_ty = VectorType::get(target_shape, i1);
|
||||
auto i32_vreg = VectorType::get(target_shape, builder.getI32Type());
|
||||
auto select_32bit = [&](Value sublane_mask, Value left, Value right) {
|
||||
// Always do the selects on 32-bit granularity for maximum HW compatibility.
|
||||
Type vreg_ty = left.getType();
|
||||
if (dst_layout.packing() != 1) {
|
||||
left = builder.create<tpu::BitcastVregOp>(loc, i32_vreg, left);
|
||||
right = builder.create<tpu::BitcastVregOp>(loc, i32_vreg, right);
|
||||
}
|
||||
Value result =
|
||||
builder.create<arith::SelectOp>(loc, sublane_mask, left, right);
|
||||
if (dst_layout.packing() != 1) {
|
||||
result = builder.create<tpu::BitcastVregOp>(loc, vreg_ty, result);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder,
|
||||
left_partial_vreg.getLoc());
|
||||
@ -3977,9 +3989,7 @@ Value selectTilesFromRotatedRowVregs(
|
||||
boundIdxConst(0)},
|
||||
ArrayRef<Value>{boundIdxConst(right_first_dst_tile_sublane_offset),
|
||||
boundIdxConst(target_shape[1])});
|
||||
return builder.create<arith::SelectOp>(left_partial_vreg.getLoc(),
|
||||
sublanes_mask, left_partial_vreg,
|
||||
right_partial_vreg);
|
||||
return select_32bit(sublanes_mask, left_partial_vreg, right_partial_vreg);
|
||||
}
|
||||
|
||||
auto sublanes_mask = builder.create<tpu::CreateMaskOp>(
|
||||
@ -3988,9 +3998,7 @@ Value selectTilesFromRotatedRowVregs(
|
||||
boundIdxConst(0)},
|
||||
ArrayRef<Value>{boundIdxConst(first_dst_tile_sublane_offset),
|
||||
boundIdxConst(target_shape[1])});
|
||||
return builder.create<arith::SelectOp>(left_partial_vreg.getLoc(),
|
||||
sublanes_mask, right_partial_vreg,
|
||||
left_partial_vreg);
|
||||
return select_32bit(sublanes_mask, right_partial_vreg, left_partial_vreg);
|
||||
}
|
||||
|
||||
// Retiles across vregs to match the destination layout when the sublane tiling
|
||||
|
Loading…
x
Reference in New Issue
Block a user