From 1ff268e231974cf13cc5f48ad9a1c9582440f9d8 Mon Sep 17 00:00:00 2001 From: oyzh Date: Sat, 18 Jan 2025 00:06:30 -0800 Subject: [PATCH] Add test case for jit compile of vmap on gpu. Update tests/jax_jit_test.py Co-authored-by: Jake Vanderplas Validate expected output and allow test for all device types. Adjust variable names. use jnp.zeros for conciseness. --- tests/jax_jit_test.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/jax_jit_test.py b/tests/jax_jit_test.py index f02a56d1a..5946d557d 100644 --- a/tests/jax_jit_test.py +++ b/tests/jax_jit_test.py @@ -205,6 +205,28 @@ class JaxJitTest(jtu.JaxTestCase): jitted_f = jax.jit(f) self.assertEqual(inspect.signature(f), inspect.signature(jitted_f)) + def test_jit_compile_vmap(self): + # Regression test for https://github.com/openxla/xla/issues/15744 + @jax.vmap + def fn(x): + R1 = jnp.array([[x[0], 0, 0], + [0, x[0], 0], + [0, 0, x[0]]]) + R2 = jnp.array([[x[0], 0, 0], + [0, x[1], 0], + [0, 0, x[2]]]) + H = jnp.eye(4) + H = H.at[:3, :3].set(R2.T) + pos = H @ jnp.concatenate([x, jnp.array([1.0])]) + return pos, R1 + jitted_fn = jax.jit(fn) + v1, v2 = jitted_fn(jnp.zeros((2,3))) + v1_expected = jnp.array([[0., 0., 0., 1.], + [0., 0., 0., 1.]]) + v2_expected = jnp.zeros((2, 3, 3)) + self.assertArraysEqual(v1, v1_expected) + self.assertArraysEqual(v2, v2_expected) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())