mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic GPU] Eliminate the arrive
attribute from mosaic_gpu.async_load
.
We plan to explicitly issue an `expect_tx` operation all the time when using the dialect. PiperOrigin-RevId: 721411949
This commit is contained in:
parent
1003ba93c3
commit
d8f3b33ae4
@ -251,7 +251,7 @@ def _mgpu_async_load_op_lowering_rule(
|
||||
src_ref=load_op.source,
|
||||
dst_ref=load_op.destination,
|
||||
barrier=barrier,
|
||||
arrive=load_op.arrive,
|
||||
arrive=False,
|
||||
uniform=True,
|
||||
swizzle=load_op.swizzle.value,
|
||||
predicate=ctx.single_thread_per_warpgroup_predicate,
|
||||
|
@ -224,10 +224,8 @@ def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
|
||||
the `destination` MemRef in SMEM. The `destination` MemRef in SMEM must be
|
||||
contiguous.
|
||||
|
||||
If `arrive` is true, the `arrive.expect-tx(expect_count)` operation will be
|
||||
executed on the provided `barrier` before the copy is scheduled. Upon
|
||||
completion of the copy, the `complete-tx(complete-count)` operation will
|
||||
always be executed on the provided `barrier`.
|
||||
Upon completion of the copy, the `complete-tx(complete-count)` operation
|
||||
will always be executed on the provided `barrier`.
|
||||
|
||||
The `indices` and `slice_lengths` inputs define what slice of the GMEM
|
||||
`source` corresponds to the SMEM `destination`. Both `indices` and
|
||||
@ -270,7 +268,6 @@ def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
|
||||
DenseI64ArrayAttr:$slice_lengths,
|
||||
TypedArrayAttrBase<AnyAttrOf<[TileTransformAttr, TransposeTransformAttr]>, "transforms">:$transforms,
|
||||
DefaultValuedAttr<MosaicGPU_SwizzlingMode, "SwizzlingMode::kNoSwizzle">:$swizzle,
|
||||
DefaultValuedAttr<BoolAttr, "true" >:$arrive,
|
||||
TypedArrayAttrBase<MosaicGPU_Dimension, "dimensions">:$collective
|
||||
);
|
||||
|
||||
|
@ -2054,7 +2054,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
slice_lengths=shape,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
arrive=False,
|
||||
swizzle=swizzle,
|
||||
)
|
||||
mgpu_dialect.async_load(
|
||||
@ -2065,7 +2064,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
slice_lengths=shape,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
arrive=False,
|
||||
swizzle=swizzle,
|
||||
)
|
||||
|
||||
@ -2165,7 +2163,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
slice_lengths=shape_a,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
arrive=False,
|
||||
swizzle=swizzle,
|
||||
)
|
||||
mgpu_dialect.async_load(
|
||||
@ -2176,7 +2173,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
slice_lengths=shape_b,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
arrive=False,
|
||||
swizzle=swizzle,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user