[Mosaic] apply_vector_layout C++ rewrite (9): tpu.repeat

PiperOrigin-RevId: 569078893
This commit is contained in:
Tomás Longeri 2023-09-27 23:35:29 -07:00 committed by jax authors
parent b1b81ecc60
commit fb90d3ee31

View File

@ -82,6 +82,25 @@ FailureOr<xla::Array<Value>> disassemble(RewriteContext &ctx,
const VectorLayout &layout, Value val);
namespace {
// Models Numpy's np.repeat, repeating each element `repeats` times along the
// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is
// 3, this will return [1, 1, 1, 2, 2, 2].
xla::Array<Value> repeat(const xla::Array<Value> &src, const int repeats,
const int64_t axis) {
SmallVector<int64_t> dims(toArrayRef(src.dimensions()));
dims[axis] *= repeats;
xla::Array<Value> res(dims);
src.Each([&](absl::Span<const int64_t> idx, const Value v) {
SmallVector<int64_t> res_idx(toArrayRef(idx));
res_idx[axis] *= repeats;
for (int i = 0; i < repeats; ++i) {
res(res_idx) = v;
++res_idx[axis];
}
});
return res;
}
FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
if (isa<FloatType>(ty)) {
return TypedAttr(FloatAttr::get(ty, 0));
@ -849,6 +868,49 @@ LogicalResult tpu_gather_rule(RewriteContext &ctx, Operation &op,
return success();
}
LogicalResult tpu_repeat_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_in.front().has_value()) {
return op.emitOpError("Expected non-null input layout");
}
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
return op.emitOpError("Not implemented: Only 2D layouts supported");
}
if (layout_in != layout_out) {
return op.emitOpError("Not implemented: Changing layout mid-repeat");
}
if (!layout_in.hasNaturalTopology(ctx.target_shape) ||
layout_in.offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError("Not implemented: Non-trivial layouts unsupported");
}
tpu::RepeatOp repeat_op = cast<tpu::RepeatOp>(op);
VectorType src_ty = repeat_op.getSource().getType();
const uint32_t dim = repeat_op.getDimension();
if (dim != src_ty.getRank() - 1) {
return op.emitOpError(
"Not implemented: Only repeats along the last dim supported");
}
if (src_ty.getShape().back() % ctx.target_shape.back() != 0) {
return op.emitOpError("Not implemented: Only free repeats are suppported");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> &in_vregs,
disassemble(ctx, layout_in, repeat_op.getSource()));
xla::Array<Value> out_vregs = repeat(in_vregs, repeat_op.getTimes(), dim);
repeat_op->replaceAllUsesWith(
assemble(ctx, repeat_op.getResult().getType(), layout_out, out_vregs));
repeat_op->erase();
return success();
}
LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
@ -1263,6 +1325,7 @@ const llvm::StringMap<rule_type> &rules() {
{tpu::IotaOp::getOperationName(), tpu_iota_rule},
{tpu::GatherOp::getOperationName(), tpu_gather_rule},
{tpu::LoadOp::getOperationName(), tpu_load_rule},
{tpu::RepeatOp::getOperationName(), tpu_repeat_rule},
{tpu::StoreOp::getOperationName(), tpu_store_rule},
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
{vector::LoadOp::getOperationName(), vector_load_rule},