mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[Mosaic GPU][NFC] Delete workaround for dialect bindings before jaxlib 0.5.1.
PiperOrigin-RevId: 732102282
This commit is contained in:
parent
7c46480eab
commit
1bc36e623b
@ -149,11 +149,6 @@ def _fragmented_array_from_ir(
|
||||
).to_layout(layouts.from_layout_attr(layout))
|
||||
|
||||
|
||||
# TODO(dasenov): Remove this when minimum jaxlib version >= 0.5.1.
|
||||
# Jaxlib doesn't contain the latest Mosaic GPU dialect bindings.
|
||||
WaitOp = getattr(mgpu, "WaitOp", None)
|
||||
ArriveExpectTxOp = getattr(mgpu, "ArriveExpectTxOp", None)
|
||||
|
||||
def _register_lowering(
|
||||
op: str | Type[ir.OpView] | None
|
||||
) -> Callable[[MlirLoweringRule], MlirLoweringRule]:
|
||||
@ -574,9 +569,9 @@ def _mgpu_wgmma_op_lowering_rule(
|
||||
return [_fragmented_array_to_ir(new_acc.value, wgmma_op.accumulator.type)]
|
||||
|
||||
|
||||
@_register_lowering(ArriveExpectTxOp)
|
||||
@_register_lowering(mgpu.ArriveExpectTxOp)
|
||||
def _mgpu_arrive_expect_tx_op_lowering_rule(
|
||||
ctx: LoweringContext, arrive_expect_tx_op: ArriveExpectTxOp
|
||||
ctx: LoweringContext, arrive_expect_tx_op: mgpu.ArriveExpectTxOp
|
||||
) -> Sequence[ir.Value]:
|
||||
|
||||
barrier = utils.BarrierRef.from_dialect_barrier_memref(arrive_expect_tx_op.barrier)
|
||||
@ -588,9 +583,9 @@ def _mgpu_arrive_expect_tx_op_lowering_rule(
|
||||
return []
|
||||
|
||||
|
||||
@_register_lowering(WaitOp)
|
||||
@_register_lowering(mgpu.WaitOp)
|
||||
def _mgpu_wait_op_lowering_rule(
|
||||
_: LoweringContext, wait_op: WaitOp
|
||||
_: LoweringContext, wait_op: mgpu.WaitOp
|
||||
) -> Sequence[ir.Value]:
|
||||
|
||||
barrier = utils.BarrierRef.from_dialect_barrier_memref(wait_op.barrier)
|
||||
|
Loading…
x
Reference in New Issue
Block a user