mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Accelerate deprecation of legacy JAX FFI calling convention.
In https://github.com/jax-ml/jax/pull/24370, `ffi_call` was updated to return a callable, and the original calling convention was deprecated. This change is part of the deprecation cycle for this calling convention. PiperOrigin-RevId: 708424223
This commit is contained in:
parent
3a35155b15
commit
4216f8fad0
@ -30,6 +30,7 @@ from jax._src import abstract_arrays
|
||||
from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import linear_util
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
@ -279,8 +280,13 @@ class FfiTest(jtu.JaxTestCase):
|
||||
def testBackwardCompatSyntax(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test_ffi", x, x, param=0.5)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
msg = "Calling ffi_call directly with input arguments is deprecated"
|
||||
if deprecations.is_accelerated("jax-ffi-call-args"):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
else:
|
||||
with self.assertWarnsRegex(DeprecationWarning, msg):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
|
||||
def testInputOutputAliases(self):
|
||||
def fun(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user