mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of custom calls with guaranteed compatibility.
This is because the underlying Triton IR does not guarantee compatibility. PiperOrigin-RevId: 703127711
This commit is contained in:
parent
4a41aa0a46
commit
3f5f3e1c47
@ -70,6 +70,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
return NaN for negative integer inputs, to match the behavior of SciPy from
|
return NaN for negative integer inputs, to match the behavior of SciPy from
|
||||||
https://github.com/scipy/scipy/pull/21827.
|
https://github.com/scipy/scipy/pull/21827.
|
||||||
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
|
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
|
||||||
|
* We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
|
||||||
|
call that we guarantee export stability. This is because this custom call
|
||||||
|
relies on Triton IR, which is not guaranteed to be stable. If you need
|
||||||
|
to export code that uses this custom call, you can use the `disabled_checks`
|
||||||
|
parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls).
|
||||||
|
|
||||||
* New Features
|
* New Features
|
||||||
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
|
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
|
||||||
|
@ -1005,7 +1005,8 @@ _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = {
|
|||||||
*_CPU_FFI_KERNELS,
|
*_CPU_FFI_KERNELS,
|
||||||
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
|
||||||
"cu_threefry2x32", "cu_threefry2x32_ffi",
|
"cu_threefry2x32", "cu_threefry2x32_ffi",
|
||||||
"__gpu$xla.gpu.triton", # Pallas call on GPU
|
# Triton IR does not guarantee stability.
|
||||||
|
# "__gpu$xla.gpu.triton",
|
||||||
# cholesky on CPU
|
# cholesky on CPU
|
||||||
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
|
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
|
||||||
# eigh on TPU
|
# eigh on TPU
|
||||||
|
@ -48,9 +48,10 @@ class CompatTest(bctu.CompatTestBase):
|
|||||||
self.skipTest("Only works on GPUs with capability >= sm80")
|
self.skipTest("Only works on GPUs with capability >= sm80")
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
@unittest.skip("TODO(necula): This test is checking backwards compatibility "
|
@unittest.skip("This test is checking backwards compatibility "
|
||||||
"of Triton IR, but Triton doesn't promise backwards "
|
"of Triton IR, but Triton doesn't promise backwards "
|
||||||
"compatibility for its IR.")
|
"compatibility for its IR, and we have since removed "
|
||||||
|
"the corresponding custom call from the guaranteed stable list.")
|
||||||
def test_triton_add_one(self):
|
def test_triton_add_one(self):
|
||||||
def func(x):
|
def func(x):
|
||||||
def add_one(x_ref, o_ref):
|
def add_one(x_ref, o_ref):
|
||||||
|
@ -50,6 +50,10 @@ class ExportTest(jtu.JaxTestCase):
|
|||||||
exp = export.export(
|
exp = export.export(
|
||||||
add_vectors,
|
add_vectors,
|
||||||
platforms=["tpu", "cuda"],
|
platforms=["tpu", "cuda"],
|
||||||
|
# The Pallas GPU custom call is not enabled for export by default.
|
||||||
|
disabled_checks=[
|
||||||
|
export.DisabledSafetyCheck.custom_call("__gpu$xla.gpu.triton")
|
||||||
|
]
|
||||||
)(a, a)
|
)(a, a)
|
||||||
|
|
||||||
if (jtu.device_under_test() == "tpu" or
|
if (jtu.device_under_test() == "tpu" or
|
||||||
|
Loading…
x
Reference in New Issue
Block a user