mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Adds failing test for https://github.com/google/jax/issues/23476.
PiperOrigin-RevId: 672183133
This commit is contained in:
parent
cd782643a1
commit
265bb7bf4c
@ -746,6 +746,37 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertIn('out_names', e.params)
|
||||
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))
|
||||
|
||||
def test_nested_vmap_with_capture_spmd_axis_name(self):
|
||||
self.skipTest('https://github.com/google/jax/issues/23476')
|
||||
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||
|
||||
def to_map_with_capture(x, y):
|
||||
|
||||
# We capture x from `to_map_with_capture`'s parameters.
|
||||
def with_capture(y_slice):
|
||||
# Inside of all the maps, we have 'mapped everything away'--we are just
|
||||
# adding two scalars, but one by fully mapping across each of the two
|
||||
# dimensions, the other by mapping across one and capturing the
|
||||
# resulting scalar.
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(y_slice.shape, ())
|
||||
return x + y_slice
|
||||
|
||||
# This vmap i will refer to as 'inner vmap'.
|
||||
vmap_with_capture = jax.vmap(with_capture)
|
||||
shmap_vmap_capture = shard_map(
|
||||
vmap_with_capture, mesh=mesh, in_specs=P('y'), out_specs=P('y')
|
||||
)
|
||||
return shmap_vmap_capture(y)
|
||||
|
||||
# And this one is the outer vmap.
|
||||
mapped = jax.vmap(to_map_with_capture, spmd_axis_name='x')
|
||||
x = jnp.arange(2).reshape(2)
|
||||
y = jnp.arange(2 * 2).reshape(2, 2)
|
||||
# Inner vmap inside of shard-map will be over an axis of size 1. Outer vmap
|
||||
# is over an axis of size 2. This is a problem at the moment.
|
||||
jax.make_jaxpr(mapped)(x, y).jaxpr
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 281,
|
||||
'Requires xla_extension_version >= 281')
|
||||
def test_shard_map_abstract_mesh(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user