mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 00:06:06 +00:00
[Mosaic] Use implicit shape over directly using implicit_dim()
in layout.{cc, h}
This CL also fixes a bug in `VectorLayout::join` where `ImplicitDim::kMinor` was considered equivalent to `ImplicitDim::kNone` when the shape's minor dimension is 1 (also needed to check that the second-minor dimension is 1). Often handling every implicit dim case separately is more complex and error-prone. There will be more follow-up changes to do this consistently elsewhere the code. This is also a first step towards 0D layout support. PiperOrigin-RevId: 636250759
This commit is contained in:
parent
0453f4400a
commit
5ae2491853
@ -441,8 +441,7 @@ bool VectorLayout::hasNativeTiling(
|
||||
SmallVector<int64_t> VectorLayout::implicitShape(
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<int64_t> implicit_shape(shape);
|
||||
const int64_t num_implicit_dims = 2 - layout_rank();
|
||||
implicit_shape.reserve(shape.size() + num_implicit_dims);
|
||||
implicit_shape.reserve(shape.size() + num_implicit_dims());
|
||||
insertImplicit(implicit_shape, 1);
|
||||
return implicit_shape;
|
||||
}
|
||||
@ -478,29 +477,17 @@ std::unique_ptr<VRegDataBounds> VectorLayout::tileDataBounds(
|
||||
// TODO(apaszke): allow_replicated could have been generalized to specify
|
||||
// what action should be taken when a REPLICATED offset is encountered.
|
||||
// Right now it either disallows replication, or selects the whole dimension.
|
||||
int64_t s, l;
|
||||
switch (implicit_dim_) {
|
||||
case ImplicitDim::kNone:
|
||||
s = idxs[idxs.size() - 2];
|
||||
l = idxs[idxs.size() - 1];
|
||||
break;
|
||||
case ImplicitDim::kMinor:
|
||||
s = idxs[idxs.size() - 1];
|
||||
l = 0;
|
||||
break;
|
||||
case ImplicitDim::kSecondMinor:
|
||||
s = 0;
|
||||
l = idxs[idxs.size() - 1];
|
||||
break;
|
||||
}
|
||||
|
||||
const std::array<int64_t, 2> tiled_idxs = getImplicitTiledDims(idxs, 0);
|
||||
const int64_t s = tiled_idxs[0];
|
||||
const int64_t l = tiled_idxs[1];
|
||||
const SmallVector<int64_t> tiles_implicit_shape =
|
||||
tileArrayImplicitShape(full_shape, target_shape);
|
||||
const int64_t ns = tiles_implicit_shape[tiles_implicit_shape.size() - 2];
|
||||
const int64_t nl = tiles_implicit_shape[tiles_implicit_shape.size() - 1];
|
||||
const SmallVector<int64_t> implicit_shape = implicitShape(full_shape);
|
||||
const int64_t is = implicit_shape[implicit_shape.size() - 2];
|
||||
const int64_t il = implicit_shape[implicit_shape.size() - 1];
|
||||
const int64_t ns = *(tiles_implicit_shape.end() - 2);
|
||||
const int64_t nl = *(tiles_implicit_shape.end() - 1);
|
||||
const std::array<int64_t, 2> shape_tiled_dims =
|
||||
getImplicitTiledDims(full_shape, 1);
|
||||
const int64_t is = shape_tiled_dims[0];
|
||||
const int64_t il = shape_tiled_dims[1];
|
||||
|
||||
if (!hasNaturalTopology(target_shape)) {
|
||||
if (!offsets_[0].has_value() || !offsets_[1].has_value()) {
|
||||
@ -588,27 +575,12 @@ bool VectorLayout::generalizes(
|
||||
if (shape.data() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
// If the second-minor dimension is of size 1, then it does not matter
|
||||
// whether we have a second minor implicit dim or not.
|
||||
bool ok = false;
|
||||
if (((implicit_dim_ == ImplicitDim::kSecondMinor &&
|
||||
other.implicit_dim_ == ImplicitDim::kNone) ||
|
||||
(other.implicit_dim_ == ImplicitDim::kSecondMinor &&
|
||||
implicit_dim_ == ImplicitDim::kNone)) &&
|
||||
shape[shape.size() - 2] == 1) {
|
||||
ok = true;
|
||||
}
|
||||
// If sufficiently many trailing dimensions are of size 1, then it does not
|
||||
// matter if we use implicit dims to insert more.
|
||||
int max_rank = std::max(layout_rank(), other.layout_rank());
|
||||
CHECK_GE(max_rank, 1);
|
||||
CHECK_LE(max_rank, 2);
|
||||
if (*(shape.end() - 1) == 1 && (max_rank == 1 || *(shape.end() - 2) == 1)) {
|
||||
ok = true;
|
||||
}
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
// Since we do not reorder axes, if the shapes resulting from inserting
|
||||
// implicit dimensions resulting are the same in the 2 minormost dimensions
|
||||
// for both layouts, then the elements must be laid out the same way (i.e.
|
||||
// layouts are equivalent).
|
||||
return getImplicitTiledDims(shape, 1) ==
|
||||
other.getImplicitTiledDims(shape, 1);
|
||||
}
|
||||
if (tiling_ != other.tiling_) {
|
||||
// Don't fail yet!
|
||||
@ -658,26 +630,8 @@ std::optional<VectorLayout> VectorLayout::join(const VectorLayout& l,
|
||||
if (l.bitwidth_ != r.bitwidth_ || l.tiling_ != r.tiling_) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (l.implicit_dim_ != r.implicit_dim_) {
|
||||
if (shape.size() < 2) {
|
||||
return std::nullopt;
|
||||
}
|
||||
ImplicitDim dim;
|
||||
if (l.implicit_dim_ == ImplicitDim::kNone) {
|
||||
dim = r.implicit_dim_;
|
||||
} else if (r.implicit_dim_ == ImplicitDim::kNone) {
|
||||
dim = l.implicit_dim_;
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (dim == ImplicitDim::kMinor && shape[shape.size() - 1] == 1) {
|
||||
// OK, they are equivalent.
|
||||
} else if (dim == ImplicitDim::kSecondMinor &&
|
||||
shape[shape.size() - 2] == 1) {
|
||||
// OK, they are equivalent.
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (l.getImplicitTiledDims(shape, 1) != r.getImplicitTiledDims(shape, 1)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
LayoutOffsets offsets;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
|
@ -245,8 +245,17 @@ class VectorLayout {
|
||||
const std::array<int64_t, 2> &tiling() const { return tiling_; }
|
||||
ImplicitDim implicit_dim() const { return implicit_dim_; }
|
||||
int packing() const { return 32 / bitwidth_; }
|
||||
int num_implicit_dims() const {
|
||||
switch (implicit_dim_) {
|
||||
case ImplicitDim::kNone:
|
||||
return 0;
|
||||
case ImplicitDim::kMinor:
|
||||
case ImplicitDim::kSecondMinor:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
// The number of minormost dimensions tiled by this layout.
|
||||
int layout_rank() const { return 1 + (implicit_dim_ == ImplicitDim::kNone); }
|
||||
int layout_rank() const { return 2 - num_implicit_dims(); }
|
||||
|
||||
bool operator==(const VectorLayout &other) const;
|
||||
bool operator!=(const VectorLayout &other) const {
|
||||
@ -302,6 +311,27 @@ class VectorLayout {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the value of the tiled (2 minormost) dimensions of the given array
|
||||
// with implicit dims inserted.
|
||||
//
|
||||
// Roughly equivalent to the following (but avoids vector allocation):
|
||||
//
|
||||
// SmallVector<int64_t> vec = arr;
|
||||
// insertImplicit(arr, implicit_value);
|
||||
// return {*(vec.end() - 2), *(vec.end() - 1)};
|
||||
std::array<int64_t, 2> getImplicitTiledDims(
|
||||
const ArrayRef<int64_t> arr, const int64_t implicit_value) const {
|
||||
CHECK_GE(arr.size(), layout_rank());
|
||||
switch (implicit_dim_) {
|
||||
case ImplicitDim::kNone:
|
||||
return {*(arr.end() - 2), *(arr.end() - 1)};
|
||||
case ImplicitDim::kMinor:
|
||||
return {*(arr.end() - 1), implicit_value};
|
||||
case ImplicitDim::kSecondMinor:
|
||||
return {implicit_value, *(arr.end() - 1)};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t> implicitShape(ArrayRef<int64_t> shape) const;
|
||||
|
||||
SmallVector<int64_t> tileArrayImplicitShape(
|
||||
|
@ -2960,15 +2960,14 @@ FailureOr<xla::Array<Value>> vector_extract_slice_impl(
|
||||
TPU_ASSERT_EQ_OP(num_indices, sizes.size());
|
||||
|
||||
SmallVector<int64_t> full_sizes;
|
||||
const int64_t num_implicit_dims = 2 - layout_in.layout_rank();
|
||||
full_sizes.reserve(src_vector_rank + num_implicit_dims);
|
||||
full_sizes.reserve(src_vector_rank + layout_in.num_implicit_dims());
|
||||
full_sizes.append(sizes.begin(), sizes.end());
|
||||
full_sizes.append(src_vector_shape.begin() + num_indices,
|
||||
src_vector_shape.end());
|
||||
layout_in.insertImplicit(full_sizes, 1); /* */
|
||||
layout_in.insertImplicit(full_sizes, 1);
|
||||
|
||||
SmallVector<int64_t> full_offsets;
|
||||
full_offsets.reserve(src_vector_rank + num_implicit_dims);
|
||||
full_offsets.reserve(src_vector_rank + layout_in.num_implicit_dims());
|
||||
full_offsets.append(offsets.begin(), offsets.end());
|
||||
full_offsets.append(src_vector_rank - num_indices, 0);
|
||||
layout_in.insertImplicit(full_offsets, 0);
|
||||
|
Loading…
x
Reference in New Issue
Block a user