mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add test case for jit compile of vmap on gpu.
Update tests/jax_jit_test.py Co-authored-by: Jake Vanderplas <jakevdp@gmail.com> Validate expected output and allow test for all device types. Adjust variable names. use jnp.zeros for conciseness.
This commit is contained in:
parent
5a068da699
commit
1ff268e231
@ -205,6 +205,28 @@ class JaxJitTest(jtu.JaxTestCase):
|
||||
jitted_f = jax.jit(f)
|
||||
self.assertEqual(inspect.signature(f), inspect.signature(jitted_f))
|
||||
|
||||
def test_jit_compile_vmap(self):
|
||||
# Regression test for https://github.com/openxla/xla/issues/15744
|
||||
@jax.vmap
|
||||
def fn(x):
|
||||
R1 = jnp.array([[x[0], 0, 0],
|
||||
[0, x[0], 0],
|
||||
[0, 0, x[0]]])
|
||||
R2 = jnp.array([[x[0], 0, 0],
|
||||
[0, x[1], 0],
|
||||
[0, 0, x[2]]])
|
||||
H = jnp.eye(4)
|
||||
H = H.at[:3, :3].set(R2.T)
|
||||
pos = H @ jnp.concatenate([x, jnp.array([1.0])])
|
||||
return pos, R1
|
||||
jitted_fn = jax.jit(fn)
|
||||
v1, v2 = jitted_fn(jnp.zeros((2,3)))
|
||||
v1_expected = jnp.array([[0., 0., 0., 1.],
|
||||
[0., 0., 0., 1.]])
|
||||
v2_expected = jnp.zeros((2, 3, 3))
|
||||
self.assertArraysEqual(v1, v1_expected)
|
||||
self.assertArraysEqual(v2, v2_expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user