[sharding_in_types] Add more reshape sharding support

* Allow merging and splitting only if major most dim is sharded since that involves no data movement. This only happens if `dimensions` is None i.e. if the input array is in **row-major order**.

  * Merging: If **only** the major most dim is sharded of the merge block then that sharding is propagated to the merge block output

  * Splitting: If the dimension being split is sharded, then the sharding is propagated to the major most dimension post split only if the spec divides the new shape exactly.

PiperOrigin-RevId: 730291595
This commit is contained in:
Yash Katariya 2025-02-23 21:38:40 -08:00 committed by jax authors
parent 908ff49e22
commit 7d3c63eded
2 changed files with 56 additions and 26 deletions

View File

@ -6126,17 +6126,19 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding):
is_split, out_split = _split_on_one_axis(operand.shape, new_sizes, 'Splitting')
if is_split:
return _split_an_axis_sharding_rule(operand, out_split, new_sizes)
return _split_an_axis_sharding_rule(operand, out_split, new_sizes,
dimensions)
is_merge, operand_merge = _merge_on_one_axis(operand, new_sizes)
if is_merge:
return _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes)
return _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes,
dimensions)
raise ValueError(
'This reshape is not supported. Only 4 out of the box reshapes are'
' supported.Adding/removing singleton dims and splitting/merging without'
' sharded split/merged axes are supported. Please specify the sharding of'
' the output via the `out_sharding` argument of jax.lax.reshape.')
'This reshape is not supported. Please specify the sharding of'
' the output via the `out_sharding` argument of jax.lax.reshape. Got'
f' operand shape: {operand.shape}, new sizes: {new_sizes} and'
f' operand spec: {operand.sharding.spec}')
def _split_merge_singleton_dim_sharding_rule(operand, new_sizes):
filtered_spec = [sp for sh, sp in zip(operand.shape, operand.sharding.spec)
@ -6151,38 +6153,54 @@ def _split_merge_singleton_dim_sharding_rule(operand, new_sizes):
new_spec.append(sp)
return operand.sharding.with_spec(new_spec)
def _split_an_axis_sharding_rule(operand, out_split, new_sizes):
def _get_spec_size(sp, mesh):
tup_sp = sp if isinstance(sp, tuple) else (sp,)
return math.prod(mesh.shape[t] for t in tup_sp)
def _split_an_axis_sharding_rule(operand, out_split, new_sizes, dimensions):
new_spec = []
for sh, out, sp in safe_zip(operand.shape, out_split, operand.sharding.spec):
mesh = operand.sharding.mesh
for out, sp in safe_zip(out_split, operand.sharding.spec):
if isinstance(out, list):
if sp is not None:
if sp is None:
new_spec.extend([None] * len(out))
elif dimensions is None and out[0] % _get_spec_size(sp, mesh) == 0:
new_spec.extend([sp] + [None] * (len(out) - 1))
else:
raise ValueError(
f'Split axis cannot be sharded. Got operand dim {sh} with spec'
f' {sp}. Please specify the sharding of the output via the'
' `sharding` argument of jax.lax.reshape.')
new_spec.extend([None] * len(out))
'This reshape is not supported. Please specify the sharding of the'
' output via the `sharding` argument of jax.lax.reshape. Got'
f' operand shape: {operand.shape}, new sizes: {new_sizes} and'
f' operand spec: {operand.sharding.spec}')
else:
new_spec.append(sp)
assert len(new_spec) == len(new_sizes)
assert len(new_spec) == len(new_sizes), (new_spec, new_sizes)
return operand.sharding.with_spec(new_spec)
def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes):
def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions):
new_spec = []
mesh = operand.sharding.mesh
op_spec = iter(operand.sharding.spec)
for op_merge in operand_merge:
for new_size, op_merge in zip(new_sizes, operand_merge):
if isinstance(op_merge, list):
sp = [next(op_spec) for _ in op_merge]
if not all(s is None for s in sp):
if all(s is None for s in sp):
new_spec.append(None)
elif (sp[0] is not None and all(s is None for s in sp[1:]) and
dimensions is None):
assert new_size % _get_spec_size(sp[0], mesh) == 0
new_spec.append(sp[0])
else:
raise ValueError(
f'Merged axis cannot be sharded. Got {sp}. Please specify the'
' sharding of the output via the `sharding` argument of'
' jax.lax.reshape.')
new_spec.append(None)
'This reshape is not supported. Please specify the sharding of the'
' output via the `sharding` argument of jax.lax.reshape. Got'
f' operand shape: {operand.shape}, new sizes: {new_sizes} and'
f' operand spec: {operand.sharding.spec}')
else:
new_spec.append(next(op_spec))
assert next(op_spec, None) is None
assert len(new_spec) == len(new_sizes)
assert len(new_spec) == len(new_sizes), (new_spec, new_sizes)
return operand.sharding.with_spec(new_spec)

View File

@ -5383,8 +5383,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
P('x', None, None), P('x', None, None, None, None),
'Splitting on more than 1 axis is not supported'
),
('split_4_error', (4, 6, 8), (4, 2, 3, 8),
P('x', 'y', None), None, 'Split axis cannot be sharded'
('split_4', (4, 6, 8), (4, 2, 3, 8),
P('x', 'y', None), P('x', 'y', None, None), ''
),
('split_4_xy', (4, 12, 8), (4, 4, 3, 8),
P(None, ('x', 'y'), None), P(None, ('x', 'y'), None, None), ''
),
('split_4_error', (4, 6, 8), (4, 3, 2, 8),
P('x', 'y', None), None, 'This reshape is not supported'
),
('split_5_error', (4, 6, 8), (4, 4, 2, 6),
P('x', None, None), None, 'This reshape is not supported'
@ -5401,12 +5407,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
('merge_3', (4, 6, 2, 2, 2), (4, 6, 8),
P('x', None, None, None, None), P('x', None, None), ''
),
('merge_4', (4, 2, 3, 8), (4, 6, 8),
P(None, 'y', None, 'x'), P(None, 'y', 'x'), ''
),
('merge_4_xy', (4, 4, 3, 8), (4, 12, 8),
P(None, ('x', 'y'), None, None), P(None, ('x', 'y'), None), ''
),
('merge_4_error', (4, 2, 3, 2, 4), (4, 6, 8),
P('x', None, None, None, None), P('x', None, None),
'Merging on more than 1 axis is not supported'
),
('merge_5_error', (4, 2, 3, 8), (4, 6, 8),
P(None, 'y', None, 'x'), None, 'Merged axis cannot be sharded'
('merge_5_error', (4, 2, 6, 8), (4, 12, 8),
P(None, None, 'y', 'x'), None, 'This reshape is not supported'
),
('merge_6_error', (4, 2, 3, 8), (4, 8, 6),
P(None, 'y', None, 'x'), None, 'This reshape is not supported'