[Mosaic TPU] Support bitcast without forcing retiling.

PiperOrigin-RevId: 678765762
This commit is contained in:
Jevin Jiang 2024-09-25 10:56:30 -07:00 committed by jax authors
parent c93b272b78
commit 37641dd4fa

View File

@ -938,16 +938,16 @@ class VectorLayoutInferer {
auto out_ty = cast<VectorType>(op.getOutput().getType());
auto in_bitwidth = in_ty.getElementTypeBitWidth();
auto out_bitwidth = out_ty.getElementTypeBitWidth();
auto src_layout = getLayout(op.getInput());
LayoutOffsets src_offsets = src_layout->offsets();
auto implicit_dim = src_layout->implicit_dim();
if (src_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) {
auto in_layout = getLayout(op.getInput());
LayoutOffsets in_offsets = in_layout->offsets();
auto implicit_dim = in_layout->implicit_dim();
if (in_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) {
// Force offset to zero if the input offset on the second minor dimension
// is not a multiple of the ratio of output and input bitwidth.
src_offsets[0] = 0;
} else if (!src_offsets[0].has_value() && in_bitwidth > out_bitwidth) {
in_offsets[0] = 0;
} else if (!in_offsets[0].has_value() && in_bitwidth > out_bitwidth) {
// We can't preserve replicated offset for decreasing bitwidth.
src_offsets[0] = 0;
in_offsets[0] = 0;
}
// Force implicit dim to None if the bitwidth changes. Because we expect 2nd
// minor dim size ratio matches the bitwidth ratio in input and output.
@ -959,20 +959,24 @@ class VectorLayoutInferer {
}
implicit_dim = ImplicitDim::kNone;
}
// TODO(b/348485035): Instead of forcing to native tiling, bitcast should
// keep the input tiling and infer bitcastable tiling for output. For
// example, it is valid to bitcast vector<8x128xi32> with tile (1, 128) to
// vector<8x128xbf16> with tile (2, 128).
auto in_tiling = in_layout->tiling();
auto out_tiling = in_tiling;
auto out_offsets = in_offsets;
if (in_offsets[0].has_value()) {
out_offsets[0] = in_offsets[0].value() * in_bitwidth / out_bitwidth;
}
if ((in_tiling[0] * in_bitwidth) % out_bitwidth == 0) {
out_tiling[0] = out_tiling[0] * in_bitwidth / out_bitwidth;
} else {
// If the input sublane tiling is not bitcastable to output, we use native
// tiling for both input and output.
in_tiling = nativeTiling(in_bitwidth);
out_tiling = nativeTiling(out_bitwidth);
}
setLayout(
op,
VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth),
implicit_dim),
VectorLayout(out_bitwidth,
{src_offsets[0].has_value()
? src_offsets[0].value() * in_bitwidth / out_bitwidth
: src_offsets[0],
src_offsets[1]},
nativeTiling(out_bitwidth), implicit_dim));
op, VectorLayout(in_bitwidth, in_offsets, in_tiling, implicit_dim),
VectorLayout(out_bitwidth, out_offsets, out_tiling, implicit_dim));
return success();
}