mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Improve the coverage of shard map tests for < 8 devices. Due to the skip in SetupModule before this change, we lost a lot of coverage on latest hardware.
PiperOrigin-RevId: 676571965
This commit is contained in:
parent
47b177bd03
commit
e209abfb2c
@ -56,9 +56,7 @@ zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
# Helper for some tests.
|
||||
def create_inputs(a_sharding, b_sharding):
|
||||
x, y, z = 2, 2, 2 # pylint: disable=invalid-name
|
||||
devices = np.array(jax.devices()[:x * y * z]).reshape((x, y, z))
|
||||
mesh = Mesh(devices, axis_names=('x', 'y', 'z'))
|
||||
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
b, e, f = 8, 8, 8 # pylint: disable=invalid-name
|
||||
m1 = jax.device_put(
|
||||
jnp.arange(b * e).reshape((b, e)),
|
||||
@ -74,8 +72,6 @@ _exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
if len(jax.devices()) < 8:
|
||||
raise unittest.SkipTest("tests require 8 devices")
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
@ -93,7 +89,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
def fwd(a):
|
||||
c = shard_map(
|
||||
lambda x: x,
|
||||
identity,
|
||||
mesh,
|
||||
in_specs=(P('z', ('x', 'y')),),
|
||||
out_specs=P('z', ('x', 'y')))(a)
|
||||
@ -219,8 +215,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(np.squeeze(c.addressable_data(2 * i + 1), -1), sums)
|
||||
|
||||
def test_collective_permute(self):
|
||||
devices = np.array(jax.devices()[:8]) # Take up to 8 devices
|
||||
mesh = Mesh(devices, axis_names=('x'))
|
||||
mesh = jtu.create_mesh((8,), 'x')
|
||||
a = jax.device_put(
|
||||
jnp.arange(8 * 8).reshape((8, 8)),
|
||||
jax.sharding.NamedSharding(mesh, P('x', None)))
|
||||
@ -238,10 +233,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(c[1, :], a[0, :])
|
||||
|
||||
def test_collective_permute_with_multiple_axis_names(self):
|
||||
mesh = Mesh(
|
||||
np.array(jax.devices()[:8]).reshape((2, 2, 2)),
|
||||
axis_names=('x', 'y', 'z'),
|
||||
)
|
||||
mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
a = jax.device_put(
|
||||
jnp.arange(8 * 8).reshape((4, 16)),
|
||||
jax.sharding.NamedSharding(mesh, P('x', ('y', 'z'))),
|
||||
@ -284,11 +276,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
),
|
||||
)
|
||||
def test_all_to_all(self, axis_name, mesh_axes):
|
||||
devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))])
|
||||
mesh = Mesh(
|
||||
devices.reshape(tuple(mesh_axes.values())),
|
||||
axis_names=tuple(mesh_axes.keys()),
|
||||
)
|
||||
mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys()))
|
||||
a = jax.device_put(
|
||||
jnp.arange(8 * 8).reshape((8, 8)),
|
||||
jax.sharding.NamedSharding(mesh, P(axis_name, None)),
|
||||
@ -310,12 +298,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
assert (c == jnp.reshape(a.T, (1, 64))).all()
|
||||
|
||||
def test_all_to_all_with_axis_index_groups(self):
|
||||
mesh_axes = dict(x=4)
|
||||
devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))])
|
||||
mesh = Mesh(
|
||||
devices.reshape(tuple(mesh_axes.values())),
|
||||
axis_names=tuple(mesh_axes.keys()),
|
||||
)
|
||||
mesh = jtu.create_mesh((4,), ('x',))
|
||||
a = jax.device_put(
|
||||
jnp.arange(4 * 4).reshape((4, 4)),
|
||||
jax.sharding.NamedSharding(mesh, P('x', None)),
|
||||
@ -348,12 +331,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(block, c.addressable_data(2 * i + j))
|
||||
|
||||
def test_all_to_all_grad(self):
|
||||
mesh_axes = dict(x=4)
|
||||
devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))])
|
||||
mesh = Mesh(
|
||||
devices.reshape(tuple(mesh_axes.values())),
|
||||
axis_names=tuple(mesh_axes.keys()),
|
||||
)
|
||||
mesh = jtu.create_mesh((4,), 'x')
|
||||
a = jax.device_put(
|
||||
jnp.arange(8 * 8, dtype=jnp.float32).reshape((8, 8)),
|
||||
jax.sharding.NamedSharding(mesh, P('x', None)),
|
||||
@ -382,7 +360,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(grad, 2 * np.ones_like(a))
|
||||
|
||||
def test_eager_repr(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = None
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y'))
|
||||
@ -396,7 +374,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertIn('at mesh coordinates', s)
|
||||
|
||||
def test_jvp_basic(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
|
||||
in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
|
||||
args = np.arange(4 * 4.).reshape(4, 4),
|
||||
@ -404,7 +382,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(jax.jit(g), args, 2, ['fwd'])
|
||||
|
||||
def test_linearize_basic(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
|
||||
in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
|
||||
x = np.arange(4 * 4.).reshape(4, 4)
|
||||
@ -418,7 +396,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
|
||||
|
||||
def test_linearize_basic_repres(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh,
|
||||
in_specs=(P('x',),), out_specs=P('x',))
|
||||
x = np.arange(4.)
|
||||
@ -432,7 +410,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
|
||||
|
||||
def test_linearize_basic_repres_jit(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh,
|
||||
in_specs=(P('x',),), out_specs=P('x',))
|
||||
x = np.arange(4.)
|
||||
@ -446,7 +424,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y_dot, y_dot_, check_dtypes=False)
|
||||
|
||||
def test_replication_checker_eager(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
x = np.arange(8 * 8.).reshape(8, 8)
|
||||
|
||||
def f(x):
|
||||
@ -464,7 +442,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
_ = g2(x) # doesn't crash
|
||||
|
||||
def test_replication_checker_jit(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
x = np.arange(8 * 8.).reshape(8, 8)
|
||||
|
||||
def f(x):
|
||||
@ -494,7 +472,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(g, (x,), modes=['fwd'], order=2)
|
||||
|
||||
def test_eager_control_flow(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
x = jnp.arange(2 * 2.).reshape(2, 2)
|
||||
|
||||
def f(x):
|
||||
@ -510,12 +488,12 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, -x, check_dtypes=False)
|
||||
|
||||
def test_outer_jit_detects_shard_map_mesh(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x'))
|
||||
_ = jax.jit(f)(jnp.array(2.0)) # doesn't crash
|
||||
|
||||
def test_vmap_basic(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
x = jnp.arange(8 * 8.).reshape(8, 8)
|
||||
|
||||
def g(x):
|
||||
@ -525,7 +503,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
||||
|
||||
def test_vmap_basic_axis_name(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
x = jnp.arange(8 * 8.).reshape(8, 8)
|
||||
|
||||
def g(x):
|
||||
@ -535,7 +513,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
||||
|
||||
def test_vmap_basic_axis_name_reuse_mesh_name(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
x = jnp.arange(8 * 8.).reshape(8, 8)
|
||||
|
||||
def g(x):
|
||||
@ -545,7 +523,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
||||
|
||||
def test_tree_prefix_error(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=([P('x', 'y')],), out_specs=P('x', 'y'))
|
||||
def f(x):
|
||||
@ -556,7 +534,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
f([x, x])
|
||||
|
||||
def test_rank_errors(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
def foo():
|
||||
return {'hi': [3.]}
|
||||
@ -577,7 +555,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
shard_map(foo, mesh=mesh, in_specs=P(None), out_specs=())(3.)
|
||||
|
||||
def test_reverse_mode_ad(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh,
|
||||
@ -591,7 +569,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_post_process(self):
|
||||
# JVPTrace.post_process_shard_map and JaxprTrace.post_process_shard_map
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
def f(x):
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
||||
@ -608,7 +586,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices('gpu', 'tpu')
|
||||
def test_axis_index(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]), ('x',))
|
||||
mesh = jtu.create_mesh((4,), 'x')
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh, in_specs=(), out_specs=P('x'))
|
||||
@ -716,7 +694,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
jax.jit(f3)()
|
||||
|
||||
def test_vmap_spmd_axis_name(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
||||
def f(x):
|
||||
@ -731,7 +709,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},))
|
||||
|
||||
def test_vmap_spmd_axis_name_pair(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
|
||||
def f(x):
|
||||
@ -747,7 +725,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_nested_vmap_with_capture_spmd_axis_name(self):
|
||||
self.skipTest('https://github.com/google/jax/issues/23476')
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
def to_map_with_capture(x, y):
|
||||
|
||||
@ -1545,7 +1523,6 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_mesh((4,), ('heads',))
|
||||
|
||||
def f(q, k, v):
|
||||
|
||||
def body(q, k, v):
|
||||
return q * k[None, :] + v[None, :]
|
||||
|
||||
@ -1560,7 +1537,11 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
k = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
|
||||
v = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
|
||||
|
||||
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)
|
||||
if jtu.device_under_test() == 'tpu':
|
||||
rtol = 2e-2
|
||||
else:
|
||||
rtol = 1e-2
|
||||
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=rtol)
|
||||
|
||||
def test_axis_env_extension_regression(self):
|
||||
def foo(x):
|
||||
@ -1675,7 +1656,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) # don't crash
|
||||
|
||||
def test_error_for_variable_num_args(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
def f(*args):
|
||||
return args[0] @ args[1]
|
||||
@ -1687,7 +1668,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
shard_f(jnp.ones((8, 8)), jnp.ones((8, 8)))
|
||||
|
||||
def test_custom_vjp_replication_error_message_hint(self):
|
||||
mesh = Mesh(np.array(jax.devices()[:4]), ('i',))
|
||||
mesh = jtu.create_mesh((4,), 'i')
|
||||
|
||||
@jax.custom_vjp
|
||||
def f(x):
|
||||
@ -1710,7 +1691,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_repeated_psum_allowed(self):
|
||||
# https://github.com/google/jax/issues/19175
|
||||
mesh = Mesh(jax.devices()[:4], ('i',))
|
||||
mesh = jtu.create_mesh((4,), 'i')
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P())
|
||||
def g(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user