mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic TPU] (8,128),-2 -> (8,128) for non-zero and replicated 2nd minor offset
Also fix bug where relayouts for fully replicated source assumed it was a no-op without checking implicit dims PiperOrigin-RevId: 655746766
This commit is contained in:
parent
f1cfd99fe8
commit
220ec2aa69
@ -288,7 +288,7 @@ class VectorLayout {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void insertImplicit(SmallVector<T> &vec, T value) const {
|
||||
void insertImplicit(SmallVectorImpl<T> &vec, T value) const {
|
||||
CHECK_GE(vec.size(), layout_rank());
|
||||
switch (implicit_dim_) {
|
||||
case ImplicitDim::kNone:
|
||||
@ -302,7 +302,7 @@ class VectorLayout {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void eraseImplicit(SmallVector<T> &vec) const {
|
||||
void eraseImplicit(SmallVectorImpl<T> &vec) const {
|
||||
CHECK_GE(vec.size(), 2);
|
||||
switch (implicit_dim_) {
|
||||
case ImplicitDim::kNone:
|
||||
|
@ -5021,8 +5021,37 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
if (bitwidth != dst.bitwidth()) {
|
||||
return emitError(v.getLoc(), "Can't change bitwidth during a relayout");
|
||||
}
|
||||
const int packing = src.packing();
|
||||
VectorType vty = v.getType();
|
||||
{
|
||||
// Replication imposes a replication constraint on the *logical* value of
|
||||
// the vector: When moving along a replicated axis, all elements must be
|
||||
// equal. Note that when the axis is a singleton, there is effectively no
|
||||
// added *logical* constraint.
|
||||
// For example, a vector<2x2xf32> v with no implicit dims and layout offsets
|
||||
// {*, 0} is expected to satisfy v[0, 0] == v[1, 0] and v[0, 1] == v[1, 1].
|
||||
// Relayout does not change the logical value of the vector. Any replication
|
||||
// constraints in the result must be guaranteed by the source layout.
|
||||
SmallVector<LayoutOffset, 2> src_offsets(ArrayRef(src.offsets()));
|
||||
SmallVector<LayoutOffset, 2> dst_offsets(ArrayRef(dst.offsets()));
|
||||
// Remove implicit dims to get offsets for trailing logical dims.
|
||||
src.eraseImplicit(src_offsets);
|
||||
dst.eraseImplicit(dst_offsets);
|
||||
for (int i = dst_offsets.size(); i > 0; --i) {
|
||||
const int64_t dim_size = *(vty.getShape().end() - i);
|
||||
const bool dim_replicated_in_dst = !*(dst_offsets.end() - i);
|
||||
// If the dim is untiled in the src layout, then there is no guarantee of
|
||||
// replication, because we don't track replication for untiled dims.
|
||||
const bool dim_replicated_in_src =
|
||||
i <= src_offsets.size() && !*(src_offsets.end() - i);
|
||||
if (dim_replicated_in_dst && !dim_replicated_in_src && dim_size != 1) {
|
||||
return emitError(v.getLoc(),
|
||||
"Invalid relayout: Non-singleton logical dimension is "
|
||||
"replicated in destination but not in source for ")
|
||||
<< vty << ": " << src << " -> " << dst;
|
||||
}
|
||||
}
|
||||
}
|
||||
const int packing = src.packing();
|
||||
|
||||
// Save the original value of dst to use it at the end. It determines the
|
||||
// out_layout of the result of assemble.
|
||||
@ -5054,8 +5083,8 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
/*use_implicit_shape=*/true)
|
||||
.getResult();
|
||||
}
|
||||
if (!src.offsets()[0].has_value() && !src.offsets()[1].has_value() &&
|
||||
src.tilesPerVreg(target_shape) == 1) {
|
||||
if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() &&
|
||||
!src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) {
|
||||
// A fully replicated value is always easy to relayout
|
||||
// It would be nice to be able to assert this here, but given replicated
|
||||
// values our rules can introduce equivalent expressions.
|
||||
@ -5258,25 +5287,29 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
|
||||
// This drops the implicit second minor dimension.
|
||||
src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
|
||||
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
src.bitwidth() == 32 && src.offsets() == dst.offsets() &&
|
||||
src.offsets() == LayoutOffsets{0, 0} && src.tiling() == dst.tiling() &&
|
||||
src.bitwidth() == 32 && dst.offsets()[0] &&
|
||||
src.offsets()[1] == dst.offsets()[1] && src.tiling() == dst.tiling() &&
|
||||
src.tiling() == std::array<int64_t, 2>{8, 128}) {
|
||||
xla::Array<Value> src_tiles_retiled(
|
||||
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
|
||||
src_tiles_retiled.Each(
|
||||
[&](const absl::Span<const int64_t> idx, Value *tile) {
|
||||
for (int dst_sl_idx = 0; dst_sl_idx < 8; ++dst_sl_idx) {
|
||||
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
|
||||
src.insertImplicit<int64_t>(src_idx, 0);
|
||||
auto second_minor_idx = idx.size() - 2;
|
||||
src_idx[second_minor_idx] = 8 * idx[second_minor_idx] + dst_sl_idx;
|
||||
if (src_idx[second_minor_idx] >= src_tiles.dim(second_minor_idx)) {
|
||||
break;
|
||||
}
|
||||
*tile = copy_one_sublane(builder, src_tiles(src_idx), 0, *tile,
|
||||
dst_sl_idx, target_shape);
|
||||
}
|
||||
});
|
||||
src_tiles_retiled.Each([&](const absl::Span<const int64_t> idx,
|
||||
Value *tile) {
|
||||
const int64_t dst_2nd_minor_idx = idx.size() - 2;
|
||||
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
|
||||
src.insertImplicit<int64_t>(src_idx, 0);
|
||||
const int dst_sl_start =
|
||||
idx[dst_2nd_minor_idx] == 0 ? *dst.offsets()[0] : 0;
|
||||
src_idx[dst_2nd_minor_idx] = target_shape[0] * idx[dst_2nd_minor_idx] +
|
||||
dst_sl_start - *dst.offsets()[0];
|
||||
for (int dst_sl_idx = dst_sl_start;
|
||||
dst_sl_idx < target_shape[0] &&
|
||||
src_idx[dst_2nd_minor_idx] < src_tiles.dim(dst_2nd_minor_idx);
|
||||
++dst_sl_idx, ++src_idx[dst_2nd_minor_idx]) {
|
||||
*tile = copy_one_sublane(builder, src_tiles(src_idx),
|
||||
src.offsets()[0].value_or(dst_sl_idx), *tile,
|
||||
dst_sl_idx, target_shape);
|
||||
}
|
||||
});
|
||||
src = dst;
|
||||
src_tiles = std::move(src_tiles_retiled);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user