1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

add test for partial-auto ppermute

PiperOrigin-RevId: 707992245
This commit is contained in:
Matthew Johnson 2024-12-19 12:25:41 -08:00 committed by jax authors
parent 9c3365fb95
commit 9f42b99a76

@ -50,6 +50,8 @@ import jax.numpy as jnp
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.shard_map import shard_map
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
config.parse_flags_with_absl()
map, unsafe_map = safe_map, map
@ -2162,6 +2164,48 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertAllClose(f(), np.array(range(4), dtype=np.int32).reshape(-1, 1))
def test_partial_auto_ppermute(self):
if xla_extension_version < 302:
self.skipTest('minimum xla extension version 302')
if config.use_shardy_partitioner.value:
self.skipTest('Shardy does not support full-to-shard.')
mesh = jtu.create_mesh((4, 2), ('i', 'j'))
x = jnp.arange(8.)
def g(x):
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('j')))
return jax.lax.ppermute(x, 'i', [(0, 1), (1, 2), (2, 3), (3, 0)])
@jax.jit
def f(x):
return shard_map(g,
mesh, in_specs=P('i'), out_specs=P('i'),
check_rep=False, auto=frozenset({'j'}))(x)
y = f(x) # don't crash
self.assertAllClose(y, jnp.array([6., 7., 0., 1., 2., 3., 4., 5.]),
check_dtypes=False)
# TODO(parkers,mattjj): get XLA to support this too
# def test_partial_auto_all_to_all(self):
# if config.use_shardy_partitioner.value:
# self.skipTest('Shardy does not support anything.')
#
# mesh = jtu.create_mesh((4, 2), ('i', 'j'))
# x = jnp.arange(128.).reshape(16, 8)
#
# def g(x):
# x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('j')))
# return jax.lax.all_to_all(x, 'i', 0, 1, tiled=True)
#
# @jax.jit
# def f(x):
# return shard_map(g,
# mesh, in_specs=P('i', None), out_specs=P(None, 'i'),
# check_rep=False, auto=frozenset({'j'}))(x)
#
# f(x) # don't crash
def test_vmap_grad_shmap_spmd_axis_name_residuals(self):
# https://github.com/jax-ml/jax/pull/21032