PiperOrigin-RevId: 672183133
This commit is contained in:
Keith Rush 2024-09-07 20:29:30 -07:00 committed by jax authors
parent cd782643a1
commit 265bb7bf4c

View File

@ -746,6 +746,37 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertIn('out_names', e.params)
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))
def test_nested_vmap_with_capture_spmd_axis_name(self):
self.skipTest('https://github.com/google/jax/issues/23476')
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
def to_map_with_capture(x, y):
# We capture x from `to_map_with_capture`'s parameters.
def with_capture(y_slice):
# Inside of all the maps, we have 'mapped everything away'--we are just
# adding two scalars, but one by fully mapping across each of the two
# dimensions, the other by mapping across one and capturing the
# resulting scalar.
self.assertEqual(x.shape, ())
self.assertEqual(y_slice.shape, ())
return x + y_slice
# This vmap i will refer to as 'inner vmap'.
vmap_with_capture = jax.vmap(with_capture)
shmap_vmap_capture = shard_map(
vmap_with_capture, mesh=mesh, in_specs=P('y'), out_specs=P('y')
)
return shmap_vmap_capture(y)
# And this one is the outer vmap.
mapped = jax.vmap(to_map_with_capture, spmd_axis_name='x')
x = jnp.arange(2).reshape(2)
y = jnp.arange(2 * 2).reshape(2, 2)
# Inner vmap inside of shard-map will be over an axis of size 1. Outer vmap
# is over an axis of size 2. This is a problem at the moment.
jax.make_jaxpr(mapped)(x, y).jaxpr
@unittest.skipIf(xla_extension_version < 281,
'Requires xla_extension_version >= 281')
def test_shard_map_abstract_mesh(self):