diff --git a/tests/jax_jit_test.py b/tests/jax_jit_test.py index f02a56d1a..5946d557d 100644 --- a/tests/jax_jit_test.py +++ b/tests/jax_jit_test.py @@ -205,6 +205,28 @@ class JaxJitTest(jtu.JaxTestCase): jitted_f = jax.jit(f) self.assertEqual(inspect.signature(f), inspect.signature(jitted_f)) + def test_jit_compile_vmap(self): + # Regression test for https://github.com/openxla/xla/issues/15744 + @jax.vmap + def fn(x): + R1 = jnp.array([[x[0], 0, 0], + [0, x[0], 0], + [0, 0, x[0]]]) + R2 = jnp.array([[x[0], 0, 0], + [0, x[1], 0], + [0, 0, x[2]]]) + H = jnp.eye(4) + H = H.at[:3, :3].set(R2.T) + pos = H @ jnp.concatenate([x, jnp.array([1.0])]) + return pos, R1 + jitted_fn = jax.jit(fn) + v1, v2 = jitted_fn(jnp.zeros((2,3))) + v1_expected = jnp.array([[0., 0., 0., 1.], + [0., 0., 0., 1.]]) + v2_expected = jnp.zeros((2, 3, 3)) + self.assertArraysEqual(v1, v1_expected) + self.assertArraysEqual(v2, v2_expected) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())