mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic TPU] Use large to compact 2nd minor retiling for conversions going both ways
This specific retiling is its own inverse and it faster than alternatives. PiperOrigin-RevId: 716360070
This commit is contained in:
parent
aa9cea0a55
commit
bd22bfef71
@ -6256,12 +6256,17 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
});
|
||||
return std::pair(dst, std::move(retiled));
|
||||
}
|
||||
// (8,128) -> (8 * packing,128) tiling change for packed type.
|
||||
if (src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
|
||||
// (8,128) <-> (8 * packing,128) tiling change for packed type.
|
||||
if (ctx.hardware_generation >= 4 &&
|
||||
src_offsets[0].value_or(0) < dst_vreg_slice[0] &&
|
||||
src_offsets[1].value_or(0) < dst_vreg_slice[1] && bitwidth < 32 &&
|
||||
32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
|
||||
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
|
||||
ctx.target_shape[1]}) {
|
||||
32 % bitwidth == 0 &&
|
||||
((src_tiling == ctx.target_shape &&
|
||||
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
|
||||
ctx.target_shape[1]}) ||
|
||||
(dst_tiling == ctx.target_shape &&
|
||||
src_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * packing,
|
||||
ctx.target_shape[1]}))) {
|
||||
// TODO(tlongeri): This relayout is just ext + trunc. Refactor.
|
||||
// Note: for int4, retiling with scratch is always faster.
|
||||
if (bitwidth != 4 || !has_enough_scratch) {
|
||||
@ -6277,26 +6282,32 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
vty.getElementType().isSignlessInteger()
|
||||
? VectorType::get(target_shape, builder.getI32Type())
|
||||
: VectorType::get(target_shape, builder.getF32Type());
|
||||
// For each output vreg we collect `packing` registers from the moving dim
|
||||
// (sublanes or lanes), while using the other vreg dim to determine which
|
||||
// part of each register to use (the parts dim).
|
||||
const int parts_dim = src_tiling[0] < dst_tiling[0] ? 1 : 2;
|
||||
const int moving_dim = src_tiling[0] < dst_tiling[0] ? 2 : 1;
|
||||
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
|
||||
const int vreg_part = idx.back() % packing;
|
||||
const int vreg_part = *(idx.end() - parts_dim) % packing;
|
||||
SmallVector<Value, 8> parts;
|
||||
parts.reserve(packing);
|
||||
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
|
||||
*(src_idx.end() - 1) /= packing;
|
||||
if (!dst.offsets()[0].has_value()) {
|
||||
*(src_idx.end() - 2) = 0;
|
||||
*(src_idx.end() - parts_dim) /= packing;
|
||||
if (!dst.offsets()[2 - moving_dim].has_value()) {
|
||||
*(src_idx.end() - moving_dim) = 0;
|
||||
// Make sure we set all parts of the output vreg to make it replicated
|
||||
parts.append(packing, builder.create<tpu::UnpackSubelementsOp>(
|
||||
loc, vreg_x32, vregs(src_idx), vreg_part,
|
||||
tpu::PackFormat::kCompressed));
|
||||
} else {
|
||||
*(src_idx.end() - 2) *= packing;
|
||||
*(src_idx.end() - moving_dim) *= packing;
|
||||
for (int i = 0; i < packing; ++i) {
|
||||
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2)) {
|
||||
if (*(src_idx.end() - moving_dim) <
|
||||
*(vregs.dimensions().end() - moving_dim)) {
|
||||
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
|
||||
loc, vreg_x32, vregs(src_idx), vreg_part,
|
||||
tpu::PackFormat::kCompressed));
|
||||
++*(src_idx.end() - 2);
|
||||
++*(src_idx.end() - moving_dim);
|
||||
} else {
|
||||
parts.push_back(nullptr);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user