[Mosaic TPU] Remove special handling of implicit dim in relayout

Now all changes happen inside the dedicated functions.

PiperOrigin-RevId: 658763465
This commit is contained in:
Adam Paszke 2024-08-02 05:45:51 -07:00 committed by jax authors
parent 80560663d3
commit 959657a489

View File

@ -5380,6 +5380,14 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
if (src.implicit_dim() == dst_implicit_dim) {
return std::make_pair(src, std::move(vregs));
}
// It's possible that the implicit dim change is a no-op.
VectorLayout src_candidate(src.bitwidth(), src.offsets(), src.tiling(),
dst_implicit_dim);
if (src_candidate.equivalentTo(src, vty.getShape(), target_shape)) {
vregs.Reshape(
src_candidate.tileArrayImplicitShape(vty.getShape(), target_shape));
return std::make_pair(src_candidate, vregs);
}
// Remove second minor implicit dim, for values that have (8, 128) tiling.
// TODO(apaszke): We should allow replicated dst_offset_hints[0].
if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
@ -5457,25 +5465,6 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
}
}
}
// Save the original value of dst to use it at the end. It determines the
// out_layout of the result of assemble.
// TODO(apaszke): Retiling should not care about the implicit dim. Move
// implicit dim adjustment to the end of this function.
const VectorLayout original_dst = dst;
// Try to reconcile differences in implicit dim.
if (src.implicit_dim() != dst.implicit_dim()) {
VectorLayout src_candidate(src.bitwidth(), src.offsets(), src.tiling(),
dst.implicit_dim());
if (src_candidate.equivalentTo(src, vty.getShape(), target_shape)) {
src = src_candidate;
} else {
VectorLayout dst_candidate(dst.bitwidth(), dst.offsets(), dst.tiling(),
src.implicit_dim());
if (dst_candidate.equivalentTo(dst, vty.getShape(), target_shape)) {
dst = dst_candidate;
}
}
}
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> src_tiles,
@ -5499,6 +5488,22 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
.getResult();
}
// Consider (1,128),-2 -> (8,128). In this case we can change the implicit
// dim for free before we change the tiling, but not after.
// TODO(apaszke): In general the number of vregs necessary to represent a
// value for different implicit dims satisfies kNone < kSecondMinor < kMinor.
// We should use this property to decide if we should change the implicit dim
// before or after changing the tiling and offsets.
if (src.implicit_dim() != dst.implicit_dim()) {
VectorLayout src_candidate(src.bitwidth(), src.offsets(), src.tiling(),
dst.implicit_dim());
if (src_candidate.equivalentTo(src, vty.getShape(), target_shape)) {
src = src_candidate;
src_tiles.Reshape(
src.tileArrayImplicitShape(vty.getShape(), target_shape));
}
}
FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),
changeTiling(builder, ctx.target_shape, v.getLoc(), vty, src,
@ -5518,10 +5523,9 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
std::move(src_tiles), dst.offsets()));
CHECK_EQ(src, dst); // At this point we've should be done.
src_tiles.Reshape(
original_dst.tileArrayImplicitShape(vty.getShape(), target_shape));
return assemble(builder, vty, original_dst, std::move(src_tiles),
target_shape, /*use_implicit_shape=*/true)
src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape));
return assemble(builder, vty, dst, std::move(src_tiles), target_shape,
/*use_implicit_shape=*/true)
.getResult();
}