diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index e54c8a389..ded9171d1 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2261,6 +2261,7 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, if (!layouts_out.front().has_value()) { return op.emitOpError("Expected non-null output layout"); } + using Tiling = std::array; const VectorLayout &layout_in = *layouts_in.front(); const VectorLayout &layout_out = *layouts_out.front(); ImplicitLocOpBuilder builder(op.getLoc(), &op); @@ -2313,6 +2314,17 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, *(dst_shape.end() - 2) % layout_in.tiling()[0] == 0 && *(src_shape.end() - 2) % layout_in.tiling()[0] == 0) { no_op = true; + } else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone && + layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone && + layout_in.offsets() == layout_out.offsets() && + layout_in.offsets() == LayoutOffsets{0, 0} && + layout_in.tiling() == Tiling{1, ctx.target_shape[1]} && + layout_out.hasNaturalTopology(ctx.target_shape) && + *(dst_shape.end() - 1) != *(src_shape.end() - 1) && + *(dst_shape.end() - 1) == ctx.target_shape[1] && + *(dst_shape.end() - 2) % layout_out.tiling()[0] == 0 && + *(src_shape.end() - 1) % layout_in.tiling()[1] == 0) { + no_op = true; } FAILUREOR_ASSIGN_OR_RETURN( xla::Array src_vregs, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 841c73ae8..cb39a81d8 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -995,7 +995,7 @@ class VectorLayoutInferer { setLayout(op, layout, layout); return success(); } - // Sublane (un)tiling + // Sublane (un)tiling. if (res_rank >= 2 && layout.offsets() == LayoutOffsets{0, 0} && layout.tiling()[1] == target_shape_[1] && src_ty.getDimSize(src_ty.getRank() - 1) == @@ -1005,6 +1005,42 @@ class VectorLayoutInferer { setLayout(op, layout, layout); return success(); } + // Lane (un)tiling. + if (res_rank >= 2 && layout.offsets() == LayoutOffsets{0, 0} && + layout.tiling()[1] == target_shape_[1] && + src_ty.getDimSize(src_ty.getRank() - 1) != + res_shape[res_shape.size() - 1] && + src_ty.getDimSize(src_ty.getRank() - 1) % layout.tiling()[1] == 0 && + res_shape[res_shape.size() - 1] % layout.tiling()[1] == 0) { + // TODO(jevinjiang): support shapecast along lane with any bitwidth. + if (src_ty.getElementTypeBitWidth() != kNativeBitwidth) { + NYI("Shapecast along lane dimension when bitwidth is not 32"); + } + // Inferring in_layout to have tiling (1, 128) triggers any necessary + // relayout before shapecast. + setInLayout(op, + {VectorLayout(layout.bitwidth(), layout.offsets(), + {1, target_shape_[1]}, ImplicitDim::kNone)}); + // If the input has tiling (1, target_shape_[1]) and the last two dims + // of result shape are [n * target_shape_[0], target_shape_[1]], the + // reshape becomes a no-op if only we change the tiling to match + // target_shape_. For example, reshaping to layouts like the following + // will not require any data movement: + // 8x8x128 (1, 128) + // 4x16x128 (1, 128) + // 8x8x128 (8, 128) + // .... + if (res_shape[res_shape.size() - 1] == target_shape_[1] && + res_shape[res_shape.size() - 2] % target_shape_[0] == 0) { + setOutLayout(op, VectorLayout(layout.bitwidth(), layout.offsets(), + default_tiling_, ImplicitDim::kNone)); + return success(); + } + + // TODO(b/299253805): support shapecast along lane for other cases. + op.emitOpError("unsupported shape cast"); + return failure(); + } unsigned bitwidth = src_ty.getElementTypeBitWidth(); auto native_tiling = nativeTiling(bitwidth); if (layout.tiling() != native_tiling) { diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index 0a1f42545..b7a60f087 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -2748,6 +2748,18 @@ def _vector_shape_cast_rule(ctx: RewriteContext, op: vector.ShapeCastOp, # pyli and src_ty.shape[-2] % layout_in.tiling[-2] == 0 ): no_op = True + elif ( + layout_in.implicit_dim is None + and layout_out.implicit_dim is None + and layout_out.offsets == layout_in.offsets == (0, 0) + and layout_in.tiling == (1, TARGET_SHAPE.lanes) + and layout_out.has_natural_topology + and dst_ty.shape[-1] != src_ty.shape[-1] + and dst_ty.shape[-1] == TARGET_SHAPE.lanes + and dst_ty.shape[-2] % layout_out.tiling[-2] == 0 + and src_ty.shape[-1] % layout_in.tiling[-1] == 0 + ): # n x (m * 128) -> n x m x 128. + no_op = True src_vregs = disassemble(layout_in, op.source) if no_op: