[jax2tf] Add xmap tests

PiperOrigin-RevId: 505608747
This commit is contained in:
George Necula 2023-01-30 00:23:45 -08:00 committed by jax authors
parent ef4c2a38d6
commit 7eb7baa6f1

View File

@ -318,6 +318,53 @@ class PjitTest(tf_test_util.JaxToTfTestCase):
self.ConvertAndCompare(func_jax, x,
limitations=[skip_eager_for_partitioning])
def test_xmap_basic(self):
local_devices = list(jax.local_devices())
if len(local_devices) < 2:
raise unittest.SkipTest("Test requires at least 4 local devices")
def f(a, b):
return a * 2, b * 4
devices = np.array(local_devices[:2]).reshape((1, 2))
with Mesh(devices, ('x', 'y')):
fm = xmap(f,
in_axes=({0: 'a', 1: 'b'}, ['c', ...]),
out_axes=({0: 'a', 1: 'b'}, ['c', ...]),
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
ashape = (16, 8, 5)
a = jnp.arange(np.prod(ashape)).reshape(ashape)
bshape = (2, 7)
b = jnp.arange(np.prod(bshape)).reshape(bshape)
res_jax = fm(a, b)
self.assertAllClose(res_jax, (a * 2, b * 4))
# xmap works only with native lowering
_log_sharding_annotations(self, fm, [a, b],
experimental_native_lowering=True)
res_tf = tf.function(
jax2tf.convert(fm, experimental_native_lowering=True),
autograph=False, jit_compile=True)(a, b)
self.assertAllClose(res_tf, res_jax)
@jtu.with_mesh([('x', 1), ('y', 2)])
def test_xmap_collective_reduce(self):
fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
in_axes=(['a', 'b', ...], {0: 'c'}),
out_axes=(['b', ...], {0: 'c'}),
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
ashape = (16, 8, 5)
a = jnp.arange(np.prod(ashape)).reshape(ashape)
bshape = (2, 7)
b = jnp.arange(np.prod(bshape)).reshape(bshape)
res_jax = fm(a, b)
self.assertAllClose(res_jax, ((a * 2).sum(0), b * 4))
_log_sharding_annotations(self, fm, [a, b],
experimental_native_lowering=True)
res_tf = tf.function(
jax2tf.convert(fm, experimental_native_lowering=True),
autograph=False, jit_compile=True)(a, b)
self.assertAllClose(res_tf, res_jax)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())