mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[Mosaic TPU][NFC] Remove redundant num_subelems attribute from CreateSubelementMaskOp
PiperOrigin-RevId: 714795856
This commit is contained in:
parent
ed349e4544
commit
0930289997
@ -324,7 +324,7 @@ class TiledRectangularVregBounds : public VRegDataBounds {
|
||||
end_row = target_shape[0] * layout_.packing();
|
||||
}
|
||||
auto submask = builder.create<tpu::CreateSubelementMaskOp>(
|
||||
loc, mask_vreg_ty, start_row, end_row, layout_.packing());
|
||||
loc, mask_vreg_ty, start_row, end_row);
|
||||
tile_mask = builder.create<arith::AndIOp>(loc, tile_mask, submask);
|
||||
} else { // generation < 4
|
||||
if (num_tiles_ > 1) {
|
||||
|
@ -517,12 +517,10 @@ def TPU_CreateMaskOp : TPU_Op<"create_mask", [Pure, SameVariadicOperandSize]> {
|
||||
|
||||
def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> {
|
||||
let summary = "Create a mask masking contiguous rows of subelements.";
|
||||
// TODO(tlongeri): Why don't we just get `num_subelems` from the result type?
|
||||
// Taking a parameter and allowing a mismatch is confusing.
|
||||
let description = [{
|
||||
The "half-sublanes", "quarter-sublanes", etc. (unit is determined by
|
||||
`num_subelems`) of the mask are masked in the range specified by `from` and
|
||||
`to`.
|
||||
the type of `output`) of the mask are masked in the range specified by
|
||||
`from` and `to`.
|
||||
|
||||
- If `from <= to`, the range `[from, to)` is set and the rest is unset.
|
||||
- If `to <= from`, the range `[to, from)` is unset and the rest is set.
|
||||
@ -532,7 +530,7 @@ def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> {
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%msk = tpu.create_subelement_mask 3, 9, 2 : vector<8x128x2xi1>
|
||||
%msk = tpu.create_subelement_mask 3, 9 : vector<8x128x2xi1>
|
||||
```
|
||||
|
||||
This creates a mask `%msk` where, for all `lane`s, `%msk[*][lane][*]` is:
|
||||
@ -547,12 +545,11 @@ def TPU_CreateSubelementMaskOp : TPU_Op<"create_subelement_mask", [Pure]> {
|
||||
}];
|
||||
let arguments = (ins
|
||||
I32Attr:$from, // inclusive
|
||||
I32Attr:$to, // exclusive
|
||||
I32Attr:$num_subelems
|
||||
I32Attr:$to // exclusive
|
||||
);
|
||||
let results = (outs AnyType:$output); // Verify this is a vmsk with num_subelems
|
||||
let assemblyFormat = [{
|
||||
$from `,` $to `,` $num_subelems attr-dict `:` type($output)
|
||||
$from `,` $to attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -2669,7 +2669,7 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
|
||||
// support for unpacked types in some of the invariants in
|
||||
// lower_to_llo.
|
||||
mask = builder.create<tpu::CreateSubelementMaskOp>(
|
||||
op.getLoc(), vmask_ty, 0, operand_offset, packing);
|
||||
op.getLoc(), vmask_ty, 0, operand_offset);
|
||||
} else {
|
||||
auto sublane_offset = operand_offset / packing;
|
||||
mask = builder.create<tpu::CreateMaskOp>(
|
||||
|
Loading…
x
Reference in New Issue
Block a user