[xla:cpu] Implement XLA FFI handlers for CPU Jax callbacks.

PiperOrigin-RevId: 726185954
This commit is contained in:
Daniel Suo 2025-02-12 13:52:57 -08:00 committed by jax authors
parent 9298018afa
commit 8c685be688

View File

@ -1057,6 +1057,20 @@ class PureCallbackTest(jtu.JaxTestCase):
result += fun(jnp.ones((500, 500), jnp.complex64))[1]
jax.block_until_ready(result) # doesn't deadlock
def test_non_default_stride(self):
x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4)
def callback(x):
return np.asfortranarray(x)
@jax.jit
def f(x):
return jax.pure_callback(
callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x
)
result = f(x)
np.testing.assert_array_equal(x, result)
class IOCallbackTest(jtu.JaxTestCase):