[Mosaic TPU][NFC] Remove redundant num_subelems attribute from CreateSubelementMaskOp

PiperOrigin-RevId: 714795856
This commit is contained in:
Tomás Longeri 2025-01-12 19:33:47 -08:00 committed by jax authors
parent ed349e4544
commit 0930289997
3 changed files with 7 additions and 10 deletions

View File

@ -324,7 +324,7 @@ class TiledRectangularVregBounds : public VRegDataBounds {
end_row = target_shape[0] * layout_.packing();
}
auto submask = builder.create<tpu::CreateSubelementMaskOp>(
loc, mask_vreg_ty, start_row, end_row, layout_.packing());
loc, mask_vreg_ty, start_row, end_row);
tile_mask = builder.create<arith::AndIOp>(loc, tile_mask, submask);
} else { // generation < 4
if (num_tiles_ > 1) {

View File

@ -517,12 +517,10 @@ def TPU_CreateMaskOp : TPU_Op<"create_mask", [Pure, SameVariadicOperandSize]> {
def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> {
let summary = "Create a mask masking contiguous rows of subelements.";
// TODO(tlongeri): Why don't we just get `num_subelems` from the result type?
// Taking a parameter and allowing a mismatch is confusing.
let description = [{
The "half-sublanes", "quarter-sublanes", etc. (unit is determined by
`num_subelems`) of the mask are masked in the range specified by `from` and
`to`.
the type of `output`) of the mask are masked in the range specified by
`from` and `to`.
- If `from <= to`, the range `[from, to)` is set and the rest is unset.
- If `to <= from`, the range `[to, from)` is unset and the rest is set.
@ -532,7 +530,7 @@ def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> {
Example:
```mlir
%msk = tpu.create_subelement_mask 3, 9, 2 : vector<8x128x2xi1>
%msk = tpu.create_subelement_mask 3, 9 : vector<8x128x2xi1>
```
This creates a mask `%msk` where, for all `lane`s, `%msk[*][lane][*]` is:
@ -547,12 +545,11 @@ def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> {
}];
let arguments = (ins
I32Attr:$from, // inclusive
I32Attr:$to, // exclusive
I32Attr:$num_subelems
I32Attr:$to // exclusive
);
let results = (outs AnyType:$output); // Verify this is a vmsk with num_subelems
let assemblyFormat = [{
$from `,` $to `,` $num_subelems attr-dict `:` type($output)
$from `,` $to attr-dict `:` type($output)
}];
}

View File

@ -2669,7 +2669,7 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
// support for unpacked types in some of the invariants in
// lower_to_llo.
mask = builder.create<tpu::CreateSubelementMaskOp>(
op.getLoc(), vmask_ty, 0, operand_offset, packing);
op.getLoc(), vmask_ty, 0, operand_offset);
} else {
auto sublane_offset = operand_offset / packing;
mask = builder.create<tpu::CreateMaskOp>(