[Mosaic TPU] Add support for tpu.iota over untiled dimensions

PiperOrigin-RevId: 636567090
This commit is contained in:
Adam Paszke 2024-05-23 08:56:08 -07:00 committed by jax authors
parent 2fceaf05d4
commit 63a13f516d

View File

@ -2260,7 +2260,24 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
op.erase();
return success();
}
return op.emitOpError("Not implemented: Unsupported dimension");
// We take the iota over an untiled dimension.
CHECK_LT(*dimension, vty.getRank());
SmallVector<Value> tiles;
tiles.reserve(vty.getDimSize(*dimension));
for (int64_t i = 0; i < vty.getDimSize(*dimension); ++i) {
tiles.push_back(builder.create<arith::ConstantOp>(
native_vreg_ty,
DenseElementsAttr::get(native_vreg_ty,
IntegerAttr::get(vty.getElementType(), i))));
}
xla::Array<Value> out_tiles(tile_array_shape);
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = tiles[idxs[*dimension]];
});
op.replaceAllUsesWith(
assemble(builder, vty, layout_out, out_tiles, ctx.target_shape));
op.erase();
return success();
}
LogicalResult tpu_gather_rule(RewriteContext &ctx, Operation &op,