mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] apply_vector_layout C++ rewrite (8): tpu.gather, tpu.iota, tpu.trace
PiperOrigin-RevId: 569069717
This commit is contained in:
parent
bb4382f0bc
commit
b1b81ecc60
@ -75,6 +75,7 @@ struct RewriteContext {
|
||||
const std::array<int64_t, 2> target_shape;
|
||||
};
|
||||
|
||||
LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block);
|
||||
RollVectorsOp assemble(RewriteContext &ctx, VectorType vty,
|
||||
const VectorLayout &layout, xla::Array<Value> vals);
|
||||
FailureOr<xla::Array<Value>> disassemble(RewriteContext &ctx,
|
||||
@ -652,6 +653,202 @@ LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult tpu_trace_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
if (op.getNumOperands() != 0 || op.getNumResults() != 0) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: tpu.traced_block with inputs or outputs");
|
||||
}
|
||||
CHECK_EQ(layouts_in.size(), 0);
|
||||
CHECK_EQ(layouts_out.size(), 0);
|
||||
// We don't modify the op, but we do rewrite the branch bodies.
|
||||
CHECK_EQ(op.getNumRegions(), 1);
|
||||
Region ®ion = op.getRegion(0);
|
||||
CHECK(region.hasOneBlock());
|
||||
Block &block = region.front();
|
||||
return applyLayoutBlock(ctx, block);
|
||||
}
|
||||
|
||||
LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
CHECK_EQ(layouts_in.size(), 0);
|
||||
CHECK_EQ(layouts_out.size(), 1);
|
||||
if (!layouts_out.front().has_value()) {
|
||||
return op.emitOpError("Expected non-null output layout");
|
||||
}
|
||||
const VectorLayout &layout_out = *layouts_out.front();
|
||||
tpu::IotaOp iota_op = cast<tpu::IotaOp>(op);
|
||||
VectorType vty = iota_op.getResult().getType();
|
||||
if (const auto int_ty = dyn_cast<IntegerType>(vty.getElementType());
|
||||
int_ty == nullptr || int_ty.getWidth() != 32) {
|
||||
return iota_op.emitOpError("Not implemented: Only 32-bit Iota supported");
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const auto native_vreg_ty,
|
||||
getNativeVregType(vty.getElementType(), ctx.target_shape));
|
||||
if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
|
||||
return op.emitOpError("Not implemented: Only 2D layouts supported");
|
||||
}
|
||||
const SmallVector<int64_t> tile_array_shape =
|
||||
layout_out.tileArrayShape(vty.getShape(), ctx.target_shape);
|
||||
const std::optional<int32_t> dimension = iota_op.getDimension();
|
||||
if (!dimension.has_value()) {
|
||||
return op.emitOpError("Not implemented: null dimension");
|
||||
}
|
||||
if (*dimension == vty.getRank() - 1) {
|
||||
if (layout_out.offsets()[1] != 0) {
|
||||
return op.emitOpError("Not implemented: Unsupported offset");
|
||||
}
|
||||
const int64_t num_tiles = tile_array_shape[tile_array_shape.size() - 1];
|
||||
SmallVector<Value> tiles(num_tiles);
|
||||
auto vreg_iota = ctx.builder.create<tpu::IotaOp>(
|
||||
op.getLoc(), native_vreg_ty,
|
||||
/*dimension =*/ctx.builder.getI32IntegerAttr(1));
|
||||
for (int64_t i = 0; i < num_tiles; ++i) {
|
||||
auto offset = ctx.builder.create<arith::ConstantOp>(
|
||||
op.getLoc(), native_vreg_ty,
|
||||
DenseElementsAttr::get(
|
||||
native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(),
|
||||
i * *(native_vreg_ty.getShape().end() - 1))));
|
||||
tiles[i] =
|
||||
ctx.builder.create<arith::AddIOp>(op.getLoc(), vreg_iota, offset);
|
||||
}
|
||||
xla::Array<Value> broadcasted_tiles(tile_array_shape);
|
||||
broadcasted_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
*v = tiles[*(idxs.end() - 1)];
|
||||
});
|
||||
op.replaceAllUsesWith(assemble(ctx, vty, layout_out, broadcasted_tiles));
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
if (*dimension == vty.getRank() - 2) {
|
||||
if (layout_out.offsets()[0] != 0) {
|
||||
return op.emitOpError("Not implemented: Unsupported offset");
|
||||
}
|
||||
const int64_t num_tiles = tile_array_shape[tile_array_shape.size() - 2];
|
||||
SmallVector<Value> tiles(num_tiles);
|
||||
auto vreg_iota = ctx.builder.create<tpu::IotaOp>(
|
||||
op.getLoc(), native_vreg_ty,
|
||||
/*dimension =*/ctx.builder.getI32IntegerAttr(0));
|
||||
for (int64_t i = 0; i < num_tiles; ++i) {
|
||||
auto offset = ctx.builder.create<arith::ConstantOp>(
|
||||
op.getLoc(), native_vreg_ty,
|
||||
DenseElementsAttr::get(
|
||||
native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(),
|
||||
i * *(native_vreg_ty.getShape().end() - 2))));
|
||||
tiles[i] =
|
||||
ctx.builder.create<arith::AddIOp>(op.getLoc(), vreg_iota, offset);
|
||||
}
|
||||
xla::Array<Value> broadcasted_tiles(tile_array_shape);
|
||||
broadcasted_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
*v = tiles[*(idxs.end() - 2)];
|
||||
});
|
||||
op.replaceAllUsesWith(assemble(ctx, vty, layout_out, broadcasted_tiles));
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
return op.emitOpError("Not implemented: Unsupported dimension");
|
||||
}
|
||||
|
||||
LogicalResult tpu_gather_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
CHECK_EQ(layouts_in.size(), 1);
|
||||
CHECK_EQ(layouts_out.size(), 1);
|
||||
if (!layouts_in.front().has_value()) {
|
||||
return op.emitOpError("Expected non-null input layout");
|
||||
}
|
||||
if (!layouts_out.front().has_value()) {
|
||||
return op.emitOpError("Expected non-null output layout");
|
||||
}
|
||||
const VectorLayout &layout_in = *layouts_in.front();
|
||||
const VectorLayout &layout_out = *layouts_out.front();
|
||||
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone ||
|
||||
layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone ||
|
||||
layout_in.offsets() != layout_out.offsets() ||
|
||||
llvm::any_of(layout_in.offsets(), [&](const LayoutOffset o) {
|
||||
return o.has_value() && o != 0;
|
||||
})) {
|
||||
return op.emitOpError("Not implemented: Only 2D layouts supported");
|
||||
}
|
||||
auto gather_op = cast<tpu::GatherOp>(op);
|
||||
const VectorType vty = gather_op.getResult().getType();
|
||||
const uint32_t dimension = gather_op.getDimension();
|
||||
if (dimension + 2 < vty.getRank()) {
|
||||
return op.emitOpError("Not implemented: Unsupported dimension");
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const xla::Array<Value> in_tiles,
|
||||
disassemble(ctx, layout_in, gather_op.getSource()));
|
||||
const int64_t width = ctx.target_shape[2 - (vty.getRank() - dimension)];
|
||||
const ArrayRef<int32_t> indices(gather_op.getIndices());
|
||||
auto [num_sections, rem] = std::div(indices.size(), width);
|
||||
SmallVector<int32_t> segment_indices;
|
||||
if (rem == 0) {
|
||||
for (int64_t i = 0; i < width; ++i) {
|
||||
const int64_t offset = i - i % width;
|
||||
if (!(offset <= indices[i] && indices[i] < offset + width)) {
|
||||
return op.emitOpError("Not implemented: Cross-segment gather");
|
||||
}
|
||||
}
|
||||
for (int64_t i = width; i < indices.size(); ++i) {
|
||||
const int64_t offset = i - i % width;
|
||||
if (indices[i] != indices[i % width] + offset) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Indices varying between segments");
|
||||
}
|
||||
}
|
||||
segment_indices.assign(indices.begin(), indices.begin() + width);
|
||||
} else if (num_sections == 0) { // Only one vreg.
|
||||
segment_indices.assign(indices.begin(), indices.end());
|
||||
segment_indices.append(width - indices.size(), 0);
|
||||
} else {
|
||||
return op.emitOpError("Not implemented: Not a multiple of target length");
|
||||
}
|
||||
xla::Array<Value> out_tiles(in_tiles.dimensions());
|
||||
if (dimension == vty.getRank() - 1) {
|
||||
// TODO(b/265133497): Remove the broadcast once 2nd minor works.
|
||||
const auto dyn_ix_ty =
|
||||
VectorType::get(ctx.target_shape, ctx.builder.getI32Type());
|
||||
// Broadcast indices to target_shape
|
||||
SmallVector<int32_t> dyn_ix_val;
|
||||
for (int64_t i = 0; i < ctx.target_shape[0]; ++i) { // Broadcast
|
||||
dyn_ix_val.append(segment_indices);
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const BlockArgument dyn_ix_ref,
|
||||
appendConstant(ctx, DenseIntElementsAttr::get(dyn_ix_ty, dyn_ix_val)));
|
||||
auto all_sublanes = ctx.builder.getAttr<DenseBoolArrayAttr>(
|
||||
SmallVector<bool>(ctx.target_shape[1], true));
|
||||
auto dyn_ix = ctx.builder.create<tpu::LoadOp>(
|
||||
op.getLoc(), dyn_ix_ty, dyn_ix_ref,
|
||||
SmallVector<Value>(2, IdxConst(0, ctx.builder, op.getLoc())),
|
||||
/*sublane_mask=*/all_sublanes, /*sublane_stride=*/nullptr);
|
||||
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
const Value in_tile = in_tiles(idxs);
|
||||
*v = ctx.builder.create<tpu::DynamicGatherOp>(
|
||||
op.getLoc(), in_tile.getType(), in_tile, dyn_ix, 1);
|
||||
});
|
||||
} else {
|
||||
CHECK_EQ(dimension, vty.getRank() - 2);
|
||||
const auto segment_indices_attr =
|
||||
ctx.builder.getAttr<DenseI32ArrayAttr>(segment_indices);
|
||||
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
const Value in_tile = in_tiles(idxs);
|
||||
*v = ctx.builder.create<tpu::GatherOp>(op.getLoc(), in_tile.getType(),
|
||||
in_tile, segment_indices_attr, 0);
|
||||
});
|
||||
}
|
||||
gather_op.replaceAllUsesWith(
|
||||
assemble(ctx, vty, layout_out, out_tiles).getOperation());
|
||||
gather_op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
@ -1063,8 +1260,11 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
rules_elementwise_op_entry<math::PowFOp, 1>(),
|
||||
rules_elementwise_op_entry<math::RsqrtOp, 1>(),
|
||||
rules_elementwise_op_entry<math::TanhOp, 1>(),
|
||||
{tpu::IotaOp::getOperationName(), tpu_iota_rule},
|
||||
{tpu::GatherOp::getOperationName(), tpu_gather_rule},
|
||||
{tpu::LoadOp::getOperationName(), tpu_load_rule},
|
||||
{tpu::StoreOp::getOperationName(), tpu_store_rule},
|
||||
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
|
||||
{vector::LoadOp::getOperationName(), vector_load_rule},
|
||||
{vector::StoreOp::getOperationName(), vector_store_rule}};
|
||||
return *rules;
|
||||
|
Loading…
x
Reference in New Issue
Block a user