Accelerate deprecation of legacy JAX FFI calling convention.

In https://github.com/jax-ml/jax/pull/24370, `ffi_call` was updated to return a callable, and the original calling convention was deprecated. This change is part of the deprecation cycle for this calling convention.

PiperOrigin-RevId: 708424223
This commit is contained in:
Dan Foreman-Mackey 2024-12-20 14:12:50 -08:00 committed by jax authors
parent 3a35155b15
commit 4216f8fad0

View File

@ -30,6 +30,7 @@ from jax._src import abstract_arrays
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import linear_util
from jax._src import prng
from jax._src import test_util as jtu
@ -279,8 +280,13 @@ class FfiTest(jtu.JaxTestCase):
def testBackwardCompatSyntax(self):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x, x, param=0.5)
with self.assertWarns(DeprecationWarning):
jax.jit(fun).lower(jnp.ones(5))
msg = "Calling ffi_call directly with input arguments is deprecated"
if deprecations.is_accelerated("jax-ffi-call-args"):
with self.assertRaisesRegex(ValueError, msg):
jax.jit(fun).lower(jnp.ones(5))
else:
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(fun).lower(jnp.ones(5))
def testInputOutputAliases(self):
def fun(x):