mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Expand support of vector.broadcast
- Enable it for minor or second-minor implicit dims for the non-no-op case. - Don't allow output offsets for broadcasted dimensions to be non-replicated. Make sure to assign them as replicated in infer-vector-layout for all cases. - Don't fail when both tiled dimensions are logically broadcasted but only one of them requires actual broadcasting (before, it would hit the unimplemented sublane + lane broadcast case). PiperOrigin-RevId: 637772134
This commit is contained in:
parent
e8a1113072
commit
3fb9acf01a
@ -442,7 +442,7 @@ SmallVector<int64_t> VectorLayout::implicitShape(
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<int64_t> implicit_shape(shape);
|
||||
implicit_shape.reserve(shape.size() + num_implicit_dims());
|
||||
insertImplicit(implicit_shape, 1);
|
||||
insertImplicit<int64_t>(implicit_shape, 1);
|
||||
return implicit_shape;
|
||||
}
|
||||
|
||||
|
@ -286,7 +286,8 @@ class VectorLayout {
|
||||
return {tiling_[0], tilesPerVreg(target_shape) * tiling_[1]};
|
||||
}
|
||||
|
||||
void insertImplicit(SmallVector<int64_t> &vec, int64_t value) const {
|
||||
template <typename T>
|
||||
void insertImplicit(SmallVector<T> &vec, T value) const {
|
||||
CHECK_GE(vec.size(), layout_rank());
|
||||
switch (implicit_dim_) {
|
||||
case ImplicitDim::kNone:
|
||||
@ -299,7 +300,8 @@ class VectorLayout {
|
||||
}
|
||||
}
|
||||
|
||||
void eraseImplicit(SmallVector<int64_t> &vec) const {
|
||||
template <typename T>
|
||||
void eraseImplicit(SmallVector<T> &vec) const {
|
||||
CHECK_GE(vec.size(), 2);
|
||||
switch (implicit_dim_) {
|
||||
case ImplicitDim::kNone:
|
||||
|
@ -2664,8 +2664,11 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), &op);
|
||||
vector::BroadcastOp broadcast_op = cast<vector::BroadcastOp>(op);
|
||||
const VectorType dst_ty = broadcast_op.getResult().getType();
|
||||
const ArrayRef<int64_t> dst_shape = dst_ty.getShape();
|
||||
const SmallVector<int64_t> dst_tiles_shape =
|
||||
layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape);
|
||||
layout_out.tileArrayShape(dst_shape, ctx.target_shape);
|
||||
const SmallVector<int64_t> dst_tiles_implicit_shape =
|
||||
layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape);
|
||||
if (auto src = dyn_cast<TypedValue<VectorType>>(broadcast_op.getSource())) {
|
||||
VectorType src_ty = src.getType();
|
||||
TPU_ASSERT_OP(maybe_layout_in.has_value());
|
||||
@ -2674,8 +2677,6 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
return op.emitOpError(
|
||||
"Not implemented: Changing implicit dims mid-broadcast");
|
||||
}
|
||||
const VectorLayout::ImplicitDim implicit_dim = layout_in.implicit_dim();
|
||||
const int layout_rank = layout_in.layout_rank();
|
||||
const LayoutOffsets offsets_in = layout_in.offsets();
|
||||
const LayoutOffsets offsets_out = layout_out.offsets();
|
||||
if (layout_in.tiling() != layout_out.tiling()) {
|
||||
@ -2684,78 +2685,76 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
auto tiling = layout_in.tiling();
|
||||
|
||||
const int64_t expand_rank = dst_ty.getRank() - src_ty.getRank();
|
||||
SmallVector<int64_t> src_shape_padded(expand_rank, -1);
|
||||
const ArrayRef<int64_t> src_shape = src_ty.getShape();
|
||||
src_shape_padded.append(src_shape.begin(), src_shape.end());
|
||||
const SmallVector<bool> dim_eq = llvm::map_to_vector(
|
||||
llvm::zip(src_shape_padded, dst_ty.getShape()), [](auto tup) {
|
||||
auto [i, o] = tup;
|
||||
return i == o;
|
||||
});
|
||||
|
||||
bool no_op = false;
|
||||
switch (implicit_dim) {
|
||||
case VectorLayout::ImplicitDim::kNone: {
|
||||
const ArrayRef<bool> tiled_dim_eq = ArrayRef<bool>(dim_eq).take_back(2);
|
||||
for (auto [in_off, out_off, eq] :
|
||||
llvm::zip(offsets_in, offsets_out, tiled_dim_eq)) {
|
||||
if (eq && in_off != out_off) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Changing offsets mid-broadcast");
|
||||
}
|
||||
}
|
||||
no_op = layout_in.hasNaturalTopology(ctx.target_shape) &&
|
||||
layout_out.hasNaturalTopology(ctx.target_shape) &&
|
||||
llvm::all_of(llvm::zip_equal(offsets_in, tiled_dim_eq),
|
||||
[](auto tup) {
|
||||
auto [o, eq] = tup;
|
||||
return eq || !o.has_value();
|
||||
});
|
||||
} break;
|
||||
case VectorLayout::ImplicitDim::kMinor:
|
||||
case VectorLayout::ImplicitDim::kSecondMinor:
|
||||
if (dim_eq.back()) {
|
||||
if (offsets_in != offsets_out) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Changing offsets mid-broadcast");
|
||||
}
|
||||
no_op = true;
|
||||
} else if (implicit_dim == VectorLayout::ImplicitDim::kSecondMinor &&
|
||||
!offsets_in[1].has_value()) {
|
||||
no_op = true;
|
||||
} else if (implicit_dim == VectorLayout::ImplicitDim::kMinor &&
|
||||
!offsets_in[0].has_value()) {
|
||||
no_op = true;
|
||||
}
|
||||
break;
|
||||
SmallVector<int64_t> src_implicit_shape_padded;
|
||||
// `is_logical_broadcast` stores whether each dimension of the implicit
|
||||
// shape of the result is a broadcast. E.g. if the implicit shape goes from
|
||||
// (2, 1, 3) to (4, 2, 5, 3) it's (true, false, true, false).
|
||||
SmallVector<bool> is_logical_broadcast;
|
||||
src_implicit_shape_padded.reserve(dst_shape.size() +
|
||||
layout_in.num_implicit_dims());
|
||||
is_logical_broadcast.reserve(dst_shape.size() +
|
||||
layout_in.num_implicit_dims());
|
||||
src_implicit_shape_padded.append(expand_rank, 1);
|
||||
src_implicit_shape_padded.append(src_shape.begin(), src_shape.end());
|
||||
for (auto [i, o] : llvm::zip(src_implicit_shape_padded, dst_shape)) {
|
||||
TPU_ASSERT_OP(i == o || i == 1); // Verifier should guarantee this.
|
||||
is_logical_broadcast.push_back(i != o);
|
||||
}
|
||||
TPU_ASSERT_OP(layout_rank);
|
||||
if (src_ty.getShape().take_back(layout_rank) ==
|
||||
dst_ty.getShape().take_back(layout_rank)) {
|
||||
if (offsets_in != offsets_out) {
|
||||
op.emitOpError("Not implemented: Changing offsets mid-broadcast");
|
||||
layout_in.insertImplicit<int64_t>(src_implicit_shape_padded, 1);
|
||||
layout_in.insertImplicit<bool>(is_logical_broadcast, false);
|
||||
|
||||
// Verify that the offsets are valid.
|
||||
for (auto [is_logical_broadcast_on_dim, in_off, out_off] :
|
||||
llvm::zip_equal(ArrayRef(is_logical_broadcast).take_back(2),
|
||||
offsets_in, offsets_out)) {
|
||||
if (is_logical_broadcast_on_dim) {
|
||||
if (out_off.has_value()) {
|
||||
// There's no reason to ever assign a non-replicated offset to a
|
||||
// broadcasted dimension in the output.
|
||||
return op.emitOpError(
|
||||
// TODO(tlongeri): This should never be implemented but the fuzzed
|
||||
// tests expect a NotImplementedError, which
|
||||
// is raised with a "Not implemented" (see
|
||||
// NotImplementedDetector in tpu_ext.cc). Fix.
|
||||
"Not implemented: Broadcast output expected to have replicated "
|
||||
"offsets.");
|
||||
}
|
||||
} else { // !is_logical_broadcast_on_dim
|
||||
if (in_off != out_off) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Changing offsets mid-broadcast");
|
||||
}
|
||||
}
|
||||
no_op = true;
|
||||
}
|
||||
|
||||
// `needs_physical_broadcast` specifies whether we need to broadcast vregs
|
||||
// vregs in the sublane and lane dimensions. We only need to do this if the
|
||||
// corresponding dimension of the implicit shape is logically broadcast and
|
||||
// if the input vregs are not already replicated along this dimension.
|
||||
const std::array<bool, 2> needs_physical_broadcast{
|
||||
*(is_logical_broadcast.end() - 2) && offsets_in[0].has_value(),
|
||||
*(is_logical_broadcast.end() - 1) && offsets_in[1].has_value()};
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
xla::Array<Value> src_tiles,
|
||||
disassemble(builder, layout_in, src, ctx.target_shape));
|
||||
xla::Array<Value> dst_tiles(dst_tiles_shape);
|
||||
if (no_op) {
|
||||
disassemble(builder, layout_in, src, ctx.target_shape,
|
||||
/*use_implicit_shape=*/true));
|
||||
xla::Array<Value> dst_tiles(dst_tiles_implicit_shape);
|
||||
if (needs_physical_broadcast == std::array{false, false}) { // No-op
|
||||
SmallVector<int64_t> reshape_dims(expand_rank, 1);
|
||||
const absl::Span<const int64_t> src_tiles_dims = src_tiles.dimensions();
|
||||
reshape_dims.append(src_tiles_dims.begin(), src_tiles_dims.end());
|
||||
src_tiles.Reshape(reshape_dims);
|
||||
dst_tiles.Each([&](const absl::Span<const int64_t> dst_idx, Value *tile) {
|
||||
const SmallVector<int64_t> src_idx =
|
||||
llvm::map_to_vector(llvm::zip_equal(dst_idx, dim_eq), [](auto tup) {
|
||||
auto [i, eq] = tup;
|
||||
return eq ? i : 0;
|
||||
const SmallVector<int64_t> src_idx = llvm::map_to_vector(
|
||||
llvm::zip_equal(dst_idx, is_logical_broadcast), [](auto tup) {
|
||||
auto [i, is_logical_broadcast_on_dim] = tup;
|
||||
return is_logical_broadcast_on_dim ? 0 : i;
|
||||
});
|
||||
*tile = src_tiles(src_idx);
|
||||
});
|
||||
} else if (implicit_dim == VectorLayout::ImplicitDim::kNone) {
|
||||
} else {
|
||||
if (layout_in.bitwidth() != 32) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Only 32-bit broadcast supported");
|
||||
@ -2764,8 +2763,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
return op.emitOpError("Not implemented: unsupported tiling");
|
||||
}
|
||||
int64_t num_tiles = layout_in.tilesPerVreg(ctx.target_shape);
|
||||
TPU_ASSERT_OP(!*(dim_eq.end() - 1) || !*(dim_eq.end() - 2));
|
||||
if (*(dim_eq.end() - 1)) { // Sublane broadcast
|
||||
if (needs_physical_broadcast ==
|
||||
std::array{true, false}) { // Sublane broadcast
|
||||
if (num_tiles != 1) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Only native tiling supported");
|
||||
@ -2777,12 +2776,12 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
SmallVector<int32_t>(ctx.target_shape[0], offset));
|
||||
src_tiles.Each([&](const absl::Span<const int64_t> src_idx,
|
||||
Value *const src_tile) {
|
||||
SmallVector<int64_t> dst_starts(dst_tiles_shape.size());
|
||||
SmallVector<int64_t> dst_limits(dst_tiles_shape.size());
|
||||
SmallVector<int64_t> dst_starts(dst_tiles_implicit_shape.size());
|
||||
SmallVector<int64_t> dst_limits(dst_tiles_implicit_shape.size());
|
||||
for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) {
|
||||
if (i < expand_rank || !dim_eq[i]) {
|
||||
if (i < expand_rank || is_logical_broadcast[i]) {
|
||||
dst_starts[i] = 0;
|
||||
dst_limits[i] = dst_tiles_shape[i];
|
||||
dst_limits[i] = dst_tiles_implicit_shape[i];
|
||||
} else {
|
||||
dst_starts[i] = src_idx[i - expand_rank];
|
||||
dst_limits[i] = dst_starts[i] + 1;
|
||||
@ -2793,7 +2792,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
src_tile->getType(), *src_tile, indices, 0),
|
||||
dst_starts, dst_limits);
|
||||
});
|
||||
} else if (*(dim_eq.end() - 2)) { // Lane broadcast
|
||||
} else if (needs_physical_broadcast ==
|
||||
std::array{false, true}) { // Lane broadcast
|
||||
TPU_ASSERT_EQ_OP(*(src_tiles.dimensions().end() - 1), 1);
|
||||
TPU_ASSERT_OP(offsets_in[1].has_value());
|
||||
const int64_t offset = *offsets_in[1];
|
||||
@ -2816,12 +2816,12 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
}
|
||||
src_tiles.Each([&](const absl::Span<const int64_t> src_idx,
|
||||
Value *const src_tile) {
|
||||
SmallVector<int64_t> dst_starts(dst_tiles_shape.size());
|
||||
SmallVector<int64_t> dst_limits(dst_tiles_shape.size());
|
||||
SmallVector<int64_t> dst_starts(dst_tiles_implicit_shape.size());
|
||||
SmallVector<int64_t> dst_limits(dst_tiles_implicit_shape.size());
|
||||
for (int64_t i = 0; i < dst_tiles.num_dimensions(); ++i) {
|
||||
if (i < expand_rank || !dim_eq[i]) {
|
||||
if (i < expand_rank || is_logical_broadcast[i]) {
|
||||
dst_starts[i] = 0;
|
||||
dst_limits[i] = dst_tiles_shape[i];
|
||||
dst_limits[i] = dst_tiles_implicit_shape[i];
|
||||
} else {
|
||||
dst_starts[i] = src_idx[i - expand_rank];
|
||||
dst_limits[i] = dst_starts[i] + 1;
|
||||
@ -2838,14 +2838,15 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
|
||||
updateSlice<Value>(dst_tiles, res_vreg, dst_starts, dst_limits);
|
||||
});
|
||||
} else {
|
||||
return op.emitOpError("Not implemented");
|
||||
TPU_ASSERT_OP((needs_physical_broadcast == std::array{true, true}));
|
||||
return op.emitOpError(
|
||||
"Not implemented: Broadcast in both sublanes and lanes");
|
||||
}
|
||||
} else {
|
||||
return op.emitOpError("Not implemented");
|
||||
}
|
||||
broadcast_op.replaceAllUsesWith(
|
||||
assemble(builder, dst_ty, layout_out, dst_tiles, ctx.target_shape)
|
||||
.getOperation());
|
||||
broadcast_op.replaceAllUsesWith(assemble(builder, dst_ty, layout_out,
|
||||
dst_tiles, ctx.target_shape,
|
||||
/*use_implicit_shape=*/true)
|
||||
.getOperation());
|
||||
broadcast_op.erase();
|
||||
return success();
|
||||
} else if (layout_out.bitwidth() == 32 &&
|
||||
@ -2981,13 +2982,13 @@ FailureOr<xla::Array<Value>> vector_extract_slice_impl(
|
||||
full_sizes.append(sizes.begin(), sizes.end());
|
||||
full_sizes.append(src_vector_shape.begin() + num_indices,
|
||||
src_vector_shape.end());
|
||||
layout_in.insertImplicit(full_sizes, 1);
|
||||
layout_in.insertImplicit<int64_t>(full_sizes, 1);
|
||||
|
||||
SmallVector<int64_t> full_offsets;
|
||||
full_offsets.reserve(src_vector_rank + layout_in.num_implicit_dims());
|
||||
full_offsets.append(offsets.begin(), offsets.end());
|
||||
full_offsets.append(src_vector_rank - num_indices, 0);
|
||||
layout_in.insertImplicit(full_offsets, 0);
|
||||
layout_in.insertImplicit<int64_t>(full_offsets, 0);
|
||||
|
||||
// We currently only support no-op cases - that is, those where we effectively
|
||||
// just extract a slice of vregs without doing any operations (e.g. shifts) on
|
||||
@ -4036,9 +4037,16 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
|
||||
const VectorLayout &layout,
|
||||
const xla::Array<Value> &vals,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
CHECK(vals.dimensions() ==
|
||||
layout.tileArrayShape(vty.getShape(), target_shape));
|
||||
const std::array<int64_t, 2> target_shape,
|
||||
const bool use_implicit_shape) {
|
||||
// TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of
|
||||
// having `tileArrayShape` and `tileArrayImplicitShape`.
|
||||
SmallVector<int64_t> vreg_array_shape =
|
||||
layout.tileArrayImplicitShape(vty.getShape(), target_shape);
|
||||
if (!use_implicit_shape) {
|
||||
layout.eraseImplicit(vreg_array_shape);
|
||||
}
|
||||
CHECK(vals.dimensions() == vreg_array_shape);
|
||||
CHECK_GT(vals.num_elements(), 0);
|
||||
Location loc = vals.begin()->getLoc();
|
||||
auto op =
|
||||
@ -4059,8 +4067,8 @@ RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
|
||||
// An ndarray of MLIR values representing the tiling of val given by layout.
|
||||
FailureOr<xla::Array<Value>> disassemble(
|
||||
OpBuilder &builder, const VectorLayout &layout,
|
||||
const TypedValue<VectorType> val,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
const TypedValue<VectorType> val, const std::array<int64_t, 2> target_shape,
|
||||
const bool use_implicit_shape) { // TODO(tlongeri): Remove default
|
||||
const auto vty = val.getType();
|
||||
const auto op_result = dyn_cast<OpResult>(val);
|
||||
if (op_result == nullptr) {
|
||||
@ -4074,8 +4082,13 @@ FailureOr<xla::Array<Value>> disassemble(
|
||||
TPU_ASSERT_LOC(val.getLoc(), def_layout.has_value());
|
||||
TPU_ASSERT_LOC(val.getLoc(),
|
||||
def_layout->generalizes(layout, vty.getShape(), target_shape));
|
||||
// TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of
|
||||
// having `tileArrayShape` and `tileArrayImplicitShape`.
|
||||
SmallVector<int64_t> layout_shape =
|
||||
layout.tileArrayShape(vty.getShape(), target_shape);
|
||||
layout.tileArrayImplicitShape(vty.getShape(), target_shape);
|
||||
if (!use_implicit_shape) {
|
||||
layout.eraseImplicit(layout_shape);
|
||||
}
|
||||
if (auto roll_vectors_op = dyn_cast<RollVectorsOp>(op)) {
|
||||
return XlaArrayFromShapeAndValues<Value>(layout_shape,
|
||||
roll_vectors_op->getOperands());
|
||||
|
@ -25,14 +25,17 @@ struct RewriteContext {
|
||||
MLIRContext *getMLIRContext() { return func.getContext(); }
|
||||
};
|
||||
|
||||
// TODO(tlongeri): Remove default values for use_implicit_shape.
|
||||
RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
|
||||
const VectorLayout &layout,
|
||||
const xla::Array<Value> &vals,
|
||||
std::array<int64_t, 2> target_shape);
|
||||
std::array<int64_t, 2> target_shape,
|
||||
bool use_implicit_shape = false);
|
||||
FailureOr<xla::Array<Value>> disassemble(OpBuilder &builder,
|
||||
const VectorLayout &layout,
|
||||
TypedValue<VectorType> val,
|
||||
std::array<int64_t, 2> target_shape);
|
||||
std::array<int64_t, 2> target_shape,
|
||||
bool use_implicit_shape = false);
|
||||
|
||||
// Rewrites the operation according to its layout annotations.
|
||||
//
|
||||
|
@ -893,48 +893,42 @@ class VectorLayoutInferer {
|
||||
TPU_CHECK_OP(res_ty.getRank() >= 2, "result rank below 2D unsupported");
|
||||
auto some_layout = getLayout(op.getSource());
|
||||
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
|
||||
auto &layout = *some_layout;
|
||||
// Since we can only do sublane broadcasts in the (8, 128) tiling, we
|
||||
// should always use that when sublane broadcasting is required.
|
||||
if (src_ty.getDimSize(src_ty.getRank() - 2) !=
|
||||
res_ty.getDimSize(res_ty.getRank() - 2)) {
|
||||
if (some_layout->bitwidth() != kNativeBitwidth) {
|
||||
if (*(src_ty.getShape().end() - 2) != *(res_ty.getShape().end() - 2)) {
|
||||
if (layout.bitwidth() != kNativeBitwidth) {
|
||||
NYI("Only 32-bit broadcasts supported");
|
||||
}
|
||||
LayoutOffsets offsets = some_layout->offsets();
|
||||
LayoutOffsets offsets = layout.offsets();
|
||||
// At the moment relayout can only produce replicated sublanes when
|
||||
// converting to (8, 128) if the input was in (1, 128) tiling
|
||||
if (some_layout->tiling()[0] == 1) {
|
||||
if (layout.tiling()[0] == 1) {
|
||||
offsets[0] = std::nullopt;
|
||||
}
|
||||
*some_layout =
|
||||
VectorLayout(some_layout->bitwidth(), offsets, default_tiling_,
|
||||
some_layout->implicit_dim());
|
||||
layout = VectorLayout(layout.bitwidth(), offsets, default_tiling_,
|
||||
layout.implicit_dim());
|
||||
}
|
||||
auto &layout = *some_layout;
|
||||
if (layout.implicit_dim() != ImplicitDim::kNone) {
|
||||
VectorLayout layout_2d(layout.bitwidth(), layout.offsets(),
|
||||
layout.tiling(), ImplicitDim::kNone);
|
||||
if (layout_2d.equivalentTo(layout, src_ty.getShape(), target_shape_)) {
|
||||
// TODO(b/342237796): Stop preferring 2D layouts (if given the choice)
|
||||
// and defer the work, if any, to relayout.
|
||||
layout = layout_2d;
|
||||
} else {
|
||||
op.emitOpError() << "Only 2D layouts supported";
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
auto src_tiled_shape = src_ty.getShape().take_back(2);
|
||||
auto dst_tiled_shape = res_ty.getShape().take_back(2);
|
||||
LayoutOffsets offsets = layout.offsets();
|
||||
if (layout.bitwidth() == kNativeBitwidth &&
|
||||
layout.tiling() == default_tiling_) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
if (src_tiled_shape[i] != dst_tiled_shape[i]) {
|
||||
offsets[i] = std::nullopt;
|
||||
}
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
if (src_tiled_shape[i] != dst_tiled_shape[i]) {
|
||||
offsets[i] = std::nullopt;
|
||||
}
|
||||
}
|
||||
setLayout(op, some_layout,
|
||||
setLayout(op, layout,
|
||||
VectorLayout(layout.bitwidth(), offsets, layout.tiling(),
|
||||
ImplicitDim::kNone));
|
||||
layout.implicit_dim()));
|
||||
return success();
|
||||
}
|
||||
op.emitOpError("unsupported broadcast source type");
|
||||
@ -1122,7 +1116,7 @@ class VectorLayoutInferer {
|
||||
auto offsets = llvm::map_to_vector(offsets_attr, [](auto attr) {
|
||||
return cast<IntegerAttr>(attr).getInt();
|
||||
});
|
||||
input_layout->insertImplicit(offsets, 0);
|
||||
input_layout->insertImplicit<int64_t>(offsets, 0);
|
||||
auto vreg_slice = input_layout->vregSlice(target_shape_);
|
||||
LayoutOffsets new_layout_offsets;
|
||||
if (input_layout->offsets()[0].has_value()) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user