mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic TPU] Support bitcast without forcing retiling.
PiperOrigin-RevId: 678765762
This commit is contained in:
parent
c93b272b78
commit
37641dd4fa
@ -938,16 +938,16 @@ class VectorLayoutInferer {
|
||||
auto out_ty = cast<VectorType>(op.getOutput().getType());
|
||||
auto in_bitwidth = in_ty.getElementTypeBitWidth();
|
||||
auto out_bitwidth = out_ty.getElementTypeBitWidth();
|
||||
auto src_layout = getLayout(op.getInput());
|
||||
LayoutOffsets src_offsets = src_layout->offsets();
|
||||
auto implicit_dim = src_layout->implicit_dim();
|
||||
if (src_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) {
|
||||
auto in_layout = getLayout(op.getInput());
|
||||
LayoutOffsets in_offsets = in_layout->offsets();
|
||||
auto implicit_dim = in_layout->implicit_dim();
|
||||
if (in_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) {
|
||||
// Force offset to zero if the input offset on the second minor dimension
|
||||
// is not a multiple of the ratio of output and input bitwidth.
|
||||
src_offsets[0] = 0;
|
||||
} else if (!src_offsets[0].has_value() && in_bitwidth > out_bitwidth) {
|
||||
in_offsets[0] = 0;
|
||||
} else if (!in_offsets[0].has_value() && in_bitwidth > out_bitwidth) {
|
||||
// We can't preserve replicated offset for decreasing bitwidth.
|
||||
src_offsets[0] = 0;
|
||||
in_offsets[0] = 0;
|
||||
}
|
||||
// Force implicit dim to None if the bitwidth changes. Because we expect 2nd
|
||||
// minor dim size ratio matches the bitwidth ratio in input and output.
|
||||
@ -959,20 +959,24 @@ class VectorLayoutInferer {
|
||||
}
|
||||
implicit_dim = ImplicitDim::kNone;
|
||||
}
|
||||
// TODO(b/348485035): Instead of forcing to native tiling, bitcast should
|
||||
// keep the input tiling and infer bitcastable tiling for output. For
|
||||
// example, it is valid to bitcast vector<8x128xi32> with tile (1, 128) to
|
||||
// vector<8x128xbf16> with tile (2, 128).
|
||||
auto in_tiling = in_layout->tiling();
|
||||
auto out_tiling = in_tiling;
|
||||
auto out_offsets = in_offsets;
|
||||
if (in_offsets[0].has_value()) {
|
||||
out_offsets[0] = in_offsets[0].value() * in_bitwidth / out_bitwidth;
|
||||
}
|
||||
if ((in_tiling[0] * in_bitwidth) % out_bitwidth == 0) {
|
||||
out_tiling[0] = out_tiling[0] * in_bitwidth / out_bitwidth;
|
||||
} else {
|
||||
// If the input sublane tiling is not bitcastable to output, we use native
|
||||
// tiling for both input and output.
|
||||
in_tiling = nativeTiling(in_bitwidth);
|
||||
out_tiling = nativeTiling(out_bitwidth);
|
||||
}
|
||||
|
||||
setLayout(
|
||||
op,
|
||||
VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth),
|
||||
implicit_dim),
|
||||
VectorLayout(out_bitwidth,
|
||||
{src_offsets[0].has_value()
|
||||
? src_offsets[0].value() * in_bitwidth / out_bitwidth
|
||||
: src_offsets[0],
|
||||
src_offsets[1]},
|
||||
nativeTiling(out_bitwidth), implicit_dim));
|
||||
op, VectorLayout(in_bitwidth, in_offsets, in_tiling, implicit_dim),
|
||||
VectorLayout(out_bitwidth, out_offsets, out_tiling, implicit_dim));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user