mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic TPU] Enable (packing, 128) -> (8 * packing, 128) retiling
PiperOrigin-RevId: 654922099
This commit is contained in:
parent
ab811f3ac5
commit
5f18a2e27b
@ -558,6 +558,7 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
|
||||
return argument;
|
||||
}
|
||||
|
||||
// TODO(tlongeri): This function and others below never fail, remove FailureOr
|
||||
FailureOr<VectorType> getNativeVregOrVmaskTypeImpl(
|
||||
Type elem_ty, const int8_t bitwidth,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
@ -4801,8 +4802,14 @@ Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx,
|
||||
/*dimension=*/0, /*stride=*/nullptr, /*stride_dimension=*/nullptr);
|
||||
auto boundIdxConst =
|
||||
std::bind(IdxConst, std::placeholders::_1, builder, src_vreg.getLoc());
|
||||
const int bitwidth =
|
||||
cast<VectorType>(src_vreg.getType()).getElementTypeBitWidth();
|
||||
CHECK_EQ(bitwidth,
|
||||
cast<VectorType>(dst_vreg.getType()).getElementTypeBitWidth());
|
||||
const VectorType vmask_ty =
|
||||
*getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape);
|
||||
auto sublanes_mask = builder.create<tpu::CreateMaskOp>(
|
||||
src_vreg.getLoc(), VectorType::get(target_shape, builder.getI1Type()),
|
||||
src_vreg.getLoc(), vmask_ty,
|
||||
ValueRange{boundIdxConst(dst_sl_idx), boundIdxConst(0)},
|
||||
ValueRange{boundIdxConst(dst_sl_idx + 1),
|
||||
boundIdxConst(target_shape[1])});
|
||||
@ -5050,13 +5057,15 @@ FailureOr<TypedValue<VectorType>> relayout(
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// Handle retiling from (1, 128) to (8, 128) for 32-bit data.
|
||||
if (src.implicit_dim() == dst.implicit_dim() && src.bitwidth() == 32 &&
|
||||
// Handle retiling from (packing, 128) to (8 * packing, 128).
|
||||
if (src.implicit_dim() == dst.implicit_dim() &&
|
||||
src.offsets() == LayoutOffsets{0, 0} &&
|
||||
(dst.offsets()[0] == 0 || (dst.offsets()[0] == std::nullopt &&
|
||||
*(src_tiles.dimensions().end() - 2) == 1)) &&
|
||||
dst.offsets()[1] == 0 && src.tiling() == std::array<int64_t, 2>{1, 128} &&
|
||||
dst.tiling() == std::array<int64_t, 2>{8, 128}) {
|
||||
(dst.offsets()[0] == 0 ||
|
||||
(packing == 1 && dst.offsets()[0] == std::nullopt &&
|
||||
*(src_tiles.dimensions().end() - 2) == 1)) &&
|
||||
dst.offsets()[1] == 0 &&
|
||||
src.tiling() == std::array<int64_t, 2>{packing, 128} &&
|
||||
dst.tiling() == std::array<int64_t, 2>{8 * packing, 128}) {
|
||||
xla::Array<Value> src_tiles_retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user