[pallas:triton] Temporarily reverted to the lowering using Triton IR

The new lowering caused a performance regression internally.

PiperOrigin-RevId: 723934141
This commit is contained in:
Sergei Lebedev 2025-02-06 07:52:14 -08:00 committed by jax authors
parent 5d647ccfa1
commit efbb0afd7a

View File

@ -100,7 +100,8 @@ def pallas_call_lowering(
buf = io.BytesIO()
module_op.write_bytecode(buf)
if jaxlib_version < (0, 5, 1):
# TODO(b/394629193): Remove True once the bug is fixed.
if True and jaxlib_version < (0, 5, 1):
# AOT Triton compilation is only available on jaxlib 0.5.1+.
out_types = [
ir.RankedTensorType.get(bm.array_shape_dtype.shape,