mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[pallas:triton] Temporarily reverted to the lowering using Triton IR
The new lowering caused a performance regression internally. PiperOrigin-RevId: 723934141
This commit is contained in:
parent
5d647ccfa1
commit
efbb0afd7a
@ -100,7 +100,8 @@ def pallas_call_lowering(
|
||||
buf = io.BytesIO()
|
||||
module_op.write_bytecode(buf)
|
||||
|
||||
if jaxlib_version < (0, 5, 1):
|
||||
# TODO(b/394629193): Remove True once the bug is fixed.
|
||||
if True and jaxlib_version < (0, 5, 1):
|
||||
# AOT Triton compilation is only available on jaxlib 0.5.1+.
|
||||
out_types = [
|
||||
ir.RankedTensorType.get(bm.array_shape_dtype.shape,
|
||||
|
Loading…
x
Reference in New Issue
Block a user