Fix the creation of pmap sharding spec when sharded_dim is None.

PiperOrigin-RevId: 457045869
This commit is contained in:
Yash Katariya 2022-06-24 10:45:52 -07:00 committed by jax authors
parent e32373c3ea
commit 989a3304bf
2 changed files with 11 additions and 8 deletions

View File

@ -530,16 +530,17 @@ global_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaRe
_USE_CPP_SDA = True
def _create_pmap_sharding_spec(aval, sharded_dim=0):
def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
if sharded_dim is not None:
sharded_aval = aval.update(
shape=aval.shape[:sharded_dim] + aval.shape[sharded_dim+1:])
aval_shape = aval.shape[sharded_dim]
if sharded_dim_size is None:
sharded_dim_size = aval.shape[sharded_dim]
else:
assert sharded_dim_size is not None
sharded_aval = aval
aval_shape = aval.shape[0]
return _pmap_sharding_spec(aval_shape, aval_shape, 1, None,
return _pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None,
sharded_aval, sharded_dim)

View File

@ -113,14 +113,14 @@ ignore_xmap_warning = partial(
def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
devices=None):
devices=None, sharded_dim_size=None):
dtype = np.int32
aval = ShapedArray(input_shape, dtype)
if input_data is None:
input_data = np.arange(prod(input_shape)).reshape(input_shape)
sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes)
sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes, sharded_dim_size)
if devices is None:
devices = jax.devices()
@ -2869,7 +2869,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
dc = jax.device_count()
input_shape = (dc, 2)
a1, input_data = create_input_array_for_pmap(input_shape, in_axes=0)
a2, _ = create_input_array_for_pmap(input_shape, in_axes=None)
a2, _ = create_input_array_for_pmap(input_shape, in_axes=None,
sharded_dim_size=a1.shape[0])
def f(x, y):
assert x.shape == (2,)
@ -2890,7 +2891,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
def test_pmap_array_sharding_mismatch(self):
input_shape = (jax.device_count(), 2)
a1, _ = create_input_array_for_pmap(input_shape, in_axes=None)
a1, _ = create_input_array_for_pmap(input_shape, in_axes=None,
sharded_dim_size=input_shape[0])
f = jax.pmap(lambda x: x, in_axes=0, out_axes=0)
with jax._src.config.jax_array(True):