diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 857c69ce3..2df477454 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -746,6 +746,37 @@ class ShardMapTest(jtu.JaxTestCase): self.assertIn('out_names', e.params) self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},)) + def test_nested_vmap_with_capture_spmd_axis_name(self): + self.skipTest('https://github.com/google/jax/issues/23476') + mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + + def to_map_with_capture(x, y): + + # We capture x from `to_map_with_capture`'s parameters. + def with_capture(y_slice): + # Inside of all the maps, we have 'mapped everything away'--we are just + # adding two scalars, but one by fully mapping across each of the two + # dimensions, the other by mapping across one and capturing the + # resulting scalar. + self.assertEqual(x.shape, ()) + self.assertEqual(y_slice.shape, ()) + return x + y_slice + + # This vmap i will refer to as 'inner vmap'. + vmap_with_capture = jax.vmap(with_capture) + shmap_vmap_capture = shard_map( + vmap_with_capture, mesh=mesh, in_specs=P('y'), out_specs=P('y') + ) + return shmap_vmap_capture(y) + + # And this one is the outer vmap. + mapped = jax.vmap(to_map_with_capture, spmd_axis_name='x') + x = jnp.arange(2).reshape(2) + y = jnp.arange(2 * 2).reshape(2, 2) + # Inner vmap inside of shard-map will be over an axis of size 1. Outer vmap + # is over an axis of size 2. This is a problem at the moment. + jax.make_jaxpr(mapped)(x, y).jaxpr + @unittest.skipIf(xla_extension_version < 281, 'Requires xla_extension_version >= 281') def test_shard_map_abstract_mesh(self):