Improve the coverage of shard map tests for < 8 devices. Due to the skip in SetupModule before this change, we lost a lot of coverage on latest hardware.

PiperOrigin-RevId: 676571965
This commit is contained in:
Yash Katariya 2024-09-19 14:48:33 -07:00 committed by jax authors
parent 47b177bd03
commit e209abfb2c

View File

@ -56,9 +56,7 @@ zip, unsafe_zip = safe_zip, zip
# Helper for some tests.
def create_inputs(a_sharding, b_sharding):
x, y, z = 2, 2, 2 # pylint: disable=invalid-name
devices = np.array(jax.devices()[:x * y * z]).reshape((x, y, z))
mesh = Mesh(devices, axis_names=('x', 'y', 'z'))
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
b, e, f = 8, 8, 8 # pylint: disable=invalid-name
m1 = jax.device_put(
jnp.arange(b * e).reshape((b, e)),
@ -74,8 +72,6 @@ _exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
if len(jax.devices()) < 8:
raise unittest.SkipTest("tests require 8 devices")
def tearDownModule():
_exit_stack.close()
@ -93,7 +89,7 @@ class ShardMapTest(jtu.JaxTestCase):
@jax.jit
def fwd(a):
c = shard_map(
lambda x: x,
identity,
mesh,
in_specs=(P('z', ('x', 'y')),),
out_specs=P('z', ('x', 'y')))(a)
@ -219,8 +215,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(np.squeeze(c.addressable_data(2 * i + 1), -1), sums)
def test_collective_permute(self):
devices = np.array(jax.devices()[:8]) # Take up to 8 devices
mesh = Mesh(devices, axis_names=('x'))
mesh = jtu.create_mesh((8,), 'x')
a = jax.device_put(
jnp.arange(8 * 8).reshape((8, 8)),
jax.sharding.NamedSharding(mesh, P('x', None)))
@ -238,10 +233,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(c[1, :], a[0, :])
def test_collective_permute_with_multiple_axis_names(self):
mesh = Mesh(
np.array(jax.devices()[:8]).reshape((2, 2, 2)),
axis_names=('x', 'y', 'z'),
)
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
a = jax.device_put(
jnp.arange(8 * 8).reshape((4, 16)),
jax.sharding.NamedSharding(mesh, P('x', ('y', 'z'))),
@ -284,11 +276,7 @@ class ShardMapTest(jtu.JaxTestCase):
),
)
def test_all_to_all(self, axis_name, mesh_axes):
devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))])
mesh = Mesh(
devices.reshape(tuple(mesh_axes.values())),
axis_names=tuple(mesh_axes.keys()),
)
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
a = jax.device_put(
jnp.arange(8 * 8).reshape((8, 8)),
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
@ -310,12 +298,7 @@ class ShardMapTest(jtu.JaxTestCase):
assert (c == jnp.reshape(a.T, (1, 64))).all()
def test_all_to_all_with_axis_index_groups(self):
mesh_axes = dict(x=4)
devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))])
mesh = Mesh(
devices.reshape(tuple(mesh_axes.values())),
axis_names=tuple(mesh_axes.keys()),
)
mesh = jtu.create_mesh((4,), ('x',))
a = jax.device_put(
jnp.arange(4 * 4).reshape((4, 4)),
jax.sharding.NamedSharding(mesh, P('x', None)),
@ -348,12 +331,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(block, c.addressable_data(2 * i + j))
def test_all_to_all_grad(self):
mesh_axes = dict(x=4)
devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))])
mesh = Mesh(
devices.reshape(tuple(mesh_axes.values())),
axis_names=tuple(mesh_axes.keys()),
)
mesh = jtu.create_mesh((4,), 'x')
a = jax.device_put(
jnp.arange(8 * 8, dtype=jnp.float32).reshape((8, 8)),
jax.sharding.NamedSharding(mesh, P('x', None)),
@ -382,7 +360,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(grad, 2 * np.ones_like(a))
def test_eager_repr(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = None
@partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y'))
@ -396,7 +374,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertIn('at mesh coordinates', s)
def test_jvp_basic(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
args = np.arange(4 * 4.).reshape(4, 4),
@ -404,7 +382,7 @@ class ShardMapTest(jtu.JaxTestCase):
jtu.check_grads(jax.jit(g), args, 2, ['fwd'])
def test_linearize_basic(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
x = np.arange(4 * 4.).reshape(4, 4)
@ -418,7 +396,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
def test_linearize_basic_repres(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh,
in_specs=(P('x',),), out_specs=P('x',))
x = np.arange(4.)
@ -432,7 +410,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
def test_linearize_basic_repres_jit(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
in_specs=(P('x',),), out_specs=P('x',))
x = np.arange(4.)
@ -446,7 +424,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
def test_replication_checker_eager(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
x = np.arange(8 * 8.).reshape(8, 8)
def f(x):
@ -464,7 +442,7 @@ class ShardMapTest(jtu.JaxTestCase):
_ = g2(x) # doesn't crash
def test_replication_checker_jit(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
x = np.arange(8 * 8.).reshape(8, 8)
def f(x):
@ -494,7 +472,7 @@ class ShardMapTest(jtu.JaxTestCase):
jtu.check_grads(g, (x,), modes=['fwd'], order=2)
def test_eager_control_flow(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
x = jnp.arange(2 * 2.).reshape(2, 2)
def f(x):
@ -510,12 +488,12 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(y, -x, check_dtypes=False)
def test_outer_jit_detects_shard_map_mesh(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x'))
_ = jax.jit(f)(jnp.array(2.0)) # doesn't crash
def test_vmap_basic(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
x = jnp.arange(8 * 8.).reshape(8, 8)
def g(x):
@ -525,7 +503,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(y, 2 * x, check_dtypes=False)
def test_vmap_basic_axis_name(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
x = jnp.arange(8 * 8.).reshape(8, 8)
def g(x):
@ -535,7 +513,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(y, 2 * x, check_dtypes=False)
def test_vmap_basic_axis_name_reuse_mesh_name(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
x = jnp.arange(8 * 8.).reshape(8, 8)
def g(x):
@ -545,7 +523,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(y, 2 * x, check_dtypes=False)
def test_tree_prefix_error(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@partial(shard_map, mesh=mesh, in_specs=([P('x', 'y')],), out_specs=P('x', 'y'))
def f(x):
@ -556,7 +534,7 @@ class ShardMapTest(jtu.JaxTestCase):
f([x, x])
def test_rank_errors(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
def foo():
return {'hi': [3.]}
@ -577,7 +555,7 @@ class ShardMapTest(jtu.JaxTestCase):
shard_map(foo, mesh=mesh, in_specs=P(None), out_specs=())(3.)
def test_reverse_mode_ad(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@jax.jit
@partial(shard_map, mesh=mesh,
@ -591,7 +569,7 @@ class ShardMapTest(jtu.JaxTestCase):
def test_post_process(self):
# JVPTrace.post_process_shard_map and JaxprTrace.post_process_shard_map
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
def f(x):
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
@ -608,7 +586,7 @@ class ShardMapTest(jtu.JaxTestCase):
@jtu.run_on_devices('gpu', 'tpu')
def test_axis_index(self):
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
mesh = jtu.create_mesh((4,), 'x')
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=P('x'))
@ -716,7 +694,7 @@ class ShardMapTest(jtu.JaxTestCase):
jax.jit(f3)()
def test_vmap_spmd_axis_name(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def f(x):
@ -731,7 +709,7 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},))
def test_vmap_spmd_axis_name_pair(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
def f(x):
@ -747,7 +725,7 @@ class ShardMapTest(jtu.JaxTestCase):
def test_nested_vmap_with_capture_spmd_axis_name(self):
self.skipTest('https://github.com/google/jax/issues/23476')
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
def to_map_with_capture(x, y):
@ -1545,7 +1523,6 @@ class ShardMapTest(jtu.JaxTestCase):
mesh = jtu.create_mesh((4,), ('heads',))
def f(q, k, v):
def body(q, k, v):
return q * k[None, :] + v[None, :]
@ -1560,7 +1537,11 @@ class ShardMapTest(jtu.JaxTestCase):
k = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
v = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)
if jtu.device_under_test() == 'tpu':
rtol = 2e-2
else:
rtol = 1e-2
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=rtol)
def test_axis_env_extension_regression(self):
def foo(x):
@ -1675,7 +1656,7 @@ class ShardMapTest(jtu.JaxTestCase):
mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) # don't crash
def test_error_for_variable_num_args(self):
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
def f(*args):
return args[0] @ args[1]
@ -1687,7 +1668,7 @@ class ShardMapTest(jtu.JaxTestCase):
shard_f(jnp.ones((8, 8)), jnp.ones((8, 8)))
def test_custom_vjp_replication_error_message_hint(self):
mesh = Mesh(np.array(jax.devices()[:4]), ('i',))
mesh = jtu.create_mesh((4,), 'i')
@jax.custom_vjp
def f(x):
@ -1710,7 +1691,7 @@ class ShardMapTest(jtu.JaxTestCase):
def test_repeated_psum_allowed(self):
# https://github.com/google/jax/issues/19175
mesh = Mesh(jax.devices()[:4], ('i',))
mesh = jtu.create_mesh((4,), 'i')
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
def g(x):