[Mosaic] Expand support of vector.broadcast

- Enable it for minor or second-minor implicit dims for the non-no-op case.
- Don't allow output offsets for broadcasted dimensions to be non-replicated. Make sure to assign them as replicated in infer-vector-layout for all cases.
- Don't fail when both tiled dimensions are logically broadcasted but only one of them requires actual broadcasting (before, it would hit the unimplemented sublane + lane broadcast case).

PiperOrigin-RevId: 637772134
This commit is contained in:
Tomás Longeri 2024-05-27 21:58:52 -07:00 committed by jax authors
parent e8a1113072
commit 3fb9acf01a
5 changed files with 122 additions and 110 deletions

View File

@ -442,7 +442,7 @@ SmallVector<int64_t> VectorLayout::implicitShape(
ArrayRef<int64_t> shape) const {
SmallVector<int64_t> implicit_shape(shape);
implicit_shape.reserve(shape.size() + num_implicit_dims());
insertImplicit(implicit_shape, 1);
insertImplicit<int64_t>(implicit_shape, 1);
return implicit_shape;
}

View File

@ -286,7 +286,8 @@ class VectorLayout {
return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]};
}
void insertImplicit(SmallVector<int64_t> &vec, int64_t value) const {
template <typename T>
void insertImplicit(SmallVector<T> &vec, T value) const {
CHECK_GE(vec.size(), layout_rank());
switch (implicit_dim_) {
case ImplicitDim::kNone:
@ -299,7 +300,8 @@ class VectorLayout {
}
}
void eraseImplicit(SmallVector<int64_t> &vec) const {
template <typename T>
void eraseImplicit(SmallVector<T> &vec) const {
CHECK_GE(vec.size(), 2);
switch (implicit_dim_) {
case ImplicitDim::kNone:

View File

@ -2664,8 +2664,11 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
ImplicitLocOpBuilder builder(op.getLoc(), &op);
vector::BroadcastOp broadcast_op = cast<vector::BroadcastOp>(op);
const VectorType dst_ty = broadcast_op.getResult().getType();
const ArrayRef<int64_t> dst_shape = dst_ty.getShape();
const SmallVector<int64_t> dst_tiles_shape =
layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape);
layout_out.tileArrayShape(dst_shape, ctx.target_shape);
const SmallVector<int64_t> dst_tiles_implicit_shape =
layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape);
if (auto src = dyn_cast<TypedValue<VectorType>>(broadcast_op.getSource())) {
VectorType src_ty = src.getType();
TPU_ASSERT_OP(maybe_layout_in.has_value());
@ -2674,8 +2677,6 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
return op.emitOpError(
"Not implemented: Changing implicit dims mid-broadcast");
}
const VectorLayout::ImplicitDim implicit_dim = layout_in.implicit_dim();
const int layout_rank = layout_in.layout_rank();
const LayoutOffsets offsets_in = layout_in.offsets();
const LayoutOffsets offsets_out = layout_out.offsets();
if (layout_in.tiling() != layout_out.tiling()) {
@ -2684,78 +2685,76 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
auto tiling = layout_in.tiling();
const int64_t expand_rank = dst_ty.getRank() - src_ty.getRank();
SmallVector<int64_t> src_shape_padded(expand_rank, -1);
const ArrayRef<int64_t> src_shape = src_ty.getShape();
src_shape_padded.append(src_shape.begin(), src_shape.end());
const SmallVector<bool> dim_eq = llvm::map_to_vector(
llvm::zip(src_shape_padded, dst_ty.getShape()), [](auto tup) {
auto [i, o] = tup;
return i == o;
});
bool no_op = false;
switch (implicit_dim) {
case VectorLayout::ImplicitDim::kNone: {
const ArrayRef<bool> tiled_dim_eq = ArrayRef<bool>(dim_eq).take_back(2);
for (auto [in_off, out_off, eq] :
llvm::zip(offsets_in, offsets_out, tiled_dim_eq)) {
if (eq && in_off != out_off) {
return op.emitOpError(
"Not implemented: Changing offsets mid-broadcast");
}
}
no_op = layout_in.hasNaturalTopology(ctx.target_shape) &&
layout_out.hasNaturalTopology(ctx.target_shape) &&
llvm::all_of(llvm::zip_equal(offsets_in, tiled_dim_eq),
[](auto tup) {
auto [o, eq] = tup;
return eq || !o.has_value();
});
} break;
case VectorLayout::ImplicitDim::kMinor:
case VectorLayout::ImplicitDim::kSecondMinor:
if (dim_eq.back()) {
if (offsets_in != offsets_out) {
return op.emitOpError(
"Not implemented: Changing offsets mid-broadcast");
}
no_op = true;
} else if (implicit_dim == VectorLayout::ImplicitDim::kSecondMinor &&
!offsets_in[1].has_value()) {
no_op = true;
} else if (implicit_dim == VectorLayout::ImplicitDim::kMinor &&
!offsets_in[0].has_value()) {
no_op = true;
}
break;
SmallVector<int64_t> src_implicit_shape_padded;
// `is_logical_broadcast` stores whether each dimension of the implicit
// shape of the result is a broadcast. E.g. if the implicit shape goes from
// (2, 1, 3) to (4, 2, 5, 3) it's (true, false, true, false).
SmallVector<bool> is_logical_broadcast;
src_implicit_shape_padded.reserve(dst_shape.size() +
layout_in.num_implicit_dims());
is_logical_broadcast.reserve(dst_shape.size() +
layout_in.num_implicit_dims());
src_implicit_shape_padded.append(expand_rank, 1);
src_implicit_shape_padded.append(src_shape.begin(), src_shape.end());
for (auto [i, o] : llvm::zip(src_implicit_shape_padded, dst_shape)) {
TPU_ASSERT_OP(i == o || i == 1); // Verifier should guarantee this.
is_logical_broadcast.push_back(i != o);
}
TPU_ASSERT_OP(layout_rank);
if (src_ty.getShape().take_back(layout_rank) ==
dst_ty.getShape().take_back(layout_rank)) {
if (offsets_in != offsets_out) {
op.emitOpError("Not implemented: Changing offsets mid-broadcast");
layout_in.insertImplicit<int64_t>(src_implicit_shape_padded, 1);
layout_in.insertImplicit<bool>(is_logical_broadcast, false);
// Verify that the offsets are valid.
for (auto [is_logical_broadcast_on_dim, in_off, out_off] :
llvm::zip_equal(ArrayRef(is_logical_broadcast).take_back(2),
offsets_in, offsets_out)) {
if (is_logical_broadcast_on_dim) {
if (out_off.has_value()) {
// There's no reason to ever assign a non-replicated offset to a
// broadcasted dimension in the output.
return op.emitOpError(
// TODO(tlongeri): This should never be implemented but the fuzzed
// tests expect a NotImplementedError, which
// is raised with a "Not implemented" (see
// NotImplementedDetector in tpu_ext.cc). Fix.
"Not implemented: Broadcast output expected to have replicated "
"offsets.");
}
} else { // !is_logical_broadcast_on_dim
if (in_off != out_off) {
return op.emitOpError(
"Not implemented: Changing offsets mid-broadcast");
}
}
no_op = true;
}
// `needs_physical_broadcast` specifies whether we need to broadcast vregs
// vregs in the sublane and lane dimensions. We only need to do this if the
// corresponding dimension of the implicit shape is logically broadcast and
// if the input vregs are not already replicated along this dimension.
const std::array<bool, 2> needs_physical_broadcast{
*(is_logical_broadcast.end() - 2) && offsets_in[0].has_value(),
*(is_logical_broadcast.end() - 1) && offsets_in[1].has_value()};
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> src_tiles,
disassemble(builder, layout_in, src, ctx.target_shape));
xla::Array<Value> dst_tiles(dst_tiles_shape);
if (no_op) {
disassemble(builder, layout_in, src, ctx.target_shape,
/*use_implicit_shape=*/true));
xla::Array<Value> dst_tiles(dst_tiles_implicit_shape);
if (needs_physical_broadcast == std::array{false, false}) { // No-op
SmallVector<int64_t> reshape_dims(expand_rank, 1);
const absl::Span<const int64_t> src_tiles_dims = src_tiles.dimensions();
reshape_dims.append(src_tiles_dims.begin(), src_tiles_dims.end());
src_tiles.Reshape(reshape_dims);
dst_tiles.Each([&](const absl::Span<const int64_t> dst_idx, Value *tile) {
const SmallVector<int64_t> src_idx =
llvm::map_to_vector(llvm::zip_equal(dst_idx, dim_eq), [](auto tup) {
auto [i, eq] = tup;
return eq ? i : 0;
const SmallVector<int64_t> src_idx = llvm::map_to_vector(
llvm::zip_equal(dst_idx, is_logical_broadcast), [](auto tup) {
auto [i, is_logical_broadcast_on_dim] = tup;
return is_logical_broadcast_on_dim ? 0 : i;
});
*tile = src_tiles(src_idx);
});
} else if (implicit_dim == VectorLayout::ImplicitDim::kNone) {
} else {
if (layout_in.bitwidth() != 32) {
return op.emitOpError(
"Not implemented: Only 32-bit broadcast supported");
@ -2764,8 +2763,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
return op.emitOpError("Not implemented: unsupported tiling");
}
int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape);
TPU_ASSERT_OP(!*(dim_eq.end() - 1) || !*(dim_eq.end() - 2));
if (*(dim_eq.end() - 1)) { // Sublane broadcast
if (needs_physical_broadcast ==
std::array{true, false}) { // Sublane broadcast
if (num_tiles != 1) {
return op.emitOpError(
"Not implemented: Only native tiling supported");
@ -2777,12 +2776,12 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
SmallVector<int32_t>(ctx.target_shape[0], offset));
src_tiles.Each([&](const absl::Span<const int64_t> src_idx,
Value *const src_tile) {
SmallVector<int64_t> dst_starts(dst_tiles_shape.size());
SmallVector<int64_t> dst_limits(dst_tiles_shape.size());
SmallVector<int64_t> dst_starts(dst_tiles_implicit_shape.size());
SmallVector<int64_t> dst_limits(dst_tiles_implicit_shape.size());
for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) {
if (i < expand_rank || !dim_eq[i]) {
if (i < expand_rank || is_logical_broadcast[i]) {
dst_starts[i] = 0;
dst_limits[i] = dst_tiles_shape[i];
dst_limits[i] = dst_tiles_implicit_shape[i];
} else {
dst_starts[i] = src_idx[i - expand_rank];
dst_limits[i] = dst_starts[i] + 1;
@ -2793,7 +2792,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
src_tile->getType(), *src_tile, indices, 0),
dst_starts, dst_limits);
});
} else if (*(dim_eq.end() - 2)) { // Lane broadcast
} else if (needs_physical_broadcast ==
std::array{false, true}) { // Lane broadcast
TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 1), 1);
TPU_ASSERT_OP(offsets_in[1].has_value());
const int64_t offset = *offsets_in[1];
@ -2816,12 +2816,12 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
}
src_tiles.Each([&](const absl::Span<const int64_t> src_idx,
Value *const src_tile) {
SmallVector<int64_t> dst_starts(dst_tiles_shape.size());
SmallVector<int64_t> dst_limits(dst_tiles_shape.size());
SmallVector<int64_t> dst_starts(dst_tiles_implicit_shape.size());
SmallVector<int64_t> dst_limits(dst_tiles_implicit_shape.size());
for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) {
if (i < expand_rank || !dim_eq[i]) {
if (i < expand_rank || is_logical_broadcast[i]) {
dst_starts[i] = 0;
dst_limits[i] = dst_tiles_shape[i];
dst_limits[i] = dst_tiles_implicit_shape[i];
} else {
dst_starts[i] = src_idx[i - expand_rank];
dst_limits[i] = dst_starts[i] + 1;
@ -2838,14 +2838,15 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
updateSlice<Value>(dst_tiles, res_vreg, dst_starts, dst_limits);
});
} else {
return op.emitOpError("Not implemented");
TPU_ASSERT_OP((needs_physical_broadcast == std::array{true, true}));
return op.emitOpError(
"Not implemented: Broadcast in both sublanes and lanes");
}
} else {
return op.emitOpError("Not implemented");
}
broadcast_op.replaceAllUsesWith(
assemble(builder, dst_ty, layout_out, dst_tiles, ctx.target_shape)
.getOperation());
broadcast_op.replaceAllUsesWith(assemble(builder, dst_ty, layout_out,
dst_tiles, ctx.target_shape,
/*use_implicit_shape=*/true)
.getOperation());
broadcast_op.erase();
return success();
} else if (layout_out.bitwidth() == 32 &&
@ -2981,13 +2982,13 @@ FailureOr<xla::Array<Value>> vector_extract_slice_impl(
full_sizes.append(sizes.begin(), sizes.end());
full_sizes.append(src_vector_shape.begin() + num_indices,
src_vector_shape.end());
layout_in.insertImplicit(full_sizes, 1);
layout_in.insertImplicit<int64_t>(full_sizes, 1);
SmallVector<int64_t> full_offsets;
full_offsets.reserve(src_vector_rank + layout_in.num_implicit_dims());
full_offsets.append(offsets.begin(), offsets.end());
full_offsets.append(src_vector_rank - num_indices, 0);
layout_in.insertImplicit(full_offsets, 0);
layout_in.insertImplicit<int64_t>(full_offsets, 0);
// We currently only support no-op cases - that is, those where we effectively
// just extract a slice of vregs without doing any operations (e.g. shifts) on
@ -4036,9 +4037,16 @@ const llvm::StringMap<rule_type> &rules() {
RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
const VectorLayout &layout,
const xla::Array<Value> &vals,
const std::array<int64_t, 2> target_shape) {
CHECK(vals.dimensions() ==
layout.tileArrayShape(vty.getShape(), target_shape));
const std::array<int64_t, 2> target_shape,
const bool use_implicit_shape) {
// TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of
// having `tileArrayShape` and `tileArrayImplicitShape`.
SmallVector<int64_t> vreg_array_shape =
layout.tileArrayImplicitShape(vty.getShape(), target_shape);
if (!use_implicit_shape) {
layout.eraseImplicit(vreg_array_shape);
}
CHECK(vals.dimensions() == vreg_array_shape);
CHECK_GT(vals.num_elements(), 0);
Location loc = vals.begin()->getLoc();
auto op =
@ -4059,8 +4067,8 @@ RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
// An ndarray of MLIR values representing the tiling of val given by layout.
FailureOr<xla::Array<Value>> disassemble(
OpBuilder &builder, const VectorLayout &layout,
const TypedValue<VectorType> val,
const std::array<int64_t, 2> target_shape) {
const TypedValue<VectorType> val, const std::array<int64_t, 2> target_shape,
const bool use_implicit_shape) { // TODO(tlongeri): Remove default
const auto vty = val.getType();
const auto op_result = dyn_cast<OpResult>(val);
if (op_result == nullptr) {
@ -4074,8 +4082,13 @@ FailureOr<xla::Array<Value>> disassemble(
TPU_ASSERT_LOC(val.getLoc(), def_layout.has_value());
TPU_ASSERT_LOC(val.getLoc(),
def_layout->generalizes(layout, vty.getShape(), target_shape));
// TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of
// having `tileArrayShape` and `tileArrayImplicitShape`.
SmallVector<int64_t> layout_shape =
layout.tileArrayShape(vty.getShape(), target_shape);
layout.tileArrayImplicitShape(vty.getShape(), target_shape);
if (!use_implicit_shape) {
layout.eraseImplicit(layout_shape);
}
if (auto roll_vectors_op = dyn_cast<RollVectorsOp>(op)) {
return XlaArrayFromShapeAndValues<Value>(layout_shape,
roll_vectors_op->getOperands());

View File

@ -25,14 +25,17 @@ struct RewriteContext {
MLIRContext *getMLIRContext() { return func.getContext(); }
};
// TODO(tlongeri): Remove default values for use_implicit_shape.
RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
const VectorLayout &layout,
const xla::Array<Value> &vals,
std::array<int64_t, 2> target_shape);
std::array<int64_t, 2> target_shape,
bool use_implicit_shape = false);
FailureOr<xla::Array<Value>> disassemble(OpBuilder &builder,
const VectorLayout &layout,
TypedValue<VectorType> val,
std::array<int64_t, 2> target_shape);
std::array<int64_t, 2> target_shape,
bool use_implicit_shape = false);
// Rewrites the operation according to its layout annotations.
//

View File

@ -893,48 +893,42 @@ class VectorLayoutInferer {
TPU_CHECK_OP(res_ty.getRank() >= 2, "result rank below 2D unsupported");
auto some_layout = getLayout(op.getSource());
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
auto &layout = *some_layout;
// Since we can only do sublane broadcasts in the (8, 128) tiling, we
// should always use that when sublane broadcasting is required.
if (src_ty.getDimSize(src_ty.getRank() - 2) !=
res_ty.getDimSize(res_ty.getRank() - 2)) {
if (some_layout->bitwidth() != kNativeBitwidth) {
if (*(src_ty.getShape().end() - 2) != *(res_ty.getShape().end() - 2)) {
if (layout.bitwidth() != kNativeBitwidth) {
NYI("Only 32-bit broadcasts supported");
}
LayoutOffsets offsets = some_layout->offsets();
LayoutOffsets offsets = layout.offsets();
// At the moment relayout can only produce replicated sublanes when
// converting to (8, 128) if the input was in (1, 128) tiling
if (some_layout->tiling()[0] == 1) {
if (layout.tiling()[0] == 1) {
offsets[0] = std::nullopt;
}
*some_layout =
VectorLayout(some_layout->bitwidth(), offsets, default_tiling_,
some_layout->implicit_dim());
layout = VectorLayout(layout.bitwidth(), offsets, default_tiling_,
layout.implicit_dim());
}
auto &layout = *some_layout;
if (layout.implicit_dim() != ImplicitDim::kNone) {
VectorLayout layout_2d(layout.bitwidth(), layout.offsets(),
layout.tiling(), ImplicitDim::kNone);
if (layout_2d.equivalentTo(layout, src_ty.getShape(), target_shape_)) {
// TODO(b/342237796): Stop preferring 2D layouts (if given the choice)
// and defer the work, if any, to relayout.
layout = layout_2d;
} else {
op.emitOpError() << "Only 2D layouts supported";
return failure();
}
}
auto src_tiled_shape = src_ty.getShape().take_back(2);
auto dst_tiled_shape = res_ty.getShape().take_back(2);
LayoutOffsets offsets = layout.offsets();
if (layout.bitwidth() == kNativeBitwidth &&
layout.tiling() == default_tiling_) {
for (int i = 0; i < 2; ++i) {
if (src_tiled_shape[i] != dst_tiled_shape[i]) {
offsets[i] = std::nullopt;
}
for (int i = 0; i < 2; ++i) {
if (src_tiled_shape[i] != dst_tiled_shape[i]) {
offsets[i] = std::nullopt;
}
}
setLayout(op, some_layout,
setLayout(op, layout,
VectorLayout(layout.bitwidth(), offsets, layout.tiling(),
ImplicitDim::kNone));
layout.implicit_dim()));
return success();
}
op.emitOpError("unsupported broadcast source type");
@ -1122,7 +1116,7 @@ class VectorLayoutInferer {
auto offsets = llvm::map_to_vector(offsets_attr, [](auto attr) {
return cast<IntegerAttr>(attr).getInt();
});
input_layout->insertImplicit(offsets, 0);
input_layout->insertImplicit<int64_t>(offsets, 0);
auto vreg_slice = input_layout->vregSlice(target_shape_);
LayoutOffsets new_layout_offsets;
if (input_layout->offsets()[0].has_value()) {