Adds coverage for spmd-axisname-filtering in shard_map transpose.

PiperOrigin-RevId: 699193349
This commit is contained in:
Keith Rush 2024-11-22 09:13:46 -08:00 committed by jax authors
parent 34a2f0ca4a
commit c0811c9dff

View File

@ -709,6 +709,26 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertIn('out_names', e.params)
self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},))
def test_vmap_of_grad_spmd_axis_name(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@partial(
shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_rep=False
)
def f(x):
return jnp.sin(jnp.sum(x))
x = jnp.arange(4 * 4, dtype=jnp.float32).reshape(4, 4)
put_x = jax.device_put(
x,
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')),
)
vmap_spmd_axisname_result = jax.vmap(jax.grad(f), spmd_axis_name='x')(put_x)
vmap_no_spmd_axisname_result = jax.vmap(jax.grad(f))(put_x)
self.assertArraysEqual(
vmap_spmd_axisname_result, vmap_no_spmd_axisname_result
)
def test_vmap_spmd_axis_name_pair(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))