mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[sharding_in_types] Add more reshape sharding support
* Allow merging and splitting only if major most dim is sharded since that involves no data movement. This only happens if `dimensions` is None i.e. if the input array is in **row-major order**. * Merging: If **only** the major most dim is sharded of the merge block then that sharding is propagated to the merge block output * Splitting: If the dimension being split is sharded, then the sharding is propagated to the major most dimension post split only if the spec divides the new shape exactly. PiperOrigin-RevId: 730291595
This commit is contained in:
parent
908ff49e22
commit
7d3c63eded
@ -6126,17 +6126,19 @@ def _reshape_sharding_rule(operand, *, new_sizes, dimensions, sharding):
|
||||
|
||||
is_split, out_split = _split_on_one_axis(operand.shape, new_sizes, 'Splitting')
|
||||
if is_split:
|
||||
return _split_an_axis_sharding_rule(operand, out_split, new_sizes)
|
||||
return _split_an_axis_sharding_rule(operand, out_split, new_sizes,
|
||||
dimensions)
|
||||
|
||||
is_merge, operand_merge = _merge_on_one_axis(operand, new_sizes)
|
||||
if is_merge:
|
||||
return _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes)
|
||||
return _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes,
|
||||
dimensions)
|
||||
|
||||
raise ValueError(
|
||||
'This reshape is not supported. Only 4 out of the box reshapes are'
|
||||
' supported.Adding/removing singleton dims and splitting/merging without'
|
||||
' sharded split/merged axes are supported. Please specify the sharding of'
|
||||
' the output via the `out_sharding` argument of jax.lax.reshape.')
|
||||
'This reshape is not supported. Please specify the sharding of'
|
||||
' the output via the `out_sharding` argument of jax.lax.reshape. Got'
|
||||
f' operand shape: {operand.shape}, new sizes: {new_sizes} and'
|
||||
f' operand spec: {operand.sharding.spec}')
|
||||
|
||||
def _split_merge_singleton_dim_sharding_rule(operand, new_sizes):
|
||||
filtered_spec = [sp for sh, sp in zip(operand.shape, operand.sharding.spec)
|
||||
@ -6151,38 +6153,54 @@ def _split_merge_singleton_dim_sharding_rule(operand, new_sizes):
|
||||
new_spec.append(sp)
|
||||
return operand.sharding.with_spec(new_spec)
|
||||
|
||||
def _split_an_axis_sharding_rule(operand, out_split, new_sizes):
|
||||
def _get_spec_size(sp, mesh):
|
||||
tup_sp = sp if isinstance(sp, tuple) else (sp,)
|
||||
return math.prod(mesh.shape[t] for t in tup_sp)
|
||||
|
||||
def _split_an_axis_sharding_rule(operand, out_split, new_sizes, dimensions):
|
||||
new_spec = []
|
||||
for sh, out, sp in safe_zip(operand.shape, out_split, operand.sharding.spec):
|
||||
mesh = operand.sharding.mesh
|
||||
for out, sp in safe_zip(out_split, operand.sharding.spec):
|
||||
if isinstance(out, list):
|
||||
if sp is not None:
|
||||
if sp is None:
|
||||
new_spec.extend([None] * len(out))
|
||||
elif dimensions is None and out[0] % _get_spec_size(sp, mesh) == 0:
|
||||
new_spec.extend([sp] + [None] * (len(out) - 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Split axis cannot be sharded. Got operand dim {sh} with spec'
|
||||
f' {sp}. Please specify the sharding of the output via the'
|
||||
' `sharding` argument of jax.lax.reshape.')
|
||||
new_spec.extend([None] * len(out))
|
||||
'This reshape is not supported. Please specify the sharding of the'
|
||||
' output via the `sharding` argument of jax.lax.reshape. Got'
|
||||
f' operand shape: {operand.shape}, new sizes: {new_sizes} and'
|
||||
f' operand spec: {operand.sharding.spec}')
|
||||
else:
|
||||
new_spec.append(sp)
|
||||
assert len(new_spec) == len(new_sizes)
|
||||
assert len(new_spec) == len(new_sizes), (new_spec, new_sizes)
|
||||
return operand.sharding.with_spec(new_spec)
|
||||
|
||||
|
||||
def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes):
|
||||
def _merge_an_axis_sharding_rule(operand, operand_merge, new_sizes, dimensions):
|
||||
new_spec = []
|
||||
mesh = operand.sharding.mesh
|
||||
op_spec = iter(operand.sharding.spec)
|
||||
for op_merge in operand_merge:
|
||||
for new_size, op_merge in zip(new_sizes, operand_merge):
|
||||
if isinstance(op_merge, list):
|
||||
sp = [next(op_spec) for _ in op_merge]
|
||||
if not all(s is None for s in sp):
|
||||
if all(s is None for s in sp):
|
||||
new_spec.append(None)
|
||||
elif (sp[0] is not None and all(s is None for s in sp[1:]) and
|
||||
dimensions is None):
|
||||
assert new_size % _get_spec_size(sp[0], mesh) == 0
|
||||
new_spec.append(sp[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Merged axis cannot be sharded. Got {sp}. Please specify the'
|
||||
' sharding of the output via the `sharding` argument of'
|
||||
' jax.lax.reshape.')
|
||||
new_spec.append(None)
|
||||
'This reshape is not supported. Please specify the sharding of the'
|
||||
' output via the `sharding` argument of jax.lax.reshape. Got'
|
||||
f' operand shape: {operand.shape}, new sizes: {new_sizes} and'
|
||||
f' operand spec: {operand.sharding.spec}')
|
||||
else:
|
||||
new_spec.append(next(op_spec))
|
||||
assert next(op_spec, None) is None
|
||||
assert len(new_spec) == len(new_sizes)
|
||||
assert len(new_spec) == len(new_sizes), (new_spec, new_sizes)
|
||||
return operand.sharding.with_spec(new_spec)
|
||||
|
||||
|
||||
|
@ -5383,8 +5383,14 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
P('x', None, None), P('x', None, None, None, None),
|
||||
'Splitting on more than 1 axis is not supported'
|
||||
),
|
||||
('split_4_error', (4, 6, 8), (4, 2, 3, 8),
|
||||
P('x', 'y', None), None, 'Split axis cannot be sharded'
|
||||
('split_4', (4, 6, 8), (4, 2, 3, 8),
|
||||
P('x', 'y', None), P('x', 'y', None, None), ''
|
||||
),
|
||||
('split_4_xy', (4, 12, 8), (4, 4, 3, 8),
|
||||
P(None, ('x', 'y'), None), P(None, ('x', 'y'), None, None), ''
|
||||
),
|
||||
('split_4_error', (4, 6, 8), (4, 3, 2, 8),
|
||||
P('x', 'y', None), None, 'This reshape is not supported'
|
||||
),
|
||||
('split_5_error', (4, 6, 8), (4, 4, 2, 6),
|
||||
P('x', None, None), None, 'This reshape is not supported'
|
||||
@ -5401,12 +5407,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
('merge_3', (4, 6, 2, 2, 2), (4, 6, 8),
|
||||
P('x', None, None, None, None), P('x', None, None), ''
|
||||
),
|
||||
('merge_4', (4, 2, 3, 8), (4, 6, 8),
|
||||
P(None, 'y', None, 'x'), P(None, 'y', 'x'), ''
|
||||
),
|
||||
('merge_4_xy', (4, 4, 3, 8), (4, 12, 8),
|
||||
P(None, ('x', 'y'), None, None), P(None, ('x', 'y'), None), ''
|
||||
),
|
||||
('merge_4_error', (4, 2, 3, 2, 4), (4, 6, 8),
|
||||
P('x', None, None, None, None), P('x', None, None),
|
||||
'Merging on more than 1 axis is not supported'
|
||||
),
|
||||
('merge_5_error', (4, 2, 3, 8), (4, 6, 8),
|
||||
P(None, 'y', None, 'x'), None, 'Merged axis cannot be sharded'
|
||||
('merge_5_error', (4, 2, 6, 8), (4, 12, 8),
|
||||
P(None, None, 'y', 'x'), None, 'This reshape is not supported'
|
||||
),
|
||||
('merge_6_error', (4, 2, 3, 8), (4, 8, 6),
|
||||
P(None, 'y', None, 'x'), None, 'This reshape is not supported'
|
||||
|
Loading…
x
Reference in New Issue
Block a user