mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Adds coverage for spmd-axisname-filtering in shard_map transpose.
PiperOrigin-RevId: 699193349
This commit is contained in:
parent
34a2f0ca4a
commit
c0811c9dff
@ -709,6 +709,26 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertIn('out_names', e.params)
|
||||
self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},))
|
||||
|
||||
def test_vmap_of_grad_spmd_axis_name(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
@partial(
|
||||
shard_map, mesh=mesh, in_specs=P('y'), out_specs=P(), check_rep=False
|
||||
)
|
||||
def f(x):
|
||||
return jnp.sin(jnp.sum(x))
|
||||
|
||||
x = jnp.arange(4 * 4, dtype=jnp.float32).reshape(4, 4)
|
||||
put_x = jax.device_put(
|
||||
x,
|
||||
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')),
|
||||
)
|
||||
vmap_spmd_axisname_result = jax.vmap(jax.grad(f), spmd_axis_name='x')(put_x)
|
||||
vmap_no_spmd_axisname_result = jax.vmap(jax.grad(f))(put_x)
|
||||
self.assertArraysEqual(
|
||||
vmap_spmd_axisname_result, vmap_no_spmd_axisname_result
|
||||
)
|
||||
|
||||
def test_vmap_spmd_axis_name_pair(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user