From 7eb7baa6f1ce0dbfece656eac1750ba1a460ce59 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 30 Jan 2023 00:23:45 -0800 Subject: [PATCH] [jax2tf] Add xmap tests PiperOrigin-RevId: 505608747 --- .../jax2tf/tests/sharding_test.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 27e0a1a17..36ca0b17c 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -318,6 +318,53 @@ class PjitTest(tf_test_util.JaxToTfTestCase): self.ConvertAndCompare(func_jax, x, limitations=[skip_eager_for_partitioning]) + def test_xmap_basic(self): + local_devices = list(jax.local_devices()) + if len(local_devices) < 2: + raise unittest.SkipTest("Test requires at least 4 local devices") + def f(a, b): + return a * 2, b * 4 + devices = np.array(local_devices[:2]).reshape((1, 2)) + with Mesh(devices, ('x', 'y')): + fm = xmap(f, + in_axes=({0: 'a', 1: 'b'}, ['c', ...]), + out_axes=({0: 'a', 1: 'b'}, ['c', ...]), + axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) + ashape = (16, 8, 5) + a = jnp.arange(np.prod(ashape)).reshape(ashape) + bshape = (2, 7) + b = jnp.arange(np.prod(bshape)).reshape(bshape) + + res_jax = fm(a, b) + self.assertAllClose(res_jax, (a * 2, b * 4)) + + # xmap works only with native lowering + _log_sharding_annotations(self, fm, [a, b], + experimental_native_lowering=True) + res_tf = tf.function( + jax2tf.convert(fm, experimental_native_lowering=True), + autograph=False, jit_compile=True)(a, b) + self.assertAllClose(res_tf, res_jax) + + @jtu.with_mesh([('x', 1), ('y', 2)]) + def test_xmap_collective_reduce(self): + fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4), + in_axes=(['a', 'b', ...], {0: 'c'}), + out_axes=(['b', ...], {0: 'c'}), + axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) + ashape = (16, 8, 5) + a = jnp.arange(np.prod(ashape)).reshape(ashape) + bshape = (2, 7) + b = jnp.arange(np.prod(bshape)).reshape(bshape) + res_jax = fm(a, b) + self.assertAllClose(res_jax, ((a * 2).sum(0), b * 4)) + + _log_sharding_annotations(self, fm, [a, b], + experimental_native_lowering=True) + res_tf = tf.function( + jax2tf.convert(fm, experimental_native_lowering=True), + autograph=False, jit_compile=True)(a, b) + self.assertAllClose(res_tf, res_jax) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())