diff --git a/CHANGELOG.md b/CHANGELOG.md index 258fad49b..b6d0f97f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,6 +70,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. return NaN for negative integer inputs, to match the behavior of SciPy from https://github.com/scipy/scipy/pull/21827. * `jax.clear_backends` was removed after being deprecated in v0.4.26. + * We removed the custom call "__gpu$xla.gpu.triton" from the list of custom + call that we guarantee export stability. This is because this custom call + relies on Triton IR, which is not guaranteed to be stable. If you need + to export code that uses this custom call, you can use the `disabled_checks` + parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index ad2c7fdac..e3508639f 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1005,7 +1005,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = { *_CPU_FFI_KERNELS, "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", "cu_threefry2x32", "cu_threefry2x32_ffi", - "__gpu$xla.gpu.triton", # Pallas call on GPU + # Triton IR does not guarantee stability. + # "__gpu$xla.gpu.triton", # cholesky on CPU "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", # eigh on TPU diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index 1b810bcb6..462597e56 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -48,9 +48,10 @@ class CompatTest(bctu.CompatTestBase): self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() - @unittest.skip("TODO(necula): This test is checking backwards compatibility " + @unittest.skip("This test is checking backwards compatibility " "of Triton IR, but Triton doesn't promise backwards " - "compatibility for its IR.") + "compatibility for its IR, and we have since removed " + "the corresponding custom call from the guaranteed stable list.") def test_triton_add_one(self): def func(x): def add_one(x_ref, o_ref): diff --git a/tests/pallas/export_pallas_test.py b/tests/pallas/export_pallas_test.py index 70e40e1f2..8b18f706a 100644 --- a/tests/pallas/export_pallas_test.py +++ b/tests/pallas/export_pallas_test.py @@ -50,6 +50,10 @@ class ExportTest(jtu.JaxTestCase): exp = export.export( add_vectors, platforms=["tpu", "cuda"], + # The Pallas GPU custom call is not enabled for export by default. + disabled_checks=[ + export.DisabledSafetyCheck.custom_call("__gpu$xla.gpu.triton") + ] )(a, a) if (jtu.device_under_test() == "tpu" or