mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Adds failing test for https://github.com/google/jax/issues/23476.
PiperOrigin-RevId: 672183133
This commit is contained in:
parent
cd782643a1
commit
265bb7bf4c
@ -746,6 +746,37 @@ class ShardMapTest(jtu.JaxTestCase):
|
|||||||
self.assertIn('out_names', e.params)
|
self.assertIn('out_names', e.params)
|
||||||
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))
|
self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},))
|
||||||
|
|
||||||
|
def test_nested_vmap_with_capture_spmd_axis_name(self):
|
||||||
|
self.skipTest('https://github.com/google/jax/issues/23476')
|
||||||
|
mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y'))
|
||||||
|
|
||||||
|
def to_map_with_capture(x, y):
|
||||||
|
|
||||||
|
# We capture x from `to_map_with_capture`'s parameters.
|
||||||
|
def with_capture(y_slice):
|
||||||
|
# Inside of all the maps, we have 'mapped everything away'--we are just
|
||||||
|
# adding two scalars, but one by fully mapping across each of the two
|
||||||
|
# dimensions, the other by mapping across one and capturing the
|
||||||
|
# resulting scalar.
|
||||||
|
self.assertEqual(x.shape, ())
|
||||||
|
self.assertEqual(y_slice.shape, ())
|
||||||
|
return x + y_slice
|
||||||
|
|
||||||
|
# This vmap i will refer to as 'inner vmap'.
|
||||||
|
vmap_with_capture = jax.vmap(with_capture)
|
||||||
|
shmap_vmap_capture = shard_map(
|
||||||
|
vmap_with_capture, mesh=mesh, in_specs=P('y'), out_specs=P('y')
|
||||||
|
)
|
||||||
|
return shmap_vmap_capture(y)
|
||||||
|
|
||||||
|
# And this one is the outer vmap.
|
||||||
|
mapped = jax.vmap(to_map_with_capture, spmd_axis_name='x')
|
||||||
|
x = jnp.arange(2).reshape(2)
|
||||||
|
y = jnp.arange(2 * 2).reshape(2, 2)
|
||||||
|
# Inner vmap inside of shard-map will be over an axis of size 1. Outer vmap
|
||||||
|
# is over an axis of size 2. This is a problem at the moment.
|
||||||
|
jax.make_jaxpr(mapped)(x, y).jaxpr
|
||||||
|
|
||||||
@unittest.skipIf(xla_extension_version < 281,
|
@unittest.skipIf(xla_extension_version < 281,
|
||||||
'Requires xla_extension_version >= 281')
|
'Requires xla_extension_version >= 281')
|
||||||
def test_shard_map_abstract_mesh(self):
|
def test_shard_map_abstract_mesh(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user