mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic TPU] Remove special handling of implicit dim in relayout
Now all changes happen inside the dedicated functions. PiperOrigin-RevId: 658763465
This commit is contained in:
parent
80560663d3
commit
959657a489
@ -5380,6 +5380,14 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
|
||||
if (src.implicit_dim() == dst_implicit_dim) {
|
||||
return std::make_pair(src, std::move(vregs));
|
||||
}
|
||||
// It's possible that the implicit dim change is a no-op.
|
||||
VectorLayout src_candidate(src.bitwidth(), src.offsets(), src.tiling(),
|
||||
dst_implicit_dim);
|
||||
if (src_candidate.equivalentTo(src, vty.getShape(), target_shape)) {
|
||||
vregs.Reshape(
|
||||
src_candidate.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
return std::make_pair(src_candidate, vregs);
|
||||
}
|
||||
// Remove second minor implicit dim, for values that have (8, 128) tiling.
|
||||
// TODO(apaszke): We should allow replicated dst_offset_hints[0].
|
||||
if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
|
||||
@ -5457,25 +5465,6 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
}
|
||||
}
|
||||
}
|
||||
// Save the original value of dst to use it at the end. It determines the
|
||||
// out_layout of the result of assemble.
|
||||
// TODO(apaszke): Retiling should not care about the implicit dim. Move
|
||||
// implicit dim adjustment to the end of this function.
|
||||
const VectorLayout original_dst = dst;
|
||||
// Try to reconcile differences in implicit dim.
|
||||
if (src.implicit_dim() != dst.implicit_dim()) {
|
||||
VectorLayout src_candidate(src.bitwidth(), src.offsets(), src.tiling(),
|
||||
dst.implicit_dim());
|
||||
if (src_candidate.equivalentTo(src, vty.getShape(), target_shape)) {
|
||||
src = src_candidate;
|
||||
} else {
|
||||
VectorLayout dst_candidate(dst.bitwidth(), dst.offsets(), dst.tiling(),
|
||||
src.implicit_dim());
|
||||
if (dst_candidate.equivalentTo(dst, vty.getShape(), target_shape)) {
|
||||
dst = dst_candidate;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
xla::Array<Value> src_tiles,
|
||||
@ -5499,6 +5488,22 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// Consider (1,128),-2 -> (8,128). In this case we can change the implicit
|
||||
// dim for free before we change the tiling, but not after.
|
||||
// TODO(apaszke): In general the number of vregs necessary to represent a
|
||||
// value for different implicit dims satisfies kNone < kSecondMinor < kMinor.
|
||||
// We should use this property to decide if we should change the implicit dim
|
||||
// before or after changing the tiling and offsets.
|
||||
if (src.implicit_dim() != dst.implicit_dim()) {
|
||||
VectorLayout src_candidate(src.bitwidth(), src.offsets(), src.tiling(),
|
||||
dst.implicit_dim());
|
||||
if (src_candidate.equivalentTo(src, vty.getShape(), target_shape)) {
|
||||
src = src_candidate;
|
||||
src_tiles.Reshape(
|
||||
src.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
}
|
||||
}
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
std::tie(src, src_tiles),
|
||||
changeTiling(builder, ctx.target_shape, v.getLoc(), vty, src,
|
||||
@ -5518,10 +5523,9 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
std::move(src_tiles), dst.offsets()));
|
||||
|
||||
CHECK_EQ(src, dst); // At this point we've should be done.
|
||||
src_tiles.Reshape(
|
||||
original_dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
return assemble(builder, vty, original_dst, std::move(src_tiles),
|
||||
target_shape, /*use_implicit_shape=*/true)
|
||||
src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
return assemble(builder, vty, dst, std::move(src_tiles), target_shape,
|
||||
/*use_implicit_shape=*/true)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user