[Mosaic TPU] Support bitcast without forcing retiling.

PiperOrigin-RevId: 678765762
This commit is contained in:
Jevin Jiang 2024-09-25 10:56:30 -07:00 committed by jax authors
parent c93b272b78
commit 37641dd4fa

View File

@ -938,16 +938,16 @@ class VectorLayoutInferer {
auto out_ty = cast<VectorType>(op.getOutput().getType()); auto out_ty = cast<VectorType>(op.getOutput().getType());
auto in_bitwidth = in_ty.getElementTypeBitWidth(); auto in_bitwidth = in_ty.getElementTypeBitWidth();
auto out_bitwidth = out_ty.getElementTypeBitWidth(); auto out_bitwidth = out_ty.getElementTypeBitWidth();
auto src_layout = getLayout(op.getInput()); auto in_layout = getLayout(op.getInput());
LayoutOffsets src_offsets = src_layout->offsets(); LayoutOffsets in_offsets = in_layout->offsets();
auto implicit_dim = src_layout->implicit_dim(); auto implicit_dim = in_layout->implicit_dim();
if (src_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) { if (in_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) {
// Force offset to zero if the input offset on the second minor dimension // Force offset to zero if the input offset on the second minor dimension
// is not a multiple of the ratio of output and input bitwidth. // is not a multiple of the ratio of output and input bitwidth.
src_offsets[0] = 0; in_offsets[0] = 0;
} else if (!src_offsets[0].has_value() && in_bitwidth > out_bitwidth) { } else if (!in_offsets[0].has_value() && in_bitwidth > out_bitwidth) {
// We can't preserve replicated offset for decreasing bitwidth. // We can't preserve replicated offset for decreasing bitwidth.
src_offsets[0] = 0; in_offsets[0] = 0;
} }
// Force implicit dim to None if the bitwidth changes. Because we expect 2nd // Force implicit dim to None if the bitwidth changes. Because we expect 2nd
// minor dim size ratio matches the bitwidth ratio in input and output. // minor dim size ratio matches the bitwidth ratio in input and output.
@ -959,20 +959,24 @@ class VectorLayoutInferer {
} }
implicit_dim = ImplicitDim::kNone; implicit_dim = ImplicitDim::kNone;
} }
// TODO(b/348485035): Instead of forcing to native tiling, bitcast should auto in_tiling = in_layout->tiling();
// keep the input tiling and infer bitcastable tiling for output. For auto out_tiling = in_tiling;
// example, it is valid to bitcast vector<8x128xi32> with tile (1, 128) to auto out_offsets = in_offsets;
// vector<8x128xbf16> with tile (2, 128). if (in_offsets[0].has_value()) {
out_offsets[0] = in_offsets[0].value() * in_bitwidth / out_bitwidth;
}
if ((in_tiling[0] * in_bitwidth) % out_bitwidth == 0) {
out_tiling[0] = out_tiling[0] * in_bitwidth / out_bitwidth;
} else {
// If the input sublane tiling is not bitcastable to output, we use native
// tiling for both input and output.
in_tiling = nativeTiling(in_bitwidth);
out_tiling = nativeTiling(out_bitwidth);
}
setLayout( setLayout(
op, op, VectorLayout(in_bitwidth, in_offsets, in_tiling, implicit_dim),
VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth), VectorLayout(out_bitwidth, out_offsets, out_tiling, implicit_dim));
implicit_dim),
VectorLayout(out_bitwidth,
{src_offsets[0].has_value()
? src_offsets[0].value() * in_bitwidth / out_bitwidth
: src_offsets[0],
src_offsets[1]},
nativeTiling(out_bitwidth), implicit_dim));
return success(); return success();
} }