[export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of custom calls with guaranteed compatibility.

This is because the underlying Triton IR does not guarantee compatibility.

PiperOrigin-RevId: 703127711
This commit is contained in:
George Necula 2024-12-05 08:39:48 -08:00 committed by jax authors
parent 4a41aa0a46
commit 3f5f3e1c47
4 changed files with 14 additions and 3 deletions

View File

@ -70,6 +70,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
return NaN for negative integer inputs, to match the behavior of SciPy from return NaN for negative integer inputs, to match the behavior of SciPy from
https://github.com/scipy/scipy/pull/21827. https://github.com/scipy/scipy/pull/21827.
* `jax.clear_backends` was removed after being deprecated in v0.4.26. * `jax.clear_backends` was removed after being deprecated in v0.4.26.
* We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
call that we guarantee export stability. This is because this custom call
relies on Triton IR, which is not guaranteed to be stable. If you need
to export code that uses this custom call, you can use the `disabled_checks`
parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls).
* New Features * New Features
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for

View File

@ -1005,7 +1005,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
*_CPU_FFI_KERNELS, *_CPU_FFI_KERNELS,
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"cu_threefry2x32", "cu_threefry2x32_ffi", "cu_threefry2x32", "cu_threefry2x32_ffi",
"__gpu$xla.gpu.triton", # Pallas call on GPU # Triton IR does not guarantee stability.
# "__gpu$xla.gpu.triton",
# cholesky on CPU # cholesky on CPU
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
# eigh on TPU # eigh on TPU

View File

@ -48,9 +48,10 @@ class CompatTest(bctu.CompatTestBase):
self.skipTest("Only works on GPUs with capability >= sm80") self.skipTest("Only works on GPUs with capability >= sm80")
super().setUp() super().setUp()
@unittest.skip("TODO(necula): This test is checking backwards compatibility " @unittest.skip("This test is checking backwards compatibility "
"of Triton IR, but Triton doesn't promise backwards " "of Triton IR, but Triton doesn't promise backwards "
"compatibility for its IR.") "compatibility for its IR, and we have since removed "
"the corresponding custom call from the guaranteed stable list.")
def test_triton_add_one(self): def test_triton_add_one(self):
def func(x): def func(x):
def add_one(x_ref, o_ref): def add_one(x_ref, o_ref):

View File

@ -50,6 +50,10 @@ class ExportTest(jtu.JaxTestCase):
exp = export.export( exp = export.export(
add_vectors, add_vectors,
platforms=["tpu", "cuda"], platforms=["tpu", "cuda"],
# The Pallas GPU custom call is not enabled for export by default.
disabled_checks=[
export.DisabledSafetyCheck.custom_call("__gpu$xla.gpu.triton")
]
)(a, a) )(a, a)
if (jtu.device_under_test() == "tpu" or if (jtu.device_under_test() == "tpu" or