mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[XLA:Mosaic] Support shrinking lane dim to 128 in shapecast.
PiperOrigin-RevId: 574965883
This commit is contained in:
parent
741b71fe85
commit
bb30d3ee9f
@ -2261,6 +2261,7 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
|
||||
if (!layouts_out.front().has_value()) {
|
||||
return op.emitOpError("Expected non-null output layout");
|
||||
}
|
||||
using Tiling = std::array<int64_t, 2>;
|
||||
const VectorLayout &layout_in = *layouts_in.front();
|
||||
const VectorLayout &layout_out = *layouts_out.front();
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), &op);
|
||||
@ -2313,6 +2314,17 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
|
||||
*(dst_shape.end() - 2) % layout_in.tiling()[0] == 0 &&
|
||||
*(src_shape.end() - 2) % layout_in.tiling()[0] == 0) {
|
||||
no_op = true;
|
||||
} else if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
layout_in.offsets() == layout_out.offsets() &&
|
||||
layout_in.offsets() == LayoutOffsets{0, 0} &&
|
||||
layout_in.tiling() == Tiling{1, ctx.target_shape[1]} &&
|
||||
layout_out.hasNaturalTopology(ctx.target_shape) &&
|
||||
*(dst_shape.end() - 1) != *(src_shape.end() - 1) &&
|
||||
*(dst_shape.end() - 1) == ctx.target_shape[1] &&
|
||||
*(dst_shape.end() - 2) % layout_out.tiling()[0] == 0 &&
|
||||
*(src_shape.end() - 1) % layout_in.tiling()[1] == 0) {
|
||||
no_op = true;
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
xla::Array<Value> src_vregs,
|
||||
|
@ -995,7 +995,7 @@ class VectorLayoutInferer {
|
||||
setLayout(op, layout, layout);
|
||||
return success();
|
||||
}
|
||||
// Sublane (un)tiling
|
||||
// Sublane (un)tiling.
|
||||
if (res_rank >= 2 && layout.offsets() == LayoutOffsets{0, 0} &&
|
||||
layout.tiling()[1] == target_shape_[1] &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 1) ==
|
||||
@ -1005,6 +1005,42 @@ class VectorLayoutInferer {
|
||||
setLayout(op, layout, layout);
|
||||
return success();
|
||||
}
|
||||
// Lane (un)tiling.
|
||||
if (res_rank >= 2 && layout.offsets() == LayoutOffsets{0, 0} &&
|
||||
layout.tiling()[1] == target_shape_[1] &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 1) !=
|
||||
res_shape[res_shape.size() - 1] &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 1) % layout.tiling()[1] == 0 &&
|
||||
res_shape[res_shape.size() - 1] % layout.tiling()[1] == 0) {
|
||||
// TODO(jevinjiang): support shapecast along lane with any bitwidth.
|
||||
if (src_ty.getElementTypeBitWidth() != kNativeBitwidth) {
|
||||
NYI("Shapecast along lane dimension when bitwidth is not 32");
|
||||
}
|
||||
// Inferring in_layout to have tiling (1, 128) triggers any necessary
|
||||
// relayout before shapecast.
|
||||
setInLayout(op,
|
||||
{VectorLayout(layout.bitwidth(), layout.offsets(),
|
||||
{1, target_shape_[1]}, ImplicitDim::kNone)});
|
||||
// If the input has tiling (1, target_shape_[1]) and the last two dims
|
||||
// of result shape are [n * target_shape_[0], target_shape_[1]], the
|
||||
// reshape becomes a no-op if only we change the tiling to match
|
||||
// target_shape_. For example, reshaping to layouts like the following
|
||||
// will not require any data movement:
|
||||
// 8x8x128 (1, 128)
|
||||
// 4x16x128 (1, 128)
|
||||
// 8x8x128 (8, 128)
|
||||
// ....
|
||||
if (res_shape[res_shape.size() - 1] == target_shape_[1] &&
|
||||
res_shape[res_shape.size() - 2] % target_shape_[0] == 0) {
|
||||
setOutLayout(op, VectorLayout(layout.bitwidth(), layout.offsets(),
|
||||
default_tiling_, ImplicitDim::kNone));
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO(b/299253805): support shapecast along lane for other cases.
|
||||
op.emitOpError("unsupported shape cast");
|
||||
return failure();
|
||||
}
|
||||
unsigned bitwidth = src_ty.getElementTypeBitWidth();
|
||||
auto native_tiling = nativeTiling(bitwidth);
|
||||
if (layout.tiling() != native_tiling) {
|
||||
|
@ -2748,6 +2748,18 @@ def _vector_shape_cast_rule(ctx: RewriteContext, op: vector.ShapeCastOp, # pyli
|
||||
and src_ty.shape[-2] % layout_in.tiling[-2] == 0
|
||||
):
|
||||
no_op = True
|
||||
elif (
|
||||
layout_in.implicit_dim is None
|
||||
and layout_out.implicit_dim is None
|
||||
and layout_out.offsets == layout_in.offsets == (0, 0)
|
||||
and layout_in.tiling == (1, TARGET_SHAPE.lanes)
|
||||
and layout_out.has_natural_topology
|
||||
and dst_ty.shape[-1] != src_ty.shape[-1]
|
||||
and dst_ty.shape[-1] == TARGET_SHAPE.lanes
|
||||
and dst_ty.shape[-2] % layout_out.tiling[-2] == 0
|
||||
and src_ty.shape[-1] % layout_in.tiling[-1] == 0
|
||||
): # n x (m * 128) -> n x m x 128.
|
||||
no_op = True
|
||||
|
||||
src_vregs = disassemble(layout_in, op.source)
|
||||
if no_op:
|
||||
|
Loading…
x
Reference in New Issue
Block a user