mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[xla:cpu] Implement XLA FFI handlers for CPU Jax callbacks.
PiperOrigin-RevId: 726185954
This commit is contained in:
parent
9298018afa
commit
8c685be688
@ -1057,6 +1057,20 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
result += fun(jnp.ones((500, 500), jnp.complex64))[1]
|
||||
jax.block_until_ready(result) # doesn't deadlock
|
||||
|
||||
def test_non_default_stride(self):
|
||||
x = jnp.arange(24, dtype=jnp.float32).reshape(2, 3, 4)
|
||||
def callback(x):
|
||||
return np.asfortranarray(x)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return jax.pure_callback(
|
||||
callback, jax.ShapeDtypeStruct(x.shape, x.dtype), x
|
||||
)
|
||||
|
||||
result = f(x)
|
||||
np.testing.assert_array_equal(x, result)
|
||||
|
||||
|
||||
class IOCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user