mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[jax2tf] Add xmap tests
PiperOrigin-RevId: 505608747
This commit is contained in:
parent
ef4c2a38d6
commit
7eb7baa6f1
@ -318,6 +318,53 @@ class PjitTest(tf_test_util.JaxToTfTestCase):
|
||||
self.ConvertAndCompare(func_jax, x,
|
||||
limitations=[skip_eager_for_partitioning])
|
||||
|
||||
def test_xmap_basic(self):
|
||||
local_devices = list(jax.local_devices())
|
||||
if len(local_devices) < 2:
|
||||
raise unittest.SkipTest("Test requires at least 4 local devices")
|
||||
def f(a, b):
|
||||
return a * 2, b * 4
|
||||
devices = np.array(local_devices[:2]).reshape((1, 2))
|
||||
with Mesh(devices, ('x', 'y')):
|
||||
fm = xmap(f,
|
||||
in_axes=({0: 'a', 1: 'b'}, ['c', ...]),
|
||||
out_axes=({0: 'a', 1: 'b'}, ['c', ...]),
|
||||
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
|
||||
ashape = (16, 8, 5)
|
||||
a = jnp.arange(np.prod(ashape)).reshape(ashape)
|
||||
bshape = (2, 7)
|
||||
b = jnp.arange(np.prod(bshape)).reshape(bshape)
|
||||
|
||||
res_jax = fm(a, b)
|
||||
self.assertAllClose(res_jax, (a * 2, b * 4))
|
||||
|
||||
# xmap works only with native lowering
|
||||
_log_sharding_annotations(self, fm, [a, b],
|
||||
experimental_native_lowering=True)
|
||||
res_tf = tf.function(
|
||||
jax2tf.convert(fm, experimental_native_lowering=True),
|
||||
autograph=False, jit_compile=True)(a, b)
|
||||
self.assertAllClose(res_tf, res_jax)
|
||||
|
||||
@jtu.with_mesh([('x', 1), ('y', 2)])
|
||||
def test_xmap_collective_reduce(self):
|
||||
fm = xmap(lambda a, b: (lax.psum(a * 2, 'a'), b * 4),
|
||||
in_axes=(['a', 'b', ...], {0: 'c'}),
|
||||
out_axes=(['b', ...], {0: 'c'}),
|
||||
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
|
||||
ashape = (16, 8, 5)
|
||||
a = jnp.arange(np.prod(ashape)).reshape(ashape)
|
||||
bshape = (2, 7)
|
||||
b = jnp.arange(np.prod(bshape)).reshape(bshape)
|
||||
res_jax = fm(a, b)
|
||||
self.assertAllClose(res_jax, ((a * 2).sum(0), b * 4))
|
||||
|
||||
_log_sharding_annotations(self, fm, [a, b],
|
||||
experimental_native_lowering=True)
|
||||
res_tf = tf.function(
|
||||
jax2tf.convert(fm, experimental_native_lowering=True),
|
||||
autograph=False, jit_compile=True)(a, b)
|
||||
self.assertAllClose(res_tf, res_jax)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user