From c0811c9dffb5a6ddd6f5baf41a41651ffb7efea1 Mon Sep 17 00:00:00 2001 From: Keith Rush Date: Fri, 22 Nov 2024 09:13:46 -0800 Subject: [PATCH] Adds coverage for spmd-axisname-filtering in shard_map transpose. PiperOrigin-RevId: 699193349 --- tests/shard_map_test.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 2a343f7ba..56cf99879 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -709,6 +709,26 @@ class ShardMapTest(jtu.JaxTestCase): self.assertIn('out_names', e.params) self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},)) + def test_vmap_of_grad_spmd_axis_name(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + + @partial( + shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_rep=False + ) + def f(x): + return jnp.sin(jnp.sum(x)) + + x = jnp.arange(4 * 4, dtype=jnp.float32).reshape(4, 4) + put_x = jax.device_put( + x, + jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')), + ) + vmap_spmd_axisname_result = jax.vmap(jax.grad(f), spmd_axis_name='x')(put_x) + vmap_no_spmd_axisname_result = jax.vmap(jax.grad(f))(put_x) + self.assertArraysEqual( + vmap_spmd_axisname_result, vmap_no_spmd_axisname_result + ) + def test_vmap_spmd_axis_name_pair(self): mesh = jtu.create_mesh((2, 2), ('x', 'y'))