[Mosaic GPU][NFC] Delete workaround for dialect bindings before jaxlib 0.5.1.

PiperOrigin-RevId: 732102282
This commit is contained in:
Benjamin Chetioui 2025-02-28 05:25:17 -08:00 committed by jax authors
parent 7c46480eab
commit 1bc36e623b

View File

@ -149,11 +149,6 @@ def _fragmented_array_from_ir(
).to_layout(layouts.from_layout_attr(layout))
# TODO(dasenov): Remove this when minimum jaxlib version >= 0.5.1.
# Jaxlib doesn't contain the latest Mosaic GPU dialect bindings.
WaitOp = getattr(mgpu, "WaitOp", None)
ArriveExpectTxOp = getattr(mgpu, "ArriveExpectTxOp", None)
def _register_lowering(
op: str | Type[ir.OpView] | None
) -> Callable[[MlirLoweringRule], MlirLoweringRule]:
@ -574,9 +569,9 @@ def _mgpu_wgmma_op_lowering_rule(
return [_fragmented_array_to_ir(new_acc.value, wgmma_op.accumulator.type)]
@_register_lowering(ArriveExpectTxOp)
@_register_lowering(mgpu.ArriveExpectTxOp)
def _mgpu_arrive_expect_tx_op_lowering_rule(
ctx: LoweringContext, arrive_expect_tx_op: ArriveExpectTxOp
ctx: LoweringContext, arrive_expect_tx_op: mgpu.ArriveExpectTxOp
) -> Sequence[ir.Value]:
barrier = utils.BarrierRef.from_dialect_barrier_memref(arrive_expect_tx_op.barrier)
@ -588,9 +583,9 @@ def _mgpu_arrive_expect_tx_op_lowering_rule(
return []
@_register_lowering(WaitOp)
@_register_lowering(mgpu.WaitOp)
def _mgpu_wait_op_lowering_rule(
_: LoweringContext, wait_op: WaitOp
_: LoweringContext, wait_op: mgpu.WaitOp
) -> Sequence[ir.Value]:
barrier = utils.BarrierRef.from_dialect_barrier_memref(wait_op.barrier)