[Mosaic] Use implicit shape over directly using implicit_dim() in layout.{cc, h}

This CL also fixes a bug in `VectorLayout::join` where `ImplicitDim::kMinor` was considered equivalent to `ImplicitDim::kNone` when the shape's minor dimension is 1 (also needed to check that the second-minor dimension is 1).

Often handling every implicit dim case separately is more complex and error-prone.

There will be more follow-up changes to do this consistently elsewhere the code. This is also a first step towards 0D layout support.

PiperOrigin-RevId: 636250759
This commit is contained in:
Tomás Longeri 2024-05-22 12:10:40 -07:00 committed by jax authors
parent 0453f4400a
commit 5ae2491853
3 changed files with 52 additions and 69 deletions

View File

@ -441,8 +441,7 @@ bool VectorLayout::hasNativeTiling(
SmallVector<int64_t> VectorLayout::implicitShape(
ArrayRef<int64_t> shape) const {
SmallVector<int64_t> implicit_shape(shape);
const int64_t num_implicit_dims = 2 - layout_rank();
implicit_shape.reserve(shape.size() + num_implicit_dims);
implicit_shape.reserve(shape.size() + num_implicit_dims());
insertImplicit(implicit_shape, 1);
return implicit_shape;
}
@ -478,29 +477,17 @@ std::unique_ptr<VRegDataBounds> VectorLayout::tileDataBounds(
// TODO(apaszke): allow_replicated could have been generalized to specify
// what action should be taken when a REPLICATED offset is encountered.
// Right now it either disallows replication, or selects the whole dimension.
int64_t s, l;
switch (implicit_dim_) {
case ImplicitDim::kNone:
s = idxs[idxs.size() - 2];
l = idxs[idxs.size() - 1];
break;
case ImplicitDim::kMinor:
s = idxs[idxs.size() - 1];
l = 0;
break;
case ImplicitDim::kSecondMinor:
s = 0;
l = idxs[idxs.size() - 1];
break;
}
const std::array<int64_t, 2> tiled_idxs = getImplicitTiledDims(idxs, 0);
const int64_t s = tiled_idxs[0];
const int64_t l = tiled_idxs[1];
const SmallVector<int64_t> tiles_implicit_shape =
tileArrayImplicitShape(full_shape, target_shape);
const int64_t ns = tiles_implicit_shape[tiles_implicit_shape.size() - 2];
const int64_t nl = tiles_implicit_shape[tiles_implicit_shape.size() - 1];
const SmallVector<int64_t> implicit_shape = implicitShape(full_shape);
const int64_t is = implicit_shape[implicit_shape.size() - 2];
const int64_t il = implicit_shape[implicit_shape.size() - 1];
const int64_t ns = *(tiles_implicit_shape.end() - 2);
const int64_t nl = *(tiles_implicit_shape.end() - 1);
const std::array<int64_t, 2> shape_tiled_dims =
getImplicitTiledDims(full_shape, 1);
const int64_t is = shape_tiled_dims[0];
const int64_t il = shape_tiled_dims[1];
if (!hasNaturalTopology(target_shape)) {
if (!offsets_[0].has_value() || !offsets_[1].has_value()) {
@ -588,27 +575,12 @@ bool VectorLayout::generalizes(
if (shape.data() == nullptr) {
return false;
}
// If the second-minor dimension is of size 1, then it does not matter
// whether we have a second minor implicit dim or not.
bool ok = false;
if (((implicit_dim_ == ImplicitDim::kSecondMinor &&
other.implicit_dim_ == ImplicitDim::kNone) ||
(other.implicit_dim_ == ImplicitDim::kSecondMinor &&
implicit_dim_ == ImplicitDim::kNone)) &&
shape[shape.size() - 2] == 1) {
ok = true;
}
// If sufficiently many trailing dimensions are of size 1, then it does not
// matter if we use implicit dims to insert more.
int max_rank = std::max(layout_rank(), other.layout_rank());
CHECK_GE(max_rank, 1);
CHECK_LE(max_rank, 2);
if (*(shape.end() - 1) == 1 && (max_rank == 1 || *(shape.end() - 2) == 1)) {
ok = true;
}
if (!ok) {
return false;
}
// Since we do not reorder axes, if the shapes resulting from inserting
// implicit dimensions resulting are the same in the 2 minormost dimensions
// for both layouts, then the elements must be laid out the same way (i.e.
// layouts are equivalent).
return getImplicitTiledDims(shape, 1) ==
other.getImplicitTiledDims(shape, 1);
}
if (tiling_ != other.tiling_) {
// Don't fail yet!
@ -658,26 +630,8 @@ std::optional<VectorLayout> VectorLayout::join(const VectorLayout& l,
if (l.bitwidth_ != r.bitwidth_ || l.tiling_ != r.tiling_) {
return std::nullopt;
}
if (l.implicit_dim_ != r.implicit_dim_) {
if (shape.size() < 2) {
return std::nullopt;
}
ImplicitDim dim;
if (l.implicit_dim_ == ImplicitDim::kNone) {
dim = r.implicit_dim_;
} else if (r.implicit_dim_ == ImplicitDim::kNone) {
dim = l.implicit_dim_;
} else {
return std::nullopt;
}
if (dim == ImplicitDim::kMinor && shape[shape.size() - 1] == 1) {
// OK, they are equivalent.
} else if (dim == ImplicitDim::kSecondMinor &&
shape[shape.size() - 2] == 1) {
// OK, they are equivalent.
} else {
return std::nullopt;
}
if (l.getImplicitTiledDims(shape, 1) != r.getImplicitTiledDims(shape, 1)) {
return std::nullopt;
}
LayoutOffsets offsets;
for (int i = 0; i < 2; ++i) {

View File

@ -245,8 +245,17 @@ class VectorLayout {
const std::array<int64_t, 2> &tiling() const { return tiling_; }
ImplicitDim implicit_dim() const { return implicit_dim_; }
int packing() const { return 32 / bitwidth_; }
int num_implicit_dims() const {
switch (implicit_dim_) {
case ImplicitDim::kNone:
return 0;
case ImplicitDim::kMinor:
case ImplicitDim::kSecondMinor:
return 1;
}
}
// The number of minormost dimensions tiled by this layout.
int layout_rank() const { return 1 + (implicit_dim_ == ImplicitDim::kNone); }
int layout_rank() const { return 2 - num_implicit_dims(); }
bool operator==(const VectorLayout &other) const;
bool operator!=(const VectorLayout &other) const {
@ -302,6 +311,27 @@ class VectorLayout {
}
}
// Returns the value of the tiled (2 minormost) dimensions of the given array
// with implicit dims inserted.
//
// Roughly equivalent to the following (but avoids vector allocation):
//
// SmallVector<int64_t> vec = arr;
// insertImplicit(arr, implicit_value);
// return {*(vec.end() - 2), *(vec.end() - 1)};
std::array<int64_t, 2> getImplicitTiledDims(
const ArrayRef<int64_t> arr, const int64_t implicit_value) const {
CHECK_GE(arr.size(), layout_rank());
switch (implicit_dim_) {
case ImplicitDim::kNone:
return {*(arr.end() - 2), *(arr.end() - 1)};
case ImplicitDim::kMinor:
return {*(arr.end() - 1), implicit_value};
case ImplicitDim::kSecondMinor:
return {implicit_value, *(arr.end() - 1)};
}
}
SmallVector<int64_t> implicitShape(ArrayRef<int64_t> shape) const;
SmallVector<int64_t> tileArrayImplicitShape(

View File

@ -2960,15 +2960,14 @@ FailureOr<xla::Array<Value>> vector_extract_slice_impl(
TPU_ASSERT_EQ_OP(num_indices, sizes.size());
SmallVector<int64_t> full_sizes;
const int64_t num_implicit_dims = 2 - layout_in.layout_rank();
full_sizes.reserve(src_vector_rank + num_implicit_dims);
full_sizes.reserve(src_vector_rank + layout_in.num_implicit_dims());
full_sizes.append(sizes.begin(), sizes.end());
full_sizes.append(src_vector_shape.begin() + num_indices,
src_vector_shape.end());
layout_in.insertImplicit(full_sizes, 1); /* */
layout_in.insertImplicit(full_sizes, 1);
SmallVector<int64_t> full_offsets;
full_offsets.reserve(src_vector_rank + num_implicit_dims);
full_offsets.reserve(src_vector_rank + layout_in.num_implicit_dims());
full_offsets.append(offsets.begin(), offsets.end());
full_offsets.append(src_vector_rank - num_indices, 0);
layout_in.insertImplicit(full_offsets, 0);