mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic TPU] Add support for tpu.iota over untiled dimensions
PiperOrigin-RevId: 636567090
This commit is contained in:
parent
2fceaf05d4
commit
63a13f516d
@ -2260,7 +2260,24 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
return op.emitOpError("Not implemented: Unsupported dimension");
|
||||
// We take the iota over an untiled dimension.
|
||||
CHECK_LT(*dimension, vty.getRank());
|
||||
SmallVector<Value> tiles;
|
||||
tiles.reserve(vty.getDimSize(*dimension));
|
||||
for (int64_t i = 0; i < vty.getDimSize(*dimension); ++i) {
|
||||
tiles.push_back(builder.create<arith::ConstantOp>(
|
||||
native_vreg_ty,
|
||||
DenseElementsAttr::get(native_vreg_ty,
|
||||
IntegerAttr::get(vty.getElementType(), i))));
|
||||
}
|
||||
xla::Array<Value> out_tiles(tile_array_shape);
|
||||
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
*v = tiles[idxs[*dimension]];
|
||||
});
|
||||
op.replaceAllUsesWith(
|
||||
assemble(builder, vty, layout_out, out_tiles, ctx.target_shape));
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult tpu_gather_rule(RewriteContext &ctx, Operation &op,
|
||||
|
Loading…
x
Reference in New Issue
Block a user