[Mosaic GPU] Eliminate the arrive attribute from mosaic_gpu.async_load.

We plan to explicitly issue an `expect_tx` operation all the time when using
the dialect.

PiperOrigin-RevId: 721411949
This commit is contained in:
Benjamin Chetioui 2025-01-30 09:08:02 -08:00 committed by jax authors
parent 1003ba93c3
commit d8f3b33ae4
3 changed files with 3 additions and 10 deletions

View File

@ -251,7 +251,7 @@ def _mgpu_async_load_op_lowering_rule(
src_ref=load_op.source,
dst_ref=load_op.destination,
barrier=barrier,
arrive=load_op.arrive,
arrive=False,
uniform=True,
swizzle=load_op.swizzle.value,
predicate=ctx.single_thread_per_warpgroup_predicate,

View File

@ -224,10 +224,8 @@ def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
the `destination` MemRef in SMEM. The `destination` MemRef in SMEM must be
contiguous.
If `arrive` is true, the `arrive.expect-tx(expect_count)` operation will be
executed on the provided `barrier` before the copy is scheduled. Upon
completion of the copy, the `complete-tx(complete-count)` operation will
always be executed on the provided `barrier`.
Upon completion of the copy, the `complete-tx(complete-count)` operation
will always be executed on the provided `barrier`.
The `indices` and `slice_lengths` inputs define what slice of the GMEM
`source` corresponds to the SMEM `destination`. Both `indices` and
@ -270,7 +268,6 @@ def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
DenseI64ArrayAttr:$slice_lengths,
TypedArrayAttrBase<AnyAttrOf<[TileTransformAttr, TransposeTransformAttr]>, "transforms">:$transforms,
DefaultValuedAttr<MosaicGPU_SwizzlingMode, "SwizzlingMode::kNoSwizzle">:$swizzle,
DefaultValuedAttr<BoolAttr, "true" >:$arrive,
TypedArrayAttrBase<MosaicGPU_Dimension, "dimensions">:$collective
);

View File

@ -2054,7 +2054,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
arrive=False,
swizzle=swizzle,
)
mgpu_dialect.async_load(
@ -2065,7 +2064,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
arrive=False,
swizzle=swizzle,
)
@ -2165,7 +2163,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_lengths=shape_a,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
arrive=False,
swizzle=swizzle,
)
mgpu_dialect.async_load(
@ -2176,7 +2173,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
slice_lengths=shape_b,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
arrive=False,
swizzle=swizzle,
)