From 989a3304bfb046d8d2652ccb6f743d2015194ee0 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 24 Jun 2022 10:45:52 -0700 Subject: [PATCH] Fix the creation of pmap sharding spec when sharded_dim is None. PiperOrigin-RevId: 457045869 --- jax/interpreters/pxla.py | 9 +++++---- tests/pmap_test.py | 10 ++++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index d32d4f50c..93c1202e1 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -530,16 +530,17 @@ global_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaRe _USE_CPP_SDA = True -def _create_pmap_sharding_spec(aval, sharded_dim=0): +def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None): if sharded_dim is not None: sharded_aval = aval.update( shape=aval.shape[:sharded_dim] + aval.shape[sharded_dim+1:]) - aval_shape = aval.shape[sharded_dim] + if sharded_dim_size is None: + sharded_dim_size = aval.shape[sharded_dim] else: + assert sharded_dim_size is not None sharded_aval = aval - aval_shape = aval.shape[0] - return _pmap_sharding_spec(aval_shape, aval_shape, 1, None, + return _pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None, sharded_aval, sharded_dim) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index b59020114..894e24a2c 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -113,14 +113,14 @@ ignore_xmap_warning = partial( def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None, - devices=None): + devices=None, sharded_dim_size=None): dtype = np.int32 aval = ShapedArray(input_shape, dtype) if input_data is None: input_data = np.arange(prod(input_shape)).reshape(input_shape) - sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes) + sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes, sharded_dim_size) if devices is None: devices = jax.devices() @@ -2869,7 +2869,8 @@ class ArrayPmapTest(jtu.JaxTestCase): dc = jax.device_count() input_shape = (dc, 2) a1, input_data = create_input_array_for_pmap(input_shape, in_axes=0) - a2, _ = create_input_array_for_pmap(input_shape, in_axes=None) + a2, _ = create_input_array_for_pmap(input_shape, in_axes=None, + sharded_dim_size=a1.shape[0]) def f(x, y): assert x.shape == (2,) @@ -2890,7 +2891,8 @@ class ArrayPmapTest(jtu.JaxTestCase): def test_pmap_array_sharding_mismatch(self): input_shape = (jax.device_count(), 2) - a1, _ = create_input_array_for_pmap(input_shape, in_axes=None) + a1, _ = create_input_array_for_pmap(input_shape, in_axes=None, + sharded_dim_size=input_shape[0]) f = jax.pmap(lambda x: x, in_axes=0, out_axes=0) with jax._src.config.jax_array(True):