From 8c685be688429c213edf272e3726b5b9e1b45b4d Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Wed, 12 Feb 2025 13:52:57 -0800 Subject: [PATCH] [xla:cpu] Implement XLA FFI handlers for CPU Jax callbacks. PiperOrigin-RevId: 726185954 --- tests/python_callback_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index dc622ecca..05b4c8d7c 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -1057,6 +1057,20 @@ class PureCallbackTest(jtu.JaxTestCase): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + class IOCallbackTest(jtu.JaxTestCase):