mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix the creation of pmap sharding spec when sharded_dim is None.
PiperOrigin-RevId: 457045869
This commit is contained in:
parent
e32373c3ea
commit
989a3304bf
@ -530,16 +530,17 @@ global_result_handlers: Dict[Tuple[Type[core.AbstractValue], OutputType], PxlaRe
|
||||
_USE_CPP_SDA = True
|
||||
|
||||
|
||||
def _create_pmap_sharding_spec(aval, sharded_dim=0):
|
||||
def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
|
||||
if sharded_dim is not None:
|
||||
sharded_aval = aval.update(
|
||||
shape=aval.shape[:sharded_dim] + aval.shape[sharded_dim+1:])
|
||||
aval_shape = aval.shape[sharded_dim]
|
||||
if sharded_dim_size is None:
|
||||
sharded_dim_size = aval.shape[sharded_dim]
|
||||
else:
|
||||
assert sharded_dim_size is not None
|
||||
sharded_aval = aval
|
||||
aval_shape = aval.shape[0]
|
||||
|
||||
return _pmap_sharding_spec(aval_shape, aval_shape, 1, None,
|
||||
return _pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None,
|
||||
sharded_aval, sharded_dim)
|
||||
|
||||
|
||||
|
@ -113,14 +113,14 @@ ignore_xmap_warning = partial(
|
||||
|
||||
|
||||
def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
|
||||
devices=None):
|
||||
devices=None, sharded_dim_size=None):
|
||||
dtype = np.int32
|
||||
aval = ShapedArray(input_shape, dtype)
|
||||
|
||||
if input_data is None:
|
||||
input_data = np.arange(prod(input_shape)).reshape(input_shape)
|
||||
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes)
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes, sharded_dim_size)
|
||||
|
||||
if devices is None:
|
||||
devices = jax.devices()
|
||||
@ -2869,7 +2869,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
dc = jax.device_count()
|
||||
input_shape = (dc, 2)
|
||||
a1, input_data = create_input_array_for_pmap(input_shape, in_axes=0)
|
||||
a2, _ = create_input_array_for_pmap(input_shape, in_axes=None)
|
||||
a2, _ = create_input_array_for_pmap(input_shape, in_axes=None,
|
||||
sharded_dim_size=a1.shape[0])
|
||||
|
||||
def f(x, y):
|
||||
assert x.shape == (2,)
|
||||
@ -2890,7 +2891,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
|
||||
def test_pmap_array_sharding_mismatch(self):
|
||||
input_shape = (jax.device_count(), 2)
|
||||
a1, _ = create_input_array_for_pmap(input_shape, in_axes=None)
|
||||
a1, _ = create_input_array_for_pmap(input_shape, in_axes=None,
|
||||
sharded_dim_size=input_shape[0])
|
||||
|
||||
f = jax.pmap(lambda x: x, in_axes=0, out_axes=0)
|
||||
with jax._src.config.jax_array(True):
|
||||
|
Loading…
x
Reference in New Issue
Block a user