mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic] Allow vector.shape_cast to (un)fold the sublane dim, for as long as it remains a multiple of sublane tiling
The old guards were overly restrictive, and we can actually treat a much larger class of reshapes as no-ops. PiperOrigin-RevId: 570049016
This commit is contained in:
parent
77d11e4dfd
commit
c9851ac7f3
@ -972,6 +972,16 @@ class VectorLayoutInferer {
|
||||
setLayout(op, layout, layout);
|
||||
return success();
|
||||
}
|
||||
// Sublane (un)tiling
|
||||
if (res_rank >= 2 && layout.offsets() == LayoutOffsets{0, 0} &&
|
||||
layout.tiling()[1] == target_shape_[1] &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 1) ==
|
||||
res_shape[res_shape.size() - 1] &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 2) % layout.tiling()[0] == 0 &&
|
||||
res_shape[res_shape.size() - 2] % layout.tiling()[0] == 0) {
|
||||
setLayout(op, layout, layout);
|
||||
return success();
|
||||
}
|
||||
unsigned bitwidth = src_ty.getElementTypeBitWidth();
|
||||
auto native_tiling = nativeTiling(bitwidth);
|
||||
if (layout.tiling() != native_tiling) {
|
||||
@ -991,16 +1001,6 @@ class VectorLayoutInferer {
|
||||
ImplicitDim::kSecondMinor));
|
||||
return success();
|
||||
}
|
||||
// Sublane (un)tiling
|
||||
if (src_ty.getElementTypeBitWidth() == kNativeBitwidth &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 1) ==
|
||||
res_shape[res_shape.size() - 1] &&
|
||||
layout.offsets() == LayoutOffsets{0, 0} &&
|
||||
src_ty.getDimSize(src_ty.getRank() - 2) % target_shape_[0] == 0 &&
|
||||
res_shape[res_shape.size() - 2] % target_shape_[0] == 0) {
|
||||
setLayout(op, layout, layout);
|
||||
return success();
|
||||
}
|
||||
// Insert a singleton lane dimension. The old lane dimension ends up
|
||||
// in the sublane dimension. Other axes can be reshaped arbitrarily.
|
||||
if (src_ty.getElementTypeBitWidth() == kNativeBitwidth &&
|
||||
|
@ -2653,8 +2653,8 @@ def _vector_shape_cast_rule(ctx: RewriteContext, op: vector.ShapeCastOp, # pyli
|
||||
layout_in.implicit_dim is None
|
||||
and layout_out.implicit_dim is None
|
||||
and layout_in.offsets == layout_out.offsets == (0, 0)
|
||||
and layout_in.has_native_tiling
|
||||
and layout_in.tiling == layout_out.tiling
|
||||
and layout_in.tiling[-1] == TARGET_SHAPE.lanes
|
||||
and dst_ty.shape[-1] == src_ty.shape[-1]
|
||||
and dst_ty.shape[-2] % layout_in.tiling[-2] == 0
|
||||
and src_ty.shape[-2] % layout_in.tiling[-2] == 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user