mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[XLA:Mosaic] Support broadcast one row with padded tiling.
PiperOrigin-RevId: 609435269
This commit is contained in:
parent
cf80f574b5
commit
1fcb84dc90
@ -3836,7 +3836,6 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(b/306692696) Generalize relayout from tiling (m, 128) to (8, 128).
|
||||
// Handle retiling from (1, 128) to (8, 128) for 32-bit data.
|
||||
if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
@ -3863,29 +3862,34 @@ FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
|
||||
});
|
||||
src = dst;
|
||||
src_tiles = std::move(src_tiles_retiled);
|
||||
} else if ( // Handle retiling from (2, 128) to (8, 128) for 32-bit data.
|
||||
} else if ( // Handle retiling from (m, 128) to (8, 128) for 32-bit data
|
||||
// where m < 8 and m is a power of 2.
|
||||
// TODO(b/306692696) Generalize relayout from tiling (m, 128) to
|
||||
// (8, 128) for any src_tiles.dimensions().
|
||||
src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
src.bitwidth() == 32 && src.offsets() == LayoutOffsets{0, 0} &&
|
||||
dst.offsets() == LayoutOffsets{0, 0} &&
|
||||
src.tiling() == std::array<int64_t, 2>{2, 128} &&
|
||||
dst.tiling() == std::array<int64_t, 2>{8, 128} &&
|
||||
target_shape[0] % src.tiling()[0] == 0 &&
|
||||
src.tiling()[1] == target_shape[1] && dst.tiling() == target_shape &&
|
||||
*(src_tiles.dimensions().end() - 2) == 1) {
|
||||
xla::Array<Value> src_tiles_retiled(
|
||||
dst.tileArrayShape(vty.getShape(), target_shape));
|
||||
src_tiles_retiled.Each(
|
||||
[&](const absl::Span<const int64_t> idx, Value *const new_src_tile) {
|
||||
const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape);
|
||||
const int64_t dst_col = idx.back();
|
||||
const int64_t src_col = dst_col / 4;
|
||||
const int64_t start_slane_idx = 2 * (dst_col % 4);
|
||||
const int64_t src_col = dst_col / tiles_per_vreg;
|
||||
const int64_t start_slane_idx =
|
||||
src.tiling()[0] * (dst_col % tiles_per_vreg);
|
||||
SmallVector<int64_t> src_idx(toArrayRef(idx));
|
||||
src_idx.back() = src_col;
|
||||
Value src_tile = src_tiles(src_idx);
|
||||
if (start_slane_idx) {
|
||||
SmallVector<int32_t> slane_idxs;
|
||||
slane_idxs.reserve(8);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
slane_idxs.push_back(start_slane_idx + (i % 2));
|
||||
slane_idxs.reserve(target_shape[0]);
|
||||
for (int i = 0; i < target_shape[0]; ++i) {
|
||||
slane_idxs.push_back(start_slane_idx + (i % src.tiling()[0]));
|
||||
}
|
||||
const DenseI32ArrayAttr gather_indices =
|
||||
builder.getDenseI32ArrayAttr(slane_idxs);
|
||||
|
@ -457,10 +457,20 @@ class VectorLayoutInferer {
|
||||
auto &layout = *some_layout;
|
||||
if (layout.implicit_dim() == ImplicitDim::kNone) {
|
||||
// TODO(apaszke): Support native layouts here.
|
||||
auto src_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
|
||||
default_tiling_, ImplicitDim::kNone);
|
||||
auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_,
|
||||
ImplicitDim::kNone);
|
||||
Layout src_layout;
|
||||
Layout dst_layout;
|
||||
// All layouts that subdivide the rows of the default tiling evenly
|
||||
// can be handled uniformly with the default case, by preserving the
|
||||
// tiling through the op.
|
||||
if (default_tiling_[0] % layout.tiling()[0] == 0 &&
|
||||
default_tiling_[1] == layout.tiling()[1]) {
|
||||
src_layout = layout;
|
||||
} else {
|
||||
src_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
|
||||
default_tiling_, ImplicitDim::kNone);
|
||||
}
|
||||
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
|
||||
ImplicitDim::kNone);
|
||||
setLayout(op, src_layout, dst_layout);
|
||||
return success();
|
||||
}
|
||||
@ -860,17 +870,21 @@ class VectorLayoutInferer {
|
||||
TPU_CHECK_OP(res_ty.getRank() >= 2, "result rank below 2D unsupported");
|
||||
auto some_layout = getLayout(op.getSource());
|
||||
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
|
||||
// We want to force the layout to be (8, 128) instead of (1, 128) if we
|
||||
// are broadcasting sublane dim from 1 to at least 8.
|
||||
if (some_layout->bitwidth() == kNativeBitwidth &&
|
||||
some_layout->implicit_dim() == ImplicitDim::kNone &&
|
||||
some_layout->tiling()[0] == 1 &&
|
||||
some_layout->tiling()[1] == default_tiling_[1] &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 2) == 1 &&
|
||||
res_ty.getDimSize(res_ty.getRank() - 2) >= 8) {
|
||||
*some_layout = VectorLayout(
|
||||
some_layout->bitwidth(), {std::nullopt, some_layout->offsets()[1]},
|
||||
default_tiling_, some_layout->implicit_dim());
|
||||
// Since we can only do sublane broadcasts in the (8, 128) tiling, we
|
||||
// should always use that when sublane broadcasting is required.
|
||||
if (src_ty.getDimSize(src_ty.getRank() - 2) !=
|
||||
res_ty.getDimSize(res_ty.getRank() - 2)) {
|
||||
if (some_layout->bitwidth() != kNativeBitwidth) {
|
||||
NYI("Only 32-bit broadcasts supported");
|
||||
}
|
||||
LayoutOffsets offsets = some_layout->offsets();
|
||||
// At the moment relayout can only produce replicated sublanes when
|
||||
// converting to (8, 128) if the input was in (1, 128) tiling
|
||||
if (some_layout->tiling()[0] == 1) {
|
||||
offsets[0] = std::nullopt;
|
||||
}
|
||||
*some_layout = VectorLayout(some_layout->bitwidth(), offsets,
|
||||
default_tiling_, some_layout->implicit_dim());
|
||||
}
|
||||
auto &layout = *some_layout;
|
||||
if (layout.implicit_dim() != ImplicitDim::kNone) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user