mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] apply_vector_layout C++ rewrite (9): tpu.repeat
PiperOrigin-RevId: 569078893
This commit is contained in:
parent
b1b81ecc60
commit
fb90d3ee31
@ -82,6 +82,25 @@ FailureOr<xla::Array<Value>> disassemble(RewriteContext &ctx,
|
||||
const VectorLayout &layout, Value val);
|
||||
namespace {
|
||||
|
||||
// Models Numpy's np.repeat, repeating each element `repeats` times along the
|
||||
// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is
|
||||
// 3, this will return [1, 1, 1, 2, 2, 2].
|
||||
xla::Array<Value> repeat(const xla::Array<Value> &src, const int repeats,
|
||||
const int64_t axis) {
|
||||
SmallVector<int64_t> dims(toArrayRef(src.dimensions()));
|
||||
dims[axis] *= repeats;
|
||||
xla::Array<Value> res(dims);
|
||||
src.Each([&](absl::Span<const int64_t> idx, const Value v) {
|
||||
SmallVector<int64_t> res_idx(toArrayRef(idx));
|
||||
res_idx[axis] *= repeats;
|
||||
for (int i = 0; i < repeats; ++i) {
|
||||
res(res_idx) = v;
|
||||
++res_idx[axis];
|
||||
}
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
|
||||
if (isa<FloatType>(ty)) {
|
||||
return TypedAttr(FloatAttr::get(ty, 0));
|
||||
@ -849,6 +868,49 @@ LogicalResult tpu_gather_rule(RewriteContext &ctx, Operation &op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult tpu_repeat_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
CHECK_EQ(layouts_in.size(), 1);
|
||||
CHECK_EQ(layouts_out.size(), 1);
|
||||
if (!layouts_in.front().has_value()) {
|
||||
return op.emitOpError("Expected non-null input layout");
|
||||
}
|
||||
if (!layouts_out.front().has_value()) {
|
||||
return op.emitOpError("Expected non-null output layout");
|
||||
}
|
||||
const VectorLayout &layout_in = *layouts_in.front();
|
||||
const VectorLayout &layout_out = *layouts_out.front();
|
||||
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
|
||||
return op.emitOpError("Not implemented: Only 2D layouts supported");
|
||||
}
|
||||
if (layout_in != layout_out) {
|
||||
return op.emitOpError("Not implemented: Changing layout mid-repeat");
|
||||
}
|
||||
if (!layout_in.hasNaturalTopology(ctx.target_shape) ||
|
||||
layout_in.offsets() != LayoutOffsets{0, 0}) {
|
||||
return op.emitOpError("Not implemented: Non-trivial layouts unsupported");
|
||||
}
|
||||
tpu::RepeatOp repeat_op = cast<tpu::RepeatOp>(op);
|
||||
VectorType src_ty = repeat_op.getSource().getType();
|
||||
const uint32_t dim = repeat_op.getDimension();
|
||||
if (dim != src_ty.getRank() - 1) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Only repeats along the last dim supported");
|
||||
}
|
||||
if (src_ty.getShape().back() % ctx.target_shape.back() != 0) {
|
||||
return op.emitOpError("Not implemented: Only free repeats are suppported");
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const xla::Array<Value> &in_vregs,
|
||||
disassemble(ctx, layout_in, repeat_op.getSource()));
|
||||
xla::Array<Value> out_vregs = repeat(in_vregs, repeat_op.getTimes(), dim);
|
||||
repeat_op->replaceAllUsesWith(
|
||||
assemble(ctx, repeat_op.getResult().getType(), layout_out, out_vregs));
|
||||
repeat_op->erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
@ -1263,6 +1325,7 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
{tpu::IotaOp::getOperationName(), tpu_iota_rule},
|
||||
{tpu::GatherOp::getOperationName(), tpu_gather_rule},
|
||||
{tpu::LoadOp::getOperationName(), tpu_load_rule},
|
||||
{tpu::RepeatOp::getOperationName(), tpu_repeat_rule},
|
||||
{tpu::StoreOp::getOperationName(), tpu_store_rule},
|
||||
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
|
||||
{vector::LoadOp::getOperationName(), vector_load_rule},
|
||||
|
Loading…
x
Reference in New Issue
Block a user