mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic TPU] Support bitcast without forcing retiling.
PiperOrigin-RevId: 678765762
This commit is contained in:
parent
c93b272b78
commit
37641dd4fa
@ -938,16 +938,16 @@ class VectorLayoutInferer {
|
|||||||
auto out_ty = cast<VectorType>(op.getOutput().getType());
|
auto out_ty = cast<VectorType>(op.getOutput().getType());
|
||||||
auto in_bitwidth = in_ty.getElementTypeBitWidth();
|
auto in_bitwidth = in_ty.getElementTypeBitWidth();
|
||||||
auto out_bitwidth = out_ty.getElementTypeBitWidth();
|
auto out_bitwidth = out_ty.getElementTypeBitWidth();
|
||||||
auto src_layout = getLayout(op.getInput());
|
auto in_layout = getLayout(op.getInput());
|
||||||
LayoutOffsets src_offsets = src_layout->offsets();
|
LayoutOffsets in_offsets = in_layout->offsets();
|
||||||
auto implicit_dim = src_layout->implicit_dim();
|
auto implicit_dim = in_layout->implicit_dim();
|
||||||
if (src_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) {
|
if (in_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) {
|
||||||
// Force offset to zero if the input offset on the second minor dimension
|
// Force offset to zero if the input offset on the second minor dimension
|
||||||
// is not a multiple of the ratio of output and input bitwidth.
|
// is not a multiple of the ratio of output and input bitwidth.
|
||||||
src_offsets[0] = 0;
|
in_offsets[0] = 0;
|
||||||
} else if (!src_offsets[0].has_value() && in_bitwidth > out_bitwidth) {
|
} else if (!in_offsets[0].has_value() && in_bitwidth > out_bitwidth) {
|
||||||
// We can't preserve replicated offset for decreasing bitwidth.
|
// We can't preserve replicated offset for decreasing bitwidth.
|
||||||
src_offsets[0] = 0;
|
in_offsets[0] = 0;
|
||||||
}
|
}
|
||||||
// Force implicit dim to None if the bitwidth changes. Because we expect 2nd
|
// Force implicit dim to None if the bitwidth changes. Because we expect 2nd
|
||||||
// minor dim size ratio matches the bitwidth ratio in input and output.
|
// minor dim size ratio matches the bitwidth ratio in input and output.
|
||||||
@ -959,20 +959,24 @@ class VectorLayoutInferer {
|
|||||||
}
|
}
|
||||||
implicit_dim = ImplicitDim::kNone;
|
implicit_dim = ImplicitDim::kNone;
|
||||||
}
|
}
|
||||||
// TODO(b/348485035): Instead of forcing to native tiling, bitcast should
|
auto in_tiling = in_layout->tiling();
|
||||||
// keep the input tiling and infer bitcastable tiling for output. For
|
auto out_tiling = in_tiling;
|
||||||
// example, it is valid to bitcast vector<8x128xi32> with tile (1, 128) to
|
auto out_offsets = in_offsets;
|
||||||
// vector<8x128xbf16> with tile (2, 128).
|
if (in_offsets[0].has_value()) {
|
||||||
|
out_offsets[0] = in_offsets[0].value() * in_bitwidth / out_bitwidth;
|
||||||
|
}
|
||||||
|
if ((in_tiling[0] * in_bitwidth) % out_bitwidth == 0) {
|
||||||
|
out_tiling[0] = out_tiling[0] * in_bitwidth / out_bitwidth;
|
||||||
|
} else {
|
||||||
|
// If the input sublane tiling is not bitcastable to output, we use native
|
||||||
|
// tiling for both input and output.
|
||||||
|
in_tiling = nativeTiling(in_bitwidth);
|
||||||
|
out_tiling = nativeTiling(out_bitwidth);
|
||||||
|
}
|
||||||
|
|
||||||
setLayout(
|
setLayout(
|
||||||
op,
|
op, VectorLayout(in_bitwidth, in_offsets, in_tiling, implicit_dim),
|
||||||
VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth),
|
VectorLayout(out_bitwidth, out_offsets, out_tiling, implicit_dim));
|
||||||
implicit_dim),
|
|
||||||
VectorLayout(out_bitwidth,
|
|
||||||
{src_offsets[0].has_value()
|
|
||||||
? src_offsets[0].value() * in_bitwidth / out_bitwidth
|
|
||||||
: src_offsets[0],
|
|
||||||
src_offsets[1]},
|
|
||||||
nativeTiling(out_bitwidth), implicit_dim));
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user