diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index dc622ecca..05b4c8d7c 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -1057,6 +1057,20 @@ class PureCallbackTest(jtu.JaxTestCase): result += fun(jnp.ones((500, 500), jnp.complex64))[1] jax.block_until_ready(result) # doesn't deadlock + def test_non_default_stride(self): + x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4) + def callback(x): + return np.asfortranarray(x) + + @jax.jit + def f(x): + return jax.pure_callback( + callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x + ) + + result = f(x) + np.testing.assert_array_equal(x, result) + class IOCallbackTest(jtu.JaxTestCase):