From 37641dd4fade625563321b7e1e87165df23cf4a8 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Wed, 25 Sep 2024 10:56:30 -0700 Subject: [PATCH] [Mosaic TPU] Support bitcast without forcing retiling. PiperOrigin-RevId: 678765762 --- .../tpu/transforms/infer_vector_layout.cc | 44 ++++++++++--------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 2894b0797..408731e89 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -938,16 +938,16 @@ class VectorLayoutInferer { auto out_ty = cast(op.getOutput().getType()); auto in_bitwidth = in_ty.getElementTypeBitWidth(); auto out_bitwidth = out_ty.getElementTypeBitWidth(); - auto src_layout = getLayout(op.getInput()); - LayoutOffsets src_offsets = src_layout->offsets(); - auto implicit_dim = src_layout->implicit_dim(); - if (src_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) { + auto in_layout = getLayout(op.getInput()); + LayoutOffsets in_offsets = in_layout->offsets(); + auto implicit_dim = in_layout->implicit_dim(); + if (in_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) { // Force offset to zero if the input offset on the second minor dimension // is not a multiple of the ratio of output and input bitwidth. - src_offsets[0] = 0; - } else if (!src_offsets[0].has_value() && in_bitwidth > out_bitwidth) { + in_offsets[0] = 0; + } else if (!in_offsets[0].has_value() && in_bitwidth > out_bitwidth) { // We can't preserve replicated offset for decreasing bitwidth. - src_offsets[0] = 0; + in_offsets[0] = 0; } // Force implicit dim to None if the bitwidth changes. Because we expect 2nd // minor dim size ratio matches the bitwidth ratio in input and output. @@ -959,20 +959,24 @@ class VectorLayoutInferer { } implicit_dim = ImplicitDim::kNone; } - // TODO(b/348485035): Instead of forcing to native tiling, bitcast should - // keep the input tiling and infer bitcastable tiling for output. For - // example, it is valid to bitcast vector<8x128xi32> with tile (1, 128) to - // vector<8x128xbf16> with tile (2, 128). + auto in_tiling = in_layout->tiling(); + auto out_tiling = in_tiling; + auto out_offsets = in_offsets; + if (in_offsets[0].has_value()) { + out_offsets[0] = in_offsets[0].value() * in_bitwidth / out_bitwidth; + } + if ((in_tiling[0] * in_bitwidth) % out_bitwidth == 0) { + out_tiling[0] = out_tiling[0] * in_bitwidth / out_bitwidth; + } else { + // If the input sublane tiling is not bitcastable to output, we use native + // tiling for both input and output. + in_tiling = nativeTiling(in_bitwidth); + out_tiling = nativeTiling(out_bitwidth); + } + setLayout( - op, - VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth), - implicit_dim), - VectorLayout(out_bitwidth, - {src_offsets[0].has_value() - ? src_offsets[0].value() * in_bitwidth / out_bitwidth - : src_offsets[0], - src_offsets[1]}, - nativeTiling(out_bitwidth), implicit_dim)); + op, VectorLayout(in_bitwidth, in_offsets, in_tiling, implicit_dim), + VectorLayout(out_bitwidth, out_offsets, out_tiling, implicit_dim)); return success(); }