Add test case for jit compile of vmap on gpu.

Update tests/jax_jit_test.py

Co-authored-by: Jake Vanderplas <jakevdp@gmail.com>

Validate expected output and allow test for all device types.

Adjust variable names.

use jnp.zeros for conciseness.
This commit is contained in:
oyzh 2025-01-18 00:06:30 -08:00
parent 5a068da699
commit 1ff268e231

View File

@ -205,6 +205,28 @@ class JaxJitTest(jtu.JaxTestCase):
jitted_f = jax.jit(f)
self.assertEqual(inspect.signature(f), inspect.signature(jitted_f))
def test_jit_compile_vmap(self):
# Regression test for https://github.com/openxla/xla/issues/15744
@jax.vmap
def fn(x):
R1 = jnp.array([[x[0], 0, 0],
[0, x[0], 0],
[0, 0, x[0]]])
R2 = jnp.array([[x[0], 0, 0],
[0, x[1], 0],
[0, 0, x[2]]])
H = jnp.eye(4)
H = H.at[:3, :3].set(R2.T)
pos = H @ jnp.concatenate([x, jnp.array([1.0])])
return pos, R1
jitted_fn = jax.jit(fn)
v1, v2 = jitted_fn(jnp.zeros((2,3)))
v1_expected = jnp.array([[0., 0., 0., 1.],
[0., 0., 0., 1.]])
v2_expected = jnp.zeros((2, 3, 3))
self.assertArraysEqual(v1, v1_expected)
self.assertArraysEqual(v2, v2_expected)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())