[Mosaic TPU] Enable (packing, 128) -> (8 * packing, 128) retiling

PiperOrigin-RevId: 654922099
This commit is contained in:
Tomás Longeri 2024-07-22 15:44:27 -07:00 committed by jax authors
parent ab811f3ac5
commit 5f18a2e27b

View File

@ -558,6 +558,7 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
return argument;
}
// TODO(tlongeri): This function and others below never fail, remove FailureOr
FailureOr<VectorType> getNativeVregOrVmaskTypeImpl(
Type elem_ty, const int8_t bitwidth,
const std::array<int64_t, 2> target_shape) {
@ -4801,8 +4802,14 @@ Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx,
/*dimension=*/0, /*stride=*/nullptr, /*stride_dimension=*/nullptr);
auto boundIdxConst =
std::bind(IdxConst, std::placeholders::_1, builder, src_vreg.getLoc());
const int bitwidth =
cast<VectorType>(src_vreg.getType()).getElementTypeBitWidth();
CHECK_EQ(bitwidth,
cast<VectorType>(dst_vreg.getType()).getElementTypeBitWidth());
const VectorType vmask_ty =
*getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape);
auto sublanes_mask = builder.create<tpu::CreateMaskOp>(
src_vreg.getLoc(), VectorType::get(target_shape, builder.getI1Type()),
src_vreg.getLoc(), vmask_ty,
ValueRange{boundIdxConst(dst_sl_idx), boundIdxConst(0)},
ValueRange{boundIdxConst(dst_sl_idx + 1),
boundIdxConst(target_shape[1])});
@ -5050,13 +5057,15 @@ FailureOr<TypedValue<VectorType>> relayout(
.getResult();
}
// Handle retiling from (1, 128) to (8, 128) for 32-bit data.
if (src.implicit_dim() == dst.implicit_dim() && src.bitwidth() == 32 &&
// Handle retiling from (packing, 128) to (8 * packing, 128).
if (src.implicit_dim() == dst.implicit_dim() &&
src.offsets() == LayoutOffsets{0, 0} &&
(dst.offsets()[0] == 0 || (dst.offsets()[0] == std::nullopt &&
*(src_tiles.dimensions().end() - 2) == 1)) &&
dst.offsets()[1] == 0 && src.tiling() == std::array<int64_t, 2>{1, 128} &&
dst.tiling() == std::array<int64_t, 2>{8, 128}) {
(dst.offsets()[0] == 0 ||
(packing == 1 && dst.offsets()[0] == std::nullopt &&
*(src_tiles.dimensions().end() - 2) == 1)) &&
dst.offsets()[1] == 0 &&
src.tiling() == std::array<int64_t, 2>{packing, 128} &&
dst.tiling() == std::array<int64_t, 2>{8 * packing, 128}) {
xla::Array<Value> src_tiles_retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {