From 1fcb84dc90981ccab3d95fc525a86919eef66834 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 22 Feb 2024 11:12:45 -0800 Subject: [PATCH] [XLA:Mosaic] Support broadcast one row with padded tiling. PiperOrigin-RevId: 609435269 --- .../tpu/transforms/apply_vector_layout.cc | 22 ++++++---- .../tpu/transforms/infer_vector_layout.cc | 44 ++++++++++++------- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 0778c2a67..8fd664a1c 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -3836,7 +3836,6 @@ FailureOr relayout(OpBuilder &builder, Value v, VectorLayout src, } } - // TODO(b/306692696) Generalize relayout from tiling (m, 128) to (8, 128). // Handle retiling from (1, 128) to (8, 128) for 32-bit data. if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone && dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && @@ -3863,29 +3862,34 @@ FailureOr relayout(OpBuilder &builder, Value v, VectorLayout src, }); src = dst; src_tiles = std::move(src_tiles_retiled); - } else if ( // Handle retiling from (2, 128) to (8, 128) for 32-bit data. + } else if ( // Handle retiling from (m, 128) to (8, 128) for 32-bit data + // where m < 8 and m is a power of 2. + // TODO(b/306692696) Generalize relayout from tiling (m, 128) to + // (8, 128) for any src_tiles.dimensions(). src.implicit_dim() == VectorLayout::ImplicitDim::kNone && dst.implicit_dim() == VectorLayout::ImplicitDim::kNone && src.bitwidth() == 32 && src.offsets() == LayoutOffsets{0, 0} && dst.offsets() == LayoutOffsets{0, 0} && - src.tiling() == std::array{2, 128} && - dst.tiling() == std::array{8, 128} && + target_shape[0] % src.tiling()[0] == 0 && + src.tiling()[1] == target_shape[1] && dst.tiling() == target_shape && *(src_tiles.dimensions().end() - 2) == 1) { xla::Array src_tiles_retiled( dst.tileArrayShape(vty.getShape(), target_shape)); src_tiles_retiled.Each( [&](const absl::Span idx, Value *const new_src_tile) { + const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape); const int64_t dst_col = idx.back(); - const int64_t src_col = dst_col / 4; - const int64_t start_slane_idx = 2 * (dst_col % 4); + const int64_t src_col = dst_col / tiles_per_vreg; + const int64_t start_slane_idx = + src.tiling()[0] * (dst_col % tiles_per_vreg); SmallVector src_idx(toArrayRef(idx)); src_idx.back() = src_col; Value src_tile = src_tiles(src_idx); if (start_slane_idx) { SmallVector slane_idxs; - slane_idxs.reserve(8); - for (int i = 0; i < 8; ++i) { - slane_idxs.push_back(start_slane_idx + (i % 2)); + slane_idxs.reserve(target_shape[0]); + for (int i = 0; i < target_shape[0]; ++i) { + slane_idxs.push_back(start_slane_idx + (i % src.tiling()[0])); } const DenseI32ArrayAttr gather_indices = builder.getDenseI32ArrayAttr(slane_idxs); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 0ddccf574..856837b3e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -457,10 +457,20 @@ class VectorLayoutInferer { auto &layout = *some_layout; if (layout.implicit_dim() == ImplicitDim::kNone) { // TODO(apaszke): Support native layouts here. - auto src_layout = VectorLayout(layout.bitwidth(), layout.offsets(), - default_tiling_, ImplicitDim::kNone); - auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_, - ImplicitDim::kNone); + Layout src_layout; + Layout dst_layout; + // All layouts that subdivide the rows of the default tiling evenly + // can be handled uniformly with the default case, by preserving the + // tiling through the op. + if (default_tiling_[0] % layout.tiling()[0] == 0 && + default_tiling_[1] == layout.tiling()[1]) { + src_layout = layout; + } else { + src_layout = VectorLayout(layout.bitwidth(), layout.offsets(), + default_tiling_, ImplicitDim::kNone); + } + dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(), + ImplicitDim::kNone); setLayout(op, src_layout, dst_layout); return success(); } @@ -860,17 +870,21 @@ class VectorLayoutInferer { TPU_CHECK_OP(res_ty.getRank() >= 2, "result rank below 2D unsupported"); auto some_layout = getLayout(op.getSource()); TPU_CHECK_OP(some_layout.has_value(), "missing vector layout"); - // We want to force the layout to be (8, 128) instead of (1, 128) if we - // are broadcasting sublane dim from 1 to at least 8. - if (some_layout->bitwidth() == kNativeBitwidth && - some_layout->implicit_dim() == ImplicitDim::kNone && - some_layout->tiling()[0] == 1 && - some_layout->tiling()[1] == default_tiling_[1] && - src_ty.getDimSize(src_ty.getRank() - 2) == 1 && - res_ty.getDimSize(res_ty.getRank() - 2) >= 8) { - *some_layout = VectorLayout( - some_layout->bitwidth(), {std::nullopt, some_layout->offsets()[1]}, - default_tiling_, some_layout->implicit_dim()); + // Since we can only do sublane broadcasts in the (8, 128) tiling, we + // should always use that when sublane broadcasting is required. + if (src_ty.getDimSize(src_ty.getRank() - 2) != + res_ty.getDimSize(res_ty.getRank() - 2)) { + if (some_layout->bitwidth() != kNativeBitwidth) { + NYI("Only 32-bit broadcasts supported"); + } + LayoutOffsets offsets = some_layout->offsets(); + // At the moment relayout can only produce replicated sublanes when + // converting to (8, 128) if the input was in (1, 128) tiling + if (some_layout->tiling()[0] == 1) { + offsets[0] = std::nullopt; + } + *some_layout = VectorLayout(some_layout->bitwidth(), offsets, + default_tiling_, some_layout->implicit_dim()); } auto &layout = *some_layout; if (layout.implicit_dim() != ImplicitDim::kNone) {